Skip to content

Commit

Permalink
Merge pull request #16 from the-database/dev
Browse files Browse the repository at this point in the history
cast float32 for more losses, add missing scales for configs
  • Loading branch information
the-database authored Jun 28, 2024
2 parents 85b18eb + c819b6a commit cc02af3
Show file tree
Hide file tree
Showing 15 changed files with 24 additions and 17 deletions.
2 changes: 1 addition & 1 deletion options/train/Compact/Compact.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_Compact
scale: 4 # 1, 2, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/Compact/Compact_OTF.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_Compact_OTF
scale: 4 # 1, 2, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/DAT/DAT.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_DAT_2
scale: 4 # 2, 3, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/DAT/DAT_OTF.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_DAT_2_OTF
scale: 4 # 2, 3, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/OmniSR/OmniSR.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_OmniSR
scale: 4 # 2, 3, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/OmniSR/OmniSR_OTF.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_OmniSR_OTF
scale: 4 # 2, 3, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/PLKSR/PLKSR.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_PLKSR
scale: 4 # 2, 3, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/PLKSR/PLKSR_OTF.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_PLKSR_OTF
scale: 4 # 2, 3, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/RealPLKSR/RealPLKSR.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_RealPLKSR
scale: 4 # 2, 3, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
2 changes: 1 addition & 1 deletion options/train/RealPLKSR/RealPLKSR_OTF.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# General Settings
####################
name: 4x_RealPLKSR_OTF
scale: 4 # 2, 3, 4
scale: 4 # 1, 2, 3, 4, 8
use_amp: true # Speed up training and reduce VRAM usage.
amp_bf16: true # Use bf16 for AMP, RTX 3000 series or newer only.
fast_matmul: false # Trade precision for performance.
Expand Down
13 changes: 6 additions & 7 deletions scripts/options/generate_default_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class ArchInfo(TypedDict):


ALL_SCALES = [1, 2, 3, 4, 8]
SCALES_234 = [2, 3, 4]


def final_template(template: str, arch: ArchInfo) -> str:
Expand Down Expand Up @@ -46,19 +45,19 @@ def final_template(template: str, arch: ArchInfo) -> str:
"extras": {"use_pixel_unshuffle": "true"},
},
{"names": ["ATD"], "scales": ALL_SCALES},
{"names": ["DAT_2"], "scales": SCALES_234},
{"names": ["DAT_2"], "scales": ALL_SCALES},
{"names": ["HAT_L", "HAT_M", "HAT_S"], "scales": ALL_SCALES},
{"names": ["OmniSR"], "scales": SCALES_234},
{"names": ["PLKSR"], "scales": SCALES_234},
{"names": ["RealPLKSR"], "scales": SCALES_234},
{"names": ["OmniSR"], "scales": ALL_SCALES},
{"names": ["PLKSR"], "scales": ALL_SCALES},
{"names": ["RealPLKSR"], "scales": ALL_SCALES},
{
"names": ["RealCUGAN"],
"scales": SCALES_234,
"scales": [2, 3, 4],
"extras": {"pro": "true", "fast": "false"},
},
{"names": ["SPAN"], "scales": [2, 4]},
{"names": ["SRFormer", "SRFormer_light"], "scales": ALL_SCALES},
{"names": ["Compact", "UltraCompact", "SuperUltraCompact"], "scales": [1, 2, 4]},
{"names": ["Compact", "UltraCompact", "SuperUltraCompact"], "scales": ALL_SCALES},
{"names": ["SwinIR_L", "SwinIR_M", "SwinIR_S"], "scales": ALL_SCALES},
{"names": ["RGT", "RGT_S"], "scales": ALL_SCALES},
{"names": ["DRCT", "DRCT_L", "DRCT_XL"], "scales": ALL_SCALES},
Expand Down
4 changes: 4 additions & 0 deletions traiNNer/losses/basic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
self.reduction = reduction
self.eps = eps

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
self, pred: Tensor, target: Tensor, weight: Tensor | None = None, **kwargs
) -> Tensor:
Expand Down Expand Up @@ -156,6 +157,7 @@ def __init__(
else:
raise NotImplementedError(f"{criterion} criterion has not been supported.")

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, x: Tensor, y: Tensor) -> Tensor:
input_yuv = rgb2ycbcr_pt(x)
target_yuv = rgb2ycbcr_pt(y)
Expand Down Expand Up @@ -218,6 +220,7 @@ def __init__(
else:
raise NotImplementedError(f"{criterion} criterion has not been supported.")

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, x: Tensor, y: Tensor) -> Tensor:
return self.criterion(self.ds_f(x), self.ds_f(y)) * self.loss_weight

Expand All @@ -238,6 +241,7 @@ def __init__(self, criterion: str = "l1", loss_weight: float = 1.0) -> None:
else:
raise NotImplementedError(f"{criterion} criterion has not been supported.")

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, x: Tensor, y: Tensor) -> Tensor:
x_luma = rgb_to_luma(x)
y_luma = rgb_to_luma(y)
Expand Down
1 change: 1 addition & 0 deletions traiNNer/losses/contextual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
else: # if calc_type == 'regular':
self.calculate_loss = self.calculate_cx_loss

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, images: Tensor, gt: Tensor) -> Tensor:
device = images.device

Expand Down
2 changes: 2 additions & 0 deletions traiNNer/losses/gan_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_target_label(self, input: Tensor, target_is_real: bool) -> Tensor | bool
target_val = self.real_label_val if target_is_real else self.fake_label_val
return input.new_ones(input.size()) * target_val

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
self, input: Tensor, target_is_real: bool, is_disc: bool = False
) -> Tensor:
Expand Down Expand Up @@ -137,6 +138,7 @@ def __init__(
) -> None:
super().__init__(gan_type, real_label_val, fake_label_val, loss_weight)

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(
self, input: Tensor | list[Tensor], target_is_real: bool, is_disc: bool = False
) -> Tensor:
Expand Down
1 change: 1 addition & 0 deletions traiNNer/losses/perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
else:
raise NotImplementedError(f"{criterion} criterion has not been supported.")

@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, x: Tensor, gt: Tensor) -> tuple[Tensor | None, Tensor | None]:
"""Forward function.
Expand Down

0 comments on commit cc02af3

Please sign in to comment.