From 01b286fce1e79fca8348e42ba8521f75bcb4ca96 Mon Sep 17 00:00:00 2001 From: spdis Date: Sun, 31 Aug 2025 12:15:38 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20CRNN=5Fpart/crnn=5Finterfa?= =?UTF-8?q?ce.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CRNN_part/crnn_interface.py | 322 ++++++++++++++++++++++++++++++++++-- 1 file changed, 308 insertions(+), 14 deletions(-) diff --git a/CRNN_part/crnn_interface.py b/CRNN_part/crnn_interface.py index 31e8c0f..594595f 100644 --- a/CRNN_part/crnn_interface.py +++ b/CRNN_part/crnn_interface.py @@ -1,4 +1,211 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F import numpy as np +from PIL import Image +import cv2 +from torchvision import transforms +import os + +# 全局变量 +crnn_model = None +crnn_decoder = None +crnn_preprocessor = None +device = None + +class CRNN(nn.Module): + """CRNN车牌识别模型""" + def __init__(self, img_height=32, num_classes=68, hidden_size=256): + super(CRNN, self).__init__() + self.img_height = img_height + self.num_classes = num_classes + self.hidden_size = hidden_size + + # CNN特征提取部分 - 7层卷积 + self.cnn = nn.Sequential( + # 第1层:3->64, 3x3卷积 + nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + + # 第2层:64->128, 3x3卷积 + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + + # 第3层:128->256, 3x3卷积 + nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + + # 第4层:256->256, 3x3卷积 + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), + + # 第5层:256->512, 3x3卷积 + nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + + # 第6层:512->512, 3x3卷积 + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), + + # 第7层:512->512, 2x2卷积 + nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ) + + # RNN序列建模部分 - 2层双向LSTM + self.rnn = nn.LSTM( + input_size=512, + hidden_size=hidden_size, + num_layers=2, + batch_first=True, + bidirectional=True + ) + + # 全连接分类层 + self.fc = nn.Linear(hidden_size * 2, num_classes) + + def forward(self, x): + batch_size = x.size(0) + + # CNN特征提取 + conv_out = self.cnn(x) + + # 重塑为RNN输入格式 + batch_size, channels, height, width = conv_out.size() + conv_out = conv_out.permute(0, 3, 1, 2) + conv_out = conv_out.contiguous().view(batch_size, width, channels * height) + + # RNN序列建模 + rnn_out, _ = self.rnn(conv_out) + + # 全连接分类 + output = self.fc(rnn_out) + + # 转换为CTC需要的格式:(width, batch_size, num_classes) + output = output.permute(1, 0, 2) + + return output + +class CTCDecoder: + """CTC解码器""" + def __init__(self): + # 定义中国车牌字符集(68个字符) + self.chars = [ + # 空白字符(CTC需要) + '', + # 中文省份简称 + '京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑', + '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤', + '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新', + # 字母 A-Z + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + # 数字 0-9 + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9' + ] + + self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)} + self.idx_to_char = {idx: char for idx, char in enumerate(self.chars)} + self.blank_idx = 0 + + def decode_greedy(self, predictions): + """贪婪解码""" + # 获取每个时间步的最大概率索引 + indices = torch.argmax(predictions, dim=1) + + # CTC解码:移除重复字符和空白字符 + decoded_chars = [] + prev_idx = -1 + + for idx in indices: + idx = idx.item() + if idx != prev_idx and idx != self.blank_idx: + if idx < len(self.chars): + decoded_chars.append(self.chars[idx]) + prev_idx = idx + + return ''.join(decoded_chars) + + def decode_with_confidence(self, predictions): + """解码并返回置信度信息""" + # 应用softmax获得概率 + probs = torch.softmax(predictions, dim=1) + + # 贪婪解码 + indices = torch.argmax(probs, dim=1) + max_probs = torch.max(probs, dim=1)[0] + + # CTC解码 + decoded_chars = [] + char_confidences = [] + prev_idx = -1 + + for i, idx in enumerate(indices): + idx = idx.item() + confidence = max_probs[i].item() + + if idx != prev_idx and idx != self.blank_idx: + if idx < len(self.chars): + decoded_chars.append(self.chars[idx]) + char_confidences.append(confidence) + prev_idx = idx + + text = ''.join(decoded_chars) + avg_confidence = np.mean(char_confidences) if char_confidences else 0.0 + + return text, avg_confidence, char_confidences + +class LicensePlatePreprocessor: + """车牌图像预处理器""" + def __init__(self, target_height=32, target_width=128): + self.target_height = target_height + self.target_width = target_width + + # 定义图像变换 + self.transform = transforms.Compose([ + transforms.Resize((target_height, target_width)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def preprocess_numpy_array(self, image_array): + """预处理numpy数组格式的图像""" + try: + # 确保图像是RGB格式 + if len(image_array.shape) == 3 and image_array.shape[2] == 3: + # 如果是BGR格式,转换为RGB + if image_array.dtype == np.uint8: + image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB) + + # 转换为PIL图像 + if image_array.dtype != np.uint8: + image_array = (image_array * 255).astype(np.uint8) + + image = Image.fromarray(image_array) + + # 应用变换 + tensor = self.transform(image) + + # 添加batch维度 + tensor = tensor.unsqueeze(0) + + return tensor + + except Exception as e: + print(f"图像预处理失败: {e}") + return None def initialize_crnn_model(): """ @@ -7,12 +214,65 @@ def initialize_crnn_model(): 返回: bool: 初始化是否成功 """ - # CRNN模型初始化代码 - # 例如: 加载预训练模型、设置参数等 + global crnn_model, crnn_decoder, crnn_preprocessor, device - print("CRNN模型初始化完成(占位)") - return True - + try: + # 设置设备 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f"CRNN使用设备: {device}") + + # 初始化组件 + crnn_decoder = CTCDecoder() + crnn_preprocessor = LicensePlatePreprocessor(target_height=32, target_width=128) + + # 创建模型实例 + crnn_model = CRNN(num_classes=len(crnn_decoder.chars), hidden_size=256) + + # 加载模型权重 + model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth') + + if not os.path.exists(model_path): + raise FileNotFoundError(f"模型文件不存在: {model_path}") + + print(f"正在加载CRNN模型: {model_path}") + + # 加载检查点 + checkpoint = torch.load(model_path, map_location=device, weights_only=False) + + # 处理不同的模型保存格式 + if isinstance(checkpoint, dict): + if 'model_state_dict' in checkpoint: + # 完整检查点格式 + state_dict = checkpoint['model_state_dict'] + print(f"检查点信息:") + print(f" - 训练轮次: {checkpoint.get('epoch', 'N/A')}") + print(f" - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}") + else: + # 精简模型格式(只包含权重) + print("加载精简模型(仅权重)") + state_dict = checkpoint + else: + # 直接是状态字典 + state_dict = checkpoint + + # 加载权重 + crnn_model.load_state_dict(state_dict) + crnn_model.to(device) + crnn_model.eval() + + print("CRNN模型初始化完成") + + # 统计模型参数 + total_params = sum(p.numel() for p in crnn_model.parameters()) + print(f"CRNN模型参数数量: {total_params:,}") + + return True + + except Exception as e: + print(f"CRNN模型初始化失败: {e}") + import traceback + traceback.print_exc() + return False def crnn_predict(image_array): """ @@ -25,13 +285,47 @@ def crnn_predict(image_array): list: 包含7个字符的列表,代表车牌号的每个字符 例如: ['京', 'A', '1', '2', '3', '4', '5'] """ - # 这是CRNN部分的占位函数 - # 实际实现时,这里应该包含: - # 1. 图像预处理 - # 2. CRNN模型推理 - # 3. CTC解码 - # 4. 后处理和字符识别 + global crnn_model, crnn_decoder, crnn_preprocessor, device - # 临时返回占位结果 - placeholder_result = ['待', '识', '别', '0', '0', '0', '0'] - return placeholder_result + if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None: + print("CRNN模型未初始化,请先调用initialize_crnn_model()") + return ['待', '识', '别', '0', '0', '0', '0'] + + try: + # 预处理图像 + input_tensor = crnn_preprocessor.preprocess_numpy_array(image_array) + if input_tensor is None: + raise ValueError("图像预处理失败") + + input_tensor = input_tensor.to(device) + + # 模型推理 + with torch.no_grad(): + outputs = crnn_model(input_tensor) # (seq_len, batch_size, num_classes) + + # 移除batch维度 + outputs = outputs.squeeze(1) # (seq_len, num_classes) + + # CTC解码 + predicted_text, confidence, char_confidences = crnn_decoder.decode_with_confidence(outputs) + + print(f"CRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}") + + # 将字符串转换为字符列表 + char_list = list(predicted_text) + + # 确保返回7个字符(车牌标准长度) + if len(char_list) < 7: + # 如果识别结果少于7个字符,用'0'补齐 + char_list.extend(['0'] * (7 - len(char_list))) + elif len(char_list) > 7: + # 如果识别结果多于7个字符,截取前7个 + char_list = char_list[:7] + + return char_list + + except Exception as e: + print(f"CRNN识别失败: {e}") + import traceback + traceback.print_exc() + return ['识', '别', '失', '败', '0', '0', '0']