Skip to content

Commit

Permalink
Merge pull request #162 from the-database/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
the-database authored Dec 18, 2024
2 parents 6f3b3ca + 43a06f6 commit da1837d
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = []
extensions = ["sphinx-jsonschema"]

templates_path = ["_templates"]
exclude_patterns = []
Expand Down
1 change: 1 addition & 0 deletions docs/source/config_reference.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.. jsonschema:: https://raw.githubusercontent.com/the-database/traiNNer-redux/refs/heads/master/schemas/redux-config.schema.json
3 changes: 3 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ documentation for details.
:maxdepth: 2
:caption: Contents:

/index
/training_guidelines
/config_reference
1 change: 1 addition & 0 deletions docs/source/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
sphinx
sphinx-rtd-theme
sphinx-jsonschema
7 changes: 4 additions & 3 deletions traiNNer/archs/rcanmod_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn
from torch.nn import functional as F # noqa: N812
from torch.nn.init import trunc_normal_
from traiNNer.utils.registry import ARCH_REGISTRY


def default_conv(
Expand Down Expand Up @@ -116,7 +117,7 @@ def __init__(
modules_body.append(nn.BatchNorm2d(n_feat))
if i == 0:
modules_body.append(act)
modules_body.append(MSG(n_feat))
modules_body.append(SpatialSELayer(n_feat))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale

Expand Down Expand Up @@ -243,9 +244,9 @@ def forward(self, x):
return res


# @ARCH_REGISTRY.register()
@ARCH_REGISTRY.register()
## Residual Channel Attention Network (RCAN)
class RCANMod(nn.Module):
class RCANSpatialSELayer(nn.Module):
def __init__(
self,
scale: int = 4,
Expand Down

0 comments on commit da1837d

Please sign in to comment.