关键词:CGAN,条件生成对抗网络,生成器,判别器,损失函数,PyTorch,标签嵌入,深度学习
一、什么是CGAN?
“CGAN”即条件生成对抗网络(Conditional GAN),是在经典GAN基础上增加条件信息 y 的改进模型。
- 通过引入标签、文本描述或任何额外条件,CGAN让生成器学会 “听指令画画”,而判别器则负责判断 “这幅画是否按标签画的”。
- 相较传统GAN,CGAN在可控性与多样性上实现显著突破。
简单一句话:有了CGAN,想让它生成“数字7”,就不怕它画成“2”。
二、CGAN训练六步走
1. 初始化
网络权重随机赋值(通常使用Kaiming或Xavier策略)。
快速记忆:把锅碗瓢盆摆好,开火之前先预热。
2. 数据准备
- 训练图像
imgs与标签labels成对送入 DataLoader。 - 新闻件:Mini-Batch 能有效抑制梯度震荡。
gen_labels = torch.randint(0, opt.n_classes, (batch_size,)) # 伪造「真标签」
3. 前向计算
- 生成器:拼接噪声
z与标签y,抽取出gen_imgs。 - 判别器:接收
(img, y)二元组并输出真伪打分。
4. 计算损失
- 生成器损失:鼓励判别器把假样本判真。
- 判别器损失:正确区分真假,两个目标的平均值。
g_loss = adversarial_loss(validity, valid)
d_loss = (d_real_loss + d_fake_loss) / 2
5. 反向传播与优化
调用 PyTorch 的自动求导,交替更新两组参数:
| 网络 | 优化器 |
|---|---|
| 生成器 | optimizer_G |
| 判别器 | optimizer_D |
6. 迭代训练
重复步骤 3—5,直到 Loss 曲线收敛或达到 epoch 上限。Step by Step,小球滚到大山底。
三、核心代码全解析
本段落的核心关键词:Embedding, LeakyReLU, tanh, BCEWithLogitsLoss, Adam。
1. 模型结构一览
1.1 生成器(Generator)
- 标签嵌入
label_emb = nn.Embedding(num_classes, embed_dim):把离散索引转为连续向量。 - 全连接塔:128 → 256 → 512 → 1024 → 像素层。
- 激活策略:
LeakyReLU做骨干,tanh输出像素范围 [-1, 1]。
1.2 判别器(Discriminator)
- 双输入拼接:
(img_flatten, label_embed)沿最后一维拼接,交给三层全连接。 - 正则化:
Dropout(p=0.4)防止粗暴过拟合。
2. 关键片段 + 中文注释
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes) # 标签嵌入
def block(i, o, normalize=True):
layers = [nn.Linear(i, o)]
if normalize:
layers += [nn.BatchNorm1d(o, 0.8), nn.LeakyReLU(0.2, True)]
return layers
self.net = nn.Sequential(
*block(opt.latent_dim+opt.n_classes, 128, False),
*block(128, 256), *block(256, 512), *block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, noise, labels):
inp = torch.cat((self.label_emb(labels), noise), -1)
return self.net(inp).view(-1, *img_shape)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
self.net = nn.Sequential(
nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, True),
nn.Linear(512, 1)
)
def forward(self, img, labels):
x = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), -1)
return self.net(x)
👉 你还想亲手跑跑效果?点击获得开箱即用的PyTorch脚本
3. 训练曲线与可视化
在 CPU+GPU 双模式下训练 50 个 epoch:
- 生成器损失 先迅速上升后缓降;
- 判别器损失 则从 0.7 逐渐收敛到 0.45 左右。
若出现 “D 损失震荡剧烈” → 适当下调学习率lr到1e-4,能有效提升稳定度。
四、常见问题速答(FAQ)
Q1:为什么噪声维度和条件维度要拼接,而不是相加?
A1:拼接保留了各个维度间的关系;相加会模糊信息边界,导致生成标签错位。
Q2:tanh 激活为什么比 sigmoid 更好?
A2:tanh 输出均值接近 0,可加速网络收敛;同时与归一化后图像的范围 [-1, 1] 自然对齐。
Q3:判别器为什么对生成图像使用 .detach()?
A3:防止梯度流入生成器,避免训练时“左手打右手”,保证两网络交替更新。
Q4:训练多久能看出明显“数字”形状?
A4:在MNIST+单GPU模式下,大约 10 个 epoch 出现可识别的 0–9 粗略轮廓;20–30 epoch 达到逼真。
Q5:模型如何支持更高分辨率(64×64 或 128×128)?
A5:将 Linear 堆叠改为 转置卷积(ConvTranspose2d),同时减少 “展平→重排” 带来的信息瓶颈。
Q6:想给CGAN增加文本描述条件(例如“红色五角星”)怎么做?
A6:用 文本编码器(例如 BERT 得到 768 维向量)替换 数字标签嵌入,然后将文本向量与噪声拼接即可。
五、避坑锦囊 & 进阶路线
- 标签泄漏:别在生成图像时把真实标签也喂给生成器,否则会引发“偷懒”现象。
- 模式崩塌:若某标签只能生成单一纹理,可引入 DRAGAN 梯度惩罚 或 Unrolled GAN。
- 迁移到复杂场景:在 CIFAR-10、STL-10 上训练时,把判别器改为 深度卷积网络 (DCGAN),同时将标签做 one-hot → Embeddings → 条件特征图,可与噪声一起注入每个卷积块。