PyTorch实战:5分钟搞定SE-ResNet18图像分类模型搭建(附完整代码)

PyTorch实战:5分钟搞定SE-ResNet18图像分类模型搭建(附完整代码) PyTorch实战5分钟搞定SE-ResNet18图像分类模型搭建附完整代码刚接触PyTorch时最让人头疼的莫过于从零搭建一个完整的神经网络模型。特别是当需要实现一些复杂结构时比如结合了注意力机制的残差网络光是理解各个模块的衔接关系就够喝一壶的。今天我们就来手把手实现一个SE-ResNet18模型这个在图像分类任务中表现出色的架构其实用PyTorch搭建起来比你想象的简单得多。1. 环境准备与基础概念在开始编码之前确保你的环境中已经安装了最新版的PyTorch。如果你使用conda管理环境可以这样安装conda install pytorch torchvision -c pytorchSE-ResNet是ResNet与Squeeze-and-Excitation(SE)注意力机制的完美结合。简单来说ResNet通过残差连接解决了深层网络训练困难的问题SE模块通过学习通道间的关系来自适应地调整各通道的重要性两者的结合使模型能够更有效地利用特征信息。下表对比了普通ResNet和SE-ResNet的主要区别特性ResNetSE-ResNet核心结构残差块带SE的残差块参数量相对较少略有增加特征利用平等对待各通道动态加权各通道典型应用基础分类任务精细分类任务提示SE模块增加的参数量很少通常不到1%但能带来显著的性能提升2. SE模块实现详解让我们先实现SE注意力模块这是整个模型的核心创新点。SE模块的工作流程可以分为三步Squeeze全局平均池化将空间信息压缩为一个通道描述符Excitation通过全连接层学习通道间关系Scale将学习到的权重应用到原始特征图上class SEModule(nn.Module): def __init__(self, channels, reduction16): super(SEModule, self).__init__() self.avg_pool nn.AdaptiveAvgPool2d(1) self.fc1 nn.Linear(channels, channels // reduction) self.fc2 nn.Linear(channels // reduction, channels) self.relu nn.ReLU(inplaceTrue) self.sigmoid nn.Sigmoid() def forward(self, x): batch, channels, _, _ x.size() y self.avg_pool(x).view(batch, channels) y self.fc1(y) y self.relu(y) y self.fc2(y) y self.sigmoid(y).view(batch, channels, 1, 1) return x * y.expand_as(x)这个实现有几个关键点值得注意AdaptiveAvgPool2d(1)将任意大小的特征图压缩为1x1降维比例reduction通常设为16平衡效果与计算量最后使用expand_as确保权重能正确广播到原始特征图尺寸3. 构建SE-ResNet18主体结构现在我们将SE模块整合到ResNet的基础块中。ResNet18使用的是BasicBlock相比Bottleneck结构更简单。class SEBasicBlock(nn.Module): expansion 1 def __init__(self, inplanes, planes, stride1, downsampleNone): super(SEBasicBlock, self).__init__() self.conv1 nn.Conv2d(inplanes, planes, kernel_size3, stridestride, padding1, biasFalse) self.bn1 nn.BatchNorm2d(planes) self.conv2 nn.Conv2d(planes, planes, kernel_size3, padding1, biasFalse) self.bn2 nn.BatchNorm2d(planes) self.relu nn.ReLU(inplaceTrue) self.se SEModule(planes) self.downsample downsample self.stride stride def forward(self, x): residual x out self.conv1(x) out self.bn1(out) out self.relu(out) out self.conv2(out) out self.bn2(out) out self.se(out) # 在此处插入SE模块 if self.downsample is not None: residual self.downsample(x) out residual out self.relu(out) return out这里最容易出错的点是维度匹配问题特别是在有下采样的块中。常见错误包括忘记处理残差连接的维度变化SE模块插入位置不当导致梯度消失下采样块与普通块混淆使用注意当stride≠1或通道数变化时必须通过downsample调整残差路径的维度4. 完整模型组装与训练技巧现在我们可以组装完整的SE-ResNet18了。模型结构与原始ResNet18基本相同只是用SEBasicBlock替换了原来的BasicBlock。class SEResNet18(nn.Module): def __init__(self, num_classes1000): super(SEResNet18, self).__init__() self.inplanes 64 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) self.bn1 nn.BatchNorm2d(64) self.relu nn.ReLU(inplaceTrue) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) self.layer1 self._make_layer(SEBasicBlock, 64, 2) self.layer2 self._make_layer(SEBasicBlock, 128, 2, stride2) self.layer3 self._make_layer(SEBasicBlock, 256, 2, stride2) self.layer4 self._make_layer(SEBasicBlock, 512, 2, stride2) self.avgpool nn.AdaptiveAvgPool2d(1) self.fc nn.Linear(512 * SEBasicBlock.expansion, num_classes) def _make_layer(self, block, planes, blocks, stride1): downsample None if stride ! 1 or self.inplanes ! planes * block.expansion: downsample nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size1, stridestride, biasFalse), nn.BatchNorm2d(planes * block.expansion), ) layers [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x self.conv1(x) x self.bn1(x) x self.relu(x) x self.maxpool(x) x self.layer1(x) x self.layer2(x) x self.layer3(x) x self.layer4(x) x self.avgpool(x) x x.view(x.size(0), -1) x self.fc(x) return x训练时有几个技巧可以提升SE-ResNet的表现学习率策略初始学习率设为0.1每30个epoch乘以0.1数据增强随机水平翻转、颜色抖动、随机裁剪优化器选择带动量的SGD通常比Adam表现更好model SEResNet18(num_classes10) criterion nn.CrossEntropyLoss() optimizer torch.optim.SGD(model.parameters(), lr0.1, momentum0.9, weight_decay1e-4) scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size30, gamma0.1)5. 常见问题与解决方案在实际项目中你可能会遇到以下典型问题问题1维度不匹配错误特别是在残差连接处解决方案检查downsample是否正确设置打印各层输出的shape进行调试确保SE模块不会改变特征图尺寸问题2模型收敛缓慢或性能不佳解决方案检查初始化是否正确特别是BatchNorm层尝试调整SE模块的reduction ratio增加数据增强的强度问题3显存不足解决方案# 减小batch size train_loader DataLoader(dataset, batch_size32, shuffleTrue) # 或者使用梯度累积 for i, (inputs, targets) in enumerate(train_loader): outputs model(inputs) loss criterion(outputs, targets) loss.backward() if (i1) % 4 0: # 每4个batch更新一次 optimizer.step() optimizer.zero_grad()最后分享一个实用技巧当需要快速验证模型结构时可以使用以下代码检查各层输出尺寸def check_model_output_shapes(model, input_size(1, 3, 224, 224)): device next(model.parameters()).device x torch.randn(input_size).to(device) for name, layer in model.named_children(): x layer(x) print(f{name}: {x.shape})