力控中间层适配新的底层接口,开发中

This commit is contained in:
ziwei.he 2025-05-09 14:03:24 +08:00
parent f56944f573
commit 2dfb5f8447
13 changed files with 807 additions and 0 deletions

View File

@ -0,0 +1,151 @@
from hardware.dobot_nova5 import dobot_nova5
from algorithms.arm_state import ArmState
from algorithms.controller_manager import ControllerManager
from algorithms.admittance_controller import AdmittanceController
from tools.log import CustomLogger
from tools.yaml_operator import read_yaml
from tools.Rate import Rate
import numpy as np
import threading
import time
import os
class MassageRobot:
def __init__(self,arm_config_path,is_log=False):
self.logger = CustomLogger(
log_name="测试日志",
log_file="logs/MassageRobot_nova5_test.log",
precise_time=True)
if is_log:
self.logger_data = CustomLogger(log_name="运动数据日志",
log_file="logs/MassageRobot_kinematics_data.log",
precise_time=True)
# 日志线程
threading.Thread(target=self.log_thread,daemon=True).start()
self.arm_state = ArmState()
self.arm_config = read_yaml(arm_config_path)
self.arm = dobot_nova5(arm_ip=self.arm_config['arm_ip'])
self.force_sensor = None
# controller
self.controller_manager = ControllerManager(self.arm_state)
self.controller_manager.add_controller(AdmittanceController,'admittance',self.arm_config['controller'][0])
# massage heads
massage_head_dir = self.arm_config['massage_head_dir']
all_items = os.listdir(massage_head_dir)
head_config_files = [f for f in all_items if os.path.isfile(os.path.join(massage_head_dir, f))]
self.playload_dict = {}
for file in head_config_files:
file_address = massage_head_dir + '/' + file
play_load = read_yaml(file_address)
self.playload_dict[play_load['name']] = play_load
self.current_head = 'none'
# 频率
self.control_rate = Rate(self.arm_config['control_rate'])
self.sensor_rate = Rate(self.arm_config['sensor_rate'])
self.command_rate = Rate(self.arm_config['command_rate'])
# 低通滤波
self.cutoff_freq = 80.0
# flags
self.exit_event = threading.Event()
self.exit_event.set() # 运行 True
self.interrupt_event = threading.Event()
self.interrupt_event.clear() # 中断 False
self.pause_envent = threading.Event()
self.pause_envent.clear() # 暂停 False
self.skip_envent = threading.Event()
self.skip_envent.clear() # 跳过 False
self.is_waitting = False
self.last_print_time = 0
self.last_record_time = 0
self.last_command_time = 0
self.move_to_point_count = 0
self.width_default = 5
self.x_base = np.zeros(6)
self.P_base = np.eye(6)
# 过程噪声协方差矩阵
self.Q_base = np.eye(6) * 0.01
# 测量噪声协方差矩阵
self.R_base = np.eye(6) * 0.1
self.x_tcp = np.zeros(6)
self.P_tcp = np.eye(6)
# 过程噪声协方差矩阵
self.Q_tcp = np.eye(6) * 0.01
# 测量噪声协方差矩阵
self.R_tcp = np.eye(6) * 0.1
# 传感器故障计数器
self.sensor_fail_count = 0
# init
self.arm.__init__()
# 预测步骤
def kalman_predict(self,x, P, Q):
# 预测状态(这里假设状态不变)
x_predict = x
# 预测误差协方差
P_predict = P + Q
return x_predict, P_predict
# 更新步骤
def kalman_update(self,x_predict, P_predict, z, R):
# 卡尔曼增益
K = P_predict @ np.linalg.inv(P_predict + R)
# 更新状态
x_update = x_predict + K @ (z - x_predict)
# 更新误差协方差
P_update = (np.eye(len(K)) - K) @ P_predict
return x_update, P_update
def arm_measure_loop(self):
return
def arm_command_loop(self):
return
def start(self):
if self.exit_event.is_set():
self.exit_event.clear()
self.arm_measure_thread = threading.Thread(target=self.arm_measure_loop)
self.arm_control_thread = threading.Thread(target=self.arm_command_loop)
self.arm_measure_thread.start()
def stop(self):
if not self.exit_event.is_set():
self.exit_event.set()
self.interrupt_event.clear()
self.arm_control_thread.join()
self.arm_measure_thread.join()
return
def init_hardwares(self,ready_pose):
self.ready_pose = np.array(ready_pose)
def switch_payload(self,name):
if name in self.playload_dict:
return
def log_thread(self):
while True:
self.logger_data.log_info(f"机械臂位置:{self.arm_state.arm_position},机械臂姿态:{self.arm_state.arm_orientation}",is_print=False)
self.logger_data.log_info(f"机械臂期望位置:{self.arm_state.desired_position},机械臂真实位置:{self.arm_state.arm_position}",is_print=False)
self.logger_data.log_info(f"机械臂期望姿态:{self.arm_state.desired_orientation},机械臂真实姿态:{self.arm_state.arm_orientation}",is_print=False)
self.logger_data.log_info(f"机械臂期望力矩:{self.arm_state.desired_wrench},机械臂真实力矩:{self.arm_state.external_wrench_base}",is_print=False)
self.logger_data.log_info(f"当前按摩头:{self.current_head}",is_print=False)
time.sleep(1)
if __name__ == "__main__":
robot = MassageRobot(arm_config_path="MassageControl/config/robot_config.yaml")

View File

@ -0,0 +1,122 @@
import sys
import numpy as np
from scipy.spatial.transform import Rotation as R
from .base_controller import BaseController
from .arm_state import ArmState
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import time
from tools.yaml_operator import read_yaml
class AdmittanceController(BaseController):
def __init__(self, name, state:ArmState,config_path) -> None:
super().__init__(name, state)
self.load_config(config_path)
def load_config(self, config_path):
config_dict = read_yaml(config_path)
if self.name != config_dict['name']:
raise ValueError(f"Controller name {self.name} does not match config name {config_dict['name']}")
mass_tran = np.array(config_dict['mass_tran'])
mass_rot = np.array(config_dict['mass_rot'])
stiff_tran = np.array(config_dict['stiff_tran'])
stiff_rot = np.array(config_dict['stiff_rot'])
desired_xi = np.array(config_dict['desired_xi'])
damp_tran = np.array(config_dict['damp_tran'])
damp_rot = np.array(config_dict['damp_rot'])
self.pos_scale_factor = config_dict['pos_scale_factor']
self.rot_scale_factor = config_dict['rot_scale_factor']
for i in range(3):
if damp_tran[i] < 0:
damp_tran[i] = 2 * desired_xi * np.sqrt(stiff_tran[i] * mass_tran[i])
if damp_rot[i] < 0:
damp_rot[i] = 2 * desired_xi * np.sqrt(stiff_rot[i] * mass_rot[i])
self.M = np.diag(np.concatenate([mass_tran, mass_rot]))
self.K = np.diag(np.concatenate([stiff_tran, stiff_rot]))
self.D = np.diag(np.concatenate([damp_tran, damp_rot]))
self.laset_print_time = 0
def step(self,dt):
# 计算误差 位置直接作差,姿态误差以旋转向量表示
temp_pose_error = self.state.arm_position - self.state.desired_position
# if time.time() - self.laset_print_time > 5:
# print(f'temp_pose_error: {temp_pose_error} ||| arm_position: {self.state.arm_position} ||| desired_position: {self.state.desired_position}')
if self.state.desired_orientation.dot(self.state.arm_orientation) < 0:
self.state.arm_orientation = -self.state.arm_orientation
self.state.pose_error[:3] = R.from_quat(self.state.arm_orientation).as_matrix().T @ temp_pose_error
# if time.time() - self.laset_print_time > 5:
# print("pose_error:",self.state.pose_error[:3])
# 计算误差 位置直接作差,姿态误差以旋转向量表示
#rot_err_mat = R.from_quat(self.state.arm_orientation).as_matrix() @ R.from_quat(self.state.desired_orientation).as_matrix().T
rot_err_mat = R.from_quat(self.state.arm_orientation).as_matrix().T @ R.from_quat(self.state.desired_orientation).as_matrix()
# print(f'rot_err_mat: {rot_err_mat} ||| arm_orientation: {R.from_quat(self.state.arm_orientation).as_euler('xyz',False)} ||| desired_orientation: {R.from_quat(self.state.desired_orientation).as_euler('xyz',False)}')
rot_err_rotvex = R.from_matrix(rot_err_mat).as_rotvec(degrees=False)
self.state.pose_error[3:] = -rot_err_rotvex
#wrench_err = self.state.external_wrench_base - self.state.desired_wrench
wrench_err = self.state.external_wrench_tcp - self.state.desired_wrench
# if time.time() - self.laset_print_time > 5:
# print(f'wrench_err: {wrench_err} ||| external_wrench_tcp: {self.state.external_wrench_tcp} ||| desired_wrench: {self.state.desired_wrench}')
self.state.arm_desired_acc = np.linalg.inv(self.M) @ (wrench_err - self.D @ (self.state.arm_desired_twist -self.state.desired_twist) - self.K @ self.state.pose_error)
# if time.time() - self.laset_print_time > 5:
# print("@@@:",wrench_err - self.D @ (self.state.arm_desired_twist -self.state.desired_twist) - self.K @ self.state.pose_error)
self.clip_command(self.state.arm_desired_acc,"acc")
self.state.arm_desired_twist = self.state.arm_desired_acc * dt + self.state.arm_desired_twist
self.clip_command(self.state.arm_desired_twist,"vel")
delta_pose = self.state.arm_desired_twist * dt
delta_pose[:3] = self.pos_scale_factor * delta_pose[:3]
delta_pose[3:] = self.rot_scale_factor * delta_pose[3:]
# if time.time() - self.laset_print_time > 5:
# print("delta_pose:",delta_pose)
delta_pose[:3] = R.from_quat(self.state.arm_orientation).as_matrix() @ delta_pose[:3]
# if time.time() - self.laset_print_time > 5:
# print("tf_delta_pose:",delta_pose)
self.clip_command(delta_pose,"pose")
# testlsy
delta_ori_mat = R.from_rotvec(delta_pose[3:]).as_matrix()
#arm_ori_mat = delta_ori_mat @ R.from_quat(self.state.arm_orientation).as_matrix()
arm_ori_mat = R.from_quat(self.state.arm_orientation).as_matrix() @ delta_ori_mat
# self.state.arm_orientation_command = R.from_matrix(arm_ori_mat).as_quat()
self.state.arm_orientation_command = R.from_matrix(arm_ori_mat).as_quat()
# 归一化四元数
self.state.arm_orientation_command /= np.linalg.norm(self.state.arm_orientation_command)
# arm_ori_mat = R.from_quat(self.state.arm_orientation).as_rotvec(degrees=False) + delta_pose[3:]
# self.state.arm_orientation_command = R.from_rotvec(arm_ori_mat).as_quat()
# self.state.arm_orientation_command = R.from_matrix(arm_ori_mat).as_quat()
self.state.arm_position_command = self.state.arm_position + delta_pose[:3]
# if time.time() - self.laset_print_time > 0.1:
# print("-------------admittance_1-------------------------------")
# print("arm_position:",self.state.arm_position)
# print("desired_position:",self.state.desired_position)
# print("arm_orientation",R.from_quat(self.state.arm_orientation).as_euler('xyz',degrees=True))
# print("desired_orientation",R.from_quat(self.state.desired_orientation).as_euler('xyz',degrees=True))
# print("arm_position_command",self.state.arm_position_command)
# print("arm_orientation_command",R.from_quat(self.state.arm_orientation_command).as_euler('xyz',degrees=True))
# print("delta_pose:",delta_pose)
# self.laset_print_time = time.time()
if __name__ == "__main__":
state = ArmState()
controller = AdmittanceController("admittance",state,"/home/zyc/admittance_control/MassageControl/config/admittance.yaml")
print(controller.name)
print(controller.state.arm_position)
state.arm_position = np.array([1,2,3])
print(controller.state.arm_position)
print(controller.M)
print(controller.D)
print(controller.K)

View File

@ -0,0 +1,44 @@
import numpy as np
class ArmState:
def __init__(self) -> None:
# 当前状态
self.arm_position = np.zeros(3,dtype=np.float64)
self.arm_orientation = np.array([0.0,0.0,0.0,1.0]) # [qx, qy, qz, qw]
self.external_wrench_base = np.zeros(6,dtype=np.float64)
self.external_wrench_tcp = np.zeros(6,dtype=np.float64)
# 上一个状态
self.last_arm_position = np.zeros(3,dtype=np.float64)
self.last_arm_orientation = np.array([0.0,0.0,0.0,1.0]) # [qx, qy, qz, qw]
self.last_external_wrench_base = np.zeros(6,dtype=np.float64)
self.last_external_wrench_tcp = np.zeros(6,dtype=np.float64)
# 目标状态
self.desired_position = np.zeros(3,dtype=np.float64)
self.desired_orientation = np.array([0.0,0,0,1]) # [qx, qy, qz, qw]
self.desired_wrench = np.zeros(6,dtype=np.float64)
self.desired_twist = np.zeros(6,dtype=np.float64)
# 导纳计算过程变量
self.arm_desired_twist = np.zeros(6,dtype=np.float64)
self.arm_desired_twist_tcp = np.zeros(6,dtype=np.float64)
self.arm_desired_acc = np.zeros(6,dtype=np.float64)
# 控制输出
self.arm_position_command = np.zeros(3,dtype=np.float64)
self.arm_orientation_command = np.array([0,0,0,1]) # [qx, qy, qz, qw]
# 误差信息
self.pose_error = np.zeros(6,dtype=np.float64)
self.twist_error = np.zeros(6,dtype=np.float64)
self.wrench_error = np.zeros(6,dtype=np.float64)
# clip项
self.max_acc_tran = 40
self.max_acc_rot = 20 * 3.5
self.max_vel_tran = 0.5 * 2.1
self.max_vel_rot = 3.1415 * 2 #1.5
self.max_dx = 0.01 *1.5
self.max_dr = 0.0087 * 3*1.5

View File

@ -0,0 +1,80 @@
from abc import ABC, abstractmethod
from typing import Literal
import numpy as np
from scipy.spatial.transform import Rotation as R
from .arm_state import ArmState
"""
位置单位为m力单位为N力矩单位为Nm角度单位为rad
"""
class BaseController(ABC):
def __init__(self,name,state:ArmState) -> None:
super().__init__()
self.name = name
self.state = state
@abstractmethod
def step(self,dt):
# 算法的一次迭代
pass
@abstractmethod
def load_config(self, config_path):
# 加载配置文件
pass
def clip_command(self, command :np.array,type: Literal["acc", "vel", "pose"],is_print=False):
if type == "acc":
if np.linalg.norm(command[:3]) > self.state.max_acc_tran:
command[:3] = command[:3] / np.linalg.norm(command[:3]) * self.state.max_acc_tran
if is_print:
print(f"translational acceleration {np.linalg.norm(command[:3])}m/s exceeds maximum allowed value")
if np.linalg.norm(command[3:]) > self.state.max_acc_rot:
if is_print:
print(f"rotational acceleration {np.linalg.norm(command[3:])}rad/s exceeds maximum allowed value,")
command[3:] = command[3:] / np.linalg.norm(command[3:]) * self.state.max_acc_rot
elif type == "vel":
if np.linalg.norm(command[:3]) > self.state.max_vel_tran:
if is_print:
print(f"translational velocity {np.linalg.norm(command[:3])}m/s exceeds maximum allowed value,")
command[:3] = command[:3] / np.linalg.norm(command[:3]) * self.state.max_vel_tran
if np.linalg.norm(command[3:]) > self.state.max_vel_rot:
command[3:] = command[3:] / np.linalg.norm(command[3:]) * self.state.max_vel_rot
if is_print:
print(f"rotational velocity {np.linalg.norm(command[3:])}rad/s exceeds maximum allowed value")
elif type == "pose":
if np.linalg.norm(command[:3]) > self.state.max_dx:
command[:3] = command[:3] / np.linalg.norm(command[:3]) * self.state.max_dx
if is_print:
print(f"translational displacement {np.linalg.norm(command[:3])}m exceeds maximum allowed value")
if np.linalg.norm(command[3:]) > self.state.max_dr:
command[3:] = command[3:] / np.linalg.norm(command[3:]) * self.state.max_dr
if is_print:
print(f"rotational displacement {np.linalg.norm(command[3:])}rad exceeds maximum allowed value")
@staticmethod
def rotvec_pose_add(pose, delta_pose):
"""
Compute the pose sum between two poses, which consists if position (x, y, z) and rotation vector (rx, ry, rz).
Update rule: x_t+1 = x_t + dx, R_t+1 = dR * R_t (rotation matrix)
:param pose: np.ndarray (6,)
:param delta_pose: np.ndarray (6,)
:return: np.ndarray (6,)
"""
assert len(pose) == 6 and len(delta_pose) == 6, "pose and delta_pose must be 6-dimensional"
ret = np.zeros(6)
ret[:3] = pose[:3] + delta_pose[:3]
# 当前姿态的旋转矩阵
pose_matrix = R.from_rotvec(pose[3:]).as_matrix()
# 旋转矩阵的增量
pose_delta_matrix = R.from_rotvec(delta_pose[3:]).as_matrix()
# 更新后的旋转矩阵,然后转换为旋转向量
ret[3:] = R.from_matrix(pose_delta_matrix @ pose_matrix).as_rotvec()
return ret

View File

@ -0,0 +1,42 @@
from .arm_state import ArmState
class ControllerManager:
def __init__(self, state: ArmState):
self.state = state
self.controllers = {}
self.current_controller = None
def add_controller(self, controller_class, name, config_path):
if name not in self.controllers:
self.controllers[name] = controller_class(name,self.state,config_path)
else:
raise ValueError(f"Controller {name} already exists")
def remove_controller(self, name):
if name in self.controllers:
del self.controllers[name]
if self.current_controller == self.controllers.get(name):
self.current_controller = None
else:
raise ValueError(f"Controller {name} does not exist")
def switch_controller(self, name):
if name in self.controllers:
self.current_controller = self.controllers[name]
else:
raise ValueError(f"Controller {name} does not exist")
def __getattr__(self, name):
if self.current_controller:
method = getattr(self.current_controller, name, None)
if method:
return method
else:
raise AttributeError(f"'{type(self.current_controller).__name__}' object has no attribute '{name}'")
else:
raise RuntimeError("No current controller is set")
def step(self,dt):
if self.current_controller:
self.current_controller.step(dt)

View File

@ -0,0 +1,66 @@
name: admittance
# mass_tran: [5,5,5]
# mass_rot: [5,5,5]
# stiff_tran: [500,500,500]
# stiff_rot: [500,500,500]
# desired_xi: 0.7
# damp_tran: [-125,-125,-125]
# damp_rot: [-125,-125,-125]
# pos_scale_factor: 1
# rot_scale_factor: 1
# name: admittance
# mass_tran: [1,1,1]
# mass_rot: [0.5,0.5,0.5]
# stiff_tran: [800,800,800]
# stiff_rot: [100,100,100]
# desired_xi: 1.1
# damp_tran: [-2,-2,-2]
# damp_rot: [-10,-10,-10]
# pos_scale_factor: 1
# rot_scale_factor: 1
# mass_tran: [4,4,4]
# mass_rot: [0.4,0.4,0.4]
# stiff_tran: [500,500,500]
# stiff_rot: [6,6,6]
# desired_xi: 1.1
# damp_tran: [-1,-1,-1]
# damp_rot: [-1,-1,-1]
# pos_scale_factor: 1
# rot_scale_factor: 12
# mass_tran: [1.3,0.4,2.4]
# mass_rot: [0.08,0.02,0.01]
# stiff_tran: [200,200,200]
# stiff_rot: [200,200,200]
# desired_xi: 1.1
# damp_tran: [-1,-1,-1]
# damp_rot: [-1,-1,-1]
# pos_scale_factor: 1
# rot_scale_factor: 1
desired_xi: 1
pos_scale_factor: 1
rot_scale_factor: 1
# mass_tran: [2.0, 2.0, 2.0]
# mass_rot: [0.2, 0.2, 0.2]
# stiff_tran: [300, 300, 300]
# stiff_rot: [3, 3, 3]
# damp_tran: [30, 30, 30]
# damp_rot: [1.0, 1.0, 1.0]
mass_tran: [2.0, 2.0, 2.0]
mass_rot: [0.2, 0.2, 0.2]
stiff_tran: [400, 400, 400]
stiff_rot: [4, 4, 4]
# stiff_tran: [100, 100, 100]
# stiff_rot: [1, 1, 1]
damp_tran: [20, 20, 20]
damp_rot: [0.6, 0.6, 0.6]
# damp_tran: [40, 40, 40]
# damp_rot: [3, 3, 3]

View File

@ -0,0 +1,28 @@
name: none
sensor_mass: 0.334
tool_type: Cylinder
tcp_offset:
- 0
- 0
- 0.031
- 0
- 0
- 0
tool_radius: 0.04
tool_mass: 0
mass_center_position:
- 0.021919029914178913
- -0.010820480799427892
- 0.011034252651402962
force_zero:
- 0.8846671183185211
- -0.6473878547983709
- -2.1312346218888862
torque_zero:
- 0.017893715524241308
- 0.04546799757174578
- -0.029532236049108707
gravity_base:
- 0.9041730658541057
- 1.6570854791729466
- -1.8745612276068087

View File

@ -0,0 +1,12 @@
# IP for DOBOT NOVA5
arm_ip: '192.168.5.1'
# controller
controller: ['Massage/MassageControl/config/admittance.yaml']
# massage heads diretory
massage_head_dir: 'Massage/MassageControl/config/massage_head'
control_rate: 50
sensor_rate: 100
command_rate: 120

View File

@ -323,6 +323,14 @@ class dobot_nova5:
else:
print("下使能机械臂失败")
return
def clearError(self):
code = self.dashboard.ClearError()
if code == 0:
print("清楚报警成功")
else:
print("清楚报警失败")
return
def __del__(self):
del self.dashboard

View File

@ -0,0 +1,80 @@
import time
import asyncio
class Rate:
def __init__(self, hz):
self.interval = 1.0 / hz
self.last_time = time.monotonic()
self.start_time = self.last_time
def sleep(self,print_duration=False):
now = time.monotonic()
sleep_duration = self.interval - (now - self.last_time)
if sleep_duration > 0:
if print_duration:
print(f"睡眠时间: {sleep_duration:.4f}")
time.sleep(sleep_duration)
self.last_time = time.monotonic()
def precise_sleep(self):
# 计算距离下一次休眠的剩余时间
now = time.perf_counter()
elapsed = now - self.last_time
sleep_duration = self.interval - elapsed
# 分成多个小的睡眠周期
while sleep_duration > 0:
# 最大睡眠时间限制在 1ms 以内,避免过多的误差
sleep_time = min(sleep_duration, 0.001) # 最大休眠1ms
time.sleep(sleep_time)
now = time.perf_counter()
elapsed = now - self.last_time
sleep_duration = self.interval - elapsed
self.last_time = time.perf_counter()
async def async_sleep(self):
now = time.monotonic()
sleep_duration = self.interval - (now - self.last_time)
if sleep_duration > 0:
# print(f"睡眠时间: {sleep_duration:.4f} 秒")
await asyncio.sleep(sleep_duration)
self.last_time = time.monotonic()
def remaining(self):
now = time.monotonic()
remaining_time = self.interval - (now - self.last_time)
return max(0.0, remaining_time)
def cycle_time(self):
now = time.monotonic()
cycle_duration = now - self.start_time
self.start_time = now
return cycle_duration
def to_sec(self):
return self.interval
if __name__ == "__main__":
# 示例用法:同步代码
def main_sync():
rate = Rate(2) # 2 赫兹的循环频率每次迭代0.5秒)
for i in range(10):
print(f"同步迭代 {i}")
print(f"剩余时间: {rate.remaining():.4f}")
print(f"上一个循环周期时间: {rate.cycle_time():.4f}")
rate.sleep()
# 示例用法:异步代码
async def main_async():
rate = Rate(2) # 2 赫兹的循环频率每次迭代0.5秒)
for i in range(10):
print(f"异步迭代 {i}")
print(f"剩余时间: {rate.remaining():.4f}")
print(f"上一个循环周期时间: {rate.cycle_time():.4f}")
await rate.async_sleep()
main_sync()
asyncio.run(main_async())

View File

@ -0,0 +1,122 @@
import json
import logging
from colorama import Fore, Style, init
from datetime import datetime
import numpy as np
import sys
import inspect
import os
# 初始化 colorama自动重置颜色
init(autoreset=True)
class CustomLogger:
def __init__(self, log_name=None, log_file=None, propagate=False, precise_time=True):
"""初始化日志记录器
:param log_name: 日志名称标识
:param propagate: 是否传递日志到父记录器
:param precise_time: 是否启用毫秒级时间戳
"""
self.logger = logging.getLogger(f"custom_logger_{log_name}")
self.logger.setLevel(logging.INFO)
self.logger.propagate = propagate
self.log_name = log_name
self.precise_time = precise_time
# 配置日志格式器
self.log_formatter = logging.Formatter(
'%(asctime)s - %(log_name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
)
# 默认添加控制台处理器
if not self.logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setFormatter(self.log_formatter)
self.logger.addHandler(console_handler)
# 添加文件处理器
if log_file:
# 确保日志目录存在
os.makedirs(os.path.dirname(log_file), exist_ok=True)
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(self.log_formatter)
self.logger.addHandler(file_handler)
def __enter__(self):
"""上下文管理器入口"""
return self
def __exit__(self, exc_type, exc_value, traceback):
"""上下文管理器退出时清理资源"""
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
handler.close()
# 彩色日志方法
def log_info(self, message, is_print=True):
self._log(message, logging.INFO, Fore.GREEN, is_print)
def log_yellow(self, message, is_print=True):
self._log(message, logging.INFO, Fore.YELLOW, is_print)
def log_blue(self, message, is_print=True):
self._log(message, logging.INFO, Fore.CYAN, is_print)
def log_warning(self, message, is_print=True):
self._log(message, logging.WARNING, Fore.YELLOW, is_print)
def log_error(self, message, is_print=True):
self._log(message, logging.ERROR, Fore.RED, is_print)
def _log(self, message, level, color, is_print=True):
"""内部日志处理核心方法
:param message: 日志内容支持字符串/字典/数组等
:param level: 日志级别 logging.INFO
:param color: colorama 颜色控制符
:param is_print: 是否打印到控制台
"""
# 1. 时间戳格式化
current_time = datetime.now()
if self.precise_time:
formatted_time = current_time.strftime('%Y-%m-%d %H:%M:%S.') + f"{current_time.microsecond // 1000:03d}"
else:
formatted_time = current_time.strftime('%Y-%m-%d %H:%M:%S')
# 2. 获取调用者信息(跳过两层栈帧)
caller_frame = inspect.stack()[2]
filename = os.path.basename(caller_frame.filename)
lineno = caller_frame.lineno
# 3. 序列化复杂数据类型
if isinstance(message, (dict, list, tuple, np.ndarray)):
try:
message = json.dumps(message if not isinstance(message, np.ndarray) else message.tolist(),
ensure_ascii=False)
except (TypeError, ValueError):
message = str(message)
else:
message = str(message)
# 4. 控制台彩色输出
if is_print:
level_name = logging.getLevelName(level)
print(f"{formatted_time} - {self.log_name} - {level_name} - {filename}:{lineno} - {color}{message}{Style.RESET_ALL}")
# 5. 写入日志系统
self.logger.log(level, message, extra={'log_name': self.log_name})
if __name__ == "__main__":
# 示例用法
with CustomLogger(
log_name="测试日志",
log_file="logs/MassageRobot_nova5.log",
precise_time=True
) as logger:
logger.log_info("普通信息(绿色)")
logger.log_yellow("黄色提示")
logger.log_blue("蓝色消息")
logger.log_warning("警告信息(黄色)")
logger.log_error("错误信息(红色)")
# 测试复杂数据类型
# logger.log_info({"key": "值", "数组": [1, 2, 3]})
# logger.log_info(np.random.rand(3, 3))

View File

@ -0,0 +1,52 @@
import yaml
def read_yaml(file_path):
"""
读取 YAML 文件并返回 Python 对象
参数:
file_path (str): YAML 文件路径
返回:
data (dict): YAML 文件内容转换的 Python 对象
"""
with open(file_path, 'r', encoding='utf-8') as file:
data = yaml.safe_load(file)
return data
def write_yaml(data, file_path):
"""
Python 对象写入 YAML 文件
参数:
data (dict): 要写入 YAML 文件的 Python 对象
file_path (str): 目标 YAML 文件路径
"""
with open(file_path, 'w', encoding='utf-8') as file:
yaml.safe_dump(data, file, default_flow_style=False, allow_unicode=True, sort_keys=False)
def update_yaml(file_path, key, value):
"""
更新 YAML 文件中的指定键值对
参数:
file_path (str): YAML 文件路径
key (str): 要更新的键
value: 要更新的值
"""
data = read_yaml(file_path)
data[key] = value
write_yaml(data, file_path)
def delete_key_yaml(file_path, key):
"""
删除 YAML 文件中的指定键
参数:
file_path (str): YAML 文件路径
key (str): 要删除的键
"""
data = read_yaml(file_path)
if key in data:
del data[key]
write_yaml(data, file_path)

View File