Skip to content

Commit

Permalink
update stagate example script to use dance data object (#127)
Browse files Browse the repository at this point in the history
* refactor StagateGraph

* remove unused function, fix docstring and format

* add docstring

* update stagate example script to use dance data object

* remove deprecated graph construct functions
  • Loading branch information
RemyLau authored Jan 10, 2023
1 parent 7de5b40 commit b8c00eb
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 319 deletions.
161 changes: 47 additions & 114 deletions dance/modules/spatial/spatial_domain/stagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,18 @@
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.sparse as sp
import sklearn.neighbors
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import mixture
from sklearn.metrics.cluster import adjusted_rand_score
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.data import Data
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax
from torch_sparse import SparseTensor, set_diag
from tqdm import tqdm
Expand All @@ -40,62 +37,22 @@ def transfer_pytorch_data(adata, adj):
return data


def Stats_Spatial_Net(adata):
Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]
Mean_edge = Num_edge / adata.shape[0]
plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))
plot_df = plot_df / adata.shape[0]
fig, ax = plt.subplots(figsize=[3, 2])
plt.ylabel('Percentage')
plt.xlabel('')
plt.title('Number of Neighbors (Mean=%.2f)' % Mean_edge)
ax.bar(plot_df.index, plot_df)


def mclust_P(adata, num_cluster, used_obsm='STAGATE', modelNames='EEE'):
from sklearn import mixture
g = mixture.GaussianMixture(n_components=num_cluster, covariance_type='tied', warm_start=True, n_init=100,
def mclust(adata, num_cluster, used_obsm="STAGATE", modelNames="EEE"):
g = mixture.GaussianMixture(n_components=num_cluster, covariance_type="tied", warm_start=True, n_init=100,
max_iter=300, reg_covar=1.4663143602030552e-04, random_state=36282,
tol=0.00022187708009762592)
res = g.fit_predict(adata.obsm[used_obsm])
adata.obs['mclust'] = res
adata.obs["mclust"] = res
return adata


'''
def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
"""\
Clustering using the mclust algorithm.
The parameters are the same as those in the R package mclust.
"""
np.random.seed(random_seed)
import rpy2.robjects as robjects
robjects.r.library("mclust")
import rpy2.robjects.numpy2ri
rpy2.robjects.numpy2ri.activate()
r_random_seed = robjects.r['set.seed']
r_random_seed(random_seed)
rmclust = robjects.r['Mclust']
res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
mclust_res = np.array(res[-2])
adata.obs['mclust'] = mclust_res
adata.obs['mclust'] = adata.obs['mclust'].astype('int')
adata.obs['mclust'] = adata.obs['mclust'].astype('category')
return adata
'''


class GATConv(MessagePassing):
"""Graph attention layer from Graph Attention Network."""
_alpha = None

def __init__(self, in_channels, out_channels, heads: int = 1, concat: bool = True, negative_slope: float = 0.2,
dropout: float = 0.0, add_self_loops=True, bias=True, **kwargs):
kwargs.setdefault('aggr', 'add')
kwargs.setdefault("aggr", "add")
super().__init__(node_dim=0, **kwargs)

self.in_channels = in_channels
Expand Down Expand Up @@ -126,12 +83,12 @@ def forward(self, x, edge_index, size=None, return_attention_weights=None, atten
# We first transform the input node features. If a tuple is passed, we
# transform source and target node features via separate weights:
if isinstance(x, Tensor):
assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
assert x.dim() == 2, "Static graphs not supported in GATConv"
# x_src = x_dst = self.lin_src(x).view(-1, H, C)
x_src = x_dst = torch.mm(x, self.lin_src).view(-1, H, C)
else: # Tuple of source and target node features:
x_src, x_dst = x
assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
assert x_src.dim() == 2, "Static graphs not supported in GATConv"
x_src = self.lin_src(x_src).view(-1, H, C)
if x_dst is not None:
x_dst = self.lin_dst(x_dst).view(-1, H, C)
Expand All @@ -142,7 +99,7 @@ def forward(self, x, edge_index, size=None, return_attention_weights=None, atten
return x[0].mean(dim=1)
# return x[0].view(-1, self.heads * self.out_channels)

if tied_attention == None:
if tied_attention is None:
# Next, we compute node-level attention coefficients, both for source
# and target nodes (if present):
alpha_src = (x_src * self.att_src).sum(dim=-1)
Expand Down Expand Up @@ -180,7 +137,7 @@ def forward(self, x, edge_index, size=None, return_attention_weights=None, atten
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
return out, edge_index.set_value(alpha, layout="coo")
else:
return out

Expand All @@ -196,7 +153,7 @@ def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
return x_j * alpha.unsqueeze(-1)

def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads)
return "{}({}, {}, heads={})".format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads)


class Stagate(torch.nn.Module):
Expand All @@ -219,21 +176,19 @@ def __init__(self, hidden_dims):
self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False, dropout=0, add_self_loops=False, bias=False)

def forward(self, features, edge_index):
"""forward function for training.
"""Forward function for training.
Parameters
----------
features :
node features.
Node features.
edge_index :
adjacent matrix.
Adjacent matrix.
Returns
-------
h2 :
the second hidden layer.
h4 :
the forth hidden layer.
Tuple[Tensor, Tensor]
The second and the forth hidden layerx.
"""
h1 = F.elu(self.conv1(features, edge_index))
Expand All @@ -247,56 +202,50 @@ def forward(self, features, edge_index):

return h2, h4 # F.log_softmax(x, dim=-1)

def fit(self, adata, graph, n_epochs=1, lr=0.001, key_added='STAGATE', gradient_clipping=5., pre_resolution=0.2,
def fit(self, adata, graph, n_epochs=1, lr=0.001, key_added="STAGATE", gradient_clipping=5., pre_resolution=0.2,
weight_decay=0.0001, verbose=True, random_seed=0, save_loss=False, save_reconstrction=False,
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
"""fit function for training.
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")):
"""Fit function for training.
Parameters
----------
adata :
input data.
Input data.
graph :
graph structure.
Graph structure.
n_epochs : int optional
number of epochs.
Number of epochs.
lr : float optional
learning rate.
Learning rate.
key_added : str optional
by default 'STAGATE'.
Default "STAGATE".
gradient_clipping : float optional
gradient clipping.
Gradient clipping.
pre_resolution : float optional
pre resolution.
Pre-resolution.
weight_decay : float optional
weight decay.
Weight decay.
verbose : bool optional
verbose, by default to be True.
Verbosity, by default to be True.
random_seed : int optional
random seed by default to be 0.
Random seed.
save_loss : bool optional
by default to be False.
Whether to save loss or not.
save_reconstrction : bool optional
by default to be False.
Whether to save reconstruction or not.
device : str optional
to indicate gpu or cpu device.
Returns
-------
None.
Computation device.
"""
adata.X = sp.csr_matrix(adata.X)

if 'highly_variable' in adata.var.columns:
adata_Vars = adata[:, adata.var['highly_variable']]
if "highly_variable" in adata.var.columns:
adata_Vars = adata[:, adata.var["highly_variable"]]
else:
adata_Vars = adata

if verbose:
print('Size of Input: ', adata_Vars.shape)
if 'Spatial_Net' not in adata.uns.keys():
raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!")
print("Size of Input: ", adata_Vars.shape)

data = transfer_pytorch_data(adata_Vars, graph)

Expand All @@ -320,55 +269,39 @@ def fit(self, adata, graph, n_epochs=1, lr=0.001, key_added='STAGATE', gradient_
model.eval()
z, out = model(data.x, data.edge_index)

STAGATE_rep = z.to('cpu').detach().numpy()
STAGATE_rep = z.to("cpu").detach().numpy()
adata.obsm[key_added] = STAGATE_rep

if save_loss:
adata.uns['STAGATE_loss'] = loss
adata.uns["STAGATE_loss"] = loss
if save_reconstrction:
ReX = out.to('cpu').detach().numpy()
ReX = out.to("cpu").detach().numpy()
ReX[ReX < 0] = 0
adata.layers['STAGATE_ReX'] = ReX
adata.layers["STAGATE_ReX"] = ReX

print("post process...")
sc.pp.neighbors(adata, use_rep='STAGATE')
sc.pp.neighbors(adata, use_rep="STAGATE")
sc.tl.umap(adata)
#adata = mclust_R(adata, used_obsm='STAGATE', num_cluster=7)
adata = mclust_P(adata, used_obsm='STAGATE', num_cluster=7)
adata = mclust(adata, used_obsm="STAGATE", num_cluster=7)
self.adata = adata

def predict(self, ):
"""prediction function.
Parameters
----------
Returns
-------
self.y_pred :
predicted label.
"""
data_dropna = self.adata.obs.dropna()
self.y_pred = data_dropna['mclust']
self.target = data_dropna['ground_truth']
return data_dropna['mclust']
def predict(self):
"""Prediction function."""
return self.adata.obs["mclust"].values

def score(self, y_true=None):
"""score function to get score of prediction.
Parameters
----------
y_true :
ground truth label.
Ground truth label.
Returns
-------
score : float
metric eval score.
float
Adjusted rand index score.
"""
from sklearn.metrics.cluster import adjusted_rand_score
score = adjusted_rand_score(self.target, self.y_pred)
print("ARI {}".format(adjusted_rand_score(self.target, self.y_pred)))
score = adjusted_rand_score(y_true, self.predict())
return score
3 changes: 2 additions & 1 deletion dance/transforms/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dance.transforms.graph.cell_feature_graph import CellFeatureGraph, PCACellFeatureGraph
from dance.transforms.graph.dstg_graph import DSTGraph
from dance.transforms.graph.neighbor_graph import NeighborGraph
from dance.transforms.graph.spatial_graph import SMEGraph, SpaGCNGraph, SpaGCNGraph2D
from dance.transforms.graph.spatial_graph import SMEGraph, SpaGCNGraph, SpaGCNGraph2D, StagateGraph

__all__ = [
"CellFeatureGraph",
Expand All @@ -11,4 +11,5 @@
"SMEGraph",
"SpaGCNGraph",
"SpaGCNGraph2D",
"StagateGraph",
] # yapf: disable
43 changes: 43 additions & 0 deletions dance/transforms/graph/spatial_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors

from dance.transforms.base import BaseTransform
from dance.typing import Sequence
Expand Down Expand Up @@ -103,3 +104,45 @@ def __call__(self, data):
adj = adj_p * adj_m * adj_g

data.data.obsp[self.out] = adj


class StagateGraph(BaseTransform):
"""STAGATE spatial graph."""

_MODELS = ("radius", "knn")
_DISPLAY_ATTRS = ("model_name", "radius", "n_neighbors")

def __init__(self, model_name: str = "radius", *, radius: float = 1, n_neighbors: int = 5,
channel: str = "spatial_pixel", channel_type: str = "obsm", **kwargs):
"""Initialize StagateGraph.
Parameters
----------
model_name
Type of graph to construct. Currently support `radius` and `knn`. See
:class:`~sklearn.neighbors.NearestNeighbors` for more info.
radius
Radius parameter for `radius_neighbors_graph`.
n_neighbors
Number of neighbors for `kneighbors_graph`.
"""
super().__init__(**kwargs)

if not isinstance(model_name, str) or (model_name.lower() not in self._MODELS):
raise ValueError(f"Unknown model {model_name!r}, available options are {self._MODELS}")
self.model_name = model_name
self.radius = radius
self.n_neighbors = n_neighbors
self.channel = channel
self.channel_type = channel_type

def __call__(self, data):
xy_pixel = data.get_feature(return_type="numpy", channel=self.channel, channel_type=self.channel_type)

if self.model_name.lower() == "radius":
adj = NearestNeighbors(radius=self.radius).fit(xy_pixel).radius_neighbors_graph(xy_pixel)
elif self.model_name.lower() == "knn":
adj = NearestNeighbors(n_neighbors=self.n_neighbors).fit(xy_pixel).kneighbors_graph(xy_pixel)

data.data.obsp[self.out] = adj
Loading

0 comments on commit b8c00eb

Please sign in to comment.