English | 简体中文
本项目是基于PaddlePaddle复现论文**《Conditional Image Synthesis with Auxiliary Classifier GANs》**(ACGAN), 该论文的主要工作是向条件式生成对抗网络(Conditional GAN)中加入辅助判别器来指导图像生成过程,具体的做法是在模型的判别器中加入分类层来强迫生成的图像类别与输入的标签尽可能接近。实验证明,ACGAN在合成高分辨的图像时表现良好。
论文
- [1] Odena, A. , C. Olah , and J. Shlens . "Conditional Image Synthesis With Auxiliary Classifier GANs." (2016).
参考项目
由于作者并未开源代码,所以本项目参考了以下非官方实现:
在线运行
- Ai Studio 脚本项目:ACGAN-Paddle
本次复现未涉及指标测评,主要目标是生成图像能够在肉眼评估上与真实的样本接近,故以下展示了随机生成的样本和真实样本:
生成样本 | 真实样本 |
---|---|
论文中的数据集是ImageNet, 数据集的组织格式如下:
- 训练集:1279591张图像
- 验证集:50000张图像
- 测试集:10000张图像
按照论文中的设置,将1000个图像类别分组,每10个类别一组用来训练一个模型。本次复现共进行三组不同实验:
- 图像类别序号为10-20共10000张图像作为训练集
- 图像类别序号为100-100共10000张图像作为训练集
- 随机挑选10个类别共10000张图像作为训练集
- 硬件:GPU、CPU
- 框架:PaddlePaddle>=2.0.0
git https://github.com/Callifrey/ACGAN-Paddle.git
cd ACGAN-Paddle
python trian.py --dataroot [imagenet path] # [eg:xxx/ImageNet/train]
python test.py --check_path [checkpoints path] --which_epoch [epoch]
visuldl --logdir ./log
预训练模型见百度网盘链接( 提取码: ce8r )其中每个文件夹内有三个文件,分别是生成器模型参数、判别器模型参数以及该组实验对应的log, 请将预训练模型置于checkpoints目录下,测试时设置对应的文件夹路径。
├─checkpoints # 保存模型
├─imgs # 保存各类图像
├─log # 保存入职文件
├─results # 保存生成结果
│ README.md # 英文readme
│ README_cn.md # 中文readme
│ dataset.py # 数据集类
│ network.py # 模型结构
│ train.py # 训练
│ test.py # 测试
│ utils.py # 部分工具类
-
train.py 参数说明(部分)
参数 默认值 说明 --dataroot str: ‘/media/gallifrey/DJW/Dataset/Imagenet/train’ 训练集路径 --workers int : 4 数据加载子进程数量 --batchSize int: 100 开始训练的断点 --imageSize int: 128 读取/生成图像尺寸 --nz int: 110 随机噪声维度 --ngf int: 64 生成器通道数基数 --ndf int: 5 判别器通道数基数 --lr float: 0.0002 初始学习率 --beta1 float: 0.5 优化器参数 --check_path str: './checkpoints' 模型保存路径 --result_path str:'./result' 结果保存路径 --log_path str: './log' 日志保存路径 --save_freq int: 5 每隔几个epoch保存一次模型 --num_classes int: 10 图像类别 --niter int: 500 训练的epoch -
test.py 参数说明(部分)
参数 默认值 说明 --batchSize int: 100 测试时的样本数量 --nz int: 110 随机噪声维度 --check_path str: './checkpoints' 模型保存路径 --imageSize int: 128 读取/生成图像尺寸 --result_path str:'./result' 结果保存路径 --num_classes int: 10 图像类别 --which_epoch int: 499 测试模型序号
Accuracy | D Loss | G Loss |
---|---|---|
-
生成的图像与真实图像(类别序号 10-20)
生成的假样本 参考实现生成的假样本 真实样本 -
更多类别结果对比
类别 假样本1 假样本2 假样本3 真实样本 100-110序号类别 随机10类别
关于模型的其他信息,可以参考下表:
信息 | 说明 |
---|---|
发布者 | 戴家武 |
时间 | 2021.09 |
框架版本 | Paddle 2.0.2 |
应用场景 | 图像生成 |
支持硬件 | GPU、CPU |
下载链接 | 预训练模型 (提取码:ce8r) |
在线运行 | 脚本任务 |