diff --git a/dance/transforms/cell_feature.py b/dance/transforms/cell_feature.py index 5179f50e..2ba22e1c 100644 --- a/dance/transforms/cell_feature.py +++ b/dance/transforms/cell_feature.py @@ -7,7 +7,7 @@ from dance.utils.matrix import normalize -class WeightedGenePCA(BaseTransform): +class WeightedFeaturePCA(BaseTransform): """Compute the weighted gene PCA as cell features. Given a gene expression matrix of dimension (cell x gene), the gene PCA is first compured. Then, the representation @@ -16,7 +16,7 @@ class WeightedGenePCA(BaseTransform): """ def __init__(self, n_components: int = 400, split_name: Optional[str] = None, **kwargs): - """Initialize WeightedGenePCA. + """Initialize WeightedFeaturePCA. Parameters ---------- diff --git a/dance/transforms/graph/__init__.py b/dance/transforms/graph/__init__.py new file mode 100644 index 00000000..3e3c1c6b --- /dev/null +++ b/dance/transforms/graph/__init__.py @@ -0,0 +1,6 @@ +from dance.transforms.graph.cell_feature_graph import CellFeatureGraph, PCACellFeatureGraph + +__all__ = [ + "CellFeatureGraph", + "PCACellFeatureGraph", +] # yapf: disable diff --git a/dance/transforms/graph.py b/dance/transforms/graph/cell_feature_graph.py similarity index 84% rename from dance/transforms/graph.py rename to dance/transforms/graph/cell_feature_graph.py index 9312beb3..994ff9ab 100644 --- a/dance/transforms/graph.py +++ b/dance/transforms/graph/cell_feature_graph.py @@ -3,11 +3,11 @@ import torch from dance.transforms.base import BaseTransform -from dance.transforms.cell_feature import WeightedGenePCA +from dance.transforms.cell_feature import WeightedFeaturePCA from dance.typing import LogLevel, Optional -class CellGeneGraph(BaseTransform): +class CellFeatureGraph(BaseTransform): def __init__(self, cell_feature_channel: str, gene_feature_channel: Optional[str] = None, *, layer: Optional[str] = None, mod: Optional[str] = None, **kwargs): @@ -52,9 +52,9 @@ def __call__(self, data): g.add_edges(g.nodes(), g.nodes(), {"weight": torch.ones(g.number_of_nodes())[:, None]}) gene_feature = data.get_feature(return_type="torch", channel=self.gene_feature_channel, mod=self.mod, - channel_type="var") + channel_type="varm") cell_feature = data.get_feature(return_type="torch", channel=self.cell_feature_channel, mod=self.mod, - channel_type="obs") + channel_type="obsm") g.ndata["features"] = torch.vstack((gene_feature, cell_feature)) data.data.uns[self.out] = g @@ -62,7 +62,7 @@ def __call__(self, data): return data -class PCACellGeneGraph(BaseTransform): +class PCACellFeatureGraph(BaseTransform): def __init__(self, n_components: int = 400, split_name: Optional[str] = None, *, layer: Optional[str] = None, mod: Optional[str] = None, log_level: LogLevel = "WARNING"): @@ -75,7 +75,7 @@ def __init__(self, n_components: int = 400, split_name: Optional[str] = None, *, self.mod = mod def __call__(self, data): - WeightedGenePCA(self.n_components, self.split_name, log_level=self.log_level)(data) - CellGeneGraph(cell_feature_channel="WeightedGenePCA", layer=self.layer, mod=self.mod, - log_level=self.log_level)(data) + WeightedFeaturePCA(self.n_components, self.split_name, log_level=self.log_level)(data) + CellFeatureGraph(cell_feature_channel="WeightedFeaturePCA", layer=self.layer, mod=self.mod, + log_level=self.log_level)(data) return data diff --git a/examples/single_modality/cell_type_annotation/scdeepsort.py b/examples/single_modality/cell_type_annotation/scdeepsort.py index 8db442fb..5f47e235 100644 --- a/examples/single_modality/cell_type_annotation/scdeepsort.py +++ b/examples/single_modality/cell_type_annotation/scdeepsort.py @@ -6,7 +6,7 @@ from dance.data import Data from dance.datasets.singlemodality import CellTypeDataset from dance.modules.single_modality.cell_type_annotation.scdeepsort import ScDeepSort -from dance.transforms.graph import PCACellGeneGraph +from dance.transforms.graph import PCACellFeatureGraph from dance.utils.preprocess import cell_label_to_df if __name__ == "__main__": @@ -49,7 +49,7 @@ adata, cell_labels, idx_to_label, train_size = dataloader.load_data() adata.obsm["cell_type"] = cell_label_to_df(cell_labels, idx_to_label, index=adata.obs.index) data = Data(adata, train_size=train_size) - PCACellGeneGraph(n_components=params.dense_dim, split_name="train", log_level="INFO")(data) + PCACellFeatureGraph(n_components=params.dense_dim, split_name="train", log_level="INFO")(data) data.set_config(label_channel="cell_type") y_train = data.get_y(split_name="train", return_type="torch").argmax(1) @@ -57,7 +57,7 @@ num_labels = y_test.shape[1] # TODO: make api for the following blcok? - g = data.data.uns["CellGeneGraph"] + g = data.data.uns["CellFeatureGraph"] num_genes = data.num_features gene_ids = torch.arange(num_genes) train_cell_ids = torch.LongTensor(data.train_idx) + num_genes diff --git a/examples/single_modality/cell_type_annotation/svm.py b/examples/single_modality/cell_type_annotation/svm.py index 134ec996..3cf69c83 100644 --- a/examples/single_modality/cell_type_annotation/svm.py +++ b/examples/single_modality/cell_type_annotation/svm.py @@ -4,7 +4,7 @@ from dance.data import Data from dance.datasets.singlemodality import CellTypeDataset from dance.modules.single_modality.cell_type_annotation.svm import SVM -from dance.transforms.cell_feature import WeightedGenePCA +from dance.transforms.cell_feature import WeightedFeaturePCA from dance.utils.preprocess import cell_label_to_df if __name__ == "__main__": @@ -35,9 +35,9 @@ adata, cell_labels, idx_to_label, train_size = dataloader.load_data() adata.obsm["cell_type"] = cell_label_to_df(cell_labels, idx_to_label, index=adata.obs.index) data = Data(adata, train_size=train_size) - WeightedGenePCA(n_components=params.dense_dim, split_name="train", log_level="INFO")(data) + WeightedFeaturePCA(n_components=params.dense_dim, split_name="train", log_level="INFO")(data) - data.set_config(feature_channel="WeightedGenePCA", label_channel="cell_type") + data.set_config(feature_channel="WeightedFeaturePCA", label_channel="cell_type") x_train, y_train = data.get_train_data() y_train_converted = y_train.argmax(1) # convert one-hot representation into label index representation