2025-05-27 15:46:31 +08:00

213 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import cv2
import os
import numpy as np
from PIL import ImageFont, ImageDraw, Image
try:
from .config import Config
except:
from config import Config
class AbdominalAcupointsDetector:
def __init__(self):
"""
初始化 AbdominalAcupointsDetector 类,加载模型和字体。
这个类使用 YOLO 模型自动检测腹部穴位点。
"""
# 初始化模型
self.model = YOLO(Config.ABDOMEN_MODEL_PATH)
self.device = "cpu" # 默认使用 CPU
# 穴位名称列表
self.name = ["神阙", "天枢右", "天枢左", "气海", "石门", "关元", "水分", "外陵右", "滑肉右", "外陵左", "滑肉左", "大横左", "大横右"]
# 获取字体路径
self.font_path = Config.get_font_path()
if self.font_path is None:
print("警告:未找到合适的字体文件,将使用系统默认字体")
def detect_keypoints(self, image_path):
"""
加载模型并推理关键点
:param image_path: 输入图片路径
:return: 推理结果(包含关键点信息)
"""
# 进行推理
results = self.model.predict(
source=image_path,
conf=0.5, # 置信度阈值
device=self.device, # 使用 CPU 或 GPU
)
result = results[0]
if result.keypoints is not None:
print("检测到的关键点坐标:", result.keypoints.data.cpu().numpy())
else:
print("未检测到关键点")
return result
def plot_acupoints_on_image(self, image_path, acupoints, output_path):
"""
在指定图片上绘制穴位点,并标注名称,保存结果。
参数:
image_path (str): 图片文件路径。
acupoints (dict): 包含穴位坐标的字典 {"穴位名称": (x, y), ...}。
output_path (str): 保存结果图片的路径。
"""
# 读取图片
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"无法加载背景图:{image_path}")
# 将图像分辨率增大到原来的三倍
scale_factor = 2
image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
# 调整穴位点坐标
acupoints = {name: (x * scale_factor, y * scale_factor) for name, (x, y) in acupoints.items()}
# (252, 229, 179)
# 定义蓝紫色渐变颜色BGR
purple_colors = [
(223, 87, 98), # 浅蓝紫色
(223, 87, 98), # 中等蓝紫色
(223, 87, 98),
]
# 绘制穴位点
for name, (x, y) in acupoints.items():
if x == 0 and y == 0:
continue
# 绘制蓝紫色带光圈的穴位点
radius = 4 # 穴位点半径
for r in range(radius, 0, -1):
alpha = r / radius # alpha 从 1中心到 0边缘
# 计算渐变颜色
color_index = int(alpha * (len(purple_colors) - 1))
color = purple_colors[color_index]
# 绘制渐变圆
cv2.circle(image, (x, y), r, color, -1, cv2.LINE_AA)
# 添加高光效果
highlight_radius = int(radius * 0.6)
# highlight_color = (200, 200, 200) # 白色高光
highlight_color = (0, 255, 185) # 白色高光
highlight_pos = (x, y)
cv2.circle(image, highlight_pos, highlight_radius, highlight_color, -1, cv2.LINE_AA)
# 使用 PIL 绘制文本
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(image_pil)
font = ImageFont.truetype(self.font_path, 8 * scale_factor) # 字体大小也放大三倍
for name, (x, y) in acupoints.items():
if x == 0 and y == 0:
continue
if name in ["滑肉左", "大横左", "天枢左"]:
x_offset, y_offset = -15 * scale_factor, -15 * scale_factor
elif name in ["滑肉右", "大横右","天枢右"]:
x_offset, y_offset = -15 * scale_factor, -15 * scale_factor
elif name == "外陵左":
x_offset, y_offset = -35 * scale_factor, -5 * scale_factor
elif name == "关元":
x_offset, y_offset = -20 * scale_factor, -5 * scale_factor
elif name in ["石门","外陵右"]:
x_offset, y_offset = 5 * scale_factor, -5 * scale_factor
else:
x_offset, y_offset = -10 * scale_factor, -15 * scale_factor
# 绘制带白边的黑字
text_pos = (x + x_offset, y + y_offset)
for offset in [(-1,-1), (-1,1), (1,-1), (1,1)]:
draw.text((text_pos[0]+offset[0], text_pos[1]+offset[1]),
name, font=font, fill=(255,255,255)) # 白边
draw.text(text_pos, name, font=font, fill=(0,0,0)) # 黑字
# 将图像转换回 OpenCV 格式
image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
# 保存结果
cv2.imwrite(output_path, image)
print(f"结果已保存到:{output_path}")
def process_image(self, image_path, output_path):
"""
处理单张图片:推理关键点并绘制结果
参数:
image_path (str): 输入图片路径
output_path (str): 输出图片路径
返回:
dict: 包含穴位坐标的字典 {"穴位名称": (x, y), ...}
"""
# 检查输入图片是否存在
if not os.path.exists(image_path):
print(f"输入图片不存在:{image_path}")
return {}
# 创建输出文件夹(如果不存在)
output_dir = os.path.dirname(output_path)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 推理关键点
result = self.detect_keypoints(image_path)
# 获取关键点坐标
keypoints = result.keypoints.data.cpu().numpy() # 关键点坐标
num_keypoints = keypoints.shape[1] # 关键点数量
# 构建穴位点字典
acupoints = {}
for i in range(num_keypoints):
x, y = int(keypoints[0, i, 0]), int(keypoints[0, i, 1])
point_name = self.name[i] if i < len(self.name) else f"未知穴位_{i + 1}"
acupoints[point_name] = (x, y)
# 绘制穴位点并保存结果
self.plot_acupoints_on_image(image_path, acupoints, output_path)
# 返回穴位点字典
return acupoints
if __name__ == "__main__":
"""
single_picture
"""
input_image_path = Config.get_image_path("color_mz.png")
output_image_path = Config.get_output_path("color_mz_abdomen.png")
# 初始化检测器
detector = AbdominalAcupointsDetector()
# 处理单张图片,并获取穴位点字典
acupoints = detector.process_image(input_image_path, output_image_path)
print(acupoints)
"""
batch_picture
"""
# # 输入输出文件夹路径
# input_folder = os.path.join(Config.IMAGE_DIR, "train", "images")
# output_folder = os.path.join(Config.IMAGE_DIR, "output_all")
# Config.ensure_dir(output_folder)
# # 遍历输入文件夹中的所有图片文件
# for filename in os.listdir(input_folder):
# if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
# input_image_path = os.path.join(input_folder, filename)
# output_image_path = os.path.join(output_folder, filename)
# acupoints = detector.process_image(input_image_path, output_image_path)
# print(f"处理图片: {filename}")
# print("检测到的穴位点:")
# for name, coords in acupoints.items():
# print(f"{name}: {coords}")
# print("-" * 40)