From 6ea5f9ca69af2a5cde8fbad1749a8584a3690c22 Mon Sep 17 00:00:00 2001 From: Gokcen Eraslan Date: Sat, 20 Mar 2021 18:35:38 -0400 Subject: [PATCH] Add a new checkcounts argument to skip raw count check --- dca/__main__.py | 5 +++++ dca/api.py | 8 ++++++-- dca/io.py | 17 +++++++++-------- dca/train.py | 1 + 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/dca/__main__.py b/dca/__main__.py index 4c4c494..8a5d09f 100644 --- a/dca/__main__.py +++ b/dca/__main__.py @@ -114,6 +114,10 @@ def parse_args(): parser.add_argument('--tensorboard', dest='tensorboard', action='store_true', help="Use tensorboard for saving weight distributions and " "visualization. (default: False)") + parser.add_argument('--checkcounts', dest='checkcounts', action='store_true', + help="Check if the expression matrix has raw (unnormalized) counts (default: True)") + parser.add_argument('--nocheckcounts', dest='checkcounts', action='store_false', + help="Do not check if the expression matrix has raw (unnormalized) counts") parser.add_argument('--denoisesubset', dest='denoisesubset', type=str, help='Perform denoising only for the subset of genes ' 'in the given file. Gene names should be line ' @@ -124,6 +128,7 @@ def parse_args(): saveweights=False, sizefactors=True, batchnorm=True, + checkcounts=True, norminput=True, hyper=False, debug=False, diff --git a/dca/api.py b/dca/api.py index 59b5ca9..5c4a1ec 100644 --- a/dca/api.py +++ b/dca/api.py @@ -40,7 +40,8 @@ def dca(adata, training_kwds={}, return_model=False, return_info=False, - copy=False + copy=False, + check_counts=True, ): """Deep count autoencoder(DCA) API. @@ -116,6 +117,8 @@ def dca(adata, zinb or zinb-conddisp. copy : `bool`, optional. Default: `False`. If true, a copy of anndata is returned. + check_counts : `bool`. Default `True`. + Check if the counts are unnormalized (raw) counts. Returns ------- @@ -153,7 +156,8 @@ def dca(adata, adata = read_dataset(adata, transpose=False, test_split=False, - copy=copy) + copy=copy, + check_counts=check_counts) # check for zero genes nonzero_genes, _ = sc.pp.filter_genes(adata.X, min_counts=1) diff --git a/dca/io.py b/dca/io.py index e43c1af..1e3f47b 100644 --- a/dca/io.py +++ b/dca/io.py @@ -50,7 +50,7 @@ def __getitem__(self, idx): return {'count': batch, 'size_factors': batch_sf}, batch -def read_dataset(adata, transpose=False, test_split=False, copy=False): +def read_dataset(adata, transpose=False, test_split=False, copy=False, check_counts=True): if isinstance(adata, sc.AnnData): if copy: @@ -60,13 +60,14 @@ def read_dataset(adata, transpose=False, test_split=False, copy=False): else: raise NotImplementedError - # check if observations are unnormalized using first 10 - X_subset = adata.X[:10] - norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.' - if sp.sparse.issparse(X_subset): - assert (X_subset.astype(int) != X_subset).nnz == 0, norm_error - else: - assert np.all(X_subset.astype(int) == X_subset), norm_error + if check_counts: + # check if observations are unnormalized using first 10 + X_subset = adata.X[:10] + norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.' + if sp.sparse.issparse(X_subset): + assert (X_subset.astype(int) != X_subset).nnz == 0, norm_error + else: + assert np.all(X_subset.astype(int) == X_subset), norm_error if transpose: adata = adata.transpose() diff --git a/dca/train.py b/dca/train.py index fbb59e6..6171f26 100644 --- a/dca/train.py +++ b/dca/train.py @@ -123,6 +123,7 @@ def train_with_args(args): adata = io.read_dataset(args.input, transpose=(not args.transpose), # assume gene x cell by default + check_counts=args.checkcounts, test_split=args.testsplit) adata = io.normalize(adata,