告别CNN?手把手教你用PyTorch复现ViT(Vision Transformer)图像分类模型

告别CNN?手把手教你用PyTorch复现ViT(Vision Transformer)图像分类模型 从零实现ViT用PyTorch构建视觉Transformer图像分类器当Transformer在自然语言处理领域大获成功后计算机视觉领域的研究者们开始思考能否用同样的架构彻底改变图像处理的方式2020年Google Research团队提出的Vision Transformer(ViT)给出了肯定答案。本文将带你从零开始用PyTorch实现这个颠覆性的图像分类模型。1. 环境准备与数据预处理在开始构建模型前我们需要搭建合适的开发环境。推荐使用Python 3.8和PyTorch 1.10版本这些版本在兼容性和性能方面都经过了充分验证。conda create -n vit python3.8 conda activate vit pip install torch torchvision torchaudio pytorch-lightning对于数据集我们将使用经典的CIFAR-10作为示例但ViT的设计同样适用于ImageNet等更大规模的数据集。首先实现数据预处理流程from torchvision import transforms, datasets # 定义ViT专用的数据增强 train_transform transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) val_transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) # 加载CIFAR-10数据集 train_set datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtrain_transform) val_set datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformval_transform)注意ViT通常需要较大的输入尺寸如224x224这与传统CNN处理的小尺寸图像如32x32不同。这种改变是为了让每个图像块(patch)包含足够的信息。2. ViT核心组件实现2.1 图像分块与嵌入层ViT的核心创新是将图像视为一系列小块(patch)的序列。标准的ViT-Base模型使用16x16的分块大小import torch import torch.nn as nn from einops import rearrange class PatchEmbedding(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, embed_dim768): super().__init__() self.img_size img_size self.patch_size patch_size self.n_patches (img_size // patch_size) ** 2 self.proj nn.Conv2d( in_chans, embed_dim, kernel_sizepatch_size, stridepatch_size ) def forward(self, x): x self.proj(x) # (B, E, H/P, W/P) x rearrange(x, b e h w - b (h w) e) return x这个模块使用卷积操作实现高效的分块处理然后通过重排维度将空间信息转换为序列形式。2.2 位置编码与类别标记与CNN不同Transformer需要显式的位置信息来理解图像结构class ViTEmbeddings(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, embed_dim768): super().__init__() self.patch_embed PatchEmbedding(img_size, patch_size, in_chans, embed_dim) self.cls_token nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed nn.Parameter( torch.zeros(1, self.patch_embed.n_patches 1, embed_dim) ) self.pos_drop nn.Dropout(p0.1) def forward(self, x): batch_size x.shape[0] x self.patch_embed(x) # (B, N, E) cls_tokens self.cls_token.expand(batch_size, -1, -1) x torch.cat((cls_tokens, x), dim1) # (B, 1 N, E) x x self.pos_embed return self.pos_drop(x)提示类别标记(cls_token)是ViT用于分类的特殊标记类似于BERT中的[CLS]标记。它会在Transformer处理过程中聚合整个序列的信息。3. Transformer编码器实现3.1 多头自注意力机制自注意力是Transformer的核心组件它允许模型动态地关注输入序列的不同部分class MultiHeadAttention(nn.Module): def __init__(self, embed_dim768, num_heads12, dropout0.1): super().__init__() self.num_heads num_heads self.head_dim embed_dim // num_heads self.qkv nn.Linear(embed_dim, embed_dim * 3) self.attn_drop nn.Dropout(dropout) self.proj nn.Linear(embed_dim, embed_dim) self.proj_drop nn.Dropout(dropout) def forward(self, x): B, N, C x.shape qkv self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv qkv.permute(2, 0, 3, 1, 4) q, k, v qkv[0], qkv[1], qkv[2] attn (q k.transpose(-2, -1)) * (self.head_dim ** -0.5) attn attn.softmax(dim-1) attn self.attn_drop(attn) x (attn v).transpose(1, 2).reshape(B, N, C) x self.proj(x) x self.proj_drop(x) return x3.2 前馈网络与编码器层完整的Transformer编码器层包含自注意力和前馈网络两部分class MLP(nn.Module): def __init__(self, in_features, hidden_featuresNone, dropout0.1): super().__init__() hidden_features hidden_features or in_features self.fc1 nn.Linear(in_features, hidden_features) self.act nn.GELU() self.fc2 nn.Linear(hidden_features, in_features) self.drop nn.Dropout(dropout) def forward(self, x): x self.fc1(x) x self.act(x) x self.drop(x) x self.fc2(x) x self.drop(x) return x class TransformerBlock(nn.Module): def __init__(self, embed_dim768, num_heads12, mlp_ratio4.0, dropout0.1): super().__init__() self.norm1 nn.LayerNorm(embed_dim) self.attn MultiHeadAttention(embed_dim, num_heads, dropout) self.norm2 nn.LayerNorm(embed_dim) self.mlp MLP(embed_dim, int(embed_dim * mlp_ratio), dropout) def forward(self, x): x x self.attn(self.norm1(x)) x x self.mlp(self.norm2(x)) return x4. 完整ViT模型与训练策略4.1 组装完整ViT模型现在我们可以将所有组件组合成完整的Vision Transformerclass VisionTransformer(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, num_classes1000, embed_dim768, depth12, num_heads12, mlp_ratio4.0): super().__init__() self.embeddings ViTEmbeddings(img_size, patch_size, in_chans, embed_dim) self.blocks nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth) ]) self.norm nn.LayerNorm(embed_dim) self.head nn.Linear(embed_dim, num_classes) def forward(self, x): x self.embeddings(x) for blk in self.blocks: x blk(x) x self.norm(x) return self.head(x[:, 0]) # 只使用cls_token进行分类4.2 训练技巧与超参数设置ViT训练需要特别注意以下几点学习率调度使用带热启动的余弦退火调度优化器选择AdamW通常表现最好正则化策略权重衰减和dropout至关重要from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR model VisionTransformer(num_classes10) # CIFAR-10有10个类别 optimizer AdamW(model.parameters(), lr3e-4, weight_decay0.05) scheduler CosineAnnealingLR(optimizer, T_max10, eta_min1e-5) criterion nn.CrossEntropyLoss()4.3 自定义数据集微调当需要在特定领域应用ViT时微调预训练模型是关键# 加载预训练权重 pretrained_dict torch.load(vit_base_patch16_224.pth) model_dict model.state_dict() # 过滤不匹配的键特别是分类头 pretrained_dict {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].shape v.shape} # 更新模型参数 model_dict.update(pretrained_dict) model.load_state_dict(model_dict) # 冻结部分层可选 for name, param in model.named_parameters(): if head not in name: # 只训练分类头 param.requires_grad False注意当目标数据集与预训练数据分布差异较大时建议解冻更多层进行微调。较小的数据集可能需要更强的正则化来防止过拟合。5. 模型评估与可视化理解ViT如何看待图像对于调试和改进模型至关重要。我们可以可视化注意力图来洞察模型的工作机制import matplotlib.pyplot as plt def visualize_attention(model, image, layer_idx11, head_idx0): model.eval() with torch.no_grad(): embeddings model.embeddings(image.unsqueeze(0)) attn_maps [] x embeddings for blk in model.blocks[:layer_idx1]: x blk.norm1(x) _, attn blk.attn(x, return_attentionTrue) attn_maps.append(attn) # 获取特定层和头的注意力图 attn attn_maps[layer_idx][0, head_idx, 0, 1:] # 忽略cls_token patch_size model.embeddings.patch_embed.patch_size grid_size image.shape[-1] // patch_size attn attn.reshape(grid_size, grid_size) plt.imshow(attn, cmaphot) plt.colorbar() plt.show()这种可视化可以帮助我们理解模型关注图像的哪些区域进行决策对于调试和解释模型行为非常有价值。在实际项目中我发现ViT的注意力机制在捕捉全局上下文方面表现出色但在处理细粒度局部特征时可能不如CNN精确。这解释了为什么ViT通常在中等规模数据集上需要更强的数据增强和正则化。