yolorestart
This commit is contained in:
275
yolopart/detector.py
Normal file
275
yolopart/detector.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
import os
|
||||
|
||||
class LicensePlateYOLO:
|
||||
"""
|
||||
车牌YOLO检测器类
|
||||
负责加载YOLO pose模型并进行车牌检测和角点提取
|
||||
"""
|
||||
|
||||
def __init__(self, model_path=None):
|
||||
"""
|
||||
初始化YOLO检测器
|
||||
|
||||
参数:
|
||||
model_path: 模型文件路径,如果为None则使用默认路径
|
||||
"""
|
||||
self.model = None
|
||||
self.model_path = model_path or self._get_default_model_path()
|
||||
self.class_names = {0: '蓝牌', 1: '绿牌'}
|
||||
self.load_model()
|
||||
|
||||
def _get_default_model_path(self):
|
||||
"""获取默认模型路径"""
|
||||
current_dir = os.path.dirname(__file__)
|
||||
return os.path.join(current_dir, "yolo11s-pose42.pt")
|
||||
|
||||
def load_model(self):
|
||||
"""
|
||||
加载YOLO pose模型
|
||||
|
||||
返回:
|
||||
bool: 加载是否成功
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(self.model_path):
|
||||
self.model = YOLO(self.model_path)
|
||||
print(f"YOLO模型加载成功: {self.model_path}")
|
||||
return True
|
||||
else:
|
||||
print(f"模型文件不存在: {self.model_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"YOLO模型加载失败: {e}")
|
||||
return False
|
||||
|
||||
def detect_license_plates(self, image, conf_threshold=0.5):
|
||||
"""
|
||||
检测图像中的车牌
|
||||
|
||||
参数:
|
||||
image: 输入图像 (numpy数组)
|
||||
conf_threshold: 置信度阈值
|
||||
|
||||
返回:
|
||||
list: 检测结果列表,每个元素包含:
|
||||
- box: 边界框坐标 [x1, y1, x2, y2]
|
||||
- keypoints: 四个角点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
- confidence: 置信度
|
||||
- class_id: 类别ID (0=蓝牌, 1=绿牌)
|
||||
- class_name: 类别名称
|
||||
"""
|
||||
if self.model is None:
|
||||
print("模型未加载")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 进行推理
|
||||
results = self.model(image, conf=conf_threshold, verbose=False)
|
||||
detections = []
|
||||
|
||||
for result in results:
|
||||
# 检查是否有检测结果
|
||||
if result.boxes is None or result.keypoints is None:
|
||||
continue
|
||||
|
||||
# 提取检测信息
|
||||
boxes = result.boxes.xyxy.cpu().numpy() # 边界框
|
||||
keypoints = result.keypoints.xy.cpu().numpy() # 关键点
|
||||
confidences = result.boxes.conf.cpu().numpy() # 置信度
|
||||
classes = result.boxes.cls.cpu().numpy() # 类别
|
||||
|
||||
# 处理每个检测结果
|
||||
for i in range(len(boxes)):
|
||||
# 检查关键点数量是否为4个
|
||||
if len(keypoints[i]) == 4:
|
||||
class_id = int(classes[i])
|
||||
detection = {
|
||||
'box': boxes[i],
|
||||
'keypoints': keypoints[i],
|
||||
'confidence': confidences[i],
|
||||
'class_id': class_id,
|
||||
'class_name': self.class_names.get(class_id, '未知')
|
||||
}
|
||||
detections.append(detection)
|
||||
else:
|
||||
# 关键点不足4个,记录但标记为不完整
|
||||
class_id = int(classes[i])
|
||||
detection = {
|
||||
'box': boxes[i],
|
||||
'keypoints': keypoints[i] if len(keypoints[i]) > 0 else [],
|
||||
'confidence': confidences[i],
|
||||
'class_id': class_id,
|
||||
'class_name': self.class_names.get(class_id, '未知'),
|
||||
'incomplete': True # 标记为不完整
|
||||
}
|
||||
detections.append(detection)
|
||||
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
print(f"检测过程中出错: {e}")
|
||||
return []
|
||||
|
||||
def draw_detections(self, image, detections):
|
||||
"""
|
||||
在图像上绘制检测结果
|
||||
|
||||
参数:
|
||||
image: 输入图像
|
||||
detections: 检测结果列表
|
||||
|
||||
返回:
|
||||
numpy.ndarray: 绘制了检测结果的图像
|
||||
"""
|
||||
draw_image = image.copy()
|
||||
|
||||
for i, detection in enumerate(detections):
|
||||
box = detection['box']
|
||||
keypoints = detection['keypoints']
|
||||
class_name = detection['class_name']
|
||||
confidence = detection['confidence']
|
||||
incomplete = detection.get('incomplete', False)
|
||||
|
||||
# 绘制边界框
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
# 根据车牌类型选择颜色
|
||||
if class_name == '绿牌':
|
||||
box_color = (0, 255, 0) # 绿色
|
||||
elif class_name == '蓝牌':
|
||||
box_color = (255, 0, 0) # 蓝色
|
||||
else:
|
||||
box_color = (128, 128, 128) # 灰色
|
||||
|
||||
cv2.rectangle(draw_image, (x1, y1), (x2, y2), box_color, 2)
|
||||
|
||||
# 绘制标签
|
||||
label = f"{class_name} {confidence:.2f}"
|
||||
if incomplete:
|
||||
label += " (不完整)"
|
||||
|
||||
# 计算文本大小和位置
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.6
|
||||
thickness = 2
|
||||
(text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
|
||||
|
||||
# 绘制文本背景
|
||||
cv2.rectangle(draw_image, (x1, y1 - text_height - 10),
|
||||
(x1 + text_width, y1), box_color, -1)
|
||||
|
||||
# 绘制文本
|
||||
cv2.putText(draw_image, label, (x1, y1 - 5),
|
||||
font, font_scale, (255, 255, 255), thickness)
|
||||
|
||||
# 绘制关键点和连线
|
||||
if len(keypoints) >= 4 and not incomplete:
|
||||
# 四个角点完整,用黄色连线
|
||||
points = [(int(kp[0]), int(kp[1])) for kp in keypoints[:4]]
|
||||
|
||||
# 绘制关键点
|
||||
for point in points:
|
||||
cv2.circle(draw_image, point, 5, (0, 255, 255), -1)
|
||||
|
||||
# 连接关键点形成四边形(按顺序连接)
|
||||
# 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top
|
||||
for j in range(4):
|
||||
cv2.line(draw_image, points[j], points[(j+1)%4], (0, 255, 255), 2)
|
||||
|
||||
elif len(keypoints) > 0:
|
||||
# 关键点不完整,用红色标记现有点
|
||||
for kp in keypoints:
|
||||
point = (int(kp[0]), int(kp[1]))
|
||||
cv2.circle(draw_image, point, 5, (0, 0, 255), -1)
|
||||
|
||||
return draw_image
|
||||
|
||||
def correct_license_plate(self, image, keypoints, target_size=(240, 80)):
|
||||
"""
|
||||
使用四个角点对车牌进行透视变换矫正
|
||||
|
||||
参数:
|
||||
image: 原始图像
|
||||
keypoints: 四个角点坐标
|
||||
target_size: 目标尺寸 (width, height)
|
||||
|
||||
返回:
|
||||
numpy.ndarray: 矫正后的车牌图像,如果失败返回None
|
||||
"""
|
||||
if len(keypoints) != 4:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 将关键点转换为numpy数组
|
||||
src_points = np.array(keypoints, dtype=np.float32)
|
||||
|
||||
# 定义目标矩形的四个角点
|
||||
# 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top
|
||||
# 重新排序为标准顺序: left_top, right_top, right_bottom, left_bottom
|
||||
width, height = target_size
|
||||
dst_points = np.array([
|
||||
[0, 0], # left_top
|
||||
[width, 0], # right_top
|
||||
[width, height], # right_bottom
|
||||
[0, height] # left_bottom
|
||||
], dtype=np.float32)
|
||||
|
||||
# 重新排序源点以匹配目标点
|
||||
# 原顺序: right_bottom, left_bottom, left_top, right_top
|
||||
# 目标顺序: left_top, right_top, right_bottom, left_bottom
|
||||
reordered_src = np.array([
|
||||
src_points[2], # left_top
|
||||
src_points[3], # right_top
|
||||
src_points[0], # right_bottom
|
||||
src_points[1] # left_bottom
|
||||
], dtype=np.float32)
|
||||
|
||||
# 计算透视变换矩阵
|
||||
matrix = cv2.getPerspectiveTransform(reordered_src, dst_points)
|
||||
|
||||
# 应用透视变换
|
||||
corrected = cv2.warpPerspective(image, matrix, target_size)
|
||||
|
||||
return corrected
|
||||
|
||||
except Exception as e:
|
||||
print(f"车牌矫正失败: {e}")
|
||||
return None
|
||||
|
||||
def get_model_info(self):
|
||||
"""
|
||||
获取模型信息
|
||||
|
||||
返回:
|
||||
dict: 模型信息字典
|
||||
"""
|
||||
if self.model is None:
|
||||
return {"status": "未加载", "path": self.model_path}
|
||||
|
||||
return {
|
||||
"status": "已加载",
|
||||
"path": self.model_path,
|
||||
"model_type": "YOLO11 Pose",
|
||||
"classes": self.class_names
|
||||
}
|
||||
|
||||
def initialize_yolo_detector(model_path=None):
|
||||
"""
|
||||
初始化YOLO检测器的便捷函数
|
||||
|
||||
参数:
|
||||
model_path: 模型文件路径
|
||||
|
||||
返回:
|
||||
LicensePlateYOLO: 初始化后的检测器实例
|
||||
"""
|
||||
detector = LicensePlateYOLO(model_path)
|
||||
return detector
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
detector = initialize_yolo_detector()
|
||||
print("检测器信息:", detector.get_model_info())
|
||||
Reference in New Issue
Block a user