113_站在巨人的肩膀上:PyTorch 经典模型(VGG16)的获取与自定义修改

113_站在巨人的肩膀上:PyTorch 经典模型(VGG16)的获取与自定义修改 在处理图像识别任务时我们不需要总是从零开始搭建网络。PyTorch 的torchvision.models库提供了大量已经在 ImageNet 大规模数据集上训练好的经典模型。本文将教你如何获取这些模型并根据自己的需求如将 1000 分类改为 10 分类进行灵活修改。1. 现有网络模型的获取官方模型库提供了两种获取方式区别在于是否加载预训练好的参数pretrainedFalse只下载网络结构参数是随机初始化的默认方式。pretrainedTrue不仅下载结构还下载已经在 ImageNet 上训练好的参数。这通常用于迁移学习能极大地加快收敛速度。代码实现2. 为什么要修改网络模型像 VGG16 这样的模型其输出层通常是为 ImageNet 设计的输出 1000 个类别。但如果我们处理的是 CIFAR-10只需输出 10 个类别我们就需要对模型的结构进行微调。3. 实战修改模型的两种常用方法文件展示了如何通过“添加层”或“替换层”来适配 10 分类任务方法一在现有结构后添加层 (add_module)我们可以保持原有的classifier结构不变在其最后追加一个新的线性层。import torchvision from torch import nn dataset torchvision.datasets.CIFAR10(./dataset,trainTrue,transformtorchvision.transforms.ToTensor(),downloadTrue) vgg16_true torchvision.models.vgg16(pretrainedTrue) # 下载卷积层对应的参数是多少、池化层对应的参数时多少这些参数时ImageNet训练好了的 vgg16_true.add_module(add_linear,nn.Linear(1000,10)) # 在VGG16后面添加一个线性层使得输出为适应CIFAR10的输出CIFAR10需要输出10个种类 print(vgg16_true)方法二直接修改/替换现有层如果你觉得多加一层太麻烦可以直接修改classifier中的最后一个子模块。import torchvision from torch import nn vgg16_false torchvision.models.vgg16(pretrainedFalse) # 没有预训练的参数 print(vgg16_false) vgg16_false.classifier[6] nn.Linear(4096,10) print(vgg16_false)4. 迁移学习的意义通过修改现有模型我们实际上是在进行迁移学习Transfer Learning特征提取器利用 VGG 前半部分强大的特征提取能力这部分在 ImageNet 上学到了识别线条、形状、纹理的通用能力。自定义分类器只针对我们特定的数据集训练最后的几层全连接层。这种方法在数据集较小时效果尤为显著能避免过拟合且大幅节省算力。5. 总结分析该文件后我们可以掌握以下技巧加载模型利用torchvision.models快速调用经典结构。查看结构直接print(model)找到需要修改的层名称或索引。动态修改使用add_module或直接索引赋值来改变网络层级使其适配你的任务。 学习小结学会修改官方模型是迈向中高级开发者的重要一步。你不再受限于简单的 3 层卷积而是可以自由调用 ResNet、VGG、MobileNet 等工业级模型来解决复杂的视觉问题。