Skip to content

Commit

Permalink
root_dir -> root
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Sep 27, 2022
1 parent 4aea435 commit b89fc41
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ All TorchGeo datasets are compatible with PyTorch data loaders, making them easy
In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created PyTorch Lightning [*datamodules*](https://torchgeo.readthedocs.io/en/stable/api/datamodules.html) with well-defined train-val-test splits and [*trainers*](https://torchgeo.readthedocs.io/en/stable/api/trainers.html) for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a semantic segmentation model on the [Inria Aerial Image Labeling](https://project.inria.fr/aerialimagelabeling/) dataset is as easy as a few imports and four lines of code.

```python
datamodule = InriaAerialImageLabelingDataModule(root_dir="...", batch_size=64, num_workers=6)
datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(segmentation_model="unet", encoder_weights="imagenet", learning_rate=0.1)
trainer = Trainer(gpus=1, default_root_dir="...")

Expand Down
6 changes: 3 additions & 3 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def set_up_parser() -> argparse.ArgumentParser:
"--gpu", default=0, type=int, help="GPU ID to use", metavar="ID"
)
parser.add_argument(
"--root-dir",
"--root",
required=True,
type=str,
help="root directory of the dataset for the accompanying task",
Expand Down Expand Up @@ -123,7 +123,7 @@ def main(args: argparse.Namespace) -> None:
args: command-line arguments
"""
assert os.path.exists(args.input_checkpoint)
assert os.path.exists(args.root_dir)
assert os.path.exists(args.root)
TASK = TASK_TO_MODULES_MAPPING[args.task][0]
DATAMODULE = TASK_TO_MODULES_MAPPING[args.task][1]

Expand All @@ -135,7 +135,7 @@ def main(args: argparse.Namespace) -> None:

dm = DATAMODULE( # type: ignore[call-arg]
seed=args.seed,
root_dir=args.root_dir,
root=args.root,
num_workers=args.num_workers,
batch_size=args.batch_size,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_overwrite_experiment_dir(tmp_path: Path) -> None:
"program.data_dir=" + data_dir,
"program.log_dir=" + str(log_dir),
"experiment.task=cyclone",
"experiment.datamodule.root_dir=" + data_dir,
"experiment.datamodule.root=" + data_dir,
"program.overwrite=True",
"trainer.fast_dev_run=1",
"trainer.gpus=0",
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_config_file(tmp_path: Path) -> None:
name: test
task: cyclone
datamodule:
root_dir: {data_dir}
root: {data_dir}
trainer:
fast_dev_run: true
gpus: 0
Expand Down

0 comments on commit b89fc41

Please sign in to comment.