用Google Colab免费GPU,10分钟搞定PyTorch猫狗分类项目(附完整代码和数据集链接)

用Google Colab免费GPU,10分钟搞定PyTorch猫狗分类项目(附完整代码和数据集链接) 零门槛玩转PyTorchColab云端10分钟实现猫狗分类实战记得第一次接触深度学习时最让我头疼的不是算法原理而是环境配置——CUDA版本冲突、显存不足、依赖库缺失...直到发现Google Colab这个神器。今天我们就用Colab的免费GPU资源带你在浏览器里完成一个完整的猫狗分类项目从数据准备到模型训练全流程实战无需任何本地环境配置1. 为什么选择Google Colab入门深度学习很多初学者在搭建本地开发环境时容易陷入配置地狱显卡驱动、CUDA工具包、Python环境这些基础依赖就可能耗掉半天时间。而Google Colab提供了开箱即用的Jupyter Notebook环境特别适合快速验证想法和教学演示。它的核心优势在于免费GPU资源Tesla T4或K80显卡足够中小型模型训练零配置环境预装主流深度学习框架PyTorch/TensorFlow云端存储集成直接挂载Google Drive管理数据集协作分享便捷一键生成可交互的分享链接提示Colab的GPU会话默认持续12小时长时间训练建议定期保存中间结果到Google Drive2. 五分钟快速搭建Colab开发环境打开浏览器访问 Colab官网 点击新建笔记本我们立即开始配置# 验证GPU是否可用 import torch print(fPyTorch版本: {torch.__version__}) print(GPU可用:, torch.cuda.is_available()) print(当前设备:, torch.cuda.get_device_name(0))如果输出显示Tesla T4等GPU信息说明环境已经就绪。接下来安装必要依赖!pip install torchvision0.12.0 !pip install tqdm matplotlib数据集我们使用经典的Kaggle猫狗大战精简版12500张猫狗图片已上传到公开网盘# 下载并解压数据集 !wget -O cats_dogs.zip 你的数据集下载链接 !unzip -q cats_dogs.zip -d /content/3. PyTorch数据管道构建技巧高效的数据加载是模型训练的前提PyTorch提供了Dataset和DataLoader两个核心组件。我们针对Colab环境做了这些优化from torchvision import transforms from torch.utils.data import DataLoader, random_split import os from PIL import Image # 定义图像预处理流水线 transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) class CatDogDataset(torch.utils.data.Dataset): def __init__(self, root_dir, transformNone): self.image_paths [] self.labels [] self.transform transform # 遍历目录收集样本 for label, class_dir in enumerate([Cat, Dog]): dir_path os.path.join(root_dir, class_dir) for img_name in os.listdir(dir_path): self.image_paths.append(os.path.join(dir_path, img_name)) self.labels.append(label) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img Image.open(self.image_paths[idx]).convert(RGB) if self.transform: img self.transform(img) return img, self.labels[idx] # 数据集划分与加载 full_dataset CatDogDataset(/content/cats_dogs, transformtransform) train_size int(0.8 * len(full_dataset)) test_size len(full_dataset) - train_size train_dataset, test_dataset random_split(full_dataset, [train_size, test_size]) train_loader DataLoader(train_dataset, batch_size32, shuffleTrue) test_loader DataLoader(test_dataset, batch_size32)4. 轻量级CNN模型设计与训练考虑到Colab的GPU内存限制我们设计了一个精简但有效的网络结构import torch.nn as nn import torch.optim as optim class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.features nn.Sequential( nn.Conv2d(3, 16, kernel_size3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size3, padding1), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier nn.Sequential( nn.Flatten(), nn.Linear(64*28*28, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 2) ) def forward(self, x): x self.features(x) x self.classifier(x) return x model SimpleCNN().cuda() criterion nn.CrossEntropyLoss() optimizer optim.Adam(model.parameters(), lr0.001)训练过程加入进度条和准确率监控from tqdm import tqdm def train_epoch(model, loader, optimizer, criterion): model.train() running_loss 0.0 correct 0 total 0 for inputs, labels in tqdm(loader, descTraining): inputs, labels inputs.cuda(), labels.cuda() optimizer.zero_grad() outputs model(inputs) loss criterion(outputs, labels) loss.backward() optimizer.step() running_loss loss.item() _, predicted outputs.max(1) total labels.size(0) correct predicted.eq(labels).sum().item() return running_loss/len(loader), 100.*correct/total for epoch in range(5): train_loss, train_acc train_epoch(model, train_loader, optimizer, criterion) print(fEpoch {epoch1}: Loss{train_loss:.4f}, Acc{train_acc:.2f}%)5. 模型评估与可视化分析训练完成后我们通过测试集验证模型表现def evaluate(model, loader): model.eval() correct 0 total 0 with torch.no_grad(): for inputs, labels in tqdm(loader, descEvaluating): inputs, labels inputs.cuda(), labels.cuda() outputs model(inputs) _, predicted outputs.max(1) total labels.size(0) correct predicted.eq(labels).sum().item() return 100.*correct/total test_acc evaluate(model, test_loader) print(fTest Accuracy: {test_acc:.2f}%)可视化部分预测结果import matplotlib.pyplot as plt import numpy as np def imshow(img): img img / 2 0.5 # 反归一化 npimg img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.axis(off) # 获取一批测试数据 dataiter iter(test_loader) images, labels next(dataiter) images, labels images.cuda(), labels.cuda() # 预测并可视化 outputs model(images) _, preds outputs.max(1) images images.cpu() fig plt.figure(figsize(12, 8)) for idx in range(6): ax fig.add_subplot(2, 3, idx1) imshow(images[idx]) ax.set_title(fPred: {Dog if preds[idx] else Cat}\nTrue: {Dog if labels[idx] else Cat}) plt.tight_layout() plt.show()6. 模型保存与Colab使用技巧训练好的模型可以保存到Google Drive避免丢失from google.colab import drive drive.mount(/content/drive) # 保存模型权重 torch.save(model.state_dict(), /content/drive/MyDrive/cat_dog_cnn.pth) # 后续加载方式 # model SimpleCNN().cuda() # model.load_state_dict(torch.load(/content/drive/MyDrive/cat_dog_cnn.pth))Colab使用中的几个实用技巧会话保持在代码单元格添加以下内容防止断连from IPython.display import Javascript def keep_alive(): display(Javascript( function ConnectButton(){ console.log(Connect pushed); document.querySelector(#connect).click() } setInterval(ConnectButton,60000); )) keep_alive()显存清理遇到CUDA内存错误时运行import torch, gc gc.collect() torch.cuda.empty_cache()大数据集处理建议先将数据压缩上传到Google Drive再挂载解压7. 进阶优化方向当基础模型跑通后可以考虑这些改进方案优化方向具体方法预期收益数据增强随机翻转、旋转、色彩抖动提升模型泛化能力模型架构使用ResNet等预训练模型进行迁移学习显著提高准确率超参数调优学习率调度、批量大小调整加快收敛速度类别不平衡处理加权随机采样、Focal Loss改善少数类识别率迁移学习示例代码from torchvision.models import resnet18 pretrained_model resnet18(pretrainedTrue) # 替换最后一层 pretrained_model.fc nn.Linear(pretrained_model.fc.in_features, 2) model pretrained_model.cuda() # 只训练最后一层 optimizer optim.Adam(model.fc.parameters(), lr0.001)数据增强策略train_transform transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])在Colab上跑完整个流程后最大的感受是云端开发确实能节省大量环境配置时间让学习者可以专注于模型本身。有个小技巧分享当需要长时间训练时可以打开Colab Pro的后台执行功能即使关闭浏览器标签页也能继续运行。