Skip to content

Commit

Permalink
Add a new checkcounts argument to skip raw count check
Browse files Browse the repository at this point in the history
  • Loading branch information
Gokcen Eraslan committed Mar 20, 2021
1 parent 6cd8c67 commit 6ea5f9c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
5 changes: 5 additions & 0 deletions dca/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def parse_args():
parser.add_argument('--tensorboard', dest='tensorboard',
action='store_true', help="Use tensorboard for saving weight distributions and "
"visualization. (default: False)")
parser.add_argument('--checkcounts', dest='checkcounts', action='store_true',
help="Check if the expression matrix has raw (unnormalized) counts (default: True)")
parser.add_argument('--nocheckcounts', dest='checkcounts', action='store_false',
help="Do not check if the expression matrix has raw (unnormalized) counts")
parser.add_argument('--denoisesubset', dest='denoisesubset', type=str,
help='Perform denoising only for the subset of genes '
'in the given file. Gene names should be line '
Expand All @@ -124,6 +128,7 @@ def parse_args():
saveweights=False,
sizefactors=True,
batchnorm=True,
checkcounts=True,
norminput=True,
hyper=False,
debug=False,
Expand Down
8 changes: 6 additions & 2 deletions dca/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def dca(adata,
training_kwds={},
return_model=False,
return_info=False,
copy=False
copy=False,
check_counts=True,
):
"""Deep count autoencoder(DCA) API.
Expand Down Expand Up @@ -116,6 +117,8 @@ def dca(adata,
zinb or zinb-conddisp.
copy : `bool`, optional. Default: `False`.
If true, a copy of anndata is returned.
check_counts : `bool`. Default `True`.
Check if the counts are unnormalized (raw) counts.
Returns
-------
Expand Down Expand Up @@ -153,7 +156,8 @@ def dca(adata,
adata = read_dataset(adata,
transpose=False,
test_split=False,
copy=copy)
copy=copy,
check_counts=check_counts)

# check for zero genes
nonzero_genes, _ = sc.pp.filter_genes(adata.X, min_counts=1)
Expand Down
17 changes: 9 additions & 8 deletions dca/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __getitem__(self, idx):
return {'count': batch, 'size_factors': batch_sf}, batch


def read_dataset(adata, transpose=False, test_split=False, copy=False):
def read_dataset(adata, transpose=False, test_split=False, copy=False, check_counts=True):

if isinstance(adata, sc.AnnData):
if copy:
Expand All @@ -60,13 +60,14 @@ def read_dataset(adata, transpose=False, test_split=False, copy=False):
else:
raise NotImplementedError

# check if observations are unnormalized using first 10
X_subset = adata.X[:10]
norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.'
if sp.sparse.issparse(X_subset):
assert (X_subset.astype(int) != X_subset).nnz == 0, norm_error
else:
assert np.all(X_subset.astype(int) == X_subset), norm_error
if check_counts:
# check if observations are unnormalized using first 10
X_subset = adata.X[:10]
norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.'
if sp.sparse.issparse(X_subset):
assert (X_subset.astype(int) != X_subset).nnz == 0, norm_error
else:
assert np.all(X_subset.astype(int) == X_subset), norm_error

if transpose: adata = adata.transpose()

Expand Down
1 change: 1 addition & 0 deletions dca/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def train_with_args(args):

adata = io.read_dataset(args.input,
transpose=(not args.transpose), # assume gene x cell by default
check_counts=args.checkcounts,
test_split=args.testsplit)

adata = io.normalize(adata,
Expand Down

0 comments on commit 6ea5f9c

Please sign in to comment.