Skip to content

Commit

Permalink
use the validation set as the test set during training
Browse files Browse the repository at this point in the history
  • Loading branch information
tony-kuo committed Mar 19, 2024
1 parent 7f0ff96 commit c4270a8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 55 deletions.
9 changes: 4 additions & 5 deletions src/scimilarity/training_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,14 @@ def test_step(self, batch, batch_idx):
A batch index as defined by a pytorch-lightning.
"""

if self.trainer.datamodule.test_dataset is None:
if self.trainer.datamodule.val_dataset is None:
return {}
return self._eval_step(batch, prefix="test")

def on_test_epoch_end(self):
"""Pytorch-lightning test epoch end evaluation."""

if self.trainer.datamodule.test_dataset is None:
if self.trainer.datamodule.val_dataset is None:
return {}
return self._eval_epoch(prefix="test")

Expand Down Expand Up @@ -599,11 +599,10 @@ def save_all(

# write metadata: data paths, timestamp
meta_data = {
"train_path": self.trainer.datamodule.train_path,
"val_path": self.trainer.datamodule.val_path,
"test_path": self.trainer.datamodule.test_path,
"date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
meta_data["train_path"] = self.trainer.datamodule.train_path
meta_data["val_path"] = self.trainer.datamodule.val_path
with open(os.path.join(model_path, "metadata.json"), "w") as f:
f.write(json.dumps(meta_data))

Expand Down
51 changes: 1 addition & 50 deletions src/scimilarity/zarr_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def __init__(
train_path: str,
gene_order: str,
val_path: Optional[str] = None,
test_path: Optional[str] = None,
obs_field: str = "celltype_name",
batch_size: int = 1000,
num_workers: int = 1,
Expand All @@ -82,8 +81,6 @@ def __init__(
after preprocessing.
val_path: str, optional, default: None
Path to folder containing all validation datasets.
test_path: str, optional, default: None
Path to folder containing all test datasets.
obs_field: str, default: "celltype_name"
The obs key name containing celltype labels.
batch_size: int, default: 1000
Expand All @@ -105,7 +102,6 @@ def __init__(
super().__init__()
self.train_path = train_path
self.val_path = val_path
self.test_path = test_path
self.batch_size = batch_size
self.num_workers = num_workers

Expand Down Expand Up @@ -181,41 +177,6 @@ def __init__(
# Lazy load val data from list of zarr datasets
self.val_dataset = scDatasetFromList(val_data_list)

self.test_dataset = None
if self.test_path is not None:
test_data_list = []
self.test_Y = []
self.test_study = []

if self.test_path[-1] != os.sep:
self.test_path += os.sep

self.test_file_list = [
(
root.replace(self.test_path, "").split(os.sep)[0],
dirs[0].replace(".aligned.zarr", ""),
)
for root, dirs, files in os.walk(self.test_path)
if dirs and dirs[0].endswith(".aligned.zarr")
]

for study, sample in tqdm(self.test_file_list):
data_path = os.path.join(
self.test_path, study, sample, sample + ".aligned.zarr"
)
if os.path.isdir(data_path):
zarr_data = ZarrDataset(data_path)
test_data_list.append(zarr_data)
self.test_Y.extend(
zarr_data.get_obs(obs_field).astype(str).tolist()
)
self.test_study.extend(
zarr_data.get_obs("study").astype(str).tolist()
)

# Lazy load test data from list of zarr datasets
self.test_dataset = scDatasetFromList(test_data_list)

def two_way_weighting(self, vec1: list, vec2: list) -> dict:
"""Two-way weighting.
Expand Down Expand Up @@ -341,14 +302,4 @@ def test_dataloader(self) -> DataLoader:
A DataLoader object containing the test dataset.
"""

if self.test_dataset is None:
return None
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
sampler=self.get_sampler_weights(self.test_Y, self.test_study),
collate_fn=self.collate,
)
return self.val_dataloader()

0 comments on commit c4270a8

Please sign in to comment.