by Qiaole Dong*, Chenjie Cao*, Yanwei Fu
Paper and Supplemental Material (arXiv)
Our project page is available at https://dqiaole.github.io/ZITS_inpainting/.
🔥🔥🔥 News: Our Extended version ZITS++ has been accepted by TPAMI, codes and dataset have been released in here.
The overview of our ZITS. At first, the TSR model is used to restore structures with low resolutions. Then the simple CNN based upsampler is leveraged to upsample edge and line maps. Moreover, the upsampled sketch space is encoded and added to the FTR through ZeroRA to restore the textures.
- Releasing inference codes.
- Releasing pre-trained model.
- Releasing training codes.
-
Preparing the environment:
as there are some bugs when using GP loss with DDP (link), we strongly recommend installing Apex without CUDA extensions via torch1.9.0 for the multi-gpu training
conda create -n train_env python=3.6 conda activate train_env pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html pip install -r requirement.txt git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --no-build-isolation ./
-
For training, MST provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training as in MST.
The training masks we used are contained in coco_mask_list.txt and irregular_mask_list.txt, besides test_mask.zip includes 1000 test masks.
-
Download the pretrained masked wireframe detection model to the './ckpt' fold: LSM-HAWP (MST ICCV2021 retrained from HAWP CVPR2020).
-
Prepare the wireframes:
Update: No need prepare another environment anymore, just extract wireframes with following code
conda activate train_env python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path> --gpu_ids '0'
-
If you need to train the model, please download the pretrained models for perceptual loss, provided by LaMa:
mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/ wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth
-
Indoor Dataset and Test set of Places2 (Optional)
To download the full Indoor dataset: BaiduDrive, passward:hfok; Google drive (link).
The training and validation split of Indoor can be find on indoor_train_list.txt and indoor_val_list.txt.
The test set of our Places2 can be find on places2_test_list.txt.
Download pretrained models on Places2 here.
Link for BaiduDrive, password:qnm5
For batch test, you need to complete steps 3 and 4 above.
Put the pretrained models to the './ckpt' fold. Then modify the config file according to you image, mask and wireframes path.
Test on 256 images:
conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2 --config_file ./config_list/config_ZITS_places2.yml --GPU_ids '0'
Test on 512 images:
conda activate train_env
python FTR_inference.py --path ./ckpt/zits_places2_hr --config_file ./config_list/config_ZITS_HR_places2.yml --GPU_ids '0'
This code only supports squared images (or they will be center cropped).
conda activate train_env
python single_image_test.py --path <ckpt_path> --config_file <config_path> \
--GPU_ids '0' --img_path ./image.png --mask_path ./mask.png --save_path ./
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
--train_line_path [training_wireframes_path] \
--mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
--train_epoch 12 --validation_path [validation_data_path] \
--val_line_path [validation_wireframes_path] \
--valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP
python TSR_train.py --name places2_continous_edgeline --data_path [training_data_path] \
--train_line_path [training_wireframes_path] \
--mask_path ['irregular_mask_list.txt', 'coco_mask_list.txt'] \
--train_epoch 15 --validation_path [validation_data_path] \
--val_line_path [validation_wireframes_path] \
--valid_mask_path [validation_mask] --nodes 1 --gpus 1 --GPU_ids '0' --AMP --MaP
We recommend to use the pretrained SSU. You can also train your SSU refered to https://github.com/ewrfcas/StructureUpsampling.
python FTR_train.py --nodes 1 --gpus 1 --GPU_ids '0' --path ./ckpt/lama_places2 \
--config_file ./config_list/config_LAMA.yml --lama
256:
python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2 \
--config_file ./config_list/config_ZITS_places2.yml --DDP
256~512:
python FTR_train.py --nodes 1 --gpus 2 --GPU_ids '0,1' --path ./ckpt/places2_HR \
--config_file ./config_list/config_ZITS_HR_places2.yml --DDP
If you found our program helpful, please consider citing:
@inproceedings{dong2022incremental,
title={Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding},
author={Qiaole Dong and Chenjie Cao and Yanwei Fu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}