From 28e02e42457530d916bd7907ada158bad05805c9 Mon Sep 17 00:00:00 2001 From: ryanxingql <34084019+ryanxingql@users.noreply.github.com> Date: Tue, 8 Oct 2024 15:19:08 +0800 Subject: [PATCH] polish: update pre-commit hooks; add: csv and markdown logging --- .github/workflows/sync_to_private.yml | 2 +- .gitignore | 2 +- .pre-commit-config.yaml | 34 +- basicsr | 2 +- powerqe/archs/__init__.py | 27 +- powerqe/archs/arcnn_arch.py | 4 +- powerqe/archs/cbdnet_arch.py | 48 ++- powerqe/archs/dncnn_arch.py | 4 +- powerqe/archs/identitynet_arch.py | 2 +- powerqe/archs/mprnet_arch.py | 146 +++----- powerqe/archs/rbqe_arch.py | 329 +++++++----------- powerqe/archs/rdn_arch.py | 55 +-- powerqe/archs/unet_arch.py | 94 +++-- powerqe/data/__init__.py | 62 ++-- powerqe/data/registry.py | 1 - powerqe/losses/__init__.py | 8 +- powerqe/metrics/__init__.py | 4 +- powerqe/models/__init__.py | 7 +- powerqe/models/qe_model.py | 1 - powerqe/models/registry.py | 1 - powerqe/test.py | 27 +- powerqe/train.py | 202 ++++------- requirements.txt | 4 +- scripts/data_preparation/compress_img.py | 144 ++++---- scripts/data_preparation/create_lmdb.py | 13 +- scripts/data_preparation/extract_subimages.py | 50 +-- scripts/test.sh | 2 +- scripts/train.sh | 2 +- setup.cfg | 25 ++ 29 files changed, 556 insertions(+), 746 deletions(-) create mode 100644 setup.cfg diff --git a/.github/workflows/sync_to_private.yml b/.github/workflows/sync_to_private.yml index 41f7d36..6f969f1 100644 --- a/.github/workflows/sync_to_private.yml +++ b/.github/workflows/sync_to_private.yml @@ -25,4 +25,4 @@ jobs: - name: Sync to private repository run: | git remote add private git@github.com:ryanxingql/${{ secrets.PRIVATE_REPO_NAME }}.git - git push private basicsr-based-dev --force \ No newline at end of file + git push private basicsr-based-dev --force diff --git a/.gitignore b/.gitignore index 7d767ec..1b97180 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -.idea/ \ No newline at end of file +.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a394eb8..3057e31 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,37 @@ repos: - - repo: https://github.com/psf/black - rev: 24.8.0 + # flake8 + - repo: https://github.com/PyCQA/flake8 + rev: 7.1.1 hooks: - - id: black + - id: flake8 + + # yapf + - repo: https://github.com/google/yapf + rev: v0.40.2 + hooks: + - id: yapf + + # isort + - repo: https://github.com/timothycrosley/isort + rev: 5.13.2 + hooks: + - id: isort + + # codespell - repo: https://github.com/codespell-project/codespell rev: v2.3.0 hooks: - id: codespell + + # pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace # Trim trailing whitespace + - id: check-yaml # Attempt to load all yaml files to verify syntax + - id: check-merge-conflict # Check for files that contain merge conflict strings + - id: double-quote-string-fixer # Replace double-quoted strings with single quoted strings + - id: end-of-file-fixer # Make sure files end in a newline and only a newline + - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 + - id: mixed-line-ending # Replace or check mixed line ending + args: ["--fix=lf"] diff --git a/basicsr b/basicsr index c680165..17b9a1a 160000 --- a/basicsr +++ b/basicsr @@ -1 +1 @@ -Subproject commit c68016567bdb9b45293e28f53930503a3735a4e2 +Subproject commit 17b9a1a39d8001dc7b8d90598edb6ac66f1a9103 diff --git a/powerqe/archs/__init__.py b/powerqe/archs/__init__.py index aef0280..b100ddf 100644 --- a/powerqe/archs/__init__.py +++ b/powerqe/archs/__init__.py @@ -1,7 +1,6 @@ from copy import deepcopy from basicsr.utils import get_root_logger - from .arcnn_arch import ARCNN from .cbdnet_arch import CBDNet from .dcad_arch import DCAD @@ -14,24 +13,24 @@ from .unet_arch import UNet __all__ = [ - "ARCNN", - "CBDNet", - "DCAD", - "DnCNN", - "IdentityNet", - "MPRNet", - "RBQE", - "RDN", - "build_network", - "ARCH_REGISTRY", - "UNet", + 'ARCNN', + 'CBDNet', + 'DCAD', + 'DnCNN', + 'IdentityNet', + 'MPRNet', + 'RBQE', + 'RDN', + 'build_network', + 'ARCH_REGISTRY', + 'UNet', ] def build_network(opt): opt = deepcopy(opt) - network_type = opt.pop("type") + network_type = opt.pop('type') net = ARCH_REGISTRY.get(network_type)(**opt) logger = get_root_logger() - logger.info(f"Network [{net.__class__.__name__}] is created.") + logger.info(f'Network [{net.__class__.__name__}] is created.') return net diff --git a/powerqe/archs/arcnn_arch.py b/powerqe/archs/arcnn_arch.py index 20a1ea5..bd3ad0b 100644 --- a/powerqe/archs/arcnn_arch.py +++ b/powerqe/archs/arcnn_arch.py @@ -37,9 +37,7 @@ def __init__( super().__init__() self.layers = nn.Sequential( - nn.Conv2d( - io_channels, mid_channels_1, in_kernel_size, padding=in_kernel_size // 2 - ), + nn.Conv2d(io_channels, mid_channels_1, in_kernel_size, padding=in_kernel_size // 2), nn.ReLU(inplace=False), nn.Conv2d( mid_channels_1, diff --git a/powerqe/archs/cbdnet_arch.py b/powerqe/archs/cbdnet_arch.py index d327b78..d01ef14 100644 --- a/powerqe/archs/cbdnet_arch.py +++ b/powerqe/archs/cbdnet_arch.py @@ -1,8 +1,8 @@ import torch from torch import nn as nn -from .unet_arch import UNet from .registry import ARCH_REGISTRY +from .unet_arch import UNet @ARCH_REGISTRY.register() @@ -31,41 +31,35 @@ def __init__( nf_gr_denoise=2, nl_base_denoise=1, nl_gr_denoise=2, - down_denoise="avepool2d", - up_denoise="transpose2d", - reduce_denoise="add", + down_denoise='avepool2d', + up_denoise='transpose2d', + reduce_denoise='add', ): super().__init__() - estimate_list = nn.ModuleList( - [ + estimate_list = nn.ModuleList([ + nn.Conv2d( + in_channels=io_channels, + out_channels=estimate_channels, + kernel_size=3, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + ]) + for _ in range(3): + estimate_list += nn.ModuleList([ nn.Conv2d( - in_channels=io_channels, + in_channels=estimate_channels, out_channels=estimate_channels, kernel_size=3, padding=3 // 2, ), nn.ReLU(inplace=True), - ] - ) - for _ in range(3): - estimate_list += nn.ModuleList( - [ - nn.Conv2d( - in_channels=estimate_channels, - out_channels=estimate_channels, - kernel_size=3, - padding=3 // 2, - ), - nn.ReLU(inplace=True), - ] - ) - estimate_list += nn.ModuleList( - [ - nn.Conv2d(estimate_channels, io_channels, 3, padding=3 // 2), - nn.ReLU(inplace=True), - ] - ) + ]) + estimate_list += nn.ModuleList([ + nn.Conv2d(estimate_channels, io_channels, 3, padding=3 // 2), + nn.ReLU(inplace=True), + ]) self.estimate = nn.Sequential(*estimate_list) self.denoise = UNet( diff --git a/powerqe/archs/dncnn_arch.py b/powerqe/archs/dncnn_arch.py index 0030e8e..5e10e25 100644 --- a/powerqe/archs/dncnn_arch.py +++ b/powerqe/archs/dncnn_arch.py @@ -32,9 +32,7 @@ def __init__(self, io_channels=3, mid_channels=64, num_blocks=15, if_bn=False): layers += [ # bias is unnecessary and off due to the following BN nn.Conv2d(mid_channels, mid_channels, 3, padding=1, bias=False), - nn.BatchNorm2d( - num_features=mid_channels, momentum=0.9, eps=1e-04, affine=True - ), + nn.BatchNorm2d(num_features=mid_channels, momentum=0.9, eps=1e-04, affine=True), ] else: layers.append(nn.Conv2d(mid_channels, mid_channels, 3, padding=1)) diff --git a/powerqe/archs/identitynet_arch.py b/powerqe/archs/identitynet_arch.py index 76e7bde..a0a5dec 100644 --- a/powerqe/archs/identitynet_arch.py +++ b/powerqe/archs/identitynet_arch.py @@ -8,7 +8,7 @@ class IdentityNet(nn.Module): """Identity network used for testing benchmarks (in tensors). Support up-scaling.""" - def __init__(self, scale=1, upscale_mode="nearest"): + def __init__(self, scale=1, upscale_mode='nearest'): super().__init__() self.scale = scale self.upscale_mode = upscale_mode diff --git a/powerqe/archs/mprnet_arch.py b/powerqe/archs/mprnet_arch.py index 01d6c72..350da67 100644 --- a/powerqe/archs/mprnet_arch.py +++ b/powerqe/archs/mprnet_arch.py @@ -22,6 +22,7 @@ def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): # Channel Attention Layer class CALayer(nn.Module): + def __init__(self, channel, reduction=16, bias=False): super().__init__() @@ -43,6 +44,7 @@ def forward(self, x): # Channel Attention Block (CAB) class CAB(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, bias, act): super().__init__() @@ -64,6 +66,7 @@ def forward(self, x): # Supervised Attention Module class SAM(nn.Module): + def __init__(self, n_feat, kernel_size, bias): super().__init__() @@ -105,23 +108,19 @@ def pad_and_add(x, y): x_pads[0] = (-diff) // 2 x_pads[1] = (-diff) - (-diff) // 2 - x = nn_func.pad(input=x, pad=x_pads, mode="constant", value=0) - y = nn_func.pad(input=y, pad=y_pads, mode="constant", value=0) + x = nn_func.pad(input=x, pad=x_pads, mode='constant', value=0) + y = nn_func.pad(input=y, pad=y_pads, mode='constant', value=0) return x + y class Encoder(nn.Module): - def __init__( - self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff - ): + + def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): super().__init__() - self.encoder_level1 = [ - CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2) - ] + self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.encoder_level2 = [ - CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) - for _ in range(2) + CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2) ] self.encoder_level3 = [ CAB( @@ -130,8 +129,7 @@ def __init__( reduction, bias=bias, act=act, - ) - for _ in range(2) + ) for _ in range(2) ] self.encoder_level1 = nn.Sequential(*self.encoder_level1) @@ -174,17 +172,13 @@ def __init__( def forward(self, x, encoder_outs=None, decoder_outs=None): enc1 = self.encoder_level1(x) if (encoder_outs is not None) and (decoder_outs is not None): - enc1 = ( - enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) - ) + enc1 = (enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])) x = self.down12(enc1) enc2 = self.encoder_level2(x) if (encoder_outs is not None) and (decoder_outs is not None): - enc2 = ( - enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) - ) + enc2 = (enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])) x = self.down23(enc2) @@ -197,15 +191,13 @@ def forward(self, x, encoder_outs=None, decoder_outs=None): class Decoder(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): super().__init__() - self.decoder_level1 = [ - CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2) - ] + self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] self.decoder_level2 = [ - CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) - for _ in range(2) + CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2) ] self.decoder_level3 = [ CAB( @@ -214,8 +206,7 @@ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): reduction, bias=bias, act=act, - ) - for _ in range(2) + ) for _ in range(2) ] self.decoder_level1 = nn.Sequential(*self.decoder_level1) @@ -223,9 +214,7 @@ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): self.decoder_level3 = nn.Sequential(*self.decoder_level3) self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) - self.skip_attn2 = CAB( - n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act - ) + self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) self.up21 = SkipUpSample(n_feat, scale_unetfeats) self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats) @@ -245,14 +234,13 @@ def forward(self, outs): # Resizing Modules class DownSample(nn.Module): + def __init__(self, in_channels, s_factor): super().__init__() self.down = nn.Sequential( - nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=False), - nn.Conv2d( - in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False - ), + nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), + nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False), ) def forward(self, x): @@ -261,14 +249,13 @@ def forward(self, x): class UpSample(nn.Module): + def __init__(self, in_channels, s_factor): super().__init__() self.up = nn.Sequential( - nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), - nn.Conv2d( - in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False - ), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False), ) def forward(self, x): @@ -277,14 +264,13 @@ def forward(self, x): class SkipUpSample(nn.Module): + def __init__(self, in_channels, s_factor): super().__init__() self.up = nn.Sequential( - nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), - nn.Conv2d( - in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False - ), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False), ) def forward(self, x, y): @@ -294,13 +280,11 @@ def forward(self, x, y): # Original Resolution Block (ORB) class ORB(nn.Module): + def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): super().__init__() - modules_body = [ - CAB(n_feat, kernel_size, reduction, bias=bias, act=act) - for _ in range(num_cab) - ] + modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) @@ -311,6 +295,7 @@ def forward(self, x): class ORSNet(nn.Module): + def __init__( self, n_feat, @@ -324,15 +309,9 @@ def __init__( ): super().__init__() - self.orb1 = ORB( - n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab - ) - self.orb2 = ORB( - n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab - ) - self.orb3 = ORB( - n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab - ) + self.orb1 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) + self.orb2 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) + self.orb3 = ORB(n_feat + scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) self.up_enc1 = UpSample(n_feat, scale_unetfeats) self.up_dec1 = UpSample(n_feat, scale_unetfeats) @@ -346,25 +325,13 @@ def __init__( UpSample(n_feat, scale_unetfeats), ) - self.conv_enc1 = nn.Conv2d( - n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias - ) - self.conv_enc2 = nn.Conv2d( - n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias - ) - self.conv_enc3 = nn.Conv2d( - n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias - ) + self.conv_enc1 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_enc2 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_enc3 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) - self.conv_dec1 = nn.Conv2d( - n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias - ) - self.conv_dec2 = nn.Conv2d( - n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias - ) - self.conv_dec3 = nn.Conv2d( - n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias - ) + self.conv_dec1 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_dec2 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) + self.conv_dec3 = nn.Conv2d(n_feat, n_feat + scale_orsnetfeats, kernel_size=1, bias=bias) def forward(self, x, encoder_outs, decoder_outs): x = self.orb1(x) @@ -382,6 +349,7 @@ def forward(self, x, encoder_outs, decoder_outs): @ARCH_REGISTRY.register() class MPRNet(nn.Module): + def __init__( self, io_channels=3, @@ -410,19 +378,11 @@ def __init__( ) # Cross Stage Feature Fusion (CSFF) - self.stage1_encoder = Encoder( - n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False - ) - self.stage1_decoder = Decoder( - n_feat, kernel_size, reduction, act, bias, scale_unetfeats - ) + self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False) + self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) - self.stage2_encoder = Encoder( - n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True - ) - self.stage2_decoder = Decoder( - n_feat, kernel_size, reduction, act, bias, scale_unetfeats - ) + self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True) + self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) self.stage3_orsnet = ORSNet( n_feat, @@ -439,12 +399,8 @@ def __init__( self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias) - self.concat23 = conv( - n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias - ) - self.tail = conv( - n_feat + scale_orsnetfeats, io_channels, kernel_size, bias=bias - ) + self.concat23 = conv(n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias) + self.tail = conv(n_feat + scale_orsnetfeats, io_channels, kernel_size, bias=bias) def forward(self, x3_img): # Original-resolution Image for Stage 3 @@ -454,14 +410,14 @@ def forward(self, x3_img): # Multi-Patch Hierarchy: Split Image into four non-overlapping patches # Two Patches for Stage 2 - x2top_img = x3_img[:, :, 0 : int(hgt / 2), :] - x2bot_img = x3_img[:, :, int(hgt / 2) : hgt, :] + x2top_img = x3_img[:, :, 0:int(hgt / 2), :] + x2bot_img = x3_img[:, :, int(hgt / 2):hgt, :] # Four Patches for Stage 1 - x1ltop_img = x2top_img[:, :, :, 0 : int(wdt / 2)] - x1rtop_img = x2top_img[:, :, :, int(wdt / 2) : wdt] - x1lbot_img = x2bot_img[:, :, :, 0 : int(wdt / 2)] - x1rbot_img = x2bot_img[:, :, :, int(wdt / 2) : wdt] + x1ltop_img = x2top_img[:, :, :, 0:int(wdt / 2)] + x1rtop_img = x2top_img[:, :, :, int(wdt / 2):wdt] + x1lbot_img = x2bot_img[:, :, :, 0:int(wdt / 2)] + x1rbot_img = x2bot_img[:, :, :, int(wdt / 2):wdt] # Stage 1 diff --git a/powerqe/archs/rbqe_arch.py b/powerqe/archs/rbqe_arch.py index 98ca53b..8544f02 100644 --- a/powerqe/archs/rbqe_arch.py +++ b/powerqe/archs/rbqe_arch.py @@ -1,6 +1,5 @@ import math import numbers - import torch import torch.nn as nn import torch.nn.functional as nn_func @@ -36,16 +35,13 @@ def forward(self, x): # -> (N, C, 1) -> (N, 1, C) -> Conv (just like FC, but ks=3) # -> (N, 1, C) -> (N, C, 1) -> (N, C, 1, 1) logic = self.avg_pool(x) - logic = ( - self.conv(logic.squeeze(-1).transpose(-1, -2)) - .transpose(-1, -2) - .unsqueeze(-1) - ) + logic = (self.conv(logic.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)) logic = self.sigmoid(logic) return x * logic.expand_as(x) class SeparableConv2d(nn.Module): + def __init__(self, nf_in, nf_out): super().__init__() @@ -57,9 +53,7 @@ def __init__(self, nf_in, nf_out): padding=3 // 2, groups=nf_in, # each channel is convolved with its own filter ), - nn.Conv2d( - in_channels=nf_in, out_channels=nf_out, kernel_size=1, groups=1 - ), # then point-wise + nn.Conv2d(in_channels=nf_in, out_channels=nf_out, kernel_size=1, groups=1), # then point-wise ) def forward(self, x): @@ -91,16 +85,11 @@ def __init__(self, channels, kernel_size, sigma, padding, dim=2): # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 - meshgrids = torch.meshgrid( - [torch.arange(size, dtype=torch.float32) for size in kernel_size] - ) + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 - kernel *= ( - 1 - / (std * math.sqrt(2 * math.pi)) - * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) - ) # ignore the warning: it is a tensor + kernel *= (1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / std)**2) / 2) + ) # ignore the warning: it is a tensor # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # ignore the warning: it is a tensor @@ -109,7 +98,7 @@ def __init__(self, channels, kernel_size, sigma, padding, dim=2): kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) - self.register_buffer("weight", kernel) + self.register_buffer('weight', kernel) self.groups = channels self.padding = padding @@ -120,10 +109,8 @@ def __init__(self, channels, kernel_size, sigma, padding, dim=2): elif dim == 3: self.conv = nn_func.conv3d else: - raise ValueError( - "Data with 1/2/3 dimensions is supported;" - f" received {dim} dimensions." - ) + raise ValueError('Data with 1/2/3 dimensions is supported;' + f' received {dim} dimensions.') def forward(self, x): """Apply gaussian filter to input. @@ -134,14 +121,13 @@ def forward(self, x): Returns: Tensor: Filtered output. """ - return self.conv( - x, weight=self.weight, groups=self.groups, padding=self.padding - ) + return self.conv(x, weight=self.weight, groups=self.groups, padding=self.padding) class IQAM: - def __init__(self, comp_type="jpeg"): - if comp_type == "jpeg": + + def __init__(self, comp_type='jpeg'): + if comp_type == 'jpeg': self.patch_sz = 8 self.tche_poly = torch.tensor( @@ -223,7 +209,7 @@ def __init__(self, comp_type="jpeg"): self.thr_out = 0.855 - elif comp_type == "hevc": + elif comp_type == 'hevc': self.patch_sz = 4 self.tche_poly = torch.tensor( @@ -245,16 +231,12 @@ def __init__(self, comp_type="jpeg"): self.bigc = torch.tensor(1e-5) # numerical stability self.alpha_block = 0.9 # [0, 1] - self.gaussian_filter = GaussianSmoothing( - channels=1, kernel_size=3, sigma=5, padding=3 // 2 - ).cuda() + self.gaussian_filter = GaussianSmoothing(channels=1, kernel_size=3, sigma=5, padding=3 // 2).cuda() def cal_tchebichef_moments(self, x): x = x.clone() - x /= torch.sqrt( - self.patch_sz * self.patch_sz * (x.reshape((-1,)).pow(2).mean()) - ) - x -= x.reshape((-1,)).mean() + x /= torch.sqrt(self.patch_sz * self.patch_sz * (x.reshape((-1, )).pow(2).mean())) + x -= x.reshape((-1, )).mean() moments = torch.mm(torch.mm(self.tche_poly, x), self.tche_poly_transposed) return moments @@ -282,8 +264,8 @@ def forward(self, x): while start_w + self.patch_sz <= w_cut: patch = x[ - start_h : (start_h + self.patch_sz), - start_w : (start_w + self.patch_sz), + start_h:(start_h + self.patch_sz), + start_w:(start_w + self.patch_sz), ] sum_patch = torch.sum(torch.abs(patch)) @@ -301,50 +283,31 @@ def forward(self, x): num_textured += 1 patch_blurred = torch.squeeze( - self.gaussian_filter( - patch.clone().view(1, 1, self.patch_sz, self.patch_sz) - ) - ) - moments_patch_blurred = self.cal_tchebichef_moments( - patch_blurred - ) + self.gaussian_filter(patch.clone().view(1, 1, self.patch_sz, self.patch_sz))) + moments_patch_blurred = self.cal_tchebichef_moments(patch_blurred) similarity_matrix = torch.div( - ( - torch.mul(moments_patch, moments_patch_blurred) * 2.0 - + self.bigc - ), - (moments_patch.pow(2)) - + moments_patch_blurred.pow(2) - + self.bigc, - ) - score_blurred_textured += 1 - torch.mean( - similarity_matrix.reshape((-1)) + (torch.mul(moments_patch, moments_patch_blurred) * 2.0 + self.bigc), + (moments_patch.pow(2)) + moments_patch_blurred.pow(2) + self.bigc, ) + score_blurred_textured += 1 - torch.mean(similarity_matrix.reshape((-1))) else: num_smooth += 1 sum_moments = torch.sum(torch.abs(moments_patch)) strength_vertical = ( - torch.sum(torch.abs(moments_patch[self.patch_sz - 1, :])) - / sum_moments - - torch.abs(moments_patch[0, 0]) - + self.bigc - ) + torch.sum(torch.abs(moments_patch[self.patch_sz - 1, :])) / sum_moments - + torch.abs(moments_patch[0, 0]) + self.bigc) strength_horizontal = ( - torch.sum(torch.abs(moments_patch[:, self.patch_sz - 1])) - / sum_moments - - torch.abs(moments_patch[0, 0]) - + self.bigc - ) + torch.sum(torch.abs(moments_patch[:, self.patch_sz - 1])) / sum_moments - + torch.abs(moments_patch[0, 0]) + self.bigc) if strength_vertical > self.thr_jnd: strength_vertical = self.thr_jnd if strength_horizontal > self.thr_jnd: strength_horizontal = self.thr_jnd - score_ = torch.log( - 1 - ((strength_vertical + strength_horizontal) / 2) - ) / torch.log(1 - self.thr_jnd) + score_ = torch.log(1 - ( + (strength_vertical + strength_horizontal) / 2)) / torch.log(1 - self.thr_jnd) score_blocky_smooth = score_blocky_smooth + score_ @@ -360,9 +323,7 @@ def forward(self, x): else: score_blocky_smooth = torch.tensor(1.0, dtype=torch.float32) - score_quality = (score_blocky_smooth.pow(self.alpha_block)) * ( - score_blurred_textured.pow(1 - self.alpha_block) - ) + score_quality = (score_blocky_smooth.pow(self.alpha_block)) * (score_blurred_textured.pow(1 - self.alpha_block)) if score_quality >= self.thr_out: return True else: @@ -375,55 +336,43 @@ class Down(nn.Module): def __init__(self, nf_in, nf_out, method, if_separable, if_eca): super().__init__() - supported_methods = ["avepool2d", "strideconv"] + supported_methods = ['avepool2d', 'strideconv'] if method not in supported_methods: - raise NotImplementedError( - f'Downsampling method should be in "{supported_methods}";' - f' received "{method}".' - ) + raise NotImplementedError(f'Downsampling method should be in "{supported_methods}";' + f' received "{method}".') if if_separable and if_eca: - layers = nn.ModuleList( - [ECA(k_size=3), SeparableConv2d(nf_in=nf_in, nf_out=nf_in)] - ) + layers = nn.ModuleList([ECA(k_size=3), SeparableConv2d(nf_in=nf_in, nf_out=nf_in)]) elif if_separable and (not if_eca): layers = nn.ModuleList([SeparableConv2d(nf_in=nf_in, nf_out=nf_in)]) elif (not if_separable) and if_eca: - layers = nn.ModuleList( - [ - ECA(k_size=3), - nn.Conv2d( - in_channels=nf_in, - out_channels=nf_in, - kernel_size=3, - padding=3 // 2, - ), - ] - ) - else: - layers = nn.ModuleList( - [ - nn.Conv2d( - in_channels=nf_in, - out_channels=nf_in, - kernel_size=3, - padding=3 // 2, - ) - ] - ) - - if method == "avepool2d": - layers.append(nn.AvgPool2d(kernel_size=2)) - elif method == "strideconv": - layers.append( + layers = nn.ModuleList([ + ECA(k_size=3), nn.Conv2d( in_channels=nf_in, - out_channels=nf_out, + out_channels=nf_in, kernel_size=3, padding=3 // 2, - stride=2, - ) - ) + ), + ]) + else: + layers = nn.ModuleList([nn.Conv2d( + in_channels=nf_in, + out_channels=nf_in, + kernel_size=3, + padding=3 // 2, + )]) + + if method == 'avepool2d': + layers.append(nn.AvgPool2d(kernel_size=2)) + elif method == 'strideconv': + layers.append(nn.Conv2d( + in_channels=nf_in, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + stride=2, + )) if if_separable and if_eca: layers += [ @@ -443,14 +392,12 @@ def __init__(self, nf_in, nf_out, method, if_separable, if_eca): ), ] else: - layers.append( - nn.Conv2d( - in_channels=nf_out, - out_channels=nf_out, - kernel_size=3, - padding=3 // 2, - ) - ) + layers.append(nn.Conv2d( + in_channels=nf_out, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + )) self.layers = nn.Sequential(*layers) @@ -464,16 +411,14 @@ class Up(nn.Module): def __init__(self, nf_in_s, nf_in, nf_out, method, if_separable, if_eca): super().__init__() - supported_methods = ["upsample", "transpose2d"] + supported_methods = ['upsample', 'transpose2d'] if method not in supported_methods: - raise NotImplementedError( - f'Upsampling method should be in "{supported_methods}";' - f' received "{method}".' - ) + raise NotImplementedError(f'Upsampling method should be in "{supported_methods}";' + f' received "{method}".') - if method == "upsample": + if method == 'upsample': self.up = nn.Upsample(scale_factor=2) - elif method == "transpose2d": + elif method == 'transpose2d': self.up = nn.ConvTranspose2d( in_channels=nf_in_s, out_channels=nf_out, @@ -483,61 +428,53 @@ def __init__(self, nf_in_s, nf_in, nf_out, method, if_separable, if_eca): ) if if_separable and if_eca: - layers = nn.ModuleList( - [ - ECA(k_size=3), - SeparableConv2d(nf_in=nf_in, nf_out=nf_out), - nn.ReLU(inplace=True), - ECA(k_size=3), - SeparableConv2d(nf_in=nf_out, nf_out=nf_out), - ] - ) + layers = nn.ModuleList([ + ECA(k_size=3), + SeparableConv2d(nf_in=nf_in, nf_out=nf_out), + nn.ReLU(inplace=True), + ECA(k_size=3), + SeparableConv2d(nf_in=nf_out, nf_out=nf_out), + ]) elif if_separable and (not if_eca): - layers = nn.ModuleList( - [ - SeparableConv2d(nf_in=nf_in, nf_out=nf_out), - nn.ReLU(inplace=True), - SeparableConv2d(nf_in=nf_out, nf_out=nf_out), - ] - ) + layers = nn.ModuleList([ + SeparableConv2d(nf_in=nf_in, nf_out=nf_out), + nn.ReLU(inplace=True), + SeparableConv2d(nf_in=nf_out, nf_out=nf_out), + ]) elif (not if_separable) and if_eca: - layers = nn.ModuleList( - [ - ECA(k_size=3), - nn.Conv2d( - in_channels=nf_in, - out_channels=nf_out, - kernel_size=3, - padding=3 // 2, - ), - nn.ReLU(inplace=True), - ECA(k_size=3), - nn.Conv2d( - in_channels=nf_out, - out_channels=nf_out, - kernel_size=3, - padding=3 // 2, - ), - ] - ) + layers = nn.ModuleList([ + ECA(k_size=3), + nn.Conv2d( + in_channels=nf_in, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + ECA(k_size=3), + nn.Conv2d( + in_channels=nf_out, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + ]) else: - layers = nn.ModuleList( - [ - nn.Conv2d( - in_channels=nf_in, - out_channels=nf_out, - kernel_size=3, - padding=3 // 2, - ), - nn.ReLU(inplace=True), - nn.Conv2d( - in_channels=nf_out, - out_channels=nf_out, - kernel_size=3, - padding=3 // 2, - ), - ] - ) + layers = nn.ModuleList([ + nn.Conv2d( + in_channels=nf_in, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + nn.ReLU(inplace=True), + nn.Conv2d( + in_channels=nf_out, + out_channels=nf_out, + kernel_size=3, + padding=3 // 2, + ), + ]) self.layers = nn.Sequential(*layers) def forward(self, small_t, *normal_t_list): @@ -559,7 +496,7 @@ def forward(self, small_t, *normal_t_list): feat = nn_func.pad( input=feat, pad=[dw // 2, (dw - dw // 2), dh // 2, (dh - dh // 2)], - mode="constant", + mode='constant', value=0, ) @@ -570,17 +507,18 @@ def forward(self, small_t, *normal_t_list): @ARCH_REGISTRY.register() class RBQE(nn.Module): + def __init__( self, nf_io=3, nf_base=32, nlevel=5, - down_method="strideconv", - up_method="transpose2d", + down_method='strideconv', + up_method='transpose2d', if_separable=False, if_eca=False, if_only_last_output=True, - comp_type="hevc", + comp_type='hevc', ): super().__init__() @@ -615,7 +553,7 @@ def __init__( for idx_unet in range(nlevel): setattr( self, - f"down_{idx_unet}", + f'down_{idx_unet}', Down( nf_in=nf_base, nf_out=nf_base, @@ -627,7 +565,7 @@ def __init__( for idx_up in range(idx_unet + 1): setattr( self, - f"up_{idx_unet}_{idx_up}", + f'up_{idx_unet}_{idx_up}', Up( nf_in_s=nf_base, nf_in=nf_base * (2 + idx_up), # dense connection @@ -646,11 +584,7 @@ def __init__( repeat_times = nlevel for _ in range(repeat_times): if if_separable and if_eca: - self.out_layers.append( - nn.Sequential( - ECA(k_size=3), SeparableConv2d(nf_in=nf_base, nf_out=nf_io) - ) - ) + self.out_layers.append(nn.Sequential(ECA(k_size=3), SeparableConv2d(nf_in=nf_base, nf_out=nf_io))) elif if_separable and (not if_eca): self.out_layers.append(SeparableConv2d(nf_in=nf_base, nf_out=nf_io)) elif (not if_separable) and if_eca: @@ -663,8 +597,7 @@ def __init__( kernel_size=3, padding=3 // 2, ), - ) - ) + )) else: self.out_layers.append( nn.Conv2d( @@ -672,8 +605,7 @@ def __init__( out_channels=nf_io, kernel_size=3, padding=3 // 2, - ) - ) + )) # IQA module # no trainable parameters @@ -693,9 +625,8 @@ def forward(self, x, idx_out=None): """ if self.if_only_last_output: if idx_out is not None: - raise ValueError( - "Exit cannot be indicated" " since there is only one exit." - ) + raise ValueError('Exit cannot be indicated' + ' since there is only one exit.') idx_out = self.nlevel - 1 feat = self.in_conv_seq(x) @@ -705,7 +636,7 @@ def forward(self, x, idx_out=None): out_img_list = [] for idx_unet in range(self.nlevel): # per U-Net - down = getattr(self, f"down_{idx_unet}") + down = getattr(self, f'down_{idx_unet}') feat = down(feat_level_unet[-1][0]) # the previous U-Net, the first level feat_up_list = [feat] @@ -717,12 +648,10 @@ def forward(self, x, idx_out=None): # It needs C2,1 to C2,3 at feat_level_unet[1][0], # feat_level_unet[2][1] and feat_level_unet[3][2]. # feat_level_unet now contains 4 lists. - for idx_, feat_level in enumerate(feat_level_unet[-(idx_up + 1) :]): - dense_inp_list.append( - feat_level[idx_] - ) # append features from previous U-Nets at the same level + for idx_, feat_level in enumerate(feat_level_unet[-(idx_up + 1):]): + dense_inp_list.append(feat_level[idx_]) # append features from previous U-Nets at the same level - up = getattr(self, f"up_{idx_unet}_{idx_up}") + up = getattr(self, f'up_{idx_unet}_{idx_up}') feat_up = up(feat_up_list[-1], *dense_inp_list) feat_up_list.append(feat_up) diff --git a/powerqe/archs/rdn_arch.py b/powerqe/archs/rdn_arch.py index b42a200..627f631 100644 --- a/powerqe/archs/rdn_arch.py +++ b/powerqe/archs/rdn_arch.py @@ -1,5 +1,4 @@ import math - import torch from torch import nn @@ -43,16 +42,10 @@ class RDB(nn.Module): def __init__(self, in_channels, channel_growth, num_layers): super().__init__() self.layers = nn.Sequential( - *[ - DenseLayer(in_channels + channel_growth * i, channel_growth) - for i in range(num_layers) - ] - ) + *[DenseLayer(in_channels + channel_growth * i, channel_growth) for i in range(num_layers)]) # local feature fusion - self.lff = nn.Conv2d( - in_channels + channel_growth * num_layers, in_channels, kernel_size=1 - ) + self.lff = nn.Conv2d(in_channels + channel_growth * num_layers, in_channels, kernel_size=1) def forward(self, x): """Forward function. @@ -78,9 +71,7 @@ def __init__(self, scale_factor, mode): self.mode = mode def forward(self, x): - x = self.interp( - x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False - ) + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) return x @@ -118,12 +109,12 @@ def __init__( self.num_layers = num_layers if not math.log2(rescale).is_integer(): - raise ValueError(f"Rescale factor ({rescale}) should be a power of 2.") + raise ValueError(f'Rescale factor ({rescale}) should be a power of 2.') if rescale == 1: self.downscale = nn.Identity() else: - self.downscale = Interpolate(scale_factor=1.0 / rescale, mode="bicubic") + self.downscale = Interpolate(scale_factor=1.0 / rescale, mode='bicubic') # shallow feature extraction self.sfe1 = nn.Conv2d(io_channels, mid_channels, kernel_size=3, padding=3 // 2) @@ -132,18 +123,12 @@ def __init__( # residual dense blocks self.rdbs = nn.ModuleList() for _ in range(self.num_blocks): - self.rdbs.append( - RDB(self.mid_channels, self.channel_growth, self.num_layers) - ) + self.rdbs.append(RDB(self.mid_channels, self.channel_growth, self.num_layers)) # global feature fusion self.gff = nn.Sequential( - nn.Conv2d( - self.mid_channels * self.num_blocks, self.mid_channels, kernel_size=1 - ), - nn.Conv2d( - self.mid_channels, self.mid_channels, kernel_size=3, padding=3 // 2 - ), + nn.Conv2d(self.mid_channels * self.num_blocks, self.mid_channels, kernel_size=1), + nn.Conv2d(self.mid_channels, self.mid_channels, kernel_size=3, padding=3 // 2), ) # upsampling @@ -152,22 +137,18 @@ def __init__( else: self.upscale = [] for _ in range(rescale // 2): - self.upscale.extend( - [ - nn.Conv2d( - self.mid_channels, - self.mid_channels * (2**2), - kernel_size=3, - padding=3 // 2, - ), - nn.PixelShuffle(2), - ] - ) + self.upscale.extend([ + nn.Conv2d( + self.mid_channels, + self.mid_channels * (2**2), + kernel_size=3, + padding=3 // 2, + ), + nn.PixelShuffle(2), + ]) self.upscale = nn.Sequential(*self.upscale) - self.output = nn.Conv2d( - self.mid_channels, io_channels, kernel_size=3, padding=3 // 2 - ) + self.output = nn.Conv2d(self.mid_channels, io_channels, kernel_size=3, padding=3 // 2) def forward(self, x): """Forward. diff --git a/powerqe/archs/unet_arch.py b/powerqe/archs/unet_arch.py index 731fa48..d4f5367 100644 --- a/powerqe/archs/unet_arch.py +++ b/powerqe/archs/unet_arch.py @@ -6,19 +6,18 @@ class Up(nn.Module): + def __init__(self, method, nf_in=None): super().__init__() - supported_methods = ["upsample", "transpose2d"] + supported_methods = ['upsample', 'transpose2d'] if method not in supported_methods: - raise NotImplementedError( - f'Upsampling method should be in "{supported_methods}";' - f' received "{method}".' - ) + raise NotImplementedError(f'Upsampling method should be in "{supported_methods}";' + f' received "{method}".') - if method == "upsample": - self.up = nn.Upsample(scale_factor=2, mode="bicubic", align_corners=False) - elif method == "transpose2d": + if method == 'upsample': + self.up = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=False) + elif method == 'transpose2d': self.up = nn.ConvTranspose2d( in_channels=nf_in, out_channels=nf_in // 2, @@ -34,10 +33,10 @@ def forward(self, inp_t, ref_big): diff_w = ref_big.size()[3] - feat.size()[3] # W if diff_h < 0: - feat = feat[:, :, : ref_big.size()[2], :] + feat = feat[:, :, :ref_big.size()[2], :] diff_h = 0 if diff_w < 0: - feat = feat[:, :, :, : ref_big.size()[3]] + feat = feat[:, :, :, :ref_big.size()[3]] diff_w = 0 # only pad H and W; left (diff_w//2) @@ -51,7 +50,7 @@ def forward(self, inp_t, ref_big): diff_h // 2, (diff_h - diff_h // 2), ], - mode="constant", + mode='constant', value=0, ) @@ -60,6 +59,7 @@ def forward(self, inp_t, ref_big): @ARCH_REGISTRY.register() class UNet(nn.Module): + def __init__( self, nf_in, @@ -71,48 +71,38 @@ def __init__( nl_base=1, nl_max=8, nl_gr=2, - down="avepool2d", - up="transpose2d", - reduce="concat", + down='avepool2d', + up='transpose2d', + reduce='concat', residual=True, ): super().__init__() - supported_up_methods = ["upsample", "transpose2d"] + supported_up_methods = ['upsample', 'transpose2d'] if up not in supported_up_methods: - raise NotImplementedError( - f'Upsampling method should be in "{supported_up_methods}";' - f' received "{up}".' - ) + raise NotImplementedError(f'Upsampling method should be in "{supported_up_methods}";' + f' received "{up}".') - supported_down_methods = ["avepool2d", "strideconv"] + supported_down_methods = ['avepool2d', 'strideconv'] if down not in supported_down_methods: - raise NotImplementedError( - f'Downsampling method should be in "{supported_down_methods}";' - f' received "{down}".' - ) + raise NotImplementedError(f'Downsampling method should be in "{supported_down_methods}";' + f' received "{down}".') - supported_reduce_methods = ["add", "concat"] + supported_reduce_methods = ['add', 'concat'] if reduce not in supported_reduce_methods: - raise NotImplementedError( - f'Reduce method should be in "{supported_reduce_methods}";' - f' received "{reduce}".' - ) + raise NotImplementedError(f'Reduce method should be in "{supported_reduce_methods}";' + f' received "{reduce}".') if residual and (nf_in != nf_out): - raise ValueError( - "The input channel number should be equal to the" - " output channel number." - ) + raise ValueError('The input channel number should be equal to the' + ' output channel number.') self.nlevel = nlevel self.reduce = reduce self.residual = residual self.inc = nn.Sequential( - nn.Conv2d( - in_channels=nf_in, out_channels=nf_base, kernel_size=3, padding=1 - ), + nn.Conv2d(in_channels=nf_in, out_channels=nf_base, kernel_size=3, padding=1), nn.ReLU(inplace=True), ) @@ -126,12 +116,12 @@ def __init__( # define downsampling operator - if down == "avepool2d": - setattr(self, f"down_{idx_level}", nn.AvgPool2d(kernel_size=2)) - elif down == "strideconv": + if down == 'avepool2d': + setattr(self, f'down_{idx_level}', nn.AvgPool2d(kernel_size=2)) + elif down == 'strideconv': setattr( self, - f"down_{idx_level}", + f'down_{idx_level}', nn.Sequential( nn.Conv2d( in_channels=nf_lst[-2], @@ -165,15 +155,15 @@ def __init__( ), nn.ReLU(inplace=True), ] - setattr(self, f"enc_{idx_level}", nn.Sequential(*module_lst)) + setattr(self, f'enc_{idx_level}', nn.Sequential(*module_lst)) for idx_level in range((nlevel - 2), -1, -1): # define upsampling operator - setattr(self, f"up_{idx_level}", Up(nf_in=nf_lst[idx_level + 1], method=up)) + setattr(self, f'up_{idx_level}', Up(nf_in=nf_lst[idx_level + 1], method=up)) # define decoding operator - if reduce == "add": + if reduce == 'add': module_lst = [ nn.Conv2d( in_channels=nf_lst[idx_level], @@ -203,11 +193,9 @@ def __init__( ), nn.ReLU(inplace=True), ] - setattr(self, f"dec_{idx_level}", nn.Sequential(*module_lst)) + setattr(self, f'dec_{idx_level}', nn.Sequential(*module_lst)) - self.outc = nn.Conv2d( - in_channels=nf_base, out_channels=nf_out, kernel_size=3, padding=1 - ) + self.outc = nn.Conv2d(in_channels=nf_base, out_channels=nf_out, kernel_size=3, padding=1) def forward(self, inp_t): feat = self.inc(inp_t) @@ -217,21 +205,21 @@ def forward(self, inp_t): map_lst = [] # guidance maps for idx_level in range(1, self.nlevel): map_lst.append(feat) # from level 0, 1, ..., (nlevel-1) - down = getattr(self, f"down_{idx_level}") - enc = getattr(self, f"enc_{idx_level}") + down = getattr(self, f'down_{idx_level}') + enc = getattr(self, f'enc_{idx_level}') feat = enc(down(feat)) # up for idx_level in range((self.nlevel - 2), -1, -1): - up = getattr(self, f"up_{idx_level}") - dec = getattr(self, f"dec_{idx_level}") + up = getattr(self, f'up_{idx_level}') + dec = getattr(self, f'dec_{idx_level}') g_map = map_lst[idx_level] up_feat = up(inp_t=feat, ref_big=g_map) - if self.reduce == "add": + if self.reduce == 'add': feat = up_feat + g_map - elif self.reduce == "concat": + elif self.reduce == 'concat': feat = torch.cat((up_feat, g_map), dim=1) feat = dec(feat) diff --git a/powerqe/data/__init__.py b/powerqe/data/__init__.py index 11a9cd5..0722bca 100644 --- a/powerqe/data/__init__.py +++ b/powerqe/data/__init__.py @@ -8,10 +8,9 @@ from basicsr.data.prefetch_dataloader import PrefetchDataLoader from basicsr.utils import get_root_logger from basicsr.utils.dist_util import get_dist_info - from .registry import DATASET_REGISTRY -__all__ = ["build_dataset", "build_dataloader", "DATASET_REGISTRY"] +__all__ = ['build_dataset', 'build_dataloader', 'DATASET_REGISTRY'] def build_dataset(dataset_opt): @@ -23,17 +22,13 @@ def build_dataset(dataset_opt): type (str): Dataset type. """ dataset_opt = deepcopy(dataset_opt) - dataset = DATASET_REGISTRY.get(dataset_opt["type"])(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) logger = get_root_logger() - logger.info( - f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.' - ) + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') return dataset -def build_dataloader( - dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None -): +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): """Build dataloader. Args: @@ -49,16 +44,16 @@ def build_dataloader( sampler (torch.utils.data.sampler): Data sampler. Default: None. seed (int | None): Seed. Default: None """ - phase = dataset_opt["phase"] + phase = dataset_opt['phase'] rank, _ = get_dist_info() - if phase == "train": + if phase == 'train': if dist: # distributed training - batch_size = dataset_opt["batch_size_per_gpu"] - num_workers = dataset_opt["num_worker_per_gpu"] + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] else: # non-distributed training multiplier = 1 if num_gpu == 0 else num_gpu - batch_size = dataset_opt["batch_size_per_gpu"] * multiplier - num_workers = dataset_opt["num_worker_per_gpu"] * multiplier + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier dataloader_args = dict( dataset=dataset, batch_size=batch_size, @@ -68,34 +63,23 @@ def build_dataloader( drop_last=True, ) if sampler is None: - dataloader_args["shuffle"] = True - dataloader_args["worker_init_fn"] = ( - partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) - if seed is not None - else None - ) - elif phase in ["val", "test"]: # validation - dataloader_args = dict( - dataset=dataset, batch_size=1, shuffle=False, num_workers=0 - ) + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = ( + partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None) + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) else: - raise ValueError( - f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'." - ) + raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") - dataloader_args["pin_memory"] = dataset_opt.get("pin_memory", False) - dataloader_args["persistent_workers"] = dataset_opt.get("persistent_workers", False) + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) - prefetch_mode = dataset_opt.get("prefetch_mode") - if prefetch_mode == "cpu": # CPUPrefetcher - num_prefetch_queue = dataset_opt.get("num_prefetch_queue", 1) + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) logger = get_root_logger() - logger.info( - f"Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}" - ) - return PrefetchDataLoader( - num_prefetch_queue=num_prefetch_queue, **dataloader_args - ) + logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) else: # prefetch_mode=None: Normal dataloader # prefetch_mode='cuda': dataloader for CUDAPrefetcher diff --git a/powerqe/data/registry.py b/powerqe/data/registry.py index 988e487..3eb88df 100644 --- a/powerqe/data/registry.py +++ b/powerqe/data/registry.py @@ -1,4 +1,3 @@ from basicsr.utils.registry import DATASET_REGISTRY as DATASET_REGISTRY_BASICSR - DATASET_REGISTRY = DATASET_REGISTRY_BASICSR diff --git a/powerqe/losses/__init__.py b/powerqe/losses/__init__.py index 1db98c9..25238b6 100644 --- a/powerqe/losses/__init__.py +++ b/powerqe/losses/__init__.py @@ -1,11 +1,9 @@ from copy import deepcopy - from basicsr.utils import get_root_logger - from .registry import LOSS_REGISTRY -__all__ = ["build_loss", "LOSS_REGISTRY"] +__all__ = ['build_loss', 'LOSS_REGISTRY'] def build_loss(opt): @@ -16,8 +14,8 @@ def build_loss(opt): type (str): Model type. """ opt = deepcopy(opt) - loss_type = opt.pop("type") + loss_type = opt.pop('type') loss = LOSS_REGISTRY.get(loss_type)(**opt) logger = get_root_logger() - logger.info(f"Loss [{loss.__class__.__name__}] is created.") + logger.info(f'Loss [{loss.__class__.__name__}] is created.') return loss diff --git a/powerqe/metrics/__init__.py b/powerqe/metrics/__init__.py index 2bcff62..0ba93db 100644 --- a/powerqe/metrics/__init__.py +++ b/powerqe/metrics/__init__.py @@ -2,7 +2,7 @@ from .registry import METRIC_REGISTRY -__all__ = ["calculate_metric", "METRIC_REGISTRY"] +__all__ = ['calculate_metric', 'METRIC_REGISTRY'] def calculate_metric(data, opt): @@ -13,6 +13,6 @@ def calculate_metric(data, opt): type (str): Model type. """ opt = deepcopy(opt) - metric_type = opt.pop("type") + metric_type = opt.pop('type') metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) return metric diff --git a/powerqe/models/__init__.py b/powerqe/models/__init__.py index 92bfd2a..86cb031 100644 --- a/powerqe/models/__init__.py +++ b/powerqe/models/__init__.py @@ -1,11 +1,10 @@ from copy import deepcopy from basicsr.utils import get_root_logger - from .qe_model import QEModel from .registry import MODEL_REGISTRY -__all__ = ["build_model", "MODEL_REGISTRY", "QEModel"] +__all__ = ['build_model', 'MODEL_REGISTRY', 'QEModel'] def build_model(opt): @@ -16,7 +15,7 @@ def build_model(opt): model_type (str): Model type. """ opt = deepcopy(opt) - model = MODEL_REGISTRY.get(opt["model_type"])(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) logger = get_root_logger() - logger.info(f"Model [{model.__class__.__name__}] is created.") + logger.info(f'Model [{model.__class__.__name__}] is created.') return model diff --git a/powerqe/models/qe_model.py b/powerqe/models/qe_model.py index 4cf5017..9ed0a13 100644 --- a/powerqe/models/qe_model.py +++ b/powerqe/models/qe_model.py @@ -1,5 +1,4 @@ from basicsr.models.sr_model import SRModel - from .registry import MODEL_REGISTRY diff --git a/powerqe/models/registry.py b/powerqe/models/registry.py index 1d62fb7..980158d 100644 --- a/powerqe/models/registry.py +++ b/powerqe/models/registry.py @@ -1,4 +1,3 @@ from basicsr.utils.registry import MODEL_REGISTRY as MODEL_REGISTRY_BASICSR - MODEL_REGISTRY = MODEL_REGISTRY_BASICSR diff --git a/powerqe/test.py b/powerqe/test.py index 1de11de..17679ae 100644 --- a/powerqe/test.py +++ b/powerqe/test.py @@ -4,7 +4,6 @@ from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs from basicsr.utils.options import dict2str, parse_options - from powerqe.data import build_dataloader, build_dataset from powerqe.models import build_model @@ -13,7 +12,7 @@ def test_pipeline(root_path): # parse options, set distributed setting, set random seed opt, _ = parse_options(root_path, is_train=False) - if opt["reproduce"]: + if opt['reproduce']: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True, warn_only=True) @@ -23,24 +22,22 @@ def test_pipeline(root_path): # mkdir and initialize loggers make_exp_dirs(opt) - log_file = osp.join(opt["path"]["log"], f"{opt['name']}_{get_time_str()}.log") - logger = get_root_logger( - logger_name="basicsr", log_level=logging.INFO, log_file=log_file - ) + log_file = osp.join(opt['path']['log'], f"{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) logger.info(get_env_info()) logger.info(dict2str(opt)) # create test dataset and dataloader test_loaders = [] - for _, dataset_opt in sorted(opt["datasets"].items()): + for _, dataset_opt in sorted(opt['datasets'].items()): test_set = build_dataset(dataset_opt) test_loader = build_dataloader( test_set, dataset_opt, - num_gpu=opt["num_gpu"], - dist=opt["dist"], + num_gpu=opt['num_gpu'], + dist=opt['dist'], sampler=None, - seed=opt["manual_seed"], + seed=opt['manual_seed'], ) logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") test_loaders.append(test_loader) @@ -49,16 +46,16 @@ def test_pipeline(root_path): model = build_model(opt) for test_loader in test_loaders: - test_set_name = test_loader.dataset.opt["name"] - logger.info(f"Testing {test_set_name}...") + test_set_name = test_loader.dataset.opt['name'] + logger.info(f'Testing {test_set_name}...') model.validation( test_loader, - current_iter=opt["name"], + current_iter=opt['name'], tb_logger=None, - save_img=opt["val"]["save_img"], + save_img=opt['val']['save_img'], ) -if __name__ == "__main__": +if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) test_pipeline(root_path) diff --git a/powerqe/train.py b/powerqe/train.py index 1143029..b06cfc1 100644 --- a/powerqe/train.py +++ b/powerqe/train.py @@ -7,132 +7,100 @@ from basicsr.data.data_sampler import EnlargedSampler from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher -from basicsr.utils import ( - AvgTimer, - MessageLogger, - check_resume, - get_env_info, - get_root_logger, - get_time_str, - init_tb_logger, - init_wandb_logger, - make_exp_dirs, - mkdir_and_rename, - scandir, -) +from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, + init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) from basicsr.utils.options import copy_opt_file, dict2str, parse_options - -from powerqe.models import build_model from powerqe.data import build_dataloader, build_dataset +from powerqe.models import build_model def init_tb_loggers(opt): # initialize wandb logger before tensorboard logger to allow proper sync - if ( - (opt["logger"].get("wandb") is not None) - and (opt["logger"]["wandb"].get("project") is not None) - and ("debug" not in opt["name"]) - ): - assert ( - opt["logger"].get("use_tb_logger") is True - ), "should turn on tensorboard when using wandb" + if ((opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None) + and ('debug' not in opt['name'])): + assert (opt['logger'].get('use_tb_logger') is True), 'should turn on tensorboard when using wandb' init_wandb_logger(opt) tb_logger = None - if opt["logger"].get("use_tb_logger") and "debug" not in opt["name"]: - tb_logger = init_tb_logger( - log_dir=osp.join(opt["root_path"], "tb_logger", opt["name"]) - ) + if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: + tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name'])) return tb_logger def create_train_val_dataloader(opt, logger): # create train and val dataloaders train_loader, val_loaders = None, [] - for phase, dataset_opt in opt["datasets"].items(): - if phase == "train": - dataset_enlarge_ratio = dataset_opt.get("dataset_enlarge_ratio", 1) + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) train_set = build_dataset(dataset_opt) - train_sampler = EnlargedSampler( - train_set, opt["world_size"], opt["rank"], dataset_enlarge_ratio - ) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) train_loader = build_dataloader( train_set, dataset_opt, - num_gpu=opt["num_gpu"], - dist=opt["dist"], + num_gpu=opt['num_gpu'], + dist=opt['dist'], sampler=train_sampler, - seed=opt["manual_seed"], + seed=opt['manual_seed'], ) num_iter_per_epoch = math.ceil( - len(train_set) - * dataset_enlarge_ratio - / (dataset_opt["batch_size_per_gpu"] * opt["world_size"]) - ) - total_iters = int(opt["train"]["total_iter"]) + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) - logger.info( - "Training statistics:" - f"\n\tNumber of train images: {len(train_set)}" - f"\n\tDataset enlarge ratio: {dataset_enlarge_ratio}" - f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' - f'\n\tWorld size (gpu number): {opt["world_size"]}' - f"\n\tRequire iter number per epoch: {num_iter_per_epoch}" - f"\n\tTotal epochs: {total_epochs}; iters: {total_iters}." - ) - elif phase.split("_")[0] == "val": + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + elif phase.split('_')[0] == 'val': val_set = build_dataset(dataset_opt) val_loader = build_dataloader( val_set, dataset_opt, - num_gpu=opt["num_gpu"], - dist=opt["dist"], + num_gpu=opt['num_gpu'], + dist=opt['dist'], sampler=None, - seed=opt["manual_seed"], - ) - logger.info( - f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}' + seed=opt['manual_seed'], ) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}') val_loaders.append(val_loader) else: - raise ValueError(f"Dataset phase {phase} is not recognized.") + raise ValueError(f'Dataset phase {phase} is not recognized.') return train_loader, train_sampler, val_loaders, total_epochs, total_iters def load_resume_state(opt): resume_state_path = None - if opt["auto_resume"]: - state_path = osp.join("experiments", opt["name"], "training_states") + if opt['auto_resume']: + state_path = osp.join('experiments', opt['name'], 'training_states') if osp.isdir(state_path): - states = list( - scandir(state_path, suffix="state", recursive=False, full_path=False) - ) + states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) if len(states) != 0: - states = [float(v.split(".state")[0]) for v in states] - resume_state_path = osp.join(state_path, f"{max(states):.0f}.state") - opt["path"]["resume_state"] = resume_state_path + states = [float(v.split('.state')[0]) for v in states] + resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') + opt['path']['resume_state'] = resume_state_path else: - if opt["path"].get("resume_state"): - resume_state_path = opt["path"]["resume_state"] + if opt['path'].get('resume_state'): + resume_state_path = opt['path']['resume_state'] if resume_state_path is None: resume_state = None else: device_id = torch.cuda.current_device() - resume_state = torch.load( - resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id) - ) - check_resume(opt, resume_state["iter"]) + resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) + check_resume(opt, resume_state['iter']) return resume_state def train_pipeline(root_path): # parse options, set distributed setting, set random seed opt, args = parse_options(root_path, is_train=True) - opt["root_path"] = root_path + opt['root_path'] = root_path - if opt["reproduce"]: + if opt['reproduce']: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True, warn_only=True) @@ -145,22 +113,16 @@ def train_pipeline(root_path): # mkdir for experiments and logger if resume_state is None: make_exp_dirs(opt) - if ( - opt["logger"].get("use_tb_logger") - and "debug" not in opt["name"] - and opt["rank"] == 0 - ): - mkdir_and_rename(osp.join(opt["root_path"], "tb_logger", opt["name"])) + if (opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0): + mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name'])) # copy the yml file to the experiment root - copy_opt_file(args.opt, opt["path"]["experiments_root"]) + copy_opt_file(args.opt, opt['path']['experiments_root']) # WARNING: should not use get_root_logger in the above codes, including the called functions # Otherwise the logger will not be properly initialized - log_file = osp.join(opt["path"]["log"], f"{opt['name']}_{get_time_str()}.log") - logger = get_root_logger( - logger_name="basicsr", log_level=logging.INFO, log_file=log_file - ) + log_file = osp.join(opt['path']['log'], f"{opt['name']}_{get_time_str()}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) logger.info(get_env_info()) logger.info(dict2str(opt)) # initialize wandb and tb loggers @@ -174,11 +136,9 @@ def train_pipeline(root_path): model = build_model(opt) if resume_state: # resume training model.resume_training(resume_state) # handle optimizers and schedulers - logger.info( - f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}." - ) - start_epoch = resume_state["epoch"] - current_iter = resume_state["iter"] + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] else: start_epoch = 0 current_iter = 0 @@ -187,21 +147,19 @@ def train_pipeline(root_path): msg_logger = MessageLogger(opt, current_iter, tb_logger) # dataloader prefetcher - prefetch_mode = opt["datasets"]["train"].get("prefetch_mode") - if prefetch_mode is None or prefetch_mode == "cpu": + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': prefetcher = CPUPrefetcher(train_loader) - elif prefetch_mode == "cuda": + elif prefetch_mode == 'cuda': prefetcher = CUDAPrefetcher(train_loader, opt) - logger.info(f"Use {prefetch_mode} prefetch dataloader") - if opt["datasets"]["train"].get("pin_memory") is not True: - raise ValueError("Please set pin_memory=True for CUDAPrefetcher.") + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') else: - raise ValueError( - f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'." - ) + raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.") # training - logger.info(f"Start training from epoch: {start_epoch}, iter: {current_iter}") + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') data_timer, iter_timer = AvgTimer(), AvgTimer() start_time = time.time() @@ -218,9 +176,7 @@ def train_pipeline(root_path): break # update learning rate - model.update_learning_rate( - current_iter, warmup_iter=opt["train"].get("warmup_iter", -1) - ) + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) # training model.feed_data(train_data) model.optimize_parameters(current_iter) @@ -230,36 +186,28 @@ def train_pipeline(root_path): # not work in resume mode msg_logger.reset_start_time() # log - if current_iter % opt["logger"]["print_freq"] == 0: - log_vars = {"epoch": epoch, "iter": current_iter} - log_vars.update({"lrs": model.get_current_learning_rate()}) - log_vars.update( - { - "time": iter_timer.get_avg_time(), - "data_time": data_timer.get_avg_time(), - } - ) + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({ + 'time': iter_timer.get_avg_time(), + 'data_time': data_timer.get_avg_time(), + }) log_vars.update(model.get_current_log()) msg_logger(log_vars) # save models and training states - if current_iter % opt["logger"]["save_checkpoint_freq"] == 0: - logger.info("Saving models and training states.") + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') model.save(epoch, current_iter) # validation - if opt.get("val") is not None and ( - current_iter % opt["val"]["val_freq"] == 0 - or current_iter == total_iters - ): + if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0 + or current_iter == total_iters): if len(val_loaders) > 1: - logger.warning( - "Multiple validation datasets are *only* supported by SRModel." - ) + logger.warning('Multiple validation datasets are *only* supported by SRModel.') for val_loader in val_loaders: - model.validation( - val_loader, current_iter, tb_logger, opt["val"]["save_img"] - ) + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) data_timer.start() iter_timer.start() @@ -269,13 +217,13 @@ def train_pipeline(root_path): # end of epoch consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) - logger.info(f"End of training. Time consumed: {consumed_time}") - logger.info("Save the latest model.") + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') model.save(epoch=-1, current_iter=-1) # -1 stands for the latest if tb_logger: tb_logger.close() -if __name__ == "__main__": +if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) train_pipeline(root_path) diff --git a/requirements.txt b/requirements.txt index dc7b953..f2a6a4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -six --r basicsr/requirements.txt \ No newline at end of file +-r basicsr/requirements.txt +six diff --git a/scripts/data_preparation/compress_img.py b/scripts/data_preparation/compress_img.py index 567bc24..0becd2e 100644 --- a/scripts/data_preparation/compress_img.py +++ b/scripts/data_preparation/compress_img.py @@ -20,11 +20,10 @@ """ import argparse +import cv2 import multiprocessing as mp import os import os.path as osp - -import cv2 from tqdm import tqdm @@ -35,83 +34,78 @@ def run_cmd(cmd): def opencv_write_jpeg(src_path, quality, tar_path): img = cv2.imread(src_path) encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] # 0-100 - _, jpeg_data = cv2.imencode(".jpg", img, encode_param) + _, jpeg_data = cv2.imencode('.jpg', img, encode_param) comp_img = cv2.imdecode(jpeg_data, cv2.IMREAD_COLOR) cv2.imwrite(tar_path, comp_img) def parse_args(): - parser = argparse.ArgumentParser(description="Compress image dataset.") - parser.add_argument("--codec", type=str, required=True, choices=["BPG", "JPEG"]) - parser.add_argument( - "--dataset", type=str, required=True, choices=["DIV2K", "Flickr2K"] - ) - parser.add_argument("--max-npro", type=int, default=16) - parser.add_argument("--quality", type=int, default=37) + parser = argparse.ArgumentParser(description='Compress image dataset.') + parser.add_argument('--codec', type=str, required=True, choices=['BPG', 'JPEG']) + parser.add_argument('--dataset', type=str, required=True, choices=['DIV2K', 'Flickr2K']) + parser.add_argument('--max-npro', type=int, default=16) + parser.add_argument('--quality', type=int, default=37) args = parser.parse_args() return args -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() - if args.codec == "BPG": - enc_path = osp.abspath("datasets/libbpg/bpgenc") - dec_path = osp.abspath("datasets/libbpg/bpgdec") + if args.codec == 'BPG': + enc_path = osp.abspath('datasets/libbpg/bpgenc') + dec_path = osp.abspath('datasets/libbpg/bpgdec') paths = [] - if args.dataset == "DIV2K": - src_root = osp.abspath("datasets/DIV2K") - tmp_root = osp.abspath("tmp/datasets/DIV2K") - tar_root = osp.abspath("datasets/DIV2K") + if args.dataset == 'DIV2K': + src_root = osp.abspath('datasets/DIV2K') + tmp_root = osp.abspath('tmp/datasets/DIV2K') + tar_root = osp.abspath('datasets/DIV2K') # training set - src_dir = osp.join(src_root, "train") - tmp_dir = osp.join(tmp_root, f"train_BPG_QP{args.quality}") - tar_dir = osp.join(tar_root, f"train_BPG_QP{args.quality}") + src_dir = osp.join(src_root, 'train') + tmp_dir = osp.join(tmp_root, f'train_BPG_QP{args.quality}') + tar_dir = osp.join(tar_root, f'train_BPG_QP{args.quality}') os.makedirs(tmp_dir) os.makedirs(tar_dir) for idx in range(1, 801): paths.append( dict( - src=osp.join(src_dir, f"{idx:04d}.png"), - bpg=osp.join(tmp_dir, f"{idx:04d}.bpg"), - tar=osp.join(tar_dir, f"{idx:04d}.png"), - ) - ) + src=osp.join(src_dir, f'{idx:04d}.png'), + bpg=osp.join(tmp_dir, f'{idx:04d}.bpg'), + tar=osp.join(tar_dir, f'{idx:04d}.png'), + )) # validation set - src_dir = osp.join(src_root, "valid") - tmp_dir = osp.join(tmp_root, f"valid_BPG_QP{args.quality}") - tar_dir = osp.join(tar_root, f"valid_BPG_QP{args.quality}") + src_dir = osp.join(src_root, 'valid') + tmp_dir = osp.join(tmp_root, f'valid_BPG_QP{args.quality}') + tar_dir = osp.join(tar_root, f'valid_BPG_QP{args.quality}') os.makedirs(tmp_dir) os.makedirs(tar_dir) for idx in range(801, 901): paths.append( dict( - src=osp.join(src_dir, f"{idx:04d}.png"), - bpg=osp.join(tmp_dir, f"{idx:04d}.bpg"), - tar=osp.join(tar_dir, f"{idx:04d}.png"), - ) - ) - - if args.dataset == "Flickr2K": - src_dir = osp.abspath("datasets/Flickr2K") - tmp_dir = osp.abspath(f"tmp/datasets/Flickr2K/BPG_QP{args.quality}") - tar_dir = osp.abspath(f"datasets/Flickr2K/BPG_QP{args.quality}") + src=osp.join(src_dir, f'{idx:04d}.png'), + bpg=osp.join(tmp_dir, f'{idx:04d}.bpg'), + tar=osp.join(tar_dir, f'{idx:04d}.png'), + )) + + if args.dataset == 'Flickr2K': + src_dir = osp.abspath('datasets/Flickr2K') + tmp_dir = osp.abspath(f'tmp/datasets/Flickr2K/BPG_QP{args.quality}') + tar_dir = osp.abspath(f'datasets/Flickr2K/BPG_QP{args.quality}') os.makedirs(tmp_dir) os.makedirs(tar_dir) for idx in range(1, 2651): paths.append( dict( - src=osp.join(src_dir, f"{idx:06d}.png"), - bpg=osp.join(tmp_dir, f"{idx:06d}.bpg"), - tar=osp.join(tar_dir, f"{idx:06d}.png"), - ) - ) + src=osp.join(src_dir, f'{idx:06d}.png'), + bpg=osp.join(tmp_dir, f'{idx:06d}.bpg'), + tar=osp.join(tar_dir, f'{idx:06d}.png'), + )) # create meta # with open(osp.join(src_dir, "train.txt"), "w") as file: @@ -138,10 +132,10 @@ def parse_args(): for path in paths: enc_cmd = f'{enc_path} -o {path["bpg"]} -q {args.quality}' f' {path["src"]}' dec_cmd = f'{dec_path} -o {path["tar"]} {path["bpg"]}' - cmd = f"{enc_cmd} && {dec_cmd}" + cmd = f'{enc_cmd} && {dec_cmd}' pool.apply_async( func=run_cmd, - args=(cmd,), + args=(cmd, ), callback=lambda _: pbar.update(), error_callback=lambda err: print(err), ) @@ -149,51 +143,45 @@ def parse_args(): pool.join() pbar.close() - elif args.codec == "JPEG": + elif args.codec == 'JPEG': paths = [] - if args.dataset == "DIV2K": - src_root = osp.abspath("datasets/DIV2K") - tar_root = osp.abspath("datasets/DIV2K") + if args.dataset == 'DIV2K': + src_root = osp.abspath('datasets/DIV2K') + tar_root = osp.abspath('datasets/DIV2K') # training set - src_dir = osp.join(src_root, "train") - tar_dir = osp.join(tar_root, f"train_JPEG_QF{args.quality}") + src_dir = osp.join(src_root, 'train') + tar_dir = osp.join(tar_root, f'train_JPEG_QF{args.quality}') os.makedirs(tar_dir) for idx in range(1, 801): - paths.append( - dict( - src=osp.join(src_dir, f"{idx:04d}.png"), - tar=osp.join(tar_dir, f"{idx:04d}.png"), - ) - ) + paths.append(dict( + src=osp.join(src_dir, f'{idx:04d}.png'), + tar=osp.join(tar_dir, f'{idx:04d}.png'), + )) # validation set - src_dir = osp.join(src_root, "valid") - tar_dir = osp.join(tar_root, f"valid_JPEG_QF{args.quality}") + src_dir = osp.join(src_root, 'valid') + tar_dir = osp.join(tar_root, f'valid_JPEG_QF{args.quality}') os.makedirs(tar_dir) for idx in range(801, 901): - paths.append( - dict( - src=osp.join(src_dir, f"{idx:04d}.png"), - tar=osp.join(tar_dir, f"{idx:04d}.png"), - ) - ) - - if args.dataset == "Flickr2K": - src_dir = osp.abspath("datasets/Flickr2K") - tar_dir = osp.abspath(f"datasets/Flickr2K_JPEG_QF{args.quality}") + paths.append(dict( + src=osp.join(src_dir, f'{idx:04d}.png'), + tar=osp.join(tar_dir, f'{idx:04d}.png'), + )) + + if args.dataset == 'Flickr2K': + src_dir = osp.abspath('datasets/Flickr2K') + tar_dir = osp.abspath(f'datasets/Flickr2K_JPEG_QF{args.quality}') os.makedirs(tar_dir) for idx in range(1, 2651): - paths.append( - dict( - src=osp.join(src_dir, f"{idx:06d}.png"), - tar=osp.join(tar_dir, f"{idx:06d}.png"), - ) - ) + paths.append(dict( + src=osp.join(src_dir, f'{idx:06d}.png'), + tar=osp.join(tar_dir, f'{idx:06d}.png'), + )) # create meta # with open(osp.join(src_dir, "train.txt"), "w") as file: @@ -221,9 +209,9 @@ def parse_args(): pool.apply_async( func=opencv_write_jpeg, args=( - path["src"], + path['src'], args.quality, - path["tar"], + path['tar'], ), callback=lambda _: pbar.update(), error_callback=lambda err: print(err), diff --git a/scripts/data_preparation/create_lmdb.py b/scripts/data_preparation/create_lmdb.py index a8971d9..c11c9cb 100644 --- a/scripts/data_preparation/create_lmdb.py +++ b/scripts/data_preparation/create_lmdb.py @@ -1,4 +1,5 @@ import argparse + from basicsr.utils import scandir from basicsr.utils.lmdb_util import make_lmdb_from_imgs @@ -13,22 +14,22 @@ def prepare_keys_div2k(folder_path): list[str]: Image path list. list[str]: Key list. """ - print("Reading image path list ...") - img_path_list = sorted(list(scandir(folder_path, suffix="png", recursive=False))) - keys = [img_path.split(".png")[0] for img_path in sorted(img_path_list)] + print('Reading image path list ...') + img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=False))) + keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)] return img_path_list, keys -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--input_folder", + '--input_folder', type=str, required=True, ) parser.add_argument( - "--lmdb_path", + '--lmdb_path', type=str, required=True, ) diff --git a/scripts/data_preparation/extract_subimages.py b/scripts/data_preparation/extract_subimages.py index fe9d74d..42ad4b9 100644 --- a/scripts/data_preparation/extract_subimages.py +++ b/scripts/data_preparation/extract_subimages.py @@ -37,14 +37,14 @@ def worker(path): for y in w_space: index += 1 assert index < 1000 - cropped_img = img[x : x + args.crop_size, y : y + args.crop_size, ...] + cropped_img = img[x:x + args.crop_size, y:y + args.crop_size, ...] cropped_img = np.ascontiguousarray(cropped_img) cv2.imwrite( - osp.join(args.save_folder, f"{img_name}_s{index:03d}{extension}"), + osp.join(args.save_folder, f'{img_name}_s{index:03d}{extension}'), cropped_img, [cv2.IMWRITE_PNG_COMPRESSION, args.compression_level], ) - process_info = f"Processing {img_name} ..." + process_info = f'Processing {img_name} ...' return process_info @@ -52,50 +52,52 @@ def extract_subimages(): """Crop images to subimages.""" if not osp.exists(args.save_folder): os.makedirs(args.save_folder) - print(f"mkdir {args.save_folder} ...") + print(f'mkdir {args.save_folder} ...') else: - print(f"Folder {args.save_folder} already exists. Exit.") + print(f'Folder {args.save_folder} already exists. Exit.') sys.exit(1) img_list = list(scandir(args.input_folder, full_path=True)) - pbar = tqdm(total=len(img_list), unit="image", desc="Extract") + pbar = tqdm(total=len(img_list), unit='image', desc='Extract') pool = Pool(args.n_thread) for path in img_list: - pool.apply_async(worker, args=(path,), callback=lambda arg: pbar.update(1)) + pool.apply_async(worker, args=(path, ), callback=lambda arg: pbar.update(1)) pool.close() pool.join() pbar.close() - print("All processes done.") + print('All processes done.') -if __name__ == "__main__": +if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--dataset", type=str, required=True, choices=["DIV2K"]) - parser.add_argument("--n_thread", type=int, default=20, help="Thread number") + parser.add_argument('--dataset', type=str, required=True, choices=['DIV2K']) + parser.add_argument('--n_thread', type=int, default=20, help='Thread number') parser.add_argument( - "--compression_level", + '--compression_level', type=int, default=3, - help="CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.", + help=('CV_IMWRITE_PNG_COMPRESSION from 0 to 9. ' + 'A higher value means a smaller size and longer compression time. ' + 'Use 0 for faster CPU decompression. Default: 3, same in cv2.'), ) - parser.add_argument("--crop_size", type=int, default=128, help="Crop size") + parser.add_argument('--crop_size', type=int, default=128, help='Crop size') + parser.add_argument('--step', type=int, default=64, help='Step for overlapped sliding window') parser.add_argument( - "--step", type=int, default=64, help="Step for overlapped sliding window" - ) - parser.add_argument( - "--thresh_size", + '--thresh_size', type=int, default=0, - help="Threshold size. If the remaining portion at the edge of the image is smaller than thresh_size, that portion will be discarded and not included as a patch.", + help=('Threshold size. If the remaining portion at the edge of the image is smaller than thresh_size, ' + 'that portion will be discarded and not included as a patch.'), ) args = parser.parse_args() - if args.dataset == "DIV2K": - args.input_folder = "datasets/DIV2K/train" - args.save_folder = f"tmp/datasets/DIV2K/train_size{args.crop_size}_step{args.step}_thresh{args.thresh_size}" + if args.dataset == 'DIV2K': + args.input_folder = 'datasets/DIV2K/train' + args.save_folder = f'tmp/datasets/DIV2K/train_size{args.crop_size}_step{args.step}_thresh{args.thresh_size}' extract_subimages() - args.input_folder = "datasets/DIV2K/train_BPG_QP37" - args.save_folder = f"tmp/datasets/DIV2K/train_BPG_QP37_size{args.crop_size}_step{args.step}_thresh{args.thresh_size}" + args.input_folder = 'datasets/DIV2K/train_BPG_QP37' + args.save_folder = ('tmp/datasets/DIV2K/train_BPG_QP37_' + f'size{args.crop_size}_step{args.step}_thresh{args.thresh_size}') extract_subimages() diff --git a/scripts/test.sh b/scripts/test.sh index 9bea019..1e11058 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -21,4 +21,4 @@ else PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \ python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ powerqe/test.py -opt $CONFIG --launcher pytorch ${@:3} -fi \ No newline at end of file +fi diff --git a/scripts/train.sh b/scripts/train.sh index 0eb771f..08dcb9a 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -21,4 +21,4 @@ else PYTHONPATH="$(dirname $0)/..:${PYTHONPATH}" \ python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ powerqe/train.py -opt $CONFIG --launcher pytorch ${@:3} -fi \ No newline at end of file +fi diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..dc528cb --- /dev/null +++ b/setup.cfg @@ -0,0 +1,25 @@ +[flake8] +ignore = + # line break before binary operator (W503) + W503, + # line break after binary operator (W503) + W504, +max-line-length=120 + +[yapf] +based_on_style = pep8 +column_limit = 120 +blank_line_before_nested_class_or_def = true +split_before_expression_after_opening_paren = true + +[isort] +line_length = 120 +multi_line_output = 0 +known_standard_library = pkg_resources,setuptools +known_first_party = basicsr +known_third_party = PIL,cv2,distutils,lmdb,numpy,pyiqa,pytest,requests,scipy,skimage,torch,torchvision,tqdm,yaml +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY + +[codespell] +skip = .git,./docs/,*.cfg