更新接口

This commit is contained in:
2025-09-04 00:07:52 +08:00
parent 95aa6b6bba
commit 6c7f013a0c
5 changed files with 290 additions and 187 deletions

View File

@@ -1,28 +1,27 @@
import torch
import torch.nn as nn
import numpy as np
import cv2
from torch.autograd import Variable
import numpy as np
import os
import sys
from torch.autograd import Variable
from PIL import Image
# 添加父目录到路径,以便导入模型和数据加载器
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 字符集定义
# LPRNet字符集定义与训练时保持一致
CHARS = ['', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'',
'', '', '', '', '', '', '', '', '', '', '',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O', '-'
]
'W', 'X', 'Y', 'Z', 'I', 'O', '-']
CHARS_DICT = {char: i for i, char in enumerate(CHARS)}
# 全局变量
lprnet_model = None
device = None
# 简化的LPRNet模型定义
class small_basic_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(small_basic_block, self).__init__()
@@ -35,7 +34,7 @@ class small_basic_block(nn.Module):
nn.ReLU(),
nn.Conv2d(ch_out // 4, ch_out, kernel_size=1),
)
def forward(self, x):
return self.block(x)
@@ -58,20 +57,20 @@ class LPRNet(nn.Module):
nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 10
small_basic_block(ch_in=256, ch_out=256), # *** 11 ***
nn.BatchNorm2d(num_features=256), # 12
nn.ReLU(),
nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 13
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)), # 14
nn.Dropout(dropout_rate),
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16
nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 18
nn.Dropout(dropout_rate),
nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20
nn.BatchNorm2d(num_features=class_num),
nn.ReLU(), # *** 22 ***
nn.ReLU(), # 22
)
self.container = nn.Sequential(
nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1, 1), stride=(1, 1)),
nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1,1), stride=(1,1)),
)
def forward(self, x):
@@ -98,101 +97,177 @@ class LPRNet(nn.Module):
return logits
def build_lprnet(lpr_max_len=8, phase=False, class_num=66, dropout_rate=0.5):
"""构建LPRNet模型"""
Net = LPRNet(lpr_max_len, phase, class_num, dropout_rate)
class LPRNetInference:
def __init__(self, model_path=None, img_size=[94, 24], lpr_max_len=8, dropout_rate=0.5):
"""
初始化LPRNet推理类
Args:
model_path: 训练好的模型权重文件路径
img_size: 输入图像尺寸 [width, height]
lpr_max_len: 车牌最大长度
dropout_rate: dropout率
"""
self.img_size = img_size
self.lpr_max_len = lpr_max_len
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 设置默认模型路径
if model_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'LPRNet__iteration_74000.pth')
# 初始化模型
self.model = LPRNet(lpr_max_len=lpr_max_len, phase=False, class_num=len(CHARS), dropout_rate=dropout_rate)
# 加载模型权重
if model_path and os.path.exists(model_path):
print(f"Loading LPRNet model from {model_path}")
try:
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
print("LPRNet模型权重加载成功")
except Exception as e:
print(f"Warning: 加载模型权重失败: {e}. 使用随机权重.")
else:
print(f"Warning: 模型文件不存在或未指定: {model_path}. 使用随机权重.")
self.model.to(self.device)
self.model.eval()
print(f"LPRNet模型加载完成设备: {self.device}")
print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters()):,}")
if phase == "train":
return Net.train()
else:
return Net.eval()
def preprocess_image(image_array, img_size=(94, 24)):
"""图像预处理"""
# 确保输入是numpy数组
if not isinstance(image_array, np.ndarray):
raise ValueError("输入必须是numpy数组")
def preprocess_image(self, image_array):
"""
预处理图像数组 - 使用与训练时相同的预处理方式
Args:
image_array: numpy数组格式的图像 (H, W, C)
Returns:
preprocessed_image: 预处理后的图像tensor
"""
if image_array is None:
raise ValueError("Input image is None")
# 确保图像是numpy数组
if not isinstance(image_array, np.ndarray):
raise ValueError("Input must be numpy array")
# 检查图像维度
if len(image_array.shape) != 3:
raise ValueError(f"Expected 3D image array, got {len(image_array.shape)}D")
height, width, channels = image_array.shape
if channels != 3:
raise ValueError(f"Expected 3 channels, got {channels}")
# 调整图像尺寸到模型要求的尺寸
if height != self.img_size[1] or width != self.img_size[0]:
image_array = cv2.resize(image_array, tuple(self.img_size))
# 使用与训练时相同的预处理方式
image_array = image_array.astype('float32')
image_array -= 127.5
image_array *= 0.0078125
image_array = np.transpose(image_array, (2, 0, 1)) # HWC -> CHW
# 转换为tensor并添加batch维度
image_tensor = torch.from_numpy(image_array).unsqueeze(0)
return image_tensor
# 调整图像尺寸
height, width = image_array.shape[:2]
if height != img_size[1] or width != img_size[0]:
image_array = cv2.resize(image_array, img_size)
# 归一化到[0,1]
image_array = image_array.astype(np.float32) / 255.0
# 转换为CHW格式
if len(image_array.shape) == 3:
image_array = np.transpose(image_array, (2, 0, 1))
# 添加batch维度
image_array = np.expand_dims(image_array, axis=0)
return image_array
def greedy_decode(prebs):
"""贪婪解码"""
preb_labels = list()
for i in range(prebs.shape[0]):
preb = prebs[i, :, :]
preb_label = list()
for j in range(preb.shape[1]):
def decode_prediction(self, logits):
"""
解码模型预测结果 - 使用正确的CTC贪婪解码
Args:
logits: 模型输出的logits [batch_size, num_classes, sequence_length]
Returns:
predicted_text: 预测的车牌号码
"""
# 转换为numpy进行处理
prebs = logits.cpu().detach().numpy()
preb = prebs[0, :, :] # 取第一个batch [num_classes, sequence_length]
# 贪婪解码:对每个时间步选择最大概率的字符
preb_label = []
for j in range(preb.shape[1]): # 遍历每个时间步
preb_label.append(np.argmax(preb[:, j], axis=0))
no_repeat_blank_label = list()
# CTC解码去除重复字符和空白字符
no_repeat_blank_label = []
pre_c = preb_label[0]
if pre_c != len(CHARS) - 1:
# 处理第一个字符
if pre_c != len(CHARS) - 1: # 不是空白字符
no_repeat_blank_label.append(pre_c)
for c in preb_label: # 去除重复标签和空白标签
if (pre_c == c) or (c == len(CHARS) - 1):
# 处理后续字符
for c in preb_label:
if (pre_c == c) or (c == len(CHARS) - 1): # 重复字符或空白字符
if c == len(CHARS) - 1:
pre_c = c
continue
no_repeat_blank_label.append(c)
pre_c = c
preb_labels.append(no_repeat_blank_label)
# 转换为字符
decoded_chars = [CHARS[idx] for idx in no_repeat_blank_label]
return ''.join(decoded_chars)
return preb_labels
def predict(self, image_array):
"""
预测单张图像的车牌号码
Args:
image_array: numpy数组格式的图像
Returns:
prediction: 预测的车牌号码
confidence: 预测置信度
"""
try:
# 预处理图像
image = self.preprocess_image(image_array)
if image is None:
return None, 0.0
image = image.to(self.device)
# 模型推理
with torch.no_grad():
logits = self.model(image)
# logits shape: [batch_size, class_num, sequence_length]
# 计算置信度使用softmax后的最大概率平均值
probs = torch.softmax(logits, dim=1)
max_probs = torch.max(probs, dim=1)[0]
confidence = torch.mean(max_probs).item()
# 解码预测结果
prediction = self.decode_prediction(logits)
return prediction, confidence
except Exception as e:
print(f"预测图像失败: {e}")
return None, 0.0
def LPRNinitialize_model(model_path=None):
"""初始化LPRNet模型"""
global lprnet_model, device
# 全局变量
lpr_model = None
def LPRNinitialize_model():
"""
初始化LPRNet模型
返回:
bool: 初始化是否成功
"""
global lpr_model
try:
# 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 模型权重文件路径
model_path = os.path.join(os.path.dirname(__file__), 'LPRNet__iteration_74000.pth')
# 构建模型
lprnet_model = build_lprnet(
lpr_max_len=8,
phase=False,
class_num=len(CHARS),
dropout_rate=0.5
)
# 加载预训练权重
if model_path is None:
model_path = os.path.join(os.path.dirname(__file__), "Final_LPRNet_model.pth")
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=device)
lprnet_model.load_state_dict(checkpoint)
print(f"成功加载预训练模型: {model_path}")
else:
print(f"警告: 未找到预训练模型文件 {model_path},使用随机初始化权重")
lprnet_model.to(device)
lprnet_model.eval()
# 创建推理对象
lpr_model = LPRNetInference(model_path)
print("LPRNet模型初始化完成")
# 统计模型参数
total_params = sum(p.numel() for p in lprnet_model.parameters())
print(f"LPRNet模型参数数量: {total_params:,}")
return True
except Exception as e:
@@ -209,76 +284,45 @@ def LPRNmodel_predict(image_array):
image_array: numpy数组格式的车牌图像已经过矫正处理
返回:
list: 包含7个字符的列表,代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5']
list: 包含最多8个字符的列表,代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5'] (蓝牌7位)
['', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
"""
global lprnet_model, device
global lpr_model
if lprnet_model is None:
if lpr_model is None:
print("LPRNet模型未初始化请先调用LPRNinitialize_model()")
return ['', '', '', '0', '0', '0', '0']
return ['', '', '', '0', '0', '0', '0', '0']
try:
# 预处理图像
processed_image = preprocess_image(image_array)
# 预测车牌号
predicted_text, confidence = lpr_model.predict(image_array)
# 转换为tensor
input_tensor = torch.from_numpy(processed_image).float()
input_tensor = input_tensor.to(device)
if predicted_text is None:
print("LPRNet识别失败")
return ['', '', '', '', '0', '0', '0', '0']
# 模型推理
with torch.no_grad():
prebs = lprnet_model(input_tensor)
prebs = prebs.cpu().detach().numpy()
print(f"LPRNet识别结果: {predicted_text}, 置信度: {confidence:.3f}")
# 贪婪解码
preb_labels = greedy_decode(prebs)
# 将字符串转换为字符列表
char_list = list(predicted_text)
# 确保返回至少7个字符最多8个字符
if len(char_list) < 7:
# 如果识别结果少于7个字符用'0'补齐到7位
char_list.extend(['0'] * (7 - len(char_list)))
elif len(char_list) > 8:
# 如果识别结果多于8个字符截取前8个
char_list = char_list[:8]
# 如果是7位补齐到8位以保持接口一致性第8位用空字符或占位符
if len(char_list) == 7:
char_list.append('') # 添加空字符作为第8位占位符
return char_list
if len(preb_labels) > 0 and len(preb_labels[0]) > 0:
# 将索引转换为字符
predicted_chars = [CHARS[idx] for idx in preb_labels[0] if idx < len(CHARS)]
print(f"LPRNet识别结果: {''.join(predicted_chars)}")
# 确保返回7个字符车牌标准长度
if len(predicted_chars) < 7:
# 如果识别结果少于7个字符用'0'补齐
predicted_chars.extend(['0'] * (7 - len(predicted_chars)))
elif len(predicted_chars) > 7:
# 如果识别结果多于7个字符截取前7个
predicted_chars = predicted_chars[:7]
return predicted_chars
else:
print("LPRNet识别结果为空")
return ['', '', '', '', '0', '0', '0']
except Exception as e:
print(f"LPRNet识别失败: {e}")
import traceback
traceback.print_exc()
return ['', '', '', '', '0', '0', '0']
# 为了保持与其他模块的一致性,提供一个处理器类
class LPRProcessor:
def __init__(self):
self.initialized = False
def initialize(self, model_path=None):
"""初始化模型"""
self.initialized = LPRNinitialize_model(model_path)
return self.initialized
def predict(self, image_array):
"""预测接口"""
if not self.initialized:
print("模型未初始化")
return ['', '', '', '', '0', '0', '0']
return LPRNmodel_predict(image_array)
# 创建全局处理器实例
_processor = LPRProcessor()
def get_lpr_processor():
"""获取LPR处理器实例"""
return _processor
return ['', '', '', '', '0', '0', '0', '0']