使用Python从零开始训练ViT图像分类模型

使用Python从零开始训练ViT图像分类模型 使用Python从零开始训练ViT图像分类模型1. 准备工作与环境搭建想要自己训练一个图像分类模型吗今天我们就用Python和PyTorch来实战训练一个ViTVision Transformer模型。不用担心复杂的技术概念我会用最直白的方式带你一步步完成整个过程。首先需要准备好编程环境。我推荐使用Anaconda来管理Python环境这样能避免各种依赖包冲突的问题。打开你的终端或命令提示符跟着我一起操作conda create -n vit-train python3.9 conda activate vit-train pip install torch torchvision torchaudio pip install transformers timm matplotlib tqdm这些包都是我们后面要用到的核心工具。torch是深度学习框架torchvision提供了图像处理工具transformers包含了ViT模型的实现timm是另一个常用的视觉模型库。如果你的电脑有NVIDIA显卡建议安装CUDA版本的PyTorch这样训练速度会快很多。可以在PyTorch官网找到对应的安装命令。2. 理解ViT模型的基本原理ViT模型其实很有意思它把处理文本的Transformer架构用到了图像识别上。想象一下我们把一张图片切成很多个小块就像把一篇文章分成多个单词一样。每个图像块都会被转换成一个向量表示然后加上位置信息因为图片中不同位置的内容很重要。这些向量序列输入到Transformer编码器中模型就能学会识别图像中的内容了。与传统卷积神经网络不同ViT不需要一层层地提取特征它能够直接关注图像中的全局信息。这也是为什么ViT在很多图像任务上表现那么出色的原因。3. 准备训练数据好的数据是成功训练模型的关键。我们可以使用CIFAR-10这个经典的数据集它包含10个类别的6万张彩色图片每张图片都是32x32像素。from torchvision import datasets, transforms from torch.utils.data import DataLoader # 定义数据预处理流程 transform transforms.Compose([ transforms.Resize((224, 224)), # ViT通常需要224x224的输入 transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) # 加载训练集和测试集 train_dataset datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtransform) test_dataset datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformtransform) # 创建数据加载器 train_loader DataLoader(train_dataset, batch_size32, shuffleTrue, num_workers2) test_loader DataLoader(test_dataset, batch_size32, shuffleFalse, num_workers2)如果你有自己的图像数据集只需要按照类似的格式组织数据即可。建议把不同类别的图片放在不同的文件夹里这样torchvision能自动处理标签。4. 构建ViT模型现在我们用timm库来创建ViT模型这个库提供了很多预定义的视觉Transformer模型。import timm import torch.nn as nn class ViTForImageClassification(nn.Module): def __init__(self, num_classes10): super(ViTForImageClassification, self).__init__() self.vit timm.create_model(vit_base_patch16_224, pretrainedFalse, num_classesnum_classes) def forward(self, x): return self.vit(x) # 创建模型实例 model ViTForImageClassification(num_classes10) print(f模型参数量{sum(p.numel() for p in model.parameters()):,})这里我们选择了vit_base_patch16_224这个架构意思是使用基础规模的ViT将图像分成16x16的小块输入尺寸是224x224。如果你想要更小的模型可以选择vit_tiny或者vit_small。5. 设置训练参数和优化器训练深度学习模型需要选择合适的优化器和学习率调度策略。import torch.optim as optim from transformers import get_linear_schedule_with_warmup # 定义优化器 optimizer optim.AdamW(model.parameters(), lr1e-4, weight_decay0.01) # 如果有GPU就使用GPU device torch.device(cuda if torch.cuda.is_available() else cpu) model model.to(device) # 定义学习率调度器 num_epochs 10 num_training_steps len(train_loader) * num_epochs lr_scheduler get_linear_schedule_with_warmup( optimizer, num_warmup_steps0.1 * num_training_steps, num_training_stepsnum_training_steps )AdamW是目前最常用的优化器它在Adam的基础上加入了权重衰减能帮助模型更好地泛化。学习率预热warmup让模型在训练初期更稳定。6. 编写训练循环现在来到最核心的部分——训练循环。这个过程会让模型逐渐学会识别图像。from tqdm import tqdm import numpy as np def train_epoch(model, dataloader, optimizer, lr_scheduler, device): model.train() total_loss 0 correct 0 total 0 progress_bar tqdm(dataloader, desc训练中) for batch_idx, (data, target) in enumerate(progress_bar): data, target data.to(device), target.to(device) optimizer.zero_grad() output model(data) loss nn.CrossEntropyLoss()(output, target) loss.backward() optimizer.step() lr_scheduler.step() total_loss loss.item() _, predicted output.max(1) total target.size(0) correct predicted.eq(target).sum().item() progress_bar.set_postfix({ loss: f{loss.item():.4f}, acc: f{100.*correct/total:.2f}% }) return total_loss / len(dataloader), 100. * correct / total每个epoch中模型会遍历整个训练集一次计算损失并更新参数。tqdm进度条能让我们直观地看到训练进度和效果。7. 评估模型性能训练过程中需要定期评估模型在测试集上的表现这能帮助我们判断模型是否过拟合。def evaluate(model, dataloader, device): model.eval() correct 0 total 0 test_loss 0 with torch.no_grad(): for data, target in dataloader: data, target data.to(device), target.to(device) output model(data) test_loss nn.CrossEntropyLoss()(output, target).item() _, predicted output.max(1) total target.size(0) correct predicted.eq(target).sum().item() accuracy 100. * correct / total avg_loss test_loss / len(dataloader) print(f测试集损失: {avg_loss:.4f}, 准确率: {accuracy:.2f}%) return avg_loss, accuracy评估时不计算梯度这样可以节省内存并加速计算。我们关注的是模型在未见过的数据上的表现。8. 开始训练模型现在把所有的部分组合起来开始真正的训练过程。# 训练多个epoch best_accuracy 0 for epoch in range(1, num_epochs 1): print(f\nEpoch {epoch}/{num_epochs}) # 训练一个epoch train_loss, train_acc train_epoch(model, train_loader, optimizer, lr_scheduler, device) # 在测试集上评估 test_loss, test_acc evaluate(model, test_loader, device) # 保存最好的模型 if test_acc best_accuracy: best_accuracy test_acc torch.save(model.state_dict(), best_vit_model.pth) print(f保存新的最佳模型准确率: {test_acc:.2f}%) print(f\n训练完成最佳准确率: {best_accuracy:.2f}%)训练过程中你会看到损失逐渐下降准确率逐渐上升。如果发现准确率不再提升可能就需要调整学习率或者尝试其他优化策略了。9. 使用训练好的模型进行预测训练完成后我们可以用最好的模型来对新图像进行分类。def predict(image_path, model, transform, class_names): model.eval() image Image.open(image_path).convert(RGB) image transform(image).unsqueeze(0).to(device) with torch.no_grad(): output model(image) probabilities torch.softmax(output, dim1) confidence, predicted torch.max(probabilities, 1) return class_names[predicted.item()], confidence.item() # CIFAR-10的类别名称 class_names [飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船, 卡车] # 使用训练好的模型进行预测 model.load_state_dict(torch.load(best_vit_model.pth)) prediction, confidence predict(your_image.jpg, model, transform, class_names) print(f预测结果: {prediction}, 置信度: {confidence:.2%})这个预测函数会输出最可能的类别及其置信度。你可以用自己的图片试试看模型的表现如何。10. 进阶技巧和优化建议如果你想要进一步提升模型性能这里有一些实用建议数据增强能显著提升模型泛化能力。在数据预处理中加入随机翻转、旋转、颜色抖动等变换transform_train transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2, hue0.1), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])使用预训练权重可以加速收敛并提升性能。timm库提供了在ImageNet上预训练的ViT模型model timm.create_model(vit_base_patch16_224, pretrainedTrue, num_classes10)学习率查找能帮助你找到最适合的学习率。可以先在一个小范围内测试不同学习率的效果。早停机制能防止过拟合。当验证集性能不再提升时就停止训练。11. 常见问题解决训练过程中可能会遇到一些问题这里提供一些解决方案如果遇到内存不足的问题可以减小批大小batch size或者使用梯度累积# 梯度累积 accumulation_steps 4 for i, (data, target) in enumerate(train_loader): output model(data) loss criterion(output, target) / accumulation_steps loss.backward() if (i 1) % accumulation_steps 0: optimizer.step() optimizer.zero_grad()如果模型收敛太慢可以尝试增大学习率或者使用学习率预热。如果出现过拟合训练准确率高但测试准确率低可以增加数据增强、使用权重衰减或者添加Dropout。梯度爆炸可以通过梯度裁剪来解决torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0)12. 完整代码示例为了方便你快速开始这里提供一个完整的训练脚本import torch import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader import timm from transformers import get_linear_schedule_with_warmup from tqdm import tqdm # 设置设备 device torch.device(cuda if torch.cuda.is_available() else cpu) # 数据准备 transform transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) train_dataset datasets.CIFAR10(./data, trainTrue, downloadTrue, transformtransform) test_dataset datasets.CIFAR10(./data, trainFalse, downloadTrue, transformtransform) train_loader DataLoader(train_dataset, batch_size32, shuffleTrue) test_loader DataLoader(test_dataset, batch_size32, shuffleFalse) # 模型定义 model timm.create_model(vit_base_patch16_224, pretrainedFalse, num_classes10) model model.to(device) # 训练设置 optimizer torch.optim.AdamW(model.parameters(), lr1e-4) criterion nn.CrossEntropyLoss() # 训练循环 num_epochs 10 for epoch in range(num_epochs): model.train() for batch_idx, (data, target) in enumerate(tqdm(train_loader)): data, target data.to(device), target.to(device) optimizer.zero_grad() output model(data) loss criterion(output, target) loss.backward() optimizer.step() # 保存模型 torch.save(model.state_dict(), vit_model.pth)这个完整示例包含了从数据准备到模型训练的所有步骤你可以直接运行它来开始你的ViT训练之旅。从头开始训练ViT模型确实需要一些时间和耐心但整个过程充满了乐趣和挑战。通过调整超参数、尝试不同的数据增强方法你会逐渐深入理解深度学习的工作原理。最重要的是多实践、多尝试遇到问题时不要气馁每一个问题都是学习的机会。获取更多AI镜像想探索更多AI镜像和应用场景访问 CSDN星图镜像广场提供丰富的预置镜像覆盖大模型推理、图像生成、视频生成、模型微调等多个领域支持一键部署。