用PyTorch代码逐层拆解ResNet18从张量流动理解残差连接当你第一次看到ResNet18的结构图时是否曾被那些交错连接的箭头弄得晕头转向作为计算机视觉领域的里程碑式架构残差网络(ResNet)通过引入跳跃连接(skip connection)解决了深度神经网络中的梯度消失问题。但纸上得来终觉浅本文将带你用PyTorch从零构建ResNet18通过打印每一层的张量形状变化直观理解数据在网络中的流动路径。1. 残差网络的核心思想在传统的卷积神经网络中随着网络层数的增加模型性能往往会达到饱和甚至下降。这种现象被称为退化问题(degradation problem)并非由过拟合引起而是因为深层网络难以优化。ResNet创造性地提出了残差学习框架——与其让网络直接拟合目标映射H(x)不如让它学习残差函数F(x) H(x) - x这样原始映射就变为F(x) x。这种设计的精妙之处在于恒等映射的捷径当残差F(x)趋近于0时该层仅需执行恒等映射这使得深层网络的训练至少不会比浅层网络更困难梯度高速公路跳跃连接为反向传播提供了直达路径有效缓解了梯度消失问题特征复用浅层特征可以直接传递到深层避免了重复学习import torch import torch.nn as nn class BasicBlock(nn.Module): expansion 1 def __init__(self, in_channels, out_channels, stride1): super().__init__() self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse) self.bn1 nn.BatchNorm2d(out_channels) self.conv2 nn.Conv2d(out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse) self.bn2 nn.BatchNorm2d(out_channels) self.shortcut nn.Sequential() if stride ! 1 or in_channels ! out_channels: self.shortcut nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size1, stridestride, biasFalse), nn.BatchNorm2d(out_channels) ) def forward(self, x): print(f输入形状: {x.shape}) residual x out torch.relu(self.bn1(self.conv1(x))) print(f第一个卷积后形状: {out.shape}) out self.bn2(self.conv2(out)) print(f第二个卷积后形状: {out.shape}) out self.shortcut(residual) print(f残差连接后形状: {out.shape}) out torch.relu(out) return out2. ResNet18的层次结构解析ResNet18由初始卷积层、四个残差块阶段和全连接层组成。让我们分解每个阶段的数据变化过程重点关注特征图尺寸和通道数的变化规律。2.1 初始卷积层输入图像首先经过一个7×7的大卷积核进行初步特征提取这有助于在早期捕获更大范围的视觉特征# 假设输入为3通道的224x224图像 x torch.randn(1, 3, 224, 224) conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) bn1 nn.BatchNorm2d(64) maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) out conv1(x) print(f初始卷积后形状: {out.shape}) # torch.Size([1, 64, 112, 112]) out bn1(out) out torch.relu(out) out maxpool(out) print(f池化后形状: {out.shape}) # torch.Size([1, 64, 56, 56])2.2 残差块阶段ResNet18包含四个主要阶段每个阶段由多个BasicBlock组成。观察虚线残差块的特殊处理阶段块类型重复次数输出尺寸通道变化1BasicBlock256×5664→642BasicBlock228×2864→1283BasicBlock214×14128→2564BasicBlock27×7256→512# 阶段1示例 - 无下采样 layer1 nn.Sequential( BasicBlock(64, 64), BasicBlock(64, 64) ) out layer1(out) print(f阶段1输出形状: {out.shape}) # 阶段2示例 - 带下采样 layer2 nn.Sequential( BasicBlock(64, 128, stride2), # 注意stride2 BasicBlock(128, 128) ) out layer2(out) print(f阶段2输出形状: {out.shape})2.3 跳跃连接实现细节残差块中的跳跃连接处理分为两种情况实线连接当输入输出通道数相同时直接相加虚线连接当通道数变化时使用1×1卷积调整通道和尺寸# 实线连接示例 x torch.randn(1, 64, 56, 56) block BasicBlock(64, 64) out block(x) # 直接相加 # 虚线连接示例 x torch.randn(1, 64, 56, 56) block BasicBlock(64, 128, stride2) out block(x) # 使用1×1卷积调整3. 完整ResNet18实现与验证现在我们将所有组件组合起来构建完整的ResNet18并验证其结构class ResNet18(nn.Module): def __init__(self, num_classes1000): super().__init__() self.in_channels 64 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) self.bn1 nn.BatchNorm2d(64) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) self.layer1 self._make_layer(64, 2, stride1) self.layer2 self._make_layer(128, 2, stride2) self.layer3 self._make_layer(256, 2, stride2) self.layer4 self._make_layer(512, 2, stride2) self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * BasicBlock.expansion, num_classes) def _make_layer(self, out_channels, blocks, stride): layers [] layers.append(BasicBlock(self.in_channels, out_channels, stride)) self.in_channels out_channels * BasicBlock.expansion for _ in range(1, blocks): layers.append(BasicBlock(self.in_channels, out_channels)) return nn.Sequential(*layers) def forward(self, x): print(\n 初始卷积 ) x self.conv1(x) print(f初始卷积输出: {x.shape}) x self.bn1(x) x torch.relu(x) print(\n 最大池化 ) x self.maxpool(x) print(f池化后: {x.shape}) print(\n 阶段1 ) x self.layer1(x) print(f阶段1输出: {x.shape}) print(\n 阶段2 ) x self.layer2(x) print(f阶段2输出: {x.shape}) print(\n 阶段3 ) x self.layer3(x) print(f阶段3输出: {x.shape}) print(\n 阶段4 ) x self.layer4(x) print(f阶段4输出: {x.shape}) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x # 验证网络结构 model ResNet18() input_tensor torch.randn(1, 3, 224, 224) output model(input_tensor)4. 残差连接的可视化分析为了更直观地理解残差连接的作用我们可以对比有无跳跃连接时的梯度流动有残差连接时的梯度计算∂loss/∂x ∂loss/∂F(x) * ∂F(x)/∂x ∂loss/∂F(x)无残差连接时的梯度计算∂loss/∂x ∂loss/∂F(x) * ∂F(x)/∂x这种设计确保了即使深层网络的梯度很小至少能有∂loss/∂F(x)这一项直接传递到浅层避免了梯度消失。在实际项目中调试ResNet时有几个实用技巧值得注意初始化残差块最后一层BN的γ为0这样初始状态下残差块输出为0网络从浅层开始学习下采样放在第一个残差块这样后续块可以专注于特征提取而非尺寸调整适当使用预训练权重特别是当数据集较小时ImageNet预训练的特征提取器非常有效
别再死记ResNet18结构图了!用PyTorch代码逐层拆解,搞懂残差连接到底怎么跑的
用PyTorch代码逐层拆解ResNet18从张量流动理解残差连接当你第一次看到ResNet18的结构图时是否曾被那些交错连接的箭头弄得晕头转向作为计算机视觉领域的里程碑式架构残差网络(ResNet)通过引入跳跃连接(skip connection)解决了深度神经网络中的梯度消失问题。但纸上得来终觉浅本文将带你用PyTorch从零构建ResNet18通过打印每一层的张量形状变化直观理解数据在网络中的流动路径。1. 残差网络的核心思想在传统的卷积神经网络中随着网络层数的增加模型性能往往会达到饱和甚至下降。这种现象被称为退化问题(degradation problem)并非由过拟合引起而是因为深层网络难以优化。ResNet创造性地提出了残差学习框架——与其让网络直接拟合目标映射H(x)不如让它学习残差函数F(x) H(x) - x这样原始映射就变为F(x) x。这种设计的精妙之处在于恒等映射的捷径当残差F(x)趋近于0时该层仅需执行恒等映射这使得深层网络的训练至少不会比浅层网络更困难梯度高速公路跳跃连接为反向传播提供了直达路径有效缓解了梯度消失问题特征复用浅层特征可以直接传递到深层避免了重复学习import torch import torch.nn as nn class BasicBlock(nn.Module): expansion 1 def __init__(self, in_channels, out_channels, stride1): super().__init__() self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse) self.bn1 nn.BatchNorm2d(out_channels) self.conv2 nn.Conv2d(out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse) self.bn2 nn.BatchNorm2d(out_channels) self.shortcut nn.Sequential() if stride ! 1 or in_channels ! out_channels: self.shortcut nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size1, stridestride, biasFalse), nn.BatchNorm2d(out_channels) ) def forward(self, x): print(f输入形状: {x.shape}) residual x out torch.relu(self.bn1(self.conv1(x))) print(f第一个卷积后形状: {out.shape}) out self.bn2(self.conv2(out)) print(f第二个卷积后形状: {out.shape}) out self.shortcut(residual) print(f残差连接后形状: {out.shape}) out torch.relu(out) return out2. ResNet18的层次结构解析ResNet18由初始卷积层、四个残差块阶段和全连接层组成。让我们分解每个阶段的数据变化过程重点关注特征图尺寸和通道数的变化规律。2.1 初始卷积层输入图像首先经过一个7×7的大卷积核进行初步特征提取这有助于在早期捕获更大范围的视觉特征# 假设输入为3通道的224x224图像 x torch.randn(1, 3, 224, 224) conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) bn1 nn.BatchNorm2d(64) maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) out conv1(x) print(f初始卷积后形状: {out.shape}) # torch.Size([1, 64, 112, 112]) out bn1(out) out torch.relu(out) out maxpool(out) print(f池化后形状: {out.shape}) # torch.Size([1, 64, 56, 56])2.2 残差块阶段ResNet18包含四个主要阶段每个阶段由多个BasicBlock组成。观察虚线残差块的特殊处理阶段块类型重复次数输出尺寸通道变化1BasicBlock256×5664→642BasicBlock228×2864→1283BasicBlock214×14128→2564BasicBlock27×7256→512# 阶段1示例 - 无下采样 layer1 nn.Sequential( BasicBlock(64, 64), BasicBlock(64, 64) ) out layer1(out) print(f阶段1输出形状: {out.shape}) # 阶段2示例 - 带下采样 layer2 nn.Sequential( BasicBlock(64, 128, stride2), # 注意stride2 BasicBlock(128, 128) ) out layer2(out) print(f阶段2输出形状: {out.shape})2.3 跳跃连接实现细节残差块中的跳跃连接处理分为两种情况实线连接当输入输出通道数相同时直接相加虚线连接当通道数变化时使用1×1卷积调整通道和尺寸# 实线连接示例 x torch.randn(1, 64, 56, 56) block BasicBlock(64, 64) out block(x) # 直接相加 # 虚线连接示例 x torch.randn(1, 64, 56, 56) block BasicBlock(64, 128, stride2) out block(x) # 使用1×1卷积调整3. 完整ResNet18实现与验证现在我们将所有组件组合起来构建完整的ResNet18并验证其结构class ResNet18(nn.Module): def __init__(self, num_classes1000): super().__init__() self.in_channels 64 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) self.bn1 nn.BatchNorm2d(64) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) self.layer1 self._make_layer(64, 2, stride1) self.layer2 self._make_layer(128, 2, stride2) self.layer3 self._make_layer(256, 2, stride2) self.layer4 self._make_layer(512, 2, stride2) self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * BasicBlock.expansion, num_classes) def _make_layer(self, out_channels, blocks, stride): layers [] layers.append(BasicBlock(self.in_channels, out_channels, stride)) self.in_channels out_channels * BasicBlock.expansion for _ in range(1, blocks): layers.append(BasicBlock(self.in_channels, out_channels)) return nn.Sequential(*layers) def forward(self, x): print(\n 初始卷积 ) x self.conv1(x) print(f初始卷积输出: {x.shape}) x self.bn1(x) x torch.relu(x) print(\n 最大池化 ) x self.maxpool(x) print(f池化后: {x.shape}) print(\n 阶段1 ) x self.layer1(x) print(f阶段1输出: {x.shape}) print(\n 阶段2 ) x self.layer2(x) print(f阶段2输出: {x.shape}) print(\n 阶段3 ) x self.layer3(x) print(f阶段3输出: {x.shape}) print(\n 阶段4 ) x self.layer4(x) print(f阶段4输出: {x.shape}) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x # 验证网络结构 model ResNet18() input_tensor torch.randn(1, 3, 224, 224) output model(input_tensor)4. 残差连接的可视化分析为了更直观地理解残差连接的作用我们可以对比有无跳跃连接时的梯度流动有残差连接时的梯度计算∂loss/∂x ∂loss/∂F(x) * ∂F(x)/∂x ∂loss/∂F(x)无残差连接时的梯度计算∂loss/∂x ∂loss/∂F(x) * ∂F(x)/∂x这种设计确保了即使深层网络的梯度很小至少能有∂loss/∂F(x)这一项直接传递到浅层避免了梯度消失。在实际项目中调试ResNet时有几个实用技巧值得注意初始化残差块最后一层BN的γ为0这样初始状态下残差块输出为0网络从浅层开始学习下采样放在第一个残差块这样后续块可以专注于特征提取而非尺寸调整适当使用预训练权重特别是当数据集较小时ImageNet预训练的特征提取器非常有效