264 lines
10 KiB
Python
264 lines
10 KiB
Python
from ultralytics import YOLO
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
from PIL import ImageFont, ImageDraw, Image
|
||
try:
|
||
from .config import Config
|
||
except:
|
||
from config import Config
|
||
|
||
class LegAcupointsDetector:
|
||
def __init__(self):
|
||
"""
|
||
初始化 LegAcupointsDetector 类,加载模型和字体。
|
||
这个类使用 YOLO 模型自动检测腹部穴位点。
|
||
"""
|
||
# 初始化模型
|
||
self.model = YOLO(Config.LEG_MODEL_PATH)
|
||
self.device = "cpu" # 默认使用 CPU
|
||
|
||
# 穴位名称列表
|
||
self.name = ["承扶左", "委中左", "昆仑左", "承扶右", "委中右", "昆仑右", "殷门左", "殷门右", "上委中左", "上委中右", "承山左", "承筋左", "合阳左","承山右","承筋右","合阳右"]
|
||
|
||
# 定义约束
|
||
self.unique_list = [
|
||
"委中左", "委中右", "殷门左", "殷门右", "上委中左", "上委中右", "承山左", "承筋左", "合阳左","承山右","承筋右","合阳右"
|
||
]
|
||
# 获取字体路径
|
||
self.font_path = Config.get_font_path()
|
||
if self.font_path is None:
|
||
print("警告:未找到合适的字体文件,将使用系统默认字体")
|
||
|
||
def detect_keypoints(self, image_path):
|
||
"""
|
||
加载模型并推理关键点
|
||
:param image_path: 输入图片路径
|
||
:return: 推理结果(包含关键点信息)
|
||
"""
|
||
# 进行推理
|
||
results = self.model.predict(
|
||
source=image_path,
|
||
conf=0.5, # 置信度阈值
|
||
device=self.device, # 使用 CPU 或 GPU
|
||
)
|
||
result = results[0]
|
||
if result.keypoints is not None:
|
||
print("检测到的关键点坐标:", result.keypoints.data.cpu().numpy())
|
||
else:
|
||
print("未检测到关键点")
|
||
return result
|
||
|
||
# 异常处理
|
||
def error_process(self, coordinate):
|
||
"""
|
||
检查坐标数据是否有效。
|
||
|
||
:param coordinate: 包含穴位点坐标的字典
|
||
:return: 如果坐标有效返回 True,否则返回 None
|
||
"""
|
||
if not coordinate: # 判断 coordinates 是否为空
|
||
print("Coordinates are empty...")
|
||
return None
|
||
|
||
# 遍历 unique_list 中的每个点
|
||
for point in self.unique_list:
|
||
# 检查点是否存在于 coordinate 中
|
||
if point in coordinate:
|
||
# 检查点的值是否为空或 (0, 0)
|
||
if not coordinate[point] or coordinate[point] == (0, 0):
|
||
print(f"Point '{point}' is empty or (0, 0), skipping the calculation for this iteration...")
|
||
return None
|
||
else:
|
||
print(f"Point '{point}' is missing in the coordinates, skipping the calculation for this iteration...")
|
||
return None
|
||
|
||
# 如果所有检查都通过,返回 True
|
||
print("All points are valid, proceeding with the calculation...")
|
||
return True
|
||
|
||
|
||
def plot_acupoints_on_image(self, image_path, acupoints, output_path):
|
||
"""
|
||
在指定图片上绘制穴位点,并标注名称,保存结果。
|
||
|
||
参数:
|
||
image_path (str): 图片文件路径。
|
||
acupoints (dict): 包含穴位坐标的字典 {"穴位名称": (x, y), ...}。
|
||
output_path (str): 保存结果图片的路径。
|
||
"""
|
||
# 读取图片
|
||
image = cv2.imread(image_path)
|
||
if image is None:
|
||
raise FileNotFoundError(f"无法加载背景图:{image_path}")
|
||
|
||
# 将图像分辨率增大到原来的三倍
|
||
scale_factor = 2
|
||
image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
|
||
|
||
# 调整穴位点坐标
|
||
acupoints = {name: (x * scale_factor, y * scale_factor) for name, (x, y) in acupoints.items()}
|
||
# (252, 229, 179)
|
||
# 定义蓝紫色渐变颜色BGR
|
||
purple_colors = [
|
||
(223, 87, 98), # 浅蓝紫色
|
||
(223, 87, 98), # 中等蓝紫色
|
||
(223, 87, 98),
|
||
]
|
||
|
||
# 绘制穴位点
|
||
for name, (x, y) in acupoints.items():
|
||
if x == 0 and y == 0:
|
||
continue
|
||
|
||
# 绘制蓝紫色带光圈的穴位点
|
||
radius = 4 # 穴位点半径
|
||
for r in range(radius, 0, -1):
|
||
alpha = r / radius # alpha 从 1(中心)到 0(边缘)
|
||
|
||
# 计算渐变颜色
|
||
color_index = int(alpha * (len(purple_colors) - 1))
|
||
color = purple_colors[color_index]
|
||
|
||
# 绘制渐变圆
|
||
cv2.circle(image, (x, y), r, color, -1, cv2.LINE_AA)
|
||
|
||
# 添加高光效果
|
||
highlight_radius = int(radius * 0.6)
|
||
# highlight_color = (200, 200, 200) # 白色高光
|
||
highlight_color = (0, 255, 185) # 白色高光
|
||
highlight_pos = (x, y)
|
||
cv2.circle(image, highlight_pos, highlight_radius, highlight_color, -1, cv2.LINE_AA)
|
||
|
||
# 使用 PIL 绘制文本
|
||
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(image_pil)
|
||
font = ImageFont.truetype(self.font_path, 8 * scale_factor) # 字体大小也放大三倍
|
||
|
||
for name, (x, y) in acupoints.items():
|
||
if x == 0 and y == 0:
|
||
continue
|
||
if name == "上委中左":
|
||
x_offset, y_offset = -40 * scale_factor, -5 * scale_factor
|
||
elif "左" in name:
|
||
x_offset, y_offset = -30 * scale_factor, -5 * scale_factor
|
||
else:
|
||
x_offset, y_offset = 5 * scale_factor, -5 * scale_factor
|
||
|
||
|
||
# 绘制带白边的黑字
|
||
text_pos = (x + x_offset, y + y_offset)
|
||
for offset in [(-1,-1), (-1,1), (1,-1), (1,1)]:
|
||
draw.text((text_pos[0]+offset[0], text_pos[1]+offset[1]),
|
||
name, font=font, fill=(255,255,255)) # 白边
|
||
draw.text(text_pos, name, font=font, fill=(0,0,0)) # 黑字
|
||
|
||
# 将图像转换回 OpenCV 格式
|
||
image = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
||
|
||
# 保存结果
|
||
cv2.imwrite(output_path, image)
|
||
print(f"结果已保存到:{output_path}")
|
||
|
||
|
||
def process_image(self, image_path, output_path):
|
||
"""
|
||
处理单张图片:推理关键点并绘制结果
|
||
|
||
参数:
|
||
image_path (str): 输入图片路径
|
||
output_path (str): 输出图片路径
|
||
|
||
返回:
|
||
dict: 包含穴位坐标的字典 {"穴位名称": (x, y), ...}
|
||
"""
|
||
# 检查输入图片是否存在
|
||
if not os.path.exists(image_path):
|
||
print(f"输入图片不存在:{image_path}")
|
||
return {}
|
||
|
||
# 创建输出文件夹(如果不存在)
|
||
output_dir = os.path.dirname(output_path)
|
||
if not os.path.exists(output_dir):
|
||
os.makedirs(output_dir)
|
||
|
||
# 推理关键点
|
||
result = self.detect_keypoints(image_path)
|
||
|
||
# 获取关键点坐标
|
||
keypoints = result.keypoints.data.cpu().numpy() # 关键点坐标
|
||
num_keypoints = keypoints.shape[1] # 关键点数量
|
||
|
||
# 构建穴位点字典
|
||
acupoints = {}
|
||
for i in range(num_keypoints):
|
||
x, y = int(keypoints[0, i, 0]), int(keypoints[0, i, 1])
|
||
point_name = self.name[i] if i < len(self.name) else f"未知穴位_{i + 1}"
|
||
acupoints[point_name] = (x, y)
|
||
|
||
result = self.error_process(acupoints)
|
||
if result is None:
|
||
print("Pass this image due to error processing:Yolov8穴位产生为None或(0,0),人躺太上或者是躺太下了.")
|
||
return None
|
||
# 绘制穴位点并保存结果
|
||
self.plot_acupoints_on_image(image_path, acupoints, output_path)
|
||
|
||
# 返回穴位点字典
|
||
return acupoints
|
||
|
||
|
||
if __name__ == "__main__":
|
||
"""
|
||
single_picture
|
||
"""
|
||
input_image_path = Config.get_image_path("leg.png")
|
||
output_image_path = Config.get_output_path("leg_yolo.png")
|
||
# input_image_path = "aucpuncture2point/configs/using_img/leg.png" # 替换为你的图片路径
|
||
# output_coordinate_image_path = "aucpuncture2point/configs/using_img/leg_16points.png" # 替换为输出图片路径
|
||
# 初始化检测器
|
||
detector = LegAcupointsDetector()
|
||
|
||
# 处理单张图片,并获取穴位点字典
|
||
acupoints = detector.process_image(input_image_path, output_image_path)
|
||
print(acupoints)
|
||
|
||
# """
|
||
# batch_picture
|
||
# """
|
||
# # 批量处理
|
||
# leg_point = LegAcupointsDetector()
|
||
|
||
# # # 输入和输出文件夹路径
|
||
# # input_folder = r"/home/kira/codes/datas/colors-0872E1"
|
||
# # output_folder = r"/home/kira/codes/datas/results-0872E11"
|
||
|
||
# # input_folder = r"/home/kira/codes/IoT_img_process/output_classified/leg"
|
||
# # output_folder = r"/home/kira/codes/datas/results-3-31-leg"
|
||
|
||
|
||
# #异常处理的
|
||
# input_folder = r"/home/kira/codes/datas/colors_leg_yichang"
|
||
# output_folder = r"/home/kira/codes/datas/results-leg_yichang"
|
||
|
||
# # 确保输出文件夹存在
|
||
# if not os.path.exists(output_folder):
|
||
# os.makedirs(output_folder)
|
||
|
||
# # 遍历输入文件夹中的所有文件
|
||
# for filename in os.listdir(input_folder):
|
||
# # 只处理图片文件(支持常见格式)
|
||
# if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
|
||
# # 输入图片路径
|
||
# input_image_path = os.path.join(input_folder, filename)
|
||
|
||
# # 输出图片路径
|
||
# output_image_name = f"processed_{filename}"
|
||
# output_coordinate_image_path = os.path.join(output_folder, output_image_name)
|
||
|
||
# # 处理单张图片
|
||
# try:
|
||
# leg_acupoints_list = leg_point.process_image(input_image_path, output_coordinate_image_path)
|
||
# print(f"处理完成:{filename} -> {output_image_name}")
|
||
# print(f"穴位点坐标:{leg_acupoints_list}")
|
||
# except Exception as e:
|
||
# print(f"处理失败:{filename},错误信息:{e}") |