101 lines
3.6 KiB
Python
101 lines
3.6 KiB
Python
"""
|
|
NanoDetHeat v4: 残差 + SE注意力 + 扩通道
|
|
MACs: ~8.7M, 参数: ~18K, 目标: 20 TPS @ 0.44 GOPS (79%)
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
|
|
def conv_bn_relu(in_c, out_c, kernel, stride=1, groups=1):
|
|
pad = kernel // 2
|
|
return nn.Sequential(
|
|
nn.Conv2d(in_c, out_c, kernel, stride, pad, groups=groups, bias=False),
|
|
nn.BatchNorm2d(out_c),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
|
|
class SEResDWBlock(nn.Module):
|
|
"""残差DW + SE通道注意力"""
|
|
def __init__(self, in_c, out_c, stride=1, reduction=4):
|
|
super().__init__()
|
|
self.dw = conv_bn_relu(in_c, in_c, 3, stride=stride, groups=in_c)
|
|
self.pw = conv_bn_relu(in_c, out_c, 1)
|
|
|
|
self.skip = nn.Identity() if (stride == 1 and in_c == out_c) else nn.Sequential(
|
|
nn.Conv2d(in_c, out_c, 1, stride, bias=False),
|
|
nn.BatchNorm2d(out_c),
|
|
)
|
|
|
|
self.se = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Conv2d(out_c, out_c // reduction, 1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(out_c // reduction, out_c, 1),
|
|
nn.Sigmoid(),
|
|
)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
def forward(self, x):
|
|
out = self.pw(self.dw(x))
|
|
out = out + self.skip(x)
|
|
return self.relu(out * self.se(out))
|
|
|
|
|
|
class NanoDetHeatV4(nn.Module):
|
|
def __init__(self, num_classes=3):
|
|
super().__init__()
|
|
self.stem = conv_bn_relu(3, 8, 3, stride=2)
|
|
self.block1 = SEResDWBlock(8, 16, stride=1)
|
|
self.block2 = SEResDWBlock(16, 28, stride=2)
|
|
self.block3 = SEResDWBlock(28, 40, stride=1)
|
|
self.block4 = SEResDWBlock(40, 56, stride=2)
|
|
self.block5 = SEResDWBlock(56, 64, stride=1)
|
|
|
|
self.shared = conv_bn_relu(64, 24, 1)
|
|
self.cls_head = nn.Conv2d(24, num_classes + 1, 1, bias=True)
|
|
self.size_head = nn.Conv2d(24, 2, 1, bias=True)
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None: nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1); nn.init.constant_(m.bias, 0)
|
|
self.cls_head.bias.data[-1] = 0.5 # 轻微背景偏置,焦点损失自动调整
|
|
|
|
def forward(self, x):
|
|
x = self.stem(x); x = self.block1(x); x = self.block2(x)
|
|
x = self.block3(x); x = self.block4(x); x = self.block5(x)
|
|
x = self.shared(x)
|
|
return torch.cat([self.cls_head(x), self.size_head(x)], dim=1)
|
|
|
|
|
|
OUT_H, OUT_W = 15, 20
|
|
STRIDE = 8
|
|
|
|
|
|
if __name__ == "__main__":
|
|
m = NanoDetHeatV4()
|
|
p = sum(p.numel() for p in m.parameters())
|
|
|
|
macs = (3*3*3*8*80*60 + # stem
|
|
8*9*80*60 + 8*16*80*60 + 8*16*80*60 + # b1
|
|
16*9*40*30 + 16*28*40*30 + 16*28*40*30 + # b2
|
|
28*9*40*30 + 28*40*40*30 + # b3
|
|
40*9*20*15 + 40*56*20*15 + 40*56*20*15 + # b4
|
|
56*9*20*15 + 56*64*20*15 + # b5
|
|
64*24*20*15 + 24*4*20*15 + 24*2*20*15) # heads
|
|
|
|
print(f"Params: {p:,} = ~{p/1000:.1f}K")
|
|
print(f"MACs: {macs:,} = ~{macs/1e6:.2f}M")
|
|
print(f"20 TPS: {macs*20*2/1e9:.2f} GOPS ({(macs*20*2/0.44e9)*100:.0f}%)")
|
|
print(f"25 TPS: {macs*25*2/1e9:.2f} GOPS ({(macs*25*2/0.44e9)*100:.0f}%)")
|
|
x = torch.randn(1,3,120,160); y = m(x)
|
|
print(f"Output: {y.shape}")
|