2025-06-29 23:59:06 -07:00

302 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

try:from scripts.APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
except:from APF_global_demo import GaussianSchedule,GaussianPathSchedule,TimedGaussianSchedule,FieldScheduler,Agent,Visualizer
import numpy as np
try:from scripts.sorter import sorter
except:from sorter import sorter
import matplotlib.pyplot as plt
import json
import time
import os
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文黑体Windows/Linux
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
class planner():
def __init__(self,acupoints_dict:dict,sorted_path,time_scale = 1.0):
self.scale = 5
self.mid_x = 4
self.raw_path = sorted_path
self.acupoints_dict = acupoints_dict
self.acupoints_dict_scaled = {
name:{
**info, # **info 内容和原来一样但把Pos项乘以一个缩放系数
"pos":tuple(np.array(info["pos"]) * self.scale)
} for name, info in self.acupoints_dict.items()
}
self.path = []
self.time_scale = time_scale
self.vis = Visualizer((8 * self.scale,26 * self.scale),[])
self.jump_pos = None
self.jump_pos_name = None
def _make_scheduler(self, start, goal, dist):
start_duration = dist * 0.45
end_duration = dist * 0.45
return FieldScheduler([
# 起点附近,早期激活、后期衰减
TimedGaussianSchedule(
center=start,
height_fn=lambda t: -8 * (1 - (t - 0) / start_duration),
sigma=1.5,
t_start=0,
t_end=start_duration,
hold_final=True
),
# 终点附近后期激活、前期为0
TimedGaussianSchedule(
center=goal,
height_fn=lambda t: -8 * ((t - end_duration) / end_duration),
sigma=1.5,
t_start=dist - end_duration,
t_end=dist,
hold_final=True
),
# 中间常驻路径引导势场
GaussianPathSchedule(
start=start,
end=goal,
height_fn=lambda t:-4,
sigma=2.0
)
])
def path_generator(self):
self.path.clear()
i = 0
# --- 判断特殊情况:起始点是左侧唯一按摩点,右侧有对应对称点 ---
if len(self.raw_path) >= 2:
start_pos = self.acupoints_dict_scaled[self.raw_path[0]]['pos']
left_coords = [
self.acupoints_dict_scaled[name]['pos']
for name in self.raw_path
if self.acupoints_dict_scaled[name]['pos'][0] < self.mid_x * self.scale
]
right_coords = [
self.acupoints_dict_scaled[name]['pos']
for name in self.raw_path
if self.acupoints_dict_scaled[name]['pos'][0] >= self.mid_x * self.scale
]
# 判断左侧是否只有一个点
if len(left_coords) == 1 and len(right_coords) >= 1:
# 判断起始点是否是左侧这个唯一点,且距离右侧至少有一个点对应
max_x_left = left_coords[0][0]
if np.allclose(start_pos, left_coords[0], atol=3.0):
print("检测到左侧只有一个点,右侧有对应对称点,采用特殊路径顺序")
# 生成路径:先左侧点停留若干(假设重复几次),再去右侧点停留
left_point = left_coords[0]
right_point = right_coords[0] # 这里可以改成更合适的点,比如对称点或第一个右侧点
path = []
# 左侧点按10次可调整
path.extend([left_point] * 10)
# 右侧点按10次
path.extend([right_point] * 10)
self.path = path
return self.path # 直接返回特殊路径,跳过默认流程
# 回到正常流程
while i < len(self.raw_path) - 1:
name_start = self.raw_path[i]
name_goal = self.raw_path[i + 1]
start = self.acupoints_dict_scaled[name_start]["pos"]
goal = self.acupoints_dict_scaled[name_goal]["pos"]
# --- 生成场和agent ---
start_array = np.array(start)
goal_array = np.array(goal)
dist = np.linalg.norm(goal_array-start_array)
field_scheduler = self._make_scheduler(start, goal, dist)
agent = Agent(start)
# --- 执行路径引导 ---
t = 0
while True:
Z = field_scheduler.get_field(t)
agent.step(Z)
if self.vis:
self.vis.show(Z=Z, agent_pos=agent.pos, t=t, start=start, goal=goal)
self.path.append(tuple(agent.pos))
if np.linalg.norm(agent.pos - np.array(goal)) <= 1.0 and t > dist*self.time_scale:
break
t += 1
# --- 最终可视化目标点(可选) ---
if self.vis:
self.vis.show(Z=Z, agent_pos=goal, t=t + 1, start=start, goal=goal)
# --- 判断是否跳跃 ---
final_pos = agent.pos
left_coords = [
self.acupoints_dict_scaled[name]["pos"]
for name in self.raw_path
if self.acupoints_dict_scaled[name]["pos"][0] < self.mid_x * self.scale
]
if left_coords:
max_x = max(p[0] for p in left_coords)
candidates = [p for p in left_coords if p[0] == max_x]
target_jump_pos = min(candidates, key=lambda p: p[1])
self.jump_pos = target_jump_pos # 将跳跃位置添加入类
print(f"flags:{target_jump_pos}")
print(f"final_pos:{final_pos}")
if np.allclose(final_pos, target_jump_pos, atol=3.0):
print(f"跳跃触发 @ {name_goal}{self.raw_path[i + 2] if i + 2 < len(self.raw_path) else '终点'}")
i += 2 # 跳过一个点
else:
i += 1
else:
i += 1
return self.path
def replanning(self):
"""
将 self.path 中的坐标映射为穴位名 + 连续停留次数的队列(动态显示 + 竖屏)
"""
# 异常处理
if not self.path:
print("无路径可规划跳过replanning")
return []
name_pos_list = [(name, np.array(info['pos'])) for name, info in self.acupoints_dict_scaled.items()]
# === 映射路径点为最近穴位名 === #
mapped_names = []
for pt in self.path:
pt = np.array(pt)
nearest_name = min(name_pos_list, key=lambda item: np.linalg.norm(item[1] - pt))[0]
mapped_names.append(nearest_name)
if np.array_equal(pt,self.jump_pos):
self.jump_pos_name = nearest_name
# === 统计连续按压次数 === #
result_queue = []
prev_name = mapped_names[0]
count = 1
for name in mapped_names[1:]:
if name == prev_name:
count += 1
else:
result_queue.append((prev_name, count))
prev_name = name
count = 1
result_queue.append((prev_name, count))
# === 动态可视化 === #
plt.ion()
fig, ax = plt.subplots(figsize=(6, 12)) # 竖屏显示
grid_w, grid_h = map(int, self.vis.grid_size)
ax.set_xlim(0, grid_w)
ax.set_ylim(grid_h, 0)
ax.invert_yaxis()
for idx, (name, count) in enumerate(result_queue):
pos = self.acupoints_dict_scaled[name]['pos']
ax.plot(pos[0], pos[1], 'ro')
ax.text(pos[0], pos[1], f"{idx+1}.{name} x{count}", fontsize=8, color='blue')
if idx > 0:
prev_pos = self.acupoints_dict_scaled[result_queue[idx - 1][0]]['pos']
ax.plot([prev_pos[0], pos[0]], [prev_pos[1], pos[1]], 'g--')
ax.set_title(f"{idx+1} 个穴位:{name} x{count}")
plt.pause(0.5) # 每步停留 0.5 秒
plt.ioff()
plt.show()
print("连续穴位按压序列:")
for item in result_queue:
print(item)
return result_queue
def convert_queue_to_task_plan(self,queue,body_part,massage_head,massage_name):
timestamp = int(time.time())
if not massage_name:
massage_name = f"{body_part}-{massage_head}-{timestamp}"
if not massage_head:
massage_head = f"finger" # 默认值
if not body_part:
body_part = f"AI诊疗按摩部位"
task_plan = []
is_line_mode = massage_head in {'thermotherapy','stone'}
for i in range(len(queue)):
name, stay_time = queue[i]
if not is_line_mode:
# 点按模式添加停留任务AA,BB...
# 添加停留任务 AA、BB、CC...
task_plan.append({
"start_point": name,
"end_point": name,
"time": stay_time,
"path": "point"
})
# 判断是否需要跳跃
# 添加移动任务 AB、BC、CD...,除了最后一个点不需要转移
if i < len(queue) - 1:
next_name = queue[i + 1][0]
# 跳过跳跃点的移动任务
if name == self.jump_pos_name:
print(f"跳过{name}->{next_name}的移动任务")
continue
task_plan.append({
"start_point": name,
"end_point": next_name,
"time": None,
"path": "line"
})
# --- 构造完整项目结构 ---
project_data = {
f"{massage_name}": {
"body_part": body_part,
"can_delete": True,
"choose_task": massage_head,
"introduction": f"{body_part}-AI诊断理疗方案",
"task_plan": task_plan
}
}
# 打印格式化的 JSON 字符串
print(json.dumps(project_data, ensure_ascii=False, indent=4))
self.save_project_json(project_data)
return project_data
def save_project_json(self,project_data: dict, save_dir: str = "saved_projects", file_prefix: str = "massage_project"):
"""
将 massage 项目字典保存为本地 JSON 文件
:param project_data: 要保存的项目数据dict
:param save_dir: 保存目录(默认 "saved_projects"
:param file_prefix: 文件名前缀(默认 "massage_project"
:return: 保存的完整文件路径
"""
# 创建保存目录(如不存在)
os.makedirs(save_dir, exist_ok=True)
# 构造文件名:前缀 + 时间戳 + .json
timestamp = int(time.time())
file_name = f"{file_prefix}_{timestamp}.json"
file_path = os.path.join(save_dir, file_name)
# 写入 JSON 文件
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(project_data, f, ensure_ascii=False, indent=4)
print(f"项目已保存到:{file_path}")
return file_path
if __name__ == "__main__":
test_text = ['风池穴', '肩井穴', '天柱穴', '大椎穴', '肩髃穴', '天宗穴']
mySorter = sorter()
sorted_path = mySorter.sort_acupoints(test_text)
print(sorted_path)
myPlanner = planner(mySorter.acupoints_metadata,sorted_path,80,'line')
myPlanner.path_generator()
myPlanner.replanning()