2025-05-27 15:46:31 +08:00

264 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import os
import numpy as np
from PIL import ImageFont, ImageDraw, Image
try:
from .config import Config
except:
from config import Config
class LegAcupointsDetector:
def __init__(self):
"""
初始化 LegAcupointsDetector 类,加载模型和字体。
这个类使用 YOLO 模型自动检测腹部穴位点。
"""
# 初始化模型
self.model = YOLO(Config.LEG_MODEL_PATH)
self.device = "cpu" # 默认使用 CPU
# 穴位名称列表
self.name = ["承扶左", "委中左", "昆仑左", "承扶右", "委中右", "昆仑右", "殷门左", "殷门右", "上委中左", "上委中右", "承山左", "承筋左", "合阳左","承山右","承筋右","合阳右"]
# 定义约束
self.unique_list = [
"委中左", "委中右", "殷门左", "殷门右", "上委中左", "上委中右", "承山左", "承筋左", "合阳左","承山右","承筋右","合阳右"
]
# 获取字体路径
self.font_path = Config.get_font_path()
if self.font_path is None:
print("警告:未找到合适的字体文件,将使用系统默认字体")
def detect_keypoints(self, image_path):
"""
加载模型并推理关键点
:param image_path: 输入图片路径
:return: 推理结果(包含关键点信息)
"""
# 进行推理
results = self.model.predict(
source=image_path,
conf=0.5, # 置信度阈值
device=self.device, # 使用 CPU 或 GPU
)
result = results[0]
if result.keypoints is not None:
print("检测到的关键点坐标:", result.keypoints.data.cpu().numpy())
else:
print("未检测到关键点")
return result
# 异常处理
def error_process(self, coordinate):
"""
检查坐标数据是否有效。
:param coordinate: 包含穴位点坐标的字典
:return: 如果坐标有效返回 True,否则返回 None
"""
if not coordinate: # 判断 coordinates 是否为空
print("Coordinates are empty...")
return None
# 遍历 unique_list 中的每个点
for point in self.unique_list:
# 检查点是否存在于 coordinate 中
if point in coordinate:
# 检查点的值是否为空或 (0, 0)
if not coordinate[point] or coordinate[point] == (0, 0):
print(f"Point '{point}' is empty or (0, 0), skipping the calculation for this iteration...")
return None
else:
print(f"Point '{point}' is missing in the coordinates, skipping the calculation for this iteration...")
return None
# 如果所有检查都通过,返回 True
print("All points are valid, proceeding with the calculation...")
return True
def plot_acupoints_on_image(self, image_path, acupoints, output_path):
"""
在指定图片上绘制穴位点,并标注名称,保存结果。
参数:
image_path (str): 图片文件路径。
acupoints (dict): 包含穴位坐标的字典 {"穴位名称": (x, y), ...}。
output_path (str): 保存结果图片的路径。
"""
# 读取图片
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"无法加载背景图:{image_path}")
# 将图像分辨率增大到原来的三倍
scale_factor = 2
image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
# 调整穴位点坐标
acupoints = {name: (x * scale_factor, y * scale_factor) for name, (x, y) in acupoints.items()}
# (252, 229, 179)
# 定义蓝紫色渐变颜色BGR
purple_colors = [
(223, 87, 98), # 浅蓝紫色
(223, 87, 98), # 中等蓝紫色
(223, 87, 98),
]
# 绘制穴位点
for name, (x, y) in acupoints.items():
if x == 0 and y == 0:
continue
# 绘制蓝紫色带光圈的穴位点
radius = 4 # 穴位点半径
for r in range(radius, 0, -1):
alpha = r / radius # alpha 从 1中心到 0边缘
# 计算渐变颜色
color_index = int(alpha * (len(purple_colors) - 1))
color = purple_colors[color_index]
# 绘制渐变圆
cv2.circle(image, (x, y), r, color, -1, cv2.LINE_AA)
# 添加高光效果
highlight_radius = int(radius * 0.6)
# highlight_color = (200, 200, 200) # 白色高光
highlight_color = (0, 255, 185) # 白色高光
highlight_pos = (x, y)
cv2.circle(image, highlight_pos, highlight_radius, highlight_color, -1, cv2.LINE_AA)
# 使用 PIL 绘制文本
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(image_pil)
font = ImageFont.truetype(self.font_path, 8 * scale_factor) # 字体大小也放大三倍
for name, (x, y) in acupoints.items():
if x == 0 and y == 0:
continue
if name == "上委中左":
x_offset, y_offset = -40 * scale_factor, -5 * scale_factor
elif "" in name:
x_offset, y_offset = -30 * scale_factor, -5 * scale_factor
else:
x_offset, y_offset = 5 * scale_factor, -5 * scale_factor
# 绘制带白边的黑字
text_pos = (x + x_offset, y + y_offset)
for offset in [(-1,-1), (-1,1), (1,-1), (1,1)]:
draw.text((text_pos[0]+offset[0], text_pos[1]+offset[1]),
name, font=font, fill=(255,255,255)) # 白边
draw.text(text_pos, name, font=font, fill=(0,0,0)) # 黑字
# 将图像转换回 OpenCV 格式
image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
# 保存结果
cv2.imwrite(output_path, image)
print(f"结果已保存到:{output_path}")
def process_image(self, image_path, output_path):
"""
处理单张图片:推理关键点并绘制结果
参数:
image_path (str): 输入图片路径
output_path (str): 输出图片路径
返回:
dict: 包含穴位坐标的字典 {"穴位名称": (x, y), ...}
"""
# 检查输入图片是否存在
if not os.path.exists(image_path):
print(f"输入图片不存在:{image_path}")
return {}
# 创建输出文件夹(如果不存在)
output_dir = os.path.dirname(output_path)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 推理关键点
result = self.detect_keypoints(image_path)
# 获取关键点坐标
keypoints = result.keypoints.data.cpu().numpy() # 关键点坐标
num_keypoints = keypoints.shape[1] # 关键点数量
# 构建穴位点字典
acupoints = {}
for i in range(num_keypoints):
x, y = int(keypoints[0, i, 0]), int(keypoints[0, i, 1])
point_name = self.name[i] if i < len(self.name) else f"未知穴位_{i + 1}"
acupoints[point_name] = (x, y)
result = self.error_process(acupoints)
if result is None:
print("Pass this image due to error processing:Yolov8穴位产生为None或(0,0),人躺太上或者是躺太下了.")
return None
# 绘制穴位点并保存结果
self.plot_acupoints_on_image(image_path, acupoints, output_path)
# 返回穴位点字典
return acupoints
if __name__ == "__main__":
"""
single_picture
"""
input_image_path = Config.get_image_path("leg.png")
output_image_path = Config.get_output_path("leg_yolo.png")
# input_image_path = "aucpuncture2point/configs/using_img/leg.png" # 替换为你的图片路径
# output_coordinate_image_path = "aucpuncture2point/configs/using_img/leg_16points.png" # 替换为输出图片路径
# 初始化检测器
detector = LegAcupointsDetector()
# 处理单张图片,并获取穴位点字典
acupoints = detector.process_image(input_image_path, output_image_path)
print(acupoints)
# """
# batch_picture
# """
# # 批量处理
# leg_point = LegAcupointsDetector()
# # # 输入和输出文件夹路径
# # input_folder = r"/home/kira/codes/datas/colors-0872E1"
# # output_folder = r"/home/kira/codes/datas/results-0872E11"
# # input_folder = r"/home/kira/codes/IoT_img_process/output_classified/leg"
# # output_folder = r"/home/kira/codes/datas/results-3-31-leg"
# #异常处理的
# input_folder = r"/home/kira/codes/datas/colors_leg_yichang"
# output_folder = r"/home/kira/codes/datas/results-leg_yichang"
# # 确保输出文件夹存在
# if not os.path.exists(output_folder):
# os.makedirs(output_folder)
# # 遍历输入文件夹中的所有文件
# for filename in os.listdir(input_folder):
# # 只处理图片文件(支持常见格式)
# if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
# # 输入图片路径
# input_image_path = os.path.join(input_folder, filename)
# # 输出图片路径
# output_image_name = f"processed_{filename}"
# output_coordinate_image_path = os.path.join(output_folder, output_image_name)
# # 处理单张图片
# try:
# leg_acupoints_list = leg_point.process_image(input_image_path, output_coordinate_image_path)
# print(f"处理完成:{filename} -> {output_image_name}")
# print(f"穴位点坐标:{leg_acupoints_list}")
# except Exception as e:
# print(f"处理失败:{filename},错误信息:{e}")