GAN for image colorization based on Johnson's network structure.
Install the following Python libraries:
- numpy
- scipy
- Pytorch
- scikit-image
- Pillow
- opencv-python
#Download pre-trained model
wget -O model.pth "https://github.com/zeruniverse/neural-colorization/releases/download/1.1/G.pth"
#Colorize an image with CPU
python colorize.py -m model.pth -i input.jpg -o output.jpg --gpu -1
# If you want to colorize all images in a folder with GPU
python colorize.py -m model.pth -i input -o output --gpu 0
Note: Training is only supported with GPU (CUDA).
- Download some datasets and unzip them into a same folder (saying
train_raw_dataset
). If the images are not in.jpg
format, you should convert them all in.jpg
s. - run
python build_dataset_directory.py -i train_raw_dataset -o train
(you can skip this if all your images are directly under thetrain_raw_dataset
, in which case, just rename the folder astrain
) - run
python resize_all_imgs.py -d train
to resize all training images into256*256
(you can skip this if your images are already in256*256
)
It's highly recommended to train from my pretrained models. You can get both generator model and discriminator model from the GitHub Release:
wget "https://github.com/zeruniverse/neural-colorization/releases/download/1.1/G.pth"
wget "https://github.com/zeruniverse/neural-colorization/releases/download/1.1/D.pth"
It's also recommended to have a test image (the script will generate a colorization for the test image you give at every checkpoint so you can see how the model works during training).
The required arguments are training image directory (e.g. train
) and path to save checkpoints (e.g. checkpoints
)
python train.py -d train -c chekpoints
To add initial weights and test images:
python train.py -d train -c chekpoints --d_init D.pth --g_init G.pth -t test.jpg
More options are available and you can run python train.py --help
to print all options.
For torch equivalent (no GAN), you can set option -p 1e9
(set a very large weight for pixel loss).
Perceptual Losses for Real-Time Style Transfer and Super-Resolution
GNU GPL 3.0 for personal or research use. COMMERCIAL USE PROHIBITED.
Model weights are released under CC BY 4.0