From 66b09537ed45f5f21e014567d6f05c2149fa9e66 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 2 Dec 2024 12:36:01 -0500 Subject: [PATCH 1/5] Rename `project/datamodules`->`project/datasets` Signed-off-by: Fabrice Normandin --- .../project/{datamodules => datasets}/datamodules_test/.gitignore | 0 .../test_first_batch/cifar10_algorithm_no_op_test.yaml | 0 .../test_first_batch/cifar10_algorithm_no_op_train.yaml | 0 .../test_first_batch/cifar10_algorithm_no_op_validate.yaml | 0 .../test_first_batch/fashion_mnist_algorithm_no_op_test.yaml | 0 .../test_first_batch/fashion_mnist_algorithm_no_op_train.yaml | 0 .../test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml | 0 .../test_first_batch/glue_cola_algorithm_no_op_test.yaml | 0 .../test_first_batch/glue_cola_algorithm_no_op_train.yaml | 0 .../test_first_batch/glue_cola_algorithm_no_op_validate.yaml | 0 .../test_first_batch/imagenet_algorithm_no_op_test.yaml | 0 .../test_first_batch/imagenet_algorithm_no_op_train.yaml | 0 .../test_first_batch/imagenet_algorithm_no_op_validate.yaml | 0 .../test_first_batch/mnist_algorithm_no_op_test.yaml | 0 .../test_first_batch/mnist_algorithm_no_op_train.yaml | 0 .../test_first_batch/mnist_algorithm_no_op_validate.yaml | 0 project/{datamodules => datasets}/__init__.py | 0 project/{datamodules => datasets}/datamodules_test.py | 0 .../{datamodules => datasets}/image_classification/__init__.py | 0 project/{datamodules => datasets}/image_classification/cifar10.py | 0 .../image_classification/fashion_mnist.py | 0 .../image_classification/image_classification.py | 0 .../{datamodules => datasets}/image_classification/imagenet.py | 0 .../{datamodules => datasets}/image_classification/inaturalist.py | 0 .../image_classification/inaturalist_test.py | 0 project/{datamodules => datasets}/image_classification/mnist.py | 0 project/{datamodules => datasets}/text/__init__.py | 0 project/{datamodules => datasets}/text/text_classification.py | 0 .../{datamodules => datasets}/text/text_classification_test.py | 0 project/{datamodules => datasets}/vision.py | 0 30 files changed, 0 insertions(+), 0 deletions(-) rename .regression_files/project/{datamodules => datasets}/datamodules_test/.gitignore (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/cifar10_algorithm_no_op_test.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/cifar10_algorithm_no_op_train.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/imagenet_algorithm_no_op_test.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/imagenet_algorithm_no_op_train.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/mnist_algorithm_no_op_test.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/mnist_algorithm_no_op_train.yaml (100%) rename .regression_files/project/{datamodules => datasets}/datamodules_test/test_first_batch/mnist_algorithm_no_op_validate.yaml (100%) rename project/{datamodules => datasets}/__init__.py (100%) rename project/{datamodules => datasets}/datamodules_test.py (100%) rename project/{datamodules => datasets}/image_classification/__init__.py (100%) rename project/{datamodules => datasets}/image_classification/cifar10.py (100%) rename project/{datamodules => datasets}/image_classification/fashion_mnist.py (100%) rename project/{datamodules => datasets}/image_classification/image_classification.py (100%) rename project/{datamodules => datasets}/image_classification/imagenet.py (100%) rename project/{datamodules => datasets}/image_classification/inaturalist.py (100%) rename project/{datamodules => datasets}/image_classification/inaturalist_test.py (100%) rename project/{datamodules => datasets}/image_classification/mnist.py (100%) rename project/{datamodules => datasets}/text/__init__.py (100%) rename project/{datamodules => datasets}/text/text_classification.py (100%) rename project/{datamodules => datasets}/text/text_classification_test.py (100%) rename project/{datamodules => datasets}/vision.py (100%) diff --git a/.regression_files/project/datamodules/datamodules_test/.gitignore b/.regression_files/project/datasets/datamodules_test/.gitignore similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/.gitignore rename to .regression_files/project/datasets/datamodules_test/.gitignore diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/cifar10_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/cifar10_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/cifar10_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/cifar10_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/imagenet_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/imagenet_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/imagenet_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/imagenet_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/mnist_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/mnist_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/mnist_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/mnist_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/mnist_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datamodules/datamodules_test/test_first_batch/mnist_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_validate.yaml diff --git a/project/datamodules/__init__.py b/project/datasets/__init__.py similarity index 100% rename from project/datamodules/__init__.py rename to project/datasets/__init__.py diff --git a/project/datamodules/datamodules_test.py b/project/datasets/datamodules_test.py similarity index 100% rename from project/datamodules/datamodules_test.py rename to project/datasets/datamodules_test.py diff --git a/project/datamodules/image_classification/__init__.py b/project/datasets/image_classification/__init__.py similarity index 100% rename from project/datamodules/image_classification/__init__.py rename to project/datasets/image_classification/__init__.py diff --git a/project/datamodules/image_classification/cifar10.py b/project/datasets/image_classification/cifar10.py similarity index 100% rename from project/datamodules/image_classification/cifar10.py rename to project/datasets/image_classification/cifar10.py diff --git a/project/datamodules/image_classification/fashion_mnist.py b/project/datasets/image_classification/fashion_mnist.py similarity index 100% rename from project/datamodules/image_classification/fashion_mnist.py rename to project/datasets/image_classification/fashion_mnist.py diff --git a/project/datamodules/image_classification/image_classification.py b/project/datasets/image_classification/image_classification.py similarity index 100% rename from project/datamodules/image_classification/image_classification.py rename to project/datasets/image_classification/image_classification.py diff --git a/project/datamodules/image_classification/imagenet.py b/project/datasets/image_classification/imagenet.py similarity index 100% rename from project/datamodules/image_classification/imagenet.py rename to project/datasets/image_classification/imagenet.py diff --git a/project/datamodules/image_classification/inaturalist.py b/project/datasets/image_classification/inaturalist.py similarity index 100% rename from project/datamodules/image_classification/inaturalist.py rename to project/datasets/image_classification/inaturalist.py diff --git a/project/datamodules/image_classification/inaturalist_test.py b/project/datasets/image_classification/inaturalist_test.py similarity index 100% rename from project/datamodules/image_classification/inaturalist_test.py rename to project/datasets/image_classification/inaturalist_test.py diff --git a/project/datamodules/image_classification/mnist.py b/project/datasets/image_classification/mnist.py similarity index 100% rename from project/datamodules/image_classification/mnist.py rename to project/datasets/image_classification/mnist.py diff --git a/project/datamodules/text/__init__.py b/project/datasets/text/__init__.py similarity index 100% rename from project/datamodules/text/__init__.py rename to project/datasets/text/__init__.py diff --git a/project/datamodules/text/text_classification.py b/project/datasets/text/text_classification.py similarity index 100% rename from project/datamodules/text/text_classification.py rename to project/datasets/text/text_classification.py diff --git a/project/datamodules/text/text_classification_test.py b/project/datasets/text/text_classification_test.py similarity index 100% rename from project/datamodules/text/text_classification_test.py rename to project/datasets/text/text_classification_test.py diff --git a/project/datamodules/vision.py b/project/datasets/vision.py similarity index 100% rename from project/datamodules/vision.py rename to project/datasets/vision.py From dadff89cd3f197d474aed4a3cd6c24dcb6e633dc Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 2 Dec 2024 12:40:40 -0500 Subject: [PATCH 2/5] Also rename referrences in python modules Signed-off-by: Fabrice Normandin --- .../datasets/{datamodules_test => datasets_test}/.gitignore | 0 .../test_first_batch/cifar10_algorithm_no_op_test.yaml | 0 .../test_first_batch/cifar10_algorithm_no_op_train.yaml | 0 .../test_first_batch/cifar10_algorithm_no_op_validate.yaml | 0 .../test_first_batch/fashion_mnist_algorithm_no_op_test.yaml | 0 .../fashion_mnist_algorithm_no_op_train.yaml | 0 .../fashion_mnist_algorithm_no_op_validate.yaml | 0 .../test_first_batch/glue_cola_algorithm_no_op_test.yaml | 0 .../test_first_batch/glue_cola_algorithm_no_op_train.yaml | 0 .../test_first_batch/glue_cola_algorithm_no_op_validate.yaml | 0 .../test_first_batch/imagenet_algorithm_no_op_test.yaml | 0 .../test_first_batch/imagenet_algorithm_no_op_train.yaml | 0 .../test_first_batch/imagenet_algorithm_no_op_validate.yaml | 0 .../test_first_batch/mnist_algorithm_no_op_test.yaml | 0 .../test_first_batch/mnist_algorithm_no_op_train.yaml | 0 .../test_first_batch/mnist_algorithm_no_op_validate.yaml | 0 project/algorithms/image_classifier.py | 2 +- project/algorithms/image_classifier_test.py | 4 ++-- project/algorithms/jax_image_classifier.py | 4 ++-- project/algorithms/jax_image_classifier_test.py | 2 +- project/algorithms/text_classifier.py | 2 +- project/algorithms/text_classifier_test.py | 2 +- project/conftest.py | 2 +- project/datasets/{datamodules_test.py => datasets_test.py} | 5 +++-- project/datasets/image_classification/cifar10.py | 2 +- project/datasets/image_classification/fashion_mnist.py | 2 +- .../datasets/image_classification/image_classification.py | 2 +- project/datasets/image_classification/imagenet.py | 2 +- project/datasets/image_classification/inaturalist.py | 2 +- project/datasets/image_classification/inaturalist_test.py | 2 +- project/datasets/image_classification/mnist.py | 2 +- project/datasets/text/text_classification_test.py | 2 +- project/utils/testutils.py | 4 ++-- 33 files changed, 22 insertions(+), 21 deletions(-) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/.gitignore (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/cifar10_algorithm_no_op_test.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/cifar10_algorithm_no_op_train.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/cifar10_algorithm_no_op_validate.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/glue_cola_algorithm_no_op_test.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/glue_cola_algorithm_no_op_train.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/glue_cola_algorithm_no_op_validate.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/imagenet_algorithm_no_op_test.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/imagenet_algorithm_no_op_train.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/imagenet_algorithm_no_op_validate.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/mnist_algorithm_no_op_test.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/mnist_algorithm_no_op_train.yaml (100%) rename .regression_files/project/datasets/{datamodules_test => datasets_test}/test_first_batch/mnist_algorithm_no_op_validate.yaml (100%) rename project/datasets/{datamodules_test.py => datasets_test.py} (96%) diff --git a/.regression_files/project/datasets/datamodules_test/.gitignore b/.regression_files/project/datasets/datasets_test/.gitignore similarity index 100% rename from .regression_files/project/datasets/datamodules_test/.gitignore rename to .regression_files/project/datasets/datasets_test/.gitignore diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/cifar10_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/cifar10_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/cifar10_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/cifar10_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/cifar10_algorithm_no_op_validate.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/fashion_mnist_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/fashion_mnist_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/fashion_mnist_algorithm_no_op_validate.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/imagenet_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/imagenet_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/imagenet_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/imagenet_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/imagenet_algorithm_no_op_validate.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_test.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/mnist_algorithm_no_op_test.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_test.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/mnist_algorithm_no_op_test.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_train.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/mnist_algorithm_no_op_train.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_train.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/mnist_algorithm_no_op_train.yaml diff --git a/.regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_validate.yaml b/.regression_files/project/datasets/datasets_test/test_first_batch/mnist_algorithm_no_op_validate.yaml similarity index 100% rename from .regression_files/project/datasets/datamodules_test/test_first_batch/mnist_algorithm_no_op_validate.yaml rename to .regression_files/project/datasets/datasets_test/test_first_batch/mnist_algorithm_no_op_validate.yaml diff --git a/project/algorithms/image_classifier.py b/project/algorithms/image_classifier.py index 7d397a96..f1555a0a 100644 --- a/project/algorithms/image_classifier.py +++ b/project/algorithms/image_classifier.py @@ -21,7 +21,7 @@ from torch.optim.optimizer import Optimizer from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback -from project.datamodules.image_classification import ImageClassificationDataModule +from project.datasets.image_classification import ImageClassificationDataModule from project.utils.typing_utils import HydraConfigFor logger = getLogger(__name__) diff --git a/project/algorithms/image_classifier_test.py b/project/algorithms/image_classifier_test.py index 7d7023f2..9af90c59 100644 --- a/project/algorithms/image_classifier_test.py +++ b/project/algorithms/image_classifier_test.py @@ -7,8 +7,8 @@ from project.algorithms.testsuites.lightning_module_tests import LightningModuleTests from project.configs import Config from project.conftest import command_line_overrides, skip_on_macOS_in_CI -from project.datamodules.image_classification.cifar10 import CIFAR10DataModule -from project.datamodules.image_classification.image_classification import ( +from project.datasets.image_classification.cifar10 import CIFAR10DataModule +from project.datasets.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.testutils import run_for_all_configs_of_type diff --git a/project/algorithms/jax_image_classifier.py b/project/algorithms/jax_image_classifier.py index cdbf0653..fe09cef2 100644 --- a/project/algorithms/jax_image_classifier.py +++ b/project/algorithms/jax_image_classifier.py @@ -16,10 +16,10 @@ from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback -from project.datamodules.image_classification.image_classification import ( +from project.datasets.image_classification.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.image_classification.mnist import MNISTDataModule +from project.datasets.image_classification.mnist import MNISTDataModule from project.utils.typing_utils import HydraConfigFor diff --git a/project/algorithms/jax_image_classifier_test.py b/project/algorithms/jax_image_classifier_test.py index 8af161ac..79a0dbb7 100644 --- a/project/algorithms/jax_image_classifier_test.py +++ b/project/algorithms/jax_image_classifier_test.py @@ -6,7 +6,7 @@ from project.algorithms.jax_image_classifier import JaxImageClassifier from project.conftest import fails_on_macOS_in_CI -from project.datamodules.image_classification.image_classification import ( +from project.datasets.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.testutils import run_for_all_configs_of_type diff --git a/project/algorithms/text_classifier.py b/project/algorithms/text_classifier.py index 2ef16b1a..e0432587 100644 --- a/project/algorithms/text_classifier.py +++ b/project/algorithms/text_classifier.py @@ -11,7 +11,7 @@ ) from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from project.datamodules.text.text_classification import TextClassificationDataModule +from project.datasets.text.text_classification import TextClassificationDataModule from project.utils.typing_utils import HydraConfigFor diff --git a/project/algorithms/text_classifier_test.py b/project/algorithms/text_classifier_test.py index 7f50ff84..fc9e19c1 100644 --- a/project/algorithms/text_classifier_test.py +++ b/project/algorithms/text_classifier_test.py @@ -12,7 +12,7 @@ from typing_extensions import override from project.algorithms.text_classifier import TextClassifier -from project.datamodules.text.text_classification import TextClassificationDataModule +from project.datasets.text.text_classification import TextClassificationDataModule from project.utils.env_vars import SLURM_JOB_ID from project.utils.testutils import run_for_all_configs_of_type, total_vram_gb diff --git a/project/conftest.py b/project/conftest.py index 6e3d0393..e3c51c1d 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -92,7 +92,7 @@ from torch.utils.data import DataLoader from project.configs.config import Config -from project.datamodules.vision import VisionDataModule, num_cpus_on_node +from project.datasets.vision import VisionDataModule, num_cpus_on_node from project.experiment import ( instantiate_algorithm, instantiate_datamodule, diff --git a/project/datasets/datamodules_test.py b/project/datasets/datasets_test.py similarity index 96% rename from project/datasets/datamodules_test.py rename to project/datasets/datasets_test.py index 311444c3..7b75e13a 100644 --- a/project/datasets/datamodules_test.py +++ b/project/datasets/datasets_test.py @@ -11,10 +11,10 @@ from torch import Tensor from project.conftest import command_line_overrides -from project.datamodules.image_classification.image_classification import ( +from project.datasets.image_classification.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision import VisionDataModule +from project.datasets.vision import VisionDataModule from project.utils.testutils import run_for_all_configs_in_group from project.utils.typing_utils import is_sequence_of @@ -45,6 +45,7 @@ def test_first_batch( stage: RunningStage, datadir: Path, ): + """Test that the first batch of the dataloader is reproducible (the same for the same seed).""" # Note: using dataloader workers in tests can cause issues, since if a test fails, dataloader # workers aren't always cleaned up properly. if isinstance(datamodule, VisionDataModule) or hasattr(datamodule, "num_workers"): diff --git a/project/datasets/image_classification/cifar10.py b/project/datasets/image_classification/cifar10.py index 0e924186..d57707f8 100644 --- a/project/datasets/image_classification/cifar10.py +++ b/project/datasets/image_classification/cifar10.py @@ -4,7 +4,7 @@ from torchvision.datasets import CIFAR10 from torchvision.transforms import v2 as transforms -from project.datamodules.image_classification.image_classification import ( +from project.datasets.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.typing_utils import C, H, W diff --git a/project/datasets/image_classification/fashion_mnist.py b/project/datasets/image_classification/fashion_mnist.py index 613ea6be..45b379c6 100644 --- a/project/datasets/image_classification/fashion_mnist.py +++ b/project/datasets/image_classification/fashion_mnist.py @@ -2,7 +2,7 @@ from torchvision.datasets import FashionMNIST -from project.datamodules.image_classification.mnist import MNISTDataModule +from project.datasets.image_classification.mnist import MNISTDataModule class FashionMNISTDataModule(MNISTDataModule): diff --git a/project/datasets/image_classification/image_classification.py b/project/datasets/image_classification/image_classification.py index 3fe2e26a..00b42f7d 100644 --- a/project/datasets/image_classification/image_classification.py +++ b/project/datasets/image_classification/image_classification.py @@ -4,7 +4,7 @@ from torchvision.tv_tensors import Image from typing_extensions import TypeVar -from project.datamodules.vision import VisionDataModule +from project.datasets.vision import VisionDataModule from project.utils.typing_utils import C, H, W from project.utils.typing_utils.protocols import ClassificationDataModule diff --git a/project/datasets/image_classification/imagenet.py b/project/datasets/image_classification/imagenet.py index 9c774262..cd917358 100644 --- a/project/datasets/image_classification/imagenet.py +++ b/project/datasets/image_classification/imagenet.py @@ -22,7 +22,7 @@ from torchvision.models.resnet import ResNet152_Weights from torchvision.transforms import v2 as transforms -from project.datamodules.image_classification.image_classification import ( +from project.datasets.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.env_vars import DATA_DIR, NETWORK_DIR, NUM_WORKERS diff --git a/project/datasets/image_classification/inaturalist.py b/project/datasets/image_classification/inaturalist.py index 14856fba..eb6f2a61 100644 --- a/project/datasets/image_classification/inaturalist.py +++ b/project/datasets/image_classification/inaturalist.py @@ -9,7 +9,7 @@ import torchvision.transforms as T from torchvision.datasets import INaturalist, VisionDataset -from project.datamodules.vision import VisionDataModule +from project.datasets.vision import VisionDataModule from project.utils.env_vars import DATA_DIR, NUM_WORKERS, SLURM_TMPDIR from project.utils.typing_utils import C, H, W diff --git a/project/datasets/image_classification/inaturalist_test.py b/project/datasets/image_classification/inaturalist_test.py index 7b9757f8..1e836d69 100644 --- a/project/datasets/image_classification/inaturalist_test.py +++ b/project/datasets/image_classification/inaturalist_test.py @@ -6,7 +6,7 @@ from torchvision import transforms as T from torchvision.datasets import INaturalist -from project.datamodules.image_classification.image_classification import ( +from project.datasets.image_classification.image_classification import ( ImageClassificationDataModule, ) diff --git a/project/datasets/image_classification/mnist.py b/project/datasets/image_classification/mnist.py index d635142c..6804ec0f 100644 --- a/project/datasets/image_classification/mnist.py +++ b/project/datasets/image_classification/mnist.py @@ -9,7 +9,7 @@ from torchvision.datasets import MNIST from torchvision.transforms import v2 as transforms -from project.datamodules.image_classification.image_classification import ( +from project.datasets.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.env_vars import DATA_DIR diff --git a/project/datasets/text/text_classification_test.py b/project/datasets/text/text_classification_test.py index 7d450812..633dec05 100644 --- a/project/datasets/text/text_classification_test.py +++ b/project/datasets/text/text_classification_test.py @@ -4,7 +4,7 @@ import lightning import pytest -from project.datamodules.text.text_classification import TextClassificationDataModule +from project.datasets.text.text_classification import TextClassificationDataModule from project.experiment import instantiate_datamodule from project.utils.testutils import get_config_loader diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 96c0d9f9..67f3d4d7 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -17,8 +17,8 @@ import torchvision.models from hydra.core.config_store import ConfigStore -from project.datamodules.image_classification.fashion_mnist import FashionMNISTDataModule -from project.datamodules.image_classification.mnist import MNISTDataModule +from project.datasets.image_classification.fashion_mnist import FashionMNISTDataModule +from project.datasets.image_classification.mnist import MNISTDataModule from project.utils.env_vars import NETWORK_DIR from project.utils.hydra_utils import get_outer_class From 35ae25a1e96793906c6f1ffa174f677a20e02faf Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 2 Dec 2024 13:07:00 -0500 Subject: [PATCH 3/5] Fix outdated names in configs and docs Signed-off-by: Fabrice Normandin --- docs/examples/text_classification.md | 2 +- project/configs/config.py | 4 ++-- project/configs/datamodule/cifar10.yaml | 4 ++-- project/configs/datamodule/fashion_mnist.yaml | 2 +- project/configs/datamodule/glue_cola.yaml | 2 +- project/configs/datamodule/imagenet.yaml | 2 +- project/configs/datamodule/inaturalist.yaml | 2 +- project/configs/datamodule/mnist.yaml | 4 ++-- project/configs/datamodule/vision.yaml | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/examples/text_classification.md b/docs/examples/text_classification.md index 1ebc1c00..79a3477d 100644 --- a/docs/examples/text_classification.md +++ b/docs/examples/text_classification.md @@ -1,7 +1,7 @@ --- additional_python_references: - project.algorithms.text_classifier - - project.datamodules.text.text_classification + - project.datasets.text.text_classification --- # Text Classification (⚡ + 🤗) diff --git a/project/configs/config.py b/project/configs/config.py index 277b0f6f..bee93e8f 100644 --- a/project/configs/config.py +++ b/project/configs/config.py @@ -34,14 +34,14 @@ class Config: """Configuration for the datamodule (dataset + transforms + dataloader creation). This should normally create a [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]. - See the [MNISTDataModule][project.datamodules.image_classification.mnist.MNISTDataModule] for an example. + See the [MNISTDataModule][project.datasets.image_classification.mnist.MNISTDataModule] for an example. """ datamodule: Optional[Any] = None # noqa """Configuration for the datamodule (dataset + transforms + dataloader creation). This should normally create a [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]. - See the [MNISTDataModule][project.datamodules.image_classification.mnist.MNISTDataModule] for an example. + See the [MNISTDataModule][project.datasets.image_classification.mnist.MNISTDataModule] for an example. """ trainer: dict = field(default_factory=dict) diff --git a/project/configs/datamodule/cifar10.yaml b/project/configs/datamodule/cifar10.yaml index 2410ef7c..bd121c5c 100644 --- a/project/configs/datamodule/cifar10.yaml +++ b/project/configs/datamodule/cifar10.yaml @@ -1,8 +1,8 @@ defaults: - vision - _self_ -_target_: project.datamodules.CIFAR10DataModule +_target_: project.datasets.CIFAR10DataModule data_dir: ${constant:torchvision_dir,DATA_DIR} batch_size: 128 train_transforms: - _target_: project.datamodules.image_classification.cifar10.cifar10_train_transforms + _target_: project.datasets.image_classification.cifar10.cifar10_train_transforms diff --git a/project/configs/datamodule/fashion_mnist.yaml b/project/configs/datamodule/fashion_mnist.yaml index 472a4d96..0ae43b44 100644 --- a/project/configs/datamodule/fashion_mnist.yaml +++ b/project/configs/datamodule/fashion_mnist.yaml @@ -1,4 +1,4 @@ defaults: - mnist - _self_ -_target_: project.datamodules.FashionMNISTDataModule +_target_: project.datasets.FashionMNISTDataModule diff --git a/project/configs/datamodule/glue_cola.yaml b/project/configs/datamodule/glue_cola.yaml index 078a153d..35aad0ec 100644 --- a/project/configs/datamodule/glue_cola.yaml +++ b/project/configs/datamodule/glue_cola.yaml @@ -1,4 +1,4 @@ -_target_: project.datamodules.text.TextClassificationDataModule +_target_: project.datasets.text.TextClassificationDataModule data_dir: ${oc.env:SCRATCH,.}/data hf_dataset_path: glue task_name: cola diff --git a/project/configs/datamodule/imagenet.yaml b/project/configs/datamodule/imagenet.yaml index 23804087..b71d8a99 100644 --- a/project/configs/datamodule/imagenet.yaml +++ b/project/configs/datamodule/imagenet.yaml @@ -1,5 +1,5 @@ defaults: - vision - _self_ -_target_: project.datamodules.ImageNetDataModule +_target_: project.datasets.ImageNetDataModule # todo: add good configuration options here. diff --git a/project/configs/datamodule/inaturalist.yaml b/project/configs/datamodule/inaturalist.yaml index d3621b0f..5ff33ba2 100644 --- a/project/configs/datamodule/inaturalist.yaml +++ b/project/configs/datamodule/inaturalist.yaml @@ -1,6 +1,6 @@ defaults: - vision - _self_ -_target_: project.datamodules.INaturalistDataModule +_target_: project.datasets.INaturalistDataModule version: "2021_train" target_type: "full" diff --git a/project/configs/datamodule/mnist.yaml b/project/configs/datamodule/mnist.yaml index 625b1ad9..4de0c908 100644 --- a/project/configs/datamodule/mnist.yaml +++ b/project/configs/datamodule/mnist.yaml @@ -1,9 +1,9 @@ defaults: - vision - _self_ -_target_: project.datamodules.MNISTDataModule +_target_: project.datasets.MNISTDataModule data_dir: ${constant:torchvision_dir,DATA_DIR} normalize: True batch_size: 128 train_transforms: - _target_: project.datamodules.image_classification.mnist.mnist_train_transforms + _target_: project.datasets.image_classification.mnist.mnist_train_transforms diff --git a/project/configs/datamodule/vision.yaml b/project/configs/datamodule/vision.yaml index 561a36b1..fcc6e1b4 100644 --- a/project/configs/datamodule/vision.yaml +++ b/project/configs/datamodule/vision.yaml @@ -1,5 +1,5 @@ # todo: This config should not show up as an option on the command-line. -_target_: project.datamodules.VisionDataModule +_target_: project.datasets.VisionDataModule data_dir: ${constant:DATA_DIR} num_workers: ${constant:NUM_WORKERS} val_split: 0.1 # NOTE: reduced from default of 0.2 From b1a0cd0302e6bfea9b996cc0bdaede44d8985bb3 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 2 Dec 2024 13:44:48 -0500 Subject: [PATCH 4/5] Remove unnecessary generic `Callback` class Signed-off-by: Fabrice Normandin --- project/algorithms/callbacks/__init__.py | 2 - project/algorithms/callbacks/callback.py | 273 ------------------ .../callbacks/classification_metrics.py | 86 +++++- .../callbacks/samples_per_second.py | 124 +++++++- project/algorithms/jax_ppo_test.py | 7 +- 5 files changed, 195 insertions(+), 297 deletions(-) delete mode 100644 project/algorithms/callbacks/callback.py diff --git a/project/algorithms/callbacks/__init__.py b/project/algorithms/callbacks/__init__.py index 1f8fc4b3..585714db 100644 --- a/project/algorithms/callbacks/__init__.py +++ b/project/algorithms/callbacks/__init__.py @@ -1,9 +1,7 @@ -from .callback import Callback from .classification_metrics import ClassificationMetricsCallback from .samples_per_second import MeasureSamplesPerSecondCallback __all__ = [ - "Callback", "ClassificationMetricsCallback", "MeasureSamplesPerSecondCallback", ] diff --git a/project/algorithms/callbacks/callback.py b/project/algorithms/callbacks/callback.py deleted file mode 100644 index 05c42bbb..00000000 --- a/project/algorithms/callbacks/callback.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from logging import getLogger as get_logger -from pathlib import Path -from typing import Any, Generic, Literal - -import torch -from lightning import LightningModule, Trainer -from lightning import pytorch as pl -from typing_extensions import TypeVar, override - -from project.utils.typing_utils import NestedMapping - -logger = get_logger(__name__) - -BatchType = TypeVar( - "BatchType", - bound=torch.Tensor | tuple[torch.Tensor, ...] | NestedMapping[str, torch.Tensor], - contravariant=True, -) -StepOutputType = TypeVar( - "StepOutputType", - bound=torch.Tensor | Mapping[str, Any] | None, - default=dict[str, torch.Tensor], - contravariant=True, -) - - -class Callback(pl.Callback, Generic[BatchType, StepOutputType]): - """Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class. - - Adds the following typing information: - - The type of inputs that the algorithm takes - - The type of outputs that are returned by the algorithm's `[training/validation/test]_step` methods. - - Adds the following methods: - - `on_shared_batch_start`: called by `on_[train/validation/test]_batch_start` - - `on_shared_batch_end`: called by `on_[train/validation/test]_batch_end` - - `on_shared_epoch_start`: called by `on_[train/validation/test]_epoch_start` - - `on_shared_epoch_end`: called by `on_[train/validation/test]_epoch_end` - """ - - def __init__(self) -> None: - super().__init__() - self.log_dir: Path | None = None - - @override - def setup( - self, - trainer: pl.Trainer, - pl_module: LightningModule, - # todo: "tune" is mentioned in the docstring, is it still used? - stage: Literal["fit", "validate", "test", "predict", "tune"], - ) -> None: - self.log_dir = Path(trainer.log_dir or trainer.default_root_dir) - - def on_shared_batch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - batch: BatchType, - batch_index: int, - phase: Literal["train", "val", "test"], - dataloader_idx: int | None = None, - ): - """Shared hook, called by `on_[train/validation/test]_batch_start`. - - Use this if you want to do something at the start of batches in more than one phase. - """ - - def on_shared_batch_end( - self, - trainer: Trainer, - pl_module: LightningModule, - outputs: StepOutputType, - batch: BatchType, - batch_index: int, - phase: Literal["train", "val", "test"], - dataloader_idx: int | None = None, - ): - """Shared hook, called by `on_[train/validation/test]_batch_end`. - - Use this if you want to do something at the end of batches in more than one phase. - """ - - def on_shared_epoch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - phase: Literal["train", "val", "test"], - ) -> None: - """Shared hook, called by `on_[train/validation/test]_epoch_start`. - - Use this if you want to do something at the start of epochs in more than one phase. - """ - - def on_shared_epoch_end( - self, - trainer: Trainer, - pl_module: LightningModule, - phase: Literal["train", "val", "test"], - ) -> None: - """Shared hook, called by `on_[train/validation/test]_epoch_end`. - - Use this if you want to do something at the end of epochs in more than one phase. - """ - - @override - def on_train_batch_end( - self, - trainer: Trainer, - pl_module: LightningModule, - outputs: StepOutputType, - batch: BatchType, - batch_index: int, - ) -> None: - super().on_train_batch_end( - trainer=trainer, - pl_module=pl_module, - outputs=outputs, - batch=batch, - batch_idx=batch_index, - ) - self.on_shared_batch_end( - trainer=trainer, - pl_module=pl_module, - outputs=outputs, - batch=batch, - batch_index=batch_index, - phase="train", - ) - - @override - def on_validation_batch_end( - self, - trainer: Trainer, - pl_module: LightningModule, - outputs: StepOutputType, - batch: BatchType, - batch_index: int, - dataloader_idx: int = 0, - ) -> None: - super().on_validation_batch_end( - trainer=trainer, - pl_module=pl_module, - outputs=outputs, # type: ignore - batch=batch, - batch_idx=batch_index, - dataloader_idx=dataloader_idx, - ) - self.on_shared_batch_end( - trainer=trainer, - pl_module=pl_module, - outputs=outputs, - batch=batch, - batch_index=batch_index, - phase="val", - dataloader_idx=dataloader_idx, - ) - - @override - def on_test_batch_end( - self, - trainer: Trainer, - pl_module: LightningModule, - outputs: StepOutputType, - batch: BatchType, - batch_index: int, - dataloader_idx: int = 0, - ) -> None: - super().on_test_batch_end( - trainer=trainer, - pl_module=pl_module, - outputs=outputs, # type: ignore - batch=batch, - batch_idx=batch_index, - dataloader_idx=dataloader_idx, - ) - self.on_shared_batch_end( - trainer=trainer, - pl_module=pl_module, - outputs=outputs, - batch=batch, - batch_index=batch_index, - dataloader_idx=dataloader_idx, - phase="test", - ) - - @override - def on_train_batch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - batch: BatchType, - batch_index: int, - ) -> None: - super().on_train_batch_start(trainer, pl_module, batch, batch_index) - self.on_shared_batch_start( - trainer=trainer, - pl_module=pl_module, - batch=batch, - batch_index=batch_index, - phase="train", - ) - - @override - def on_validation_batch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - batch: BatchType, - batch_index: int, - dataloader_idx: int = 0, - ) -> None: - super().on_validation_batch_start(trainer, pl_module, batch, batch_index, dataloader_idx) - self.on_shared_batch_start( - trainer, - pl_module, - batch, - batch_index, - dataloader_idx=dataloader_idx, - phase="val", - ) - - @override - def on_test_batch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - batch: BatchType, - batch_index: int, - dataloader_idx: int = 0, - ) -> None: - super().on_test_batch_start(trainer, pl_module, batch, batch_index, dataloader_idx) - self.on_shared_batch_start( - trainer, - pl_module, - batch, - batch_index, - dataloader_idx=dataloader_idx, - phase="test", - ) - - @override - def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_train_epoch_start(trainer, pl_module) - self.on_shared_epoch_start(trainer, pl_module, phase="train") - - @override - def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_validation_epoch_start(trainer, pl_module) - self.on_shared_epoch_start(trainer, pl_module, phase="val") - - @override - def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_test_epoch_start(trainer, pl_module) - self.on_shared_epoch_start(trainer, pl_module, phase="test") - - @override - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_train_epoch_end(trainer, pl_module) - self.on_shared_epoch_end(trainer, pl_module, phase="train") - - @override - def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_validation_epoch_end(trainer, pl_module) - self.on_shared_epoch_end(trainer, pl_module, phase="val") - - @override - def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: - super().on_test_epoch_end(trainer, pl_module) - self.on_shared_epoch_end(trainer, pl_module, phase="test") diff --git a/project/algorithms/callbacks/classification_metrics.py b/project/algorithms/callbacks/classification_metrics.py index 8988216d..10ce7e97 100644 --- a/project/algorithms/callbacks/classification_metrics.py +++ b/project/algorithms/callbacks/classification_metrics.py @@ -2,6 +2,7 @@ from logging import getLogger as get_logger from typing import Literal, TypedDict +import lightning import torch import torchmetrics from lightning import LightningModule, Trainer @@ -9,7 +10,6 @@ from torchmetrics.classification import MulticlassAccuracy from typing_extensions import NotRequired, Required, override -from project.algorithms.callbacks.callback import BatchType, Callback from project.utils.typing_utils.protocols import ClassificationDataModule logger = get_logger(__name__) @@ -30,7 +30,7 @@ class ClassificationOutputs(TypedDict, total=False): """The class labels.""" -class ClassificationMetricsCallback(Callback[BatchType, ClassificationOutputs]): +class ClassificationMetricsCallback(lightning.Callback): """Callback that adds classification metrics to a LightningModule.""" def __init__(self) -> None: @@ -105,12 +105,92 @@ def setup( self.add_metrics_to(pl_module, num_classes=num_classes) @override + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: ClassificationOutputs, + batch: tuple[Tensor, Tensor], + batch_index: int, + ) -> None: + super().on_train_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_idx=batch_index, + ) + self.on_shared_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_index=batch_index, + phase="train", + ) + + @override + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: ClassificationOutputs, + batch: tuple[Tensor, Tensor], + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + super().on_validation_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, # type: ignore + batch=batch, + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + ) + self.on_shared_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_index=batch_idx, + phase="val", + dataloader_idx=dataloader_idx, + ) + + @override + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: ClassificationOutputs, + batch: tuple[Tensor, Tensor], + batch_index: int, + dataloader_idx: int = 0, + ) -> None: + super().on_test_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, # type: ignore + batch=batch, + batch_idx=batch_index, + dataloader_idx=dataloader_idx, + ) + self.on_shared_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_index=batch_index, + dataloader_idx=dataloader_idx, + phase="test", + ) + def on_shared_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: ClassificationOutputs, - batch: BatchType, + batch: tuple[Tensor, Tensor], batch_index: int, phase: Literal["train", "val", "test"], dataloader_idx: int | None = None, diff --git a/project/algorithms/callbacks/samples_per_second.py b/project/algorithms/callbacks/samples_per_second.py index d0134cb1..a6ad820d 100644 --- a/project/algorithms/callbacks/samples_per_second.py +++ b/project/algorithms/callbacks/samples_per_second.py @@ -1,18 +1,25 @@ import time -from typing import Any, Literal +from typing import Any, Generic, Literal import jax +import lightning import torch from lightning import LightningModule, Trainer +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor -from torch.optim import Optimizer -from typing_extensions import override +from torch.optim.optimizer import Optimizer +from typing_extensions import TypeVar, override -from project.algorithms.callbacks.callback import BatchType, Callback, StepOutputType -from project.utils.typing_utils import is_sequence_of +from project.utils.typing_utils import NestedMapping, is_sequence_of +BatchType = TypeVar( + "BatchType", + bound=torch.Tensor | tuple[torch.Tensor, ...] | NestedMapping[str, torch.Tensor], + contravariant=True, +) -class MeasureSamplesPerSecondCallback(Callback[BatchType, StepOutputType]): + +class MeasureSamplesPerSecondCallback(lightning.Callback, Generic[BatchType]): def __init__(self, num_optimizers: int | None = None): super().__init__() self.last_step_times: dict[Literal["train", "val", "test"], float] = {} @@ -20,6 +27,20 @@ def __init__(self, num_optimizers: int | None = None): self.num_optimizers: int | None = num_optimizers @override + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_train_epoch_start(trainer, pl_module) + self.on_shared_epoch_start(trainer, pl_module, phase="train") + + @override + def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_validation_epoch_start(trainer, pl_module) + self.on_shared_epoch_start(trainer, pl_module, phase="val") + + @override + def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + super().on_test_epoch_start(trainer, pl_module) + self.on_shared_epoch_start(trainer, pl_module, phase="test") + def on_shared_epoch_start( self, trainer: Trainer, @@ -36,25 +57,96 @@ def on_shared_epoch_start( self.num_optimizers = len(optimizer_or_optimizers) @override - def on_shared_batch_end( + def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, - outputs: StepOutputType, + outputs: STEP_OUTPUT, batch: BatchType, - batch_index: int, - phase: Literal["train", "val", "test"], - dataloader_idx: int | None = None, - ): - super().on_shared_batch_end( - trainer, + batch_idx: int, + ) -> None: + super().on_train_batch_end( + trainer=trainer, pl_module=pl_module, outputs=outputs, batch=batch, - batch_index=batch_index, - phase=phase, + batch_idx=batch_idx, + ) + self.on_shared_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_index=batch_idx, + phase="train", + ) + + @override + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: STEP_OUTPUT, + batch: BatchType, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + super().on_validation_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, # type: ignore + batch=batch, + batch_idx=batch_idx, dataloader_idx=dataloader_idx, ) + self.on_shared_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_index=batch_idx, + phase="val", + dataloader_idx=dataloader_idx, + ) + + @override + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: STEP_OUTPUT, + batch: BatchType, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + super().on_test_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, # type: ignore + batch=batch, + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + ) + self.on_shared_batch_end( + trainer=trainer, + pl_module=pl_module, + outputs=outputs, + batch=batch, + batch_index=batch_idx, + dataloader_idx=dataloader_idx, + phase="test", + ) + + def on_shared_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: STEP_OUTPUT, + batch: BatchType, + batch_index: int, + phase: Literal["train", "val", "test"], + dataloader_idx: int | None = None, + ): now = time.perf_counter() if phase in self.last_step_times: elapsed = now - self.last_step_times[phase] diff --git a/project/algorithms/jax_ppo_test.py b/project/algorithms/jax_ppo_test.py index 20a3026a..c250d479 100644 --- a/project/algorithms/jax_ppo_test.py +++ b/project/algorithms/jax_ppo_test.py @@ -23,6 +23,7 @@ from gymnax.environments.environment import Environment from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar from lightning.pytorch.loggers import CSVLogger +from lightning.pytorch.utilities.types import STEP_OUTPUT from tensor_regression import TensorRegressionFixture from torch.utils.data import DataLoader from typing_extensions import override @@ -694,11 +695,11 @@ def on_train_batch_end( self, trainer: lightning.Trainer, pl_module: lightning.LightningModule, - outputs: dict[str, torch.Tensor], + outputs: STEP_OUTPUT, batch: TrajectoryWithLastObs, - batch_index: int, + batch_idx: int, ) -> None: - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_index) + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) if not isinstance(batch, TrajectoryWithLastObs): return episodes = batch.trajectories From 87fb5b7f5b6a82be28f284a796f667ab5d33825a Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 2 Dec 2024 14:07:29 -0500 Subject: [PATCH 5/5] Rename other uses of datamodule-->dataset Signed-off-by: Fabrice Normandin --- docs/examples/jax_image_classification.md | 2 +- docs/examples/text_classification.md | 4 +- docs/profiling_test.py | 10 ++-- .../callbacks/classification_metrics.py | 2 +- project/algorithms/image_classifier.py | 2 +- project/algorithms/image_classifier_test.py | 4 +- project/algorithms/jax_image_classifier.py | 4 +- project/algorithms/llm_finetuning_test.py | 8 +-- .../testsuites/lightning_module_tests.py | 15 +++-- project/algorithms/text_classifier_test.py | 8 +-- project/configs/__init__.py | 2 +- project/configs/config.py | 13 +--- project/configs/datamodule/__init__.py | 37 ------------ project/configs/datamodule/cifar10.yaml | 8 --- project/configs/datamodule/fashion_mnist.yaml | 4 -- project/configs/datamodule/glue_cola.yaml | 19 ------ project/configs/datamodule/imagenet.yaml | 5 -- project/configs/datamodule/inaturalist.yaml | 6 -- project/configs/datamodule/mnist.yaml | 9 --- project/configs/datamodule/vision.yaml | 10 ---- .../configs/trainer/overfit_one_batch.yaml | 2 +- project/conftest.py | 60 +++++++++---------- project/datasets/datasets_test.py | 30 +++++----- .../image_classification/fashion_mnist.py | 2 +- .../datasets/text/text_classification_test.py | 4 +- project/experiment.py | 28 ++++----- project/main.py | 26 ++++---- project/main_test.py | 6 +- project/utils/remote_launcher_plugin_test.py | 4 +- 29 files changed, 115 insertions(+), 219 deletions(-) delete mode 100644 project/configs/datamodule/__init__.py delete mode 100644 project/configs/datamodule/cifar10.yaml delete mode 100644 project/configs/datamodule/fashion_mnist.yaml delete mode 100644 project/configs/datamodule/glue_cola.yaml delete mode 100644 project/configs/datamodule/imagenet.yaml delete mode 100644 project/configs/datamodule/inaturalist.yaml delete mode 100644 project/configs/datamodule/mnist.yaml delete mode 100644 project/configs/datamodule/vision.yaml diff --git a/docs/examples/jax_image_classification.md b/docs/examples/jax_image_classification.md index ee1ddc99..61595cb2 100644 --- a/docs/examples/jax_image_classification.md +++ b/docs/examples/jax_image_classification.md @@ -41,5 +41,5 @@ pass uses Jax to calculate the gradients, and the weights are updated by a PyTor ## Running the example ```console -$ python project/main.py algorithm=jax_image_classifier network=jax_cnn datamodule=cifar10 +$ python project/main.py algorithm=jax_image_classifier network=jax_cnn dataset=cifar10 ``` diff --git a/docs/examples/text_classification.md b/docs/examples/text_classification.md index 79a3477d..a861cc24 100644 --- a/docs/examples/text_classification.md +++ b/docs/examples/text_classification.md @@ -27,9 +27,9 @@ It accepts a `TextClassificationDataModule` as input, along with a network. ### Datamodule config ??? note "Click to show the Datamodule config" - Source: project/configs/datamodule/glue_cola.yaml + Source: project/configs/dataset/glue_cola.yaml - {{ inline('project/configs/datamodule/glue_cola.yaml', 4) }} + {{ inline('project/configs/dataset/glue_cola.yaml', 4) }} ## Running the example diff --git a/docs/profiling_test.py b/docs/profiling_test.py index 14d02549..d29d9b67 100644 --- a/docs/profiling_test.py +++ b/docs/profiling_test.py @@ -11,12 +11,12 @@ algorithm_network_config, command_line_arguments, command_line_overrides, - datamodule_config, + dataset_config, experiment_dictconfig, ) from project.experiment import ( instantiate_algorithm, - instantiate_datamodule, + instantiate_dataset, instantiate_trainer, setup_logging, ) @@ -80,7 +80,7 @@ experiment=profiling \ algorithm=image_classifier \ algorithm/network=fcnet \ - datamodule=mnist \ + dataset=mnist \ trainer.logger.wandb.name="FcNet/MNIST baseline with training" \ trainer.logger.wandb.tags=["CPU/GPU comparison","GPU","MNIST"] """, @@ -122,7 +122,7 @@ def test_notebook_commands_dont_cause_errors(experiment_dictconfig: DictConfig): setup_logging(log_level=config.log_level) lightning.seed_everything(config.seed, workers=True) _trainer = instantiate_trainer(config) - datamodule = instantiate_datamodule(config.datamodule) - _algorithm = instantiate_algorithm(config.algorithm, datamodule=datamodule) + dataset = instantiate_dataset(config.dataset) + _algorithm = instantiate_algorithm(config.algorithm, dataset=dataset) # Note: Here we don't actually do anything with the objects. diff --git a/project/algorithms/callbacks/classification_metrics.py b/project/algorithms/callbacks/classification_metrics.py index 10ce7e97..c92b3ccd 100644 --- a/project/algorithms/callbacks/classification_metrics.py +++ b/project/algorithms/callbacks/classification_metrics.py @@ -94,7 +94,7 @@ def setup( warnings.warn( RuntimeWarning( f"Disabling the {type(self).__name__} callback because it only works with " - f"classification datamodules, but {pl_module.datamodule=} isn't a " + f"classification datamodules, but {pl_module.dataset=} isn't a " f"{ClassificationDataModule.__name__}." ) ) diff --git a/project/algorithms/image_classifier.py b/project/algorithms/image_classifier.py index f1555a0a..58d4fb34 100644 --- a/project/algorithms/image_classifier.py +++ b/project/algorithms/image_classifier.py @@ -3,7 +3,7 @@ This can be run from the command-line like so: ```console -python project/main.py algorithm=image_classification datamodule=cifar10 +python project/main.py algorithm=image_classification dataset=cifar10 ``` """ diff --git a/project/algorithms/image_classifier_test.py b/project/algorithms/image_classifier_test.py index 9af90c59..dcfc5ca0 100644 --- a/project/algorithms/image_classifier_test.py +++ b/project/algorithms/image_classifier_test.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize( command_line_overrides.__name__, - ["algorithm=image_classifier datamodule=cifar10"], + ["algorithm=image_classifier dataset=cifar10"], indirect=True, ) def test_example_experiment_defaults(experiment_config: Config) -> None: @@ -28,7 +28,7 @@ def test_example_experiment_defaults(experiment_config: Config) -> None: ImageClassifier.__module__ + "." + ImageClassifier.__qualname__ ) - assert isinstance(experiment_config.datamodule, CIFAR10DataModule) + assert isinstance(experiment_config.dataset, CIFAR10DataModule) @skip_on_macOS_in_CI diff --git a/project/algorithms/jax_image_classifier.py b/project/algorithms/jax_image_classifier.py index fe09cef2..d6d83668 100644 --- a/project/algorithms/jax_image_classifier.py +++ b/project/algorithms/jax_image_classifier.py @@ -229,11 +229,11 @@ def demo(**trainer_kwargs): network = JaxCNN(num_classes=datamodule.num_classes) optimizer = functools.partial(torch.optim.SGD, lr=0.01) # type: ignore model = JaxImageClassifier( - datamodule=datamodule, + dataset=datamodule, network=hydra_zen.just(network), # type: ignore optimizer=hydra_zen.just(optimizer), # type: ignore ) - trainer.fit(model, datamodule=datamodule) + trainer.fit(model, dataset=datamodule) ... diff --git a/project/algorithms/llm_finetuning_test.py b/project/algorithms/llm_finetuning_test.py index de75dc1a..57b144cc 100644 --- a/project/algorithms/llm_finetuning_test.py +++ b/project/algorithms/llm_finetuning_test.py @@ -122,7 +122,7 @@ def test_training_batch_doesnt_change( def test_initialization_is_reproducible( self, experiment_config: Config, - datamodule: lightning.LightningDataModule, + dataset: lightning.LightningDataModule, seed: int, tensor_regression: TensorRegressionFixture, trainer: lightning.Trainer, @@ -130,7 +130,7 @@ def test_initialization_is_reproducible( ): super().test_initialization_is_reproducible( experiment_config=experiment_config, - datamodule=datamodule, + dataset=dataset, seed=seed, tensor_regression=tensor_regression, trainer=trainer, @@ -159,7 +159,7 @@ def test_forward_pass_is_reproducible( ) def test_backward_pass_is_reproducible( self, - datamodule: lightning.LightningDataModule, + dataset: lightning.LightningDataModule, algorithm: LLMFinetuningExample, seed: int, accelerator: str, @@ -168,5 +168,5 @@ def test_backward_pass_is_reproducible( tmp_path: Path, ): return super().test_backward_pass_is_reproducible( - datamodule, algorithm, seed, accelerator, devices, tensor_regression, tmp_path + dataset, algorithm, seed, accelerator, devices, tensor_regression, tmp_path ) diff --git a/project/algorithms/testsuites/lightning_module_tests.py b/project/algorithms/testsuites/lightning_module_tests.py index dedc6118..4c49034c 100644 --- a/project/algorithms/testsuites/lightning_module_tests.py +++ b/project/algorithms/testsuites/lightning_module_tests.py @@ -15,7 +15,7 @@ import lightning import pytest import torch -from lightning import LightningDataModule, LightningModule +from lightning import LightningModule from tensor_regression import TensorRegressionFixture from project.configs.config import Config @@ -63,7 +63,7 @@ def test_initialization_is_reproducible( """Check that the network initialization is reproducible given the same random seed.""" with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): torch.random.manual_seed(seed) - algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) + algorithm = instantiate_algorithm(experiment_config.algorithm, dataset=datamodule) assert isinstance(algorithm, lightning.LightningModule) # A bit hacky, but we have to do this because the lightningmodule isn't associated # with a Trainer here. @@ -103,7 +103,7 @@ def test_forward_pass_is_reproducible( def test_backward_pass_is_reproducible( self, - datamodule: LightningDataModule, + dataset: lightning.LightningDataModule | None, algorithm: AlgorithmType, seed: int, accelerator: str, @@ -119,7 +119,7 @@ def test_backward_pass_is_reproducible( gradients_callback = GetStuffFromFirstTrainingStep() self.do_one_step_of_training( algorithm, - datamodule, + dataset, accelerator=accelerator, devices=devices, callbacks=[gradients_callback], @@ -178,7 +178,7 @@ def to_device(v): def do_one_step_of_training( self, algorithm: AlgorithmType, - datamodule: LightningDataModule, + dataset: lightning.LightningDataModule | None, accelerator: str, devices: int | list[int] | Literal["auto"], callbacks: list[lightning.Callback], @@ -198,7 +198,10 @@ def do_one_step_of_training( deterministic=True, default_root_dir=tmp_path, ) - trainer.fit(algorithm, datamodule=datamodule) + if isinstance(dataset, lightning.LightningDataModule): + trainer.fit(algorithm, datamodule=dataset) + else: + trainer.fit(algorithm) return callbacks diff --git a/project/algorithms/text_classifier_test.py b/project/algorithms/text_classifier_test.py index fc9e19c1..77fd7770 100644 --- a/project/algorithms/text_classifier_test.py +++ b/project/algorithms/text_classifier_test.py @@ -50,7 +50,7 @@ class TestTextClassifier(LightningModuleTests[TextClassifier]): ) def test_backward_pass_is_reproducible( # type: ignore self, - datamodule: TextClassificationDataModule, + dataset: TextClassificationDataModule, algorithm: TextClassifier, seed: int, accelerator: str, @@ -59,7 +59,7 @@ def test_backward_pass_is_reproducible( # type: ignore tmp_path: Path, ): return super().test_backward_pass_is_reproducible( - datamodule=datamodule, + dataset=dataset, algorithm=algorithm, seed=seed, accelerator=accelerator, @@ -73,7 +73,7 @@ def test_backward_pass_is_reproducible( # type: ignore def test_overfit_batch( self, algorithm: TextClassifier, - datamodule: TextClassificationDataModule, + dataset: TextClassificationDataModule, tmp_path: Path, num_steps: int = 3, ): @@ -91,7 +91,7 @@ def test_overfit_batch( limit_train_batches=1, max_epochs=num_steps, ) - trainer.fit(algorithm, datamodule) + trainer.fit(algorithm, dataset) losses_at_each_epoch: list[Tensor] = get_loss_cb.losses assert ( diff --git a/project/configs/__init__.py b/project/configs/__init__.py index 8bf93ed8..31be553b 100644 --- a/project/configs/__init__.py +++ b/project/configs/__init__.py @@ -8,7 +8,7 @@ from project.configs.algorithm.network import network_store from project.configs.algorithm.optimizer import optimizers_store from project.configs.config import Config -from project.configs.datamodule import datamodule_store +from project.configs.dataset import datamodule_store from project.utils.remote_launcher_plugin import RemoteSlurmQueueConf cs = ConfigStore.instance() diff --git a/project/configs/config.py b/project/configs/config.py index bee93e8f..e65adb18 100644 --- a/project/configs/config.py +++ b/project/configs/config.py @@ -21,7 +21,7 @@ class Config: """ algorithm: Any - """Configuration for the algorithm (a + """Configuration for the algorithm to use during training (typically a [LightningModule][lightning.pytorch.core.module.LightningModule]). It is suggested for this class to accept a `datamodule` and `network` as arguments. The @@ -30,15 +30,8 @@ class Config: For more info, see the [instantiate_algorithm][project.experiment.instantiate_algorithm] function. """ - datamodule: Any | None = None - """Configuration for the datamodule (dataset + transforms + dataloader creation). - - This should normally create a [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]. - See the [MNISTDataModule][project.datasets.image_classification.mnist.MNISTDataModule] for an example. - """ - - datamodule: Optional[Any] = None # noqa - """Configuration for the datamodule (dataset + transforms + dataloader creation). + dataset: Optional[Any] = None # noqa + """Configuration for the dataset or datamodule (dataset + transforms + dataloader creation). This should normally create a [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]. See the [MNISTDataModule][project.datasets.image_classification.mnist.MNISTDataModule] for an example. diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py deleted file mode 100644 index d9b68bc5..00000000 --- a/project/configs/datamodule/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -from logging import getLogger as get_logger - -from hydra_zen import store - -logger = get_logger(__name__) - - -# TODO: Make it possible to extend a structured base via yaml files as well as adding new fields -# (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the -# config). -datamodule_store = store(group="datamodule") - - -# @hydrated_dataclass(target=VisionDataModule, populate_full_signature=True) -# class VisionDataModuleConfig: -# data_dir: str | None = str(torchvision_dir or DATA_DIR) -# val_split: int | float = 0.1 # NOTE: reduced from default of 0.2 -# num_workers: int = NUM_WORKERS -# normalize: bool = True # NOTE: Set to True by default instead of False -# batch_size: int = 32 -# seed: int = 42 -# shuffle: bool = True # NOTE: Set to True by default instead of False. -# pin_memory: bool = True # NOTE: Set to True by default instead of False. -# drop_last: bool = False - -# __call__ = instantiate - - -# datamodule_store(VisionDataModuleConfig, name="vision") - -# inaturalist_config = hydra_zen.builds( -# INaturalistDataModule, -# builds_bases=(VisionDataModuleConfig,), -# populate_full_signature=True, -# dataclass_name=f"{INaturalistDataModule.__name__}Config", -# ) -# datamodule_store(inaturalist_config, name="inaturalist") diff --git a/project/configs/datamodule/cifar10.yaml b/project/configs/datamodule/cifar10.yaml deleted file mode 100644 index bd121c5c..00000000 --- a/project/configs/datamodule/cifar10.yaml +++ /dev/null @@ -1,8 +0,0 @@ -defaults: - - vision - - _self_ -_target_: project.datasets.CIFAR10DataModule -data_dir: ${constant:torchvision_dir,DATA_DIR} -batch_size: 128 -train_transforms: - _target_: project.datasets.image_classification.cifar10.cifar10_train_transforms diff --git a/project/configs/datamodule/fashion_mnist.yaml b/project/configs/datamodule/fashion_mnist.yaml deleted file mode 100644 index 0ae43b44..00000000 --- a/project/configs/datamodule/fashion_mnist.yaml +++ /dev/null @@ -1,4 +0,0 @@ -defaults: - - mnist - - _self_ -_target_: project.datasets.FashionMNISTDataModule diff --git a/project/configs/datamodule/glue_cola.yaml b/project/configs/datamodule/glue_cola.yaml deleted file mode 100644 index 35aad0ec..00000000 --- a/project/configs/datamodule/glue_cola.yaml +++ /dev/null @@ -1,19 +0,0 @@ -_target_: project.datasets.text.TextClassificationDataModule -data_dir: ${oc.env:SCRATCH,.}/data -hf_dataset_path: glue -task_name: cola -text_fields: - - "sentence" -tokenizer: - _target_: transformers.models.auto.tokenization_auto.AutoTokenizer.from_pretrained - use_fast: true - # Note: We could interpolate this value with `${/algorithm/network/pretrained_model_name_or_path}` - # to avoid duplicating a value, but this also makes it harder to use this by itself or with - # another algorithm. - pretrained_model_name_or_path: albert-base-v2 - cache_dir: ${..data_dir} - trust_remote_code: true -num_classes: 2 -max_seq_length: 128 -train_batch_size: 32 -eval_batch_size: 32 diff --git a/project/configs/datamodule/imagenet.yaml b/project/configs/datamodule/imagenet.yaml deleted file mode 100644 index b71d8a99..00000000 --- a/project/configs/datamodule/imagenet.yaml +++ /dev/null @@ -1,5 +0,0 @@ -defaults: - - vision - - _self_ -_target_: project.datasets.ImageNetDataModule -# todo: add good configuration options here. diff --git a/project/configs/datamodule/inaturalist.yaml b/project/configs/datamodule/inaturalist.yaml deleted file mode 100644 index 5ff33ba2..00000000 --- a/project/configs/datamodule/inaturalist.yaml +++ /dev/null @@ -1,6 +0,0 @@ -defaults: - - vision - - _self_ -_target_: project.datasets.INaturalistDataModule -version: "2021_train" -target_type: "full" diff --git a/project/configs/datamodule/mnist.yaml b/project/configs/datamodule/mnist.yaml deleted file mode 100644 index 4de0c908..00000000 --- a/project/configs/datamodule/mnist.yaml +++ /dev/null @@ -1,9 +0,0 @@ -defaults: - - vision - - _self_ -_target_: project.datasets.MNISTDataModule -data_dir: ${constant:torchvision_dir,DATA_DIR} -normalize: True -batch_size: 128 -train_transforms: - _target_: project.datasets.image_classification.mnist.mnist_train_transforms diff --git a/project/configs/datamodule/vision.yaml b/project/configs/datamodule/vision.yaml deleted file mode 100644 index fcc6e1b4..00000000 --- a/project/configs/datamodule/vision.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# todo: This config should not show up as an option on the command-line. -_target_: project.datasets.VisionDataModule -data_dir: ${constant:DATA_DIR} -num_workers: ${constant:NUM_WORKERS} -val_split: 0.1 # NOTE: reduced from default of 0.2 -normalize: True # NOTE: Set to True by default instead of False -shuffle: True # NOTE: Set to True by default instead of False. -pin_memory: True # NOTE: Set to True by default instead of False. -seed: 42 -batch_size: 64 diff --git a/project/configs/trainer/overfit_one_batch.yaml b/project/configs/trainer/overfit_one_batch.yaml index 80d02ae8..16f80dca 100644 --- a/project/configs/trainer/overfit_one_batch.yaml +++ b/project/configs/trainer/overfit_one_batch.yaml @@ -1,5 +1,5 @@ # Note: This configuration should be run in combination with an algorithm. For example like this: -# `python project/main.py algorithm=example datamodule=cifar10 trainer=overfit_one_batch` +# `python project/main.py algorithm=example dataset=cifar10 trainer=overfit_one_batch` # defaults: - default diff --git a/project/conftest.py b/project/conftest.py index e3c51c1d..759447b2 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -14,9 +14,9 @@ The fixtures for command-line arguments -For example, one of the fixtures which is created first is [datamodule_config][project.conftest.datamodule_config]. +For example, one of the fixtures which is created first is [dataset_config][project.conftest.dataset_config]. -The first fixtures to be created are the [datamodule_config][project.conftest.datamodule_config], `network_config` and `algorithm_config`, along with `overrides`. +The first fixtures to be created are the [dataset_config][project.conftest.dataset_config], `network_config` and `algorithm_config`, along with `overrides`. From these, the `experiment_dictconfig` is created ```mermaid @@ -24,9 +24,9 @@ title: Fixture dependency graph --- flowchart TD -datamodule_config[ - datamodule_config -] -- 'datamodule=A' --> command_line_arguments +dataset_config[ + dataset_config +] -- 'dataset=A' --> command_line_arguments algorithm_config[ algorithm_config ] -- 'algorithm=B' --> command_line_arguments @@ -35,7 +35,7 @@ ] -- 'seed=123' --> command_line_arguments command_line_arguments[ command_line_arguments -] -- load configs for 'datamodule=A algorithm=B seed=123' --> experiment_dictconfig +] -- load configs for 'dataset=A algorithm=B seed=123' --> experiment_dictconfig experiment_dictconfig[ experiment_dictconfig ] -- instantiate objects from configs --> experiment_config @@ -95,7 +95,7 @@ from project.datasets.vision import VisionDataModule, num_cpus_on_node from project.experiment import ( instantiate_algorithm, - instantiate_datamodule, + instantiate_dataset, instantiate_trainer, setup_logging, ) @@ -195,13 +195,13 @@ def algorithm_config(request: pytest.FixtureRequest) -> str | None: @pytest.fixture(scope="session") -def datamodule_config(request: pytest.FixtureRequest) -> str | None: - """The datamodule config to use in the experiment, as if `datamodule=` was passed.""" +def dataset_config(request: pytest.FixtureRequest) -> str | None: + """The dataset config to use in the experiment, as if `dataset=` was passed.""" - datamodule_config_name = getattr(request, "param", None) - if datamodule_config_name: - _add_default_marks_for_config_name(datamodule_config_name, request) - return datamodule_config_name + dataset_config_name = getattr(request, "param", None) + if dataset_config_name: + _add_default_marks_for_config_name(dataset_config_name, request) + return dataset_config_name @pytest.fixture(scope="session") @@ -216,7 +216,7 @@ def algorithm_network_config(request: pytest.FixtureRequest) -> str | None: @pytest.fixture(scope="session") def command_line_arguments( algorithm_config: str | None, - datamodule_config: str | None, + dataset_config: str | None, algorithm_network_config: str | None, command_line_overrides: tuple[str, ...], request: pytest.FixtureRequest, @@ -224,7 +224,7 @@ def command_line_arguments( """Fixture that returns the command-line arguments that will be passed to Hydra to run the experiment. - The `algorithm_config`, `network_config` and `datamodule_config` values here are parametrized + The `algorithm_config`, `network_config` and `dataset_config` values here are parametrized indirectly by most tests using the [`project.utils.testutils.run_for_all_configs_of_type`][] function so that the respective components are created in the same way as they would be by Hydra in a regular run. @@ -238,7 +238,7 @@ def command_line_arguments( assert isinstance(param, list | tuple) return tuple(param) - combination = set([datamodule_config, algorithm_network_config, algorithm_config]) + combination = set([dataset_config, algorithm_network_config, algorithm_config]) for configs, marks in default_marks_for_config_combinations.items(): marks = [marks] if not isinstance(marks, list | tuple) else marks configs = set(configs) @@ -262,8 +262,8 @@ def command_line_arguments( default_overrides.append(f"algorithm={algorithm_config}") if algorithm_network_config: default_overrides.append(f"algorithm/network={algorithm_network_config}") - if datamodule_config: - default_overrides.append(f"datamodule={datamodule_config}") + if dataset_config: + default_overrides.append(f"dataset={dataset_config}") all_overrides = default_overrides + list(command_line_overrides) return all_overrides @@ -316,23 +316,23 @@ def experiment_config( @pytest.fixture(scope="session") -def datamodule(experiment_dictconfig: DictConfig) -> lightning.LightningDataModule | None: - """Fixture that creates the datamodule for the given config.""" +def dataset(experiment_dictconfig: DictConfig) -> lightning.LightningDataModule | None: + """Fixture that creates the dataset or datamodule for the given config.""" # NOTE: creating the datamodule by itself instead of with everything else. - return instantiate_datamodule(experiment_dictconfig["datamodule"]) + return instantiate_dataset(experiment_dictconfig["dataset"]) @pytest.fixture(scope="function") def algorithm( experiment_config: Config, - datamodule: lightning.LightningDataModule | None, + dataset: lightning.LightningDataModule | None, trainer: lightning.Trainer | JaxTrainer, seed: int, device: torch.device, ): """Fixture that creates the "algorithm" (a [LightningModule][lightning.pytorch.core.module.LightningModule]).""" - algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) + algorithm = instantiate_algorithm(experiment_config.algorithm, dataset=dataset) if isinstance(trainer, lightning.Trainer) and isinstance(algorithm, lightning.LightningModule): with trainer.init_module(), device: # A bit hacky, but we have to do this because the lightningmodule isn't associated @@ -353,19 +353,19 @@ def trainer( @pytest.fixture(scope="session") def train_dataloader( - datamodule: lightning.LightningDataModule | None, request: pytest.FixtureRequest + dataset: lightning.LightningDataModule | None, request: pytest.FixtureRequest ) -> DataLoader: - if isinstance(datamodule, VisionDataModule) or hasattr(datamodule, "num_workers"): - datamodule.num_workers = 0 # type: ignore - if datamodule is None: + if isinstance(dataset, VisionDataModule) or hasattr(dataset, "num_workers"): + dataset.num_workers = 0 # type: ignore + if dataset is None: raise NotImplementedError( "This test is trying to use `train_dataloader` directly or indirectly but the " "algorithm that is being tested does not use a datamodule (or the test was not " "configured properly)! Consider overwriting this fixture in your test class." ) - datamodule.prepare_data() - datamodule.setup("fit") - train_dataloader = datamodule.train_dataloader() + dataset.prepare_data() + dataset.setup("fit") + train_dataloader = dataset.train_dataloader() assert isinstance(train_dataloader, DataLoader) return train_dataloader diff --git a/project/datasets/datasets_test.py b/project/datasets/datasets_test.py index 7b75e13a..a3f29c52 100644 --- a/project/datasets/datasets_test.py +++ b/project/datasets/datasets_test.py @@ -38,7 +38,7 @@ @pytest.mark.parametrize(command_line_overrides.__name__, ["algorithm=no_op"], indirect=True) @run_for_all_configs_in_group(group_name="datamodule") def test_first_batch( - datamodule: LightningDataModule, + dataset: LightningDataModule, request: pytest.FixtureRequest, tensor_regression: TensorRegressionFixture, original_datadir: Path, @@ -48,36 +48,36 @@ def test_first_batch( """Test that the first batch of the dataloader is reproducible (the same for the same seed).""" # Note: using dataloader workers in tests can cause issues, since if a test fails, dataloader # workers aren't always cleaned up properly. - if isinstance(datamodule, VisionDataModule) or hasattr(datamodule, "num_workers"): - datamodule.num_workers = 0 # type: ignore + if isinstance(dataset, VisionDataModule) or hasattr(dataset, "num_workers"): + dataset.num_workers = 0 # type: ignore - datamodule.prepare_data() + dataset.prepare_data() if stage == RunningStage.TRAINING: - datamodule.setup("fit") - dataloader = datamodule.train_dataloader() + dataset.setup("fit") + dataloader = dataset.train_dataloader() elif stage in [RunningStage.VALIDATING, RunningStage.SANITY_CHECKING]: - datamodule.setup("validate") - dataloader = datamodule.val_dataloader() + dataset.setup("validate") + dataloader = dataset.val_dataloader() elif stage == RunningStage.TESTING: - datamodule.setup("test") - dataloader = datamodule.test_dataloader() + dataset.setup("test") + dataloader = dataset.test_dataloader() else: assert stage == RunningStage.PREDICTING - datamodule.setup("predict") - dataloader = datamodule.predict_dataloader() + dataset.setup("predict") + dataloader = dataset.predict_dataloader() iterator = iter(dataloader) batch = next(iterator) from torchvision.tv_tensors import Image - if isinstance(datamodule, ImageClassificationDataModule): + if isinstance(dataset, ImageClassificationDataModule): assert isinstance(batch, list | tuple) and len(batch) == 2 # todo: if we tighten this and make it so vision datamodules return Images, then we should # have strict asserts here that check that batch[0] is an Image. It doesn't seem to be the case though. # assert isinstance(batch[0], Image) assert isinstance(batch[0], torch.Tensor) assert isinstance(batch[1], torch.Tensor) - elif isinstance(datamodule, VisionDataModule): + elif isinstance(dataset, VisionDataModule): if isinstance(batch, list | tuple): # assert isinstance(batch[0], Image) assert isinstance(batch[0], torch.Tensor) @@ -135,7 +135,7 @@ def test_first_batch( RunningStage.PREDICTING: "prediction(?)", } - fig.suptitle(f"First {split[stage]} batch of datamodule {type(datamodule).__name__}") + fig.suptitle(f"First {split[stage]} batch of datamodule {type(dataset).__name__}") figure_path, _ = get_test_source_and_temp_file_paths( extension=".png", request=request, diff --git a/project/datasets/image_classification/fashion_mnist.py b/project/datasets/image_classification/fashion_mnist.py index 45b379c6..9f1f9038 100644 --- a/project/datasets/image_classification/fashion_mnist.py +++ b/project/datasets/image_classification/fashion_mnist.py @@ -30,7 +30,7 @@ class FashionMNISTDataModule(MNISTDataModule): dm = FashionMNISTDataModule('.') model = LitModel() - Trainer().fit(model, datamodule=dm) + Trainer().fit(model, dataset=dm) """ name = "fashion_mnist" diff --git a/project/datasets/text/text_classification_test.py b/project/datasets/text/text_classification_test.py index 633dec05..efeb4bd0 100644 --- a/project/datasets/text/text_classification_test.py +++ b/project/datasets/text/text_classification_test.py @@ -5,7 +5,7 @@ import pytest from project.datasets.text.text_classification import TextClassificationDataModule -from project.experiment import instantiate_datamodule +from project.experiment import instantiate_dataset from project.utils.testutils import get_config_loader datamodule_configs = ["glue_cola"] @@ -29,7 +29,7 @@ def datamodule( run_mode=RunMode.RUN, ) datamodule_config = config["datamodule"] - datamodule = instantiate_datamodule(datamodule_config) + datamodule = instantiate_dataset(datamodule_config) assert datamodule is not None return datamodule diff --git a/project/experiment.py b/project/experiment.py index 8b9e4cc8..9ab22d9b 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -100,32 +100,32 @@ def instantiate_trainer(experiment_config: Config) -> Trainer | JaxTrainer: return trainer -def instantiate_datamodule( - datamodule_config: Builds[type[LightningDataModule]] | LightningDataModule | None, +def instantiate_dataset( + dataset_config: Builds[type[LightningDataModule]] | LightningDataModule | None, ) -> LightningDataModule | None: - """Instantiate the datamodule from the configuration dict. + """Instantiate the dataset/datamodule from the configuration dict. Any interpolations in the config will have already been resolved by the time we get here. """ - if not datamodule_config: + if not dataset_config: return None import lightning - if isinstance(datamodule_config, lightning.LightningDataModule): + if isinstance(dataset_config, lightning.LightningDataModule): logger.info( f"Datamodule was already instantiated (probably to interpolate a field value). " - f"{datamodule_config=}" + f"{dataset_config=}" ) - datamodule = datamodule_config + datamodule = dataset_config else: - logger.debug(f"Instantiating datamodule from config: {datamodule_config}") - datamodule = instantiate(datamodule_config) + logger.debug(f"Instantiating dataset from config: {dataset_config}") + datamodule = instantiate(dataset_config) return datamodule def instantiate_algorithm( - algorithm_config: Config, datamodule: LightningDataModule | None + algorithm_config: Config, dataset: LightningDataModule | None ) -> LightningModule | JaxModule: """Function used to instantiate the algorithm. @@ -148,14 +148,14 @@ def instantiate_algorithm( ) return algo_config - if datamodule: - algo_or_algo_partial = hydra.utils.instantiate(algo_config, datamodule=datamodule) + if dataset: + algo_or_algo_partial = hydra.utils.instantiate(algo_config, datamodule=dataset) else: algo_or_algo_partial = hydra.utils.instantiate(algo_config) if isinstance(algo_or_algo_partial, functools.partial): - if datamodule: - algorithm = algo_or_algo_partial(datamodule=datamodule) + if dataset: + algorithm = algo_or_algo_partial(datamodule=dataset) else: algorithm = algo_or_algo_partial() else: diff --git a/project/main.py b/project/main.py index 6c715159..bc2629a2 100644 --- a/project/main.py +++ b/project/main.py @@ -34,7 +34,7 @@ from project.configs.config import Config from project.experiment import ( instantiate_algorithm, - instantiate_datamodule, + instantiate_dataset, setup_logging, ) from project.trainers.jax_trainer import JaxModule, JaxTrainer, Ts, _MetricsT @@ -104,11 +104,11 @@ def main(dict_config: DictConfig) -> dict: ) # Create the datamodule (if present) - datamodule: lightning.LightningDataModule | None = instantiate_datamodule(config.datamodule) + dataset: lightning.LightningDataModule | None = instantiate_dataset(config.dataset) # Create the "algorithm" algorithm: lightning.LightningModule | JaxModule = instantiate_algorithm( - config.algorithm, datamodule=datamodule + config.algorithm, dataset=dataset ) if wandb.run: @@ -118,15 +118,13 @@ def main(dict_config: DictConfig) -> dict: ) # Train the algorithm. - train_results = train( - config=config, trainer=trainer, datamodule=datamodule, algorithm=algorithm - ) + train_results = train(config=config, trainer=trainer, dataset=dataset, algorithm=algorithm) # Evaluate the algorithm. if isinstance(trainer, lightning.Trainer): assert isinstance(algorithm, lightning.LightningModule) metric_name, error, _metrics = evaluate_lightningmodule( - algorithm, datamodule=datamodule, trainer=trainer + algorithm, dataset=dataset, trainer=trainer ) else: assert isinstance(trainer, JaxTrainer) @@ -146,7 +144,7 @@ def main(dict_config: DictConfig) -> dict: def train( config: Config, trainer: lightning.Trainer | JaxTrainer, - datamodule: lightning.LightningDataModule | None, + dataset: lightning.LightningDataModule | None, algorithm: lightning.LightningModule | JaxModule, ): if isinstance(trainer, lightning.Trainer): @@ -155,14 +153,14 @@ def train( # The Algorithm gets to "wrap" the datamodule if it wants to. This could be useful for # example in RL, where we need to set the actor to use in the environment, as well as # potentially adding Wrappers on top of the environment, or having a replay buffer, etc. - datamodule = getattr(algorithm, "datamodule", datamodule) + dataset = getattr(algorithm, "datamodule", dataset) return trainer.fit( algorithm, - datamodule=datamodule, + datamodule=dataset, ckpt_path=config.ckpt_path, ) - if datamodule is not None: + if dataset is not None: raise NotImplementedError( "The JaxTrainer doesn't yet support using a datamodule. For now, you should " f"return a batch of data from the {JaxModule.get_batch.__name__} method in your " @@ -210,7 +208,7 @@ def instantiate_values(config_dict: DictConfig | None) -> list[Any] | None: def evaluate_lightningmodule( algorithm: lightning.LightningModule, trainer: lightning.Trainer, - datamodule: lightning.LightningDataModule | None, + dataset: lightning.LightningDataModule | None, ) -> tuple[MetricName, float | None, dict]: """Evaluates the algorithm and returns the metrics. @@ -235,11 +233,11 @@ def evaluate_lightningmodule( ] elif trainer.limit_val_batches != 0: results_type = "val" - results = trainer.validate(model=algorithm, datamodule=datamodule) + results = trainer.validate(model=algorithm, datamodule=dataset) else: warnings.warn(RuntimeWarning("About to use the test set for evaluation!")) results_type = "test" - results = trainer.test(model=algorithm, datamodule=datamodule) + results = trainer.test(model=algorithm, datamodule=dataset) if results is None: rich.print("RUN FAILED!") diff --git a/project/main_test.py b/project/main_test.py index 9c2f3a0b..b9f06e62 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -109,14 +109,14 @@ def mock_evaluate_jax_module(monkeypatch: pytest.MonkeyPatch): ), pytest.param( "experiment=profiling " - "datamodule=cifar10 " # Run a small dataset instead of ImageNet (would take ~6min to process on a compute node..) + "dataset=cifar10 " # Run a small dataset instead of ImageNet (would take ~6min to process on a compute node..) "trainer/logger=tensorboard " # Use Tensorboard logger because DeviceStatsMonitor requires a logger being used. "trainer.fast_dev_run=True ", # make each job quicker to run marks=pytest.mark.slow, ), ( "experiment=profiling algorithm=no_op " - "datamodule=cifar10 " # Run a small dataset instead of ImageNet (would take ~6min to process on a compute node..) + "dataset=cifar10 " # Run a small dataset instead of ImageNet (would take ~6min to process on a compute node..) "trainer/logger=tensorboard " # Use Tensorboard logger because DeviceStatsMonitor requires a logger being used. "trainer.fast_dev_run=True " # make each job quicker to run ), @@ -219,7 +219,7 @@ def test_setting_just_algorithm_isnt_enough(experiment_dictconfig: DictConfig) - @pytest.mark.parametrize( command_line_overrides.__name__, [ - "algorithm=image_classifier datamodule=cifar10 seed=1 trainer/callbacks=none trainer.fast_dev_run=True" + "algorithm=image_classifier dataset=cifar10 seed=1 trainer/callbacks=none trainer.fast_dev_run=True" ], indirect=True, ) diff --git a/project/utils/remote_launcher_plugin_test.py b/project/utils/remote_launcher_plugin_test.py index d30f8b08..3ef893d4 100644 --- a/project/utils/remote_launcher_plugin_test.py +++ b/project/utils/remote_launcher_plugin_test.py @@ -40,7 +40,7 @@ def _yaml_files_in(directory: str | Path, recursive: bool = False): "command_line_args", [ pytest.param( - f"algorithm=image_classifier datamodule=cifar10 trainer.fast_dev_run=True cluster={cluster} resources={resources}", + f"algorithm=image_classifier dataset=cifar10 trainer.fast_dev_run=True cluster={cluster} resources={resources}", marks=[ pytest.mark.skipif( SLURM_JOB_ID is None and cluster == "current", @@ -110,7 +110,7 @@ def test_can_load_configs(command_line_args: str): [ [ "algorithm=image_classifier", - "datamodule=cifar10", + "dataset=cifar10", # TODO: The ordering is important here, we can't use `cluster` before `resources`, # otherwise it will use the local launcher! "resources=gpu",