128 lines
6.1 KiB
Python

import sys
import numpy as np
from scipy.spatial.transform import Rotation as R
from .base_controller import BaseController
from .arm_state import ArmState
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import time
from tools.yaml_operator import read_yaml
import random
class AdmittanceController(BaseController):
def __init__(self, name, state:ArmState,config_path) -> None:
super().__init__(name, state)
self.load_config(config_path)
def load_config(self, config_path):
config_dict = read_yaml(config_path)
if self.name != config_dict['name']:
raise ValueError(f"Controller name {self.name} does not match config name {config_dict['name']}")
mass_tran = np.array(config_dict['mass_tran'])
mass_rot = np.array(config_dict['mass_rot'])
stiff_tran = np.array(config_dict['stiff_tran'])
stiff_rot = np.array(config_dict['stiff_rot'])
desired_xi = np.array(config_dict['desired_xi'])
damp_tran = np.array(config_dict['damp_tran'])
damp_rot = np.array(config_dict['damp_rot'])
self.pos_scale_factor = config_dict['pos_scale_factor']
self.rot_scale_factor = config_dict['rot_scale_factor']
for i in range(3):
if damp_tran[i] < 0:
damp_tran[i] = 2 * desired_xi * np.sqrt(stiff_tran[i] * mass_tran[i])
if damp_rot[i] < 0:
damp_rot[i] = 2 * desired_xi * np.sqrt(stiff_rot[i] * mass_rot[i])
self.M = np.diag(np.concatenate([mass_tran, mass_rot]))
self.M_inv = np.linalg.inv(self.M)
self.K = np.diag(np.concatenate([stiff_tran, stiff_rot]))
self.D = np.diag(np.concatenate([damp_tran, damp_rot]))
self.laset_print_time = 0
def step(self,dt):
# 方向统一
if self.state.desired_orientation.dot(self.state.arm_orientation) < 0:
self.state.arm_orientation = -self.state.arm_orientation
# 缓存常用计算
arm_ori_quat = R.from_quat(self.state.arm_orientation)
arm_ori_mat = arm_ori_quat.as_matrix()
# 位置误差
temp_pose_error = self.state.arm_position - self.state.desired_position
self.state.pose_error[:3] = arm_ori_mat.T @ temp_pose_error
# 姿态误差(四元数)
rot_err_quat = arm_ori_quat.inv() * R.from_quat(self.state.desired_orientation)
self.state.pose_error[3:] = -rot_err_quat.as_rotvec(degrees=False)
# 期望加速度
wrench_err = self.state.external_wrench_tcp - self.state.desired_wrench
D_vel = self.D @ (self.state.arm_desired_twist - self.state.desired_twist)
K_pose = self.K @ self.state.pose_error
self.state.arm_desired_acc = self.M_inv @ (wrench_err - D_vel - K_pose)
self.clip_command(self.state.arm_desired_acc, "acc")
## 更新速度和位姿
self.state.arm_desired_twist += self.state.arm_desired_acc * dt
self.clip_command(self.state.arm_desired_twist, "vel")
# 计算位姿变化
delta_pose = np.concatenate([
arm_ori_mat @ (self.state.arm_desired_twist[:3] * dt),
self.state.arm_desired_twist[3:] * dt
])
self.clip_command(delta_pose, "pose")
# 更新四元数
delta_ori_quat = R.from_rotvec(delta_pose[3:]).as_quat()
arm_ori_quat_new = arm_ori_quat * R.from_quat(delta_ori_quat)
self.state.arm_orientation_command = arm_ori_quat_new.as_quat()
# 归一化四元数
self.state.arm_orientation_command /= np.linalg.norm(self.state.arm_orientation_command)
# 更新位置
self.state.arm_position_command = self.state.arm_position + delta_pose[:3]
def step_traj(self,dt):
# 方向统一
if self.state.desired_orientation.dot(self.state.arm_orientation) < 0:
self.state.arm_orientation = -self.state.arm_orientation
# 缓存常用计算
arm_ori_quat = R.from_quat(self.state.arm_orientation)
arm_ori_mat = arm_ori_quat.as_matrix()
# 位置误差
temp_pose_error = self.state.arm_position - self.state.desired_position + self.state.desired_twist[:3] * dt
self.state.pose_error[:3] = arm_ori_mat.T @ temp_pose_error
# 姿态误差(四元数)
rot_err_quat = (R.from_quat(self.state.desired_twist[3:] * dt)).inv() * arm_ori_quat.inv() * R.from_quat(self.state.desired_orientation)
self.state.pose_error[3:] = -rot_err_quat.as_rotvec(degrees=False)
# 期望加速度
wrench_err = self.state.external_wrench_tcp - self.state.desired_wrench
D_vel = self.D @ (self.state.arm_desired_twist - self.state.desired_twist + self.state.desired_acc*dt)
K_pose = self.K @ self.state.pose_error
self.state.arm_desired_acc = self.M_inv @ (wrench_err - D_vel - K_pose) + self.state.desired_acc
self.clip_command(self.state.arm_desired_acc, "acc")
## 更新速度和位姿
self.state.arm_desired_twist += self.state.arm_desired_acc * dt
self.clip_command(self.state.arm_desired_twist, "vel")
# 计算位姿变化
delta_pose = np.concatenate([
arm_ori_mat @ (self.state.arm_desired_twist[:3] * dt),
self.state.arm_desired_twist[3:] * dt
])
self.clip_command(delta_pose, "pose")
# 更新四元数
delta_ori_quat = R.from_rotvec(delta_pose[3:]).as_quat()
arm_ori_quat_new = arm_ori_quat * R.from_quat(delta_ori_quat)
self.state.arm_orientation_command = arm_ori_quat_new.as_quat()
# 归一化四元数
self.state.arm_orientation_command /= np.linalg.norm(self.state.arm_orientation_command)
# 更新位置
self.state.arm_position_command = self.state.arm_position + delta_pose[:3]
if __name__ == "__main__":
state = ArmState()
controller = AdmittanceController("admittance",state,"/home/zyc/admittance_control/MassageControl/config/admittance.yaml")
print(controller.name)
print(controller.state.arm_position)
state.arm_position = np.array([1,2,3])
print(controller.state.arm_position)
print(controller.M)
print(controller.D)
print(controller.K)