从零实现你的第一个GANPyTorch实战指南与可视化技巧当你第一次听说生成对抗网络这个词时脑海中浮现的是什么是复杂的数学公式还是晦涩难懂的论文今天我们要打破这种刻板印象。想象一下你只需要几行代码就能让计算机学会创造——生成从未存在过的人脸、风景画甚至艺术作品。这就是GAN的魅力所在而我们将用最直接的方式带你体验这种创造力的爆发。1. 环境准备与项目初始化在开始之前确保你的开发环境已经准备就绪。我们将使用Python 3.8和PyTorch 1.10版本这是目前最稳定的组合。如果你还没有安装PyTorch可以通过以下命令快速完成pip install torch torchvision matplotlib numpy创建一个新的项目目录结构如下gan_project/ ├── data/ # 存放训练数据 ├── models/ # 保存训练好的模型 ├── utils.py # 工具函数 ├── train.py # 训练脚本 └── visualize.py # 可视化工具我们将使用MNIST手写数字数据集作为起点因为它简单易懂且训练速度快。PyTorch内置了数据加载工具可以轻松获取from torchvision import datasets, transforms transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset datasets.MNIST( root./data, trainTrue, downloadTrue, transformtransform ) train_loader torch.utils.data.DataLoader( datasettrain_dataset, batch_size64, shuffleTrue )提示在Windows系统上如果遇到数据加载问题可以尝试设置num_workers0。对于Linux/macOS可以适当增加这个值以提高加载速度。2. 构建生成器与判别器现在让我们动手构建GAN的两个核心组件。生成器的任务是接收随机噪声并输出伪造的图像而判别器则需要判断输入图像是真实的还是生成的。2.1 生成器架构我们的生成器采用全连接网络结构逐步将100维的随机噪声转换为28×28的手写数字图像import torch.nn as nn class Generator(nn.Module): def __init__(self, latent_dim100): super(Generator, self).__init__() self.main nn.Sequential( nn.Linear(latent_dim, 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1024), nn.LeakyReLU(0.2), nn.Linear(1024, 784), nn.Tanh() ) def forward(self, z): img self.main(z) return img.view(-1, 1, 28, 28)关键设计选择使用LeakyReLU激活函数防止梯度消失负斜率设为0.2最后一层使用Tanh将输出压缩到[-1,1]范围与归一化后的输入数据匹配逐步扩大网络容量100→256→512→1024→7842.2 判别器架构判别器是一个二分类网络结构上与生成器对称但方向相反class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main nn.Sequential( nn.Linear(784, 1024), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(1024, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, img): flattened img.view(-1, 784) validity self.main(flattened) return validity判别器的特殊设计添加Dropout层防止过拟合丢弃率0.3使用Sigmoid输出0到1之间的概率值同样采用LeakyReLU保持梯度流动注意在实际应用中现代GAN更常使用卷积结构(DCGAN)但全连接网络更易于理解和调试适合初学者入门。3. 训练过程详解GAN的训练就像一场精妙的博弈游戏我们需要平衡生成器和判别器的学习进度。以下是训练循环的核心代码# 初始化模型和优化器 generator Generator() discriminator Discriminator() optimizer_G torch.optim.Adam(generator.parameters(), lr0.0002, betas(0.5, 0.999)) optimizer_D torch.optim.Adam(discriminator.parameters(), lr0.0002, betas(0.5, 0.999)) adversarial_loss nn.BCELoss() for epoch in range(epochs): for i, (imgs, _) in enumerate(train_loader): # 真实和假标签 real torch.ones(imgs.size(0), 1) fake torch.zeros(imgs.size(0), 1) # --------------------- # 训练判别器 # --------------------- optimizer_D.zero_grad() # 真实图像的损失 real_loss adversarial_loss(discriminator(imgs), real) # 生成图像的损失 z torch.randn(imgs.size(0), 100) gen_imgs generator(z) fake_loss adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss (real_loss fake_loss) / 2 d_loss.backward() optimizer_D.step() # ----------------- # 训练生成器 # ----------------- optimizer_G.zero_grad() z torch.randn(imgs.size(0), 100) gen_imgs generator(z) g_loss adversarial_loss(discriminator(gen_imgs), real) g_loss.backward() optimizer_G.step()训练过程中的关键点交替训练先固定生成器训练判别器然后固定判别器训练生成器标签定义真实图像标签为1生成图像标签为0损失计算判别器需要同时识别真实和生成图像生成器试图欺骗判别器使其输出接近1优化器设置使用Adam优化器β10.5有助于稳定训练常见的训练问题与解决方案问题现象可能原因解决方法生成图像模糊判别器太强降低判别器学习率模式崩溃生成单一结果生成器找到捷径增加噪声调整损失函数训练不稳定学习率过高减小学习率使用梯度裁剪4. 可视化训练过程可视化是理解GAN训练进展的绝佳方式。我们将实现三种可视化技术4.1 实时生成样本展示在训练过程中定期保存生成器输出的样本import matplotlib.pyplot as plt def save_sample_images(epoch, generator, latent_dim100, n_rows4, n_cols4): 保存生成样本的网格图像 z torch.randn(n_rows * n_cols, latent_dim) gen_imgs generator(z) fig, axs plt.subplots(n_rows, n_cols, figsize(8,8)) cnt 0 for i in range(n_rows): for j in range(n_cols): axs[i,j].imshow(gen_imgs[cnt,0,:,:].detach().numpy(), cmapgray) axs[i,j].axis(off) cnt 1 fig.savefig(fimages/epoch_{epoch}.png) plt.close()4.2 损失曲线绘制记录并绘制生成器和判别器的损失变化def plot_losses(g_losses, d_losses): plt.figure(figsize(10,5)) plt.title(Generator and Discriminator Loss During Training) plt.plot(g_losses, labelG) plt.plot(d_losses, labelD) plt.xlabel(iterations) plt.ylabel(Loss) plt.legend() plt.savefig(loss_curve.png)4.3 潜在空间插值探索生成器的创造力如何随输入噪声变化def interpolate_between_points(generator, point1, point2, n_steps10): 在两个噪声向量之间插值并生成图像 ratios torch.linspace(0, 1, n_steps) vectors [] for ratio in ratios: v (1.0 - ratio) * point1 ratio * point2 vectors.append(v) vectors torch.stack(vectors) with torch.no_grad(): images generator(vectors) fig, axs plt.subplots(1, n_steps, figsize(20,2)) for i, img in enumerate(images): axs[i].imshow(img[0].numpy(), cmapgray) axs[i].axis(off) plt.show()可视化分析要点初期阶段生成图像通常是随机噪声判别器损失快速下降中期阶段开始出现可辨认的形状但可能模糊或扭曲后期阶段图像质量显著提高损失曲线趋于平衡提示如果发现判别器损失降至接近0说明判别器过于强大生成器无法学习。此时应暂停判别器训练让生成器追赶几步。5. 进阶技巧与优化当基本模型能够生成可辨认的图像后我们可以尝试以下改进策略5.1 架构改进将全连接网络升级为DCGAN深度卷积GAN结构class DCGenerator(nn.Module): def __init__(self, latent_dim100): super(DCGenerator, self).__init__() self.main nn.Sequential( # 输入是Z进入卷积 nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, biasFalse), nn.BatchNorm2d(512), nn.ReLU(True), # 状态大小 (512,4,4) nn.ConvTranspose2d(512, 256, 4, 2, 1, biasFalse), nn.BatchNorm2d(256), nn.ReLU(True), # 状态大小 (256,8,8) nn.ConvTranspose2d(256, 128, 4, 2, 1, biasFalse), nn.BatchNorm2d(128), nn.ReLU(True), # 状态大小 (128,16,16) nn.ConvTranspose2d(128, 1, 4, 2, 1, biasFalse), nn.Tanh() # 输出大小 (1,32,32) ) def forward(self, input): return self.main(input)卷积GAN的优势更好地保留空间信息参数效率更高通常生成质量更好的图像5.2 训练技巧标签平滑将真实标签从1.0改为0.9防止判别器过于自信real_labels torch.full((batch_size,1), 0.9)噪声注入在判别器的输入中添加随机噪声real_imgs imgs 0.01 * torch.randn_like(imgs)历史缓冲保存之前生成的图像用于判别器训练fake_buffer deque(maxlen100)5.3 评估指标如何客观评价GAN的性能常用指标包括Inception Score (IS)衡量生成图像的多样性和质量Fréchet Inception Distance (FID)比较生成与真实图像的统计特性人工评估仍然是最可靠的方法实现简单的FID计算from scipy.linalg import sqrtm def calculate_fid(real_activations, fake_activations): # 计算均值和协方差 mu1, sigma1 real_activations.mean(axis0), np.cov(real_activations, rowvarFalse) mu2, sigma2 fake_activations.mean(axis0), np.cov(fake_activations, rowvarFalse) # 计算平方根||mu1-mu2||^2 ssdiff np.sum((mu1 - mu2)**2.0) # 计算协方差矩阵的平方根 covmean sqrtm(sigma1.dot(sigma2)) # 检查复数 if np.iscomplexobj(covmean): covmean covmean.real # 计算FID fid ssdiff np.trace(sigma1 sigma2 - 2.0 * covmean) return fid6. 实战生成手写数字之外的图像掌握了MNIST生成后我们可以挑战更复杂的数据集。以Fashion-MNIST为例# 加载Fashion-MNIST数据集 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) fashion_dataset datasets.FashionMNIST( root./data, trainTrue, downloadTrue, transformtransform ) fashion_loader torch.utils.data.DataLoader( fashion_dataset, batch_size64, shuffleTrue )训练调整建议增加模型容量更多层或更宽的层延长训练时间可能需要更多epoch数据增强随机旋转、裁剪等调整学习率更复杂数据通常需要更小的学习率生成时尚单品的关键观察服装细节如纹理、褶皱更难学习不同类别鞋子、包、上衣需要模型有更强的区分能力可能需要更深的网络结构7. 常见问题与调试指南即使按照教程操作你仍可能遇到各种问题。以下是常见问题及其解决方案7.1 生成器不学习现象生成图像始终是噪声损失值不下降。可能原因判别器太强压倒性优势生成器架构不合理学习率设置不当解决方案# 尝试以下调整 # 1. 降低判别器学习率 optimizer_D torch.optim.Adam(discriminator.parameters(), lr0.0001) # 2. 减少判别器更新频率 if i % 2 0: # 每两次更新一次生成器 optimizer_G.step() # 3. 增加生成器容量 self.main nn.Sequential( nn.Linear(latent_dim, 512), # 增加第一层宽度 # ...其余层保持不变 )7.2 模式崩溃现象生成器只产生少量几种样本缺乏多样性。解决方案表方法实现效果小批量判别在判别器中添加小批量特征统计增加样本多样性噪声注入在生成器和判别器的各层添加噪声防止确定性行为历史缓冲保存旧生成样本用于训练防止生成器遗忘7.3 训练不稳定现象损失值剧烈波动生成质量时好时坏。稳定训练的技巧使用梯度裁剪torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)采用Wasserstein损失WGAN使用谱归一化torch.nn.utils.spectral_norm(nn.Linear(in_dim, out_dim))7.4 生成图像模糊原因分析L1/L2损失倾向于生成平均结果。改进方向改用感知损失Perceptual Loss添加对抗性特征匹配使用多尺度判别器# 示例特征匹配损失 real_features discriminator.extract_features(real_imgs) fake_features discriminator.extract_features(fake_imgs) feature_loss torch.mean(torch.abs(real_features - fake_features))8. 从GAN到现代变体基础GAN只是起点现代GAN已经发展出众多强大变体DCGAN使用卷积结构的标准架构WGAN采用Wasserstein距离改善训练稳定性CycleGAN实现无配对图像的风格转换StyleGAN生成高分辨率、可控性强的图像以StyleGAN为例的关键创新渐进式增长从低分辨率开始逐步增加风格混合分离高层次和低层次特征噪声输入在每个卷积层添加噪声实现简单的渐进式增长class ProgressiveGenerator(nn.Module): def __init__(self): super().__init__() # 初始块生成4x4图像 self.block4x4 nn.Sequential(...) # 当需要提升分辨率时添加新块 self.block8x8 nn.Sequential(...) # 当前分辨率 self.current_scale 1 # 1表示4x4 def forward(self, z): if self.current_scale 1: return self.block4x4(z) elif self.current_scale 2: x self.block4x4(z) return self.block8x8(x) def add_scale(self): 增加一个新的分辨率级别 self.current_scale 1 # 初始化新块的权重...9. 实际应用案例GAN不仅限于学术研究在实际项目中也有广泛应用9.1 数据增强当真实数据稀缺时用GAN生成补充样本def generate_synthetic_samples(generator, n_samples): z torch.randn(n_samples, latent_dim) synthetic_data generator(z) return synthetic_data # 混合真实和生成数据 augmented_dataset torch.utils.data.ConcatDataset([ real_dataset, SyntheticDataset(generate_synthetic_samples(generator, 10000)) ])9.2 图像修复用GAN填充图像缺失部分class InpaintingGAN(nn.Module): def __init__(self): super().__init__() self.encoder ... # 编码器网络 self.decoder ... # 解码器网络 def forward(self, masked_img, mask): # masked_img: 带缺失的图像 # mask: 缺失区域标记 features self.encoder(masked_img) output self.decoder(features) return masked_img * (1-mask) output * mask9.3 艺术创作结合CLIP等模型实现文本到图像生成def text_to_image(text_prompt, generator, clip_model): # 将文本编码为向量 text_features clip_model.encode_text(text_prompt) # 优化噪声向量以匹配文本特征 z torch.randn(1, latent_dim, requires_gradTrue) optimizer torch.optim.Adam([z], lr0.01) for _ in range(100): image generator(z) image_features clip_model.encode_image(image) loss -torch.cosine_similarity(text_features, image_features) loss.backward() optimizer.step() return generator(z)10. 资源与后续学习为了帮助你继续深入GAN领域以下是一些优质资源开源实现参考PyTorch-GAN包含多种GAN变体的实现StyleGAN2-ADA官方PyTorch实现CLIP-GAN文本引导的图像生成推荐学习路径掌握基础GAN实现本文内容学习DCGAN和WGAN-GP探索条件GANcGAN研究最新架构如StyleGAN3实用工具库# GAN训练监控 from torch.utils.tensorboard import SummaryWriter # 高级GAN实现 import pytorch_lightning as pl # 模型可视化 import netron在完成第一个GAN项目后试着挑战以下任务生成更复杂的数据如人脸、风景实现图像到图像的转换如素描→彩色探索潜在空间操作属性编辑
别再死磕理论了!用PyTorch手把手带你跑通第一个GAN(附完整代码与可视化结果)
从零实现你的第一个GANPyTorch实战指南与可视化技巧当你第一次听说生成对抗网络这个词时脑海中浮现的是什么是复杂的数学公式还是晦涩难懂的论文今天我们要打破这种刻板印象。想象一下你只需要几行代码就能让计算机学会创造——生成从未存在过的人脸、风景画甚至艺术作品。这就是GAN的魅力所在而我们将用最直接的方式带你体验这种创造力的爆发。1. 环境准备与项目初始化在开始之前确保你的开发环境已经准备就绪。我们将使用Python 3.8和PyTorch 1.10版本这是目前最稳定的组合。如果你还没有安装PyTorch可以通过以下命令快速完成pip install torch torchvision matplotlib numpy创建一个新的项目目录结构如下gan_project/ ├── data/ # 存放训练数据 ├── models/ # 保存训练好的模型 ├── utils.py # 工具函数 ├── train.py # 训练脚本 └── visualize.py # 可视化工具我们将使用MNIST手写数字数据集作为起点因为它简单易懂且训练速度快。PyTorch内置了数据加载工具可以轻松获取from torchvision import datasets, transforms transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset datasets.MNIST( root./data, trainTrue, downloadTrue, transformtransform ) train_loader torch.utils.data.DataLoader( datasettrain_dataset, batch_size64, shuffleTrue )提示在Windows系统上如果遇到数据加载问题可以尝试设置num_workers0。对于Linux/macOS可以适当增加这个值以提高加载速度。2. 构建生成器与判别器现在让我们动手构建GAN的两个核心组件。生成器的任务是接收随机噪声并输出伪造的图像而判别器则需要判断输入图像是真实的还是生成的。2.1 生成器架构我们的生成器采用全连接网络结构逐步将100维的随机噪声转换为28×28的手写数字图像import torch.nn as nn class Generator(nn.Module): def __init__(self, latent_dim100): super(Generator, self).__init__() self.main nn.Sequential( nn.Linear(latent_dim, 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1024), nn.LeakyReLU(0.2), nn.Linear(1024, 784), nn.Tanh() ) def forward(self, z): img self.main(z) return img.view(-1, 1, 28, 28)关键设计选择使用LeakyReLU激活函数防止梯度消失负斜率设为0.2最后一层使用Tanh将输出压缩到[-1,1]范围与归一化后的输入数据匹配逐步扩大网络容量100→256→512→1024→7842.2 判别器架构判别器是一个二分类网络结构上与生成器对称但方向相反class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main nn.Sequential( nn.Linear(784, 1024), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(1024, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, img): flattened img.view(-1, 784) validity self.main(flattened) return validity判别器的特殊设计添加Dropout层防止过拟合丢弃率0.3使用Sigmoid输出0到1之间的概率值同样采用LeakyReLU保持梯度流动注意在实际应用中现代GAN更常使用卷积结构(DCGAN)但全连接网络更易于理解和调试适合初学者入门。3. 训练过程详解GAN的训练就像一场精妙的博弈游戏我们需要平衡生成器和判别器的学习进度。以下是训练循环的核心代码# 初始化模型和优化器 generator Generator() discriminator Discriminator() optimizer_G torch.optim.Adam(generator.parameters(), lr0.0002, betas(0.5, 0.999)) optimizer_D torch.optim.Adam(discriminator.parameters(), lr0.0002, betas(0.5, 0.999)) adversarial_loss nn.BCELoss() for epoch in range(epochs): for i, (imgs, _) in enumerate(train_loader): # 真实和假标签 real torch.ones(imgs.size(0), 1) fake torch.zeros(imgs.size(0), 1) # --------------------- # 训练判别器 # --------------------- optimizer_D.zero_grad() # 真实图像的损失 real_loss adversarial_loss(discriminator(imgs), real) # 生成图像的损失 z torch.randn(imgs.size(0), 100) gen_imgs generator(z) fake_loss adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss (real_loss fake_loss) / 2 d_loss.backward() optimizer_D.step() # ----------------- # 训练生成器 # ----------------- optimizer_G.zero_grad() z torch.randn(imgs.size(0), 100) gen_imgs generator(z) g_loss adversarial_loss(discriminator(gen_imgs), real) g_loss.backward() optimizer_G.step()训练过程中的关键点交替训练先固定生成器训练判别器然后固定判别器训练生成器标签定义真实图像标签为1生成图像标签为0损失计算判别器需要同时识别真实和生成图像生成器试图欺骗判别器使其输出接近1优化器设置使用Adam优化器β10.5有助于稳定训练常见的训练问题与解决方案问题现象可能原因解决方法生成图像模糊判别器太强降低判别器学习率模式崩溃生成单一结果生成器找到捷径增加噪声调整损失函数训练不稳定学习率过高减小学习率使用梯度裁剪4. 可视化训练过程可视化是理解GAN训练进展的绝佳方式。我们将实现三种可视化技术4.1 实时生成样本展示在训练过程中定期保存生成器输出的样本import matplotlib.pyplot as plt def save_sample_images(epoch, generator, latent_dim100, n_rows4, n_cols4): 保存生成样本的网格图像 z torch.randn(n_rows * n_cols, latent_dim) gen_imgs generator(z) fig, axs plt.subplots(n_rows, n_cols, figsize(8,8)) cnt 0 for i in range(n_rows): for j in range(n_cols): axs[i,j].imshow(gen_imgs[cnt,0,:,:].detach().numpy(), cmapgray) axs[i,j].axis(off) cnt 1 fig.savefig(fimages/epoch_{epoch}.png) plt.close()4.2 损失曲线绘制记录并绘制生成器和判别器的损失变化def plot_losses(g_losses, d_losses): plt.figure(figsize(10,5)) plt.title(Generator and Discriminator Loss During Training) plt.plot(g_losses, labelG) plt.plot(d_losses, labelD) plt.xlabel(iterations) plt.ylabel(Loss) plt.legend() plt.savefig(loss_curve.png)4.3 潜在空间插值探索生成器的创造力如何随输入噪声变化def interpolate_between_points(generator, point1, point2, n_steps10): 在两个噪声向量之间插值并生成图像 ratios torch.linspace(0, 1, n_steps) vectors [] for ratio in ratios: v (1.0 - ratio) * point1 ratio * point2 vectors.append(v) vectors torch.stack(vectors) with torch.no_grad(): images generator(vectors) fig, axs plt.subplots(1, n_steps, figsize(20,2)) for i, img in enumerate(images): axs[i].imshow(img[0].numpy(), cmapgray) axs[i].axis(off) plt.show()可视化分析要点初期阶段生成图像通常是随机噪声判别器损失快速下降中期阶段开始出现可辨认的形状但可能模糊或扭曲后期阶段图像质量显著提高损失曲线趋于平衡提示如果发现判别器损失降至接近0说明判别器过于强大生成器无法学习。此时应暂停判别器训练让生成器追赶几步。5. 进阶技巧与优化当基本模型能够生成可辨认的图像后我们可以尝试以下改进策略5.1 架构改进将全连接网络升级为DCGAN深度卷积GAN结构class DCGenerator(nn.Module): def __init__(self, latent_dim100): super(DCGenerator, self).__init__() self.main nn.Sequential( # 输入是Z进入卷积 nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, biasFalse), nn.BatchNorm2d(512), nn.ReLU(True), # 状态大小 (512,4,4) nn.ConvTranspose2d(512, 256, 4, 2, 1, biasFalse), nn.BatchNorm2d(256), nn.ReLU(True), # 状态大小 (256,8,8) nn.ConvTranspose2d(256, 128, 4, 2, 1, biasFalse), nn.BatchNorm2d(128), nn.ReLU(True), # 状态大小 (128,16,16) nn.ConvTranspose2d(128, 1, 4, 2, 1, biasFalse), nn.Tanh() # 输出大小 (1,32,32) ) def forward(self, input): return self.main(input)卷积GAN的优势更好地保留空间信息参数效率更高通常生成质量更好的图像5.2 训练技巧标签平滑将真实标签从1.0改为0.9防止判别器过于自信real_labels torch.full((batch_size,1), 0.9)噪声注入在判别器的输入中添加随机噪声real_imgs imgs 0.01 * torch.randn_like(imgs)历史缓冲保存之前生成的图像用于判别器训练fake_buffer deque(maxlen100)5.3 评估指标如何客观评价GAN的性能常用指标包括Inception Score (IS)衡量生成图像的多样性和质量Fréchet Inception Distance (FID)比较生成与真实图像的统计特性人工评估仍然是最可靠的方法实现简单的FID计算from scipy.linalg import sqrtm def calculate_fid(real_activations, fake_activations): # 计算均值和协方差 mu1, sigma1 real_activations.mean(axis0), np.cov(real_activations, rowvarFalse) mu2, sigma2 fake_activations.mean(axis0), np.cov(fake_activations, rowvarFalse) # 计算平方根||mu1-mu2||^2 ssdiff np.sum((mu1 - mu2)**2.0) # 计算协方差矩阵的平方根 covmean sqrtm(sigma1.dot(sigma2)) # 检查复数 if np.iscomplexobj(covmean): covmean covmean.real # 计算FID fid ssdiff np.trace(sigma1 sigma2 - 2.0 * covmean) return fid6. 实战生成手写数字之外的图像掌握了MNIST生成后我们可以挑战更复杂的数据集。以Fashion-MNIST为例# 加载Fashion-MNIST数据集 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) fashion_dataset datasets.FashionMNIST( root./data, trainTrue, downloadTrue, transformtransform ) fashion_loader torch.utils.data.DataLoader( fashion_dataset, batch_size64, shuffleTrue )训练调整建议增加模型容量更多层或更宽的层延长训练时间可能需要更多epoch数据增强随机旋转、裁剪等调整学习率更复杂数据通常需要更小的学习率生成时尚单品的关键观察服装细节如纹理、褶皱更难学习不同类别鞋子、包、上衣需要模型有更强的区分能力可能需要更深的网络结构7. 常见问题与调试指南即使按照教程操作你仍可能遇到各种问题。以下是常见问题及其解决方案7.1 生成器不学习现象生成图像始终是噪声损失值不下降。可能原因判别器太强压倒性优势生成器架构不合理学习率设置不当解决方案# 尝试以下调整 # 1. 降低判别器学习率 optimizer_D torch.optim.Adam(discriminator.parameters(), lr0.0001) # 2. 减少判别器更新频率 if i % 2 0: # 每两次更新一次生成器 optimizer_G.step() # 3. 增加生成器容量 self.main nn.Sequential( nn.Linear(latent_dim, 512), # 增加第一层宽度 # ...其余层保持不变 )7.2 模式崩溃现象生成器只产生少量几种样本缺乏多样性。解决方案表方法实现效果小批量判别在判别器中添加小批量特征统计增加样本多样性噪声注入在生成器和判别器的各层添加噪声防止确定性行为历史缓冲保存旧生成样本用于训练防止生成器遗忘7.3 训练不稳定现象损失值剧烈波动生成质量时好时坏。稳定训练的技巧使用梯度裁剪torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)采用Wasserstein损失WGAN使用谱归一化torch.nn.utils.spectral_norm(nn.Linear(in_dim, out_dim))7.4 生成图像模糊原因分析L1/L2损失倾向于生成平均结果。改进方向改用感知损失Perceptual Loss添加对抗性特征匹配使用多尺度判别器# 示例特征匹配损失 real_features discriminator.extract_features(real_imgs) fake_features discriminator.extract_features(fake_imgs) feature_loss torch.mean(torch.abs(real_features - fake_features))8. 从GAN到现代变体基础GAN只是起点现代GAN已经发展出众多强大变体DCGAN使用卷积结构的标准架构WGAN采用Wasserstein距离改善训练稳定性CycleGAN实现无配对图像的风格转换StyleGAN生成高分辨率、可控性强的图像以StyleGAN为例的关键创新渐进式增长从低分辨率开始逐步增加风格混合分离高层次和低层次特征噪声输入在每个卷积层添加噪声实现简单的渐进式增长class ProgressiveGenerator(nn.Module): def __init__(self): super().__init__() # 初始块生成4x4图像 self.block4x4 nn.Sequential(...) # 当需要提升分辨率时添加新块 self.block8x8 nn.Sequential(...) # 当前分辨率 self.current_scale 1 # 1表示4x4 def forward(self, z): if self.current_scale 1: return self.block4x4(z) elif self.current_scale 2: x self.block4x4(z) return self.block8x8(x) def add_scale(self): 增加一个新的分辨率级别 self.current_scale 1 # 初始化新块的权重...9. 实际应用案例GAN不仅限于学术研究在实际项目中也有广泛应用9.1 数据增强当真实数据稀缺时用GAN生成补充样本def generate_synthetic_samples(generator, n_samples): z torch.randn(n_samples, latent_dim) synthetic_data generator(z) return synthetic_data # 混合真实和生成数据 augmented_dataset torch.utils.data.ConcatDataset([ real_dataset, SyntheticDataset(generate_synthetic_samples(generator, 10000)) ])9.2 图像修复用GAN填充图像缺失部分class InpaintingGAN(nn.Module): def __init__(self): super().__init__() self.encoder ... # 编码器网络 self.decoder ... # 解码器网络 def forward(self, masked_img, mask): # masked_img: 带缺失的图像 # mask: 缺失区域标记 features self.encoder(masked_img) output self.decoder(features) return masked_img * (1-mask) output * mask9.3 艺术创作结合CLIP等模型实现文本到图像生成def text_to_image(text_prompt, generator, clip_model): # 将文本编码为向量 text_features clip_model.encode_text(text_prompt) # 优化噪声向量以匹配文本特征 z torch.randn(1, latent_dim, requires_gradTrue) optimizer torch.optim.Adam([z], lr0.01) for _ in range(100): image generator(z) image_features clip_model.encode_image(image) loss -torch.cosine_similarity(text_features, image_features) loss.backward() optimizer.step() return generator(z)10. 资源与后续学习为了帮助你继续深入GAN领域以下是一些优质资源开源实现参考PyTorch-GAN包含多种GAN变体的实现StyleGAN2-ADA官方PyTorch实现CLIP-GAN文本引导的图像生成推荐学习路径掌握基础GAN实现本文内容学习DCGAN和WGAN-GP探索条件GANcGAN研究最新架构如StyleGAN3实用工具库# GAN训练监控 from torch.utils.tensorboard import SummaryWriter # 高级GAN实现 import pytorch_lightning as pl # 模型可视化 import netron在完成第一个GAN项目后试着挑战以下任务生成更复杂的数据如人脸、风景实现图像到图像的转换如素描→彩色探索潜在空间操作属性编辑