Skip to content

Commit

Permalink
create cell feature function(updated version of gen_batch_feature) (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
RemyLau authored Oct 16, 2022
1 parent 4a45b71 commit b78764f
Showing 1 changed file with 83 additions and 3 deletions.
86 changes: 83 additions & 3 deletions dance/transforms/graph_construct.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright 2022 DSE lab. All rights reserved.

import glob
import itertools
import os
import pickle
import random
import time
from collections import defaultdict
from typing import List, Union

import anndata as ad
import dgl
import networkx as nx
import numba
Expand All @@ -27,6 +26,7 @@
from torch.nn import functional as F

import dance.transforms.preprocess
from dance import logger


@numba.njit("f4(f4[:], f4[:])")
Expand Down Expand Up @@ -486,6 +486,86 @@ def gen_batch_features(ad_inputs):
return batch_features


def generate_cell_features(
data: Union[ad.AnnData, List[ad.AnnData]],
*,
group_batch: bool = False,
batch_col_id: str = "batch",
) -> torch.Tensor:
"""Generate cell node features from anndata objects.
Parameters
----------
data: Union[anndata.AnnData, List[anndata.AnnData]]
A list of or a single AnnData object(s).
group_batch: bool
If set to True, set features of cell within a batch to the mean values.
batch_col_id: str
Column ID corresponding to the batchs.
Returns
-------
cell_features: torch.Tensor
A cell feature matrix, each row represents the node features corresponding to a cell, generated based on
the statistics of the cell's gene expression profiles.
TODO
----
Add option for providing call-backs for additional flexibility of generating different types of features.
"""
data = data if isinstance(data, list) else [data]

cells = []
columns = [
"cell_mean", "cell_std", "nonzero_25%", "nonzero_50%", "nonzero_75%", "nonzero_max", "nonzero_count",
"nonzero_mean", "nonzero_std", "batch"
]

for adata in data:
bcl = adata.obs[batch_col_id].tolist()
logger.info(f"Unique batches: {sorted(set(bcl))}")

for i, cell in enumerate(adata.X):
cell = cell.toarray()
nz = cell[np.nonzero(cell)]

if len(nz) == 0:
logger.warning("Encountered a cell with all zero features.")
cells.append([0] * (len(columns) - 1) + bcl[i])
else:
cells.append([
cell.mean(),
cell.std(),
np.percentile(nz, 25),
np.percentile(nz, 50),
np.percentile(nz, 75),
cell.max(),
len(nz) / 1000,
nz.mean(),
nz.std(),
bcl[i],
])

features = pd.DataFrame(cells, columns=columns)
logger.debug(f"features=\n{features}")

if group_batch:
batch_source = features.groupby("batch", as_index=False).mean()
logger.info(f"Batch features:\n{batch_source.set_index('batch')}")

# Assign cell features with corresponding reduced batch features
b2i = {j: i for i, j in enumerate(batch_source["batch"].tolist())}
batch_source = batch_source.drop("batch", axis=1).to_numpy()
cell_batch_idxs = list(map(b2i.get, features["batch"]))
cell_features = batch_source[cell_batch_idxs]

else:
cell_features = features.drop("batch", axis=1).to_numpy()

return torch.tensor(cell_features, dtype=torch.float32)


def construct_modality_prediction_graph(dataset, **kwargs):
"""Construct the cell-feature graph object for modality prediction task, based on
the input dataset.
Expand Down

0 comments on commit b78764f

Please sign in to comment.