Skip to content

Commit

Permalink
feat(doctest): SSIM metric doctest (#2241)
Browse files Browse the repository at this point in the history
  • Loading branch information
ydcjeff authored Oct 15, 2021
1 parent 9468c0a commit a967d87
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
25 changes: 25 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,28 @@ jobs:
- name: make linkcheck
working-directory: ./docs/
run: make linkcheck --jobs 2 SPHINXOPTS="--color -W"

doctest:
if: github.event_name == 'pull_request' || github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.7

- run: sudo npm install katex@0.13.0 -g
- uses: actions/cache@v2
with:
path: ~/.cache/pip
key: pip-${{ hashFiles('requirements-dev.txt') }}-${{ hashFiles('docs/requirements.txt') }}

- name: Install docs deps
run: bash .github/workflows/install_docs_deps.sh

- name: make doctest
working-directory: ./docs/
run: |
make html SPHINXOPTS="--color -W"
make doctest
make coverage
13 changes: 13 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,16 @@ def run(self):
("py:class", "torch.optim.lr_scheduler._LRScheduler"),
("py:class", "torch.utils.data.dataloader.DataLoader"),
]

# doctest config
doctest_global_setup = """
import torch
from torch import nn, optim
from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
manual_seed(666)
"""
14 changes: 11 additions & 3 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,22 @@ class SSIM(Metric):
``y_pred`` and ``y`` can be un-normalized or normalized image tensors. Depending on that, the user might need
to adjust ``data_range``. ``y_pred`` and ``y`` should have the same shape.
.. code-block:: python
.. testcode::
def process_function(engine, batch):
# ...
y_pred, y = batch
return y_pred, y
engine = Engine(process_function)
metric = SSIM(data_range=1.0)
metric.attach(engine, "ssim")
metric.attach(engine, 'ssim')
preds = torch.rand([4, 3, 16, 16])
target = preds * 0.75
state = engine.run([[preds, target]])
print(state.metrics['ssim'])
.. testoutput::
0.9218971...
.. versionadded:: 0.4.2
"""
Expand Down

0 comments on commit a967d87

Please sign in to comment.