From 6342fe0d14d7444bc9a9e3856ea6339da58b51ea Mon Sep 17 00:00:00 2001 From: Jonas Teuwen Date: Sun, 15 Oct 2023 18:42:44 +0200 Subject: [PATCH] First commit --- .flake8 | 3 + .github/ISSUE_TEMPLATE/bug_report.md | 26 + .github/ISSUE_TEMPLATE/feature_request.md | 20 + .github/PULL_REQUEST_TEMPLATE.md | 1 + .github/workflows/precommit_checks.yml | 50 + .github/workflows/requirements.txt | 3 + .gitignore | 121 +++ .mypy.ini | 37 + .pre-commit-config.yaml | 50 + CONTRIBUTING.rst | 133 +++ LICENSE | 201 ++++ Makefile | 94 ++ README.md | 13 + ahcore/__init__.py | 2 + ahcore/callbacks.py | 950 ++++++++++++++++++ ahcore/cli/__init__.py | 84 ++ ahcore/cli/data.py | 100 ++ ahcore/cli/tiling.py | 436 ++++++++ ahcore/data/__init__.py | 5 + ahcore/data/dataset.py | 273 +++++ ahcore/data/samplers.py | 111 ++ ahcore/entrypoints.py | 274 +++++ ahcore/exceptions.py | 15 + ahcore/lit_module.py | 216 ++++ ahcore/losses.py | 253 +++++ ahcore/metrics/__init__.py | 10 + ahcore/metrics/metrics.py | 399 ++++++++ ahcore/models/__init__.py | 5 + ahcore/py.typed | 0 ahcore/readers.py | 309 ++++++ ahcore/transforms/__init__.py | 6 + ahcore/transforms/augmentations.py | 322 ++++++ ahcore/transforms/pre_transforms.py | 238 +++++ ahcore/utils/__init__.py | 0 ahcore/utils/data.py | 67 ++ ahcore/utils/database_models.py | 204 ++++ ahcore/utils/io.py | 214 ++++ ahcore/utils/manifest.py | 393 ++++++++ ahcore/utils/rois.py | 127 +++ ahcore/utils/types.py | 31 + ahcore/writers.py | 205 ++++ config/augmentations/segmentation.yaml | 11 + config/augmentations/stages/train_seg.yaml | 76 ++ config/augmentations/stages/val_seg.yaml | 11 + config/augmentations/wsi_classification.yaml | 11 + .../compute_wsi_metrics_callback.yaml | 3 + config/callbacks/default.yaml | 41 + config/callbacks/early_stopping.yaml | 17 + config/callbacks/inference.yaml | 16 + config/callbacks/learningrate_monitor.yaml | 2 + config/callbacks/model_checkpoint.yaml | 19 + config/callbacks/model_summary.yaml | 7 + config/callbacks/rich_progress_bar.yaml | 6 + config/callbacks/write_h5_callback.yaml | 3 + config/callbacks/write_tiff_callback.yaml | 2 + config/datamodule/dataset.yaml | 3 + config/debug/default.yaml | 25 + .../example_folder/example_experiment.yaml | 23 + config/hparams_search/optuna.yaml | 18 + config/hydra/default.yaml | 11 + config/inference.yaml | 50 + config/lit_module/debug_lit.yaml | 2 + config/lit_module/example_model.yaml | 6 + .../monai_segmentation/attention_unet.yaml | 26 + .../monai_segmentation/swin_unetr.yaml | 37 + .../lit_module/monai_segmentation/unet.yaml | 29 + config/lit_module/new_unet.yaml | 25 + .../lit_module/resnet50_classification.yaml | 18 + config/lit_module/unet.yaml | 22 + config/logger/mlflow.yaml | 10 + config/logger/tensorboard.yaml | 7 + config/losses/segmentation_ce.yaml | 9 + config/losses/segmentation_tiger.yaml | 15 + config/losses/segmentation_topk.yaml | 9 + config/machine_settings/example.yaml | 13 + config/metrics/segmentation.yaml | 6 + config/paths/default.yaml | 13 + config/pre_transform/segmentation.yaml | 2 + config/pre_transform/tile_classification.yaml | 2 + config/pre_transform/wsi_classification.yaml | 2 + config/task/segmentation.yaml | 9 + config/task/segmentation_inference.yaml | 7 + config/train.yaml | 64 ++ config/trainer/cpu.yaml | 10 + config/trainer/default.yaml | 15 + config/trainer/default_ddp.yaml | 16 + docs/Makefile | 20 + docs/cli.rst | 5 + docs/conf.py | 276 +++++ docs/contributing.rst | 1 + docs/ext/doi_role.py | 51 + docs/index.rst | 26 + docs/make.bat | 36 + pyproject.toml | 45 + setup.cfg | 22 + setup.py | 80 ++ tools/.env.example | 5 + tools/README.md | 10 + tools/inference.py | 35 + tools/populate_tcga_db.py | 133 +++ tools/train.py | 35 + 101 files changed, 7510 insertions(+) create mode 100644 .flake8 create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/workflows/precommit_checks.yml create mode 100644 .github/workflows/requirements.txt create mode 100644 .gitignore create mode 100644 .mypy.ini create mode 100644 .pre-commit-config.yaml create mode 100644 CONTRIBUTING.rst create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 ahcore/__init__.py create mode 100644 ahcore/callbacks.py create mode 100644 ahcore/cli/__init__.py create mode 100644 ahcore/cli/data.py create mode 100644 ahcore/cli/tiling.py create mode 100644 ahcore/data/__init__.py create mode 100644 ahcore/data/dataset.py create mode 100644 ahcore/data/samplers.py create mode 100644 ahcore/entrypoints.py create mode 100644 ahcore/exceptions.py create mode 100644 ahcore/lit_module.py create mode 100644 ahcore/losses.py create mode 100644 ahcore/metrics/__init__.py create mode 100644 ahcore/metrics/metrics.py create mode 100644 ahcore/models/__init__.py create mode 100644 ahcore/py.typed create mode 100644 ahcore/readers.py create mode 100644 ahcore/transforms/__init__.py create mode 100644 ahcore/transforms/augmentations.py create mode 100644 ahcore/transforms/pre_transforms.py create mode 100644 ahcore/utils/__init__.py create mode 100644 ahcore/utils/data.py create mode 100644 ahcore/utils/database_models.py create mode 100644 ahcore/utils/io.py create mode 100644 ahcore/utils/manifest.py create mode 100644 ahcore/utils/rois.py create mode 100644 ahcore/utils/types.py create mode 100644 ahcore/writers.py create mode 100644 config/augmentations/segmentation.yaml create mode 100644 config/augmentations/stages/train_seg.yaml create mode 100644 config/augmentations/stages/val_seg.yaml create mode 100644 config/augmentations/wsi_classification.yaml create mode 100644 config/callbacks/compute_wsi_metrics_callback.yaml create mode 100644 config/callbacks/default.yaml create mode 100644 config/callbacks/early_stopping.yaml create mode 100644 config/callbacks/inference.yaml create mode 100644 config/callbacks/learningrate_monitor.yaml create mode 100644 config/callbacks/model_checkpoint.yaml create mode 100644 config/callbacks/model_summary.yaml create mode 100644 config/callbacks/rich_progress_bar.yaml create mode 100644 config/callbacks/write_h5_callback.yaml create mode 100644 config/callbacks/write_tiff_callback.yaml create mode 100644 config/datamodule/dataset.yaml create mode 100644 config/debug/default.yaml create mode 100644 config/experiment/example_folder/example_experiment.yaml create mode 100644 config/hparams_search/optuna.yaml create mode 100644 config/hydra/default.yaml create mode 100644 config/inference.yaml create mode 100644 config/lit_module/debug_lit.yaml create mode 100644 config/lit_module/example_model.yaml create mode 100644 config/lit_module/monai_segmentation/attention_unet.yaml create mode 100644 config/lit_module/monai_segmentation/swin_unetr.yaml create mode 100644 config/lit_module/monai_segmentation/unet.yaml create mode 100644 config/lit_module/new_unet.yaml create mode 100644 config/lit_module/resnet50_classification.yaml create mode 100644 config/lit_module/unet.yaml create mode 100644 config/logger/mlflow.yaml create mode 100644 config/logger/tensorboard.yaml create mode 100644 config/losses/segmentation_ce.yaml create mode 100644 config/losses/segmentation_tiger.yaml create mode 100644 config/losses/segmentation_topk.yaml create mode 100644 config/machine_settings/example.yaml create mode 100644 config/metrics/segmentation.yaml create mode 100644 config/paths/default.yaml create mode 100644 config/pre_transform/segmentation.yaml create mode 100644 config/pre_transform/tile_classification.yaml create mode 100644 config/pre_transform/wsi_classification.yaml create mode 100644 config/task/segmentation.yaml create mode 100644 config/task/segmentation_inference.yaml create mode 100644 config/train.yaml create mode 100644 config/trainer/cpu.yaml create mode 100644 config/trainer/default.yaml create mode 100644 config/trainer/default_ddp.yaml create mode 100644 docs/Makefile create mode 100644 docs/cli.rst create mode 100755 docs/conf.py create mode 100644 docs/contributing.rst create mode 100644 docs/ext/doi_role.py create mode 100644 docs/index.rst create mode 100644 docs/make.bat create mode 100644 pyproject.toml create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tools/.env.example create mode 100644 tools/README.md create mode 100644 tools/inference.py create mode 100644 tools/populate_tcga_db.py create mode 100644 tools/train.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..8332de3 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 +extend-ignore = E203 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..2fff59f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,26 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +A clear and concise description on how to reproduce the problem. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Environment** +dlup version: +How installed: +Python version: +Operating System: + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..11fc491 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..676b18c --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1 @@ +Fixes #{issue number} diff --git a/.github/workflows/precommit_checks.yml b/.github/workflows/precommit_checks.yml new file mode 100644 index 0000000..aacfb2c --- /dev/null +++ b/.github/workflows/precommit_checks.yml @@ -0,0 +1,50 @@ +# This action is triggered when a pull request review is submitted. +# It will run precommit checks only when the pull request is approved. +# Reference: https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#running-a-workflow-when-a-pull-request-is-approved + +name: Run Precommit Checks on PR Approval + +on: + workflow_dispatch: + push: + branches: + - main + pull_request_review: + types: [submitted] + +jobs: + # This job runs the precommit checks on the changed files. + run_precommit_checks: + + # Ensures the job only runs when the PR review state is "approved". + if: github.event.review.state == 'approved' + runs-on: ubuntu-22.04 + + steps: + # Checkout the repository to the GitHub Actions runner. + - name: Checkout Repository + uses: actions/checkout@v4 + + # Set up the desired Python environment. + - name: Setup Python Environment + uses: actions/setup-python@v4 + with: + python-version: '3.10.13' + cache: 'pip' + + # Install the required dependencies for local repos in the precommit hooks + - name: Install Dependencies + run: pip install -r .github/workflows/requirements.txt + + # Determine which python files have changed in the PR. + - name: Get Changed Files + id: changed-files + uses: tj-actions/changed-files@v36 + with: + files: '*.py' + + # Run the precommit hooks only on the changed files. + - name: Execute Precommit Hooks on Changed Files in PR + uses: pre-commit/action@v3.0.0 + with: + extra_args: --files ${{ steps.changed-files.outputs.all_changed_files }} diff --git a/.github/workflows/requirements.txt b/.github/workflows/requirements.txt new file mode 100644 index 0000000..d6df21f --- /dev/null +++ b/.github/workflows/requirements.txt @@ -0,0 +1,3 @@ +pylint==2.17.5 +mypy==1.5.1 +numpy==1.25.2 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1e2c5b1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,121 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# CMake +CMakeCache.txt +CMakeFiles + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# IDE settings +.vscode/ +.idea/ + +# Output of tests +dlup/preprocessors/tests/data/test_output + +# OS files +.DS_Store + + +tools/logs +config/machine_settings/* +!config/machine_settings/example.yaml diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 0000000..159fa83 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,37 @@ +[mypy] +plugins = numpy.typing.mypy_plugin + +[mypy-h5py.*] +ignore_missing_imports = True + +[mypy-pytorch_lightning.*] +ignore_missing_imports = True + +[mypy-torchvision.*] +ignore_missing_imports = True + +[mypy-torchmetrics.*] +ignore_missing_imports = True + +[mypy-scipy.ndimage.*] +ignore_missing_imports = True + +[mypy-kornia.*] +ignore_missing_imports = True + +[mypy-shapely.*] +ignore_missing_imports = True + +[mypy-tqdm.*] +ignore_missing_imports = True + +[mypy-cv2.*] +ignore_missing_imports = True + +; shouldn't this be in types-Pillow? +[mypy-PIL.*] +ignore_missing_imports = True + +# TODO: This needs to be fixed obviously +[mypy-dlup.*] +ignore_errors = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..1782942 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,50 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files +- repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black +- repo: https://github.com/pycqa/flake8 + rev: 6.1.0 + # Ignore the configuration files + hooks: + - id: flake8 + exclude: ^docs/ +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) +- repo: local # Use pylint from local environment as it requires to import packages + hooks: + - id: mypy + name: mypy + entry: mypy + language: system + types: [python] + args: ["ahcore", "--strict"] +- repo: https://github.com/Yelp/detect-secrets + rev: v1.4.0 + hooks: + - id: detect-secrets +- repo: local # Use pylint from local environment as it requires to import packages + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + types: [python] + args: + [ + "-rn", # Only display messages + "-sn", # Don't display the score + "--errors-only" # Only show the errors + ] diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst new file mode 100644 index 0000000..a1f79ad --- /dev/null +++ b/CONTRIBUTING.rst @@ -0,0 +1,133 @@ +Contributing +============ + +Contributions are welcome, and they are greatly appreciated! + +You can contribute in many ways: + +Types of Contributions +---------------------- +Report Bugs +########### +Report bugs by filing an `issue`_. + +If you are reporting a bug, please include: + +* Your operating system name and version. +* Any details about your local setup that might be helpful in troubleshooting. +* Detailed steps to reproduce the bug. + +Fix Bugs +######## +Look through the GitHub issues for bugs. Anything tagged with "bug" and "help +wanted" is open to whoever wants to implement it. + +Implement Features +################## +Look through the GitHub issues for features. Anything tagged with "enhancement" +and "help wanted" is open to whoever wants to implement it. + +Write Documentation +################### +Ahcore could always use more documentation, whether as part of the +official ahcore docs, in docstrings, or even on the web in blog posts, +articles, and such. + +Submit Feedback +############### +The best way to send feedback is to file an `issue`_. + +If you are proposing a feature: + +* Explain in detail how it would work. +* Keep the scope as narrow as possible, to make it easier to implement. +* Remember that this is a volunteer-driven project, and that contributions + are welcome :) + +Get Started! +############ +Ready to contribute? Here's how to set up `ahcore` for local development. + +1. Fork the `ahcore` repo on GitHub. +2. Clone your fork locally: + +.. code-block:: console + + $ git clone git@github.com:your_name_here/ahcore.git + +3. Create a virtual environment either through conda, docker or virtualenv + +4. Install git pre-commit hooks + + - Install `pre-commit`_. + - Install pre-commit hooks: :code:`pre-commit install`. + +5. Install your local copy into a virtual environment: + +.. code-block:: console + + $ cd ahcore/ + $ pip install --editable ".[dev]" + +4. Create a branch for local development: + +.. code-block:: console + + $ git checkout -b name-of-your-bugfix-or-feature` + +Now you can make your changes locally. + +5. When you're done making changes, check that your changes pass the tests and the pre-commit hooks and the + tests, including testing other Python versions with pre-commit: + +.. code-block:: console + + $ make tests + $ tox + $ pre-commit + +To get pylint and tox, just pip install them. + +6. Commit your changes and push your branch to GitHub: + +.. code-block:: console + + $ git add . + $ git commit -m "Your detailed description of your changes." + $ git push origin name-of-your-bugfix-or-feature + + +7. Submit a pull request through the GitHub website. + +Pull Request Guidelines +####################### +Before you submit a pull request, check that it meets these guidelines: + +1. The pull request should include tests. +2. If the pull request adds functionality, the docs should be updated. Put + your new functionality into a function with a docstring, and add the + feature to the list in README.md or elsewhere in the documentation. + and make sure that the tests pass for all supported Python versions. + +Tips +#### +To run a subset of tests: + +.. code-block:: console + + $ pytest tests.test_ahcore` + +Deploying +######### +A reminder for the maintainers on how to deploy. +Make sure all your changes are committed. Then run: + +.. code-block:: console + + $ bump2version patch # possible: major / minor / patch + $ git push + + +.. _pre-commit: https://pre-commit.com/ +.. _GitHub repository: https://github.com/NKI-AI/ahcore +.. _issue: https://github.com/NKI-AI/ahcore/issues diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0586611 --- /dev/null +++ b/Makefile @@ -0,0 +1,94 @@ +.PHONY: clean clean-test clean-pyc clean-build docs help +.DEFAULT_GOAL := help + +define BROWSER_PYSCRIPT +import os, webbrowser, sys + +from urllib.request import pathname2url + +webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) +endef +export BROWSER_PYSCRIPT + +define PRINT_HELP_PYSCRIPT +import re, sys + +for line in sys.stdin: + match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) + if match: + target, help = match.groups() + print("%-20s %s" % (target, help)) +endef +export PRINT_HELP_PYSCRIPT + +BROWSER := python -c "$$BROWSER_PYSCRIPT" + +help: + @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) + +clean: clean-build clean-pyc clean-test clean-docs ## remove all build, test, coverage, docs and Python artifacts + +clean-build: ## remove build artifacts + rm -fr build/ + rm -fr dist/ + rm -fr .eggs/ + find . -name '*.egg-info' -exec rm -fr {} + + find . -name '*.egg' -exec rm -f {} + + +clean-pyc: ## remove Python file artifacts + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + find . -name '__pycache__' -exec rm -fr {} + + +clean-test: ## remove test and coverage artifacts + rm -fr .tox/ + rm -f .coverage + rm -fr htmlcov/ + rm -fr .pytest_cache + +clean-docs: ## clean sphinx docs + rm -f docs/ahcore.rst + rm -f docs/modules.rst + rm -f docs/ahcore.*.rst + rm -rf docs/_build + +lint: ## check style with flake8 + flake8 ahcore tests + +test: ## run tests quickly with the default Python + pytest + +test-all: ## run tests on every Python version with tox + tox + +coverage: ## check code coverage quickly with the default Python + coverage run --source ahcore -m pytest + coverage report -m + coverage html + $(BROWSER) htmlcov/index.html + +docs: clean-docs ## generate Sphinx HTML documentation, including API docs + sphinx-apidoc -o docs/ ahcore + $(MAKE) -C docs clean + $(MAKE) -C docs html + +viewdocs: + $(BROWSER) docs/_build/html/index.html + +uploaddocs: docs # Compile the docs + rsync -avh docs/_build/html/ docs@aiforoncology.nl:/var/www/html/ahcore --delete + +servedocs: docs ## compile the docs watching for changes + watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . + +release: dist ## package and upload a release + twine upload dist/* + +dist: clean ## builds source and wheel package + python setup.py sdist + python setup.py bdist_wheel + ls -l dist + +install: clean ## install the package to the active Python's site-packages + python setup.py install diff --git a/README.md b/README.md new file mode 100644 index 0000000..8fa547c --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +# AI for Oncology Core for Computational Pathology +[![Pre-commit checks](https://github.com/NKI-AI/ahcore/actions/workflows/precommit_checks.yml/badge.svg)](https://github.com/NKI-AI/ahcore/actions/workflows/precommit_checks.yml) +[![codecov](https://codecov.io/gh/NKI-AI/ahcore/branch/main/graph/badge.svg?token=OIJ7F9G7OO)](https://codecov.io/gh/NKI-AI/ahcore) + +Ahcore are the [AI for Oncology](https://aiforoncology.nl) core components for computational pathology. +It provides a set of tools for working with pathology images and annotations. +It also offers standard computational pathology algorithms. + +Check the [full documentation](https://docs.aiforoncology.nl/ahcore) for more details on how to use ahcore. + + +# License and usage +Ahcore is not intended for clinical use. It is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). diff --git a/ahcore/__init__.py b/ahcore/__init__.py new file mode 100644 index 0000000..b4dd3b4 --- /dev/null +++ b/ahcore/__init__.py @@ -0,0 +1,2 @@ +"""Main ahcore module""" +__version__ = "0.1" diff --git a/ahcore/callbacks.py b/ahcore/callbacks.py new file mode 100644 index 0000000..2b111a3 --- /dev/null +++ b/ahcore/callbacks.py @@ -0,0 +1,950 @@ +"""Ahcore's callbacks""" + +from __future__ import annotations + +import hashlib +import itertools +import json +import logging +import multiprocessing +import time +from collections import namedtuple +from multiprocessing import Pipe, Process, Queue, Semaphore +from multiprocessing.connection import Connection +from multiprocessing.pool import Pool +from pathlib import Path +from typing import Any, Callable, Generator, Iterator, Optional, TypedDict, cast + +import numpy as np +import numpy.typing as npt +import pytorch_lightning as pl +import torch +from dlup import SlideImage +from dlup._image import Resampling +from dlup.annotations import WsiAnnotations +from dlup.data.dataset import ConcatDataset, TiledWsiDataset +from dlup.data.transforms import convert_annotations, rename_labels +from dlup.tiling import Grid, GridOrder, TilingMode +from dlup.writers import TiffCompression, TifffileImageWriter +from pytorch_lightning.callbacks import Callback +from shapely.geometry import MultiPoint, Point +from torch.utils.data import Dataset + +from ahcore.lit_module import AhCoreLightningModule +from ahcore.metrics import WSIMetricFactory +from ahcore.readers import H5FileImageReader, StitchingMode +from ahcore.transforms.pre_transforms import one_hot_encoding +from ahcore.utils.data import DataDescription, GridDescription +from ahcore.utils.io import get_logger +from ahcore.utils.manifest import DataManager, ImageMetadata, fetch_image_metadata, get_mask_and_annotations_from_record +from ahcore.utils.types import DlupDatasetSample, GenericArray +from ahcore.writers import H5FileImageWriter + +logger = get_logger(__name__) + +logging.getLogger("pyvips").setLevel(logging.ERROR) + + +class _ValidationDataset(Dataset[DlupDatasetSample]): + """Helper dataset to compute the validation metrics.""" + + def __init__( + self, + data_description: Optional[DataDescription], + native_mpp: float, + reader: H5FileImageReader, + annotations: Optional[WsiAnnotations] = None, + mask: Optional[WsiAnnotations] = None, + region_size: tuple[int, int] = (1024, 1024), + ): + """ + Parameters + ---------- + data_description : DataDescription + native_mpp : float + The actual mpp of the underlying image. + reader : H5FileImageReader + annotations : WsiAnnotations + mask : WsiAnnotations + region_size : Tuple[int, int] + The region size to use to split up the image into regions. + """ + super().__init__() + self._data_description = data_description + self._native_mpp = native_mpp + self._scaling = self._native_mpp / reader.mpp + self._reader = reader + self._region_size = region_size + self._logger = get_logger(type(self).__name__) + + self._annotations = self._validate_annotations(annotations) + self._mask = self._validate_annotations(mask) + + self._grid = Grid.from_tiling( + (0, 0), + reader.size, + tile_size=self._region_size, + tile_overlap=(0, 0), + mode=TilingMode.overflow, + order=GridOrder.C, + ) + + self._regions = self._generate_regions() + self._logger.debug(f"Number of validation regions: {len(self._regions)}") + + def _validate_annotations(self, annotations: Optional[WsiAnnotations]) -> Optional[WsiAnnotations]: + if annotations is None: + return None + + if isinstance(annotations, WsiAnnotations): + if self._data_description is None: + raise ValueError( + "Annotations as a `WsiAnnotations` class are provided but no data description is given." + "This is required to map the labels to indices." + ) + elif isinstance(annotations, SlideImage): + pass # We do not need a specific test for this + else: + raise NotImplementedError(f"Annotations of type {type(annotations)} are not supported.") + + return annotations + + def _generate_regions(self) -> list[tuple[int, int]]: + """Generate the regions to use. These regions are filtered grid cells where there is a mask. + + Returns + ------- + List[Tuple[int, int]] + The list of regions. + """ + regions = [] + for coordinates in self._grid: + _coordinates = (coordinates[0], coordinates[1]) + if self._mask is None or self._is_masked(_coordinates): + regions.append(_coordinates) + return regions + + def _is_masked(self, coordinates: tuple[int, int]) -> bool: + """Check if the region is masked. This works with any masking function that supports a `read_region` method or + returns a list of annotations with an `area` attribute. In case there are elements of the form `Point` in the + annotation list, these are also added. + + Parameters + ---------- + coordinates : Tuple[int, int] + The coordinates of the region to check. + + Returns + ------- + bool + True if the region is masked, False otherwise. Will also return True when there is no mask. + """ + if self._mask is None: + return True + + region_mask = self._mask.read_region(coordinates, self._scaling, self._region_size) + + if isinstance(region_mask, np.ndarray): + return region_mask.sum() > 0 + + # We check if the region is not a Point, otherwise this annotation is always included + # Else, we compute if there is a positive area in the region. + return bool(sum(_.area if _ is not isinstance(_, (Point, MultiPoint)) else 1.0 for _ in region_mask) > 0) + + def __getitem__(self, idx: int) -> dict[str, Any]: + sample = {} + coordinates = self._regions[idx] + + sample["prediction"] = self._get_h5_region(coordinates) + + if self._annotations is not None: + target, roi = self._get_annotation_data(coordinates) + if roi is not None: + sample["roi"] = roi.astype(np.uint8) + sample["target"] = target + + return sample + + def _get_h5_region(self, coordinates: tuple[int, int]) -> npt.NDArray[np.uint8 | np.uint16 | np.float32 | np.bool_]: + x, y = coordinates + width, height = self._region_size + + if x + width > self._reader.size[0] or y + height > self._reader.size[1]: + region = self._read_and_pad_region(coordinates) + else: + region = self._reader.read_region_raw(coordinates, self._region_size) + return region + + def _read_and_pad_region(self, coordinates: tuple[int, int]) -> npt.NDArray[Any]: + x, y = coordinates + width, height = self._region_size + new_width = min(width, self._reader.size[0] - x) + new_height = min(height, self._reader.size[1] - y) + clipped_region = self._reader.read_region_raw((x, y), (new_width, new_height)) + + prediction = np.zeros((clipped_region.shape[0], *self._region_size), dtype=clipped_region.dtype) + prediction[:, :new_height, :new_width] = clipped_region + return prediction + + def _get_annotation_data( + self, coordinates: tuple[int, int] + ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.int_] | None]: + if not self._annotations: + raise ValueError("No annotations are provided.") + + if not self._data_description: + raise ValueError("No data description is provided.") + + if not self._data_description.remap_labels: + raise ValueError("Remap labels are not provided.") + + if not self._data_description.index_map: + raise ValueError("Index map is not provided.") + + _annotations = self._annotations.read_region(coordinates, self._scaling, self._region_size) + _annotations = rename_labels(_annotations, remap_labels=self._data_description.remap_labels) + + points, boxes, region, roi = convert_annotations( + _annotations, + self._region_size, + index_map=self._data_description.index_map, + roi_name="roi", + ) + encoded_region = one_hot_encoding(index_map=self._data_description.index_map, mask=region) + if roi is not None: + return encoded_region, roi[np.newaxis, ...] + return encoded_region, None + + def __iter__(self) -> Iterator[dict[str, Any]]: + for idx in range(len(self)): + yield self[idx] + + def __len__(self) -> int: + return len(self._regions) + + +class _WriterMessage(TypedDict): + queue: Queue[Optional[tuple[GenericArray, GenericArray]]] # pylint: disable=unsubscriptable-object + writer: H5FileImageWriter + process: Process + connection: Connection + + +def _get_uuid_for_filename(input_path: Path) -> str: + """Get a unique filename for the given input path. This is done by hashing the absolute path of the file. + This is required because we cannot assume any input format. We hash the complete input path. + + Parameters + ---------- + input_path : Path + The input path to hash. + + Returns + ------- + str + The hashed filename. + """ + # Get the absolute path of the file + input_path = Path(input_path).resolve() + + # Create a SHA256 hash of the file path + hash_object = hashlib.sha256(str(input_path).encode()) + hex_dig = hash_object.hexdigest() + return hex_dig + + +def _get_h5_output_filename(dump_dir: Path, input_path: Path, model_name: str, step: None | int | str = None) -> Path: + hex_dig = _get_uuid_for_filename(input_path=input_path) + + # Return the hashed filename with the new extension + if step is not None: + return dump_dir / "outputs" / model_name / f"step_{step}" / f"{hex_dig}.h5" + return dump_dir / "outputs" / model_name / f"{hex_dig}.h5" + + +class WriteH5Callback(Callback): + def __init__(self, max_queue_size: int, max_concurrent_writers: int, dump_dir: Path): + """ + Callback to write predictions to H5 files. This callback is used to write whole-slide predictions to single H5 + files in a separate thread. + + TODO: + - Add support for distributed data parallel + + Parameters + ---------- + max_queue_size : int + The maximum number of items to store in the queue (i.e. tiles). + max_concurrent_writers : int + The maximum number of concurrent writers. + dump_dir : pathlib.Path + The directory to dump the H5 files to. + """ + super().__init__() + self._writers: dict[str, _WriterMessage] = {} + self._current_filename = None + self._dump_dir = Path(dump_dir) + self._max_queue_size = max_queue_size + self._semaphore = Semaphore(max_concurrent_writers) + self._dataset_index = 0 + + self._logger = get_logger(type(self).__name__) + + @property + def dump_dir(self) -> Path: + return self._dump_dir + + def __process_management(self) -> None: + """ + Handle the graceful termination of multiple processes at the end of h5 writing. + This block ensures proper release of resources allocated during multiprocessing. + + Returns + ------- + None + """ + assert self._current_filename, "_current_filename shouldn't be None here" + + self._writers[self._current_filename]["queue"].put(None) + self._writers[self._current_filename]["process"].join() + self._writers[self._current_filename]["process"].close() + self._writers[self._current_filename]["queue"].close() + + @property + def writers(self) -> dict[str, _WriterMessage]: + return self._writers + + def _batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + stage: str, + dataloader_idx: int = 0, + ) -> None: + filename = batch["path"][0] # Filenames are constant across the batch. + if any([filename != path for path in batch["path"]]): + raise ValueError( + "All paths in a batch must be the same. " + "Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler." + ) + + if filename != self._current_filename: + output_filename = _get_h5_output_filename( + self.dump_dir, + filename, + model_name=str(pl_module.name), + step=pl_module.global_step, + ) + output_filename.parent.mkdir(parents=True, exist_ok=True) + link_fn = ( + self.dump_dir / "outputs" / f"{pl_module.name}" / f"step_{pl_module.global_step}" / "image_h5_link.txt" + ) + with open(link_fn, "a" if link_fn.is_file() else "w") as file: + file.write(f"{filename},{output_filename}\n") + + self._logger.debug("%s -> %s", filename, output_filename) + if self._current_filename is not None: + self.__process_management() + self._semaphore.release() + + self._semaphore.acquire() + + if stage == "validate": + total_dataset: ConcatDataset = trainer.datamodule.validate_dataset # type: ignore + elif stage == "predict": + total_dataset: ConcatDataset = trainer.predict_dataloaders.dataset # type: ignore + else: + raise NotImplementedError(f"Stage {stage} is not supported for {self.__class__.__name__}.") + + current_dataset: TiledWsiDataset + current_dataset, _ = total_dataset.index_to_dataset(self._dataset_index) # type: ignore + slide_image = current_dataset.slide_image + + data_description: DataDescription = pl_module.data_description # type: ignore + inference_grid: GridDescription = data_description.inference_grid + + mpp = inference_grid.mpp + if mpp is None: + mpp = slide_image.mpp + + size = slide_image.get_scaled_size(slide_image.get_scaling(mpp)) + num_samples = len(current_dataset) + + # Let's get the data_description, so we can figure out the tile size and things like that + tile_size = inference_grid.tile_size + tile_overlap = inference_grid.tile_overlap + + # TODO: We are really putting strange things in the Queue if we may believe mypy + new_queue: Queue[Any] = Queue() + parent_conn, child_conn = Pipe() + new_writer = H5FileImageWriter( + output_filename, + size=size, + mpp=mpp, + tile_size=tile_size, + tile_overlap=tile_overlap, + num_samples=num_samples, + progress=None, + ) + new_process = Process(target=new_writer.consume, args=(self.generator(new_queue), child_conn)) + new_process.start() + self._writers[filename] = { + "queue": new_queue, + "writer": new_writer, + "process": new_process, + "connection": parent_conn, + } + self._current_filename = filename + + prediction = outputs["prediction"].detach().cpu().numpy() + coordinates_x, coordinates_y = batch["coordinates"] + coordinates = torch.stack([coordinates_x, coordinates_y]).T.detach().cpu().numpy() + self._writers[filename]["queue"].put((coordinates, prediction)) + self._dataset_index += prediction.shape[0] + + def _epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if self._current_filename is not None: + self.__process_management() + self._semaphore.release() + self._dataset_index = 0 + # Reset current filename to None for correct execution of subsequent validation loop + self._current_filename = None + # Clear all the writers from the current epoch + self._writers = {} + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self._batch_end(trainer, pl_module, outputs, batch, batch_idx, "validate", dataloader_idx) + + def on_predict_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self._batch_end(trainer, pl_module, outputs, batch, batch_idx, "predict", dataloader_idx) + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_end(trainer, pl_module) + + def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_end(trainer, pl_module) + + @staticmethod + def generator( + queue: Queue[Optional[GenericArray]], # pylint: disable=unsubscriptable-object + ) -> Generator[GenericArray, None, None]: + while True: + batch = queue.get() + if batch is None: + break + yield batch + + +# Separate because this cannot be pickled. +def _generator_from_reader( + h5_reader: H5FileImageReader, + tile_size: tuple[int, int], + tile_process_function: Callable[[GenericArray], GenericArray], +) -> Generator[GenericArray, None, None]: + validation_dataset = _ValidationDataset( + data_description=None, + native_mpp=h5_reader.mpp, + reader=h5_reader, + annotations=None, + mask=None, + region_size=(1024, 1024), + ) + + for sample in validation_dataset: + region = sample["prediction"] + yield region if tile_process_function is None else tile_process_function(region) + + +def _write_tiff( + filename: Path, + tile_size: tuple[int, int], + tile_process_function: Callable[[GenericArray], GenericArray], + generator_from_reader: Callable[ + [H5FileImageReader, tuple[int, int], Callable[[GenericArray], GenericArray]], + Iterator[npt.NDArray[np.int_]], + ], +) -> None: + logger.debug("Writing TIFF %s", filename.with_suffix(".tiff")) + with H5FileImageReader(filename, stitching_mode=StitchingMode.CROP) as h5_reader: + writer = TifffileImageWriter( + filename.with_suffix(".tiff"), + size=h5_reader.size, + mpp=h5_reader.mpp, + tile_size=tile_size, + pyramid=True, + compression=TiffCompression.JPEG, + quality=100, + interpolator=Resampling.NEAREST, + ) + writer.from_tiles_iterator(generator_from_reader(h5_reader, tile_size, tile_process_function)) + + +def tile_process_function(x: GenericArray) -> GenericArray: + return np.asarray(np.argmax(x, axis=0).astype(np.uint8)) + + +class WriteTiffCallback(Callback): + def __init__(self, max_concurrent_writers: int, tile_size: tuple[int, int] = (1024, 1024)): + self._pool = multiprocessing.Pool(max_concurrent_writers) + self._logger = get_logger(type(self).__name__) + self._dump_dir: Optional[Path] = None + self.__write_h5_callback_index = -1 + + self._model_name: str | None = None + self._tile_size = tile_size + + # TODO: Handle tile operation such that we avoid repetitions. + + self._tile_process_function = tile_process_function # function that is applied to the tile. + self._filenames: dict[Path, Path] = {} # This has all the h5 files + + @property + def dump_dir(self) -> Optional[Path]: + return self._dump_dir + + def _validate_parameters(self) -> None: + dump_dir = self._dump_dir + if not dump_dir: + raise ValueError("Dump directory is not set.") + + def setup( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + stage: Optional[str] = None, + ) -> None: + if not isinstance(pl_module, AhCoreLightningModule): + # TODO: Make a AhCoreCallback with these features + raise ValueError("AhCoreLightningModule required for WriteTiffCallback.") + + self._model_name = pl_module.name + + _callback: Optional[WriteH5Callback] = None + for idx, callback in enumerate(trainer.callbacks): # type: ignore + if isinstance(callback, WriteH5Callback): + _callback = cast(WriteH5Callback, trainer.callbacks[idx]) # type: ignore + break + if _callback is None: + raise ValueError("WriteH5Callback required before tiff images can be written using this Callback.") + + # This is needed for mypy + assert _callback, "_callback should never be None after the setup." + assert _callback.dump_dir, "_callback.dump_dir should never be None after the setup." + self._dump_dir = _callback.dump_dir + + def _batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + assert self.dump_dir, "dump_dir should never be None here." + + filename = Path(batch["path"][0]) # Filenames are constant across the batch. + if filename not in self._filenames: + output_filename = _get_h5_output_filename( + dump_dir=self.dump_dir, + input_path=filename, + model_name=str(pl_module.name), + step=pl_module.global_step, + ) + self._filenames[filename] = output_filename + + def _epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + assert self.dump_dir, "dump_dir should never be None here." + self._logger.info("Writing TIFF files to %s", self.dump_dir / "outputs" / f"{pl_module.name}") + results = [] + for image_filename, h5_filename in self._filenames.items(): + self._logger.debug( + "Writing image output %s to %s", + Path(image_filename), + Path(image_filename).with_suffix(".tiff"), + ) + output_path = self.dump_dir / "outputs" / f"{pl_module.name}" / f"step_{pl_module.global_step}" + with open(output_path / "image_tiff_link.txt", "a") as file: + file.write(f"{image_filename},{h5_filename.with_suffix('.tiff')}\n") + if not h5_filename.exists(): + self._logger.warning("H5 file %s does not exist. Skipping", h5_filename) + continue + + result = self._pool.apply_async( + _write_tiff, + ( + h5_filename, + self._tile_size, + self._tile_process_function, + _generator_from_reader, + ), + ) + results.append(result) + + for result in results: + result.get() # Wait for the process to complete. + self._filenames = {} # Reset the filenames + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + + def on_predict_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self._batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_end(trainer, pl_module) + + def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._epoch_end(trainer, pl_module) + + +# Create a data structure to hold all required information for each task +TaskData = namedtuple("TaskData", ["filename", "h5_filename", "metadata", "mask", "annotations"]) + + +def prepare_task_data( + filename: Path, + dump_dir: Path, + pl_module: pl.LightningModule, + data_description: DataDescription, + data_manager: DataManager, +) -> TaskData: + h5_filename = _get_h5_output_filename( + dump_dir=dump_dir, + input_path=data_description.data_dir / filename, + model_name=str(pl_module.name), + step=pl_module.global_step, + ) + image = data_manager.get_image_by_filename(str(filename)) + metadata = fetch_image_metadata(image) + mask, annotations = get_mask_and_annotations_from_record(data_description.annotations_dir, image) + + return TaskData(filename, h5_filename, metadata, mask, annotations) + + +def compute_metrics_for_case( + task_data: TaskData, + class_names: dict[int, str], + data_description: DataDescription, + wsi_metrics: WSIMetricFactory, + save_per_image: bool, +) -> list[dict[str, Any]]: + # Extract the data from the namedtuple + filename, h5_filename, metadata, mask, annotations = task_data + + dump_list = [] + + logger.info("Computing metrics for %s", filename) + + with H5FileImageReader(h5_filename, stitching_mode=StitchingMode.CROP) as h5reader: + dataset_of_validation_image = _ValidationDataset( + data_description=data_description, + native_mpp=metadata.mpp, + mask=mask, + annotations=annotations, + reader=h5reader, + ) + for sample in dataset_of_validation_image: + prediction = torch.from_numpy(sample["prediction"]).unsqueeze(0).float() + target = torch.from_numpy(sample["target"]).unsqueeze(0) + roi = torch.from_numpy(sample["roi"]).unsqueeze(0) + + wsi_metrics.process_batch( + predictions=prediction, + target=target, + roi=roi, + wsi_name=str(filename), + ) + if save_per_image: + wsi_metrics_dictionary = { + "image_fn": str(data_description.data_dir / metadata.filename), + "uuid": filename.stem, + } + + # TODO: These need to be removed, this is really weird. + if filename.with_suffix(".tiff").is_file(): + wsi_metrics_dictionary["tiff_fn"] = str(filename.with_suffix(".tiff")) + if filename.is_file(): + wsi_metrics_dictionary["h5_fn"] = str(filename) + for metric in wsi_metrics._metrics: + metric.get_wsi_score(str(filename)) + wsi_metrics_dictionary[metric.name] = { + class_names[class_idx]: metric.wsis[str(filename)][class_idx][metric.name].item() + for class_idx in range(data_description.num_classes) + } + dump_list.append(wsi_metrics_dictionary) + + return dump_list + + +# Adjusted stand-alone function. +def schedule_task( + task_data: TaskData, + pool: Pool, + results_dict: dict[Any, str], # Any because it will be a multiprocessing.pool.AsyncResult + class_names: dict[int, str], + data_description: DataDescription, + wsi_metrics: WSIMetricFactory, + save_per_image: bool, +) -> None: + result = pool.apply_async( + compute_metrics_for_case, + args=(task_data, class_names, data_description, wsi_metrics, save_per_image), + ) + results_dict[result] = task_data.filename + + +class ComputeWsiMetricsCallback(Callback): + def __init__(self, max_processes: int = 10, save_per_image: bool = True) -> None: + """ + Callback to compute metrics on whole-slide images. This callback is used to compute metrics on whole-slide + images in separate processes. + + Parameters + ---------- + max_processes : int + The maximum number of concurrent processes. + """ + self._data_description: Optional[DataDescription] = None + self._reader = H5FileImageReader + self._max_processes: int = max_processes + self._dump_dir: Optional[Path] = None + self._save_per_image = save_per_image + self._filenames: dict[Path, Path] = {} + + self._wsi_metrics: WSIMetricFactory | None = None + self._class_names: dict[int, str] = {} + self._data_manager = None + self._validate_filenames_gen = None + + self._model_name: str | None = None + + self._validate_metadata_gen: Generator[ImageMetadata, None, None] | None = None + + self._dump_list: list[dict[str, str]] = [] + self._logger = get_logger(type(self).__name__) + + def setup( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + stage: Optional[str] = None, + ) -> None: + if not isinstance(pl_module, AhCoreLightningModule): + # TODO: Make a AhCoreCallback with these features + raise ValueError("AhCoreLightningModule required for WriteTiffCallback.") + + self._model_name = pl_module.name + + _callback: Optional[WriteH5Callback] = None + for idx, callback in enumerate(trainer.callbacks): # type: ignore + if isinstance(callback, WriteH5Callback): + _callback = cast(WriteH5Callback, trainer.callbacks[idx]) # type: ignore + break + + if _callback is None: + raise ValueError( + "WriteH5Callback is not in the trainer's callbacks. " + "This is required before WSI metrics can be computed using this Callback" + ) + + self._dump_dir = _callback.dump_dir + + if pl_module.wsi_metrics is None: + raise ValueError("WSI metrics are not set.") + + self._wsi_metrics = pl_module.wsi_metrics + self._data_description = trainer.datamodule.data_description # type: ignore + + # For mypy + assert self._data_description + index_map = self._data_description.index_map + assert index_map + + if not self._data_description: + raise ValueError("Data description is not set.") + + self._class_names = dict([(v, k) for k, v in index_map.items()]) + self._class_names[0] = "background" + + # Here we can query the database for the validation images + self._data_manager: DataManager = trainer.datamodule.data_manager # type: ignore + + def _create_validate_image_metadata_gen( + self, + ) -> Generator[ImageMetadata, None, None]: + assert self._data_description + assert self._data_manager + gen = self._data_manager.get_image_metadata_by_split( + manifest_name=self._data_description.manifest_name, + split_version=self._data_description.split_version, + split_category="validate", + ) + for image_metadata in gen: + yield image_metadata + + @property + def _validate_metadata(self) -> Generator[ImageMetadata, None, None] | None: + return self._validate_metadata_gen + + def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + self._validate_metadata_gen = self._create_validate_image_metadata_gen() + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if not self._dump_dir: + raise ValueError("Dump directory is not set.") + + filenames = batch["path"] # Filenames are constant across the batch. + if len(set(filenames)) != 1: + raise ValueError( + "All paths in a batch must be the same. " + "Either use batch_size=1 or ahcore.data.samplers.WsiBatchSampler." + ) + + def compute_metrics( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> list[list[dict[str, dict[str, float]]]]: + assert self._dump_dir + assert self._data_description + assert self._validate_metadata + assert self._data_manager + metrics = [] + + with multiprocessing.Pool(processes=self._max_processes) as pool: + results_to_filename: dict[list[dict[str, Any]], str] = {} + completed_tasks = 0 + + # Fill up the initial task pool + for image_metadata in itertools.islice(self._validate_metadata, self._max_processes): + logger.info("Metadata: %s", image_metadata) + # Assemble the task data + # filename", "h5_filename", "metadata", "mask", "annotations" + task_data = prepare_task_data( + image_metadata.filename, + self._dump_dir, + pl_module, + self._data_description, + self._data_manager, + ) + + # Schedule task + schedule_task( + task_data, + pool, + results_to_filename, + self._class_names, + self._data_description, + self._wsi_metrics, + self._save_per_image, + ) + + while results_to_filename: + time.sleep(0.1) # Reduce excessive polling + # Check for completed tasks + for result in list(results_to_filename.keys()): + if result.ready(): + filename = results_to_filename.pop(result) + try: + metric = result.get() + except Exception as exc: + self._logger.error("%r generated an exception: %s" % (filename, exc)) + else: + metrics.append(metric) + self._logger.debug("Metric for %r is %s" % (filename, metric)) + + completed_tasks += 1 + + # Schedule a new task if there are more filenames left in the generator + next_metadata = next(self._validate_metadata, None) + while next_metadata: + task_data = prepare_task_data( + next_metadata.filename, # <-- Changed from image_metadata.filename + self._dump_dir, + pl_module, + self._data_description, + self._data_manager, + ) + + # Schedule task + schedule_task( + task_data, + pool, + results_to_filename, + self._class_names, + self._data_description, + self._wsi_metrics, + self._save_per_image, + ) + + next_metadata = next(self._validate_metadata, None) + return metrics + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if not self._dump_dir: + raise ValueError("Dump directory is not set.") + if not self._wsi_metrics: + raise ValueError("WSI metrics are not set.") + assert self._model_name # This should be set in the setup() + + # Ensure that all h5 files have been written + self._logger.debug("Computing metrics for %s predictions", len(self._filenames)) + computed_metrics = self.compute_metrics(trainer, pl_module) + metrics = self._wsi_metrics.get_average_score(computed_metrics) + results_json_fn = ( + self._dump_dir / "outputs" / self._model_name / f"step_{pl_module.global_step}" / "results.json" + ) + with open(results_json_fn, "w", encoding="utf-8") as json_file: + json.dump(self._dump_list, json_file, indent=2) + self._wsi_metrics.reset() + # Reset stuff + self._dump_list = [] + self._filenames = {} + + self._logger.debug("Metrics: %s", metrics) + + # TODO: Maybe put this elsewhere? + metrics = {f"validate/{k}": v for k, v in metrics.items()} + pl_module.log_dict(metrics, prog_bar=True) diff --git a/ahcore/cli/__init__.py b/ahcore/cli/__init__.py new file mode 100644 index 0000000..3cea9dc --- /dev/null +++ b/ahcore/cli/__init__.py @@ -0,0 +1,84 @@ +"""Ahcore Command-line interface. This is the file which builds the main parser.""" +from __future__ import annotations + +import argparse +import os +import pathlib +from typing import Callable + + +def dir_path(require_writable: bool = False) -> Callable[[str], pathlib.Path]: + def check_dir_path(path: str) -> pathlib.Path: + """Check if the path is a valid and (optionally) writable directory. + + Parameters + ---------- + path : str + + Returns + ------- + pathlib.Path + The path as a pathlib.Path object. + """ + _path = pathlib.Path(path) + if _path.is_dir(): + if require_writable: + if os.access(_path, os.W_OK): + return _path + else: + raise argparse.ArgumentTypeError(f"{path} is not a writable directory.") + else: + return _path + raise argparse.ArgumentTypeError(f"{path} is not a valid directory.") + + return check_dir_path + + +def file_path(path: str) -> pathlib.Path: + """Check if the path is a valid file. + + Parameters + ---------- + path : str + + Returns + ------- + pathlib.Path + The path as a pathlib.Path object. + + """ + _path = pathlib.Path(path) + if _path.is_file(): + return _path + raise argparse.ArgumentTypeError(f"{path} is not a valid file.") + + +def main() -> None: + """ + Main entrypoint for the CLI command of ahcore. + """ + # From https://stackoverflow.com/questions/17073688/how-to-use-argparse-subparsers-correctly + root_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + root_subparsers = root_parser.add_subparsers(help="Possible ahcore CLI utils to run.") + root_subparsers.required = True + root_subparsers.dest = "subcommand" + + # Prevent circular import + from ahcore.cli.data import register_parser as register_data_subcommand + + # Data related commands. + register_data_subcommand(root_subparsers) + + # Prevent circular import + from ahcore.cli.tiling import register_parser as register_tiling_subcommand + + # Tiling related commands + register_tiling_subcommand(root_subparsers) + + args = root_parser.parse_args() + args.subcommand(args) + + +if __name__ == "__main__": + main() diff --git a/ahcore/cli/data.py b/ahcore/cli/data.py new file mode 100644 index 0000000..c1958d5 --- /dev/null +++ b/ahcore/cli/data.py @@ -0,0 +1,100 @@ +"""Module to write copy manifests files over to SCRATCH directory""" +import argparse +import hashlib +import os +import shutil +import sys +from pathlib import Path +from typing import Any + +from rich.progress import Progress + +from ahcore.cli import dir_path +from ahcore.utils.manifest import DataManager + + +def _quick_hash(file_path: Path, max_bytes: int = 10**6) -> str: + hasher = hashlib.sha256() + with open(file_path, "rb") as f: + block = f.read(max_bytes) + hasher.update(block) + return hasher.hexdigest() + + +def copy_data(args: argparse.Namespace) -> None: + manifest_fn = args.manifest_fn + base_dir = args.base_dir + dataset_name = args.dataset_name + target_dir = os.environ.get("SCRATCH", None) + + if target_dir is None or not os.access(target_dir, os.W_OK): + print("Please set the SCRATCH environment variable to a writable directory.") + sys.exit(1) + + with DataManager(manifest_fn) as dm: + all_records = dm.get_records_by_split(args.manifest_name, args.split_name, split_category=None) + with Progress() as progress: + task = progress.add_task("[cyan]Copying...") + for patient in all_records: + for image in patient.images: + image_fn = image.filename + get_from = base_dir / image_fn + write_to = Path(target_dir) / dataset_name / image_fn + + write_to.parent.mkdir(parents=True, exist_ok=True) + if write_to.exists(): + # compute the hash of previous and new file + old_hash = _quick_hash(write_to) + new_hash = _quick_hash(get_from) + if old_hash == new_hash: + # Skip if they are the same + progress.console.log("Skipping file as it already exists: {}".format(image_fn)) + progress.update(task, advance=1) + continue + + # Copy file from get_from to write_to + shutil.copy(get_from, write_to) + progress.update(task, advance=1) + + +def register_parser( + parser: argparse._SubParsersAction[Any], +) -> None: # pylint: disable=E1136 + """Register inspect commands to a root parser.""" + data_parser = parser.add_parser("data", help="Data utilities") + data_subparsers = data_parser.add_subparsers(help="Data subparser") + data_subparsers.required = True + data_subparsers.dest = "subcommand" + + _parser: argparse.ArgumentParser = data_subparsers.add_parser( + "copy-data-from-manifest", + help="Copy the data to a different drive based on the manifest. " + "The data will be copied over to $SCRATCH / DATASET_NAME", + ) + + _parser.add_argument( + "manifest_uri", + type=str, + help="URI that refers to the sqlalchemy supported database path.", + ) + _parser.add_argument( + "manifest_name", + type=str, + help="Name of the manifest to copy the data from.", + ) + _parser.add_argument( + "split_name", + type=str, + help="Name of the split in the database to copy the data from.", + ) + _parser.add_argument( + "base_dir", + type=dir_path(require_writable=False), + help="Directory to which the paths defined in the manifest are relative to.", + ) + _parser.add_argument( + "dataset_name", + type=str, + help="Name of the dataset to copy the data to. The data will be copied over to $SCRATCH / DATASET_NAME", + ) + _parser.set_defaults(subcommand=copy_data) diff --git a/ahcore/cli/tiling.py b/ahcore/cli/tiling.py new file mode 100644 index 0000000..5cf1a62 --- /dev/null +++ b/ahcore/cli/tiling.py @@ -0,0 +1,436 @@ +"""Utility to create tiles from the TCGA FFPE H&E slides. + +Many models uses 0.5um/pixel at 224 x 224 size. +""" +from __future__ import annotations + +import argparse +import io +from functools import partial +from logging import getLogger +from multiprocessing import Pool +from pathlib import Path +from pprint import pformat +from typing import Any, Generator + +import imageio.v3 as iio +import numpy as np +import numpy.typing as npt +import PIL.Image +from dlup import SlideImage +from dlup.data.dataset import TiledWsiDataset +from dlup.tiling import GridOrder, TilingMode +from PIL import Image +from pydantic import BaseModel +from rich.progress import Progress + +from ahcore.cli import dir_path, file_path +from ahcore.writers import H5FileImageWriter + +logger = getLogger(__name__) + + +def read_mask(path: Path) -> npt.NDArray[np.int_]: + return iio.imread(path)[..., 0] + + +class SlideImageMetaData(BaseModel): + """Metadata of a whole slide image.""" + + path: Path + mpp: float + aspect_ratio: float + magnification: float | None + size: tuple[int, int] + vendor: str | None + + @classmethod + def from_dataset(cls, dataset: TiledWsiDataset) -> "SlideImageMetaData": + _relevant_keys = ["aspect_ratio", "magnification", "mpp", "size", "vendor"] + return cls( + **{ + "path": dataset.path, + **{key: getattr(dataset.slide_image, key) for key in _relevant_keys}, + } + ) + + def __iter__(self) -> Generator[tuple[str, Any], None, None]: + for k, v in self.model_dump().items(): + yield k, v + + +class TileMetaData(BaseModel): + coordinates: tuple[int, int] + region_index: int + grid_local_coordinates: tuple[int, int] + grid_index: int + + def __iter__(self) -> Generator[tuple[str, Any], None, None]: + for k, v in self.model_dump().items(): + yield k, v + + +class DatasetConfigs(BaseModel): + """Configurations of the TiledROIsSlideImageDataset dataset""" + + mpp: float + tile_size: tuple[int, int] + tile_overlap: tuple[int, int] + tile_mode: str + crop: bool + mask_threshold: float + grid_order: str + + +def _save_thumbnail( + image_fn: Path, + dataset_cfg: DatasetConfigs, + mask: npt.NDArray[np.int_] | None, +) -> tuple[npt.NDArray[np.uint8], npt.NDArray[np.uint8] | None, npt.NDArray[np.uint8]]: + target_mpp = max(dataset_cfg.mpp * 30, 30) + tile_size = ( + min(30, dataset_cfg.tile_size[0] // 30), + min(30, dataset_cfg.tile_size[1] // 30), + ) + + dataset = TiledWsiDataset.from_standard_tiling( + image_fn, + target_mpp, + tile_size, + (0, 0), + mask=mask, + mask_threshold=dataset_cfg.mask_threshold, + ) + scaled_region_view = dataset.slide_image.get_scaled_view(dataset.slide_image.get_scaling(target_mpp)) + + if mask is not None: + # Let us write the mask too. + mask_io = io.BytesIO() + pil_mask = PIL.Image.fromarray(mask * 255, mode="L") + pil_mask.save(mask_io, quality=75) + mask_arr = np.frombuffer(mask_io.getvalue(), dtype="uint8") + else: + mask_arr = None + + thumbnail_io = io.BytesIO() + + # TODO: This needs to change in dlup, the scaled_region_view needs to return size in int, int. + _tile_size = tuple(scaled_region_view.size) + tile_size = (_tile_size[0], _tile_size[1]) + + thumbnail = dataset.slide_image.get_thumbnail(tile_size) + thumbnail.convert("RGB").save(thumbnail_io, quality=75) + thumbnail_arr = np.frombuffer(thumbnail_io.getvalue(), dtype="uint8") + + region_size = tuple(scaled_region_view.size) + background = Image.new("RGBA", (region_size[0], region_size[1]), (255, 255, 255, 255)) + + overlay_io = io.BytesIO() + for d in dataset: + tile = d["image"] + coords = np.array(d["coordinates"]) + box = tuple(np.array((*coords, *(coords + tile_size))).astype(int)) + background.paste(tile, (box[0], box[1])) + # You could uncomment this to plot the boxes of the grid as well, but this can quickly become crowded. + # draw = ImageDraw.Draw(background) + # draw.rectangle(box, outline="red") + background.convert("RGB").save(overlay_io, quality=75) + overlay_arr = np.frombuffer(overlay_io.getvalue(), dtype="uint8") + + return thumbnail_arr, mask_arr, overlay_arr + + +def create_slide_image_dataset( + slide_image_path: Path, + mask: SlideImage | npt.NDArray[np.int_] | None, + cfg: DatasetConfigs, + overwrite_mpp: tuple[float, float] | None = None, +) -> TiledWsiDataset: + """ + Initializes and returns a slide image dataset. + + Parameters + ---------- + slide_image_path : Path + Path to a whole slide image file. + mask : np.ndarray | None + Binary mask used to filter each tile. + cfg : DatasetConfigs + Dataset configurations. + overwrite_mpp : tuple[float, float] | None + Tuple of (mpp_x, mpp_y) used to overwrite the mpp of the loaded slide image. + + Returns + ------- + TiledROIsSlideImageDataset + Initialized slide image dataset. + + """ + + return TiledWsiDataset.from_standard_tiling( + path=slide_image_path, + mpp=cfg.mpp, + tile_size=cfg.tile_size, + tile_overlap=cfg.tile_overlap, + grid_order=GridOrder[cfg.grid_order], + tile_mode=TilingMode[cfg.tile_mode], + crop=cfg.crop, + mask=mask, + mask_threshold=cfg.mask_threshold, + overwrite_mpp=overwrite_mpp, + ) + + +def _generator( + dataset: TiledWsiDataset, quality: int | None = 80, compression: str = "JPEG" +) -> Generator[Any, Any, Any]: + for idx in range(len(dataset)): + # TODO: To use: + # for idx, sample in enumerate(dataset): + # The following needs to be added to TiledWsiDataset: + # def __iter__(self) -> Iterator[RegionFromWsiDatasetSample]: + # for i in range(len(self)): + # yield self[i] + sample = dataset[idx] + buffered = io.BytesIO() + if quality is not None: + # If we just cast the PIL.Image to RGB, the alpha channel is set to black + # which is a bit unnatural if you look in the image pyramid where it would be white in lower resolutions + # this is why we take the following approach. + tile: PIL.Image.Image = sample["image"] + background = PIL.Image.new("RGB", tile.size, (255, 255, 255)) # Create a white background + background.paste(tile, mask=tile.split()[3]) # Paste the image using the alpha channel as mask + background.convert("RGB").save(buffered, format=compression, quality=quality) + else: + sample["image"].save(buffered, format=compression, quality=quality) + + # Now we have the image bytes + coordinates = sample["coordinates"] + array = np.frombuffer(buffered.getvalue(), dtype="uint8") + yield [coordinates], array[np.newaxis, :] + + +def save_tiles( + dataset: TiledWsiDataset, + h5_writer: H5FileImageWriter, + quality: int | None = 80, +) -> None: + """ + Saves the tiles in the given image slide dataset to disk. + + Parameters + ---------- + dataset : TiledROIsSlideImageDataset + The image slide dataset containing tiles of a single whole slide image. + h5_writer : H5FileImageWriter + The H5 writer to write the tiles to. + quality : int | None + If not None, the compression quality of the saved tiles in jpg, otherwise png + + """ + compression = "JPEG" if quality is not None else "PNG" + generator = _generator(dataset, quality, compression) + h5_writer.consume(generator) + + +def _tiling_pipeline( + image_path: Path, + mask_path: Path | None, + output_file: Path, + dataset_cfg: DatasetConfigs, + quality: int, + save_thumbnail: bool = False, +) -> None: + output_file.parent.mkdir(parents=True, exist_ok=True) + + try: + # TODO: Come up with a way to inject the mask later on as well. + mask = read_mask(mask_path) if mask_path else None + dataset = create_slide_image_dataset( + slide_image_path=image_path, + mask=mask, + cfg=dataset_cfg, + ) + _scaling = dataset.slide_image.get_scaling(dataset_cfg.mpp) + + h5_writer = H5FileImageWriter( + filename=output_file, + size=dataset.slide_image.get_scaled_size(_scaling), + mpp=dataset_cfg.mpp, + tile_size=dataset_cfg.tile_size, + tile_overlap=dataset_cfg.tile_overlap, + num_samples=len(dataset), + is_binary=True, + ) + save_tiles(dataset, h5_writer, quality) + if save_thumbnail: + thumbnail, thumbnail_mask, overlay = _save_thumbnail(image_path, dataset_cfg, mask) + + if thumbnail_mask is not None: + h5_writer.add_associated_images( + images=( + ("thumbnail", thumbnail), + ("mask", thumbnail_mask), + ("overlay", overlay), + ), + description="thumbnail, mask and overlay", + ) + else: + h5_writer.add_associated_images(images=(("thumbnail", thumbnail),), description="thumbnail") + + except Exception as e: + logger.error(f"Failed: {image_path} with exception {e}") + return + + logger.debug("Working on %s. Writing to %s", image_path, output_file) + + +def _wrapper( + dataset_cfg: DatasetConfigs, + quality: int, + save_thumbnail: bool, + args: tuple[Path, Path, Path], +) -> None: + image_path, mask_path, output_file = args + return _tiling_pipeline(image_path, mask_path, output_file, dataset_cfg, quality, save_thumbnail) + + +def _do_tiling(args: argparse.Namespace) -> None: + images_list: list[tuple[Path, Path | None, Path]] = [] + + with open(args.file_list, "r") as file_list: + for line in file_list: + if line.startswith("#"): + continue + image_file, mask_file, output_filename = line.split(",") + if (args.output_directory / "data" / Path(output_filename.strip())).is_file() and args.simple_check: + continue + + mask_fn = Path(mask_file.strip()) if mask_file != "" else None + images_list.append( + ( + Path(image_file.strip()), + mask_fn, + args.output_directory / "data" / Path(output_filename.strip()), + ) + ) + + logger.info(f"Number of slides: {len(images_list)}") + logger.info(f"Output directory: {args.output_directory}") + logger.info("Tiling...") + + save_dir_data = args.output_directory / "data" + save_dir_data.mkdir(parents=True, exist_ok=True) + + crop = False + tile_mode = TilingMode.overflow + tile_size = (args.tile_size, args.tile_size) + tile_overlap = (args.tile_overlap, args.tile_overlap) + mpp = args.mpp + mask_threshold = args.mask_threshold + + dataset_cfg = DatasetConfigs( + mpp=mpp, + tile_size=tile_size, + tile_overlap=tile_overlap, + tile_mode=tile_mode, + crop=crop, + mask_threshold=mask_threshold, + grid_order="C", + ) + + logger.info(f"Dataset configurations: {pformat(dataset_cfg)}") + + if args.num_workers > 0: + # Create a partially applied function with dataset_cfg + partial_wrapper = partial(_wrapper, dataset_cfg, args.quality, args.save_thumbnail) + + with Progress() as progress: + task = progress.add_task("[cyan]Tiling...", total=len(images_list)) + with Pool(processes=args.num_workers) as pool: + for _ in pool.imap_unordered(partial_wrapper, images_list): + progress.update(task, advance=1) + else: + with Progress() as progress: + task = progress.add_task("[cyan]Tiling...", total=len(images_list)) + for idx, (image_path, mask_path, output_file) in enumerate(images_list): + _tiling_pipeline( + image_path, + mask_path, + output_file, + dataset_cfg, + quality=args.quality, + save_thumbnail=args.save_thumbnail, + ) + progress.update(task, advance=1) + + +def register_parser(parser: argparse._SubParsersAction[Any]) -> None: + """Register inspect commands to a root parser.""" + tiling_parser = parser.add_parser("tiling", help="Tiling utilities") + tiling_subparsers = tiling_parser.add_subparsers(help="Tiling subparser") + tiling_subparsers.required = True + tiling_subparsers.dest = "subcommand" + + _parser: argparse.ArgumentParser = tiling_subparsers.add_parser( + "tile-to-h5", + help="Tiling WSI images to h5", + ) + + # Assume a comma separated format from image_file,mask_file + _parser.add_argument( + "--file-list", + type=file_path, + required=True, + help="Path to the file list. Each comma-separated line is of the form `,,`" + " where the output directory is with request to --output-dir. mask_fn can be empty.", + ) + _parser.add_argument( + "--output-directory", + type=dir_path(require_writable=True), + required=True, + help="Path to the output directory.", + ) + _parser.add_argument( + "--mpp", + type=float, + required=True, + help="Resolution (microns per pixel) at which the slides should be tiled.", + ) + _parser.add_argument("--tile-size", type=int, required=True, help="Size of the tiles in pixels.") + _parser.add_argument( + "--tile-overlap", + type=int, + default=0, + help="Overlap of the tile in pixels (default=0).", + ) + _parser.add_argument( + "--mask-threshold", + type=float, + default=0.6, + help="0 every tile is discarded, 1 requires the whole tile to be foreground (default=0.6).", + ) + _parser.add_argument( + "--num-workers", + type=int, + default=8, + help="Number of workers to use for tiling. 0 disables the tiling (default: 8)", + ) + _parser.add_argument( + "--save-thumbnail", + action="store_true", + help="Save a thumbnail of the slide, including the filtered tiles and the mask itself.", + ) + _parser.add_argument( + "--simple-check", + action="store_true", + help="Filter the list based on if the h5 images already exist.", + ) + _parser.add_argument( + "--quality", + type=int, + default=80, + help="Quality of the saved tiles in jpg, otherwise png (default: 80)", + ) + + _parser.set_defaults(subcommand=_do_tiling) diff --git a/ahcore/data/__init__.py b/ahcore/data/__init__.py new file mode 100644 index 0000000..8d400aa --- /dev/null +++ b/ahcore/data/__init__.py @@ -0,0 +1,5 @@ +"""General module for datasets, samplers and lightning modules. + +- Generic dataset generated by a manifest, which can handle classification, detection and segmentation. +- Samplers that for instance perform adaptive sampling, or define different weights per sample. +""" diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py new file mode 100644 index 0000000..ecff6b6 --- /dev/null +++ b/ahcore/data/dataset.py @@ -0,0 +1,273 @@ +""" +Utilities to construct datasets and DataModule's from manifests. +""" +from __future__ import annotations + +import uuid as uuid_module +from typing import Any, Callable, Generator, Iterator, Optional + +import numpy as np +import pytorch_lightning as pl +import torch +from dlup.data.dataset import ConcatDataset, TiledWsiDataset +from pytorch_lightning.trainer.states import TrainerFn +from torch.utils.data import DataLoader, Sampler + +import ahcore.data.samplers +from ahcore.utils.data import DataDescription, basemodel_to_uuid +from ahcore.utils.io import fullname, get_cache_dir, get_logger +from ahcore.utils.manifest import DataManager, datasets_from_data_description +from ahcore.utils.types import DlupDatasetSample, _DlupDataset + + +class DlupDataModule(pl.LightningDataModule): + """Datamodule for the Ahcore framework. This datamodule is based on `dlup`.""" + + def __init__( + self, + data_description: DataDescription, + pre_transform: Callable[[bool], Callable[[DlupDatasetSample], DlupDatasetSample]], + batch_size: int = 32, # noqa,pylint: disable=unused-argument + validate_batch_size: int | None = None, # noqa,pylint: disable=unused-argument + num_workers: int = 16, + persistent_workers: bool = False, + pin_memory: bool = False, + ) -> None: + """ + Construct a DataModule based on a manifest. + + Parameters + ---------- + data_description : DataDescription + See `ahcore.utils.data.DataDescription` for more information. + pre_transform : Callable + A pre-transform is a callable which is directly applied to the output of the dataset before collation in + the dataloader. The transforms typically convert the image in the output to a tensor, convert the + `WsiAnnotations` to a mask or similar. + batch_size : int + The batch size of the data loader. + validate_batch_size : int, optional + Sometimes the batch size for validation can be larger. If so, set this variable. Will also use this for + prediction. + num_workers : int + The number of workers used to fetch tiles. + persistent_workers : bool + Whether to use persistent workers. Check the pytorch documentation for more information. + pin_memory : bool + Whether to use cuda pin workers. Check the pytorch documentation for more information. + """ + super().__init__() + self._logger = get_logger(fullname(self)) + + self.save_hyperparameters( + logger=True, + ignore=[ + "data_description", + "pre_transform", + "data_dir", + "annotations_dir", + "num_workers", + "persistent_workers", + "pin_memory", + ], + ) # save all relevant hyperparams + + # Data settings + self.data_description: DataDescription = data_description + + self._data_manager = DataManager(database_uri=data_description.manifest_database_uri) + + self._batch_size = self.hparams.batch_size # type: ignore + self._validate_batch_size = self.hparams.validate_batch_size # type: ignore + + mask_threshold = data_description.mask_threshold + if mask_threshold is None: + mask_threshold = 0.0 + self._mask_threshold = mask_threshold + + self._pre_transform = pre_transform + + # DataLoader settings + self._num_workers = num_workers + self._persistent_workers = persistent_workers + self._pin_memory = pin_memory + + self._fit_data_iterator: Iterator[_DlupDataset] | None = None + self._validate_data_iterator: Iterator[_DlupDataset] | None = None + self._test_data_iterator: Iterator[_DlupDataset] | None = None + self._predict_data_iterator: Iterator[_DlupDataset] | None = None + + # Variables to keep track if a dataset has already be constructed (it's a slow operation) + self._already_called: dict[str, bool] = { + "fit": False, + "validate": False, + "test": False, + "predict": False, + } + self._num_classes = data_description.num_classes + + @property + def data_manager(self) -> DataManager: + return self._data_manager + + def setup(self, stage: str) -> None: + if stage not in (e.value for e in TrainerFn): # type: ignore + raise ValueError(f"Stage should be one of {TrainerFn}") + + if stage and self._already_called[stage]: + return + + self._logger.info("Constructing dataset iterator for stage %s", stage) + + def dataset_iterator() -> Generator[TiledWsiDataset, None, None]: + gen = datasets_from_data_description( + db_manager=self._data_manager, + data_description=self.data_description, + transform=self._pre_transform( + requires_target=True if stage != "predict" else False # type: ignore + ), # This bool adds the target + stage=stage, + ) + for dataset in gen: + yield dataset + + setattr(self, f"_{stage}_data_iterator", dataset_iterator()) + + def _construct_concatenated_dataloader( + self, data_iterator: Iterator[_DlupDataset], batch_size: int, stage: str + ) -> Optional[DataLoader[DlupDatasetSample]]: + if not data_iterator: + return None + + def construct_dataset() -> ConcatDataset[DlupDatasetSample]: + datasets = [] + for _, ds in enumerate(data_iterator): + datasets.append(ds) + return ConcatDataset(datasets) + + self._logger.info("Constructing dataset for stage %s (this can take a while)", stage) + dataset = self._load_from_cache(construct_dataset, stage=stage) + setattr(self, f"{stage}_dataset", dataset) + + lengths = np.asarray([len(ds) for ds in dataset.datasets]) + self._logger.info( + f"Dataset for stage {stage} has {len(dataset)} samples and the following statistics:\n" + f" - Mean: {lengths.mean():.2f}\n" + f" - Std: {lengths.std():.2f}\n" + f" - Min: {lengths.min():.2f}\n" + f" - Max: {lengths.max():.2f}" + ) + + batch_sampler: Sampler[list[int]] + if stage == "fit": + batch_sampler = torch.utils.data.BatchSampler( + torch.utils.data.RandomSampler(data_source=dataset, replacement=True), + batch_size=batch_size, + drop_last=True, + ) + + elif stage == "predict": + batch_sampler = ahcore.data.samplers.WsiBatchSamplerPredict( + dataset=dataset, + batch_size=batch_size, + ) + + else: + batch_sampler = ahcore.data.samplers.WsiBatchSampler( + dataset=dataset, + batch_size=batch_size, + ) + + return DataLoader( + dataset, + num_workers=self._num_workers, + batch_sampler=batch_sampler, + persistent_workers=self._persistent_workers, + pin_memory=self._pin_memory, + ) + + def _load_from_cache(self, func: Callable[[], Any], stage: str, *args: Any, **kwargs: Any) -> Any: + name = fullname(func) + path = get_cache_dir() / stage / name + filename = path / f"{self.uuid}.pkl" + if not filename.is_file(): + path.mkdir(exist_ok=True, parents=True) + self._logger.info("Caching %s", name) + + obj = func(*args, **kwargs) + + with open(filename, "wb") as file: + torch.save(obj, file) + else: + with open(filename, "rb") as file: + self._logger.info("Loading %s from cache %s file", name, filename) + obj = torch.load(file) + + return obj + + def train_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: + if not self._fit_data_iterator: + self.setup("fit") + assert self._fit_data_iterator + return self._construct_concatenated_dataloader( + self._fit_data_iterator, + batch_size=self._batch_size, + stage="fit", + ) + + def val_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: + if not self._validate_data_iterator: + self.setup("validate") + + batch_size = self._validate_batch_size if self._validate_batch_size else self._batch_size + assert self._validate_data_iterator + val_dataloader = self._construct_concatenated_dataloader( + self._validate_data_iterator, + batch_size=batch_size, + stage="validate", + ) + if val_dataloader: + setattr(self, "val_concat_dataset", val_dataloader.dataset) + else: + setattr(self, "val_concat_dataset", None) + return val_dataloader + + def test_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: + if not self._test_data_iterator: + self.setup("test") + batch_size = self._validate_batch_size if self._validate_batch_size else self._batch_size + assert self._validate_data_iterator + return self._construct_concatenated_dataloader( + self._validate_data_iterator, batch_size=batch_size, stage="test" + ) + + def predict_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: + if not self._predict_data_iterator: + self.setup("predict") + batch_size = self._validate_batch_size if self._validate_batch_size else self._batch_size + assert self._predict_data_iterator + return self._construct_concatenated_dataloader( + self._predict_data_iterator, batch_size=batch_size, stage="predict" + ) + + def teardown(self, stage: str | None = None) -> None: + if stage is not None: + getattr(self, f"_{stage}_data_iterator").__del__() + self._data_manager.close() + + @property + def uuid(self) -> uuid_module.UUID: + """This property is used to create a unique cache file for each dataset. The constructor of this dataset + is completely determined by the data description, including the pre_transforms. Therefore, we can use the + data description to create an uuid that is unique for each datamodule. + + The uuid is computed by hashing the data description using the `basemodel_to_uuid` function, which uses + a sha256 hash of the pickled object and converts it to an UUID. + As pickles can change with python versions, this uuid will be different when using different python versions. + + Returns + ------- + str + A unique identifier for this datamodule. + """ + return basemodel_to_uuid(self.data_description) diff --git a/ahcore/data/samplers.py b/ahcore/data/samplers.py new file mode 100644 index 0000000..f75b7a7 --- /dev/null +++ b/ahcore/data/samplers.py @@ -0,0 +1,111 @@ +""" +Module implementing the samplers. These are used for instance to create batches of the same WSI. +""" +from __future__ import annotations + +import math +from typing import Generator, List + +from dlup.data.dataset import ConcatDataset, TiledROIsSlideImageDataset +from torch.utils.data import Sampler, SequentialSampler + +from ahcore.utils.io import get_logger + +logger = get_logger() + + +class WsiBatchSampler(Sampler[List[int]]): + def __init__(self, dataset: ConcatDataset[TiledROIsSlideImageDataset], batch_size: int) -> None: + super().__init__(data_source=dataset) + self._dataset = dataset + self.batch_size = batch_size + + self._slices: List[slice] = [] + self._populate_slices() + + def _populate_slices(self) -> None: + for idx, _ in enumerate(self._dataset.datasets): + slice_start = 0 if len(self._slices) == 0 else self._slices[-1].stop + slice_stop = self._dataset.cumulative_sizes[idx] + self._slices.append(slice(slice_start, slice_stop)) + + def __iter__(self) -> Generator[List[int], None, None]: + for slice_ in self._slices: + batch = [] + # Within each slice, create batches of size self._batch_size + for idx in range(slice_.start, slice_.stop): + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + # If there are remaining items that couldn't form a full batch, yield them as a smaller batch + if len(batch) > 0: + yield batch + + def __len__(self) -> int: + # The total number of batches is the sum of the number of batches in each slice + return sum(math.ceil((s.stop - s.start) / self.batch_size) for s in self._slices) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"batch_size={self.batch_size}, " + f"num_batches={self.__len__()}, " + f"num_wsis={len(self._dataset.datasets)})" + ) + + +class WsiBatchSamplerPredict(Sampler[List[int]]): + """This Sampler is identical to the WsiBatchSampler, + but its signature is changed for compatibility with the predict phase of Lightning. + """ + + def __init__( + self, + sampler: SequentialSampler | None = None, + batch_size: int | None = None, + drop_last: bool | None = None, + dataset: ConcatDataset[TiledROIsSlideImageDataset] | None = None, + ) -> None: + if sampler is not None: # During the predict phase, the sampler is passed as a parameter + self._dataset: ConcatDataset[TiledROIsSlideImageDataset] = sampler.data_source # type: ignore + else: + self._dataset: ConcatDataset[TiledROIsSlideImageDataset] = dataset # type: ignore + super().__init__(data_source=self._dataset) + self.batch_size = batch_size + + self._slices: List[slice] = [] + self._populate_slices() + + def _populate_slices(self) -> None: + for idx, _ in enumerate(self._dataset.datasets): + slice_start = 0 if len(self._slices) == 0 else self._slices[-1].stop + slice_stop = self._dataset.cumulative_sizes[idx] + self._slices.append(slice(slice_start, slice_stop)) + + def __iter__(self) -> Generator[List[int], None, None]: + for slice_ in self._slices: + batch = [] + # Within each slice, create batches of size self._batch_size + for idx in range(slice_.start, slice_.stop): + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + + # If there are remaining items that couldn't form a full batch, yield them as a smaller batch + if len(batch) > 0: + yield batch + + def __len__(self) -> int: + # The total number of batches is the sum of the number of batches in each slice + return sum(math.ceil((s.stop - s.start) / self.batch_size) for s in self._slices) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"batch_size={self.batch_size}, " + f"num_batches={self.__len__()}, " + f"num_wsis={len(self._dataset.datasets)})" + ) diff --git a/ahcore/entrypoints.py b/ahcore/entrypoints.py new file mode 100644 index 0000000..62b3c17 --- /dev/null +++ b/ahcore/entrypoints.py @@ -0,0 +1,274 @@ +""" +Entrypoints +""" +from __future__ import annotations + +import os +import pathlib +from pprint import pformat + +import hydra +import torch +from omegaconf import DictConfig +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer, seed_everything +from pytorch_lightning.loggers import Logger +from torch import nn + +from ahcore.utils.data import DataDescription +from ahcore.utils.io import get_logger, log_hyperparameters + +logger = get_logger(__name__) + + +def create_datamodule( + config: DictConfig, +) -> tuple[DataDescription, LightningDataModule]: + # Load generic description of the data + if not config.data_description.get("_target_"): + raise NotImplementedError(f"No data description defined in <{config.data_description}>") + data_description: DataDescription = hydra.utils.instantiate(config.data_description) + + if config.datamodule.get("_target_"): + logger.info(f"Instantiating datamodule <{config.datamodule._target_}>") # noqa + if not config.pre_transform.get("_target_"): + raise RuntimeError("No pre-transform defined in ") + logger.info(f"Instantiating pre_transforms <{config.pre_transform._target_}>") # noqa + pre_transform = hydra.utils.instantiate(config.pre_transform, data_description=data_description) + datamodule: LightningDataModule = hydra.utils.instantiate( + config.datamodule, + data_description=data_description, + pre_transform=pre_transform, + ) + logger.info(pformat(data_description)) # TODO: Use nice rich formatting + return data_description, datamodule + + raise NotImplementedError(f"No datamodule target found in <{config.datamodule}>") + + +def train(config: DictConfig) -> torch.Tensor | None: + """Contains the training pipeline. Can additionally evaluate model on a testset, using best + weights achieved during training. + Arguments + --------- + config : DictConfig + Configuration composed by Hydra. + + Returns + ------- + Optional : float + Metric score for hyperparameter optimization. + """ + + # Set seed for random number generators in pytorch, numpy and python.random + if config.get("seed"): + seed_everything(config.seed, workers=True) + + # TODO: Configurable? + torch.set_float32_matmul_precision("high") # type: ignore + + # Convert relative ckpt path to absolute path if necessary + checkpoint_path = config.get("ckpt_path") + if checkpoint_path and not os.path.isabs(checkpoint_path): + config.trainer.resume_from_checkpoint = pathlib.Path(hydra.utils.get_original_cwd()) / checkpoint_path + + data_description, datamodule = create_datamodule(config) + + # Init augmentations + augmentations: dict[str, nn.Module] | None = None + if "augmentations" in config: + augmentations = {} + for stage in config.augmentations: + if not config.augmentations[stage].get("_target_"): + raise NotImplementedError(f"No augmentations target found in <{config.augmentations[stage]}>") + logger.info(f"Instantiating {stage} augmentations <{config.augmentations[stage]._target_}>") # noqa + augmentations[stage] = hydra.utils.instantiate( + config.augmentations[stage], + data_description=data_description, + data_module=datamodule, + _convert_="object", + ) + + if not config.losses.get("_target_"): + raise NotImplementedError(f"No loss target found in <{config.metrics}>") + loss = hydra.utils.instantiate(config.losses) + + metrics: dict[str, nn.Module] | None = None + if "metrics" in config: + metrics = {} + for metric_class in config.metrics: + if not config.metrics[metric_class].get("_target_"): + raise NotImplementedError(f"No metrics target found in <{config.metrics[metric_class]}>") + logger.info(f"Instantiating metrics <{config.metrics[metric_class]._target_}>") # noqa + metrics[metric_class] = hydra.utils.instantiate( + config.metrics[metric_class], data_description=data_description + ) + logger.info(f"Metrics: {metrics}") + + # Init lightning model + if not config.lit_module.get("_target_"): + raise NotImplementedError(f"No model target found in <{config.lit_module}>") + + logger.info(f"Instantiating model <{config.lit_module._target_}>") # noqa + model: LightningModule = hydra.utils.instantiate( + config.lit_module, + data_description=data_description, + augmentations=augmentations, + loss=loss, + metrics=metrics, + _convert_="partial", + ) + + # Init lightning callbacks + callbacks: list[Callback] = [] + if "callbacks" in config: + for _, cb_conf in config.callbacks.items(): + if "_target_" in cb_conf: + logger.info("Instantiating callback <%s>", cb_conf._target_) # noqa + callbacks.append(hydra.utils.instantiate(cb_conf)) + + # Init lightning loggers + lightning_loggers: list[Logger] = [] + if "logger" in config: + for _, lg_conf in config.logger.items(): + if "_target_" in lg_conf: + logger.info("Instantiating logger <%s>", lg_conf._target_) # noqa + lightning_loggers.append(hydra.utils.instantiate(lg_conf)) + + # Init lightning trainer + if config.trainer.get("_target_"): + logger.info("Instantiating trainer <%s>", str(config.trainer._target_)) # noqa + trainer: Trainer = hydra.utils.instantiate( + config.trainer, + callbacks=callbacks, + logger=lightning_loggers, + _convert_="partial", + ) + + else: + raise NotImplementedError(f"No trainer target found in <{config.trainer}>") + + # Send some parameters from config to all lightning loggers + logger.info("Logging hyperparameters...") + log_hyperparameters(config=config, model=model, trainer=trainer) + + if config.get("train"): + logger.info("Starting training...") + trainer.fit(model=model, datamodule=datamodule) + + # Get metric score for hyperparameter optimization + optimized_metric = config.get("optimized_metric") + if optimized_metric and optimized_metric not in trainer.callback_metrics: + raise Exception( + "Metric for hyperparameter optimization not found. " + "Make sure the `optimized_metric` in `hparams_search` config is correct." + ) + score = trainer.callback_metrics.get(optimized_metric) + + # Test the model + if config.get("test_after_training"): + checkpoint_path = "best" + if not config.get("train") or config.trainer.get("fast_dev_run"): + checkpoint_path = None + logger.info("Starting testing...") + trainer.test(model=model, datamodule=datamodule, ckpt_path=checkpoint_path) + + # Make sure everything closed properly + logger.info("Finalizing...") + + # Print path to best checkpoint + if trainer.checkpoint_callback: + if not config.trainer.get("fast_dev_run") and config.get("train"): + logger.info(f"Best model checkpoint at {trainer.checkpoint_callback.best_model_path}") # type: ignore + + # Return metric score for hyperparameter optimization + return score + + +def inference(config: DictConfig) -> None: + """Contains the inference pipeline. + Arguments + --------- + config : DictConfig + Configuration composed by Hydra. + + Returns + ------- + None + """ + + # Set seed for random number generators in pytorch, numpy and python.random + if config.get("seed"): + seed_everything(config.seed, workers=True) + + # Convert relative ckpt path to absolute path if necessary + checkpoint_path = config.get("ckpt_path") + if not checkpoint_path: + raise RuntimeError("No checkpoint inputted in config.ckpt_path") + if checkpoint_path and not os.path.isabs(checkpoint_path): + config.trainer.resume_from_checkpoint = pathlib.Path(hydra.utils.get_original_cwd()) / checkpoint_path + + data_description, datamodule = create_datamodule(config) + + # Init augmentations + augmentations: dict[str, nn.Module] | None = None + if "augmentations" in config: + augmentations = {} + for stage in config.augmentations: + if not config.augmentations[stage].get("_target_"): + raise NotImplementedError(f"No augmentations target found in <{config.augmentations[stage]}>") + logger.info(f"Instantiating {stage} augmentations <{config.augmentations[stage]._target_}>") # noqa + augmentations[stage] = hydra.utils.instantiate( + config.augmentations[stage], + data_description=data_description, + data_module=datamodule, + _convert_="object", + ) + # Init lightning model + if not config.lit_module.get("_target_"): + raise NotImplementedError(f"No model target found in <{config.lit_module}>") + logger.info(f"Instantiating model <{config.lit_module._target_}>") # noqa + model: LightningModule = hydra.utils.instantiate( + config.lit_module, + augmentations=augmentations, + data_description=data_description, + _convert_="partial", + ) + + # Load checkpoint weights + lit_ckpt = torch.load(config.ckpt_path) + model.load_state_dict(lit_ckpt["state_dict"], strict=True) + + # Init lightning callbacks + callbacks: list[Callback] = [] + if "callbacks" in config: + for _, cb_conf in config.callbacks.items(): + if "_target_" in cb_conf: + logger.info("Instantiating callback <%s>", cb_conf._target_) # noqa + callbacks.append(hydra.utils.instantiate(cb_conf)) + + # Init lightning loggers + lightning_loggers: list[Logger] = [] + if "logger" in config: + for _, lg_conf in config.logger.items(): + if "_target_" in lg_conf: + logger.info("Instantiating logger <%s>", lg_conf._target_) # noqa + lightning_loggers.append(hydra.utils.instantiate(lg_conf)) + + # Init lightning trainer + if config.trainer.get("_target_"): + logger.info("Instantiating trainer <%s>", str(config.trainer._target_)) + trainer: Trainer = hydra.utils.instantiate( + config.trainer, + callbacks=callbacks, + logger=lightning_loggers, + _convert_="partial", + ) + else: + raise NotImplementedError(f"No trainer target found in <{config.trainer}>") + + # Inference + logger.info("Starting inference...") + trainer.predict(model=model, datamodule=datamodule) + + # Make sure everything closed properly + logger.info("Finalizing...") diff --git a/ahcore/exceptions.py b/ahcore/exceptions.py new file mode 100644 index 0000000..5cdb86d --- /dev/null +++ b/ahcore/exceptions.py @@ -0,0 +1,15 @@ +""" +Exceptions for ahcore +""" +from __future__ import annotations + + +class ConfigurationError(Exception): + def __init__(self, message: str | None) -> None: + self.message = message + + +class RecordNotFoundError(Exception): + """Exception for the database manager.""" + + pass diff --git a/ahcore/lit_module.py b/ahcore/lit_module.py new file mode 100644 index 0000000..2c4748a --- /dev/null +++ b/ahcore/lit_module.py @@ -0,0 +1,216 @@ +""" +This module contains the core Lightning module for ahcore. This module is responsible for: +- Training, Validation and Inference +- Wrapping models +""" +from __future__ import annotations + +from typing import Any + +import pytorch_lightning as pl +import torch.optim.optimizer +from pytorch_lightning.trainer.states import TrainerFn +from torch import nn + +from ahcore.exceptions import ConfigurationError +from ahcore.metrics import MetricFactory, WSIMetricFactory +from ahcore.utils.data import DataDescription +from ahcore.utils.io import get_logger +from ahcore.utils.types import DlupDatasetSample + +logger = get_logger(__name__) + +LitModuleSample = dict[str, Any] # TODO: This can be a TypedDict + + +class AhCoreLightningModule(pl.LightningModule): + RELEVANT_KEYS = [ + "coordinates", + "mpp", + "path", + "region_index", + "grid_local_coordinates", + "grid_index", + ] + + def __init__( + self, + model: nn.Module, + optimizer: torch.optim.Optimizer, # noqa + data_description: DataDescription, + loss: nn.Module | None = None, + augmentations: dict[str, nn.Module] | None = None, + metrics: dict[str, MetricFactory | WSIMetricFactory] | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, # noqa + ): + super().__init__() + + self.save_hyperparameters( + logger=False, + ignore=[ + "model", + "augmentations", + "metrics", + "data_description", + "loss", + ], + ) # TODO: we should send the hyperparams to the logger elsewhere + + self._num_classes = data_description.num_classes + self._model = model(out_channels=self._num_classes) + self._augmentations = augmentations + + self._loss = loss + if metrics is not None: + tile_metric = metrics.get("tile_level") + wsi_metric = metrics.get("wsi_level", None) + if tile_metric is not None and not isinstance(tile_metric, MetricFactory): + raise ConfigurationError("Tile metrics must be of type MetricFactory") + if wsi_metric is not None and not isinstance(wsi_metric, WSIMetricFactory): + raise ConfigurationError("WSI metrics must be of type WSIMetricFactory") + + self._tile_metric = tile_metric + self._wsi_metrics = wsi_metric + + self._data_description = data_description + + @property + def wsi_metrics(self) -> WSIMetricFactory | None: + return self._wsi_metrics + + @property + def name(self) -> str: + return str(self._model.__class__.__name__) + + def forward(self, sample: torch.Tensor) -> Any: + """This function is only used during inference""" + self._model.eval() + return self._model.forward(sample) + + @property + def data_description(self) -> DataDescription: + return self._data_description + + def _compute_metrics( + self, + prediction: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None, + stage: TrainerFn | str, + ) -> dict[str, torch.Tensor]: + if not self._tile_metric: + return {} + + _stage = stage.value if isinstance(stage, TrainerFn) else stage + metrics = {f"{_stage}/{k}": v for k, v in self._tile_metric(prediction, target, roi).items()} + return metrics + + def do_step(self, batch: DlupDatasetSample, batch_idx: int, stage: TrainerFn | str) -> LitModuleSample: + if self._augmentations and stage in self._augmentations: + batch = self._augmentations[stage](batch) + + if self._loss is None: + raise RuntimeError( + f"Loss is not defined for {self.__class__.__name__}. " + f"This is required during training and validation" + ) + + _target = batch["target"] + # Batch size is required for accurate loss calculation and logging + batch_size = batch["image"].shape[0] + # ROIs can reduce the usable area of the inputs, the loss should be scaled appropriately + roi = batch.get("roi", None) + + if stage == "fit": + _prediction = self._model(batch["image"]) + batch["prediction"] = _prediction + else: + batch = {**batch, **self._get_inference_prediction(batch["image"])} + _prediction = batch["prediction"] + + loss = self._loss(_prediction, _target, roi) + + # The relevant_dict contains values to know where the tiles originate. + _relevant_dict = {k: v for k, v in batch.items() if k in self.RELEVANT_KEYS} + _metrics = self._compute_metrics(_prediction, _target, roi, stage=stage) + _loss = loss.mean() + # TODO: This can be a TypedDict + output = { + "loss": _loss, + "loss_per_sample": loss.clone().detach(), + "metrics": _metrics, + **_relevant_dict, + } + if stage != "fit": + output["prediction"] = _prediction + + _stage = stage.value if isinstance(stage, TrainerFn) else stage + + self.log( + f"{_stage}/loss", + _loss, + batch_size=batch_size, + sync_dist=True, + on_epoch=True, + prog_bar=True, + ) + + # Log the metrics + self.log_dict( + _metrics, + batch_size=batch_size, + sync_dist=True, + prog_bar=False, + on_epoch=True, + on_step=False, + ) + + return output + + def _get_inference_prediction(self, _input: torch.Tensor) -> dict[str, torch.Tensor]: + output = {} + output["prediction"] = self._model(_input) + return output + + def training_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, Any]: + output = self.do_step(batch, batch_idx, stage="fit") + return output + + def validation_step(self, batch: dict[str, Any], batch_idx: int) -> dict[str, Any]: + output = self.do_step(batch, batch_idx, stage="validate") + + # This is a sanity check. We expect the filenames to be constant across the batch. + filename = batch["path"][0] + if any([filename != f for f in batch["path"]]): + raise ValueError("Filenames are not constant across the batch.") + return output + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + if self._augmentations and "predict" in self._augmentations: + batch = self._augmentations["predict"](batch) + + _relevant_dict = {k: v for k, v in batch.items() if k in self.RELEVANT_KEYS} + batch = {**batch, **self._get_inference_prediction(batch["image"])} + _prediction = batch["prediction"] + output = {"prediction": _prediction, **_relevant_dict} + + # This is a sanity check. We expect the filenames to be constant across the batch. + filename = batch["path"][0] + if any([filename != f for f in batch["path"]]): + raise ValueError("Filenames are not constant across the batch.") + return output + + def configure_optimizers(self) -> Any: + optimizer = self.hparams.optimizer(params=self.parameters()) # type: ignore + if self.hparams.scheduler is not None: # type: ignore + scheduler = self.hparams.scheduler(optimizer=optimizer) # type: ignore + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "validate/loss", + "interval": "epoch", + "frequency": self.trainer.check_val_every_n_epoch, + }, + } + return {"optimizer": optimizer} diff --git a/ahcore/losses.py b/ahcore/losses.py new file mode 100644 index 0000000..64930c5 --- /dev/null +++ b/ahcore/losses.py @@ -0,0 +1,253 @@ +""" +Loss factory + +All the relevant loss modules. In ahcore, losses are returned per sample in the batch. + +""" +from __future__ import annotations + +from typing import Callable, Optional, Union, cast + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LossFactory(nn.Module): + """Loss factory to construct the total loss.""" + + def __init__( + self, + losses: list[ + dict[ + str, + Callable[ + [ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + ], + torch.Tensor, + ], + ] + ], + weights: list[Union[torch.Tensor, float]] | None = None, + class_proportions: torch.Tensor | None = None, + ): + """ + Parameters + ---------- + losses : list[dict[str, Callable[[torch.Tensor, torch.Tensor, torch.Tensor | None], torch.Tensor]] + List of losses which are functions which accept `(input, target, roi, weight)`. The weight will be + applied per class. + weights : list + List of length `losses`. The weights weight the total contribution so `weight_0 * loss_0_val + ...` will + be the resulting loss. + class_proportions : torch.Tensor, optional + The class proportions are used to weight the loss per class. This is useful for class imbalance. + + TODO + ---- + Is this the best way to pass the losses? + """ + super().__init__() + if weights is None: + weights = [1.0] * len(losses) + + _weights = [torch.Tensor([_]) for _ in weights] + self._weights: list[torch.Tensor] = _weights + + self._losses = [] + for loss in losses: + self._losses += list(loss.values()) + + if class_proportions is not None: + _class_weights = 1 / class_proportions + _class_weights[_class_weights.isnan()] = 0.0 + _class_weights = _class_weights / _class_weights.max() + self._class_weights = _class_weights + else: + self._class_weights = None + + def forward(self, input: torch.Tensor, target: torch.Tensor, roi: torch.Tensor | None = None) -> torch.Tensor: + total_loss = sum( + [ + weight.to(input.device) * curr_loss(input, target, roi, self._class_weights) + for weight, curr_loss in zip(self._weights, self._losses) + ], + torch.tensor([0.0] * input.shape[0], device=input.device), # Default value for sum + ) + return total_loss + + +def cross_entropy( + input: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None = None, + weight: torch.Tensor | None = None, + ignore_index: int | None = None, + topk: float | None = None, + label_smoothing: float = 0.0, + limit: float | None = None, +) -> torch.Tensor: + """ + Compute a ROI weighted cross entropy function. The resulting output is a per-sample cross entropy. + + Parameters + ---------- + input : torch.Tensor + Input of shape `(N, C, H, W)`. + target : torch.Tensor + One-hot encoded target of shape `(N, C, H, W)`. + roi : torch.Tensor + ROI of shape `(N, 1, H, W)` + weight : torch.Tensor, optional + Per class weight + ignore_index : int, optional + Specifies a target value that is ignored and does not contribute to the input gradient. + topk : float, optional + Apply top-k in the loss + label_smoothing : float, optional + Float in [0, 1]. Amount of smoothing. + `Rethinking the Inception Architecture for Computer Vision `__. + Default: :math:`0.0`. + limit : float, optional + If set this will be the value the cross entropy is clipped (from below). This has to be a negative value. + + Returns + ------- + torch.Tensor + Output as a torch.Tensor float + """ + if limit is not None: + if limit >= 0: + raise ValueError(f"Limit has to be a negative value. Got {limit}") + + if topk is not None: + if topk <= 0 or topk >= 1.0: + raise ValueError(f"topk value needs to be between 0 and 1. Got {topk}.") + + if roi is not None: + roi_sum = roi.sum() / input.shape[0] + if roi_sum == 0: + # Return a loss of zero if there is no ROI of length batch size + return torch.tensor([0.0] * input.shape[0], requires_grad=True).to(input.device) + else: + roi_sum = torch.tensor([np.prod(tuple(input.shape)[2:])]).to(input.device) + + if ignore_index is None: + ignore_index = -100 + + # compute cross_entropy pixel by pixel + _cross_entropy = F.cross_entropy( + input, + target.argmax(dim=1), + ignore_index=ignore_index, + weight=None if weight is None else weight.to(input.device), + reduction="none", + label_smoothing=label_smoothing, + ) + + if limit is not None: + _cross_entropy = torch.clip(_cross_entropy, limit, None) + + if roi is not None: + _cross_entropy = roi[:, 0, ...] * _cross_entropy + + if topk is None: + return _cross_entropy.sum(dim=(1, 2)) / roi_sum + + k = int(round(float(roi_sum.cpu()) * topk)) + # top-k returns Any + return cast( + torch.Tensor, + torch.topk(_cross_entropy.view(input.shape[0], -1), k).values.sum(dim=1), + ) / (roi_sum * topk) + + +def soft_dice( + input: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None = None, + weight: torch.Tensor | None = None, + ignore_index: int | None = None, + eps: float = 1e-17, +) -> torch.Tensor: + r"""Criterion that computes Sørensen-Dice Coefficient loss. + + According to [1], we compute the Sørensen-Dice Coefficient as follows: + + .. math:: + + \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} + + where: + - :math:`X` expects to be the scores of each class. + - :math:`Y` expects to be the one-hot tensor with the class labels. + + the loss, is finally computed as: + + .. math:: + + \text{loss}(x, class) = 1 - \text{Dice}(x, class) + + [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient + + The shapes of input and target need to be :math:`(N, C, H, W)` where :math:`C` = number of classes. + + + Parameters + ---------- + input : torch.Tensor + Input of shape `(N, C, H, W)`. + target : torch.Tensor + One-hot encoded target of shape `(N, C, H, W)`. + roi : torch.Tensor + ROI of shape `(N, 1, H, W)` + weight : torch.Tensor, optional + Per class weight + ignore_index : int, optional + Specifies a target value that is ignored and does not contribute to the input gradient. + eps : float + Regularizer in the division + + Returns + ------- + torch.Tensor + Output as a torch.Tensor float + """ + if weight is not None: + raise NotImplementedError("Weight not yet implemented for dice loss.") + + if not torch.is_tensor(input): # type: ignore[no-untyped-call] + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {input.shape}") + if not input.shape[-2:] == target.shape[-2:]: + raise ValueError(f"Input and target shapes must be the same. Got: {input.shape} and {target.shape}") + if not input.device == target.device: + raise ValueError(f"Input and target must be in the same device. Got: {input.device} and {target.device}") + + if ignore_index is not None: + mask = target != ignore_index + input = mask * input + target = mask * target + + # Apply the ROI if it is there + if roi is not None: + input = roi * input + target = roi * target + + # Softmax still needs to be taken (logits are outputted by the network) + input_soft = F.softmax(input, dim=1) + # FIXME: This is not correct, when we input 0s for both target and input, we get 1.0. + + # Compute the dice score + intersection = torch.sum(input_soft * target, dim=(2, 3)) + cardinality = torch.sum(input_soft**2, dim=(2, 3)) + torch.sum(target**2, dim=(2, 3)) + + # dice_score has shape (B, C) + dice_score = (2.0 * intersection + eps) / (cardinality + eps) + return 1.0 - dice_score.mean(axis=1) # type: ignore diff --git a/ahcore/metrics/__init__.py b/ahcore/metrics/__init__.py new file mode 100644 index 0000000..edf15ef --- /dev/null +++ b/ahcore/metrics/__init__.py @@ -0,0 +1,10 @@ +"Ahcore's metrics" +from ahcore.metrics.metrics import DiceMetric, MetricFactory, TileMetric, WSIMetric, WSIMetricFactory + +__all__ = [ + "TileMetric", + "DiceMetric", + "MetricFactory", + "WSIMetricFactory", + "WSIMetric", +] diff --git a/ahcore/metrics/metrics.py b/ahcore/metrics/metrics.py new file mode 100644 index 0000000..d867bf4 --- /dev/null +++ b/ahcore/metrics/metrics.py @@ -0,0 +1,399 @@ +""" +Metrics module, including factory. +""" +from __future__ import annotations + +import abc +from collections import defaultdict +from typing import Any, List, Tuple + +import torch +import torch.nn.functional as F # noqa + +from ahcore.exceptions import ConfigurationError +from ahcore.utils.data import DataDescription + + +class TileMetric: + def __init__(self, data_description: DataDescription) -> None: + """Initialize the metric class""" + self._data_description = data_description + self.name: str | None = None + + @abc.abstractmethod + def __call__( + self, predictions: torch.Tensor, target: torch.Tensor, roi: torch.Tensor | None + ) -> dict[str, torch.Tensor]: + """Call metric computation""" + + +class DiceMetric(TileMetric): + def __init__(self, data_description: DataDescription) -> None: + r""" + Metric computing dice over classes. The classes are derived from the index_map that's defined in the + data_description. + + First, a softmax is taken over the predictions, followed by a softmax. Then, if there is a ROI available, the + input and target are masked with this ROI. This is followed by an argmax over the predictions and target, + resulting in a tensor of shape (batch_size, height, width) with values in [0, num_classes - 1]. The dice is then + computed over each class. + + We use as definition for the dice score: + :math:`\text{dice} = 2 * \frac{|X| \intersection |Y|}{|X| + |Y|}` where :math:`|X|` is the number of voxels in + the prediction, :math:`|Y|` is the number of voxels in the target, and :math:`\intersection` is the intersection + of :math:`X` and :math:`Y`. + + The `__call__` returns the dice score for each class, with the class name (prefixed with dice/) as key + in a dictionary. + + Parameters + ---------- + data_description : DataDescription + """ + super().__init__(data_description=data_description) + self._num_classes = self._data_description.num_classes + + # Invert the index map + _index_map = {} + if self._data_description.index_map is None: + raise ConfigurationError("`index_map` is required for to setup the dice metric.") + else: + _index_map = self._data_description.index_map + + _label_to_class = {v: k for k, v in _index_map.items()} + _label_to_class[0] = "background" + self._label_to_class = _label_to_class + + self.name = "dice" + + def __call__( + self, predictions: torch.Tensor, target: torch.Tensor, roi: torch.Tensor | None + ) -> dict[str, torch.Tensor]: + dice_components = _get_intersection_and_cardinality(predictions, target, roi, self._num_classes) + dices = [] + for intersection, cardinality in dice_components: + # dice_score is a float + dice_score = _compute_dice(intersection, cardinality) + dices.append(dice_score) + + output = {f"{self.name}/{self._label_to_class[idx]}": dices[idx] for idx in range(0, self._num_classes)} + return output + + def __repr__(self) -> str: + return f"{type(self).__name__}(num_classes={self._num_classes})" + + +class MetricFactory: + # TODO: this should be rewritten to actually be a factory + """Factory to create the metrics. These are fixed for the different tasks + (e.g., segmentation, detection, whole-slide-level classification. + """ + + def __init__(self, metrics: list[TileMetric]) -> None: + """ + Parameters + ---------- + metrics : list + List of metrics of type `AhCoreMetric`. + """ + super().__init__() + names = [metric.name for metric in metrics] + if len(set(names)) != len(names): + raise RuntimeError("Each individual metric must have a different name.") + + self._metrics = metrics + + @classmethod + def for_segmentation(cls, *args: Any, **kwargs: Any) -> MetricFactory: + dices = DiceMetric(*args, **kwargs) + return cls([dices]) + + @classmethod + def for_wsi_classification(cls, *args: Any, **kwargs: Any) -> MetricFactory: + raise NotImplementedError + + @classmethod + def for_tile_classification(cls, *args: Any, **kwargs: Any) -> MetricFactory: + raise NotImplementedError + + def __call__( + self, predictions: torch.Tensor, target: torch.Tensor, roi: torch.Tensor | None + ) -> dict[str, torch.Tensor]: + output = {} + for metric in self._metrics: + output.update(metric(predictions, target, roi=roi)) + return output + + def __repr__(self) -> str: + return f"{type(self).__name__}(metrics={self._metrics})" + + +class WSIMetric(abc.ABC): + def __init__(self, data_description: DataDescription) -> None: + """Initialize the WSI metric class""" + self.wsis: dict[str, Any] = {} + self._data_description = data_description + + @property + @abc.abstractmethod + def name(self) -> str: + pass + + @abc.abstractmethod + # TODO: Fix Any + def process_batch(self, *args: Any, **kwargs: Any) -> None: + pass + + @abc.abstractmethod + # TODO: Fix Any + def get_wsi_score(self, *args: Any, **kwargs: Any) -> None: + pass + + @abc.abstractmethod + def get_average_score(self, *args: Any, **kwargs: Any) -> dict[str, float]: + pass + + @abc.abstractmethod + def reset(self) -> None: + pass + + +class WSIDiceMetric(WSIMetric): + """WSI Dice metric class, computes the dice score over the whole WSI""" + + def __init__(self, data_description: DataDescription, compute_overall_dice: bool = False) -> None: + super().__init__(data_description=data_description) + self.compute_overall_dice = compute_overall_dice + self._num_classes = self._data_description.num_classes + # self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self._device = "cpu" + + # Invert the index map + _index_map = {} + if self._data_description.index_map is None: + raise ConfigurationError("`index_map` is required for to setup the wsi-dice metric.") + else: + _index_map = self._data_description.index_map + + _label_to_class: dict[int, str] = {v: k for k, v in _index_map.items()} + _label_to_class[0] = "background" + self._label_to_class = _label_to_class + + @property + def name(self) -> str: + return "wsi_dice" + + def process_batch( + self, + predictions: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None, + wsi_name: str, + ) -> None: + if wsi_name not in self.wsis: + self._initialize_wsi_dict(wsi_name) + dice_components = _get_intersection_and_cardinality( + predictions.to(self._device), + target.to(self._device), + roi.to(self._device) if roi is not None else None, + self._num_classes, + ) + for class_idx, (intersection, cardinality) in enumerate(dice_components): + self.wsis[wsi_name][class_idx]["intersection"] += intersection + self.wsis[wsi_name][class_idx]["cardinality"] += cardinality + + def get_wsi_score(self, wsi_name: str) -> None: + for class_idx in self.wsis[wsi_name]: + intersection = self.wsis[wsi_name][class_idx]["intersection"] + cardinality = self.wsis[wsi_name][class_idx]["cardinality"] + self.wsis[wsi_name][class_idx]["wsi_dice"] = _compute_dice(intersection, cardinality) + + def _get_overall_dice(self) -> dict[int, float]: + """ + Compute the overall dice score (per class) over all the WSIs + + Returns + ------- + dict + Dictionary with the overall dice scores across wsis per class + """ + overall_dices: dict[int, dict[str, float]] = { + class_idx: { + "total_intersection": 0.0, + "total_cardinality": 0.0, + "overall_dice": 0.0, + } + for class_idx in range(self._num_classes) + } + for wsi_name in self.wsis: + for class_idx in range(self._num_classes): + overall_dices[class_idx]["total_intersection"] += self.wsis[wsi_name][class_idx]["intersection"] + overall_dices[class_idx]["total_cardinality"] += self.wsis[wsi_name][class_idx]["cardinality"] + for class_idx in overall_dices.keys(): + intersection = overall_dices[class_idx]["total_intersection"] + cardinality = overall_dices[class_idx]["total_cardinality"] + overall_dices[class_idx]["overall_dice"] = (2 * intersection + 0.01) / (cardinality + 0.01) + return { + class_idx: torch.tensor(overall_dices[class_idx]["overall_dice"]).item() + for class_idx in overall_dices.keys() + } + + def _get_dice_averaged_over_total_wsis(self) -> dict[int, float]: + """ + Compute the dice score (per class) averaged over all the WSIs + + Returns + ------- + dict + Dictionary with the dice scores averaged over all the WSIs per class + """ + dices: dict[int, list[float]] = {class_idx: [] for class_idx in range(self._num_classes)} + for wsi_name in self.wsis: + self.get_wsi_score(wsi_name) + for class_idx in range(self._num_classes): + dices[class_idx].append(self.wsis[wsi_name][class_idx]["dice"].item()) + return {class_idx: sum(dices[class_idx]) / len(dices[class_idx]) for class_idx in dices.keys()} + + def _initialize_wsi_dict(self, wsi_name: str) -> None: + self.wsis[wsi_name] = { + class_idx: {"intersection": 0, "cardinality": 0, "dice": None} for class_idx in range(self._num_classes) + } + + def get_average_score( + self, precomputed_output: list[list[dict[str, dict[str, float]]]] | None = None + ) -> dict[Any, Any]: + if ( + precomputed_output is not None + ): # Used for multiprocessing, where multiple instances of this class are created + avg_dict = self.static_average_wsi_dice(precomputed_output) + if ( + avg_dict + ): # check if the precomputed output contained wsi dice scores, otherwise continue to compute it normally + return avg_dict + if self.compute_overall_dice: + dices = self._get_overall_dice() + else: + dices = self._get_dice_averaged_over_total_wsis() + avg_dict = {f"{self.name}/{self._label_to_class[idx]}": value for idx, value in dices.items()} + return avg_dict + + @staticmethod + def static_average_wsi_dice(precomputed_output: list[list[dict[str, dict[str, float]]]]) -> dict[str, float]: + """Static method to compute the average WSI dice score over a list of WSI dice scores, + useful for multiprocessing.""" + # Initialize defaultdicts to handle the sum and count of dice scores for each class + class_sum: dict[str, float] = defaultdict(float) + class_count: dict[str, int] = defaultdict(int) + + # Flatten the list and extract 'wsi_dice' dictionaries + wsi_dices: list[dict[str, float]] = [ + wsi_metric.get("wsi_dice", {}) for sublist in precomputed_output for wsi_metric in sublist + ] + # Check if the list is empty -- then the precomputed output did not contain any wsi dice scores + if not wsi_dices: + return {} + + # Update sum and count for each class in a single pass + for wsi_dice in wsi_dices: + for class_name, dice_score in wsi_dice.items(): + class_sum[class_name] += dice_score + class_count[class_name] += 1 + + # Compute average dice scores in a dictionary comprehension with consistent naming + avg_dice_scores = { + f"{'wsi_dice'}/{class_name}": class_sum[class_name] / class_count[class_name] + for class_name in class_sum.keys() + } + return avg_dice_scores + + def reset(self) -> None: + self.wsis = {} + + def __repr__(self) -> str: + return f"{type(self).__name__}(num_classes={self._num_classes})" + + +class WSIMetricFactory: + # TODO: this should be rewritten to actually be a factory + def __init__(self, metrics: list[WSIMetric]) -> None: + super().__init__() + names = [metric.name for metric in metrics] + if len(set(names)) != len(names): + raise RuntimeError("Each individual metric must have a different name.") + + self._metrics = metrics + + @classmethod + def for_segmentation(cls, *args: Any, **kwargs: Any) -> WSIMetricFactory: + dices = WSIDiceMetric(*args, **kwargs) + return cls([dices]) + + @classmethod + def for_wsi_classification(cls, *args: Any, **kwargs: Any) -> WSIMetricFactory: + raise NotImplementedError + + @classmethod + def for_tile_classification(cls, roi_name: str, label: str, threshold: float) -> WSIMetricFactory: + raise NotImplementedError + + def process_batch( + self, + predictions: torch.Tensor, + target: torch.Tensor, + wsi_name: str, + roi: torch.Tensor | None, + ) -> None: + for metric in self._metrics: + metric.process_batch(predictions, target, wsi_name=wsi_name, roi=roi) + + def get_average_score( + self, precomputed_output: list[list[dict[str, dict[str, float]]]] | None = None + ) -> dict[str, float]: + output = {} + for metric in self._metrics: + output.update(metric.get_average_score(precomputed_output)) + return output + + def reset(self) -> None: + for metric in self._metrics: + metric.reset() + + def __repr__(self) -> str: + return f"{type(self).__name__}(metrics={self._metrics})" + + +def _get_intersection_and_cardinality( + predictions: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None, + num_classes: int, +) -> List[Tuple[torch.Tensor, torch.Tensor]]: + soft_predictions = F.softmax(predictions, dim=1) + # if roi is not None: + # soft_predictions = soft_predictions * roi + # target = target * roi + + predictions = soft_predictions.argmax(dim=1) + _target = target.argmax(dim=1) + + dice_components = [] + for class_idx in range(num_classes): + curr_predictions = (predictions == class_idx).int() + curr_target = (_target == class_idx).int() + # Compute the dice score + if roi is not None: + intersection = torch.sum((curr_predictions * curr_target) * roi.squeeze(1), dim=(0, 1, 2)) + cardinality = torch.sum(curr_predictions * roi.squeeze(1), dim=(0, 1, 2)) + torch.sum( + curr_target * roi.squeeze(1), dim=(0, 1, 2) + ) + else: + intersection = torch.sum((curr_predictions * curr_target), dim=(0, 1, 2)) + cardinality = torch.sum(curr_predictions, dim=(0, 1, 2)) + torch.sum(curr_target, dim=(0, 1, 2)) + dice_components.append((intersection, cardinality)) + return dice_components + + +def _compute_dice(intersection: torch.Tensor, cardinality: torch.Tensor) -> torch.Tensor: + dice_score = 2.0 * intersection / cardinality + dice_score[dice_score.isnan()] = 1.0 + return dice_score diff --git a/ahcore/models/__init__.py b/ahcore/models/__init__.py new file mode 100644 index 0000000..99e0775 --- /dev/null +++ b/ahcore/models/__init__.py @@ -0,0 +1,5 @@ +""" +Module which contains all the ahcore models. + +The interfaces should be implemented like Monai models. +""" diff --git a/ahcore/py.typed b/ahcore/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/ahcore/readers.py b/ahcore/readers.py new file mode 100644 index 0000000..c1f1e64 --- /dev/null +++ b/ahcore/readers.py @@ -0,0 +1,309 @@ +""" +Reader classes. + +- `H5FileImageReader`: to read files written using the `ahcore.writers.H5FileImageWriter`. + +""" +import errno +import json +import math +import os +from enum import Enum +from pathlib import Path +from types import TracebackType +from typing import Literal, Optional, Type, cast + +import h5py +import numpy as np +from scipy.ndimage import map_coordinates + +from ahcore.utils.io import get_logger +from ahcore.utils.types import BoundingBoxType, GenericArray + +logger = get_logger(__name__) + + +class StitchingMode(str, Enum): + CROP = "crop" + AVERAGE = "average" + MAXIMUM = "maximum" + + +def crop_to_bbox(array: GenericArray, bbox: BoundingBoxType) -> GenericArray: + (start_x, start_y), (width, height) = bbox + return array[:, start_y : start_y + height, start_x : start_x + width] + + +class H5FileImageReader: + def __init__(self, filename: Path, stitching_mode: StitchingMode) -> None: + self._filename = filename + self._stitching_mode = stitching_mode + + self.__empty_tile: GenericArray | None = None + + self._h5file: Optional[h5py.File] = None + self._metadata = None + self._mpp = None + self._tile_size = None + self._tile_overlap = None + self._size = None + self._num_channels = None + self._dtype = None + self._stride = None + + @classmethod + def from_file_path(cls, filename: Path, stitching_mode: StitchingMode = StitchingMode.CROP) -> "H5FileImageReader": + return cls(filename=filename, stitching_mode=stitching_mode) + + @property + def size(self) -> tuple[int, int]: + if not self._size: + self._open_file() + assert self._size + return self._size + + @property + def mpp(self) -> float: + if not self._mpp: + self._open_file() + assert self._mpp + return self._mpp + + def get_mpp(self, scaling: Optional[float]) -> float: + if not self._mpp: + self._open_file() + assert self._mpp + if scaling is None: + return self.mpp + + return self._mpp / scaling + + def get_scaling(self, mpp: Optional[float]) -> float: + """Inverse of get_mpp().""" + if not self._mpp: + self._open_file() + assert self._mpp + if not mpp: + return 1.0 + return self._mpp / mpp + + def _open_file(self) -> None: + if not self._filename.is_file(): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), str(self._filename)) + + try: + self._h5file = h5py.File(self._filename, "r") + except OSError as e: + logger.error(f"Could not open file {self._filename}: {e}") + raise e + + try: + self._metadata = json.loads(self._h5file.attrs["metadata"]) + except KeyError as e: + logger.error(f"Could not read metadata from file {self._filename}: {e}") + raise e + + if not self._metadata: + raise ValueError("Metadata of h5 file is empty.") + self._mpp = self._metadata["mpp"] + self._tile_size = self._metadata["tile_size"] + self._tile_overlap = self._metadata["tile_overlap"] + self._size = self._metadata["size"] + self._num_channels = self._metadata["num_channels"] + self._dtype = self._metadata["dtype"] + self._stride = ( + self._tile_size[0] - self._tile_overlap[0], + self._tile_size[1] - self._tile_overlap[1], + ) + + def __enter__(self) -> "H5FileImageReader": + if self._h5file is None: + self._open_file() + return self + + def _empty_tile(self) -> GenericArray: + if self.__empty_tile is not None: + return self.__empty_tile + + # When this happens we would already be in the read_region, and self._num_channels would be populated. + assert self._num_channels + + self.__empty_tile = np.zeros((self._num_channels, *self._tile_size), dtype=self._dtype) + return self.__empty_tile + + def read_region( + self, + location: tuple[int, int], + scaling: float, + size: tuple[int, int], + ) -> GenericArray: + """ + + Parameters + ---------- + location : tuple[int, int] + Location from the top left (x, y) in pixel coordinates given at the requested scaling. + scaling : float + size : tuple[int, int] + Size of the output region + + Returns + ------- + np.ndarray + The requested region. + """ + if scaling == 1.0: + return self.read_region_raw(location, size) + + order = 1 + # Calculate original location and size considering the scaling + + # unpack for mypy + l1, l2 = location + s1, s2 = size + + original_location = ( + int(math.floor(l1 / scaling)) - order, + int(math.floor(l2 / scaling)) - order, + ) + original_size = ( + int(math.ceil(s1 / scaling)) + order, + int(math.ceil(s2 / scaling)) + order, + ) + + raw_region = self.read_region_raw(original_location, original_size) + + # Determine the fractional start and end coordinates for mapping + fractional_start = tuple(map(lambda _, ol: (_ / scaling) - ol + order, location, original_location)) + fractional_end = tuple(fs + size[i] / scaling for i, fs in enumerate(fractional_start)) + + # Create an array of coordinates for map_coordinates + # mypy doesn't properly understand yet that the complex type is valid + coordinates = np.mgrid[ + fractional_start[0] : fractional_end[0] : complex(size[0]), # type: ignore + fractional_start[1] : fractional_end[1] : complex(size[1]), # type: ignore + ] + coordinates = np.moveaxis(coordinates, 0, -1) + + # Interpolate using map_coordinates for all channels + grid = np.mgrid[: raw_region.shape[0]] + coordinates = np.concatenate([grid[:, None, None], coordinates], axis=0) + # scipy doesn't have proper typing yet + rescaled_region = cast(GenericArray, map_coordinates(raw_region, coordinates, order=order)) + + return rescaled_region + + def read_region_raw(self, location: tuple[int, int], size: tuple[int, int]) -> GenericArray: + """ + Reads a region in the stored h5 file. This function stitches the regions as saved in the h5 file. Doing this + it takes into account: + 1) The region overlap, several region merging strategies are implemented: cropping, averaging across borders + and taking the maximum across borders. + 2) If tiles are saved or not. In case the tiles are skipped due to a background mask, an empty tile is returned. + + Parameters + ---------- + location : tuple[int, int] + Coordinates (x, y) of the upper left corner of the region. + size : tuple[int, int] + The (h, w) size of the extracted region. + + Returns + ------- + np.ndarray + Extracted region + """ + if self._h5file is None: + self._open_file() + assert self._h5file, "File is not open. Should not happen" + assert self._tile_size + assert self._tile_overlap + + image_dataset = self._h5file["data"] + num_tiles = self._metadata["num_tiles"] + tile_indices = self._h5file["tile_indices"] + + total_rows = math.ceil((self._size[1] - self._tile_overlap[1]) / self._stride[1]) + total_cols = math.ceil((self._size[0] - self._tile_overlap[0]) / self._stride[0]) + + assert total_rows * total_cols == num_tiles + + x, y = location + w, h = size + if x < 0 or y < 0 or x + w > self._size[0] or y + h > self._size[1]: + logger.error(f"Requested region is out of bounds: {location}, {self._size}") + raise ValueError("Requested region is out of bounds") + + start_row = y // self._stride[1] + end_row = min((y + h - 1) // self._stride[1] + 1, total_rows) + start_col = x // self._stride[0] + end_col = min((x + w - 1) // self._stride[0] + 1, total_cols) + + if self._stitching_mode == StitchingMode.AVERAGE: + divisor_array = np.zeros((h, w), dtype=np.uint8) + stitched_image = np.zeros((self._num_channels, h, w), dtype=self._dtype) + for i in range(start_row, end_row): + for j in range(start_col, end_col): + tile_idx = (i * total_cols) + j + # Map through tile indices + tile_index_in_image_dataset = tile_indices[tile_idx] + tile = ( + self._empty_tile() + if tile_index_in_image_dataset == -1 + else image_dataset[tile_index_in_image_dataset] + ) + start_y = i * self._stride[1] - y + end_y = start_y + self._tile_size[1] + start_x = j * self._stride[0] - x + end_x = start_x + self._tile_size[0] + + img_start_y = max(0, start_y) + img_end_y = min(h, end_y) + img_start_x = max(0, start_x) + img_end_x = min(w, end_x) + + if self._stitching_mode == StitchingMode.CROP: + crop_start_y = img_start_y - start_y + crop_end_y = img_end_y - start_y + crop_start_x = img_start_x - start_x + crop_end_x = img_end_x - start_x + + bbox = (crop_start_x, crop_start_y), ( + crop_end_x - crop_start_x, + crop_end_y - crop_start_y, + ) + cropped_tile = crop_to_bbox(tile, bbox) + stitched_image[:, img_start_y:img_end_y, img_start_x:img_end_x] = cropped_tile + + elif self._stitching_mode == StitchingMode.AVERAGE: + raise NotImplementedError + tile_start_y = max(0, -start_y) + tile_end_y = img_end_y - img_start_y + tile_start_x = max(0, -start_x) + tile_end_x = img_end_x - img_start_x + + # TODO: Replace this with crop_to_bbox + cropped_tile = tile[tile_start_y:tile_end_y, tile_start_x:tile_end_x] + stitched_image[img_start_y:img_end_y, img_start_x:img_end_x] += cropped_tile + divisor_array[img_start_y:img_end_y, img_start_x:img_end_x] += 1 + else: + raise ValueError("Unsupported stitching mode") + + if self._stitching_mode == StitchingMode.AVERAGE: + stitched_image = (stitched_image / divisor_array[..., np.newaxis]).astype(float) + + return stitched_image + + def close(self) -> None: + if self._h5file is not None: + self._h5file.close() # Close the file in close + del self._h5file # Reset the h5file attribute + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: + self.close() + return False diff --git a/ahcore/transforms/__init__.py b/ahcore/transforms/__init__.py new file mode 100644 index 0000000..ef86b49 --- /dev/null +++ b/ahcore/transforms/__init__.py @@ -0,0 +1,6 @@ +""" +All transformations. These include: +- Augmentations: these are nn.Modules, and can be part of the model +- Pre-transforms: these are transforms that are applied to samples directly out of the dataset. E.g., convert +a polygon to a mask +""" diff --git a/ahcore/transforms/augmentations.py b/ahcore/transforms/augmentations.py new file mode 100644 index 0000000..55b4f75 --- /dev/null +++ b/ahcore/transforms/augmentations.py @@ -0,0 +1,322 @@ +""" +Augmentations factory +""" +from __future__ import annotations + +from typing import Any, Optional, Union, cast + +import kornia.augmentation as K +import pytorch_lightning as pl +import torch +from kornia.augmentation import random_generator as rg +from kornia.constants import DataKey, Resample +from omegaconf import ListConfig +from torch import nn + +from ahcore.utils.data import DataDescription +from ahcore.utils.io import get_logger + +logger = get_logger(__name__) + + +class MeanStdNormalizer(nn.Module): + """ + Normalizes the mean and standard deviation of the input image. Assumes the original range is `[0, 255]`. + """ + + def __init__( + self, + mean: tuple[float, float, float] | None = None, + std: tuple[float, float, float] | None = None, + ): + """ + Parameters + ---------- + mean : tuple[float, float, float], optional + std : tuple[float, float, float], optional + """ + super().__init__() + if mean is None: + self._mean = nn.Parameter(torch.Tensor([0.0] * 3), requires_grad=False) + else: + self._mean = nn.Parameter(torch.Tensor(mean), requires_grad=False) + + if std is None: + self._std = nn.Parameter(torch.Tensor([1.0] * 3), requires_grad=False) + else: + self._std = nn.Parameter(torch.Tensor(std), requires_grad=False) + + def forward(self, *args: torch.Tensor, **kwargs: Any) -> list[torch.Tensor] | torch.Tensor: + output = [] + data_keys = kwargs["data_keys"] + for sample, data_key in zip(args, data_keys): + if data_key in [DataKey.INPUT, 0, "INPUT"]: + sample = sample / 255.0 + sample = (sample - self._mean[..., None, None].to(sample.device)) / self._std[..., None, None].to( + sample.device + ) + output.append(sample) + + if len(output) == 1: + return output[0] + + return output + + +class HEDColorAugmentation(K.IntensityAugmentationBase2D): + """ + A torch implementation of the color stain augmentation algorithm on the + deconvolved Hemaetoxylin-Eosin-DAB (HED) channels of an image as described + by Tellez et al. (2018) in Appendix A & B here: https://arxiv.org/pdf/1808.05896.pdf. + """ + + # Normalized OD matrix from Ruifrok et al. (2001) + HED_REFERENCE = torch.Tensor([[0.65, 0.70, 0.29], [0.07, 0.99, 0.11], [0.27, 0.57, 0.78]]) + + def __init__( + self, + scale_sigma: float | list[float] | ListConfig, + bias_sigma: float | list[float] | ListConfig, + epsilon: float = 1e-6, + p: float = 0.5, + p_batch: float = 1.0, + same_on_batch: bool = False, + keepdim: bool = False, + **kwargs: Any, + ) -> None: + """ + Apply a color stain augmentation in the Hemaetoxylin-Eosin-DAB (HED) color space based on [1]. + The fixed normalized OD matrix values are based on [2]. + + Parameters + ---------- + scale_sigma: float, ListConfig or list of floats + For each channel in the HED space a random scaling factor is drawn from alpha_i ~ U(1-sigma_i,1+sigma_i). + bias_sigma: float, ListConfig or list of floats + For each channel in the HED space a random bias is added drawn from beta_i ~ U(-sigma_i,sigma_i). + epsilon: float + Small positive bias to avoid numerical errors + + References + ---------- + [1] Tellez, David, et al. "Whole-slide mitosis detection in H&E breast histology using PHH3 as + a reference to train distilled stain-invariant convolutional networks." + IEEE transactions on medical imaging 37.9 (2018): 2126-2136. + [2] Ruifrok AC, Johnston DA. Quantification of histochemical staining by color deconvolution. + Anal Quant Cytol Histol. 2001 Aug;23(4):291-9. PMID: 11531144. + """ + super().__init__(p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim) + if (isinstance(scale_sigma, (list, ListConfig)) and len(scale_sigma) != 3) or ( + isinstance(bias_sigma, (list, ListConfig)) and len(bias_sigma) != 3 + ): + raise ValueError( + f"scale_sigma and bias_sigma should have either 1 or 3 values, " + f"got {scale_sigma} and {bias_sigma} instead." + ) + + if isinstance(scale_sigma, float): + scale_sigma = [scale_sigma for _ in range(3)] + + if isinstance(bias_sigma, float): + bias_sigma = [bias_sigma for _ in range(3)] + + _scale_sigma = torch.tensor(scale_sigma) + _bias_sigma = torch.tensor(bias_sigma) + + scale_factor = torch.stack([1 - _scale_sigma, 1 + _scale_sigma]) + bias_factor = torch.stack([-_bias_sigma, _bias_sigma]) + + self._param_generator = rg.PlainUniformGenerator( + (scale_factor, "scale", None, None), (bias_factor, "bias", None, None) + ) + self.flags = { + "epsilon": torch.tensor([epsilon]), + "M": self.HED_REFERENCE, + "M_inv": torch.linalg.inv(self.HED_REFERENCE), + } + + def apply_transform( + self, + input: torch.Tensor, + params: dict[str, torch.Tensor], + flags: dict[str, Any], + transform: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: + """ + Apply HED color augmentation on an input tensor. + """ + assert flags, "Flags should be provided" + assert params, "Params should be provided" + + epsilon = flags["epsilon"].to(input) + reference_matrix = flags["M"].to(input) + reference_matrix_inv = flags["M_inv"].to(input) + alpha = params["scale"][:, None, None, :].to(input) + beta = params["bias"][:, None, None, :].to(input) + + rgb_tensor = input.permute(0, 2, 3, 1) + optical_density = -torch.log(rgb_tensor + epsilon) + # Mypy doesn't understand this is a tensor. + hed_tensor = cast(torch.Tensor, optical_density @ reference_matrix_inv) + + augmented_hed_tensor = torch.where(hed_tensor > epsilon, alpha * hed_tensor + beta, hed_tensor) + # Same problem that mypy doesn't understand + augmented_rgb_tensor = cast(torch.Tensor, torch.exp(-augmented_hed_tensor @ reference_matrix) - epsilon) + augmented_sample = augmented_rgb_tensor.permute(0, 3, 1, 2) + return augmented_sample + + +class CenterCrop(nn.Module): + """Perform a center crop of the image and target""" + + def __init__(self, size: int | tuple[int, int], **kwargs: Any) -> None: + super().__init__() + _size = size + if isinstance(size, int): + _size = (size, size) + + if isinstance(size, ListConfig) or isinstance(size, list): + _size = tuple(size) + + self._cropper = K.CenterCrop(size=_size, align_corners=True, p=1.0, keepdim=False, cropping_mode="slice") + + def forward( + self, + *sample: torch.Tensor, + data_keys: Optional[list[str | int | DataKey]] = None, + **kwargs: Any, + ) -> Union[list[torch.Tensor], torch.Tensor]: + output = [cast(torch.Tensor, self._cropper(item)) for item in sample] + + if len(output) == 1: + return output[0] + return output + + +def _parse_random_apply(random_apply: int | bool | tuple[int, int] | ListConfig) -> int | bool | tuple[int, int]: + if isinstance(random_apply, (int, bool)): + return random_apply + + if isinstance(random_apply, ListConfig) or isinstance(random_apply, list): + return cast(tuple[int, int], tuple(random_apply)) + + return random_apply + + +def _parse_random_apply_weights( + random_apply_weights: list[float] | ListConfig | None, +) -> list[float] | None: + if isinstance(random_apply_weights, ListConfig) or isinstance(random_apply_weights, list): + return cast(list[float], list(random_apply_weights)) + + return random_apply_weights + + +class AugmentationFactory(nn.Module): + """Factory for the augmentation. There are three classes of augmentations: + - `initial_transforms`: Transforms which are the first to always be applied to the sample + - `intensity_augmentations`: Transforms which only affect the intensity and not the geometry. Only applied to the + image. + - `geometric_augmentations`: Transforms which affect the geometry. They are applied to both the image, ROI and mask. + """ + + DATA_KEYS = {"image": DataKey.INPUT, "target": DataKey.MASK, "roi": DataKey.MASK} + + def __init__( + self, + data_description: DataDescription, + data_module: pl.LightningDataModule, + initial_transforms: list[nn.Module] | None = None, + random_apply_intensity: int | bool | ListConfig = False, + random_apply_weights_intensity: list[float] | None = None, + intensity_augmentations: list[K.IntensityAugmentationBase2D] | None = None, + random_apply_geometric: int | bool | ListConfig = False, + random_apply_weights_geometric: list[float] | ListConfig | None = None, + geometric_augmentations: list[K.GeometricAugmentationBase2D] | None = None, + final_transforms: list[nn.Module] | None = None, + ) -> None: + super().__init__() + + self._transformable_keys = ["image", "target"] + if data_description.use_roi: + self._transformable_keys.append("roi") + self._data_keys = [self.DATA_KEYS[key] for key in self._transformable_keys] + + # Initial transforms will be applied sequentially + if initial_transforms: + for transform in initial_transforms: + logger.info("Using initial transform %s", transform) + self._initial_transforms = nn.ModuleList(initial_transforms) + + # Intensity augmentations will be selected in random order + if intensity_augmentations: + for transform in intensity_augmentations: + logger.info("Adding intensity augmentation %s", transform) + + self._intensity_augmentations = None + if intensity_augmentations: + self._intensity_augmentations = K.AugmentationSequential( + *intensity_augmentations, + data_keys=list(self.DATA_KEYS.values()), + same_on_batch=False, + random_apply=_parse_random_apply(random_apply_intensity), + random_apply_weights=_parse_random_apply_weights(random_apply_weights_intensity), + ) + + # Geometric augmentations will be selected in random order. + if geometric_augmentations: + for transform in geometric_augmentations: + logger.info("Adding geometric augmentation %s", transform) + + self._geometric_augmentations = None + if geometric_augmentations: + self._geometric_augmentations = K.AugmentationSequential( + *geometric_augmentations, + data_keys=list(self.DATA_KEYS.values()), + same_on_batch=False, + random_apply=_parse_random_apply(random_apply_geometric), + random_apply_weights=_parse_random_apply_weights(random_apply_weights_geometric), + extra_args={DataKey.MASK: dict(resample=Resample.NEAREST, align_corners=True)}, + ) + + # Final transforms will be applied sequentially + if final_transforms: + for transform in final_transforms: + logger.info("Using final transform %s", transform) + self._final_transforms = nn.ModuleList(final_transforms) + + def forward(self, sample: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + output_data = [sample[key] for key in self._transformable_keys if key in sample] + if self._initial_transforms: + kwargs = { + "data_keys": self._data_keys, + "filenames": sample["path"], + } + for transform in self._initial_transforms: + output_data = transform(*output_data, **kwargs) + + if isinstance(output_data, torch.Tensor): + output_data = [output_data] + + if self._intensity_augmentations: + output_data[0] = self._intensity_augmentations(*output_data[:1], data_keys=[DataKey.INPUT]) + + if self._geometric_augmentations: + output_data = self._geometric_augmentations(*output_data, data_keys=self._data_keys) + + if isinstance(output_data, torch.Tensor): + output_data = [output_data] + + if self._final_transforms: + for transform in self._final_transforms: + output_data = transform(*output_data, data_keys=self._data_keys) + + if isinstance(output_data, torch.Tensor): + output_data = [output_data] + + # Add the output data back into the sample + for key, curr_output in zip(self._transformable_keys, output_data): + sample[key] = curr_output + + return sample diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py new file mode 100644 index 0000000..601b5f2 --- /dev/null +++ b/ahcore/transforms/pre_transforms.py @@ -0,0 +1,238 @@ +""" +Module for the pre-transforms, which are the transforms that are applied before samples are outputted in a +dataset. +""" +from __future__ import annotations + +from typing import Any, Callable + +import numpy as np +import numpy.typing as npt +import torch +from dlup.data.dataset import TileSample +from dlup.data.transforms import ContainsPolygonToLabel, ConvertAnnotationsToMask, RenameLabels +from torchvision.transforms import functional as F + +from ahcore.exceptions import ConfigurationError +from ahcore.utils.data import DataDescription +from ahcore.utils.io import get_logger +from ahcore.utils.types import DlupDatasetSample + +PreTransformCallable = Callable[[Any], Any] + +logger = get_logger(__name__) + + +class PreTransformTaskFactory: + def __init__(self, transforms: list[PreTransformCallable]): + """ + Pre-transforms are transforms that are applied to the samples directly originating from the dataset. + These transforms are typically the same for the specific tasks (e.g., segmentation, + detection or whole-slide classification). + Each of these tasks has a specific constructor. In all cases, the final transforms convert the PIL image + (as the image key of the output sample) to a tensor, and ensure that the sample dictionary can be collated. + In ahcore, the augmentations are done separately and are part of the model in the forward function. + + Parameters + ---------- + transforms : list + List of transforms to be used. + """ + # These are always finally added. + transforms += [ + ImageToTensor(), + AllowCollate(), + ] + self._transforms = transforms + + @classmethod + def for_segmentation( + cls, data_description: DataDescription, requires_target: bool = True + ) -> PreTransformTaskFactory: + """ + Pre-transforms for segmentation tasks. If the target is required these transforms are applied as follows: + - Labels are renamed (for instance if you wish to map several labels to on specific class) + - `Polygon` and `Point` annotations are converted to a mask + - The mask is one-hot encoded. + + Parameters + ---------- + data_description : DataDescription + requires_target : bool + + Returns + ------- + PreTransformTaskFactory + The `PreTransformTaskFactory` initialized for segmentation tasks. + """ + transforms: list[PreTransformCallable] = [] + if not requires_target: + return cls(transforms) + + if data_description.index_map is None: + raise ConfigurationError("`index_map` is required for segmentation models when the target is required.") + + if data_description.remap_labels is not None: + transforms.append(RenameLabels(remap_labels=data_description.remap_labels)) + + transforms.append( + ConvertAnnotationsToMask(roi_name=data_description.roi_name, index_map=data_description.index_map) + ) + transforms.append(OneHotEncodeMask(index_map=data_description.index_map)) + + return cls(transforms) + + @classmethod + def for_wsi_classification( + cls, data_description: DataDescription, requires_target: bool = True + ) -> PreTransformTaskFactory: + transforms: list[PreTransformCallable] = [] + if not requires_target: + return cls(transforms) + + index_map = data_description.index_map + if index_map is None: + raise ConfigurationError("`index_map` is required for classification models when the target is required.") + + transforms.append(LabelToClassIndex(index_map=index_map)) + + return cls(transforms) + + @classmethod + def for_tile_classification(cls, roi_name: str, label: str, threshold: float) -> PreTransformTaskFactory: + """Tile classification is based on a transform which checks if a polygon is present for a given threshold""" + convert_annotations = ContainsPolygonToLabel(roi_name=roi_name, label=label, threshold=threshold) + return cls([convert_annotations]) + + def __call__(self, data: DlupDatasetSample) -> DlupDatasetSample: + for transform in self._transforms: + data = transform(data) + return data + + def __repr__(self) -> str: + return f"PreTransformTaskFactory(transforms={self._transforms})" + + +class LabelToClassIndex: + """ + Maps label values to class indices according to the index_map specified in the data description. + Example: + If there are two tasks: + - Task1 with classes {A, B, C} + - Task2 with classes {X, Y} + Then an input sample could look like: {{"labels": {"Task1": "C", "Task2: "Y"}, ...} + If the index map is: {"A": 0, "B": 1, "C": 2, "X": 0, "Y": 1} + The returned sample will look like: {"labels": {"task1": 2, "task2": 1}, ...} + """ + + def __init__(self, index_map: dict[str, int]): + self._index_map = index_map + + def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: + sample["labels"] = { + label_name: self._index_map[label_value] for label_name, label_value in sample["labels"].items() + } + + return sample + + +class OneHotEncodeMask: + def __init__(self, index_map: dict[str, int]): + """Create the one-hot encoding of the mask for segmentation. + If we have `N` classes, the result will be an `(B, N + 1, H, W)` tensor, where the first sample is the + background. + Parameters + ---------- + index_map : dict[str, int] + Index map mapping the label name to the integer value it has in the mask. + """ + self._index_map = index_map + + # Check the max value in the mask + self._largest_index = max(index_map.values()) + + def __call__(self, sample: DlupDatasetSample) -> DlupDatasetSample: + mask = sample["annotation_data"]["mask"] + + new_mask = np.zeros((self._largest_index + 1, *mask.shape)) + for idx in range(self._largest_index + 1): + new_mask[idx] = (mask == idx).astype(np.float32) + + sample["annotation_data"]["mask"] = new_mask + return sample + + +def one_hot_encoding(index_map: dict[str, int], mask: npt.NDArray[np.int_ | np.float_]) -> npt.NDArray[np.float32]: + """ + functional interface to convert labels/predictions into one-hot codes + + Parameters + ---------- + index_map : dict[str, int] + Index map mapping the label name to the integer value it has in the mask. + + mask: npt.NDArray + The numpy array of model predictions or ground truth labels. + + Returns + ------- + new_mask: npt.NDArray + One-hot encoded output + """ + largest_index = max(index_map.values()) + new_mask = np.zeros((largest_index + 1, *mask.shape), dtype=np.float32) + for idx in range(largest_index + 1): + new_mask[idx] = mask == idx + return new_mask + + +class AllowCollate: + """Path objects cannot be collated in the standard pytorch collate function. + This transform converts the path to a string. Same holds for the annotations and labels + """ + + def __call__(self, sample: TileSample) -> dict[str, Any]: + # Path objects cannot be collated + output = dict(sample.copy()) + for key in sample: + if key == "path": + output["path"] = str(sample["path"]) + if key in ["annotation_data", "annotations"]: + # remove annotation_data and annotations keys from output + del output[key] + if key == "labels" and sample["labels"] is None: + del output[key] + + return output + + +class ImageToTensor: + """ + Transform to translate the output of a dlup dataset to data_description supported by AhCore + """ + + def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: + sample["image"] = F.pil_to_tensor(sample["image"].convert("RGB")).float() + + if sample["image"].sum() == 0: + raise RuntimeError(f"Empty tile for {sample['path']} at {sample['coordinates']}") + + # annotation_data is added by the ConvertPolygonToMask transform. + if "annotation_data" not in sample: + return sample + + if "mask" in sample["annotation_data"]: + mask = sample["annotation_data"]["mask"] + if len(mask.shape) == 2: + # Mask is not one-hot encoded + mask = mask[np.newaxis, ...] + sample["target"] = torch.from_numpy(mask).float() + + if "roi" in sample["annotation_data"]: + roi = sample["annotation_data"]["roi"] + sample["roi"] = torch.from_numpy(roi[np.newaxis, ...]).float() + + return sample + + def __repr__(self) -> str: + return f"{type(self).__name__}()" diff --git a/ahcore/utils/__init__.py b/ahcore/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py new file mode 100644 index 0000000..5e01bbe --- /dev/null +++ b/ahcore/utils/data.py @@ -0,0 +1,67 @@ +"""Utilities to describe the dataset to be used and the way it should be parsed.""" +from __future__ import annotations + +import hashlib +import pickle +import uuid +from pathlib import Path +from typing import Dict, Optional, Tuple + +from pydantic import BaseModel + +from ahcore.utils.types import NonNegativeInt, PositiveFloat, PositiveInt + + +def basemodel_to_uuid(base_model: BaseModel) -> uuid.UUID: + """Create a unique identifier for a pydantic model. + + This is done by pickling the object, and computing the sha256 hash of the pickled object and converting this to + an UUID. The UUID is generated using the sha256 hash as a namespace, ensuring similar lengths. The chance of + a collision is astronomically small. + + Arguments + --------- + base_model: BaseModel + The BaseModel to create a unique identifier for. + + Returns + ------- + uuid.UUID + A unique identifier for the BaseModel. + """ + # Serialize the object + serialized_data = pickle.dumps(base_model.model_dump()) + + # Generate a sha256 hash of the serialized data + obj_hash = hashlib.sha256(serialized_data).digest() + + # Use the hash as a namespace to generate a UUID + unique_id = uuid.uuid5(uuid.NAMESPACE_DNS, obj_hash.hex()) + + return unique_id + + +class GridDescription(BaseModel): + mpp: Optional[PositiveFloat] + tile_size: Tuple[PositiveInt, PositiveInt] + tile_overlap: Tuple[NonNegativeInt, NonNegativeInt] + output_tile_size: Optional[Tuple[int, int]] = None + + +class DataDescription(BaseModel): + mask_label: Optional[str] = None + mask_threshold: Optional[float] = None # This is only used for training + roi_name: Optional[str] = None + num_classes: PositiveInt + data_dir: Path + manifest_database_uri: str + manifest_name: str + split_version: str + annotations_dir: Path + training_grid: GridDescription + inference_grid: GridDescription + index_map: Optional[Dict[str, int]] + remap_labels: Optional[Dict[str, str]] = None + use_class_weights: Optional[bool] = False + convert_mask_to_rois: bool = True + use_roi: bool = True diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py new file mode 100644 index 0000000..0731d3d --- /dev/null +++ b/ahcore/utils/database_models.py @@ -0,0 +1,204 @@ +"""Database models for ahcore's manifest database.""" +from enum import Enum as PyEnum +from typing import List + +from sqlalchemy import Column, DateTime, Enum, Float, ForeignKey, Integer, String, UniqueConstraint, func +from sqlalchemy.orm import DeclarativeBase, Mapped, relationship + + +class CategoryEnum(PyEnum): + TRAIN = "fit" + VALIDATE = "validate" + TEST = "test" + PREDICT = "predict" + + +class Base(DeclarativeBase): + pass + + +class Manifest(Base): + """Manifest table.""" + + __tablename__ = "manifest" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + name = Column(String, unique=True) + + patients: Mapped[List["Patient"]] = relationship("Patient", back_populates="manifest") + + +class Patient(Base): + """Patient table.""" + + __tablename__ = "patient" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + patient_code = Column(String, unique=True) + manifest_id = Column(Integer, ForeignKey("manifest.id")) + + manifest: Mapped["Manifest"] = relationship("Manifest", back_populates="patients") + images: Mapped[List["Image"]] = relationship("Image", back_populates="patient") + labels: Mapped[List["PatientLabels"]] = relationship("PatientLabels", back_populates="patient") + split: Mapped[List["Split"]] = relationship("Split", uselist=False, back_populates="patient") + + +class Image(Base): + """Image table.""" + + __tablename__ = "image" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + filename = Column(String, unique=True) + reader = Column(String) + patient_id = Column(Integer, ForeignKey("patient.id")) + + height = Column(Integer) + width = Column(Integer) + mpp = Column(Float) + + patient: Mapped["Patient"] = relationship("Patient", back_populates="images") + masks: Mapped[List["Mask"]] = relationship("Mask", back_populates="image") + annotations: Mapped[List["ImageAnnotations"]] = relationship("ImageAnnotations", back_populates="image") + labels: Mapped["ImageLabels"] = relationship("ImageLabels", back_populates="image") + cache: Mapped["ImageCache"] = relationship("ImageCache", uselist=False, back_populates="image") + + +class ImageCache(Base): + """Image cache table.""" + + __tablename__ = "image_cache" + + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + filename = Column(String, unique=True) + reader = Column(String) + num_tiles = Column(Integer) + image_id = Column(Integer, ForeignKey("image.id")) + + image: Mapped["Image"] = relationship("Image", back_populates="cache") + description_id = Column(Integer, ForeignKey("cache_description.id")) + description: Mapped["CacheDescription"] = relationship("CacheDescription", back_populates="caches") + + +class CacheDescription(Base): + """Cache description table.""" + + __tablename__ = "cache_description" + + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + mpp = Column(Float) + tile_size_width = Column(Integer) + tile_size_height = Column(Integer) + tile_overlap_width = Column(Integer) + tile_overlap_height = Column(Integer) + tile_mode = Column(String) + crop = Column(Integer, default=False) # using Integer for boolean for DB compatibility + mask_threshold = Column(Float) + grid_order = Column(String) + + caches: Mapped["ImageCache"] = relationship("ImageCache", back_populates="description") + + +class Mask(Base): + """Mask table.""" + + __tablename__ = "mask" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + filename = Column(String, unique=True) + reader = Column(String) + image_id = Column(Integer, ForeignKey("image.id")) + + image: Mapped["Image"] = relationship("Image", back_populates="masks") + + +class ImageAnnotations(Base): + """Image annotations table.""" + + __tablename__ = "image_annotations" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + filename = Column(String, unique=True) + reader = Column(String) + image_id = Column(Integer, ForeignKey("image.id")) + + image: Mapped["Image"] = relationship("Image", back_populates="annotations") + + +class ImageLabels(Base): + """Image labels table.""" + + __tablename__ = "image_labels" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + label_data = Column(String) # e.g. "cancer" or "benign" + image_id = Column(Integer, ForeignKey("image.id")) + + image: Mapped["Image"] = relationship("Image", back_populates="labels") + + +class PatientLabels(Base): + """Patient labels table.""" + + __tablename__ = "patient_labels" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + key = Column(String) + value = Column(String) + patient_id = Column(Integer, ForeignKey("patient.id")) + + # Add a unique constraint + __table_args__ = (UniqueConstraint("key", "patient_id", name="uq_patient_label_key"),) + + patient: Mapped["Patient"] = relationship("Patient", back_populates="labels") + + +class SplitDefinitions(Base): + """Split definitions table.""" + + __tablename__ = "split_definitions" + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + version = Column(String, nullable=False) + description = Column(String) + splits: Mapped[List["Split"]] = relationship("Split", back_populates="split_definition") + + +class Split(Base): + """Split table.""" + + __tablename__ = "split" + + id = Column(Integer, primary_key=True) + # pylint: disable=E1102 + created = Column(DateTime(timezone=True), default=func.now()) + last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) + category: Column[CategoryEnum] = Column(Enum(CategoryEnum), nullable=False) + + patient_id = Column(Integer, ForeignKey("patient.id")) + patient: Mapped["Patient"] = relationship("Patient", back_populates="split") + + split_definition_id = Column(Integer, ForeignKey("split_definitions.id"), nullable=False) + split_definition: Mapped["SplitDefinitions"] = relationship("SplitDefinitions", back_populates="splits") diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py new file mode 100644 index 0000000..f91e62d --- /dev/null +++ b/ahcore/utils/io.py @@ -0,0 +1,214 @@ +"""Input/ Output utils. + +A part of the functions in this module are derived/taken from pytorch lightning template at: +https://github.com/ashleve/lightning-hydra-template +This template is licensed under the MIT License. + +""" +from __future__ import annotations + +import logging +import os +import warnings +from enum import Enum +from pathlib import Path +from types import FunctionType +from typing import Any, Optional, Sequence, Type + +import pytorch_lightning as pl +import rich +import rich.syntax +import rich.tree +from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf.errors import InterpolationKeyError +from pytorch_lightning.utilities import rank_zero_only # type: ignore[attr-defined] + + +def get_logger(name: str = __name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ): + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger + + +logger = get_logger(__name__) + + +def debug_function(x: int) -> int: + """ + Function to use for debugging (e.g. github workflow testing) + :param x: + :return: x^2 + """ + return x**2 + + +def validate_config(cfg: Any) -> None: + if isinstance(cfg, ListConfig): + for x in cfg: + validate_config(x) + elif isinstance(cfg, DictConfig): + for name, v in cfg.items(): + if name == "hydra": + logger.warning("Skipped validating hydra native configs") + continue + try: + validate_config(v) + except InterpolationKeyError: + logger.warning("Skipped validating %s: %s", name, str(v)) + continue + + +@rank_zero_only # type: ignore[misc] +def print_config( + config: DictConfig, + fields: Sequence[str] = ( + "trainer", + "model", + "experiment", + "transforms", + "datamodule", + "callbacks", + "logger", + "test_after_training", + "seed", + "name", + ), + resolve: bool = True, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + Arguments + --------- + config : DictConfig + Configuration composed by Hydra. + fields : Sequence[str], optional + Determines which main fields from config will be printed and in what order. + resolve : bool, optional + Whether to resolve reference fields of DictConfig. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, DictConfig): + branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + rich.print(tree) + + with open("config_tree.log", "w") as fp: + rich.print(tree, file=fp) + + +def extras(config: DictConfig) -> None: + """A couple of optional utilities, controlled by main config file: + - disabling warnings + - forcing debug friendly configuration + - verifying experiment name is set when running in experiment mode + Modifies DictConfig in place. + Args: + config (DictConfig): Configuration composed by Hydra. + """ + + logger = get_logger(__name__) + + # disable python warnings if + if config.get("ignore_warnings"): + logger.info("Disabling python warnings ") + warnings.filterwarnings("ignore") + + # verify experiment name is set when running in experiment mode + if config.get("enforce_tags") and (not config.get("tags") or config.get("tags") == ["dev"]): + logger.info( + "Running in experiment mode without tags specified" + "Use `python run.py experiment=some_experiment tags=['some_tag',...]`, or change it in the experiment yaml" + ) + logger.info("Exiting...") + exit() + + # force debugger friendly configuration if + # debuggers don't like GPUs and multiprocessing + if config.trainer.get("fast_dev_run"): + logger.info("Forcing debugger friendly configuration ") + if config.trainer.get("gpus"): + config.trainer.gpus = 0 + if config.datamodule.get("pin_memory"): + config.datamodule.pin_memory = False + if config.datamodule.get("num_workers"): + config.datamodule.num_workers = 0 + + +@rank_zero_only # type:ignore[misc] +def log_hyperparameters( + config: DictConfig, + model: pl.LightningModule, + trainer: pl.Trainer, +) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - number of model parameters + """ + + if not trainer.logger: + return + + hparams = { + "model": config["lit_module"], + "model/params/total": sum(p.numel() for p in model.parameters()), + "model/params/trainable": sum(p.numel() for p in model.parameters() if p.requires_grad), + "model/params/non_trainable": sum(p.numel() for p in model.parameters() if not p.requires_grad), + "datamodule": config["datamodule"], + "trainer": config["trainer"], + } + + if "seed" in config: + hparams["seed"] = config["seed"] + if "callbacks" in config: + hparams["callbacks"] = config["callbacks"] + + # send hparams to all loggers + trainer.logger.log_hyperparams(hparams) + + +def fullname(obj: Any) -> str: + if isinstance(obj, type) or isinstance(obj, FunctionType): + cls = obj + else: # if obj is an instance, get its class + cls = type(obj) + + module = cls.__module__ + if module is None or module == str.__class__.__module__: # don't want to return 'builtins' + return cls.__name__ + return module + "." + cls.__name__ + + +def get_enum_key_from_value(value: str, enum_class: Type[Enum]) -> Optional[str]: + for enum_member in enum_class: + if enum_member.value == value: + return enum_member.name + return None + + +def get_cache_dir() -> Path: + return Path(os.environ.get("SCRATCH", "/tmp")) / "ahcore_cache" diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py new file mode 100644 index 0000000..216142b --- /dev/null +++ b/ahcore/utils/manifest.py @@ -0,0 +1,393 @@ +""" +All utilities to parse manifests into datasets. A manifest is a database containing the description of a dataset. +See the documentation for more information and examples. +""" + +from __future__ import annotations + +import functools +from pathlib import Path +from types import TracebackType +from typing import Any, Callable, Generator, Literal, Optional, Type, TypedDict, cast + +from dlup import SlideImage +from dlup.annotations import WsiAnnotations +from dlup.data.dataset import RegionFromWsiDatasetSample, TiledWsiDataset, TileSample +from dlup.experimental_backends import ImageBackend # type: ignore +from dlup.tiling import GridOrder, TilingMode +from pydantic import BaseModel +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session, sessionmaker + +from ahcore.exceptions import RecordNotFoundError +from ahcore.utils.data import DataDescription +from ahcore.utils.database_models import ( + Base, + CategoryEnum, + Image, + ImageAnnotations, + Manifest, + Mask, + Patient, + Split, + SplitDefinitions, +) +from ahcore.utils.io import get_enum_key_from_value, get_logger +from ahcore.utils.rois import compute_rois +from ahcore.utils.types import PositiveFloat, PositiveInt, Rois + +logger = get_logger(__name__) + +_AnnotationReturnTypes = WsiAnnotations | SlideImage + + +class _AnnotationReadersDict(TypedDict): + ASAP_XML: Callable[[Path], WsiAnnotations] + GEOJSON: Callable[[Path], WsiAnnotations] + PYVIPS: Callable[[Path], SlideImage] + TIFFFILE: Callable[[Path], SlideImage] + OPENSLIDE: Callable[[Path], SlideImage] + + +_AnnotationReaders: _AnnotationReadersDict = { + "ASAP_XML": WsiAnnotations.from_asap_xml, + "GEOJSON": WsiAnnotations.from_geojson, + "PYVIPS": functools.partial(SlideImage.from_file_path, backend=ImageBackend.PYVIPS), + "TIFFFILE": functools.partial(SlideImage.from_file_path, backend=ImageBackend.TIFFFILE), + "OPENSLIDE": functools.partial(SlideImage.from_file_path, backend=ImageBackend.OPENSLIDE), +} + + +def parse_annotations_from_record( + annotations_root: Path, record: list[Mask] | list[ImageAnnotations] +) -> _AnnotationReturnTypes | None: + """ + Parse the annotations from a record of type ImageAnnotations. + + Parameters + ---------- + annotations_root : Path + The root directory of the annotations. + record : list[Type[ImageAnnotations]] + The record containing the annotations. + + Returns + ------- + WsiAnnotations + The parsed annotations. + """ + if not record: + return None + assert len(record) == 1 + + valid_readers = list(_AnnotationReaders.keys()) + reader_name = cast( + Literal["ASAP_XML", "GEOJSON", "PYVIPS", "TIFFFILE", "OPENSLIDE"], + record[0].reader, + ) + + if reader_name not in valid_readers: + raise ValueError(f"Invalid reader: {record[0].reader}") + assert reader_name in valid_readers + + filename = record[0].filename + + try: + reader_func = _AnnotationReaders[reader_name] + except KeyError: + raise NotImplementedError(f"Reader {reader_name} not implemented.") + + return reader_func(annotations_root / filename) + + +def get_mask_and_annotations_from_record( + annotations_root: Path, record: Image +) -> tuple[_AnnotationReturnTypes | None, _AnnotationReturnTypes | None]: + """ + Get the mask and annotations from a record of type Image. + + Parameters + ---------- + annotations_root : Path + The root directory of the annotations. + record : Type[Image] + The record containing the mask and annotations. + + Returns + ------- + tuple[WsiAnnotations, WsiAnnotations] + The mask and annotations. + """ + _masks = parse_annotations_from_record(annotations_root, record.masks) + _annotations = parse_annotations_from_record(annotations_root, record.annotations) + return _masks, _annotations + + +def _get_rois(mask: WsiAnnotations | None, data_description: DataDescription, stage: str) -> Optional[Rois]: + if (mask is None) or (stage != "fit") or (not data_description.convert_mask_to_rois): + return None + + tile_size = data_description.training_grid.tile_size + tile_overlap = data_description.training_grid.tile_overlap + + return compute_rois(mask, tile_size=tile_size, tile_overlap=tile_overlap, centered=True) + + +class DataManager: + def __init__(self, database_uri: str) -> None: + self._database_uri = database_uri + self.__session: Optional[Session] = None + self._logger = get_logger(type(self).__name__) + + @property + def _session(self) -> Session: + if self.__session is None: + self.__session = open_db(self._database_uri) + return self.__session + + @staticmethod + def _ensure_record(record: Any, description: str) -> None: + """Raises an error if the record is None.""" + if not record: + raise RecordNotFoundError(f"{description} not found.") + + def get_records_by_split( + self, + manifest_name: str, + split_version: str, + split_category: Optional[str] = None, + ) -> Generator[Patient, None, None]: + manifest = self._session.query(Manifest).filter_by(name=manifest_name).first() + self._ensure_record(manifest, f"Manifest with name {manifest_name}") + + split_definition = self._session.query(SplitDefinitions).filter_by(version=split_version).first() + self._ensure_record(split_definition, f"Split definition with version {split_version}") + + # This is because mypy is complaining otherwise, + # but _ensure_record effectively ensures that the record is not None + assert manifest is not None + assert split_definition is not None + query = ( + self._session.query(Patient) + .join(Split) + .filter( + Patient.manifest_id == manifest.id, + Split.split_definition_id == split_definition.id, + ) + ) + + if split_category is not None: + split_category_key = get_enum_key_from_value(split_category, CategoryEnum) + query = query.filter(Split.category == split_category_key) + + patients = query.all() + + self._logger.info( + f"Found {len(patients)} patients for split {split_category if split_category else 'all categories'}" + ) + for patient in patients: + yield patient + + def get_image_metadata_by_split( + self, + manifest_name: str, + split_version: str, + split_category: Optional[str] = None, + ) -> Generator[ImageMetadata, None, None]: + """ + Yields the metadata of images for a given manifest name, split version, and optional split category. + + Parameters + ---------- + manifest_name : str + The name of the manifest. + split_version : str + The version of the split. + split_category : Optional[str], default=None + The category of the split (e.g., "fit", "validate", "test"). + + Yields + ------- + ImageMetadata + The metadata of the image. + """ + for patient in self.get_records_by_split(manifest_name, split_version, split_category): + for image in patient.images: + yield fetch_image_metadata(image) + + def get_image_metadata_by_patient(self, patient_code: str) -> list[ImageMetadata]: + """ + Fetch the metadata for the images associated with a specific patient. + + Parameters + ---------- + patient_code : str + The unique code of the patient. + + Returns + ------- + list[ImageData] + A list of metadata for all images associated with the patient. + """ + patient = self._session.query(Patient).filter_by(patient_code=patient_code).first() + self._ensure_record(patient, f"Patient with code {patient_code} not found") + assert patient is not None # for mypy + return [fetch_image_metadata(image) for image in patient.images] + + def get_image_by_filename(self, filename: str) -> Image: + """ + Fetch the metadata for an image based on its filename. + + Parameters + ---------- + filename : str + The filename of the image. + + Returns + ------- + Image + The image from the database. + """ + image = self._session.query(Image).filter_by(filename=filename).first() + self._ensure_record(image, f"Image with filename {filename} not found") + assert image + return image + + def get_image_metadata_by_id(self, image_id: int) -> ImageMetadata: + """ + Fetch the metadata for an image based on its ID. + + Parameters + ---------- + image_id : int + The ID of the image. + + Returns + ------- + ImageMetadata + Metadata of the image. + """ + image = self._session.query(Image).filter_by(id=image_id).first() + self._ensure_record(image, f"No image found with ID {image_id}") + assert image is not None # mypy + return fetch_image_metadata(image) + + def __enter__(self) -> "DataManager": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: + if self._session is not None: + self.close() + return False + + def close(self) -> None: + if self.__session is not None: + self.__session.close() + self.__session = None + + +def datasets_from_data_description( + db_manager: DataManager, + data_description: DataDescription, + transform: Callable[[TileSample], RegionFromWsiDatasetSample] | None, + stage: str, +) -> Generator[TiledWsiDataset, None, None]: + logger.info(f"Reading manifest from {data_description.manifest_database_uri} for stage {stage}") + + image_root = data_description.data_dir + annotations_root = data_description.annotations_dir + + assert isinstance(stage, str), "Stage should be a string." + + if stage == "fit": + grid_description = data_description.training_grid + else: + grid_description = data_description.inference_grid + + records = db_manager.get_records_by_split( + manifest_name=data_description.manifest_name, + split_version=data_description.split_version, + split_category=stage, + ) + for record in records: + labels = [(str(label.key), str(label.value)) for label in record.labels] if record.labels else None + + for image in record.images: + mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) + assert isinstance(mask, WsiAnnotations) or (mask is None) + rois = _get_rois(mask, data_description, stage) + mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold + + dataset = TiledWsiDataset.from_standard_tiling( + path=image_root / image.filename, + mpp=grid_description.mpp, + tile_size=grid_description.tile_size, + tile_overlap=grid_description.tile_overlap, + tile_mode=TilingMode.overflow, + grid_order=GridOrder.C, + crop=False, + mask=mask, + mask_threshold=mask_threshold, + output_tile_size=getattr(grid_description, "output_tile_size", None), + rois=rois if rois is not None else None, + annotations=annotations if stage != "predict" else None, + labels=labels, # type: ignore + transform=transform, + backend=ImageBackend[str(image.reader)], + overwrite_mpp=(image.mpp, image.mpp), + limit_bounds=False if rois is not None else True, + ) + + yield dataset + + +class ImageMetadata(BaseModel): + """Model to hold image metadata""" + + class Config: + frozen = True + + filename: Path + height: PositiveInt + width: PositiveInt + mpp: PositiveFloat + + +def open_db(database_uri: str) -> Session: + """Open a database connection. + + Parameters + ---------- + database_uri : str + The URI of the database. + + Returns + ------- + Session + The database session. + """ + engine = create_engine(database_uri) + create_tables(engine) + SessionLocal = sessionmaker(bind=engine) + return SessionLocal() + + +def create_tables(engine: Engine) -> None: + """Create the database tables.""" + Base.metadata.create_all(bind=engine) + + +def fetch_image_metadata(image: Image) -> ImageMetadata: + """Extract metadata from an Image object.""" + return ImageMetadata( + filename=Path(image.filename), + height=int(image.height), + width=int(image.width), + mpp=float(image.mpp), + ) diff --git a/ahcore/utils/rois.py b/ahcore/utils/rois.py new file mode 100644 index 0000000..4496b6f --- /dev/null +++ b/ahcore/utils/rois.py @@ -0,0 +1,127 @@ +"""Utilities to work with regions-of-interest. Has utilities to compute balanced ROIs from masks""" +from __future__ import annotations + +import numpy as np +from dlup.annotations import WsiAnnotations +from dlup.tiling import TilingMode, tiles_grid_coordinates + +from ahcore.utils.io import logger +from ahcore.utils.types import Rois + + +def compute_rois( + mask: WsiAnnotations, + tile_size: tuple[int, int], + tile_overlap: tuple[int, int], + centered: bool = True, +) -> Rois: + """ + Compute the ROIs from a `WsiAnnotations` object. The ROIs are computed by: + 1) The bounding box of the whole `mask` object is computed (by `dlup`). + 2) The whole region up to the bounding box is computed. + 3) For each of the regions in the result, the bounding box is determined. + + If `centered` is True, the bounding box is centered in the region. This is done as follows: + 1) The effective region size is computed, depending on the tiling mode. This is the size of the region which + is tiled by the given tile size and overlap. Tiles always overflow, so there are tiles on the border which are + partially covered. + 2) The ROIs obtained in the first step are replaced by the effective region size, and centered around the original + ROI. + + Parameters + ---------- + mask : WsiAnnotations + tile_size : tuple[int, int] + tile_overlap : tuple[int, int] + centered : bool + + Returns + ------- + list[tuple[tuple[int, int], tuple[int, int]]] + List of ROIs (coordinates, size). + + """ + bbox_coords, bbox_size = mask.bounding_box + logger.debug("Annotations bounding box: %s, %s", bbox_coords, bbox_size) + total_roi = mask.read_region((0, 0), 1.0, (bbox_coords[0] + bbox_size[0], bbox_coords[1] + bbox_size[1])) + + _rois = np.asarray([_.bounds for _ in total_roi]) + _rois[:, 2:] = _rois[:, 2:] - _rois[:, :2] + rois = [((roi[0], roi[1]), (roi[2], roi[3])) for roi in _rois] + logger.debug("Regions of interest: %s", rois) + + if centered: + centered_rois = _get_centered_rois(rois, tile_size, tile_overlap) + logger.debug("Centered ROIs: %s", centered_rois) + return centered_rois + + return rois + + +def _get_centered_rois(roi_boxes: Rois, tile_size: tuple[int, int], tile_overlap: tuple[int, int]) -> Rois: + """ + Based on the ROI and the effective region size compute a grid aligned at the center of the region + + Parameters + ---------- + roi_boxes : + The effective roi boxes + tile_size : tuple[int, int] + The size of the tiles which will be used to compute the balanced ROIs + tile_overlap : tuple[int, int] + The tile overlap of the tiles in the dataset. + + Returns + ------- + list[tuple[tuple[int, int], tuple[int, int]]] + List of ROIs (coordinates, size). + + """ + logger.debug("Computing balanced ROIs from ROI boxes %s", roi_boxes) + region_sizes = [_compute_effective_size(roi_size, tile_size, tile_overlap) for _, roi_size in roi_boxes] + + output_rois: Rois = [] + for roi, region_size in zip(roi_boxes, region_sizes): + offset, size = roi + _region_size = np.asarray(region_size) + _offset = np.asarray(offset) + _size = np.asarray(size) + + _new_offset = _offset - (_region_size - _size) / 2 + _coordinates = int(_new_offset[0]), int(_new_offset[1]) + output_rois.append((_coordinates, region_size)) + return output_rois + + +def _compute_effective_size( + size: tuple[int, int], + tile_size: tuple[int, int], + tile_overlap: tuple[int, int], + mode: TilingMode = TilingMode.overflow, +) -> tuple[int, int]: + """ + Compute the effective size of a tiled region, depending on the tiling mode and given the size. The effective size + basically is the size of the region which is tiled by the given tile size and overlap. If tiles overflow, there are + tiles on the border which are partially covered. If tiles are centered, the effective size of the region will be + smaller than the original size. + + Parameters + ---------- + size : tuple[int, int] + tile_size : tuple[int, int] + tile_overlap : tuple[int, int] + mode : TilingMode + + Returns + ------- + tuple[int, int] + The effective size of the grid + """ + coordinates_x, coordinates_y = tiles_grid_coordinates( + size=size, tile_size=tile_size, tile_overlap=tile_overlap, mode=mode + ) + effective_size = ( + coordinates_x.max() + tile_size[0], + coordinates_y.max() + tile_size[1], + ) + return effective_size diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py new file mode 100644 index 0000000..9057ec1 --- /dev/null +++ b/ahcore/utils/types.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import numpy.typing as npt +from dlup.data.dataset import Dataset +from pydantic import AfterValidator +from typing_extensions import Annotated + + +def is_positive(v: int | float) -> int | float: + assert v > 0, f"{v} is not a positive {type(v)}" + return v + + +def is_non_negative(v: int | float) -> int | float: + assert v >= 0, f"{v} is not a non-negative {type(v)}" + return v + + +PositiveInt = Annotated[int, AfterValidator(is_positive)] +PositiveFloat = Annotated[float, AfterValidator(is_positive)] +NonNegativeInt = Annotated[int, AfterValidator(is_non_negative)] +NonNegativeFloat = Annotated[float, AfterValidator(is_non_negative)] +BoundingBoxType = tuple[tuple[int, int], tuple[int, int]] +Rois = list[BoundingBoxType] +GenericArray = npt.NDArray[np.generic] + +DlupDatasetSample = dict[str, Any] +_DlupDataset = Dataset[DlupDatasetSample] diff --git a/ahcore/writers.py b/ahcore/writers.py new file mode 100644 index 0000000..6f72b6f --- /dev/null +++ b/ahcore/writers.py @@ -0,0 +1,205 @@ +""" +This module contains writer classes. Currently implemented: + +- `H5FileImageWriter`: class to write H5 files based on iterators, for instance, the output of a dataset + class. Can for instance be used to store outputs of models. The `readers` contain separate modules to read these + h5 files. + +""" +import json +from multiprocessing.connection import Connection +from pathlib import Path +from typing import Any, Generator, Optional + +import h5py +import numpy as np +import numpy.typing as npt +from dlup.tiling import Grid, GridOrder, TilingMode + +from ahcore.utils.io import get_logger +from ahcore.utils.types import GenericArray + +logger = get_logger(__name__) + + +class H5FileImageWriter: + """Image writer that writes tile-by-tile to h5.""" + + def __init__( + self, + filename: Path, + size: tuple[int, int], + mpp: float, + tile_size: tuple[int, int], + tile_overlap: tuple[int, int], + num_samples: int, + is_binary: bool = False, + progress: Optional[Any] = None, + ) -> None: + self._grid: Optional[Grid] = None + self._grid_coordinates: Optional[npt.NDArray[np.int_]] = None + self._filename: Path = filename + self._size: tuple[int, int] = size + self._mpp: float = mpp + self._tile_size: tuple[int, int] = tile_size + self._tile_overlap: tuple[int, int] = tile_overlap + self._num_samples: int = num_samples + self._is_binary: bool = is_binary + self._progress = progress + self._data: Optional[h5py.Dataset] = None + self._coordinates_dataset: Optional[h5py.Dataset] = None + self._tile_indices: Optional[h5py.Dataset] = None + self._current_index: int = 0 + + self._logger = logger # maybe not the best way, think about it + self._logger.debug("Writing h5 to %s", self._filename) + + def init_writer(self, first_batch: GenericArray, h5file: h5py.File) -> None: + """Initializes the image_dataset based on the first tile.""" + batch_shape = np.asarray(first_batch).shape + batch_dtype = np.asarray(first_batch).dtype + + self._current_index = 0 + + self._coordinates_dataset = h5file.create_dataset( + "coordinates", + shape=(self._num_samples, 2), + dtype=int, + compression="gzip", + ) + + # TODO: We only support a single Grid + grid = Grid.from_tiling( + (0, 0), + size=self._size, + tile_size=self._tile_size, + tile_overlap=self._tile_overlap, + mode=TilingMode.overflow, + order=GridOrder.C, + ) + num_tiles = len(grid) + self._grid = grid + self._tile_indices = h5file.create_dataset( + "tile_indices", + shape=(num_tiles,), + dtype=int, + compression="gzip", + ) + # Initialize to -1, which is the default value + self._tile_indices[:] = -1 + + if not self._is_binary: + self._data = h5file.create_dataset( + "data", + shape=(self._num_samples,) + batch_shape[1:], + dtype=batch_dtype, + compression="gzip", + chunks=(1,) + batch_shape[1:], + ) + else: + dt = h5py.vlen_dtype(np.dtype("uint8")) # Variable-length uint8 data type + self._data = h5file.create_dataset( + "data", + shape=(self._num_samples,), + dtype=dt, + chunks=(1,), + ) + + # This only works when the mode is 'overflow' and in 'C' order. + metadata = { + "mpp": self._mpp, + "dtype": str(batch_dtype), + "shape": tuple(batch_shape[1:]), + "size": (int(self._size[0]), int(self._size[1])), + "num_channels": batch_shape[1], + "num_samples": self._num_samples, + "tile_size": tuple(self._tile_size), + "tile_overlap": tuple(self._tile_overlap), + "num_tiles": num_tiles, + "grid_order": "C", + "mode": "overflow", + "is_binary": self._is_binary, + } + metadata_json = json.dumps(metadata) + h5file.attrs["metadata"] = metadata_json + + def add_associated_images( + self, + images: tuple[tuple[str, npt.NDArray[np.uint8]], ...], + description: Optional[str] = None, + ) -> None: + """Adds associated images to the h5 file.""" + + # Create a compound dataset "associated_images" + with h5py.File(self._filename, "r+") as h5file: + associated_images = h5file.create_group("associated_images") + for name, image in images: + associated_images.create_dataset(name, data=image) + + if description: + associated_images.attrs["description"] = description + + def consume( + self, + batch_generator: Generator[tuple[GenericArray, GenericArray], None, None], + connection_to_parent: Optional[Connection] = None, + ) -> None: + """Consumes tiles one-by-one from a generator and writes them to the h5 file.""" + grid_counter = 0 + + try: + with h5py.File(self._filename.with_suffix(".h5.partial"), "w") as h5file: + first_coordinates, first_batch = next(batch_generator) + self.init_writer(first_batch, h5file) + + # Mostly for mypy + assert self._grid, "Grid is not initialized" + assert self._tile_indices, "Tile indices are not initialized" + assert self._data, "Dataset is not initialized" + assert self._coordinates_dataset, "Coordinates dataset is not initialized" + + batch_generator = self._batch_generator((first_coordinates, first_batch), batch_generator) + # progress bar will be used if self._progress is not None + if self._progress: + batch_generator = self._progress(batch_generator, total=self._num_samples) + + for coordinates, batch in batch_generator: + # We take a coordinate, and step through the grid until we find it. + # Note that this assumes that the coordinates come in C-order, so we will always hit it + for idx, curr_coordinates in enumerate(coordinates): + # As long as our current coordinates are not equal to the grid coordinates, we make a step + while not np.all(curr_coordinates == self._grid[grid_counter]): + grid_counter += 1 + # If we find it, we set it to the index, so we can find it later on + # This can be tested by comparing the grid evaluated at a grid index with the tile index + # mapped to its coordinates + self._tile_indices[grid_counter] = self._current_index + idx + grid_counter += 1 + + batch_size = batch.shape[0] + self._data[self._current_index : self._current_index + batch_size] = batch + self._coordinates_dataset[self._current_index : self._current_index + batch_size] = coordinates + self._current_index += batch_size + + except Exception as e: + self._logger.error("Error in consumer thread for %s: %s", self._filename, e, exc_info=e) + if connection_to_parent: + connection_to_parent.send((False, self._filename, e)) # Send a message to the parent + else: + # When done writing rename the file. + self._filename.with_suffix(".h5.partial").rename(self._filename) + finally: + if connection_to_parent: + connection_to_parent.send((True, None, None)) + connection_to_parent.close() + + @staticmethod + def _batch_generator( + first_coordinates_batch: Any, batch_generator: Generator[Any, None, None] + ) -> Generator[Any, None, None]: + # We yield the first batch too so the progress bar takes the first batch also into account + yield first_coordinates_batch + for tile in batch_generator: + if tile is None: + break + yield tile diff --git a/config/augmentations/segmentation.yaml b/config/augmentations/segmentation.yaml new file mode 100644 index 0000000..3a5dc3a --- /dev/null +++ b/config/augmentations/segmentation.yaml @@ -0,0 +1,11 @@ +# Below we set up transforms for each stage in pytorch lightning +# the '@name' behind the renames the config to 'name' -- allows us to re-use the validation config for other stages +# To override specific values you can use the expected override, e.g., 'transforms.predict.probability=0' +# We can, however, also override groups, albeit a bit more unfamiliar: 'transforms/stages@transforms.predict=train_seg' + +defaults: + - _self_ + - /augmentations/stages@fit: train_seg + - /augmentations/stages@validate: val_seg + - /augmentations/stages@test: val_seg + - /augmentations/stages@predict: val_seg diff --git a/config/augmentations/stages/train_seg.yaml b/config/augmentations/stages/train_seg.yaml new file mode 100644 index 0000000..7c274f7 --- /dev/null +++ b/config/augmentations/stages/train_seg.yaml @@ -0,0 +1,76 @@ +_target_: ahcore.transforms.augmentations.AugmentationFactory + +# These transforms will be applied in order as a first step. +initial_transforms: + - _target_: ahcore.transforms.augmentations.MeanStdNormalizer + mean: [0.0, 0.0, 0.0] + std: [1.0, 1.0, 1.0] + +# [a, b] select between a and b random intensity transforms to apply. +# If True, will select all in a random order. If False, will apply all sequentially. +random_apply_intensity: false +random_apply_weights_intensity: null +intensity_augmentations: + - _target_: ahcore.transforms.augmentations.HEDColorAugmentation + p: 0.5 + scale_sigma: 0.2 # HED-light + bias_sigma: 0.2 # HED-light + - _target_: kornia.augmentation.ColorJitter + p: 0.5 + brightness: 0.05 + contrast: 0.05 + saturation: 0.05 + hue: 0.05 + - _target_: kornia.augmentation.RandomGaussianBlur + p: 0.5 + kernel_size: [9, 9] + sigma: [0.1, 1.0] + - _target_: kornia.augmentation.RandomSharpness + p: 0.5 + sharpness: 10 + +# - _target_: kornia.augmentation.RandomSaturation +# p: 0.5 +# saturation: [0.5, 2.0] + +# [a, b] select between a and b random geometric transforms to apply. +# If True, will select all in a random order. If False, will apply all sequentially. +random_apply_geometric: [1,] +random_apply_weights_geometric: null +geometric_augmentations: + - _target_: kornia.augmentation.RandomHorizontalFlip + p: 0.5 + - _target_: kornia.augmentation.RandomVerticalFlip + p: 0.5 + - _target_: kornia.augmentation.RandomPerspective + p: 0.5 + distortion_scale: 0.5 + - _target_: kornia.augmentation.RandomAffine + p: 0.5 + degrees: [0.0, 90.0] + scale: null + translate: null + shear: null + - _target_: kornia.augmentation.RandomAffine + p: 0.5 + degrees: 0 + scale: [0.75, 1.0] + translate: null + shear: null + - _target_: kornia.augmentation.RandomAffine + p: 0.5 + degrees: 0 + scale: null + translate: [0, 0.05] # This is given in percentages + shear: null + - _target_: kornia.augmentation.RandomAffine + p: 0.5 + degrees: 0 + scale: null + translate: null + shear: [0, 5] + +# These transforms are applied in order as a final step. +final_transforms: + - _target_: ahcore.transforms.augmentations.CenterCrop + size: [512, 512] diff --git a/config/augmentations/stages/val_seg.yaml b/config/augmentations/stages/val_seg.yaml new file mode 100644 index 0000000..417b80b --- /dev/null +++ b/config/augmentations/stages/val_seg.yaml @@ -0,0 +1,11 @@ +_target_: ahcore.transforms.augmentations.AugmentationFactory + +# These transforms will be applied in order as a first step. +initial_transforms: + - _target_: ahcore.transforms.augmentations.MeanStdNormalizer + mean: [0.0, 0.0, 0.0] + std: [1.0, 1.0, 1.0] + +intensity_augmentations: null +geometric_augmentations: null +final_transforms: null diff --git a/config/augmentations/wsi_classification.yaml b/config/augmentations/wsi_classification.yaml new file mode 100644 index 0000000..2851a4a --- /dev/null +++ b/config/augmentations/wsi_classification.yaml @@ -0,0 +1,11 @@ +# Below we set up transforms for each stage in pytorch lightning +# the '@name' behind the renames the config to 'name' -- allows us to re-use the validation config for other stages +# To override specific values you can use the expected override, e.g., 'transforms.predict.probability=0' +# We can, however, also override groups, albeit a bit more unfamiliar: 'transforms/stages@transforms.predict=train_seg' + +defaults: + - _self_ + - /transforms/stages@fit: train_wsi_classification + - /transforms/stages@validate: val_seg + - /transforms/stages@test: val_seg + - /transforms/stages@predict: null diff --git a/config/callbacks/compute_wsi_metrics_callback.yaml b/config/callbacks/compute_wsi_metrics_callback.yaml new file mode 100644 index 0000000..b8cbf7b --- /dev/null +++ b/config/callbacks/compute_wsi_metrics_callback.yaml @@ -0,0 +1,3 @@ +compute_wsi_metrics_callback: + _target_: ahcore.callbacks.ComputeWsiMetricsCallback + save_per_image: True diff --git a/config/callbacks/default.yaml b/config/callbacks/default.yaml new file mode 100644 index 0000000..f4e65e1 --- /dev/null +++ b/config/callbacks/default.yaml @@ -0,0 +1,41 @@ +defaults: + - learningrate_monitor.yaml + - model_checkpoint.yaml + - early_stopping.yaml + - model_summary.yaml + - rich_progress_bar.yaml + - write_h5_callback.yaml +# - write_tiff_callback.yaml + - compute_wsi_metrics_callback.yaml # This needs to follow write_h5_callback as it requires the outputs + - _self_ + +learningrate_monitor: + logging_interval: null + log_momentum: false + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}" + monitor: "validate/loss" + mode: "min" + save_last: True + auto_insert_metric_name: False + +early_stopping: + monitor: "validate/wsi_dice/stroma" + patience: 1000 + mode: "min" + +model_summary: + max_depth: -1 + +write_h5_callback: + max_queue_size: 100 + max_concurrent_writers: 4 + +write_tiff_callback: + max_concurrent_writers: 4 + +compute_wsi_metrics_callback: + max_processes: 4 + save_per_image: true diff --git a/config/callbacks/early_stopping.yaml b/config/callbacks/early_stopping.yaml new file mode 100644 index 0000000..20ed267 --- /dev/null +++ b/config/callbacks/early_stopping.yaml @@ -0,0 +1,17 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html + +# Monitor a metric and stop training when it stops improving. +# Look at the above link for more detailed information. +early_stopping: + _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: ??? # quantity to be monitored, must be specified !!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 3 # number of checks with no improvement after which training will be stopped + verbose: False # verbosity mode + mode: "min" # "max" means higher metric value is better, can be also "min" + strict: True # whether to crash the training if monitor is not found in the validation metrics + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version diff --git a/config/callbacks/inference.yaml b/config/callbacks/inference.yaml new file mode 100644 index 0000000..ce37b35 --- /dev/null +++ b/config/callbacks/inference.yaml @@ -0,0 +1,16 @@ +defaults: + - model_summary.yaml + - rich_progress_bar.yaml + - write_h5_callback.yaml + - write_tiff_callback.yaml + - _self_ + +model_summary: + max_depth: -1 + +write_h5_callback: + max_queue_size: 100 + max_concurrent_writers: 4 + +write_tiff_callback: + max_concurrent_writers: 4 diff --git a/config/callbacks/learningrate_monitor.yaml b/config/callbacks/learningrate_monitor.yaml new file mode 100644 index 0000000..c6114a2 --- /dev/null +++ b/config/callbacks/learningrate_monitor.yaml @@ -0,0 +1,2 @@ +learningrate_monitor: + _target_: pytorch_lightning.callbacks.LearningRateMonitor diff --git a/config/callbacks/model_checkpoint.yaml b/config/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000..d2ae439 --- /dev/null +++ b/config/callbacks/model_checkpoint.yaml @@ -0,0 +1,19 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html + +# Save the model periodically by monitoring a quantity. +# Look at the above link for more detailed information. +model_checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + dirpath: null # directory to save the model file + filename: null # checkpoint filename + monitor: null # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: 1 # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/config/callbacks/model_summary.yaml b/config/callbacks/model_summary.yaml new file mode 100644 index 0000000..04da98d --- /dev/null +++ b/config/callbacks/model_summary.yaml @@ -0,0 +1,7 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html + +# Generates a summary of all layers in a LightningModule with rich text formatting. +# Look at the above link for more detailed information. +model_summary: + _target_: pytorch_lightning.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/config/callbacks/rich_progress_bar.yaml b/config/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000..b6be5b4 --- /dev/null +++ b/config/callbacks/rich_progress_bar.yaml @@ -0,0 +1,6 @@ +# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html + +# Create a progress bar with rich text formatting. +# Look at the above link for more detailed information. +rich_progress_bar: + _target_: pytorch_lightning.callbacks.RichProgressBar diff --git a/config/callbacks/write_h5_callback.yaml b/config/callbacks/write_h5_callback.yaml new file mode 100644 index 0000000..f301552 --- /dev/null +++ b/config/callbacks/write_h5_callback.yaml @@ -0,0 +1,3 @@ +write_h5_callback: + _target_: ahcore.callbacks.WriteH5Callback + dump_dir: ${oc.env:SCRATCH}/${task_name}/${now:%Y-%m-%d}_${now:%H-%M-%S} diff --git a/config/callbacks/write_tiff_callback.yaml b/config/callbacks/write_tiff_callback.yaml new file mode 100644 index 0000000..5aab591 --- /dev/null +++ b/config/callbacks/write_tiff_callback.yaml @@ -0,0 +1,2 @@ +write_tiff_callback: + _target_: ahcore.callbacks.WriteTiffCallback diff --git a/config/datamodule/dataset.yaml b/config/datamodule/dataset.yaml new file mode 100644 index 0000000..77bc3af --- /dev/null +++ b/config/datamodule/dataset.yaml @@ -0,0 +1,3 @@ +_target_: ahcore.data.dataset.DlupDataModule +batch_size: ??? # specify in machine_settings +num_workers: ??? # specify in machine_settings diff --git a/config/debug/default.yaml b/config/debug/default.yaml new file mode 100644 index 0000000..f579b66 --- /dev/null +++ b/config/debug/default.yaml @@ -0,0 +1,25 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +trainer: + max_epochs: 1 + gpus: 0 # debuggers don't like gpus + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + track_grad_norm: 2 # track gradient norm with loggers + +datamodule: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + verbose: True + + # use this to set level of only chosen command line loggers to 'DEBUG': + # verbose: [src.train, src.utils] + +# config is already printed by hydra when `hydra/verbose: True` +print_config: False diff --git a/config/experiment/example_folder/example_experiment.yaml b/config/experiment/example_folder/example_experiment.yaml new file mode 100644 index 0000000..eff7a16 --- /dev/null +++ b/config/experiment/example_folder/example_experiment.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example_folder/example_experiment + + +defaults: + - override /datamodule: debug_dataset.yaml + - override /lit_module: debug_lit.yaml + - override /callbacks: null + - override /logger: null + - override /trainer: segmentation.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +# Set experiment to true -- allowing any checks to run for experiment +tags: ["example", "debug"] + +# name of the run determines folder name in logs (if the name is 'default' -- it will complain!) +task_name: "debug_experiment" + +seed: 12345 diff --git a/config/hparams_search/optuna.yaml b/config/hparams_search/optuna.yaml new file mode 100644 index 0000000..ee20ce7 --- /dev/null +++ b/config/hparams_search/optuna.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +# example hyperparameter optimization of some experiment with Optuna: +# python train.py -m hparams_search=mnist_optuna experiment=example + +defaults: + - override /hydra/sweeper: optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module +optimized_metric: "val/acc_best" + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper diff --git a/config/hydra/default.yaml b/config/hydra/default.yaml new file mode 100644 index 0000000..6e473f0 --- /dev/null +++ b/config/hydra/default.yaml @@ -0,0 +1,11 @@ +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/config/inference.yaml b/config/inference.yaml new file mode 100644 index 0000000..8d6e978 --- /dev/null +++ b/config/inference.yaml @@ -0,0 +1,50 @@ +# @package _global_ + +# specify here default inference configuration +defaults: + - _self_ # Means that any variables defined here will be overridden by possible subconfigs below + - data_description: ??? + - datamodule: dataset.yaml + - pre_transform: ??? + - augmentations: ??? + - task: segmentation_inference.yaml + - logger: null + - lit_module: ??? + - callbacks: inference.yaml + - trainer: default.yaml + - hydra: default.yaml # if we want to setup different hydra logging dirs / color log + - paths: default.yaml + - machine_settings: example.yaml + + # experiment configs allow for version control of specific configurations + # e.g. specific hyperparameter searches or something + - experiment: null + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + + # optional local config for user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: segmentation.yaml + +# checkpoint path +ckpt_path: ??? + +# task name -- used for output directory path (or name it to a specific experiment if you'd like) +task_name: "inference" + +# tags that can be found in mlflow to keep track of experiments +# note you cannot append in cli currently (perhaps in future) +tags: ["dev"] + +# pretty print config at the start of the run using Rich library +print_config: True + +# disable python warnings if they annoy you +ignore_warnings: False + +# force tags to be entered (e.g. for running experiments): +enforce_tags: False + +# seed for random number generators in pytorch, numpy and python.random +seed: 42 diff --git a/config/lit_module/debug_lit.yaml b/config/lit_module/debug_lit.yaml new file mode 100644 index 0000000..9f344d9 --- /dev/null +++ b/config/lit_module/debug_lit.yaml @@ -0,0 +1,2 @@ +_target_: ahcore.debug_lit_multigpu.Debug +input_size: [100, 100, 3] diff --git a/config/lit_module/example_model.yaml b/config/lit_module/example_model.yaml new file mode 100644 index 0000000..6791c55 --- /dev/null +++ b/config/lit_module/example_model.yaml @@ -0,0 +1,6 @@ +# _target_: here_should_be_some_Lit_module +lr: 0.1 + +net: + # _target_: some_model + input_size: 200 diff --git a/config/lit_module/monai_segmentation/attention_unet.yaml b/config/lit_module/monai_segmentation/attention_unet.yaml new file mode 100644 index 0000000..31360da --- /dev/null +++ b/config/lit_module/monai_segmentation/attention_unet.yaml @@ -0,0 +1,26 @@ +_target_: ahcore.lit_module.AhCoreLightningModule +model: + # Do not set out_channels, this is derived from data_description. + _target_: monai.networks.nets.attentionunet.AttentionUnet + _partial_: true + spatial_dims: 2 + in_channels: 3 + channels: [32, 64, 128, 256, 512] + strides: [2, 2, 2, 2] + dropout: 0.1 + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 5e-4 + weight_decay: 1e-4 + betas: + - 0.99 + - 0.9 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 diff --git a/config/lit_module/monai_segmentation/swin_unetr.yaml b/config/lit_module/monai_segmentation/swin_unetr.yaml new file mode 100644 index 0000000..793f2d7 --- /dev/null +++ b/config/lit_module/monai_segmentation/swin_unetr.yaml @@ -0,0 +1,37 @@ +_target_: ahcore.lit_module.AhCoreLightningModule + +model: + # TODO(?): Do not set out_channels, this is derived from data_description. + _target_: monai.networks.nets.swin_unetr.SwinUNETR + _partial_: true + img_size: [512, 512] + in_channels: 3 + out_channels: 4 + depths: [2, 2, 2, 2] + num_heads: [3, 6, 12, 24] + feature_size: 24 + norm_name: "instance" + use_checkpoint: false + drop_rate: 0.0 + attn_drop_rate: 0.0 + dropout_path_rate: 0.0 + normalize: true + spatial_dims: 2 + downsample: "merging" + use_v2: true + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 5e-4 + weight_decay: 1e-4 + betas: + - 0.99 + - 0.9 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 diff --git a/config/lit_module/monai_segmentation/unet.yaml b/config/lit_module/monai_segmentation/unet.yaml new file mode 100644 index 0000000..06eeb9f --- /dev/null +++ b/config/lit_module/monai_segmentation/unet.yaml @@ -0,0 +1,29 @@ +_target_: ahcore.lit_module.AhCoreLightningModule + +model: + # Do not set out_channels, this is derived from data_description. + _target_: monai.networks.nets.UNet + _partial_: true + spatial_dims: 2 + in_channels: 3 + out_channels: 4 # TODO: Pass automatically? + kernel_size: 3 + channels: [4, 8, 16, 32, 64] + strides: [2, 2, 2, 2] + dropout: 0.1 + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 5e-4 + weight_decay: 1e-4 + betas: + - 0.99 + - 0.9 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 diff --git a/config/lit_module/new_unet.yaml b/config/lit_module/new_unet.yaml new file mode 100644 index 0000000..7c381d3 --- /dev/null +++ b/config/lit_module/new_unet.yaml @@ -0,0 +1,25 @@ +_target_: ahcore.lit_module.AhCoreLightningModule + +model: + _target_: ahcore.models.new_unet.UNet + _partial_: true + num_input_ch: 3 + depth: 4 + num_initial_filters: 128 +# dropout_prob: 0.1 + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 5e-4 + weight_decay: 1e-4 + betas: + - 0.99 + - 0.9 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 100 diff --git a/config/lit_module/resnet50_classification.yaml b/config/lit_module/resnet50_classification.yaml new file mode 100644 index 0000000..81ac523 --- /dev/null +++ b/config/lit_module/resnet50_classification.yaml @@ -0,0 +1,18 @@ +_target_: ahcore.lit_module.AhCoreLightningModule + +model: + _target_: ahcore.models.resnet50.Resnet50Classification + _partial_: true + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0005 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 diff --git a/config/lit_module/unet.yaml b/config/lit_module/unet.yaml new file mode 100644 index 0000000..e3c90ed --- /dev/null +++ b/config/lit_module/unet.yaml @@ -0,0 +1,22 @@ +_target_: ahcore.lit_module.AhCoreLightningModule + +model: + _target_: ahcore.models.unet.Unet + num_input_ch: 3 + depth: 4 + num_initial_filters: 64 + dropout_prob: 0.1 + _partial_: true + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.0005 + weight_decay: 0.0005 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 100 diff --git a/config/logger/mlflow.yaml b/config/logger/mlflow.yaml new file mode 100644 index 0000000..de51287 --- /dev/null +++ b/config/logger/mlflow.yaml @@ -0,0 +1,10 @@ +mlflow: + _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.output_dir}/mlflow/mlruns + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/config/logger/tensorboard.yaml b/config/logger/tensorboard.yaml new file mode 100644 index 0000000..b22fc33 --- /dev/null +++ b/config/logger/tensorboard.yaml @@ -0,0 +1,7 @@ +tensorboard: + _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + save_dir: ${paths.output_dir}/tensorboard + name: null + log_graph: False + default_hp_metric: False + prefix: "" diff --git a/config/losses/segmentation_ce.yaml b/config/losses/segmentation_ce.yaml new file mode 100644 index 0000000..9ec43e0 --- /dev/null +++ b/config/losses/segmentation_ce.yaml @@ -0,0 +1,9 @@ +_target_: ahcore.losses.LossFactory +weights: + - 1.0 +losses: + - cross_entropy: + _target_: ahcore.losses.cross_entropy + _partial_: true + ignore_index: 0 + topk: null diff --git a/config/losses/segmentation_tiger.yaml b/config/losses/segmentation_tiger.yaml new file mode 100644 index 0000000..445d8cc --- /dev/null +++ b/config/losses/segmentation_tiger.yaml @@ -0,0 +1,15 @@ +_target_: ahcore.losses.LossFactory +weights: + - 1.0 + - 1.0 +losses: + - cross_entropy: + _target_: ahcore.losses.cross_entropy + _partial_: true + label_smoothing: 0.08 + limit: -15.0 + topk: null + - soft_dice: + _target_: ahcore.losses.soft_dice + _partial_: true + ignore_index: 0 diff --git a/config/losses/segmentation_topk.yaml b/config/losses/segmentation_topk.yaml new file mode 100644 index 0000000..04fd7c9 --- /dev/null +++ b/config/losses/segmentation_topk.yaml @@ -0,0 +1,9 @@ +_target_: ahcore.losses.LossFactory +weights: + - 1.0 +losses: + - cross_entropy: + _target_: ahcore.losses.cross_entropy + _partial_: true + ignore_index: null + topk: 0.1 diff --git a/config/machine_settings/example.yaml b/config/machine_settings/example.yaml new file mode 100644 index 0000000..b05b673 --- /dev/null +++ b/config/machine_settings/example.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# In this file you can override machine / partition specific parameters +# This allows you to change machine_settings=your_machine to align with the computational availability +# Note that the @package _global_ on top is important; this allows you to override in the main config + +# Note that machine settings are excluded from version control + +# Example for datamodule: + +datamodule: + num_workers: 1 + batch_size: 16 diff --git a/config/metrics/segmentation.yaml b/config/metrics/segmentation.yaml new file mode 100644 index 0000000..cac63e8 --- /dev/null +++ b/config/metrics/segmentation.yaml @@ -0,0 +1,6 @@ +tile_level: + _target_: ahcore.metrics.MetricFactory.for_segmentation + +wsi_level: + _target_: ahcore.metrics.WSIMetricFactory.for_segmentation + compute_overall_dice: True diff --git a/config/paths/default.yaml b/config/paths/default.yaml new file mode 100644 index 0000000..2cb4101 --- /dev/null +++ b/config/paths/default.yaml @@ -0,0 +1,13 @@ +# path to root directory +root_dir: ${oc.env:PROJECT_ROOT} # specify in .env -- this will be the folder in which /logs and other stuff is dropped + +# path to logging directory (here is where hydra will create all of its subdirectories -- only internal usage) +log_dir: ${paths.root_dir}/logs + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics etc. +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/config/pre_transform/segmentation.yaml b/config/pre_transform/segmentation.yaml new file mode 100644 index 0000000..f5f03b1 --- /dev/null +++ b/config/pre_transform/segmentation.yaml @@ -0,0 +1,2 @@ +_target_: ahcore.transforms.pre_transforms.PreTransformTaskFactory.for_segmentation +_partial_: true diff --git a/config/pre_transform/tile_classification.yaml b/config/pre_transform/tile_classification.yaml new file mode 100644 index 0000000..23198a9 --- /dev/null +++ b/config/pre_transform/tile_classification.yaml @@ -0,0 +1,2 @@ +_target_: ahcore.transforms.pre_transforms.PreTransformTaskFactory.for_tile_classification +_partial_: true diff --git a/config/pre_transform/wsi_classification.yaml b/config/pre_transform/wsi_classification.yaml new file mode 100644 index 0000000..c2654f0 --- /dev/null +++ b/config/pre_transform/wsi_classification.yaml @@ -0,0 +1,2 @@ +_target_: ahcore.transforms.pre_transforms.PreTransformTaskFactory.for_wsi_classification +_partial_: true diff --git a/config/task/segmentation.yaml b/config/task/segmentation.yaml new file mode 100644 index 0000000..c4c8a39 --- /dev/null +++ b/config/task/segmentation.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +defaults: + - override /pre_transform: segmentation.yaml + - override /metrics: segmentation.yaml + - override /losses: segmentation_ce.yaml + # TODO: we could also consider specifying the lit_module here + +task_name: ??? diff --git a/config/task/segmentation_inference.yaml b/config/task/segmentation_inference.yaml new file mode 100644 index 0000000..11bd5e9 --- /dev/null +++ b/config/task/segmentation_inference.yaml @@ -0,0 +1,7 @@ +# @package _global_ + +defaults: + - override /pre_transform: segmentation.yaml + - override /augmentations: segmentation.yaml + +task_name: "segmentation_inference" diff --git a/config/train.yaml b/config/train.yaml new file mode 100644 index 0000000..1db0f44 --- /dev/null +++ b/config/train.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +# specify here default training configuration +defaults: + - _self_ # Means that any variables defined here will be overridden by possible subconfigs below + - data_description: ??? + - datamodule: dataset.yaml + - pre_transform: ??? + - augmentations: ??? + - metrics: ??? + - losses: ??? + - task: segmentation.yaml + - lit_module: ??? + - logger: # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - mlflow + - tensorboard + - callbacks: default.yaml + - trainer: default.yaml + - hydra: default.yaml # if we want to setup different hydra logging dirs / color log + - paths: default.yaml + - machine_settings: example.yaml + + # experiment configs allow for version control of specific configurations + # e.g. specific hyperparameter searches or something + - experiment: null + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: segmentation.yaml + +# checkpoint path +ckpt_path: null + +# task name -- used for output directory path (or name it to a specific experiment if you'd like) +task_name: "train" + +# tags that can be found in mlflow to keep track of experiments +# note you cannot append in cli currently (perhaps in future) +tags: ["dev"] + +# pretty print config at the start of the run using Rich library +print_config: True + +# disable python warnings if they annoy you +ignore_warnings: False + +# force tags to be entered (e.g. for running experiments): +enforce_tags: False + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test_after_training: False + +# seed for random number generators in pytorch, numpy and python.random +seed: 42 diff --git a/config/trainer/cpu.yaml b/config/trainer/cpu.yaml new file mode 100644 index 0000000..85fcea1 --- /dev/null +++ b/config/trainer/cpu.yaml @@ -0,0 +1,10 @@ +_target_: pytorch_lightning.Trainer + +accelerator: cpu +devices: 1 +max_epochs: 1000 +num_nodes: 1 + +# number of validation steps to execute at the beginning of the training +num_sanity_val_steps: 0 +log_every_n_steps: 2 diff --git a/config/trainer/default.yaml b/config/trainer/default.yaml new file mode 100644 index 0000000..2882869 --- /dev/null +++ b/config/trainer/default.yaml @@ -0,0 +1,15 @@ +_target_: pytorch_lightning.Trainer + +accelerator: gpu +devices: 1 +max_epochs: 1000 +num_nodes: 1 + +# number of validation steps to execute at the beginning of the training +num_sanity_val_steps: 0 +#log_every_n_steps: 2 + +# if a smaller grid is used, the val check interval should be smaller +#val_check_interval: 2 # Used if you want to check val more than once per epoch +check_val_every_n_epoch: 1 # Used if you want to check val less than once per epoch +accumulate_grad_batches: 1 diff --git a/config/trainer/default_ddp.yaml b/config/trainer/default_ddp.yaml new file mode 100644 index 0000000..02fc412 --- /dev/null +++ b/config/trainer/default_ddp.yaml @@ -0,0 +1,16 @@ +_target_: pytorch_lightning.Trainer + +accelerator: gpu +devices: 2 +num_nodes: 1 +max_epochs: 1000 +strategy: ddp_find_unused_parameters_false +precision: 32 + +gradient_clip_val: null +gradient_clip_algorithm: null +track_grad_norm: -1 + +# number of validation steps to execute at the beginning of the training +num_sanity_val_steps: 0 +log_every_n_steps: 2 diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..88999dc --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = python -msphinx +SPHINXPROJ = dlup +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/cli.rst b/docs/cli.rst new file mode 100644 index 0000000..0b20530 --- /dev/null +++ b/docs/cli.rst @@ -0,0 +1,5 @@ +Command line tools +================== + +Ahcore provides several command-line tools, all starting with :code:`ahcore `. The documentation can +be found using :code:`ahcore --help`, or similarly by appending :code:`--help` to any of the subcommands. diff --git a/docs/conf.py b/docs/conf.py new file mode 100755 index 0000000..a1342e0 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python +# coding=utf-8 +# +# Ahcore documentation build configuration file. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. +import ast + +# If extensions (or modules to document with autodoc) are in another +# directory, add these directories to sys.path here. If the directory is +# relative to the documentation root, use os.path.abspath to make it +# absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath("..")) + +from typing import Dict, List # noqa: E402 + +import ahcore # noqa: E402 + +# -- General configuration --------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +curpath = os.path.dirname(__file__) +sys.path.append(os.path.join(curpath, "ext")) + +extensions = [ + "sphinx_copybutton", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "numpydoc", + "doi_role", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", + "myst_parser", +] + +# Do not copy prompts in code. +copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " +copybutton_prompt_is_regexp = True + +# Do not add class members in generated docs +numpydoc_show_class_members = False + +# Add class content from main and derived classes +autoclass_content = "both" + +# build the templated autosummary files +autosummary_generate = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = ".rst" + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = "ahcore" +copyright = "2022, ahcore contributors" +author = "Jonas Teuwen" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. + +with open("../ahcore/__init__.py") as f: + for line in f: + if line.startswith("__version__"): + version = ast.parse(line).body[0].value.s # type: ignore + break + + +# The version info for the project you're documenting, acts as replacement +# for |version| and |release|, also used in various other places throughout +# the built documents. +release = version + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# List of directories, relative to source directory, that shouldn't be searched +# for source files. +exclude_trees: List[str] = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# ------------------------------------------------------------------------ +# Sphinx-gallery configuration +# ------------------------------------------------------------------------ + +from packaging.version import parse # noqa: E402 + +v = parse(release) +if v.release is None: + raise ValueError("Ill-formed version: {!r}. Version should follow " "PEP440".format(version)) + +if v.is_devrelease: + binder_branch = "main" +else: + major, minor = v.release[:2] + binder_branch = "v{}.{}.x".format(major, minor) + + +# -- Options for HTML output ------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_book_theme" + +# Theme options are theme-specific and customize the look and feel of a +# theme further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + "repository_url": "https://github.com/NKI-AI/ahcore.git", + "repository_branch": "main", + "use_issues_button": True, + "use_edit_page_button": True, + "use_download_button": False, + "single_page": False, + "use_fullscreen_button": False, + "home_page_in_toc": True, +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# -- Options for HTMLHelp output --------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = "ahcoredoc" + + +# -- Options for LaTeX output ------------------------------------------ + +latex_elements: Dict[str, str] = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass +# [howto, manual, or own class]). +latex_documents = [ + (master_doc, "ahcore.tex", "Ahcore Documentation", "Jonas Teuwen", "manual"), +] + + +# -- Options for manual page output ------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [(master_doc, "ahcore", "Ahcore Documentation", [author], 1)] + + +# -- Options for Texinfo output ---------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + master_doc, + "ahcore", + "Ahcore Documentation", + author, + "ahcore", + "Computational pathology models.", + "Miscellaneous", + ), +] + + +# ---------------------------------------------------------------------------- +# Source code links +# ---------------------------------------------------------------------------- + +import inspect # noqa: E402 +from os.path import dirname, relpath # noqa: E402 + + +# Function courtesy of NumPy to return URLs containing line numbers +def linkcode_resolve(domain, info): + """ + Determine the URL corresponding to Python object + """ + if domain != "py": + return None + + modname = info["module"] + fullname = info["fullname"] + + submod = sys.modules.get(modname) + if submod is None: + return None + + obj = submod + for part in fullname.split("."): + try: + obj = getattr(obj, part) + except: # noqa + return None + + # Strip decorators which would resolve to the source of the decorator + obj = inspect.unwrap(obj) + + try: + fn = inspect.getsourcefile(obj) + except: # noqa + fn = None + if not fn: + return None + + try: + source, start_line = inspect.getsourcelines(obj) + except: # noqa + linespec = "" + else: + stop_line = start_line + len(source) - 1 + linespec = f"#L{start_line}-L{stop_line}" + + fn = relpath(fn, start=dirname(ahcore.__file__)) + + if "dev" in ahcore.__version__: + return "https://github.com/NKI-AI/ahcore/blob/" "main/ahcore/%s%s" % ( + fn, + linespec, + ) + return "https://github.com/NKI-AI/ahcore/blob/" "v%s/ahcore/%s%s" % ( + ahcore.__version__, + fn, + linespec, + ) diff --git a/docs/contributing.rst b/docs/contributing.rst new file mode 100644 index 0000000..e582053 --- /dev/null +++ b/docs/contributing.rst @@ -0,0 +1 @@ +.. include:: ../CONTRIBUTING.rst diff --git a/docs/ext/doi_role.py b/docs/ext/doi_role.py new file mode 100644 index 0000000..a56fd38 --- /dev/null +++ b/docs/ext/doi_role.py @@ -0,0 +1,51 @@ +# coding=utf-8 +""" +doilinks +~~~~~~~~ +Extension to add links to DOIs. With this extension you can use e.g. +:doi:`10.1016/S0022-2836(05)80360-2` in your documents. This will +create a link to a DOI resolver +(``https://doi.org/10.1016/S0022-2836(05)80360-2``). +The link caption will be the raw DOI. +You can also give an explicit caption, e.g. +:doi:`Basic local alignment search tool <10.1016/S0022-2836(05)80360-2>`. + +:copyright: Copyright 2015 Jon Lund Steffensen. Based on extlinks by + the Sphinx team. +:license: BSD. +""" + +from docutils import nodes, utils # type: ignore +from sphinx.util.nodes import split_explicit_title # type: ignore + + +def doi_role(typ, rawtext, text, lineno, inliner, options={}, content=[]): + text = utils.unescape(text) + has_explicit_title, title, part = split_explicit_title(text) + full_url = "https://doi.org/" + part + if not has_explicit_title: + title = "DOI:" + part + pnode = nodes.reference(title, title, internal=False, refuri=full_url) + return [pnode], [] + + +def arxiv_role(typ, rawtext, text, lineno, inliner, options={}, content=[]): + text = utils.unescape(text) + has_explicit_title, title, part = split_explicit_title(text) + full_url = "https://arxiv.org/abs/" + part + if not has_explicit_title: + title = "arXiv:" + part + pnode = nodes.reference(title, title, internal=False, refuri=full_url) + return [pnode], [] + + +def setup_link_role(app): + app.add_role("doi", doi_role, override=True) + app.add_role("DOI", doi_role, override=True) + app.add_role("arXiv", arxiv_role, override=True) + app.add_role("arxiv", arxiv_role, override=True) + + +def setup(app): + app.connect("builder-inited", setup_link_role) + return {"version": "0.1", "parallel_read_safe": True} diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..cec8712 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,26 @@ +.. role:: bash(code) + :language: bash + +AI for Oncology Core for Comptuational Pathology +================================================ + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + cli + contributing + modules + +Indices and tables +================== +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + + + +.. _OpenSlide: https://openslide.org +.. _fork of OpenSlide: https://github.com/NKI-AI/OpenSlide +.. _Github repo: https://github.com/NKI-AI/dlup diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..0eeac90 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=python -msphinx +) +set SOURCEDIR=. +set BUILDDIR=_build +set SPHINXPROJ=dlup + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The Sphinx module was not found. Make sure you have Sphinx installed, + echo.then set the SPHINXBUILD environment variable to point to the full + echo.path of the 'sphinx-build' executable. Alternatively you may add the + echo.Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c983c61 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,45 @@ +[tool.black] +line-length = 120 # PyCharm line length +target-version = ['py39', 'py310'] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | \.idea + | _build + | buck-out + | build + | dist +)/ +''' + +[tool.isort] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true +line_length = 120 + +[tool.pylint.messages_control] +disable = "W1203, C0415" + +[tool.pylint.basic] +good-names="i,j,k,x,y,ex,Run,_" + +[tool.pylint.'TYPECHECK'] +generated-members=['numpy.*', 'torch.*', 'np.*', 'cv2.*'] +ignored-modules=['pydantic', 'dlup'] + +[tool.pylint.format] +max-line-length = "120" + +[tool.pylint.design] +max-args=20 +max-branches=30 +max-parents=15 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..d6b68f3 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,22 @@ +[bumpversion] +current_version = 0.1 +commit = True +tag = False +parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? +serialize = + {major}.{minor}.{patch}-{release}{build} + {major}.{minor}.{patch} + +[bumpversion:part:release] +optional_value = prod +first_value = dev +values = + dev + prod + +[bumpversion:file:ahcore/__init__.py] +search = {current_version} +replace = {new_version} + +[aliases] +test = pytest diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..1d0075d --- /dev/null +++ b/setup.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# coding=utf-8 +"""The setup script.""" +import ast + +from setuptools import find_packages, setup # type: ignore # noqa + +with open("ahcore/__init__.py") as f: + for line in f: + if line.startswith("__version__"): + version = ast.parse(line).body[0].value.s # type: ignore + break + + +with open("README.md") as readme_file: + long_description = readme_file.read() + +install_requires = [ + "numpy>=1.25.2", + "torch>=2.0.1", + "pillow>=9.5.0", + "pytorch-lightning>=2.0.8", + "torchvision>=0.15.2", + "pydantic>=2.0.3", + "tensorboard>=2.14.0", + "mlflow>=2.6.0", + "hydra-core>=1.3.2", + "python-dotenv>=1.0.0", + "tqdm>=4.64", + "rich>=12.4", + "hydra-submitit-launcher>=1.2.0", + "hydra-optuna-sweeper>=1.3.0.dev0", + "hydra-colorlog>=1.2.0", + "dlup>=0.3.30", + "kornia>=0.7.0", + "h5py>=3.8.0", + "monai[einops]==1.2.0", + "imagecodecs==2023.9.4", +] + + +setup( + author="AI for Oncology Lab @ The Netherlands Cancer Institute", + long_description=long_description, + long_description_content_type="text/markdown", + python_requires=">=3.10", + classifiers=[ + "Development Status :: 4 - Beta", + "Natural Language :: English", + "Programming Language :: Python :: 3", + ], + entry_points={ + "console_scripts": [ + "ahcore=ahcore.cli:main", + ], + }, + description="Ahcore the AI for Oncology core components for computational pathology.", + install_requires=install_requires, + extras_require={ + "dev": [ + "pytest", + "numpydoc", + "pylint==2.17.7", + "black==23.9.1", + "types-Pillow", + "sphinx", + "sphinx_copybutton", + "numpydoc", + "myst-parser", + "sphinx-book-theme", + ], + }, + license="Apache Software License 2.0", + include_package_data=True, + name="ahcore", + test_suite="tests", + url="https://github.com/NKI-AI/ahcore", + py_modules=["ahcore"], + version=version, +) diff --git a/tools/.env.example b/tools/.env.example new file mode 100644 index 0000000..1b7a98a --- /dev/null +++ b/tools/.env.example @@ -0,0 +1,5 @@ +PROJECT_ROOT=/home/j.teuwen/logs/ahcore +DATA_DIR=/data/groups/aiforoncology/archive/pathology/TCGA/images/gdc_manifest.2021-11-01_diagnostic_breast.txt/ +ANNOTATIONS_DIR=/data/groups/aiforoncology/derived/pathology/TCGA/gdc_manifest.2021-11-01_diagnostic_breast.txt/ +MANIFEST_PATH=${ANNOTATIONS_DIR} +DATASET_SPLIT=${ANNOTATIONS_DIR} diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 0000000..1292f90 --- /dev/null +++ b/tools/README.md @@ -0,0 +1,10 @@ +# ahcore utilities + +## `convert_wsi_to_tiles.py`: +Utility to pretile WSIs into small tiles for use with dinov2. +Combine with `detect_problems.py` to detect images which were not correctly parsed. + +Process for v1 is: +- run `convert_wsi_to_tiles.py` on all *diagnostic* slides +- run detect_problems.py on the output and remove cases with `parsing` error +- diff --git a/tools/inference.py b/tools/inference.py new file mode 100644 index 0000000..22e1cf1 --- /dev/null +++ b/tools/inference.py @@ -0,0 +1,35 @@ +import dotenv +import hydra +from omegaconf import DictConfig + +# load environment variables from `.env` file if it exists +# recursively searches for `.env` in all folders starting from work dir +dotenv.load_dotenv(override=True) + + +@hydra.main( + config_path="../config", + config_name="inference.yaml", + version_base="1.3", +) +def main(config: DictConfig): + # Imports can be nested inside @hydra.main to optimize tab completion + # https://github.com/facebookresearch/hydra/issues/934 + from ahcore.entrypoints import inference + from ahcore.utils.io import extras, print_config, validate_config + + # Validate config -- Fails if there are mandatory missing values + validate_config(config) + + # Applies optional utilities + extras(config) + + if config.get("print_config"): + print_config(config, resolve=True) + + # Train model + return inference(config) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/tools/populate_tcga_db.py b/tools/populate_tcga_db.py new file mode 100644 index 0000000..1be57b0 --- /dev/null +++ b/tools/populate_tcga_db.py @@ -0,0 +1,133 @@ +"""This is an example on how to populate an ahcore manifest database using the TCGA dataset.""" +import json +import random +from pathlib import Path + +from dlup import SlideImage +from dlup.experimental_backends import ImageBackend # type: ignore + +from ahcore.utils.database_models import ( + CategoryEnum, + Image, + ImageAnnotations, + ImageLabels, + Manifest, + Mask, + Patient, + PatientLabels, + Split, + SplitDefinitions, +) +from ahcore.utils.manifest import open_db + + +def get_patient_from_tcga_id(tcga_filename: str) -> str: + return tcga_filename[:12] + + +def populate_from_annotated_tcga( + session, + image_folder: Path, + annotation_folder: Path, + path_to_mapping: Path, + predict: bool = False, +): + """This is a basic example, adjust to your needs.""" + # TODO: We should do the mpp as well here + + with open(path_to_mapping, "r") as f: + mapping = json.load(f) + manifest = Manifest(name="v20230228") + session.add(manifest) + session.flush() + + split_definition = SplitDefinitions(version="v1", description="Initial split") + session.add(split_definition) + session.flush() + + for folder in annotation_folder.glob("TCGA*"): + patient_code = get_patient_from_tcga_id(folder.name) + + if not predict: + annotation_path = folder / "annotations.json" + mask_path = folder / "roi.json" + + # Only add patient if it doesn't exist + existing_patient = session.query(Patient).filter_by(patient_code=patient_code).first() # type: ignore + if existing_patient: + patient = existing_patient + else: + patient = Patient(patient_code=patient_code, manifest=manifest) + session.add(patient) + session.flush() + + # For now random. + if predict: + split_category = CategoryEnum.PREDICT + else: + split_category = random.choices( + [CategoryEnum.TRAIN, CategoryEnum.VALIDATE, CategoryEnum.TEST], + [67, 33, 0], + )[0] + + split = Split( + category=split_category, + patient=patient, + split_definition=split_definition, + ) + session.add(split) + session.flush() + + # Add only the label if it does not exist yet. + existing_label = session.query(PatientLabels).filter_by(key="study", patient_id=patient.id).first() + if not existing_label: + patient_label = PatientLabels(key="study", value="BRCA", patient=patient) + session.add(patient_label) + session.flush() + + filename = mapping[folder.name] + + # TODO: OPENSLIDE doesn't work + kwargs = {} + if ( + "TCGA-OL-A5RY-01Z-00-DX1.AE4E9D74-FC1C-4C1E-AE6D-5DF38899BBA6.svs" in filename + or "TCGA-OL-A5RW-01Z-00-DX1.E16DE8EE-31AF-4EAF-A85F-DB3E3E2C3BFF.svs" in filename + ): + kwargs["overwrite_mpp"] = (0.25, 0.25) + + with SlideImage.from_file_path( + image_folder / filename, backend=ImageBackend.PYVIPS, **kwargs # type: ignore + ) as slide: # type: ignore + mpp = slide.mpp + width, height = slide.size + image = Image( + filename=str(filename), + mpp=mpp, + height=height, + width=width, + reader="OPENSLIDE", + patient=patient, + ) + session.add(image) + session.flush() # Flush so that Image ID is populated for future records + + if not predict: + mask = Mask(filename=str(mask_path), reader="GEOJSON", image=image) + session.add(mask) + + image_annotation = ImageAnnotations(filename=str(annotation_path), reader="GEOJSON", image=image) + session.add(image_annotation) + + label_data = "cancer" if random.choice([True, False]) else "benign" # Randomly decide if it's cancer or benign + image_label = ImageLabels(label_data=label_data, image=image) + session.add(image_label) + + session.commit() + + +if __name__ == "__main__": + annotation_folder = Path("tissue_subtypes/v20230228_debug/") + image_folder = Path("/data/groups/aiforoncology/archive/pathology/TCGA/images/") + path_to_mapping = Path("/data/groups/aiforoncology/archive/pathology/TCGA/identifier_mapping.json") + with open_db("manifest.db") as session: + populate_from_annotated_tcga(session, image_folder, annotation_folder, path_to_mapping, predict=True) diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000..61212ed --- /dev/null +++ b/tools/train.py @@ -0,0 +1,35 @@ +import dotenv +import hydra +from omegaconf import DictConfig + +# load environment variables from `.env` file if it exists +# recursively searches for `.env` in all folders starting from work dir +dotenv.load_dotenv(override=True) + + +@hydra.main( + config_path="../config", + config_name="train.yaml", + version_base="1.3", +) +def main(config: DictConfig): + # Imports can be nested inside @hydra.main to optimize tab completion + # https://github.com/facebookresearch/hydra/issues/934 + from ahcore.entrypoints import train + from ahcore.utils.io import extras, print_config, validate_config + + # Validate config -- Fails if there are mandatory missing values + validate_config(config) + + # Applies optional utilities + extras(config) + + if config.get("print_config"): + print_config(config, resolve=True) + + # Train model + return train(config) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter