Skip to content

Commit

Permalink
add arch: CBDNet and UNet
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanxingql committed Sep 24, 2024
1 parent 8473f67 commit 299033f
Show file tree
Hide file tree
Showing 5 changed files with 489 additions and 1 deletion.
51 changes: 51 additions & 0 deletions options/test/CBDNet/DIV2K_LMDB_G1_latest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# general settings
name: test_CBDNet_DIV2K_LMDB_G1_latest
model_type: SRModel
scale: 1
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0

# dataset settings
datasets:
test: # multiple test datasets are acceptable
name: DIV2K
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/valid
dataroot_lq: datasets/DIV2K/valid_BPG_QP37
io_backend:
type: disk

# network structures
network_g:
type: CBDNet
io_channels: 3
estimate_channels: 32
nlevel_denoise: 3
nf_base_denoise: 64

# path
path:
pretrain_network_g: experiments/train_CBDNet_DIV2K_LMDB_G1/models/net_g_latest.pth
param_key_g: params_ema # load the ema model
strict_load_g: true

# validation settings
val:
save_img: false
suffix: ~ # add suffix to saved images, if None, use exp name

metrics:
psnr:
type: pyiqa
ssim:
type: pyiqa
lpips:
type: pyiqa
clipiqa+:
type: pyiqa
topiq_fr:
type: pyiqa
musiq:
type: pyiqa
wadiqam_fr:
type: pyiqa
96 changes: 96 additions & 0 deletions options/train/CBDNet/DIV2K_LMDB_G1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# general settings
name: train_CBDNet_DIV2K_LMDB_G1
model_type: QEModel
scale: 1
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 0

# dataset and data loader settings
datasets:
train:
name: DIV2K
type: PairedImageDataset
# dataroot_gt: datasets/DIV2K/train
# dataroot_lq: datasets/DIV2K/train_BPG_QP37
# io_backend:
# type: disk
dataroot_gt: datasets/DIV2K/train_size128_step64_thresh0.lmdb
dataroot_lq: datasets/DIV2K/train_BPG_QP37_size128_step64_thresh0.lmdb
io_backend:
type: lmdb

gt_size: 128 # in accord with LMDB
use_hflip: true
use_rot: true

# data loader
num_worker_per_gpu: 16
batch_size_per_gpu: 16
dataset_enlarge_ratio: 1
prefetch_mode: ~

val:
name: DIV2K
type: PairedImageDataset
dataroot_gt: datasets/DIV2K/valid
dataroot_lq: datasets/DIV2K/valid_BPG_QP37
io_backend:
type: disk

# network structures
network_g:
type: CBDNet
io_channels: 3
estimate_channels: 32
nlevel_denoise: 3
nf_base_denoise: 64

# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~

# training settings
train:
ema_decay: 0.999
optim_g:
type: Adam
lr: !!float 2e-4
weight_decay: 0
betas: [0.9, 0.99]

scheduler:
type: CosineAnnealingRestartLR
periods: [500000]
restart_weights: [1]
eta_min: !!float 1e-7

total_iter: 500000
warmup_iter: -1 # no warm up

# losses
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean

# validation settings
val:
val_freq: !!float 5e4
save_img: false

metrics:
psnr:
type: pyiqa

# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 1e4
use_tb_logger: true

# dist training settings
dist_params:
backend: nccl
port: 29500
4 changes: 3 additions & 1 deletion powerqe/archs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from .identitynet_arch import IdentityNet
from .registry import ARCH_REGISTRY
from .cbdnet_arch import CBDNet
from .unet_arch import UNet

__all__ = ["build_network", "ARCH_REGISTRY", "IdentityNet"]
__all__ = ["build_network", "ARCH_REGISTRY", "IdentityNet", "CBDNet", "UNet"]


def build_network(opt):
Expand Down
96 changes: 96 additions & 0 deletions powerqe/archs/cbdnet_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
from torch import nn as nn

from .unet_arch import UNet
from .registry import ARCH_REGISTRY


@ARCH_REGISTRY.register()
class CBDNet(nn.Module):
"""CBDNet network structure.
Args:
io_channels (int): Number of I/O channels.
estimate_channels (int): Channel number of the features in the estimation module.
nlevel_denoise (int): Level number of UNet for denoising.
nf_base_denoise (int): Base channel number of the features in the denoising module.
nf_gr_denoise (int): Growth rate of the channel number in the denoising module.
nl_base_denoise (int): Base convolution layer number in the denoising module.
nl_gr_denoise (int): Growth rate of the convolution layer number in the denoising module.
down_denoise (str): Downsampling method in the denoising module.
up_denoise (str): Upsampling method in the denoising module.
reduce_denoise (str): Reduction method for the guidance/feature maps in the denoising module.
"""

def __init__(
self,
io_channels=3,
estimate_channels=32,
nlevel_denoise=3,
nf_base_denoise=64,
nf_gr_denoise=2,
nl_base_denoise=1,
nl_gr_denoise=2,
down_denoise="avepool2d",
up_denoise="transpose2d",
reduce_denoise="add",
):
super().__init__()

estimate_list = nn.ModuleList(
[
nn.Conv2d(
in_channels=io_channels,
out_channels=estimate_channels,
kernel_size=3,
padding=3 // 2,
),
nn.ReLU(inplace=True),
]
)
for _ in range(3):
estimate_list += nn.ModuleList(
[
nn.Conv2d(
in_channels=estimate_channels,
out_channels=estimate_channels,
kernel_size=3,
padding=3 // 2,
),
nn.ReLU(inplace=True),
]
)
estimate_list += nn.ModuleList(
[
nn.Conv2d(estimate_channels, io_channels, 3, padding=3 // 2),
nn.ReLU(inplace=True),
]
)
self.estimate = nn.Sequential(*estimate_list)

self.denoise = UNet(
nf_in=io_channels * 2,
nf_out=io_channels,
nlevel=nlevel_denoise,
nf_base=nf_base_denoise,
nf_gr=nf_gr_denoise,
nl_base=nl_base_denoise,
nl_gr=nl_gr_denoise,
down=down_denoise,
up=up_denoise,
reduce=reduce_denoise,
residual=False,
)

def forward(self, x):
"""
Args:
x (Tensor): Input tensor with the shape of (N, C, H, W).
Returns:
Tensor
"""
estimated_noise_map = self.estimate(x)
res = self.denoise(torch.cat([x, estimated_noise_map], dim=1))
out = res + x
return out
Loading

0 comments on commit 299033f

Please sign in to comment.