前言在深度学习训练中我们经常需要修改官方预训练模型VGG、ResNet给网络新增网络层保存训练好的模型、断点续训加载模型继续训练/推理PyTorch 提供了两种保存模型、两种加载模型的方式新手极易混淆、极易报错。本文结合 VGG16 实战代码一次性彻底讲透以后永远不用死记硬背。一、基础认知PyTorch 模型构成一个完整的神经网络包含两部分模型结构卷积、池化、全连接、Sequential 等网络骨架模型参数权重weight、bias可训练参数两种保存方式本质区别方式1保存结构 权重方式2只保存权重参数官方推荐二、实战1加载官方模型 自定义修改网络1. 加载原生VGG16无预训练权重import torch from torchvision import models # 只加载网络结构权重随机初始化 vgg16 models.vgg16(weightsNone) print(vgg16)此时 VGG16 默认输出是1000 分类ImageNet数据集。2. 给模型【新增网络层】重点如果我要改成10分类CIFAR10可以直接给模型末尾加一层# 给模型动态添加一层全连接 vgg16.add_module(add_linear, torch.nn.Linear(1000, 10)) print(vgg16)add_module(层名, 层结构)执行后模型结构多了add_linear: Linear(1000, 10)✅ 此时模型结构被永久修改权重也多出两组参数add_linear.weightadd_linear.bias后续保存权重就会包含这两个参数加载时结构必须一致否则报错三、PyTorch 两种模型保存方式核心重点方式一保存整个模型对象结构权重# 保存整个模型 torch.save(vgg16, vgg16_method1.pth)保存内容网络结构 所有权重参数优点加载不用写模型结构一行直接用致命缺点兼容性极差PyTorch版本更新就炸2.6版本后严格限制存在安全风险新版本默认禁止加载❌工作、比赛、项目一律不推荐方式二只保存模型参数state_dict【官方推荐】# 只保存权重参数字典 torch.save(vgg16.state_dict(), vgg16_method2.pth)保存内容仅保存所有 weight、bias字典形式优点体积最小、纯参数、无版本兼容问题绝对安全工业界通用标准缺点加载前必须手动搭建一模一样的模型结构✅唯一推荐的保存方式四、两种对应的加载方式必须一一对应方式1加载【完整模型文件】对应保存方式1import torch # 新版PyTorch必须加 weights_onlyFalse model torch.load(vgg16_method1.pth, weights_onlyFalse) print(model)直接加载出完整模型不需要搭建网络。⚠️ 仅限自己本地测试使用方式2加载【权重参数文件】对应保存方式2最容易报错、最重要规则加载时的模型结构 保存时的模型结构必须完全一致因为我保存时多加了一层 add_linear所以加载时也要加import torch from torchvision import models # 1. 先搭建【和保存时一模一样】的结构 model models.vgg16(weightsNone) model.add_module(add_linear, torch.nn.Linear(1000, 10)) # 2. 再加载权重 model.load_state_dict(torch.load(vgg16_method2.pth)) print(model)五、你之前报错的终极解释必看报错内容Unexpected key(s) in state_dict: add_linear.weight, add_linear.bias报错原因100%精准保存时模型多了 add_linear 层权重文件里存了这两个参数加载时你只创建了原生 VGG16没有 add_linear结构对不上权重参数无处安放 → 直接报错✅解决方法加载前补全所有自定义层六、拓展常用两种模型修改方式方式1add_module 新增层你学的适合模型末尾追加分类头方式2直接替换原有层更常用# 直接替换VGG最后一层分类器 vgg16.classifier[6] torch.nn.Linear(4096, 10)七、最终总结可背诵1. 两种保存区别torch.save(model, path)保存结构权重省事但不通用、新版本易报错torch.save(model.state_dict(), path)只保存权重官方标准、工业唯一推荐2. 两种加载规则完整模型文件 → torch.load() 直接加载权重参数文件 →先搭结构、后载权重3. 报错核心口诀权重里有的层模型里必须有权重里没有的层模型不能乱加。4. 工作最佳实践永远只用state_dict 保存 先建结构再加载八、完整可运行模板以后直接复制保存模板model models.vgg16(weightsNone) model.add_module(add_linear, torch.nn.Linear(1000,10)) torch.save(model.state_dict(), best.pth)加载模板model models.vgg16(weightsNone) model.add_module(add_linear, torch.nn.Linear(1000,10)) model.load_state_dict(torch.load(best.pth))
PyTorch零基础】模型修改、添加网络层、两种模型保存与加载方式
前言在深度学习训练中我们经常需要修改官方预训练模型VGG、ResNet给网络新增网络层保存训练好的模型、断点续训加载模型继续训练/推理PyTorch 提供了两种保存模型、两种加载模型的方式新手极易混淆、极易报错。本文结合 VGG16 实战代码一次性彻底讲透以后永远不用死记硬背。一、基础认知PyTorch 模型构成一个完整的神经网络包含两部分模型结构卷积、池化、全连接、Sequential 等网络骨架模型参数权重weight、bias可训练参数两种保存方式本质区别方式1保存结构 权重方式2只保存权重参数官方推荐二、实战1加载官方模型 自定义修改网络1. 加载原生VGG16无预训练权重import torch from torchvision import models # 只加载网络结构权重随机初始化 vgg16 models.vgg16(weightsNone) print(vgg16)此时 VGG16 默认输出是1000 分类ImageNet数据集。2. 给模型【新增网络层】重点如果我要改成10分类CIFAR10可以直接给模型末尾加一层# 给模型动态添加一层全连接 vgg16.add_module(add_linear, torch.nn.Linear(1000, 10)) print(vgg16)add_module(层名, 层结构)执行后模型结构多了add_linear: Linear(1000, 10)✅ 此时模型结构被永久修改权重也多出两组参数add_linear.weightadd_linear.bias后续保存权重就会包含这两个参数加载时结构必须一致否则报错三、PyTorch 两种模型保存方式核心重点方式一保存整个模型对象结构权重# 保存整个模型 torch.save(vgg16, vgg16_method1.pth)保存内容网络结构 所有权重参数优点加载不用写模型结构一行直接用致命缺点兼容性极差PyTorch版本更新就炸2.6版本后严格限制存在安全风险新版本默认禁止加载❌工作、比赛、项目一律不推荐方式二只保存模型参数state_dict【官方推荐】# 只保存权重参数字典 torch.save(vgg16.state_dict(), vgg16_method2.pth)保存内容仅保存所有 weight、bias字典形式优点体积最小、纯参数、无版本兼容问题绝对安全工业界通用标准缺点加载前必须手动搭建一模一样的模型结构✅唯一推荐的保存方式四、两种对应的加载方式必须一一对应方式1加载【完整模型文件】对应保存方式1import torch # 新版PyTorch必须加 weights_onlyFalse model torch.load(vgg16_method1.pth, weights_onlyFalse) print(model)直接加载出完整模型不需要搭建网络。⚠️ 仅限自己本地测试使用方式2加载【权重参数文件】对应保存方式2最容易报错、最重要规则加载时的模型结构 保存时的模型结构必须完全一致因为我保存时多加了一层 add_linear所以加载时也要加import torch from torchvision import models # 1. 先搭建【和保存时一模一样】的结构 model models.vgg16(weightsNone) model.add_module(add_linear, torch.nn.Linear(1000, 10)) # 2. 再加载权重 model.load_state_dict(torch.load(vgg16_method2.pth)) print(model)五、你之前报错的终极解释必看报错内容Unexpected key(s) in state_dict: add_linear.weight, add_linear.bias报错原因100%精准保存时模型多了 add_linear 层权重文件里存了这两个参数加载时你只创建了原生 VGG16没有 add_linear结构对不上权重参数无处安放 → 直接报错✅解决方法加载前补全所有自定义层六、拓展常用两种模型修改方式方式1add_module 新增层你学的适合模型末尾追加分类头方式2直接替换原有层更常用# 直接替换VGG最后一层分类器 vgg16.classifier[6] torch.nn.Linear(4096, 10)七、最终总结可背诵1. 两种保存区别torch.save(model, path)保存结构权重省事但不通用、新版本易报错torch.save(model.state_dict(), path)只保存权重官方标准、工业唯一推荐2. 两种加载规则完整模型文件 → torch.load() 直接加载权重参数文件 →先搭结构、后载权重3. 报错核心口诀权重里有的层模型里必须有权重里没有的层模型不能乱加。4. 工作最佳实践永远只用state_dict 保存 先建结构再加载八、完整可运行模板以后直接复制保存模板model models.vgg16(weightsNone) model.add_module(add_linear, torch.nn.Linear(1000,10)) torch.save(model.state_dict(), best.pth)加载模板model models.vgg16(weightsNone) model.add_module(add_linear, torch.nn.Linear(1000,10)) model.load_state_dict(torch.load(best.pth))