Skip to content

Commit

Permalink
update celltypist example script to use dance data object (#72)
Browse files Browse the repository at this point in the history
* replace celltypist logger with dance logger

* update celltypist data loader

* update score function to take generic prediction matrix as input

* add celltypist compatibility with numpy array and pandas dataframe

* format fixes

* update celltypist example script to use dance data object
  • Loading branch information
RemyLau authored Dec 14, 2022
1 parent 8b419fc commit b027bd1
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 118 deletions.
17 changes: 13 additions & 4 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import glob
import os
import os.path as osp
Expand Down Expand Up @@ -371,7 +370,7 @@ def load_data(self):

if self.data_type == "celltypist":
self.download_benchmark_data(download_pretrained=False)
self.map_dict = get_map_dict(self.params.map_path, self.params.tissue) # load map
map_dict = get_map_dict(self.params.map_path, self.params.tissue) # load map
train_data = pd.read_csv(
osp.join(self.params.proj_path, self.params.train_dir, self.params.species,
self.params.species + "_" + self.params.tissue + str(self.params.train_dataset) + "_data.csv"),
Expand All @@ -381,7 +380,6 @@ def load_data(self):
self.params.proj_path, self.params.train_dir, self.params.species,
self.params.species + "_" + self.params.tissue + str(self.params.train_dataset) + "_celltype.csv"),
index_col=1)
self.train_adata = ad.AnnData(train_data.T, train_celltype)
test_data = pd.read_csv(
osp.join(self.params.proj_path, self.params.test_dir, self.params.species,
self.params.species + "_" + self.params.tissue + str(self.params.test_dataset) + "_data.csv"),
Expand All @@ -391,7 +389,18 @@ def load_data(self):
self.params.proj_path, self.params.test_dir, self.params.species,
self.params.species + "_" + self.params.tissue + str(self.params.test_dataset) + "_celltype.csv"),
index_col=1)
self.test_adata = ad.AnnData(test_data.T, test_celltype)

train_size = train_data.shape[1]
df = pd.concat(train_data.T.align(test_data.T, axis=1, join="left", fill_value=0))
adata = ad.AnnData(df, dtype=np.float32)
adata.obs_names_make_unique()

idx_to_label = sorted(train_celltype["Cell_type"].unique())
cell_labels = [{i} for i in train_celltype["Cell_type"].tolist()]
for i in test_celltype["Cell_type"].tolist():
cell_labels.append(map_dict[self.params.test_dataset][i])

return adata, cell_labels, idx_to_label, train_size

if self.data_type == "singlecellnet_exp":
if self.is_singlecellnet_complete():
Expand Down
122 changes: 38 additions & 84 deletions dance/modules/single_modality/cell_type_annotation/celltypist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,14 @@
from scipy.sparse import spmatrix
from scipy.special import expit
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler

from dance import logger
from dance.transforms.preprocess import (LRClassifier_celltypist, SGDClassifier_celltypist, downsample_adata,
get_sample_csv_celltypist, get_sample_data_celltypist, prepare_data_celltypist,
to_array_celltypist, to_vector_celltypist)

logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
set_level = logger.setLevel
info = logger.info
warn = logger.warning
error = logger.error
debug = logger.debug

celltypist_path = os.getenv('CELLTYPIST_FOLDER', default=os.path.join(str(pathlib.Path.home()), '.celltypist'))
pathlib.Path(celltypist_path).mkdir(parents=True, exist_ok=True)
data_path = os.path.join(celltypist_path, "data")
Expand Down Expand Up @@ -763,13 +757,13 @@ def __init__(
if isinstance(model, str):
model = Model.load(model)
self.model = model
if not filename:
logger.warn(f" No input file provided to the classifier")
if isinstance(filename, str) and filename == "":
logger.warn(" No input file provided to the classifier")
return
if isinstance(filename, str):
self.filename = filename
logger.info(f" Input file is '{self.filename}'")
logger.info(f" Loading data")
logger.info(" Loading data")
if isinstance(filename, str) and filename.endswith(('.csv', '.txt', '.tsv', '.tab', '.mtx', '.mtx.gz')):
self.adata = sc.read(self.filename)
if transpose:
Expand Down Expand Up @@ -803,9 +797,16 @@ def __init__(
self.indata = self.adata.X
self.indata_genes = self.adata.var_names
self.indata_names = self.adata.obs_names
elif isinstance(filename, AnnData) or (isinstance(filename, str) and filename.endswith('.h5ad')):
self.adata = sc.read(filename) if isinstance(filename, str) else filename
elif isinstance(filename, (AnnData, np.ndarray, pd.DataFrame)) or (isinstance(filename, str)
and filename.endswith('.h5ad')):
if isinstance(filename, (np.ndarray, pd.DataFrame)):
self.adata = AnnData(filename)
elif isinstance(filename, str):
self.adata = sc.read(filename)
else:
self.adata = filename
self.adata.var_names_make_unique()

if self.adata.X.min() < 0:
logger.info(" Detected scaled expression in the input data, will try the `.raw` attribute")
try:
Expand Down Expand Up @@ -991,57 +992,21 @@ def majority_vote(predictions: AnnotationResult, over_clustering: Union[list, tu
return predictions


class Celltypist():
r"""Build the ACTINN model.
Parameters
----------
classifier : Classification function
Class that wraps the celltyping and majority voting processes, as defined above
scaler : StandardScaler
The scale factor for normalization.
description : str
text description of the model.
"""
class Celltypist:

def __init__(self, clf=None, scaler=None, description=None):
self.classifier = clf
self.scaler = scaler
self.description = description

def fit(
self,
X=None,
labels: Optional[Union[str, list, tuple, np.ndarray, pd.Series, pd.Index]] = None,
genes: Optional[Union[str, list, tuple, np.ndarray, pd.Series, pd.Index]] = None,
transpose_input: bool = False,
check_expression: bool = True,
#LR param
C: float = 1.0,
solver: Optional[str] = None,
max_iter: int = 1000,
n_jobs: Optional[int] = None,
#SGD param
use_SGD: bool = False,
alpha: float = 0.0001,
#mini-batch
mini_batch: bool = False,
batch_number: int = 100,
batch_size: int = 1000,
epochs: int = 10,
balance_cell_type: bool = False,
#feature selection
feature_selection: bool = False,
top_genes: int = 300,
#description
date: str = '',
details: str = '',
url: str = '',
source: str = '',
version: str = '',
#other param
**kwargs):
def fit(self, X=None, labels: Optional[Union[str, list, tuple, np.ndarray, pd.Series,
pd.Index]] = None, genes: Optional[Union[str, list, tuple, np.ndarray,
pd.Series, pd.Index]] = None,
transpose_input: bool = False, check_expression: bool = True, C: float = 1.0, solver: Optional[str] = None,
max_iter: int = 1000, n_jobs: Optional[int] = None, use_SGD: bool = False, alpha: float = 0.0001,
mini_batch: bool = False, batch_number: int = 100, batch_size: int = 1000, epochs: int = 10,
balance_cell_type: bool = False, feature_selection: bool = False, top_genes: int = 300, date: str = '',
details: str = '', url: str = '', source: str = '', version: str = '', **kwargs):
"""Train a celltypist model using mini-batch (optional) logistic classifier with
a global solver or stochastic gradient descent (SGD) learning.
Expand Down Expand Up @@ -1328,38 +1293,27 @@ def predict(self, filename: Union[AnnData, str] = "", check_expression: bool = F
print(predictions)
return Classifier.majority_vote(predictions, over_clustering, min_prop=min_prop)

def score(self, input_adata, predictions, labels, map, label_conversion=False):
"""Run the prediction and (optional) majority voting to evaluate the model
performance.
def score(self, pred, true):
"""Model performance score measured by accuracy.
Parameters
----------
input_adata: Anndata
Input data anndata with label ground truth
predictions: classifier.AnnotationResult
Output from prediction function.
labels : string
column name for annotated cell types in the input Anndata
label_conversion: boolean optional
whether to match predicted labels to annotated cell type labels provided in input_adata
map: dictionary
a dictionary for label conversion
pred: np.ndarray
Predicted labels.
true: np.ndarray
True labels. Can be either a maxtrix of size (samples x labels) with ones indicating positives, or a
vector of size (sameples x 1) where each element is the index of the corresponding label for the sample.
The first option provides flexibility to cases where a sample could be associated with multiple labels
at test time while the model was trained as a multi-class classifier.
Returns
-------
correct: int
Number of correct predictions
Accuracy: float
Prediction accuracy from the model
score: float
Accuracy score.
"""
pred_labels = np.array(predictions.predicted_labels)[:, 0]
if label_conversion:
correct = 0
#for i, cell in enumerate(np.array(adVal.obs.Cell_type)):
for i, cell in enumerate(np.array(input_adata.obs[labels])):
if pred_labels[i] in map[cell]:
correct += 1
if true.max() == 1:
num_samples = true.shape[0]
return (true[range(num_samples), pred.ravel()]).sum() / num_samples
else:
correct = sum(np.array(input_adata.obs[labels]) == pred_labels)
return (correct / len(predictions.predicted_labels))
return accuracy_score(pred, true)
12 changes: 10 additions & 2 deletions dance/transforms/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,8 +1114,16 @@ def to_array_celltypist(_array_like):
def prepare_data_celltypist(X, labels, genes, transpose):
if (X is None) or (labels is None):
raise Exception("?? Missing training data and/or training labels. Please provide both arguments")
if isinstance(X, AnnData) or (isinstance(X, str) and X.endswith('.h5ad')):
adata = sc.read(X) if isinstance(X, str) else X
if isinstance(X, (AnnData, np.ndarray, pd.DataFrame)) or (isinstance(X, str) and X.endswith('.h5ad')):
if isinstance(X, str):
adata = sc.read(X)
elif isinstance(X, np.ndarray):
adata = AnnData(pd.DataFrame(X, columns=list(map(str, range(X.shape[1])))))
elif isinstance(X, pd.DataFrame):
adata = AnnData(X)
else:
adata = X

adata.var_names_make_unique()
if adata.X.min() < 0:
logger.info("?? Detected scaled expression in the input data, will try the .raw attribute")
Expand Down
55 changes: 27 additions & 28 deletions examples/single_modality/cell_type_annotation/celltypist.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import argparse
import time

import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse

from dance.data import Data
from dance.datasets.singlemodality import CellTypeDataset
from dance.modules.single_modality.cell_type_annotation.celltypist import Celltypist
from dance.utils.preprocess import cell_label_to_df

if __name__ == '__main__':

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--random_seed", type=int, default=10)
parser.add_argument("--train_dataset", default=4682, type=int, help="train id")
Expand All @@ -22,43 +16,48 @@
default="Cell_type")
parser.add_argument("--check_expression", type=bool,
help="whether to check the normalization of training and test data", default=False)
parser.add_argument("--species", default='mouse', type=str)
parser.add_argument("--tissue", default='Kidney', type=str)
parser.add_argument("--train_dir", type=str, default='train')
parser.add_argument("--test_dir", type=str, default='test')
parser.add_argument("--proj_path", type=str, default='./')
parser.add_argument("--map_path", type=str, default='map/mouse/')
parser.add_argument("--species", default="mouse", type=str)
parser.add_argument("--tissue", default="Kidney", type=str)
parser.add_argument("--train_dir", type=str, default="train")
parser.add_argument("--test_dir", type=str, default="test")
parser.add_argument("--proj_path", type=str, default="./")
parser.add_argument("--map_path", type=str, default="map/mouse/")
parser.add_argument("--n_jobs", type=int, help="Number of jobs", default=10)
parser.add_argument("--max_iter", type=int, help="Max iteration during training", default=5)
parser.add_argument("--use_SGD", type=bool,
help="Training algorithm -- weather it will be stochastic gradient descent", default=True)
parser.add_argument("--label_conversion", type=bool,
help="whether to convert cell type labels between training and test dataset for scoring",
default=False)
args = parser.parse_args()

dataloader = CellTypeDataset(random_seed=args.random_seed, data_type="celltypist", proj_path=args.proj_path,
train_dir=args.train_dir, test_dir=args.test_dir, train_dataset=args.train_dataset,
test_dataset=args.test_dataset, species=args.species, tissue=args.tissue,
map_path=args.map_path)
dataloader.load_data()

adata, cell_labels, idx_to_label, train_size = dataloader.load_data()
adata.obsm["cell_type"] = cell_label_to_df(cell_labels, idx_to_label, index=adata.obs_names)
data = Data(adata, train_size=train_size)
data.set_config(label_channel="cell_type")

x_train, y_train = data.get_train_data()
y_train = y_train.argmax(1)
model = Celltypist()
model_fs = model.fit(dataloader.train_adata, labels=args.cell_type_train, check_expression=args.check_expression,
n_jobs=args.n_jobs, max_iter=args.max_iter, use_SGD=args.use_SGD)
predictions = model.predict(dataloader.test_adata, check_expression=args.check_expression)
accuracy = model.score(dataloader.test_adata, predictions=predictions, labels=args.cell_type_test,
map=dataloader.map_dict[args.test_dataset], label_conversion=True)
print(accuracy)
model.fit(x_train, y_train, check_expression=args.check_expression, n_jobs=args.n_jobs, max_iter=args.max_iter,
use_SGD=args.use_SGD)

x_test, y_test = data.get_test_data()
pred_obj = model.predict(x_test, check_expression=args.check_expression)
pred = pred_obj.predicted_labels["predicted_labels"].values
score = model.score(pred, y_test)
print(f"{score=}")
"""To reproduce CellTypist benchmarks, please refer to command lines belows:
Mouse Brain
$ python celltypist.py --species mouse --tissue Brain --train_dataset 753 --test_dataset 2695 --label_conversion True
$ python celltypist.py --species mouse --tissue Brain --train_dataset 753 --test_dataset 2695
Mouse Spleen
$ python celltypist.py --species mouse --tissue Spleen --train_dataset 1970 --test_dataset 1759 --label_conversion False
$ python celltypist.py --species mouse --tissue Spleen --train_dataset 1970 --test_dataset 1759
Mouse Kidney
$ python celltypist.py --species mouse --tissue Kidney --train_dataset 4682 --test_dataset 203 --label_conversion False
$ python celltypist.py --species mouse --tissue Kidney --train_dataset 4682 --test_dataset 203
"""

0 comments on commit b027bd1

Please sign in to comment.