2025-06-30 09:53:49 +08:00

205 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import matplotlib.pyplot as plt
# ======================= 参数配置 ======================= #
GRID_SIZE = 8
START = (1, 1)
GOAL = (6, 6)
TOTAL_STEPS = 40
# ======================= 高斯场调度器 ======================= #
# 点状高斯源
class GaussianSchedule:
"""单个高斯场随时间演化的调度"""
def __init__(self, center, height_fn, sigma):
self.center = np.array(center)
self.height_fn = height_fn
# 如果传入的是函数,则用它;否则封装为恒定函数
self.sigma_fn = sigma if callable(sigma) else (lambda t: sigma)
def field_at(self, t):
gx, gy = np.meshgrid(np.arange(GRID_SIZE), np.arange(GRID_SIZE), indexing='ij')
d2 = (gx - self.center[0])**2 + (gy - self.center[1])**2
height = self.height_fn(t)
sigma = self.sigma_fn(t)
return height * np.exp(-d2 / (2 * sigma**2)) # type: ignore
# 路径状高斯源(沿线段)
class GaussianPathSchedule:
def __init__(self, start, end, height_fn, sigma):
self.start = np.array(start)
self.end = np.array(end)
self.height_fn = height_fn
self.sigma_fn = sigma if callable(sigma) else (lambda t: sigma)
def field_at(self, t):
gx, gy = np.meshgrid(np.arange(GRID_SIZE), np.arange(GRID_SIZE), indexing='ij')
height = self.height_fn(t)
sigma = self.sigma_fn(t)
# 向量 AB 和点P的到线段距离
AB = self.end - self.start
AB_norm_sq = np.dot(AB, AB)
P = np.stack([gx, gy], axis=-1) # shape=(GRID_SIZE, GRID_SIZE, 2)
# 向量 AP
AP = P - self.start
# 投影因子0~1
proj = np.clip(np.sum(AP * AB, axis=-1) / AB_norm_sq, 0, 1)
# 最近点 D 在线段 AB 上
D = self.start + proj[..., None] * AB # shape=(GRID_SIZE, GRID_SIZE, 2)
# 点到线距离
dist2 = np.sum((P - D)**2, axis=-1)
return height * np.exp(-dist2 / (2 * sigma**2)) #type: ignore
# 带时间窗口的高斯源
class TimedGaussianSchedule:
def __init__(self, center, height_fn, sigma, t_start=0, t_end=float('inf'), hold_final=False):
self.center = np.array(center)
self.height_fn = height_fn
self.sigma_fn = sigma if callable(sigma) else (lambda t: sigma)
self.t_start = t_start
self.t_end = t_end
self.hold_final = hold_final
def field_at(self, t):
if t < self.t_start:
return np.zeros((GRID_SIZE, GRID_SIZE))
elif t > self.t_end:
if self.hold_final:
t = self.t_end # 保持终点值
else:
return np.zeros((GRID_SIZE, GRID_SIZE))
gx, gy = np.meshgrid(np.arange(GRID_SIZE), np.arange(GRID_SIZE), indexing='ij')
d2 = (gx - self.center[0])**2 + (gy - self.center[1])**2
height = self.height_fn(t)
sigma = self.sigma_fn(t)
return height * np.exp(-d2 / (2 * sigma**2)) #type: ignore
# 势场叠加器
class FieldScheduler:
"""管理多个高斯 schedule 并叠加形成总势场"""
def __init__(self, schedules, base_height=10):
self.schedules = schedules
self.base_height = base_height
def get_field(self, t):
total_field = np.zeros((GRID_SIZE, GRID_SIZE))
for sched in self.schedules:
total_field += sched.field_at(t)
return total_field + self.base_height
# ======================= Agent 移动策略 ======================= #
class Agent:
def __init__(self, start_pos):
self.pos = np.array(start_pos)
self.velocity = np.array([0, 0])
self.path = [tuple(self.pos)]
def step(self, Z):
x, y = self.pos
neighbors = []
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
if dx == 0 and dy == 0:
continue
nx, ny = x + dx, y + dy
if 0 <= nx < GRID_SIZE and 0 <= ny < GRID_SIZE:
neighbors.append(((nx, ny), Z[nx, ny], np.array([dx, dy])))
min_score = float('inf')
for (nx, ny), z, dvec in neighbors:
momentum_bonus = -np.dot(dvec, self.velocity)
score = z + 0.5 * momentum_bonus # 动量影响因子
if score < min_score:
min_score = score
best = ((nx, ny), dvec)
# 动量保持 + 最低势能选择
min_score = float('inf')
for (nx, ny), z, dvec in neighbors:
momentum_bonus = -np.dot(dvec, self.velocity) # 越同方向越优
score = z + 0.8 * momentum_bonus
if score < min_score:
min_score = score
best = ((nx, ny), dvec)
next_pos, dvec = best
self.pos = np.array(next_pos)
self.velocity = dvec
self.path.append(tuple(self.pos))
# ======================= 可视化模块 ======================= #
class Visualizer:
def __init__(self, start, goal):
self.start = start
self.goal = goal
plt.ion()
def show(self, Z, agent_pos, t):
plt.clf()
plt.title(f"Time {t}")
plt.imshow(Z.T, origin='lower', cmap='viridis')
plt.colorbar(label='Height')
plt.plot(agent_pos[0], agent_pos[1], 'ro', label='Agent')
plt.plot(self.start[0], self.start[1], 'bs', label='Start')
plt.plot(self.goal[0], self.goal[1], 'g*', label='Goal')
plt.legend()
plt.pause(0.1)
def close(self):
plt.ioff()
plt.show()
# ======================= 主程序入口 ======================= #
def main():
# 定义高斯变化 schedule
schedule_list = [
TimedGaussianSchedule(
center=START,
height_fn=lambda t: -8 * (1 - (t - 0) / 20),
sigma=1.5,
t_start=0,
t_end=20,
hold_final=True
),
TimedGaussianSchedule(
center=GOAL,
height_fn=lambda t: -8 * ((t - 20) / 20),
sigma=1.5,
t_start=20,
t_end=40,
hold_final=True
),
GaussianPathSchedule(
start=START,
end=GOAL,
height_fn=lambda t: -3,
sigma=1.5
)
]
field_scheduler = FieldScheduler(schedule_list, base_height=0)
agent = Agent(START)
vis = Visualizer(START, GOAL)
for t in range(TOTAL_STEPS + 1):
Z = field_scheduler.get_field(t)
agent.step(Z)
vis.show(Z, agent.pos, t)
# if np.array_equal(agent.pos, GOAL):
# print(f"Agent reached goal at time {t}")
# break
vis.close()
print("Agent path:", agent.path)
if __name__ == "__main__":
main()