Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update clustering examples to use data object #95

Merged
merged 4 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,10 @@ def load_data(self):
assert self.is_complete()

data_mat = h5py.File(f"{self.data_dir}/{self.dataset}.h5", "r")
self.X = np.array(data_mat["X"])
self.Y = np.array(data_mat["Y"])
return self
X = np.array(data_mat["X"])
adata = ad.AnnData(X, dtype=np.float32)
Y = np.array(data_mat["Y"])
return adata, Y


class PretrainDataset(Dataset):
Expand Down
5 changes: 3 additions & 2 deletions dance/modules/single_modality/clustering/graphsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class GraphSC:
def __init__(self, args):
super().__init__()
self.args = args
self.model = GCNAE(args).to(get_device(args.use_cpu))
self.device = get_device(args.use_cpu)
self.model = GCNAE(args).to(self.device)

def fit(self, n_epochs, dataloader, n_clusters, lr, cluster=["KMeans"]):
"""Train graph-sc.
Expand Down Expand Up @@ -88,7 +89,7 @@ def fit(self, n_epochs, dataloader, n_clusters, lr, cluster=["KMeans"]):
y.extend(blocks[-1].dstdata["label"].cpu().numpy())
order.extend(blocks[-1].dstdata["order"].cpu().numpy())

adj = g.adjacency_matrix().to_dense()
adj = g.adjacency_matrix().to_dense().to(device)
adj = adj[g.dstnodes()]
pos_weight = torch.Tensor([float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()])
factor = float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
Expand Down
73 changes: 12 additions & 61 deletions dance/transforms/graph_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,39 +632,10 @@ def construct_modality_prediction_graph(dataset, **kwargs):
return g


def make_graph(X, Y=None, threshold=0, dense_dim=100, gene_data={}, normalize_weights="log_per_cell", nb_edges=1,
node_features="scale", same_edge_values=False, edge_norm=True):
"""Create DGL graph for graph-sc.
def cell_gene_graph(data, threshold=0, dense_dim=100, gene_data={}, normalize_weights="log_per_cell", nb_edges=1,
node_features="scale", same_edge_values=False, edge_norm=True):

Parameters
----------
X :
input cell-gene features.
Y : list optional
true labels.
threshold : int optional
minimum value of selected feature.
dense_dim : int optional
dense dimension for PCA.
gene_data : dict optional
external gene data.
normalize_weights : str optional
weights normalization method.
nb_edges : float, optional
proportion of edges selected.
node_features : str optional
type of node features.
same_edge_values : bool optional
set identical edge value or not.
edge_norm : bool optional
perform edge normalization or not.

Returns
-------
graph :
constructed dgl graph.

"""
X = data.get_x()
num_genes = X.shape[1]

graph = dgl.DGLGraph()
Expand Down Expand Up @@ -722,11 +693,7 @@ def make_graph(X, Y=None, threshold=0, dense_dim=100, gene_data={}, normalize_we

graph.ndata['order'] = torch.tensor([-1] * num_genes + list(np.arange(len(X))),
dtype=torch.long) # [gene_num+train_num]
if Y is not None:
graph.ndata['label'] = torch.tensor([-1] * num_genes + list(np.array(Y).astype(int)),
dtype=torch.long) # [gene_num+train_num]
else:
graph.ndata['label'] = torch.tensor([-1] * num_genes + [np.nan] * len(X))
graph.ndata['label'] = torch.tensor([-1] * num_genes + [np.nan] * len(X))
nb_edges = graph.num_edges()

if len(gene_data) != 0 and len(gene_data['gene1']) > 0:
Expand All @@ -743,7 +710,7 @@ def make_graph(X, Y=None, threshold=0, dense_dim=100, gene_data={}, normalize_we

graph.add_edges(graph.nodes(), graph.nodes(),
{'weight': torch.ones(graph.number_of_nodes(), dtype=torch.float).unsqueeze(1)})
return graph
data.data.uns["graph"] = graph


def external_data_connections(graph, gene_data, X, gene_idx, cell_idx):
Expand Down Expand Up @@ -814,26 +781,9 @@ def external_data_connections(graph, gene_data, X, gene_idx, cell_idx):
return graph


def get_adj(count, k=15, pca_dim=50, mode="connectivity"):
"""Conctruct adjacency matrix for scTAG.

Parameters
----------
count :
input cell-gene features.
k : int optional
number of neighbors for each sample in k-neighbors graph.
pca_dim : int optional
number of components in PCA.
mode : str optional
type of returned adjacency matrix.

Returns
-------
pred : list
prediction of leiden.

"""
def get_adj(data, k=15, pca_dim=50, mode="connectivity"):
"""Conctruct adjacency matrix for scTAG."""
count = data.get_x()
if pca_dim:
countp = PCA(n_components=pca_dim).fit_transform(count)
else:
Expand All @@ -842,8 +792,8 @@ def get_adj(count, k=15, pca_dim=50, mode="connectivity"):
adj = A.toarray()
normalized_D = degree_power(adj, -0.5)
adj_n = normalized_D.dot(adj).dot(normalized_D)

return adj, adj_n
data.data.obsp["adj"] = adj
data.data.obsp["adj_n"] = adj_n


def degree_power(A, k):
Expand Down Expand Up @@ -1052,7 +1002,7 @@ def stAdjConstruct(st_scale, st_label, adj_data, k_filter=1):
############################
# scDSC #
############################
def construct_graph_sc(fname, features, label, method, topk):
def construct_graph_scdsc(fname, data, method, topk):
"""Graph construction function for scDSC.

Parameters
Expand All @@ -1073,6 +1023,7 @@ def construct_graph_sc(fname, features, label, method, topk):
None.

"""
features, label = data.get_train_data()
num = len(label)
if topk == None:
topk = 0
Expand Down
100 changes: 15 additions & 85 deletions dance/transforms/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,38 +2039,15 @@ def selectTopGenes(Loadings, dims, DimGenes, maxGenes):
return (selgene)


def filter_data(X, highly_genes=500):
"""Remove less variable genes.

Parameters
----------
X :
cell-gene data.
highly_genes : int optional
number of chosen genes.

Returns
-------
genes_idx :
index of chosen genes
cells_idx :
index of chosen cells

"""

X = np.ceil(X).astype(int)
adata = sc.AnnData(X, dtype=np.float32)

def filter_data(data, highly_genes=500):
adata = data.data.copy()
sc.pp.filter_genes(adata, min_counts=3)
sc.pp.filter_cells(adata, min_counts=1)
sc.pp.normalize_per_cell(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=4, flavor='cell_ranger', min_disp=0.5,
n_top_genes=highly_genes, subset=True)
genes_idx = np.array(adata.var_names.tolist()).astype(int)
cells_idx = np.array(adata.obs_names.tolist()).astype(int)

return genes_idx, cells_idx
data._data = data.data[adata.obs_names, adata.var_names]


def generate_random_pair(y, label_cell_indx, num, error_rate=0):
Expand Down Expand Up @@ -2121,9 +2098,8 @@ def check_ind(ind1, ind2, ind_list1, ind_list2):
return ml_ind1, ml_ind2, cl_ind1, cl_ind2, error_num


def geneSelection(data, threshold=0, atleast=10, yoffset=.02, xoffset=5, decay=1.5, n=None, plot=True, markers=None,
genes=None, figsize=(6, 3.5), markeroffsets=None, labelsize=10, alpha=1, verbose=1):
if sparse.issparse(data):
def geneSelection(data, threshold=0, atleast=10, yoffset=.02, xoffset=5, decay=1.5, n=None, verbose=1):
if sp.issparse(data):
zeroRate = 1 - np.squeeze(np.array((data > threshold).mean(axis=0)))
A = data.multiply(data > threshold)
A.data = np.log2(A.data)
Expand Down Expand Up @@ -2164,50 +2140,6 @@ def geneSelection(data, threshold=0, atleast=10, yoffset=.02, xoffset=5, decay=1
nonan = ~np.isnan(zeroRate)
selected = np.zeros_like(zeroRate).astype(bool)
selected[nonan] = zeroRate[nonan] > np.exp(-decay * (meanExpr[nonan] - xoffset)) + yoffset

if plot:
if figsize is not None:
plt.figure(figsize=figsize)
plt.ylim([0, 1])
if threshold > 0:
plt.xlim([np.log2(threshold), np.ceil(np.nanmax(meanExpr))])
else:
plt.xlim([0, np.ceil(np.nanmax(meanExpr))])
x = np.arange(plt.xlim()[0], plt.xlim()[1] + .1, .1)
y = np.exp(-decay * (x - xoffset)) + yoffset
if decay == 1:
plt.text(.4, 0.2, '{} genes selected\ny = exp(-x+{:.2f})+{:.2f}'.format(np.sum(selected), xoffset, yoffset),
color='k', fontsize=labelsize, transform=plt.gca().transAxes)
else:
plt.text(
.4, 0.2,
'{} genes selected\ny = exp(-{:.1f}*(x-{:.2f}))+{:.2f}'.format(np.sum(selected), decay, xoffset,
yoffset), color='k', fontsize=labelsize,
transform=plt.gca().transAxes)

plt.plot(x, y, color=sns.color_palette()[1], linewidth=2)
xy = np.concatenate((np.concatenate((x[:, None], y[:, None]), axis=1), np.array([[plt.xlim()[1], 1]])))
t = plt.matplotlib.patches.Polygon(xy, color=sns.color_palette()[1], alpha=.4)
plt.gca().add_patch(t)

plt.scatter(meanExpr, zeroRate, s=1, alpha=alpha, rasterized=True)
if threshold == 0:
plt.xlabel('Mean log2 nonzero expression')
plt.ylabel('Frequency of zero expression')
else:
plt.xlabel('Mean log2 nonzero expression')
plt.ylabel('Frequency of near-zero expression')
plt.tight_layout()

if markers is not None and genes is not None:
if markeroffsets is None:
markeroffsets = [(0, 0) for g in markers]
for num, g in enumerate(markers):
i = np.where(genes == g)[0]
plt.scatter(meanExpr[i], zeroRate[i], s=10, color='k')
dx, dy = markeroffsets[num]
plt.text(meanExpr[i] + dx + .1, zeroRate[i] + dy, g, color='k', fontsize=labelsize)

return selected


Expand All @@ -2229,29 +2161,27 @@ def load_graph(path, data):
return adj


def normalize_adata(adata, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True):
def normalize_adata(data, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True):
if filter_min_counts:
sc.pp.filter_genes(adata, min_counts=1)
sc.pp.filter_cells(adata, min_counts=1)
sc.pp.filter_genes(data.data, min_counts=1)
sc.pp.filter_cells(data.data, min_counts=1)

if size_factors or normalize_input or logtrans_input:
adata.raw = adata.copy()
data.data.raw = data.data.copy()
else:
adata.raw = adata
data.data.raw = data.data

if size_factors:
sc.pp.normalize_per_cell(adata)
adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts)
sc.pp.normalize_per_cell(data.data)
data.data.obs['size_factors'] = data.data.obs.n_counts / np.median(data.data.obs.n_counts)
else:
adata.obs['size_factors'] = 1.0
data.data.obs['size_factors'] = 1.0

if logtrans_input:
sc.pp.log1p(adata)
sc.pp.log1p(data.data)

if normalize_input:
sc.pp.scale(adata)

return adata
sc.pp.scale(data.data)


def row_normalize(mx):
Expand Down
36 changes: 14 additions & 22 deletions examples/single_modality/clustering/graphsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,34 @@
import numpy as np
import torch

from dance.data import Data
from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.graphsc import *
from dance.transforms.graph_construct import make_graph
from dance.transforms.graph_construct import cell_gene_graph
from dance.transforms.preprocess import filter_data
from dance.utils import set_seed


def pipeline(**args):
data = ClusteringDataset(args['data_dir'], args['dataset']).load_data()
X = data.X
Y = data.Y
adata, labels = ClusteringDataset(args['data_dir'], args['dataset']).load_data()
adata.obsm["labels"] = labels
data = Data(adata, train_size="all")

filter_data(data, highly_genes=args['nb_genes'])
cell_gene_graph(data, dense_dim=args['in_feats'], node_features=args['node_features'],
normalize_weights=args['normalize_weights'], same_edge_values=args['same_edge_values'],
edge_norm=args['edge_norm'])
data.set_config(feature_channel="graph", feature_channel_type="uns", label_channel="labels")
graph, Y = data.get_train_data()
n_clusters = len(np.unique(Y))

genes_idx, cells_idx = filter_data(X, highly_genes=args['nb_genes'])
X = X[cells_idx][:, genes_idx]
Y = Y[cells_idx]

t0 = time.time()
graph = make_graph(X, Y, dense_dim=args['in_feats'], node_features=args['node_features'],
normalize_weights=args['normalize_weights'], same_edge_values=args['same_edge_values'],
edge_norm=args['edge_norm'])

labels = graph.ndata["label"]
train_ids = np.where(labels != -1)[0]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args['n_layers'])
dataloader = dgl.dataloading.NodeDataLoader(graph, train_ids, sampler, batch_size=args['batch_size'], shuffle=True,
drop_last=False, num_workers=args['num_workers'])

t1 = time.time()

for run in range(args['num_run']):
t_start = time.time()
torch.manual_seed(run)
torch.cuda.manual_seed_all(run)
np.random.seed(run)
random.seed(run)

set_seed(run)
model = GraphSC(Namespace(**args))
model.fit(args['epochs'], dataloader, n_clusters, args['learning_rate'], cluster=["KMeans", "Leiden"])
pred = model.predict(n_clusters, cluster=["KMeans", "Leiden"])
Expand Down
Loading