From b78764f49a956274c09351869ae8b27087fa6b89 Mon Sep 17 00:00:00 2001 From: Remy Liu <36778645+RemyLau@users.noreply.github.com> Date: Sun, 16 Oct 2022 17:48:22 -0400 Subject: [PATCH] create cell feature function(updated version of gen_batch_feature) (#39) --- dance/transforms/graph_construct.py | 86 ++++++++++++++++++++++++++++- 1 file changed, 83 insertions(+), 3 deletions(-) diff --git a/dance/transforms/graph_construct.py b/dance/transforms/graph_construct.py index 8b96b5fa..d632e601 100644 --- a/dance/transforms/graph_construct.py +++ b/dance/transforms/graph_construct.py @@ -1,13 +1,12 @@ # Copyright 2022 DSE lab. All rights reserved. - -import glob import itertools import os import pickle -import random import time from collections import defaultdict +from typing import List, Union +import anndata as ad import dgl import networkx as nx import numba @@ -27,6 +26,7 @@ from torch.nn import functional as F import dance.transforms.preprocess +from dance import logger @numba.njit("f4(f4[:], f4[:])") @@ -486,6 +486,86 @@ def gen_batch_features(ad_inputs): return batch_features +def generate_cell_features( + data: Union[ad.AnnData, List[ad.AnnData]], + *, + group_batch: bool = False, + batch_col_id: str = "batch", +) -> torch.Tensor: + """Generate cell node features from anndata objects. + + Parameters + ---------- + data: Union[anndata.AnnData, List[anndata.AnnData]] + A list of or a single AnnData object(s). + group_batch: bool + If set to True, set features of cell within a batch to the mean values. + batch_col_id: str + Column ID corresponding to the batchs. + + Returns + ------- + cell_features: torch.Tensor + A cell feature matrix, each row represents the node features corresponding to a cell, generated based on + the statistics of the cell's gene expression profiles. + + TODO + ---- + Add option for providing call-backs for additional flexibility of generating different types of features. + + """ + data = data if isinstance(data, list) else [data] + + cells = [] + columns = [ + "cell_mean", "cell_std", "nonzero_25%", "nonzero_50%", "nonzero_75%", "nonzero_max", "nonzero_count", + "nonzero_mean", "nonzero_std", "batch" + ] + + for adata in data: + bcl = adata.obs[batch_col_id].tolist() + logger.info(f"Unique batches: {sorted(set(bcl))}") + + for i, cell in enumerate(adata.X): + cell = cell.toarray() + nz = cell[np.nonzero(cell)] + + if len(nz) == 0: + logger.warning("Encountered a cell with all zero features.") + cells.append([0] * (len(columns) - 1) + bcl[i]) + else: + cells.append([ + cell.mean(), + cell.std(), + np.percentile(nz, 25), + np.percentile(nz, 50), + np.percentile(nz, 75), + cell.max(), + len(nz) / 1000, + nz.mean(), + nz.std(), + bcl[i], + ]) + + features = pd.DataFrame(cells, columns=columns) + logger.debug(f"features=\n{features}") + + if group_batch: + batch_source = features.groupby("batch", as_index=False).mean() + logger.info(f"Batch features:\n{batch_source.set_index('batch')}") + + # Assign cell features with corresponding reduced batch features + b2i = {j: i for i, j in enumerate(batch_source["batch"].tolist())} + batch_source = batch_source.drop("batch", axis=1).to_numpy() + cell_batch_idxs = list(map(b2i.get, features["batch"])) + cell_features = batch_source[cell_batch_idxs] + + else: + cell_features = features.drop("batch", axis=1).to_numpy() + + return torch.tensor(cell_features, dtype=torch.float32) + + def construct_modality_prediction_graph(dataset, **kwargs): """Construct the cell-feature graph object for modality prediction task, based on the input dataset.