Skip to content

Commit

Permalink
parametrize arch tests
Browse files Browse the repository at this point in the history
  • Loading branch information
the-database committed Jun 30, 2024
1 parent 451f108 commit 40e988e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Tests

# Controls when the workflow will run
on:
pull_request:
branches: ['*']
types:
- opened
- synchronize
- closed
paths:
- '**.py'
- '.github/workflows/**'
- '.pyproject.toml'
- 'requirements.txt'
push:
branches: [main]

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
- run: pip install -r requirements.txt
- run: pytest ./tests/test_archs # TODO expand tests
45 changes: 36 additions & 9 deletions tests/test_archs/test_spandrel_archs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import os
import sys
from collections.abc import Callable
Expand Down Expand Up @@ -27,6 +28,14 @@
if name not in EXCLUDE_ARCHS
]

ALL_SCALES = [1, 2, 3, 4]

FILTERED_REGISTRIES_SCALES = [
(*a, b) for a, b in itertools.product(FILTERED_REGISTRY, ALL_SCALES)
]

EXCLUDE_ARCH_SCALES = {"swinir_l": [3], "realcugan": [1]}


class TestArchData(TypedDict):
device: str
Expand All @@ -49,34 +58,52 @@ def data() -> TestArchData:

class TestArchs:
@pytest.mark.parametrize(
"arch",
[pytest.param(arch, id=f"test_{name}") for name, arch in FILTERED_REGISTRY],
"name,arch,scale",
[
pytest.param(name, arch, scale, id=f"test_{name}_{scale}x")
for name, arch, scale in FILTERED_REGISTRIES_SCALES
],
)
def test_arch_inference(
self, data: TestArchData, arch: Callable[..., nn.Module]
self,
data: TestArchData,
name: str,
arch: Callable[..., nn.Module],
scale: int,
) -> None:
if name in EXCLUDE_ARCH_SCALES and scale in EXCLUDE_ARCH_SCALES[name]:
pytest.skip(f"Skipping known unsupported {scale}x scale for {name}")

device = data["device"]
lq = data["lq"]
dtype = data["dtype"]
scale = 5
model = arch(scale=scale).eval().to(device, dtype=dtype)

with torch.inference_mode():
output = model(lq)
assert (
output.shape[2] == lq.shape[2] * scale
and output.shape[3] == lq.shape[3] * scale
)
), f"{name}: {output.shape} is not {scale}x {lq.shape}"

@pytest.mark.parametrize(
"arch",
[pytest.param(arch, id=f"train_{name}") for name, arch in FILTERED_REGISTRY],
"name,arch,scale",
[
pytest.param(name, arch, scale, id=f"train_{name}_{scale}")
for name, arch, scale in FILTERED_REGISTRIES_SCALES
],
)
def test_arch_training(
self, data: TestArchData, arch: Callable[..., nn.Module]
self,
data: TestArchData,
name: str,
arch: Callable[..., nn.Module],
scale: int,
) -> None:
if name in EXCLUDE_ARCH_SCALES and scale in EXCLUDE_ARCH_SCALES[name]:
pytest.skip(f"Skipping known unsupported {scale}x scale for {name}")

device = data["device"]
scale = 4
lq = data["lq"]
gt_shape = (lq.shape[0], lq.shape[1], lq.shape[2] * scale, lq.shape[3] * scale)
dtype = data["dtype"]
Expand Down

0 comments on commit 40e988e

Please sign in to comment.