213 lines
7.9 KiB
Python
213 lines
7.9 KiB
Python
from ultralytics import YOLO
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
from PIL import ImageFont, ImageDraw, Image
|
||
try:
|
||
from .config import Config
|
||
except:
|
||
from config import Config
|
||
|
||
class AbdominalAcupointsDetector:
|
||
def __init__(self):
|
||
"""
|
||
初始化 AbdominalAcupointsDetector 类,加载模型和字体。
|
||
这个类使用 YOLO 模型自动检测腹部穴位点。
|
||
"""
|
||
# 初始化模型
|
||
self.model = YOLO(Config.ABDOMEN_MODEL_PATH)
|
||
self.device = "cpu" # 默认使用 CPU
|
||
|
||
# 穴位名称列表
|
||
self.name = ["神阙", "天枢右", "天枢左", "气海", "石门", "关元", "水分", "外陵右", "滑肉右", "外陵左", "滑肉左", "大横左", "大横右"]
|
||
|
||
# 获取字体路径
|
||
self.font_path = Config.get_font_path()
|
||
if self.font_path is None:
|
||
print("警告:未找到合适的字体文件,将使用系统默认字体")
|
||
|
||
def detect_keypoints(self, image_path):
|
||
"""
|
||
加载模型并推理关键点
|
||
:param image_path: 输入图片路径
|
||
:return: 推理结果(包含关键点信息)
|
||
"""
|
||
# 进行推理
|
||
results = self.model.predict(
|
||
source=image_path,
|
||
conf=0.5, # 置信度阈值
|
||
device=self.device, # 使用 CPU 或 GPU
|
||
)
|
||
result = results[0]
|
||
if result.keypoints is not None:
|
||
print("检测到的关键点坐标:", result.keypoints.data.cpu().numpy())
|
||
else:
|
||
print("未检测到关键点")
|
||
return result
|
||
|
||
|
||
def plot_acupoints_on_image(self, image_path, acupoints, output_path):
|
||
"""
|
||
在指定图片上绘制穴位点,并标注名称,保存结果。
|
||
|
||
参数:
|
||
image_path (str): 图片文件路径。
|
||
acupoints (dict): 包含穴位坐标的字典 {"穴位名称": (x, y), ...}。
|
||
output_path (str): 保存结果图片的路径。
|
||
"""
|
||
# 读取图片
|
||
image = cv2.imread(image_path)
|
||
if image is None:
|
||
raise FileNotFoundError(f"无法加载背景图:{image_path}")
|
||
|
||
# 将图像分辨率增大到原来的三倍
|
||
scale_factor = 2
|
||
image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
|
||
|
||
# 调整穴位点坐标
|
||
acupoints = {name: (x * scale_factor, y * scale_factor) for name, (x, y) in acupoints.items()}
|
||
# (252, 229, 179)
|
||
# 定义蓝紫色渐变颜色BGR
|
||
purple_colors = [
|
||
(223, 87, 98), # 浅蓝紫色
|
||
(223, 87, 98), # 中等蓝紫色
|
||
(223, 87, 98),
|
||
]
|
||
|
||
# 绘制穴位点
|
||
for name, (x, y) in acupoints.items():
|
||
if x == 0 and y == 0:
|
||
continue
|
||
|
||
# 绘制蓝紫色带光圈的穴位点
|
||
radius = 4 # 穴位点半径
|
||
for r in range(radius, 0, -1):
|
||
alpha = r / radius # alpha 从 1(中心)到 0(边缘)
|
||
|
||
# 计算渐变颜色
|
||
color_index = int(alpha * (len(purple_colors) - 1))
|
||
color = purple_colors[color_index]
|
||
|
||
# 绘制渐变圆
|
||
cv2.circle(image, (x, y), r, color, -1, cv2.LINE_AA)
|
||
|
||
# 添加高光效果
|
||
highlight_radius = int(radius * 0.6)
|
||
# highlight_color = (200, 200, 200) # 白色高光
|
||
highlight_color = (0, 255, 185) # 白色高光
|
||
highlight_pos = (x, y)
|
||
cv2.circle(image, highlight_pos, highlight_radius, highlight_color, -1, cv2.LINE_AA)
|
||
|
||
# 使用 PIL 绘制文本
|
||
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(image_pil)
|
||
font = ImageFont.truetype(self.font_path, 8 * scale_factor) # 字体大小也放大三倍
|
||
|
||
for name, (x, y) in acupoints.items():
|
||
if x == 0 and y == 0:
|
||
continue
|
||
|
||
|
||
if name in ["滑肉左", "大横左", "天枢左"]:
|
||
x_offset, y_offset = -15 * scale_factor, -15 * scale_factor
|
||
elif name in ["滑肉右", "大横右","天枢右"]:
|
||
x_offset, y_offset = -15 * scale_factor, -15 * scale_factor
|
||
elif name == "外陵左":
|
||
x_offset, y_offset = -35 * scale_factor, -5 * scale_factor
|
||
elif name == "关元":
|
||
x_offset, y_offset = -20 * scale_factor, -5 * scale_factor
|
||
elif name in ["石门","外陵右"]:
|
||
x_offset, y_offset = 5 * scale_factor, -5 * scale_factor
|
||
|
||
else:
|
||
x_offset, y_offset = -10 * scale_factor, -15 * scale_factor
|
||
|
||
# 绘制带白边的黑字
|
||
text_pos = (x + x_offset, y + y_offset)
|
||
for offset in [(-1,-1), (-1,1), (1,-1), (1,1)]:
|
||
draw.text((text_pos[0]+offset[0], text_pos[1]+offset[1]),
|
||
name, font=font, fill=(255,255,255)) # 白边
|
||
draw.text(text_pos, name, font=font, fill=(0,0,0)) # 黑字
|
||
|
||
# 将图像转换回 OpenCV 格式
|
||
image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
||
|
||
# 保存结果
|
||
cv2.imwrite(output_path, image)
|
||
print(f"结果已保存到:{output_path}")
|
||
|
||
def process_image(self, image_path, output_path):
|
||
"""
|
||
处理单张图片:推理关键点并绘制结果
|
||
|
||
参数:
|
||
image_path (str): 输入图片路径
|
||
output_path (str): 输出图片路径
|
||
|
||
返回:
|
||
dict: 包含穴位坐标的字典 {"穴位名称": (x, y), ...}
|
||
"""
|
||
# 检查输入图片是否存在
|
||
if not os.path.exists(image_path):
|
||
print(f"输入图片不存在:{image_path}")
|
||
return {}
|
||
|
||
# 创建输出文件夹(如果不存在)
|
||
output_dir = os.path.dirname(output_path)
|
||
if not os.path.exists(output_dir):
|
||
os.makedirs(output_dir)
|
||
|
||
# 推理关键点
|
||
result = self.detect_keypoints(image_path)
|
||
|
||
# 获取关键点坐标
|
||
keypoints = result.keypoints.data.cpu().numpy() # 关键点坐标
|
||
num_keypoints = keypoints.shape[1] # 关键点数量
|
||
|
||
# 构建穴位点字典
|
||
acupoints = {}
|
||
for i in range(num_keypoints):
|
||
x, y = int(keypoints[0, i, 0]), int(keypoints[0, i, 1])
|
||
point_name = self.name[i] if i < len(self.name) else f"未知穴位_{i + 1}"
|
||
acupoints[point_name] = (x, y)
|
||
|
||
# 绘制穴位点并保存结果
|
||
self.plot_acupoints_on_image(image_path, acupoints, output_path)
|
||
|
||
# 返回穴位点字典
|
||
return acupoints
|
||
|
||
|
||
if __name__ == "__main__":
|
||
"""
|
||
single_picture
|
||
"""
|
||
input_image_path = Config.get_image_path("color_mz.png")
|
||
output_image_path = Config.get_output_path("color_mz_abdomen.png")
|
||
|
||
# 初始化检测器
|
||
detector = AbdominalAcupointsDetector()
|
||
|
||
# 处理单张图片,并获取穴位点字典
|
||
acupoints = detector.process_image(input_image_path, output_image_path)
|
||
print(acupoints)
|
||
|
||
"""
|
||
batch_picture
|
||
"""
|
||
# # 输入输出文件夹路径
|
||
# input_folder = os.path.join(Config.IMAGE_DIR, "train", "images")
|
||
# output_folder = os.path.join(Config.IMAGE_DIR, "output_all")
|
||
# Config.ensure_dir(output_folder)
|
||
|
||
# # 遍历输入文件夹中的所有图片文件
|
||
# for filename in os.listdir(input_folder):
|
||
# if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||
# input_image_path = os.path.join(input_folder, filename)
|
||
# output_image_path = os.path.join(output_folder, filename)
|
||
# acupoints = detector.process_image(input_image_path, output_image_path)
|
||
# print(f"处理图片: {filename}")
|
||
# print("检测到的穴位点:")
|
||
# for name, coords in acupoints.items():
|
||
# print(f"{name}: {coords}")
|
||
# print("-" * 40) |