Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: adapt Stagate to BasePretrain class; improve consistency #204

Merged
merged 5 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 80 additions & 104 deletions dance/modules/spatial/spatial_domain/stagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,23 @@
"""
import numpy as np
import scanpy as sc
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import mixture
from sklearn.mixture import GaussianMixture
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.data import Data
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax
from torch_sparse import SparseTensor, set_diag
from tqdm import tqdm

from dance import logger
from dance.modules.base import BaseClusteringMethod
from dance.modules.base import BaseClusteringMethod, BasePretrain
from dance.transforms import AnnDataTransform, Compose, SetConfig
from dance.transforms.graph import StagateGraph
from dance.typing import Any, LogLevel, Optional


def transfer_pytorch_data(adata, adj):
edgeList = adj
if type(adata.X) == np.ndarray:
data = Data(edge_index=torch.LongTensor(np.array([edgeList[0], edgeList[1]])),
x=torch.FloatTensor(adata.X)) # .todense()
else:
data = Data(edge_index=torch.LongTensor(np.array([edgeList[0], edgeList[1]])),
x=torch.FloatTensor(adata.X.todense())) # .todense()
return data


def mclust(adata, num_cluster, used_obsm="STAGATE", modelNames="EEE"):
g = mixture.GaussianMixture(n_components=num_cluster, covariance_type="tied", warm_start=True, n_init=100,
max_iter=300, reg_covar=1.4663143602030552e-04, random_state=36282,
tol=0.00022187708009762592)
res = g.fit_predict(adata.obsm[used_obsm])
adata.obs["mclust"] = res
return adata
from dance.typing import Any, LogLevel, Optional, Tuple
from dance.utils import get_device


class GATConv(MessagePassing):
Expand Down Expand Up @@ -101,7 +80,6 @@ def forward(self, x, edge_index, size=None, return_attention_weights=None, atten

if not attention:
return x[0].mean(dim=1)
# return x[0].view(-1, self.heads * self.out_channels)

if tied_attention is None:
# Next, we compute node-level attention coefficients, both for source
Expand Down Expand Up @@ -147,10 +125,7 @@ def forward(self, x, edge_index, size=None, return_attention_weights=None, atten

def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i

# alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = torch.sigmoid(alpha)

alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha # Save for later use.
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
Expand All @@ -160,25 +135,33 @@ def __repr__(self):
return "{}({}, {}, heads={})".format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads)


class Stagate(torch.nn.Module, BaseClusteringMethod):
class Stagate(nn.Module, BasePretrain, BaseClusteringMethod):
"""Stagate class.

Parameters
----------
hidden_dims : int
hidden_dims
Hidden dimensions.
device
Computation device.
pretrain_path
Save the cell representations from the trained STAGATE model to the specified path. Do not save if unspecified.

"""

def __init__(self, hidden_dims):
def __init__(self, hidden_dims, device: str = "auto", pretrain_path: Optional[str] = None):
super().__init__()
self.pretrain_path = pretrain_path

[in_dim, num_hidden, out_dim] = hidden_dims
self.conv1 = GATConv(in_dim, num_hidden, heads=1, concat=False, dropout=0, add_self_loops=False, bias=False)
self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False, dropout=0, add_self_loops=False, bias=False)
self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False, dropout=0, add_self_loops=False, bias=False)
self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False, dropout=0, add_self_loops=False, bias=False)

self.device = get_device(device)
self.to(self.device)

@staticmethod
def preprocessing_pipeline(hvg_flavor: str = "seurat_v3", n_top_hvgs: int = 3000, model_name: str = "radius",
radius: float = 150, n_neighbors: int = 5, log_level: LogLevel = "INFO"):
Expand All @@ -201,9 +184,9 @@ def forward(self, features, edge_index):

Parameters
----------
features :
features
Node features.
edge_index :
edge_index
Adjacent matrix.

Returns
Expand All @@ -221,87 +204,80 @@ def forward(self, features, edge_index):
h3 = F.elu(self.conv3(h2, edge_index, attention=True, tied_attention=self.conv1.attentions))
h4 = self.conv4(h3, edge_index, attention=False)

return h2, h4 # F.log_softmax(x, dim=-1)
return h2, h4

def pretrain(
self,
x: np.ndarray,
edge_index_array: np.ndarray,
lr: float = 1e-3,
weight_decay: float = 1e-4,
epochs: int = 100,
gradient_clipping: float = 5,
):
x_tensor = torch.from_numpy(x.astype(np.float32)).to(self.device)
edge_index_tensor = torch.from_numpy(edge_index_array.astype(int)).to(self.device)

optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
self.train()
for epoch in tqdm(range(1, epochs + 1)):
optimizer.zero_grad()
z, out = self(x_tensor, edge_index_tensor)
loss = F.mse_loss(x_tensor, out)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), gradient_clipping)
optimizer.step()

def fit(self, adata, graph, n_epochs=1, lr=0.001, key_added="STAGATE", gradient_clipping=5., pre_resolution=0.2,
weight_decay=0.0001, verbose=True, random_seed=0, save_loss=False, save_reconstrction=False,
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")):
self.eval()
z, _ = self(x_tensor, edge_index_tensor)
self.rep = z.detach().clone().cpu().numpy()

def save_pretrained(self, path):
np.save(path, self.rep)

def load_pretrained(self, path):
self.rep = np.load(path)

def fit(
self,
inputs: Tuple[np.ndarray, np.ndarray],
epochs: int = 100,
lr: float = 0.001,
gradient_clipping: float = 5,
weight_decay: float = 1e-4,
num_cluster: int = 7,
gmm_reg_covar: float = 1.5e-4,
gmm_n_init: int = 10,
gmm_max_iter: int = 300,
gmm_tol: float = 2e-4,
random_state: Optional[int] = None,
):
"""Fit function for training.

Parameters
----------
adata :
Input data.
graph :
Graph structure.
n_epochs : int
inputs
A tuple containing (1) the input features and (2) the edge index array (coo representation) of the
adjacency matrix.
epochs
Number of epochs.
lr : float
lr
Learning rate.
key_added : str
Default "STAGATE".
gradient_clipping : float
gradient_clipping
Gradient clipping.
pre_resolution : float
Pre-resolution.
weight_decay : float
weight_decay
Weight decay.
verbose : bool
Verbosity, by default to be True.
random_seed : int
Random seed.
save_loss : bool
Whether to save loss or not.
save_reconstrction : bool
Whether to save reconstruction or not.
device : str
Computation device.
num_cluster
Number of cluster.

"""
adata.X = sp.csr_matrix(adata.X)

if "highly_variable" in adata.var.columns:
adata_Vars = adata[:, adata.var["highly_variable"]]
else:
adata_Vars = adata

if verbose:
logger.info(f"Size of Input: {adata_Vars.shape}")

data = transfer_pytorch_data(adata_Vars, graph)

model = self.to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

for epoch in tqdm(range(1, n_epochs + 1)):
model.train()
optimizer.zero_grad()
z, out = model(data.x, data.edge_index)
loss = F.mse_loss(data.x, out)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
optimizer.step()

model.eval()
z, out = model(data.x, data.edge_index)

STAGATE_rep = z.to("cpu").detach().numpy()
adata.obsm[key_added] = STAGATE_rep

if save_loss:
adata.uns["STAGATE_loss"] = loss
if save_reconstrction:
ReX = out.to("cpu").detach().numpy()
ReX[ReX < 0] = 0
adata.layers["STAGATE_ReX"] = ReX
x, edge_index_array = inputs
self._pretrain(x, edge_index_array, lr, weight_decay, epochs, gradient_clipping)

logger.info("Start post-processing")
sc.pp.neighbors(adata, use_rep="STAGATE")
sc.tl.umap(adata)
adata = mclust(adata, used_obsm="STAGATE", num_cluster=7)
self.adata = adata
logger.info("Fitting Gaussian Mixture model for cluster assignments.")
gmm = GaussianMixture(n_components=num_cluster, covariance_type="tied", n_init=gmm_n_init, tol=gmm_tol,
max_iter=gmm_max_iter, reg_covar=gmm_reg_covar, random_state=random_state)
self.clust_res = gmm.fit_predict(self.rep)

def predict(self, x: Optional[Any] = None):
"""Prediction function.
Expand All @@ -312,4 +288,4 @@ def predict(self, x: Optional[Any] = None):
Not used, for compatibility with :class:`BaseClusteringMethod`.

"""
return self.adata.obs["mclust"].values
return self.clust_res
1 change: 0 additions & 1 deletion examples/spatial/spatial_domain/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
parser.add_argument("--n_components", type=int, default=50, help="Number of PC components.")
parser.add_argument("--neighbors", type=int, default=17, help="Number of neighbors.")
args = parser.parse_args()

set_seed(args.seed)

# Initialize model and get model specific preprocessing pipeline
Expand Down
1 change: 0 additions & 1 deletion examples/spatial/spatial_domain/spagcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
parser.add_argument("--lr", type=float, default=0.05, help="learning rate")
parser.add_argument("--random_state", type=int, default=100, help="")
args = parser.parse_args()

set_seed(args.random_state)

# Initialize model and get model specific preprocessing pipeline
Expand Down
12 changes: 5 additions & 7 deletions examples/spatial/spatial_domain/stagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
parser.add_argument("--hidden_dims", type=list, default=[512, 32], help="hidden dimensions")
parser.add_argument("--rad_cutoff", type=int, default=150, help="")
parser.add_argument("--seed", type=int, default=3, help="")
parser.add_argument("--n_epochs", type=int, default=1000, help="epochs")
parser.add_argument("--epochs", type=int, default=1000, help="epochs")
parser.add_argument("--high_variable_genes", type=int, default=3000, help="")
args = parser.parse_args()

set_seed(args.seed)

# Initialize model and get model specific preprocessing pipeline
Expand All @@ -28,13 +27,12 @@
dataloader = SpatialLIBDDataset(data_id=args.sample_number)
data = dataloader.load_data(transform=preprocessing_pipeline, cache=args.cache)
adj, y = data.get_data(return_type="default")
x = data.data.X.A
edge_list_array = np.vstack(np.nonzero(adj))

# Train and evaluate model
model = Stagate([args.high_variable_genes] + args.hidden_dims)
# TODO: extract nn model part of stagate and wrap with BaseClusteringMethod
# TODO: extract features from adata and directly pass to model.
model.fit(data.data, np.nonzero(adj), n_epochs=args.n_epochs)
pred = model.predict()
score = model.default_score_func(y.values.ravel(), pred)
score = model.fit_score((x, edge_list_array), y, epochs=args.epochs, random_state=args.seed)
print(f"ARI: {score:.4f}")
""" To reproduce Stagate on other samples, please refer to command lines belows:
NOTE: since the stagate method is unstable, you have to run at least 5 times to get
Expand Down
1 change: 0 additions & 1 deletion examples/spatial/spatial_domain/stlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
parser.add_argument("--n_components", type=int, default=50, help="the number of components in PCA")
parser.add_argument("--device", type=str, default="cuda", help="device for resnet extract feature")
args = parser.parse_args()

set_seed(args.seed)

# Initialize model and get model specific preprocessing pipeline
Expand Down