From 036a3add445daf51137dd586018821b40aeddaff Mon Sep 17 00:00:00 2001 From: tony-kuo Date: Wed, 20 Mar 2024 12:36:13 -0700 Subject: [PATCH] update training path saving --- src/scimilarity/training_models.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/scimilarity/training_models.py b/src/scimilarity/training_models.py index 822e2eb..3e1c0f2 100644 --- a/src/scimilarity/training_models.py +++ b/src/scimilarity/training_models.py @@ -601,8 +601,20 @@ def save_all( meta_data = { "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } - meta_data["train_path"] = self.trainer.datamodule.train_path - meta_data["val_path"] = self.trainer.datamodule.val_path + if "train_path" in dir(self.trainer.datamodule): + meta_data["train_path"] = self.trainer.datamodule.train_path + meta_data["val_path"] = self.trainer.datamodule.val_path + elif "cell_tdb_uri" in dir(self.trainer.datamodule): + meta_data["cell_tdb_uri"] = self.trainer.datamodule.cell_tdb_uri + meta_data["counts_tdb_uri"] = self.trainer.datamodule.counts_tdb_uri + meta_data["gene_tdb_uri"] = self.trainer.datamodule.gene_tdb_uri + self.trainer.datamodule.data_df.to_csv( + os.path.join(model_path, "train_cells.csv") + ) + if self.trainer.datamodule.val_df is not None: + self.trainer.datamodule.val_df.to_csv( + os.path.join(model_path, "val_cells.csv") + ) with open(os.path.join(model_path, "metadata.json"), "w") as f: f.write(json.dumps(meta_data))