Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create cell feature function (updated version of gen_batch_feature) #39

Merged
merged 1 commit into from
Oct 16, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions dance/transforms/graph_construct.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright 2022 DSE lab. All rights reserved.

import glob
import itertools
import os
import pickle
import random
import time
from collections import defaultdict
from typing import List, Union

import anndata as ad
import dgl
import networkx as nx
import numba
Expand All @@ -27,6 +26,7 @@
from torch.nn import functional as F

import dance.transforms.preprocess
from dance import logger


@numba.njit("f4(f4[:], f4[:])")
Expand Down Expand Up @@ -486,6 +486,86 @@ def gen_batch_features(ad_inputs):
return batch_features


def generate_cell_features(
data: Union[ad.AnnData, List[ad.AnnData]],
*,
group_batch: bool = False,
batch_col_id: str = "batch",
) -> torch.Tensor:
"""Generate cell node features from anndata objects.

Parameters
----------
data: Union[anndata.AnnData, List[anndata.AnnData]]
A list of or a single AnnData object(s).
group_batch: bool
If set to True, set features of cell within a batch to the mean values.
batch_col_id: str
Column ID corresponding to the batchs.

Returns
-------
cell_features: torch.Tensor
A cell feature matrix, each row represents the node features corresponding to a cell, generated based on
the statistics of the cell's gene expression profiles.

TODO
----
Add option for providing call-backs for additional flexibility of generating different types of features.

"""
data = data if isinstance(data, list) else [data]

cells = []
columns = [
"cell_mean", "cell_std", "nonzero_25%", "nonzero_50%", "nonzero_75%", "nonzero_max", "nonzero_count",
"nonzero_mean", "nonzero_std", "batch"
]

for adata in data:
bcl = adata.obs[batch_col_id].tolist()
logger.info(f"Unique batches: {sorted(set(bcl))}")

for i, cell in enumerate(adata.X):
cell = cell.toarray()
nz = cell[np.nonzero(cell)]

if len(nz) == 0:
logger.warning("Encountered a cell with all zero features.")
cells.append([0] * (len(columns) - 1) + bcl[i])
else:
cells.append([
cell.mean(),
cell.std(),
np.percentile(nz, 25),
np.percentile(nz, 50),
np.percentile(nz, 75),
cell.max(),
len(nz) / 1000,
nz.mean(),
nz.std(),
bcl[i],
])

features = pd.DataFrame(cells, columns=columns)
logger.debug(f"features=\n{features}")

if group_batch:
batch_source = features.groupby("batch", as_index=False).mean()
logger.info(f"Batch features:\n{batch_source.set_index('batch')}")

# Assign cell features with corresponding reduced batch features
b2i = {j: i for i, j in enumerate(batch_source["batch"].tolist())}
batch_source = batch_source.drop("batch", axis=1).to_numpy()
cell_batch_idxs = list(map(b2i.get, features["batch"]))
cell_features = batch_source[cell_batch_idxs]

else:
cell_features = features.drop("batch", axis=1).to_numpy()

return torch.tensor(cell_features, dtype=torch.float32)


def construct_modality_prediction_graph(dataset, **kwargs):
"""Construct the cell-feature graph object for modality prediction task, based on
the input dataset.
Expand Down