Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update stlearn example script to use dance data object #126

Merged
merged 7 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 46 additions & 104 deletions dance/modules/spatial/spatial_domain/stlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,10 @@

"""

# load kmeans from sklearn
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import adjusted_rand_score

from .louvain import Louvain


# kmeans for adata
def stKmeans(adata, n_clusters=19, init="k-means++", n_init=10, max_iter=300, tol=1e-4, algorithm='auto', verbose=False,
random_state=None, use_data='X_pca', key_added="X_pca_kmeans"):
# kmeans for gene expression data
kmeans = KMeans(n_clusters=n_clusters, init=init, n_init=n_init, max_iter=max_iter, tol=tol, algorithm=algorithm,
verbose=verbose, random_state=random_state).fit(adata.obsm[use_data])
adata.obs[key_added] = kmeans.labels_
return adata


def stPrepare(adata):
adata.obs['imagerow'] = adata.obs['x_pixel']
adata.obs['imagecol'] = adata.obs['y_pixel']
adata.obs['array_row'] = adata.obs['x']
adata.obs['array_col'] = adata.obs['y']
from dance.modules.spatial.spatial_domain.louvain import Louvain


class StKmeans:
Expand All @@ -38,163 +21,122 @@ class StKmeans:
Parameters
----------
n_clusters : int optional
the number of clusters to form as well as the number of centroids to generate.
The number of clusters to form as well as the number of centroids to generate.
init : str optional
method for initialization: {‘k-means++’, ‘random’}.
Method for initialization: {‘k-means++’, ‘random’}.
n_init : int optional
number of time the k-means algorithm will be run with different centroid seeds.
Number of time the k-means algorithm will be run with different centroid seeds.
The final results will be the best output of n_init consecutive runs in terms of inertia.
max_iter : int optional
maximum number of iterations of the k-means algorithm for a single run.
Maximum number of iterations of the k-means algorithm for a single run.
tol : float optional
relative tolerance with regards to Frobenius norm of the difference in
the cluster centers of two consecutive iterations to declare convergence.
Relative tolerance with regards to Frobenius norm of the difference in the cluster centers of two consecutive
iterations to declare convergence.
algorithm : str optional
{“lloyd”, “elkan”, “auto”, “full”}, by default "auto"
{“lloyd”, “elkan”, “auto”, “full”}, default is "auto".
verbose : bool optional
verbosity mode.
Verbosity.
random_state : int optional
determines random number generation for centroid initialization.
Use an int to make the randomness deterministic.
Determines random number generation for centroid initialization.
use_data : str optional
by default 'X_pca'.
Default "X_pca".
key_added : str optional
by default 'X_pca_kmeans'.
Default "X_pca_kmeans".

"""

def __init__(self, n_clusters=19, init="k-means++", n_init=10, max_iter=300, tol=1e-4, algorithm='auto',
verbose=False, random_state=None, use_data='X_pca', key_added="X_pca_kmeans"):
def __init__(self, n_clusters=19, init="k-means++", n_init=10, max_iter=300, tol=1e-4, algorithm="auto",
verbose=False, random_state=None, use_data="X_pca", key_added="X_pca_kmeans"):
self.use_data = use_data
self.key_added = key_added
self.model = KMeans(n_clusters=n_clusters, init=init, n_init=n_init, max_iter=max_iter, tol=tol,
algorithm=algorithm, verbose=verbose, random_state=random_state)

def fit(self, adata):
"""fit function for model training.
def fit(self, x):
"""Fit function for model training.

Parameters
----------
adata :
input data.

Returns
-------
None.
x
Input cell feature.

"""
self.model.fit(adata.obsm[self.use_data])
adata.obs[self.key_added] = self.model.labels_
self.model.fit(x)

def predict(self):
"""prediction function.

Parameters
----------

Returns
-------
self.model.labels_ :
predicted label.

"""
"""Prediction function."""
self.predict = self.model.labels_
self.y_pred = self.predict
return self.predict

def score(self, y_true):
"""score function.
"""Score function.

Parameters
----------
adata :
input data.
y_true
Cluster labels.

Returns
-------
self.score :
score.
float
Adjusted rand index score.

"""
from sklearn.metrics.cluster import adjusted_rand_score
score = adjusted_rand_score(y_true, self.y_pred)
print("ARI {}".format(adjusted_rand_score(y_true, self.y_pred)))
return score


class StLouvain:
"""StLouvain class."""

def __init__(self):
self.model = Louvain()
def __init__(self, resolution: float = 1):
self.model = Louvain(resolution)

def fit(self, adata, adj, partition=None, weight='weight', resolution=1., randomize=None, random_state=None):
"""fit function for model training.
def fit(self, adj, partition=None, weight="weight", randomize=None, random_state=None):
"""Fit function for model training.

Parameters
----------
adata :
input data.
adj :
adjacent matrix.
adj
Adjacent matrix.
partition : dict optional
a dictionary where keys are graph nodes and values the part the node
A dictionary where keys are graph nodes and values the part the node
belongs to
weight : str, optional
the key in graph to use as weight. Default to 'weight'
The key in graph to use as weight. Default to "weight"
resolution : float optional
resolution.
Resolution.
randomize : boolean, optional
Will randomize the node evaluation order and the community evaluation
order to get different partitions at each call
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
Returns
-------
None.
If int, random_state is the seed used by the random number generator; If RandomState instance, random_state
is the random number generator; If None, the random number generator is the RandomState instance used by
`np.random`.

"""
self.data = adata
self.model.fit(adata, adj, partition, weight, resolution, randomize, random_state)
self.model.fit(adj, partition, weight, randomize, random_state)

def predict(self):
"""prediction function.

Parameters
----------

Returns
-------
self.y_pred :
predicted label.

"""

"""Prediction function."""
self.y_pred = self.model.predict()
self.y_pred = [self.y_pred[i] for i in range(len(self.y_pred))]
self.data.obs['predict'] = self.y_pred
return self.y_pred

def score(self, y_true):
"""score function.
"""Score function.

Parameters
----------
adata :
input data.
y_true
Cluster labels.

Returns
-------
self.score :
score.
float
Adjusted rand index score.

"""
self.data.obs['ground'] = y_true
tempdata = self.data.obs.dropna()
from sklearn.metrics.cluster import adjusted_rand_score
score = adjusted_rand_score(tempdata['ground'], tempdata['predict'])
print("ARI {}".format(adjusted_rand_score(tempdata['ground'], tempdata['predict'])))
score = adjusted_rand_score(y_true, self.y_pred)
return score
3 changes: 3 additions & 0 deletions dance/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from dance.transforms import graph
from dance.transforms.cell_feature import CellPCA, WeightedFeaturePCA
from dance.transforms.interface import AnnDataTransform
from dance.transforms.spatial_feature import MorphologyFeature, SMEFeature

__all__ = [
"AnnDataTransform",
"CellPCA",
"MorphologyFeature",
"SMEFeature",
"WeightedFeaturePCA",
"graph",
] # yapf: disable
3 changes: 2 additions & 1 deletion dance/transforms/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from dance.transforms.graph.cell_feature_graph import CellFeatureGraph, PCACellFeatureGraph
from dance.transforms.graph.dstg_graph import DSTGraph
from dance.transforms.graph.neighbor_graph import NeighborGraph
from dance.transforms.graph.spatial_graph import SpaGCNGraph, SpaGCNGraph2D
from dance.transforms.graph.spatial_graph import SMEGraph, SpaGCNGraph, SpaGCNGraph2D

__all__ = [
"CellFeatureGraph",
"DSTGraph",
"NeighborGraph",
"PCACellFeatureGraph",
"SMEGraph",
"SpaGCNGraph",
"SpaGCNGraph2D",
] # yapf: disable
34 changes: 34 additions & 0 deletions dance/transforms/graph/spatial_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import pairwise_distances

from dance.transforms.base import BaseTransform
from dance.typing import Sequence
Expand Down Expand Up @@ -69,3 +71,35 @@ def __call__(self, data):
x = data.get_feature(channel=self.channel, channel_type="obsm", return_type="numpy")
data.data.obsp[self.out] = pairwise_distance(x.astype(np.float32), dist_func_id=0)
return data


class SMEGraph(BaseTransform):
"""Spatial Morphological gene Expression graph."""

def __init__(self, radius: float = 3, *,
channels: Sequence[str] = ("spatial", "spatial_pixel", "MorphologyFeature", "CellPCA"),
channel_types: Sequence[str] = ("obsm", "obsm", "obsm", "obsm"), **kwargs):
super().__init__(**kwargs)

self.radius = radius
self.channels = channels
self.channel_types = channel_types

def __call__(self, data):
xy = data.get_feature(return_type="numpy", channel=self.channels[0], channel_type=self.channel_types[0])
xy_pixel = data.get_feature(return_type="numpy", channel=self.channels[1], channel_type=self.channel_types[1])
morph_feat = data.get_feature(return_type="numpy", channel=self.channels[2], channel_type=self.channel_types[2])
gene_feat = data.get_feature(return_type="numpy", channel=self.channels[3], channel_type=self.channel_types[3])

reg_x = LinearRegression().fit(xy[:, 0:1], xy_pixel[:, 0:1])
reg_y = LinearRegression().fit(xy[:, 1:2], xy_pixel[:, 1:2])
unit = np.sqrt(reg_x.coef_**2 + reg_y.coef_**2)

# TODO: only captures topk, which are the ones that will be used by SMEFeature.
pdist = pairwise_distances(xy_pixel, metric="euclidean")
adj_p = np.where(pdist >= self.radius * unit, 0, 1)
adj_m = (1 - pairwise_distances(morph_feat, metric="cosine")).clip(0)
adj_g = 1 - pairwise_distances(gene_feat, metric="correlation")
adj = adj_p * adj_m * adj_g

data.data.obsp[self.out] = adj
35 changes: 2 additions & 33 deletions dance/transforms/graph_construct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import os
import pickle
import time
Expand All @@ -10,7 +9,6 @@
import networkx as nx
import numpy as np
import pandas as pd
import scanpy as sc
import sklearn
import torch
from dgl import nn as dglnn
Expand All @@ -19,11 +17,11 @@
from scipy.spatial import distance, distance_matrix, minkowski_distance
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.metrics import pairwise_distances as pair
from sklearn.neighbors import KDTree, kneighbors_graph
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import normalize
from torch.nn import functional as F

import dance.transforms.preprocess
from dance import logger


Expand All @@ -38,7 +36,6 @@ def csr_cosine_similarity(input_csr_matrix):


def cosine_similarity_gene(input_matrix):
from sklearn.metrics.pairwise import cosine_similarity
res = cosine_similarity(input_matrix)
res = np.abs(res)
return res
Expand Down Expand Up @@ -886,34 +883,6 @@ def basic_feature_graph_propagation(g, layers=3, alpha=0.5, beta=0.5, cell_init=
return hcell[1:]


##############################
# neighbor_graph for stlearn #
##############################


def neighbors_get_adj(adata, n_neighbors=10, n_pcs=10, n_jobs=1, use_rep=None, knn=True, random_state=None,
method='umap', metric='euclidean', metric_kwargs={}, copy=False, obsp=None, neighbors_key=None):

sc.pp.neighbors(
adata,
n_neighbors=n_neighbors,
n_pcs=n_pcs,
use_rep=use_rep,
knn=knn,
random_state=random_state,
method=method,
metric=metric,
metric_kwds=metric_kwargs,
copy=copy,
)

choose_graph = getattr(sc._utils, "_choose_graph", None)
adjacency = choose_graph(adata, obsp, neighbors_key)

print("Created k-Nearest-Neighbor graph in adata.uns['neighbors'] ")
return adjacency


##### scGNN create adjacency, likely much overlap with above functions, nested function defs to avoid possible namespace conflicts


Expand Down
Loading