Skip to content

Commit

Permalink
ensure no breaking change in kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
anton-emanuel committed Oct 8, 2024
1 parent 8257b8e commit b5291a6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
8 changes: 4 additions & 4 deletions torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
RandomBatchGeoSampler,
)
from ..transforms import AugmentationSequential
from .utils import MisconfigurationException, get_prefixed_kwargs
from .utils import MisconfigurationException, split_kwargs


class BaseDataModule(LightningDataModule):
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
self.dataset_class = dataset_class
self.batch_size = batch_size
self.num_workers = num_workers
self.kwargs = kwargs
self.dataloader_kwargs, self.kwargs = split_kwargs('dataloader_', **kwargs)

# Datasets
self.dataset: Dataset[dict[str, Tensor]] | None = None
Expand Down Expand Up @@ -287,7 +287,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
**get_prefixed_kwargs('dataloader_', **self.kwargs),
**self.dataloader_kwargs,
)

def train_dataloader(self) -> DataLoader[dict[str, Tensor]]:
Expand Down Expand Up @@ -432,7 +432,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
**get_prefixed_kwargs('dataloader_', **self.kwargs),
**self.dataloader_kwargs,
)

def train_dataloader(self) -> DataLoader[dict[str, Tensor]]:
Expand Down
23 changes: 18 additions & 5 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,27 @@ def group_shuffle_split(
return train_idxs, test_idxs


def get_prefixed_kwargs(prefix: str, **kwargs: Any) -> dict[str, Any]:
"""Get kwargs with a specific prefix.
def split_kwargs(*prefixes: str, **kwargs: Any) -> tuple[dict[str, Any], ...]:
"""Split kwargs into prefixed and other kwargs.
Args:
prefix: Prefix to filter kwargs by.
*prefixes: Prefixes to filter kwargs by.
**kwargs: Keyword arguments to filter.
Returns:
Dictionary of kwargs with the specified prefix.
Tuple of prefixed kwargs and other kwargs.
"""
return {k.replace(prefix, ''): v for k, v in kwargs.items() if k.startswith(prefix)}
prefixed_kwargs: list[dict[str, Any]] = [{} for _ in prefixes]
other_kwargs: dict[str, Any] = {}

for key, value in kwargs.items():
matched = False
for i, prefix in enumerate(prefixes):
if key.startswith(prefix):
prefixed_kwargs[i][key[len(prefix) :]] = value
matched = True
break
if not matched:
other_kwargs[key] = value

return *prefixed_kwargs, other_kwargs

0 comments on commit b5291a6

Please sign in to comment.