185 lines
6.6 KiB
Python
185 lines
6.6 KiB
Python
import numpy as np
|
||
from scipy.spatial.transform import Rotation as R, Slerp
|
||
|
||
class TrajectoryInterpolator:
|
||
def __init__(self, waypoints):
|
||
self.waypoints = sorted(waypoints, key=lambda x: x["time"])
|
||
self.times = [wp["time"] for wp in self.waypoints]
|
||
self._compute_position_coeffs()
|
||
self._prepare_slerp()
|
||
|
||
def _compute_position_coeffs(self):
|
||
self.coeffs_pos = []
|
||
for i in range(len(self.waypoints) - 1):
|
||
wp0 = self.waypoints[i]
|
||
wp1 = self.waypoints[i + 1]
|
||
T = wp1["time"] - wp0["time"]
|
||
|
||
p0 = np.array(wp0["position"])
|
||
v0 = np.array(wp0.get("velocity", np.zeros(3)))
|
||
a0 = np.array(wp0.get("acceleration", np.zeros(3)))
|
||
|
||
p1 = np.array(wp1["position"])
|
||
v1 = np.array(wp1.get("velocity", np.zeros(3)))
|
||
a1 = np.array(wp1.get("acceleration", np.zeros(3)))
|
||
|
||
# 系数矩阵 A(以 tau = 0~T)
|
||
A = np.array([
|
||
[1, 0, 0, 0, 0, 0],
|
||
[0, 1, 0, 0, 0, 0],
|
||
[0, 0, 2, 0, 0, 0],
|
||
[1, T, T**2, T**3, T**4, T**5],
|
||
[0, 1, 2*T, 3*T**2,4*T**3, 5*T**4],
|
||
[0, 0, 2, 6*T, 12*T**2, 20*T**3]
|
||
])
|
||
|
||
b = np.vstack([p0, v0, a0, p1, v1, a1]) # shape (6,3)
|
||
coeffs = np.linalg.solve(A, b) # shape (6,3), 每列是一个维度的系数
|
||
self.coeffs_pos.append((coeffs, T))
|
||
|
||
def _prepare_slerp(self):
|
||
self.rot_slerps = []
|
||
for i in range(len(self.waypoints) - 1):
|
||
t0 = self.waypoints[i]["time"]
|
||
t1 = self.waypoints[i + 1]["time"]
|
||
rot0 = R.from_quat(self.waypoints[i]["orientation"])
|
||
rot1 = R.from_quat(self.waypoints[i + 1]["orientation"])
|
||
self.rot_slerps.append(Slerp([0, 1], R.concatenate([rot0, rot1])))
|
||
|
||
def interpolate(self, t):
|
||
if t <= self.times[0]:
|
||
return self._format_output(self.waypoints[0])
|
||
elif t >= self.times[-1]:
|
||
return self._format_output(self.waypoints[-1])
|
||
|
||
i = np.searchsorted(self.times, t, side='right') - 1
|
||
tau = t - self.times[i]
|
||
coeffs, T = self.coeffs_pos[i]
|
||
alpha = tau / T
|
||
|
||
# 每个维度分开处理
|
||
pos = np.array([np.polyval(coeffs[:, dim][::-1], tau) for dim in range(3)])
|
||
vel = np.array([np.polyval(np.polyder(coeffs[:, dim][::-1], 1), tau) for dim in range(3)])
|
||
acc = np.array([np.polyval(np.polyder(coeffs[:, dim][::-1], 2), tau) for dim in range(3)])
|
||
|
||
# 姿态 slerp
|
||
ori = self.rot_slerps[i]([alpha])[0].as_quat()
|
||
|
||
ang_vel, ang_acc = self._estimate_angular_derivatives(t)
|
||
|
||
x_r = np.concatenate([pos, ori])
|
||
v_r = np.concatenate([vel, ang_vel])
|
||
a_r = np.concatenate([acc, ang_acc])
|
||
return x_r, v_r, a_r
|
||
|
||
def _format_output(self, wp):
|
||
pos = np.array(wp["position"])
|
||
ori = np.array(wp["orientation"])
|
||
vel = np.array(wp.get("velocity", np.zeros(3)))
|
||
acc = np.array(wp.get("acceleration", np.zeros(3)))
|
||
x_r = np.concatenate([pos, ori])
|
||
v_r = np.concatenate([vel, np.zeros(3)])
|
||
a_r = np.concatenate([acc, np.zeros(3)])
|
||
return x_r, v_r, a_r
|
||
def _quat_diff_to_angular_velocity(self, q1, q0, dt):
|
||
"""计算从 q0 到 q1 的角速度,单位 rad/s"""
|
||
dq = R.from_quat(q1) * R.from_quat(q0).inv()
|
||
angle = dq.magnitude()
|
||
axis = dq.as_rotvec()
|
||
if dt == 0 or angle == 0:
|
||
return np.zeros(3)
|
||
return axis / dt # rotvec 本身是 axis * angle
|
||
|
||
def _estimate_angular_derivatives(self, t, delta=0.08):
|
||
"""数值估计角速度和角加速度,避免递归调用"""
|
||
if t <= self.times[0] or t >= self.times[-1]:
|
||
return np.zeros(3), np.zeros(3)
|
||
|
||
i = np.searchsorted(self.times, t, side='right') - 1
|
||
t0, t1 = self.times[i], self.times[i + 1]
|
||
T = t1 - t0
|
||
alpha = (t - t0) / T
|
||
|
||
alpha_prev = max(0.0, (t - delta - t0) / T)
|
||
alpha_next = min(1.0, (t + delta - t0) / T)
|
||
|
||
q_prev = self.rot_slerps[i]([alpha_prev])[0].as_quat()
|
||
q_now = self.rot_slerps[i]([alpha])[0].as_quat()
|
||
q_next = self.rot_slerps[i]([alpha_next])[0].as_quat()
|
||
|
||
w1 = self._quat_diff_to_angular_velocity(q_now, q_prev, delta)
|
||
w2 = self._quat_diff_to_angular_velocity(q_next, q_now, delta)
|
||
w = w1
|
||
a = (w2 - w1) / (2 * delta)
|
||
return w, a
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import matplotlib.pyplot as plt
|
||
|
||
waypoints = [
|
||
{"time": 0.0, "position": [0.25, -0.135, 0.3443], "velocity": [0, 0, 0],
|
||
"acceleration": [0, 0, 0],
|
||
"orientation": R.from_euler("xyz", [0, 0, 0], degrees=True).as_quat()},
|
||
{"time": 5.0, "position": [0.30, -0.135, 0.3043], "velocity": [0, 0, 0],
|
||
"acceleration": [0, 0, 0],
|
||
"orientation": R.from_euler("xyz", [0, 30, 30], degrees=True).as_quat()},
|
||
] ## 单位 m deg
|
||
|
||
interpolator = TrajectoryInterpolator(waypoints)
|
||
ts = np.linspace(0, 5, 100)
|
||
|
||
positions = []
|
||
velocities = []
|
||
accelerations = []
|
||
orientations = []
|
||
angular_velocities = []
|
||
angular_accelerations = []
|
||
|
||
for t in ts:
|
||
x_r, v_r, a_r = interpolator.interpolate(t)
|
||
positions.append(x_r[:3])
|
||
orientations.append(R.from_quat(x_r[3:]).as_euler('xyz', degrees=True))
|
||
velocities.append(v_r[:3])
|
||
angular_velocities.append(v_r[3:])
|
||
accelerations.append(a_r[:3])
|
||
angular_accelerations.append(a_r[3:])
|
||
|
||
positions = np.array(positions)
|
||
orientations = np.array(orientations)
|
||
velocities = np.array(velocities)
|
||
angular_velocities = np.array(angular_velocities)
|
||
accelerations = np.array(accelerations)
|
||
angular_accelerations = np.array(angular_accelerations)
|
||
|
||
fig, axs = plt.subplots(6, 1, figsize=(10, 16), sharex=True)
|
||
|
||
axs[0].plot(ts, positions)
|
||
axs[0].set_ylabel("Position [m]")
|
||
axs[0].legend(['x', 'y', 'z'])
|
||
|
||
axs[1].plot(ts, velocities)
|
||
axs[1].set_ylabel("Velocity [m/s]")
|
||
axs[1].legend(['vx', 'vy', 'vz'])
|
||
|
||
axs[2].plot(ts, accelerations)
|
||
axs[2].set_ylabel("Acceleration [m/s²]")
|
||
axs[2].legend(['ax', 'ay', 'az'])
|
||
|
||
axs[3].plot(ts, orientations)
|
||
axs[3].set_ylabel("Orientation [deg]")
|
||
axs[3].legend(['roll', 'pitch', 'yaw'])
|
||
|
||
axs[4].plot(ts, angular_velocities)
|
||
axs[4].set_ylabel("Ang Vel [rad/s]")
|
||
axs[4].legend(['wx', 'wy', 'wz'])
|
||
|
||
axs[5].plot(ts, angular_accelerations)
|
||
axs[5].set_ylabel("Ang Acc [rad/s²]")
|
||
axs[5].set_xlabel("Time [s]")
|
||
axs[5].legend(['awx', 'awy', 'awz'])
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|