Skip to content

Commit

Permalink
Fix OSCD tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Jan 22, 2023
1 parent b1a9ae2 commit 98c690d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/datamodules/test_oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
assert batch["image"].shape[0] == batch["mask"].shape[0] == 2
assert batch["image"].shape[0] == batch["mask"].shape[0] == 1
if datamodule.bands == "all":
assert batch["image"].shape[1] == 26
else:
Expand All @@ -47,7 +47,7 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
batch = datamodule.on_after_batch_transfer(batch, 0)
if datamodule.val_split_pct > 0.0:
assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
assert batch["image"].shape[0] == batch["mask"].shape[0] == 2
assert batch["image"].shape[0] == batch["mask"].shape[0] == 1
if datamodule.bands == "all":
assert batch["image"].shape[1] == 26
else:
Expand All @@ -59,7 +59,7 @@ def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
assert batch["image"].shape[0] == batch["mask"].shape[0] == 2
assert batch["image"].shape[0] == batch["mask"].shape[0] == 1
if datamodule.bands == "all":
assert batch["image"].shape[1] == 26
else:
Expand Down

0 comments on commit 98c690d

Please sign in to comment.