185 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from scipy.spatial.transform import Rotation as R, Slerp
class TrajectoryInterpolator:
def __init__(self, waypoints):
self.waypoints = sorted(waypoints, key=lambda x: x["time"])
self.times = [wp["time"] for wp in self.waypoints]
self._compute_position_coeffs()
self._prepare_slerp()
def _compute_position_coeffs(self):
self.coeffs_pos = []
for i in range(len(self.waypoints) - 1):
wp0 = self.waypoints[i]
wp1 = self.waypoints[i + 1]
T = wp1["time"] - wp0["time"]
p0 = np.array(wp0["position"])
v0 = np.array(wp0.get("velocity", np.zeros(3)))
a0 = np.array(wp0.get("acceleration", np.zeros(3)))
p1 = np.array(wp1["position"])
v1 = np.array(wp1.get("velocity", np.zeros(3)))
a1 = np.array(wp1.get("acceleration", np.zeros(3)))
# 系数矩阵 A以 tau = 0~T
A = np.array([
[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 2, 0, 0, 0],
[1, T, T**2, T**3, T**4, T**5],
[0, 1, 2*T, 3*T**2,4*T**3, 5*T**4],
[0, 0, 2, 6*T, 12*T**2, 20*T**3]
])
b = np.vstack([p0, v0, a0, p1, v1, a1]) # shape (6,3)
coeffs = np.linalg.solve(A, b) # shape (6,3), 每列是一个维度的系数
self.coeffs_pos.append((coeffs, T))
def _prepare_slerp(self):
self.rot_slerps = []
for i in range(len(self.waypoints) - 1):
t0 = self.waypoints[i]["time"]
t1 = self.waypoints[i + 1]["time"]
rot0 = R.from_quat(self.waypoints[i]["orientation"])
rot1 = R.from_quat(self.waypoints[i + 1]["orientation"])
self.rot_slerps.append(Slerp([0, 1], R.concatenate([rot0, rot1])))
def interpolate(self, t):
if t <= self.times[0]:
return self._format_output(self.waypoints[0])
elif t >= self.times[-1]:
return self._format_output(self.waypoints[-1])
i = np.searchsorted(self.times, t, side='right') - 1
tau = t - self.times[i]
coeffs, T = self.coeffs_pos[i]
alpha = tau / T
# 每个维度分开处理
pos = np.array([np.polyval(coeffs[:, dim][::-1], tau) for dim in range(3)])
vel = np.array([np.polyval(np.polyder(coeffs[:, dim][::-1], 1), tau) for dim in range(3)])
acc = np.array([np.polyval(np.polyder(coeffs[:, dim][::-1], 2), tau) for dim in range(3)])
# 姿态 slerp
ori = self.rot_slerps[i]([alpha])[0].as_quat()
ang_vel, ang_acc = self._estimate_angular_derivatives(t)
x_r = np.concatenate([pos, ori])
v_r = np.concatenate([vel, ang_vel])
a_r = np.concatenate([acc, ang_acc])
return x_r, v_r, a_r
def _format_output(self, wp):
pos = np.array(wp["position"])
ori = np.array(wp["orientation"])
vel = np.array(wp.get("velocity", np.zeros(3)))
acc = np.array(wp.get("acceleration", np.zeros(3)))
x_r = np.concatenate([pos, ori])
v_r = np.concatenate([vel, np.zeros(3)])
a_r = np.concatenate([acc, np.zeros(3)])
return x_r, v_r, a_r
def _quat_diff_to_angular_velocity(self, q1, q0, dt):
"""计算从 q0 到 q1 的角速度,单位 rad/s"""
dq = R.from_quat(q1) * R.from_quat(q0).inv()
angle = dq.magnitude()
axis = dq.as_rotvec()
if dt == 0 or angle == 0:
return np.zeros(3)
return axis / dt # rotvec 本身是 axis * angle
def _estimate_angular_derivatives(self, t, delta=0.08):
"""数值估计角速度和角加速度,避免递归调用"""
if t <= self.times[0] or t >= self.times[-1]:
return np.zeros(3), np.zeros(3)
i = np.searchsorted(self.times, t, side='right') - 1
t0, t1 = self.times[i], self.times[i + 1]
T = t1 - t0
alpha = (t - t0) / T
alpha_prev = max(0.0, (t - delta - t0) / T)
alpha_next = min(1.0, (t + delta - t0) / T)
q_prev = self.rot_slerps[i]([alpha_prev])[0].as_quat()
q_now = self.rot_slerps[i]([alpha])[0].as_quat()
q_next = self.rot_slerps[i]([alpha_next])[0].as_quat()
w1 = self._quat_diff_to_angular_velocity(q_now, q_prev, delta)
w2 = self._quat_diff_to_angular_velocity(q_next, q_now, delta)
w = w1
a = (w2 - w1) / (2 * delta)
return w, a
if __name__ == "__main__":
import matplotlib.pyplot as plt
waypoints = [
{"time": 0.0, "position": [0.25, -0.135, 0.3443], "velocity": [0, 0, 0],
"acceleration": [0, 0, 0],
"orientation": R.from_euler("xyz", [0, 0, 0], degrees=True).as_quat()},
{"time": 5.0, "position": [0.30, -0.135, 0.3043], "velocity": [0, 0, 0],
"acceleration": [0, 0, 0],
"orientation": R.from_euler("xyz", [0, 30, 30], degrees=True).as_quat()},
] ## 单位 m deg
interpolator = TrajectoryInterpolator(waypoints)
ts = np.linspace(0, 5, 100)
positions = []
velocities = []
accelerations = []
orientations = []
angular_velocities = []
angular_accelerations = []
for t in ts:
x_r, v_r, a_r = interpolator.interpolate(t)
positions.append(x_r[:3])
orientations.append(R.from_quat(x_r[3:]).as_euler('xyz', degrees=True))
velocities.append(v_r[:3])
angular_velocities.append(v_r[3:])
accelerations.append(a_r[:3])
angular_accelerations.append(a_r[3:])
positions = np.array(positions)
orientations = np.array(orientations)
velocities = np.array(velocities)
angular_velocities = np.array(angular_velocities)
accelerations = np.array(accelerations)
angular_accelerations = np.array(angular_accelerations)
fig, axs = plt.subplots(6, 1, figsize=(10, 16), sharex=True)
axs[0].plot(ts, positions)
axs[0].set_ylabel("Position [m]")
axs[0].legend(['x', 'y', 'z'])
axs[1].plot(ts, velocities)
axs[1].set_ylabel("Velocity [m/s]")
axs[1].legend(['vx', 'vy', 'vz'])
axs[2].plot(ts, accelerations)
axs[2].set_ylabel("Acceleration [m/s²]")
axs[2].legend(['ax', 'ay', 'az'])
axs[3].plot(ts, orientations)
axs[3].set_ylabel("Orientation [deg]")
axs[3].legend(['roll', 'pitch', 'yaw'])
axs[4].plot(ts, angular_velocities)
axs[4].set_ylabel("Ang Vel [rad/s]")
axs[4].legend(['wx', 'wy', 'wz'])
axs[5].plot(ts, angular_accelerations)
axs[5].set_ylabel("Ang Acc [rad/s²]")
axs[5].set_xlabel("Time [s]")
axs[5].legend(['awx', 'awy', 'awz'])
plt.tight_layout()
plt.show()