LeNet5手把手实战:用PyTorch从零搭建经典CNN(附CIFAR-10完整代码)

LeNet5手把手实战:用PyTorch从零搭建经典CNN(附CIFAR-10完整代码) LeNet5实战指南PyTorch实现经典卷积神经网络1998年诞生的LeNet5是卷积神经网络发展史上的里程碑它首次成功应用于手写数字识别任务。二十多年后的今天我们依然能从这一经典架构中汲取智慧。本文将带您用现代PyTorch框架完整实现LeNet5并在CIFAR-10数据集上进行实战训练。1. 环境准备与数据加载1.1 安装必要依赖确保已安装最新版PyTorch和torchvisionpip install torch torchvision matplotlib提示推荐使用Python 3.8环境如需GPU加速请安装CUDA版本的PyTorch1.2 CIFAR-10数据集解析CIFAR-10包含10类共6万张32x32彩色图像训练集50,000张测试集10,000张类别飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船、卡车import torchvision.transforms as transforms transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_set torchvision.datasets.CIFAR10( root./data, trainTrue, downloadTrue, transformtransform ) test_set torchvision.datasets.CIFAR10( root./data, trainFalse, downloadTrue, transformtransform )2. LeNet5模型实现2.1 网络架构详解原始LeNet5包含7层可训练参数层层类型参数说明输出尺寸Conv16个5x5卷积核628x28Pool12x2最大池化614x14Conv216个5x5卷积核1610x10Pool22x2最大池化165x5FC1全连接层120FC2全连接层84FC3输出层102.2 PyTorch实现代码import torch.nn as nn class LeNet5(nn.Module): def __init__(self): super(LeNet5, self).__init__() self.conv1 nn.Conv2d(3, 6, 5) self.pool1 nn.MaxPool2d(2, 2) self.conv2 nn.Conv2d(6, 16, 5) self.pool2 nn.MaxPool2d(2, 2) self.fc1 nn.Linear(16*5*5, 120) self.fc2 nn.Linear(120, 84) self.fc3 nn.Linear(84, 10) def forward(self, x): x self.pool1(nn.functional.relu(self.conv1(x))) x self.pool2(nn.functional.relu(self.conv2(x))) x x.view(-1, 16*5*5) x nn.functional.relu(self.fc1(x)) x nn.functional.relu(self.fc2(x)) x self.fc3(x) return x注意原始论文使用sigmoid激活函数这里改用ReLU提升训练效率3. 模型训练技巧3.1 训练配置import torch.optim as optim net LeNet5() criterion nn.CrossEntropyLoss() optimizer optim.SGD(net.parameters(), lr0.001, momentum0.9) train_loader torch.utils.data.DataLoader( train_set, batch_size32, shuffleTrue )3.2 训练循环优化关键训练参数Batch Size: 32-128初始学习率: 0.001动量: 0.9训练周期: 10-20for epoch in range(10): running_loss 0.0 for i, data in enumerate(train_loader, 0): inputs, labels data optimizer.zero_grad() outputs net(inputs) loss criterion(outputs, labels) loss.backward() optimizer.step() running_loss loss.item() if i % 500 499: print(fEpoch {epoch1}, Batch {i1}: loss {running_loss/500:.3f}) running_loss 0.03.3 常见问题解决显存不足处理方案减小batch size使用梯度累积accumulation_steps 4 for i, data in enumerate(train_loader): inputs, labels data outputs net(inputs) loss criterion(outputs, labels)/accumulation_steps loss.backward() if (i1) % accumulation_steps 0: optimizer.step() optimizer.zero_grad()过拟合应对策略添加Dropout层使用L2正则化早停法(Early Stopping)4. 模型评估与部署4.1 测试集评估correct 0 total 0 with torch.no_grad(): for data in test_loader: images, labels data outputs net(images) _, predicted torch.max(outputs.data, 1) total labels.size(0) correct (predicted labels).sum().item() print(fTest Accuracy: {100 * correct / total:.2f}%)4.2 模型保存与加载# 保存完整模型 torch.save(net, lenet5_full.pth) # 仅保存参数推荐 torch.save(net.state_dict(), lenet5_params.pth) # 加载模型 loaded_net LeNet5() loaded_net.load_state_dict(torch.load(lenet5_params.pth))4.3 单张图片预测from PIL import Image def predict_image(image_path): image Image.open(image_path) image transform(image).unsqueeze(0) with torch.no_grad(): output net(image) _, predicted torch.max(output, 1) return classes[predicted.item()]在实际项目中LeNet5虽然结构简单但作为CNN的鼻祖理解其设计思想对掌握现代卷积网络至关重要。我在多个教学项目中使用这个实现发现适当调整学习率和增加数据增强能显著提升CIFAR-10上的表现。