From 94583184ac47b85e90aba3f27f6c5875932ddf9a Mon Sep 17 00:00:00 2001 From: RemyLau Date: Wed, 8 Mar 2023 08:37:19 -0500 Subject: [PATCH] use dance logger --- dance/datasets/multimodality.py | 22 +++++++++---------- .../multi_modality/predict_modality/babel.py | 4 ++-- dance/transforms/graph/scmogcn_graph.py | 9 ++++---- .../multi_modality/predict_modality/babel.py | 6 ++--- 4 files changed, 19 insertions(+), 22 deletions(-) diff --git a/dance/datasets/multimodality.py b/dance/datasets/multimodality.py index fd7b09c2..9770edaa 100644 --- a/dance/datasets/multimodality.py +++ b/dance/datasets/multimodality.py @@ -1,4 +1,3 @@ -import logging import os import pickle @@ -7,6 +6,7 @@ import scanpy as sc import torch +from dance import logger from dance.transforms.preprocess import lsiTransformer from dance.utils.download import download_file, unzip_file @@ -183,7 +183,7 @@ def __init__(self, subtask, data_dir="./data"): def preprocess(self, kind='feature_selection', selection_threshold=10000): if kind == 'pca': - logging.info('Preprocessing method not supported.') + logger.info('Preprocessing method not supported.') return self elif kind == 'feature_selection': if self.modalities[0].shape[1] > selection_threshold: @@ -193,9 +193,9 @@ def preprocess(self, kind='feature_selection', selection_threshold=10000): for i in [0, 2]: self.modalities[i] = self.modalities[i][:, self.modalities[i].var['highly_variable']] else: - logging.info('Preprocessing method not supported.') + logger.info('Preprocessing method not supported.') return self - logging.info('Preprocessing done.') + logger.info('Preprocessing done.') return self @@ -294,9 +294,9 @@ def preprocess(self, kind='pca', pkl_path=None, selection_threshold=10000): self.modalities[i] = self.modalities[i][:, self.modalities[i].var['highly_variable']] self.modalities[i + 2] = self.modalities[i + 2][:, self.modalities[i + 2].var['highly_variable']] else: - logging.info('Preprocessing method not supported.') + logger.info('Preprocessing method not supported.') return self - logging.info('Preprocessing done.') + logger.info('Preprocessing done.') self.preprocessed = True return self @@ -373,7 +373,7 @@ def preprocess(self, kind='aux', pretrained_folder='.', selection_threshold=1000 # cell types, batch labels, cell cycle self.nb_cell_types, self.nb_batches, self.nb_phases = pickle.load(f) self.preprocessed = True - logging.info('Preprocessing done.') + logger.info('Preprocessing done.') return self ########################################## @@ -466,8 +466,8 @@ def preprocess(self, kind='aux', pretrained_folder='.', selection_threshold=1000 'AURKA', 'PSRC1', 'ANLN', 'LBR', 'CKAP5', 'CENPE', 'CTCF', \ 'NEK2', 'G2E3', 'GAS2L3', 'CBX5', 'CENPA'] - logging.info('Data loading and pca done', mod1_pca.shape, mod2_pca.shape) - logging.info('Start to calculate cell_cycle score. It may roughly take an hour.') + logger.info('Data loading and pca done', mod1_pca.shape, mod2_pca.shape) + logger.info('Start to calculate cell_cycle score. It may roughly take an hour.') cell_type_labels = self.test_sol.obs['cell_type'].to_numpy() #mod1_obs['cell_type'] batch_ids = mod1_obs['batch'] @@ -523,10 +523,10 @@ def preprocess(self, kind='aux', pretrained_folder='.', selection_threshold=1000 n_top_genes=selection_threshold) self.modalities[i] = self.modalities[i][:, self.modalities[i].var['highly_variable']] else: - logging.info('Preprocessing method not supported.') + logger.info('Preprocessing method not supported.') return self self.preprocessed = True - logging.info('Preprocessing done.') + logger.info('Preprocessing done.') return self def get_preprocessed_data(self): diff --git a/dance/modules/multi_modality/predict_modality/babel.py b/dance/modules/multi_modality/predict_modality/babel.py index 919fd994..e7cd0511 100644 --- a/dance/modules/multi_modality/predict_modality/babel.py +++ b/dance/modules/multi_modality/predict_modality/babel.py @@ -8,7 +8,6 @@ multiomic profiles at single-cell resolution." Proceedings of the National Academy of Sciences 118, no. 15 (2021). """ -import logging import math from typing import Callable, List, Tuple, Union @@ -18,6 +17,7 @@ from torch.utils.data import DataLoader import dance.utils.loss as loss_functions +from dance import logger REDUCE_LR_ON_PLATEAU_PARAMS = { "mode": "min", @@ -317,7 +317,7 @@ def __init__( self.final_activations["act1"] = final_activations else: raise ValueError(f"Unrecognized type for final_activation: {type(final_activations)}") - logging.info(f"ChromDecoder with {len(self.final_activations)} output activations") + logger.info(f"ChromDecoder with {len(self.final_activations)} output activations") self.final_decoders = nn.ModuleList() # List[List[Module]] for n in self.num_outputs: diff --git a/dance/transforms/graph/scmogcn_graph.py b/dance/transforms/graph/scmogcn_graph.py index 9609a9f9..d17ec986 100644 --- a/dance/transforms/graph/scmogcn_graph.py +++ b/dance/transforms/graph/scmogcn_graph.py @@ -1,4 +1,3 @@ -import logging import os import pickle from collections import defaultdict @@ -9,8 +8,9 @@ import torch from sklearn.decomposition import TruncatedSVD +from dance import logger from dance.data.base import Data -from dance.typing import Optional, Tuple, Union +from dance.typing import Union from ..base import BaseTransform @@ -76,9 +76,8 @@ def create_pathway_graph(gex_features: scipy.sparse.spmatrix, gene_names: Union[ pk_path = f'pw_{subtask}_{pathway_weight}.pkl' if os.path.exists(pk_path): - logging.warning( - 'Pathway file exist. Load pickle file by default. Auguments "--pathway_weight" and "--pathway_path" will not take effect.' - ) + logger.warning("Pathway file exist. Load pickle file by default. " + "Auguments '--pathway_weight' and '--pathway_path' will not take effect.") uu, vv, ee = pickle.load(open(pk_path, 'rb')) else: # Load Original Pathway File diff --git a/examples/multi_modality/predict_modality/babel.py b/examples/multi_modality/predict_modality/babel.py index a7c3c958..045349bb 100644 --- a/examples/multi_modality/predict_modality/babel.py +++ b/examples/multi_modality/predict_modality/babel.py @@ -7,14 +7,13 @@ import mudata import torch +from dance import logger from dance.data import Data from dance.datasets.multimodality import ModalityPredictionDataset from dance.modules.multi_modality.predict_modality.babel import BabelWrapper from dance.utils import set_seed if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - OPTIMIZER_DICT = { "adam": torch.optim.Adam, "rmsprop": torch.optim.RMSprop, @@ -52,13 +51,12 @@ os.makedirs(os.path.dirname(args.outdir)) # Specify output log file - logger = logging.getLogger() fh = logging.FileHandler(f"{args.outdir}/training_{args.subtask}_{args.rnd_seed}.log", "w") fh.setLevel(logging.INFO) logger.addHandler(fh) for arg in vars(args): - logging.info(f"Parameter {arg}: {getattr(args, arg)}") + logger.info(f"Parameter {arg}: {getattr(args, arg)}") # Construct data object mod1 = anndata.concat((dataset.modalities[0], dataset.modalities[2]))