PyTorch实战:用AE和VAE模型重构MNIST手写数字(附完整代码)

PyTorch实战:用AE和VAE模型重构MNIST手写数字(附完整代码) PyTorch实战用AE和VAE模型重构MNIST手写数字附完整代码在深度学习领域自编码器AutoEncoder和变分自编码器Variational AutoEncoder是两种极具代表性的生成模型。它们不仅在数据降维和特征提取方面表现出色还能用于图像生成、异常检测等实际应用场景。本文将带你从零开始使用PyTorch框架实现这两种模型并在MNIST数据集上进行实战演练。1. 环境准备与数据加载在开始构建模型之前我们需要确保开发环境配置正确并准备好实验所需的数据集。以下是完整的准备工作流程import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms import matplotlib.pyplot as plt import os # 设置随机种子保证结果可复现 torch.manual_seed(42) # 数据预处理管道 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) # 将像素值归一化到[-1,1]范围 ]) # 加载MNIST数据集 train_dataset datasets.MNIST( root./data, trainTrue, downloadTrue, transformtransform ) test_dataset datasets.MNIST( root./data, trainFalse, downloadTrue, transformtransform ) # 创建数据加载器 train_loader DataLoader( train_dataset, batch_size128, shuffleTrue, num_workers2 ) test_loader DataLoader( test_dataset, batch_size128, shuffleFalse, num_workers2 )注意在实际项目中建议将数据预处理步骤封装成独立的函数或类便于维护和复用。特别是当处理更复杂的数据集时这种模块化设计能显著提高代码的可读性。2. 自编码器AE实现与训练自编码器是一种通过无监督学习来获取数据高效表示的人工神经网络。它由编码器和解码器两部分组成通过最小化重构误差来训练网络。2.1 AE模型架构设计class AutoEncoder(nn.Module): def __init__(self, latent_dim32): super(AutoEncoder, self).__init__() # 编码器部分 self.encoder nn.Sequential( nn.Linear(28*28, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, latent_dim) ) # 解码器部分 self.decoder nn.Sequential( nn.Linear(latent_dim, 128), nn.ReLU(), nn.Linear(128, 256), nn.ReLU(), nn.Linear(256, 28*28), nn.Tanh() # 配合归一化到[-1,1]的输出 ) def forward(self, x): batch_size x.size(0) x x.view(batch_size, -1) # 展平输入 encoded self.encoder(x) decoded self.decoder(encoded) return decoded.view(batch_size, 1, 28, 28) # 恢复原始形状2.2 AE训练流程训练自编码器需要定义损失函数和优化器然后通过迭代训练数据来优化模型参数# 初始化模型和优化器 ae_model AutoEncoder(latent_dim32) ae_optimizer optim.Adam(ae_model.parameters(), lr1e-3) criterion nn.MSELoss() # 均方误差损失 # 训练函数 def train_ae(epoch): ae_model.train() train_loss 0 for batch_idx, (data, _) in enumerate(train_loader): ae_optimizer.zero_grad() recon_batch ae_model(data) loss criterion(recon_batch, data) loss.backward() train_loss loss.item() ae_optimizer.step() print(fEpoch: {epoch} | Loss: {train_loss/len(train_loader):.4f}) # 训练模型 num_epochs 50 for epoch in range(1, num_epochs 1): train_ae(epoch)3. 变分自编码器VAE实现与训练变分自编码器是自编码器的概率版本它不仅能够重构输入数据还能生成新的样本。VAE通过将输入编码为概率分布而非固定向量来实现这一功能。3.1 VAE模型架构设计class VariationalAutoEncoder(nn.Module): def __init__(self, latent_dim32): super(VariationalAutoEncoder, self).__init__() # 编码器部分 self.encoder nn.Sequential( nn.Linear(28*28, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU() ) # 潜在空间的均值和方差 self.fc_mu nn.Linear(128, latent_dim) self.fc_logvar nn.Linear(128, latent_dim) # 解码器部分 self.decoder nn.Sequential( nn.Linear(latent_dim, 128), nn.ReLU(), nn.Linear(128, 256), nn.ReLU(), nn.Linear(256, 28*28), nn.Tanh() ) def encode(self, x): h self.encoder(x) return self.fc_mu(h), self.fc_logvar(h) def reparameterize(self, mu, logvar): std torch.exp(0.5 * logvar) eps torch.randn_like(std) return mu eps * std def decode(self, z): return self.decoder(z) def forward(self, x): batch_size x.size(0) x x.view(batch_size, -1) mu, logvar self.encode(x) z self.reparameterize(mu, logvar) return self.decode(z).view(batch_size, 1, 28, 28), mu, logvar3.2 VAE损失函数与训练VAE的损失函数由两部分组成重构损失和KL散度。前者衡量重构质量后者确保潜在空间的正则化。def vae_loss(recon_x, x, mu, logvar): VAE损失函数 重构损失 KL散度 BCE nn.functional.mse_loss(recon_x, x, reductionsum) KLD -0.5 * torch.sum(1 logvar - mu.pow(2) - logvar.exp()) return BCE KLD # 初始化模型和优化器 vae_model VariationalAutoEncoder(latent_dim32) vae_optimizer optim.Adam(vae_model.parameters(), lr1e-3) # 训练函数 def train_vae(epoch): vae_model.train() train_loss 0 for batch_idx, (data, _) in enumerate(train_loader): vae_optimizer.zero_grad() recon_batch, mu, logvar vae_model(data) loss vae_loss(recon_batch, data, mu, logvar) loss.backward() train_loss loss.item() vae_optimizer.step() print(fEpoch: {epoch} | Loss: {train_loss/len(train_loader.dataset):.4f}) # 训练模型 num_epochs 50 for epoch in range(1, num_epochs 1): train_vae(epoch)4. 结果可视化与模型评估训练完成后我们需要评估模型性能并可视化重构结果。这不仅能直观展示模型效果还能帮助我们发现潜在问题。4.1 重构效果对比def compare_reconstructions(ae_model, vae_model, num_samples8): 比较AE和VAE的重构效果 ae_model.eval() vae_model.eval() with torch.no_grad(): # 获取测试样本 data next(iter(test_loader))[0][:num_samples] # 获取重构结果 ae_recon ae_model(data) vae_recon, _, _ vae_model(data) # 拼接原始图像和重构结果 comparison torch.cat([data, ae_recon, vae_recon]) # 绘制结果 plt.figure(figsize(12, 6)) for i in range(3 * num_samples): plt.subplot(3, num_samples, i1) plt.imshow(comparison[i].squeeze().numpy(), cmapgray) plt.axis(off) plt.tight_layout() plt.show() # 调用比较函数 compare_reconstructions(ae_model, vae_model)4.2 潜在空间可视化VAE的一个关键优势是其潜在空间的连续性我们可以通过可视化来验证这一点def plot_latent_space(vae_model): 可视化VAE的潜在空间 vae_model.eval() with torch.no_grad(): # 获取测试集的所有潜在表示 latents [] labels [] for data, label in test_loader: data data.view(data.size(0), -1) mu, _ vae_model.encode(data) latents.append(mu) labels.append(label) latents torch.cat(latents).numpy() labels torch.cat(labels).numpy() # 绘制散点图 plt.figure(figsize(10, 8)) scatter plt.scatter(latents[:, 0], latents[:, 1], clabels, cmaptab10, alpha0.5) plt.colorbar(scatter) plt.xlabel(Latent Dimension 1) plt.ylabel(Latent Dimension 2) plt.title(VAE Latent Space Visualization) plt.show() plot_latent_space(vae_model)4.3 生成新样本VAE能够从潜在空间采样生成新的手写数字def generate_samples(vae_model, num_samples16): 从VAE生成新样本 vae_model.eval() with torch.no_grad(): # 从标准正态分布采样 z torch.randn(num_samples, 32) samples vae_model.decode(z).view(-1, 1, 28, 28) # 绘制生成样本 plt.figure(figsize(8, 8)) for i in range(num_samples): plt.subplot(4, 4, i1) plt.imshow(samples[i].squeeze().numpy(), cmapgray) plt.axis(off) plt.tight_layout() plt.show() generate_samples(vae_model)5. 常见问题与优化技巧在实际项目中实现和训练AE/VAE模型可能会遇到各种问题。以下是几个常见问题及其解决方案5.1 重构图像模糊VAE生成图像常常比AE更模糊这是因为VAE优化的是概率下界而非精确重构KL散度项迫使潜在变量接近标准正态分布解决方案调整KL散度的权重β-VAE使用更复杂的解码器结构尝试其他损失函数如感知损失5.2 潜在空间坍塌当模型忽略潜在变量所有输入都映射到相同点时发生。解决方案增加KL散度项的权重使用更小的学习率尝试更深的网络结构5.3 训练不稳定VAE训练可能比AE更不稳定特别是当潜在维度较高时。解决方案# 梯度裁剪防止爆炸 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0) # 学习率调度器 scheduler optim.lr_scheduler.ReduceLROnPlateau( optimizer, modemin, factor0.5, patience5 )5.4 性能优化技巧批归一化在编码器和解码器中添加批归一化层残差连接使用残差块构建更深网络渐进式训练先训练低分辨率版本再逐步提高分辨率# 带批归一化的编码器示例 self.encoder nn.Sequential( nn.Linear(28*28, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU() )