Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: wrap method specific composed preprocessing pipelines for clustering examples #192

Merged
merged 6 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dance/modules/single_modality/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from .scdcc import ScDCC
from .scdeepcluster import ScDeepCluster
from .scdsc import SCDSC
from .sctag import SCTAG
from .sctag import ScTAG

__all__ = [
"GraphSC",
"ScDCC",
"ScDeepCluster",
"SCDSC",
"SCTAG",
"ScTAG",
]
43 changes: 43 additions & 0 deletions dance/modules/single_modality/clustering/graphsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from torch.nn.functional import binary_cross_entropy_with_logits as BCELoss
from tqdm import tqdm

from dance.transforms import AnnDataTransform, Compose, SetConfig
from dance.transforms.graph import PCACellFeatureGraph
from dance.typing import LogLevel


class GraphSC:
"""GraphSC class.
Expand All @@ -47,6 +51,45 @@ def __init__(self, args):
self.device = get_device(args.use_cpu)
self.model = GCNAE(args).to(self.device)

@staticmethod
def preprocessing_pipeline(n_top_genes: int = 3000, normalize_weights: str = "log_per_cell", n_components: int = 50,
normalize_edges: bool = False, log_level: LogLevel = "INFO"):
transforms = [
AnnDataTransform(sc.pp.filter_genes, min_counts=3),
AnnDataTransform(sc.pp.filter_cells, min_counts=1),
AnnDataTransform(sc.pp.normalize_total),
AnnDataTransform(sc.pp.log1p),
AnnDataTransform(sc.pp.highly_variable_genes, min_mean=0.0125, max_mean=4, flavor="cell_ranger",
min_disp=0.5, n_top_genes=n_top_genes, subset=True),
]

if normalize_weights == "log_per_cell":
transforms.extend([
AnnDataTransform(sc.pp.log1p),
AnnDataTransform(sc.pp.normalize_total, target_sum=1),
])
elif normalize_weights == "per_cell":
transforms.append(AnnDataTransform(sc.pp.normalize_total, target_sum=1))
elif normalize_weights != "none":
raise ValueError(f"Unknown normalization option {normalize_weights!r}."
"Available options are: 'none', 'log_per_cell', 'per_cell'")

# Cell-gene graph construction
transforms.extend([
PCACellFeatureGraph(
n_components=n_components,
normalize_edges=normalize_edges,
feat_norm_mode="standardize",
),
SetConfig({
"feature_channel": "CellFeatureGraph",
"feature_channel_type": "uns",
"label_channel": "labels",
}),
])

return Compose(*transforms, log_level=log_level)

def fit(self, g, n_epochs, n_clusters, lr, cluster=["KMeans"]):
"""Train graph-sc.

Expand Down
20 changes: 20 additions & 0 deletions dance/modules/single_modality/clustering/scdcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os

import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -23,6 +24,8 @@
from torch.nn import Parameter
from torch.utils.data import DataLoader, TensorDataset

from dance.transforms import AnnDataTransform, Compose, SaveRaw, SetConfig
from dance.typing import LogLevel
from dance.utils.loss import ZINBLoss
from dance.utils.metrics import cluster_acc

Expand Down Expand Up @@ -106,6 +109,23 @@ def __init__(self, input_dim, z_dim, n_clusters, encodeLayer=[], decodeLayer=[],
self.mu = Parameter(torch.Tensor(n_clusters, z_dim))
self.zinb_loss = ZINBLoss().cpu()

@staticmethod
def preprocessing_pipeline(log_level: LogLevel = "INFO"):
return Compose(
AnnDataTransform(sc.pp.filter_genes, min_counts=1),
AnnDataTransform(sc.pp.filter_cells, min_counts=1),
SaveRaw(),
AnnDataTransform(sc.pp.normalize_total),
AnnDataTransform(sc.pp.log1p),
AnnDataTransform(sc.pp.scale),
SetConfig({
"feature_channel": [None, None, "n_counts"],
"feature_channel_type": ["X", "raw_X", "obs"],
"label_channel": "Group"
}),
log_level=log_level,
)

def save_model(self, path):
"""Save model to path.

Expand Down
20 changes: 20 additions & 0 deletions dance/modules/single_modality/clustering/scdeepcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os

import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -23,6 +24,8 @@
from torch.nn import Parameter
from torch.utils.data import DataLoader, TensorDataset

from dance.transforms import AnnDataTransform, Compose, SaveRaw, SetConfig
from dance.typing import LogLevel
from dance.utils.loss import ZINBLoss
from dance.utils.metrics import cluster_acc

Expand Down Expand Up @@ -106,6 +109,23 @@ def __init__(self, input_dim, z_dim, encodeLayer=[], decodeLayer=[], activation=
self.zinb_loss = ZINBLoss().to(self.device)
self.to(device)

@staticmethod
def preprocessing_pipeline(log_level: LogLevel = "INFO"):
return Compose(
AnnDataTransform(sc.pp.filter_genes, min_counts=1),
AnnDataTransform(sc.pp.filter_cells, min_counts=1),
SaveRaw(),
AnnDataTransform(sc.pp.normalize_total),
AnnDataTransform(sc.pp.log1p),
AnnDataTransform(sc.pp.scale),
SetConfig({
"feature_channel": [None, None, "n_counts"],
"feature_channel_type": ["X", "raw_X", "obs"],
"label_channel": "Group",
}),
log_level=log_level,
)

def save_model(self, path):
"""Save model to path.

Expand Down
31 changes: 31 additions & 0 deletions dance/modules/single_modality/clustering/scdsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import math

import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -22,7 +23,10 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, TensorDataset

from dance.transforms import AnnDataTransform, Compose, SaveRaw, SetConfig
from dance.transforms.graph import NeighborGraph
from dance.transforms.preprocess import sparse_mx_to_torch_sparse_tensor
from dance.typing import LogLevel
from dance.utils.loss import ZINBLoss
from dance.utils.metrics import cluster_acc

Expand All @@ -46,6 +50,33 @@ def __init__(self, args):
n_dec_2=args.n_dec_2, n_dec_3=args.n_dec_3, n_input=args.n_input, n_z1=args.n_z1,
n_z2=args.n_z2, n_z3=args.n_z3).to(self.device)

@staticmethod
def preprocessing_pipeline(n_top_genes: int = 2000, n_neighbors: int = 50, log_level: LogLevel = "INFO"):
return Compose(
# Filter data
AnnDataTransform(sc.pp.filter_genes, min_counts=3),
AnnDataTransform(sc.pp.filter_cells, min_counts=1),
AnnDataTransform(sc.pp.normalize_per_cell),
AnnDataTransform(sc.pp.log1p),
AnnDataTransform(sc.pp.highly_variable_genes, min_mean=0.0125, max_mean=4, flavor="cell_ranger",
min_disp=0.5, n_top_genes=n_top_genes, subset=True),
# Normalize data
AnnDataTransform(sc.pp.filter_genes, min_counts=1),
AnnDataTransform(sc.pp.filter_cells, min_counts=1),
SaveRaw(),
AnnDataTransform(sc.pp.normalize_total),
AnnDataTransform(sc.pp.log1p),
AnnDataTransform(sc.pp.scale),
# Construct k-neighbors graph using the noramlized feature matrix
NeighborGraph(n_neighbors=n_neighbors, metric="correlation", channel="X"),
SetConfig({
"feature_channel": [None, None, "n_counts", "NeighborGraph"],
"feature_channel_type": ["X", "raw_X", "obs", "obsp"],
"label_channel": "Group"
}),
log_level=log_level,
)

def target_distribution(self, q):
"""Calculate auxiliary target distribution p with q.

Expand Down
35 changes: 34 additions & 1 deletion dance/modules/single_modality/clustering/sctag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import dgl
import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -21,11 +22,14 @@
from sklearn.cluster import KMeans
from torch.nn import Parameter

from dance.transforms import AnnDataTransform, CellPCA, Compose, SaveRaw, SetConfig
from dance.transforms.graph import NeighborGraph
from dance.typing import LogLevel
from dance.utils.loss import ZINBLoss, dist_loss
from dance.utils.metrics import cluster_acc


class SCTAG(nn.Module):
class ScTAG(nn.Module):
"""scTAG class.

Parameters
Expand Down Expand Up @@ -90,6 +94,35 @@ def __init__(self, X, adj, n_clusters, k=3, hidden_dim=128, latent_dim=15, dec_d
self.zinb_loss = ZINBLoss().to(self.device)
self.to(self.device)

@staticmethod
def preprocessing_pipeline(n_top_genes: int = 3000, n_components: int = 50, n_neighbors: int = 15,
log_level: LogLevel = "INFO"):
return Compose(
# Filter data
AnnDataTransform(sc.pp.filter_genes, min_counts=3),
AnnDataTransform(sc.pp.filter_cells, min_counts=1),
AnnDataTransform(sc.pp.normalize_per_cell),
AnnDataTransform(sc.pp.log1p),
AnnDataTransform(sc.pp.highly_variable_genes, min_mean=0.0125, max_mean=4, flavor="cell_ranger",
min_disp=0.5, n_top_genes=n_top_genes, subset=True),
# Normalize data
AnnDataTransform(sc.pp.filter_genes, min_counts=1),
AnnDataTransform(sc.pp.filter_cells, min_counts=1),
SaveRaw(),
AnnDataTransform(sc.pp.normalize_total),
AnnDataTransform(sc.pp.log1p),
AnnDataTransform(sc.pp.scale),
# Construct k-neighbors graph
CellPCA(n_components=n_components),
NeighborGraph(n_neighbors=n_neighbors, n_pcs=n_components),
SetConfig({
"feature_channel": [None, None, "n_counts", "NeighborGraph"],
"feature_channel_type": ["X", "raw_X", "obs", "obsp"],
"label_channel": "Group",
}),
log_level=log_level,
)

def forward(self, A_in, X_input):
"""Forward propagation.

Expand Down
30 changes: 7 additions & 23 deletions examples/single_modality/clustering/graphsc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import argparse

import numpy as np
import scanpy as sc

from dance.data import Data
from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.graphsc import GraphSC
from dance.transforms import AnnDataTransform
from dance.transforms.graph import PCACellFeatureGraph
from dance.utils import set_seed

if __name__ == "__main__":
Expand Down Expand Up @@ -46,27 +43,14 @@
adata.obsm["labels"] = labels
data = Data(adata, train_size="all")

# Filter data
AnnDataTransform(sc.pp.filter_genes, min_counts=3)(data)
AnnDataTransform(sc.pp.filter_cells, min_counts=1)(data)
AnnDataTransform(sc.pp.normalize_total)(data)
AnnDataTransform(sc.pp.log1p)(data)
AnnDataTransform(sc.pp.highly_variable_genes, min_mean=0.0125, max_mean=4, flavor="cell_ranger", min_disp=0.5,
n_top_genes=args.nb_genes, subset=True)(data)
preprocessing_pipeline = GraphSC.preprocessing_pipeline(
n_top_genes=args.nb_genes,
normalize_weights=args.normalize_weights,
n_components=args.in_feats,
normalize_edges=args.edge_norm,
)
preprocessing_pipeline(data)

# Normalize
if args.normalize_weights == "log_per_cell":
AnnDataTransform(sc.pp.log1p)(data)
AnnDataTransform(sc.pp.normalize_total, target_sum=1)(data)
elif args.normalize_weights == "per_cell":
AnnDataTransform(sc.pp.normalize_total, target_sum=1)(data)
elif args.normalize_weights != "none":
raise ValueError(f"Unknown normalization option {args.normalize_weights!r}."
"Available options are: 'none', 'log_per_cell', 'per_cell'")

# Construct cell-gene graph
PCACellFeatureGraph(n_components=args.in_feats, normalize_edges=args.edge_norm, feat_norm_mode="standardize")(data)
data.set_config(feature_channel="CellFeatureGraph", feature_channel_type="uns", label_channel="labels")
graph, y = data.get_train_data()
n_clusters = len(np.unique(y))

Expand Down
25 changes: 5 additions & 20 deletions examples/single_modality/clustering/scdcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
from time import time

import numpy as np
import scanpy as sc
import torch

from dance.data import Data
from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.scdcc import ScDCC
from dance.transforms import AnnDataTransform, SaveRaw
from dance.transforms.preprocess import generate_random_pair
from dance.utils import set_seed

# for repeatability
set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="train", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--label_cells", default=0.1, type=float)
Expand Down Expand Up @@ -45,19 +45,9 @@
adata.obsm["Group"] = labels
data = Data(adata, train_size="all")

# Normalize data
AnnDataTransform(sc.pp.filter_genes, min_counts=1)(data)
AnnDataTransform(sc.pp.filter_cells, min_counts=1)(data)
SaveRaw()(data)
AnnDataTransform(sc.pp.normalize_total)(data)
AnnDataTransform(sc.pp.log1p)(data)
AnnDataTransform(sc.pp.scale)(data)

data.set_config(
feature_channel=[None, None, "n_counts"],
feature_channel_type=["X", "raw_X", "obs"],
label_channel="Group",
)
preprocessing_pipeline = ScDCC.preprocessing_pipeline()
preprocessing_pipeline(data)

(x, x_raw, n_counts), y = data.get_train_data()
n_clusters = len(np.unique(y))

Expand All @@ -80,11 +70,6 @@

# Construct moodel
sigma = 2.75
use_cuda = torch.cuda.is_available()
if use_cuda:
device = "cuda"
else:
device = "cpu"
model = ScDCC(input_dim=x.shape[1], z_dim=32, n_clusters=n_clusters, encodeLayer=[256, 64], decodeLayer=[64, 256],
sigma=args.sigma, gamma=args.gamma, ml_weight=args.ml_weight, cl_weight=args.ml_weight).to(device)

Expand Down
Loading