Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CI for Super Resolution example and tqdm bar to the example #2899

Merged
merged 38 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a7829a9
Add the example for Super-Resolution
guptaaryan16 Mar 3, 2023
74602d4
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 3, 2023
1b0baf3
Made some changes
guptaaryan16 Mar 3, 2023
7ebee49
Made some changes
guptaaryan16 Mar 6, 2023
f6b5b41
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 14, 2023
d810510
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 15, 2023
3982d7b
Add the time profiling features
guptaaryan16 Mar 15, 2023
bc219c7
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 17, 2023
982a0eb
Added torchvision dataset
guptaaryan16 Mar 17, 2023
51fe3df
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 17, 2023
0cd5c59
Changed the dataset used in README to cifar10
guptaaryan16 Mar 17, 2023
83f10e2
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 20, 2023
7bcea2f
Used snake case in arguments
guptaaryan16 Mar 20, 2023
698d76f
Made some changes
guptaaryan16 Mar 20, 2023
51f47b4
Make some formatting changes
guptaaryan16 Mar 20, 2023
235c908
Make the formatting changes
guptaaryan16 Mar 20, 2023
3b2fde9
some changes
guptaaryan16 Mar 20, 2023
0e2f9a3
update the crop method
guptaaryan16 Mar 21, 2023
3d9dda7
Made the suggested changes
guptaaryan16 Mar 21, 2023
a91912b
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 22, 2023
689b7e4
Add SR example to unit tests
guptaaryan16 Mar 22, 2023
3303d86
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 22, 2023
fb3f64a
Add tqdm to the SR example and some CI changes
guptaaryan16 Mar 22, 2023
051999e
Update unit-tests.yml
guptaaryan16 Mar 22, 2023
e36beff
Update unit-tests.yml
guptaaryan16 Mar 22, 2023
87456cd
changed crop_size in SR example
guptaaryan16 Mar 22, 2023
780dbdb
Made crop_size a parameter in SR example
guptaaryan16 Mar 22, 2023
b69c914
Add debug mode in SR example
guptaaryan16 Mar 24, 2023
4b1d337
Added Cifar image example
guptaaryan16 Mar 24, 2023
93766f7
autopep8 fix
guptaaryan16 Mar 24, 2023
93d1584
Some reformatting of files
guptaaryan16 Mar 24, 2023
8541b2c
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 24, 2023
655e569
Added Basic Profile Handler in SR example
guptaaryan16 Mar 27, 2023
2ce8749
made some changes
guptaaryan16 Mar 27, 2023
52b3043
Merge branch 'pytorch:master' into master
guptaaryan16 Mar 27, 2023
b53b150
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Mar 27, 2023
9f81e33
Update README
guptaaryan16 Mar 28, 2023
5ccd25d
Update README.md
guptaaryan16 Mar 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,8 @@ jobs:
#train
mkdir -p ~/.cache/torch/checkpoints/ && wget "https://download.pytorch.org/models/vgg16-397923af.pth" -O ~/.cache/torch/checkpoints/vgg16-397923af.pth
python examples/fast_neural_style/neural_style.py train --epochs 1 --cuda 0 --dataset test --dataroot . --image_size 32 --style_image examples/fast_neural_style/images/style_images/mosaic.jpg --style_size 32
- name: Run SR Example
if: ${{ matrix.os == 'ubuntu-latest' }}
run: |
# Super-Resolution
python examples/super_resolution/main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 1 --lr 0.001 --threads 2
5 changes: 3 additions & 2 deletions examples/super_resolution/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/sup
This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution.

```
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE]
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--crop_size CROPSIZE] [--batch_size BATCHSIZE]
[--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR]
[--cuda] [--threads THREADS] [--seed SEED]

Expand All @@ -14,6 +14,7 @@ PyTorch Super Res Example
optional arguments:
-h, --help show this help message and exit
--upscale_factor super resolution upscale factor
--crop_size cropped size of the images for training
--batch_size training batch size
--test_batch_size testing batch size
--n_epochs number of epochs to train for
Expand All @@ -30,7 +31,7 @@ This example trains a super-resolution network on the [Caltech101 dataset](https

### Train

`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`
`python main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`

### Super Resolve
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
10 changes: 8 additions & 2 deletions examples/super_resolution/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from torch.utils.data import DataLoader
from torchvision.transforms.functional import center_crop, resize, to_tensor

from ignite.contrib.handlers import ProgressBar

from ignite.engine import Engine, Events
from ignite.metrics import PSNR

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--crop_size", type=int, default=256, help="cropped size of the images for training")
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor")
parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size")
Expand All @@ -22,6 +25,7 @@
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training")
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use")
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123")

opt = parser.parse_args()

print(opt)
Expand Down Expand Up @@ -70,8 +74,8 @@ def __len__(self):
trainset = torchvision.datasets.Caltech101(root="./data", download=True)
testset = torchvision.datasets.Caltech101(root="./data", download=False)

trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor)
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor)
trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size)
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor, crop_size=opt.crop_size)

training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size)
Expand Down Expand Up @@ -145,4 +149,6 @@ def checkpoint():
print("Checkpoint saved to {}".format(model_out_path))


ProgressBar().attach(trainer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove log_training_loss handler and update ProgressBar options to display similar things.


vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
trainer.run(training_data_loader, opt.n_epochs)