Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor graph transforms and rename Gene -> Feature #82

Merged
merged 1 commit into from
Dec 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dance/transforms/cell_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dance.utils.matrix import normalize


class WeightedGenePCA(BaseTransform):
class WeightedFeaturePCA(BaseTransform):
"""Compute the weighted gene PCA as cell features.

Given a gene expression matrix of dimension (cell x gene), the gene PCA is first compured. Then, the representation
Expand All @@ -16,7 +16,7 @@ class WeightedGenePCA(BaseTransform):
"""

def __init__(self, n_components: int = 400, split_name: Optional[str] = None, **kwargs):
"""Initialize WeightedGenePCA.
"""Initialize WeightedFeaturePCA.

Parameters
----------
Expand Down
6 changes: 6 additions & 0 deletions dance/transforms/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dance.transforms.graph.cell_feature_graph import CellFeatureGraph, PCACellFeatureGraph

__all__ = [
"CellFeatureGraph",
"PCACellFeatureGraph",
] # yapf: disable
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch

from dance.transforms.base import BaseTransform
from dance.transforms.cell_feature import WeightedGenePCA
from dance.transforms.cell_feature import WeightedFeaturePCA
from dance.typing import LogLevel, Optional


class CellGeneGraph(BaseTransform):
class CellFeatureGraph(BaseTransform):

def __init__(self, cell_feature_channel: str, gene_feature_channel: Optional[str] = None, *,
layer: Optional[str] = None, mod: Optional[str] = None, **kwargs):
Expand Down Expand Up @@ -52,17 +52,17 @@ def __call__(self, data):
g.add_edges(g.nodes(), g.nodes(), {"weight": torch.ones(g.number_of_nodes())[:, None]})

gene_feature = data.get_feature(return_type="torch", channel=self.gene_feature_channel, mod=self.mod,
channel_type="var")
channel_type="varm")
cell_feature = data.get_feature(return_type="torch", channel=self.cell_feature_channel, mod=self.mod,
channel_type="obs")
channel_type="obsm")
g.ndata["features"] = torch.vstack((gene_feature, cell_feature))

data.data.uns[self.out] = g

return data


class PCACellGeneGraph(BaseTransform):
class PCACellFeatureGraph(BaseTransform):

def __init__(self, n_components: int = 400, split_name: Optional[str] = None, *, layer: Optional[str] = None,
mod: Optional[str] = None, log_level: LogLevel = "WARNING"):
Expand All @@ -75,7 +75,7 @@ def __init__(self, n_components: int = 400, split_name: Optional[str] = None, *,
self.mod = mod

def __call__(self, data):
WeightedGenePCA(self.n_components, self.split_name, log_level=self.log_level)(data)
CellGeneGraph(cell_feature_channel="WeightedGenePCA", layer=self.layer, mod=self.mod,
log_level=self.log_level)(data)
WeightedFeaturePCA(self.n_components, self.split_name, log_level=self.log_level)(data)
CellFeatureGraph(cell_feature_channel="WeightedFeaturePCA", layer=self.layer, mod=self.mod,
log_level=self.log_level)(data)
return data
6 changes: 3 additions & 3 deletions examples/single_modality/cell_type_annotation/scdeepsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dance.data import Data
from dance.datasets.singlemodality import CellTypeDataset
from dance.modules.single_modality.cell_type_annotation.scdeepsort import ScDeepSort
from dance.transforms.graph import PCACellGeneGraph
from dance.transforms.graph import PCACellFeatureGraph
from dance.utils.preprocess import cell_label_to_df

if __name__ == "__main__":
Expand Down Expand Up @@ -49,15 +49,15 @@
adata, cell_labels, idx_to_label, train_size = dataloader.load_data()
adata.obsm["cell_type"] = cell_label_to_df(cell_labels, idx_to_label, index=adata.obs.index)
data = Data(adata, train_size=train_size)
PCACellGeneGraph(n_components=params.dense_dim, split_name="train", log_level="INFO")(data)
PCACellFeatureGraph(n_components=params.dense_dim, split_name="train", log_level="INFO")(data)
data.set_config(label_channel="cell_type")

y_train = data.get_y(split_name="train", return_type="torch").argmax(1)
y_test = data.get_y(split_name="test", return_type="torch")
num_labels = y_test.shape[1]

# TODO: make api for the following blcok?
g = data.data.uns["CellGeneGraph"]
g = data.data.uns["CellFeatureGraph"]
num_genes = data.num_features
gene_ids = torch.arange(num_genes)
train_cell_ids = torch.LongTensor(data.train_idx) + num_genes
Expand Down
6 changes: 3 additions & 3 deletions examples/single_modality/cell_type_annotation/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dance.data import Data
from dance.datasets.singlemodality import CellTypeDataset
from dance.modules.single_modality.cell_type_annotation.svm import SVM
from dance.transforms.cell_feature import WeightedGenePCA
from dance.transforms.cell_feature import WeightedFeaturePCA
from dance.utils.preprocess import cell_label_to_df

if __name__ == "__main__":
Expand Down Expand Up @@ -35,9 +35,9 @@
adata, cell_labels, idx_to_label, train_size = dataloader.load_data()
adata.obsm["cell_type"] = cell_label_to_df(cell_labels, idx_to_label, index=adata.obs.index)
data = Data(adata, train_size=train_size)
WeightedGenePCA(n_components=params.dense_dim, split_name="train", log_level="INFO")(data)
WeightedFeaturePCA(n_components=params.dense_dim, split_name="train", log_level="INFO")(data)

data.set_config(feature_channel="WeightedGenePCA", label_channel="cell_type")
data.set_config(feature_channel="WeightedFeaturePCA", label_channel="cell_type")

x_train, y_train = data.get_train_data()
y_train_converted = y_train.argmax(1) # convert one-hot representation into label index representation
Expand Down