136 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import numpy as np
from scipy.spatial.transform import Rotation as R
from .base_controller import BaseController
from .arm_state import ArmState
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import time
from tools.yaml_operator import read_yaml
import random
class AdmittanceController(BaseController):
def __init__(self, name, state:ArmState,config_path) -> None:
super().__init__(name, state)
self.load_config(config_path)
def load_config(self, config_path):
config_dict = read_yaml(config_path)
if self.name != config_dict['name']:
raise ValueError(f"Controller name {self.name} does not match config name {config_dict['name']}")
mass_tran = np.array(config_dict['mass_tran'])
mass_rot = np.array(config_dict['mass_rot'])
stiff_tran = np.array(config_dict['stiff_tran'])
stiff_rot = np.array(config_dict['stiff_rot'])
desired_xi = np.array(config_dict['desired_xi'])
damp_tran = np.array(config_dict['damp_tran'])
damp_rot = np.array(config_dict['damp_rot'])
self.pos_scale_factor = config_dict['pos_scale_factor']
self.rot_scale_factor = config_dict['rot_scale_factor']
for i in range(3):
if damp_tran[i] < 0:
damp_tran[i] = 2 * desired_xi * np.sqrt(stiff_tran[i] * mass_tran[i])
if damp_rot[i] < 0:
damp_rot[i] = 2 * desired_xi * np.sqrt(stiff_rot[i] * mass_rot[i])
self.M = np.diag(np.concatenate([mass_tran, mass_rot]))
self.M_inv = np.linalg.inv(self.M)
self.K = np.diag(np.concatenate([stiff_tran, stiff_rot]))
self.D = np.diag(np.concatenate([damp_tran, damp_rot]))
self.laset_print_time = 0
def step(self,dt):
# 方向统一
if self.state.desired_orientation.dot(self.state.arm_orientation) < 0:
self.state.arm_orientation = -self.state.arm_orientation
# 缓存常用计算
arm_ori_quat = R.from_quat(self.state.arm_orientation)
arm_ori_mat = arm_ori_quat.as_matrix()
# 位置误差
temp_pose_error = self.state.arm_position - self.state.desired_position
self.state.pose_error[:3] = arm_ori_mat.T @ temp_pose_error
# 姿态误差(四元数)
rot_err_quat = arm_ori_quat.inv() * R.from_quat(self.state.desired_orientation)
self.state.pose_error[3:] = -rot_err_quat.as_rotvec(degrees=False)
# 期望加速度
wrench_err = self.state.external_wrench_tcp - self.state.desired_wrench
D_vel = self.D @ (self.state.arm_desired_twist - self.state.desired_twist)
K_pose = self.K @ self.state.pose_error
self.state.arm_desired_acc = self.M_inv @ (wrench_err - D_vel - K_pose)
self.clip_command(self.state.arm_desired_acc, "acc")
## 更新速度和位姿
self.state.arm_desired_twist += self.state.arm_desired_acc * dt
self.clip_command(self.state.arm_desired_twist, "vel")
# 计算位姿变化
delta_pose = np.concatenate([
arm_ori_mat @ (self.state.arm_desired_twist[:3] * dt),
self.state.arm_desired_twist[3:] * dt
])
self.clip_command(delta_pose, "pose")
# 更新四元数
delta_ori_quat = R.from_rotvec(delta_pose[3:]).as_quat()
arm_ori_quat_new = arm_ori_quat * R.from_quat(delta_ori_quat)
self.state.arm_orientation_command = arm_ori_quat_new.as_quat()
# 归一化四元数
self.state.arm_orientation_command /= np.linalg.norm(self.state.arm_orientation_command)
# 更新位置
self.state.arm_position_command = self.state.arm_position + delta_pose[:3]
def step_traj(self,dt):
# 方向统一
if self.state.desired_orientation.dot(self.state.arm_orientation) < 0:
self.state.arm_orientation = -self.state.arm_orientation
# 缓存常用计算
arm_ori_quat = R.from_quat(self.state.arm_orientation)
arm_ori_mat = arm_ori_quat.as_matrix()
# 位置误差
temp_pose_error = self.state.arm_position - self.state.desired_position + self.state.desired_twist[:3] * dt
self.state.pose_error[:3] = arm_ori_mat.T @ temp_pose_error
# 姿态误差(四元数)
angular_velocity = np.array(self.state.desired_twist[3:]) # 形状 (3,)
# 用旋转向量(小角度近似)
rotvec = angular_velocity * dt # 旋转向量 = 角速度 × 时间
rot_quat = R.from_rotvec(rotvec).as_quat() # 转成四元数,形状 (4,)
rot_err_quat = R.from_quat(rot_quat).inv() * arm_ori_quat.inv() * R.from_quat(self.state.desired_orientation)
self.state.pose_error[3:] = -rot_err_quat.as_rotvec(degrees=False)
# 期望加速度
wrench_err = self.state.external_wrench_tcp - self.state.desired_wrench
D_vel = self.D @ (self.state.arm_desired_twist - self.state.desired_twist + self.state.desired_acc*dt)
K_pose = self.K @ self.state.pose_error
self.state.arm_desired_acc = self.M_inv @ (wrench_err - D_vel - K_pose) + self.state.desired_acc
self.clip_command(self.state.arm_desired_acc, "acc")
## 更新速度和位姿
self.state.arm_desired_twist += self.state.arm_desired_acc * dt
self.clip_command(self.state.arm_desired_twist, "vel")
# 计算位姿变化
delta_pose = np.concatenate([
arm_ori_mat @ (self.state.arm_desired_twist[:3] * dt),
self.state.arm_desired_twist[3:] * dt
])
self.clip_command(delta_pose, "pose")
# 更新四元数
delta_ori_quat = R.from_rotvec(delta_pose[3:]).as_quat()
arm_ori_quat_new = arm_ori_quat * R.from_quat(delta_ori_quat)
self.state.arm_orientation_command = arm_ori_quat_new.as_quat()
# 归一化四元数
self.state.arm_orientation_command /= np.linalg.norm(self.state.arm_orientation_command)
# 更新位置
self.state.arm_position_command = self.state.arm_position + delta_pose[:3]
if __name__ == "__main__":
state = ArmState()
controller = AdmittanceController("admittance",state,"/home/zyc/admittance_control/MassageControl/config/admittance.yaml")
print(controller.name)
print(controller.state.arm_position)
state.arm_position = np.array([1,2,3])
print(controller.state.arm_position)
print(controller.M)
print(controller.D)
print(controller.K)