""" NanoDetHeat v4: 残差 + SE注意力 + 扩通道 MACs: ~8.7M, 参数: ~18K, 目标: 20 TPS @ 0.44 GOPS (79%) """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np def conv_bn_relu(in_c, out_c, kernel, stride=1, groups=1): pad = kernel // 2 return nn.Sequential( nn.Conv2d(in_c, out_c, kernel, stride, pad, groups=groups, bias=False), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True), ) class SEResDWBlock(nn.Module): """残差DW + SE通道注意力""" def __init__(self, in_c, out_c, stride=1, reduction=4): super().__init__() self.dw = conv_bn_relu(in_c, in_c, 3, stride=stride, groups=in_c) self.pw = conv_bn_relu(in_c, out_c, 1) self.skip = nn.Identity() if (stride == 1 and in_c == out_c) else nn.Sequential( nn.Conv2d(in_c, out_c, 1, stride, bias=False), nn.BatchNorm2d(out_c), ) self.se = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(out_c, out_c // reduction, 1), nn.ReLU(inplace=True), nn.Conv2d(out_c // reduction, out_c, 1), nn.Sigmoid(), ) self.relu = nn.ReLU(inplace=True) def forward(self, x): out = self.pw(self.dw(x)) out = out + self.skip(x) return self.relu(out * self.se(out)) class NanoDetHeatV4(nn.Module): def __init__(self, num_classes=3): super().__init__() self.stem = conv_bn_relu(3, 8, 3, stride=2) self.block1 = SEResDWBlock(8, 16, stride=1) self.block2 = SEResDWBlock(16, 28, stride=2) self.block3 = SEResDWBlock(28, 40, stride=1) self.block4 = SEResDWBlock(40, 56, stride=2) self.block5 = SEResDWBlock(56, 64, stride=1) self.shared = conv_bn_relu(64, 24, 1) self.cls_head = nn.Conv2d(24, num_classes + 1, 1, bias=True) self.size_head = nn.Conv2d(24, 2, 1, bias=True) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1); nn.init.constant_(m.bias, 0) self.cls_head.bias.data[-1] = 0.5 # 轻微背景偏置,焦点损失自动调整 def forward(self, x): x = self.stem(x); x = self.block1(x); x = self.block2(x) x = self.block3(x); x = self.block4(x); x = self.block5(x) x = self.shared(x) return torch.cat([self.cls_head(x), self.size_head(x)], dim=1) OUT_H, OUT_W = 15, 20 STRIDE = 8 if __name__ == "__main__": m = NanoDetHeatV4() p = sum(p.numel() for p in m.parameters()) macs = (3*3*3*8*80*60 + # stem 8*9*80*60 + 8*16*80*60 + 8*16*80*60 + # b1 16*9*40*30 + 16*28*40*30 + 16*28*40*30 + # b2 28*9*40*30 + 28*40*40*30 + # b3 40*9*20*15 + 40*56*20*15 + 40*56*20*15 + # b4 56*9*20*15 + 56*64*20*15 + # b5 64*24*20*15 + 24*4*20*15 + 24*2*20*15) # heads print(f"Params: {p:,} = ~{p/1000:.1f}K") print(f"MACs: {macs:,} = ~{macs/1e6:.2f}M") print(f"20 TPS: {macs*20*2/1e9:.2f} GOPS ({(macs*20*2/0.44e9)*100:.0f}%)") print(f"25 TPS: {macs*25*2/1e9:.2f} GOPS ({(macs*25*2/0.44e9)*100:.0f}%)") x = torch.randn(1,3,120,160); y = m(x) print(f"Output: {y.shape}")