from hardware.dobot_nova5 import dobot_nova5
from hardware.force_sensor import XjcSensor

from algorithms.arm_state import ArmState
from algorithms.controller_manager import ControllerManager
from algorithms.admittance_controller import AdmittanceController
from algorithms.position_controller import PositionController

from tools.log import CustomLogger
from tools.yaml_operator import read_yaml
from tools.Rate import Rate
import numpy as np
import threading
import time
import os
from scipy.spatial.transform import Rotation as R
class MassageRobot:
    def __init__(self,arm_config_path,is_log=False):
        self.logger = CustomLogger(
            log_name="测试日志",
            log_file="logs/MassageRobot_nova5_test.log",
            precise_time=True)
        if is_log:
            self.logger_data = CustomLogger(log_name="运动数据日志",
                                            log_file="logs/MassageRobot_kinematics_data.log",
                                            precise_time=True)
            # 日志线程    
            threading.Thread(target=self.log_thread,daemon=True).start()
        
        self.arm_state = ArmState()
        self.arm_config = read_yaml(arm_config_path)
        # arm 实例化时机械臂类内部进行通讯连接
        self.arm = dobot_nova5(arm_ip=self.arm_config['arm_ip'])
        self.force_sensor = XjcSensor(host=self.arm_config['arm_ip'],port=60000)

        # 控制器初始化(初始化为导纳控制)
        self.controller_manager = ControllerManager(self.arm_state)
        self.controller_manager.add_controller(AdmittanceController,'admittance',self.arm_config['controller'][0])
        self.controller_manager.add_controller(PositionController,'position',self.arm_config['controller'][1])
        self.controller_manager.switch_controller('admittance')
        # 按摩头参数暂时使用本地数据
        massage_head_dir = self.arm_config['massage_head_dir']
        all_items = os.listdir(massage_head_dir)
        head_config_files = [f for f in all_items if os.path.isfile(os.path.join(massage_head_dir, f))]
        self.playload_dict = {}
        for file in head_config_files:
            file_address = massage_head_dir + '/' + file
            play_load = read_yaml(file_address)
            self.playload_dict[play_load['name']] = play_load
        self.current_head = 'none'

        # 频率数据赋值
        self.control_rate = Rate(self.arm_config['control_rate'])
        self.sensor_rate = Rate(self.arm_config['sensor_rate'])
        self.command_rate = Rate(self.arm_config['command_rate'])

        # 低通滤波
        self.cutoff_freq = 80.0

        # 停止标志位线程 
        self.exit_event = threading.Event()
        self.exit_event.set() # 运行初始化为True
        self.interrupt_event = threading.Event()
        self.interrupt_event.clear() # 中断初始化为False
        self.pause_envent = threading.Event()
        self.pause_envent.clear() # 暂停初始化为False
        self.skip_envent = threading.Event()
        self.skip_envent.clear() # 跳过初始化为False
        # 按摩调整
        self.adjust_wrench_event = threading.Event()
        self.adjust_wrench_event.clear() # 调整初始化为False
        self.pos_increment = np.zeros(3,dtype=np.float64)
        self.adjust_wrench = np.zeros(6,dtype=np.float64)

        self.is_waitting = False

        self.last_print_time = 0
        self.last_record_time = 0
        self.last_command_time = 0
        self.move_to_point_count = 0
        # 卡尔曼滤波相关初始化
        self.x_base = np.zeros(6)
        self.P_base  = np.eye(6)
        # 过程噪声协方差矩阵
        self.Q_base = np.eye(6) * 0.01
        # 测量噪声协方差矩阵
        self.R_base = np.eye(6) * 0.1

        self.x_tcp = np.zeros(6)
        self.P_tcp  = np.eye(6)
        # 过程噪声协方差矩阵
        self.Q_tcp = np.eye(6) * 0.01
        # 测量噪声协方差矩阵
        self.R_tcp = np.eye(6) * 0.1
        # 传感器故障计数器
        self.sensor_fail_count = 0 

        # 机械臂初始化,适配中间层
        self.arm.init()
    
    # 预测步骤
    def kalman_predict(self,x, P, Q):
        # 预测状态(这里假设状态不变)
        x_predict = x
        # 预测误差协方差
        P_predict = P + Q
        return x_predict, P_predict

    # 更新步骤
    def kalman_update(self,x_predict, P_predict, z, R):
        # 卡尔曼增益
        K = P_predict @ np.linalg.inv(P_predict + R)
        # 更新状态
        x_update = x_predict + K @ (z - x_predict)
        # 更新误差协方差
        P_update = (np.eye(len(K)) - K) @ P_predict
        return x_update, P_update
    
    def update_wrench(self):
        compensation_config = self.playload_dict[self.current_head]
        # 读取数据
        gravity_base = np.array(compensation_config['gravity_base'])
        force_zero = np.array(compensation_config['force_zero'])
        torque_zero = np.array(compensation_config['torque_zero'])
        tool_position = np.array(compensation_config['tcp_offset'])
        mass_center_position = np.array(compensation_config['mass_center_position'])
        # 当前的机械臂到末端转换 (实时)
        b_rotation_s = R.from_quat(self.arm_state.arm_orientation).as_matrix()
        # 读取数据
        sensor_data = self.force_sensor.read()
        if sensor_data is None:
            self.force_sensor.stop_background_reading()
            self.logger.log_error("传感器数据读取失败")
            return -1
        # 传感器数据通过矫正计算得到外来施加力 传感器坐标系下
        # 重力和零位
        gravity_in_sensor = b_rotation_s.T @ gravity_base
        s_force = sensor_data[:3] - force_zero - gravity_in_sensor
        # 力矩
        s_torque = sensor_data[3:] - torque_zero - np.cross(mass_center_position,gravity_in_sensor)
        wrench = np.concatenate([s_force,s_torque])
        # 传感器工具转换 
        s_tf_matrix_t = self.get_tf_matrix(tool_position[:3], R.from_euler('xyz',tool_position[3:],degrees=False).as_quat())
        # 传感器到TCP
        wrench = self.wrench_coordinate_conversion(s_tf_matrix_t,wrench)
        # 交给ARM STATE集中管理
        self.arm_state.external_wrench_tcp = wrench
        self.arm_state.external_wrench_base = np.concatenate([b_rotation_s @ self.arm_state.external_wrench_tcp[:3],
                                                          b_rotation_s @ self.arm_state.external_wrench_tcp[3:]])
        # 卡尔曼滤波
        x_base_predict, P_base_predict = self.kalman_predict(x = self.x_base, 
                                                             P = self.P_base, 
                                                             Q = self.Q_base)
        self.x_base, self.P_base = self.kalman_update(x_predict = x_base_predict, 
                                                      P_predict = P_base_predict, 
                                                      z = self.arm_state.external_wrench_base, 
                                                      R = self.R_base)
        self.arm_state.external_wrench_base = self.x_base
        self.arm_state.last_external_wrench_base = self.arm_state.external_wrench_base
        # 对tcp坐标系下的外力外矩进行平滑
        x_tcp_predict, P_tcp_predict = self.kalman_predict(x = self.x_tcp, 
                                                           P = self.P_tcp, 
                                                           Q = self.Q_tcp)
        self.x_tcp, self.P_tcp = self.kalman_update(x_predict = x_tcp_predict, 
                                                    P_predict = P_tcp_predict, 
                                                    z = self.arm_state.external_wrench_tcp, 
                                                    R = self.R_tcp)
        self.arm_state.external_wrench_tcp = self.x_tcp
        self.arm_state.last_external_wrench_tcp = self.arm_state.external_wrench_tcp
        return 0

    def arm_measure_loop(self):
        self.logger.log_info("机械臂测量线程启动")
        while (not self.arm.is_exit) and (not self.exit_event.is_set()):
            try:
                if not self.is_waitting:
                    self.arm_state.arm_position,
                    self.arm_state.arm_orientation = self.arm.get_arm_position()
                    code = self.update_wrench() 
                    if code == -1:
                        self.sensor_fail_count += 1
                        self.logger.log_error(f"传感器线程数据读取失败-{self.sensor_fail_count}")
                        if self.sensor_fail_count > 10:
                            self.logger.log_error("传感器线程数据读取失败超过10次,程序终止")
                            self.stop()
                            break
                    else:
                        self.sensor_fail_count = 0
            except Exception as e:
                self.logger.log_error(f"机械臂测量线程报错:{e}")
                self.exit_event.set() # 控制退出while
            self.sensor_rate.precise_sleep() # 控制频率

    def arm_command_loop(self):
        self.logger.log_info("机械臂控制线程启动")
        while (not self.arm.is_exit) and (not self.exit_event.is_set()):
            try:
                if not self.is_waitting:
                    self.controller_manager.step(self.control_rate.to_sec())
                    self.last_command_time += 1
                    code = self.arm.ServoPose(self.arm_state.arm_position,
                                               self.arm_state.arm_orientation_command)
                    if code == -1:
                        self.logger.log_error("机械臂急停")
                        # self.stop() # 底层已经做了急停处理
                        break
            except Exception as e:
                self.logger.log_error(f"机械臂控制失败:{e}")
                self.exit_event.set()
            self.control_rate.precise_sleep()
    def start(self):
        if self.exit_event.is_set():
            # 线程
            self.exit_event.clear()
            self.arm_measure_thread = threading.Thread(target=self.arm_measure_loop)
            self.arm_control_thread = threading.Thread(target=self.arm_command_loop)
            self.arm_measure_thread.start()
            position,quat_rot = self.arm.get_arm_position()
            # 初始值
            self.arm_state.desired_position = position
            self.arm_state.arm_position_command = position
            self.arm_state.desired_orientation = quat_rot
            self.arm_state.arm_orientation_command = quat_rot
            self.logger.log_info("MassageRobot启动")
            time.sleep(0.5)
            
    def stop(self):
        if not self.exit_event.is_set():
            self.exit_event.set()
            self.interrupt_event.clear()
            self.arm_control_thread.join()
            self.arm_measure_thread.join()
            for i in range(3):
                self.force_sensor.disable_active_transmission()
            self.force_sensor.start_background_reading()
            self.arm.stop_motion()
            self.logger.log_info("MassageRobot停止")
            
    def init_hardwares(self,ready_pose):
        self.ready_pose = np.array(ready_pose)
        self.switch_payload(self.current_head)
        self.arm_state.desired_orientation = self.ready_pose[:3]
        euler_angles = self.ready_pose
        self.arm_state.desired_orientation = R.from_euler('xyz',euler_angles).as_quat()
        time.sleep(0.5)
        
    def switch_payload(self,name):
        if name in self.playload_dict:
            self.stop()
            self.current_head = name
            tcp_offset = self.playload_dict[name]["tcp_offset"]
            tcp_offset_str = "{" + ",".join(map(str, tcp_offset)) + "}"
            print(tcp_offset_str)
            self.arm.setEndEffector(i=1,tool_i=tcp_offset_str)
            self.arm.chooseEndEffector(i=1)
            self.logger.log_info(f"切换到{name}按摩头")
            R_matrix = R.from_euler('xyz',self.ready_pose[3:] ,degrees=False).as_matrix()
            ready_position = self.ready_pose[:3] + R_matrix @ self.playload_dict[self.current_head]['tcp_offset'][:3]
            ready_orientation =  R_matrix @ R.from_euler('xyz',self.playload_dict[self.current_head]['tcp_offset'][3:],degrees=False).as_matrix()
            ready_orientation_euler = R.from_matrix(ready_orientation).as_euler('xyz',degrees=False)
            self.arm_state.desired_position = ready_position 
            self.arm_state.desired_orientation = R.from_euler('xyz',ready_orientation_euler,degrees=False).as_quat()
            self.controller_manager.switch_controller('position')
        else:
            self.logger.log_error(f"未找到{name}按摩头")

    def log_thread(self):
        while True:
            self.logger_data.log_info(f"机械臂位置:{self.arm_state.arm_position},机械臂姿态:{self.arm_state.arm_orientation}",is_print=False)
            self.logger_data.log_info(f"机械臂期望位置:{self.arm_state.desired_position},机械臂真实位置:{self.arm_state.arm_position}",is_print=False)
            self.logger_data.log_info(f"机械臂期望姿态:{self.arm_state.desired_orientation},机械臂真实姿态:{self.arm_state.arm_orientation}",is_print=False)
            self.logger_data.log_info(f"机械臂期望力矩:{self.arm_state.desired_wrench},机械臂真实力矩:{self.arm_state.external_wrench_base}",is_print=False)
            self.logger_data.log_info(f"当前按摩头:{self.current_head}",is_print=False)
            time.sleep(1)

    # 工具函数
    def get_tf_matrix(position,orientation):
        tf_matrix = np.eye(4)
        rotation_matrix = R.from_quat(orientation).as_matrix()
        tf_matrix[:3,3] = position
        tf_matrix[:3,:3] = rotation_matrix
        return tf_matrix
    def wrench_coordinate_conversion(tf_matrix, wrench):
        rot_matrix = tf_matrix[:3, :3]
        vector_p = tf_matrix[:3, 3]
        temp_force = wrench[:3]
        torque = wrench[3:]
        force = rot_matrix.T @ temp_force
        torque = rot_matrix.T @ np.cross(temp_force,vector_p) + rot_matrix.T @ torque
        return np.concatenate([force, torque])

if __name__ == "__main__":
    robot = MassageRobot(arm_config_path="MassageControl/config/robot_config.yaml")