Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update clustering examples to use data object #95

Merged
merged 4 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,10 @@ def load_data(self):
assert self.is_complete()

data_mat = h5py.File(f"{self.data_dir}/{self.dataset}.h5", "r")
self.X = np.array(data_mat["X"])
self.Y = np.array(data_mat["Y"])
return self
X = np.array(data_mat["X"])
adata = ad.AnnData(X, dtype=np.float32)
Y = np.array(data_mat["Y"])
return adata, Y


class PretrainDataset(Dataset):
Expand Down
2 changes: 1 addition & 1 deletion dance/modules/single_modality/clustering/graphsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def fit(self, n_epochs, dataloader, n_clusters, lr, cluster=["KMeans"]):
y.extend(blocks[-1].dstdata["label"].cpu().numpy())
order.extend(blocks[-1].dstdata["order"].cpu().numpy())

adj = g.adjacency_matrix().to_dense()
adj = g.adjacency_matrix().to_dense().to(device)
adj = adj[g.dstnodes()]
pos_weight = torch.Tensor([float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()])
factor = float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
Expand Down
19 changes: 11 additions & 8 deletions dance/transforms/graph_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ def construct_modality_prediction_graph(dataset, **kwargs):
return g


def make_graph(X, Y=None, threshold=0, dense_dim=100, gene_data={}, normalize_weights="log_per_cell", nb_edges=1,
node_features="scale", same_edge_values=False, edge_norm=True):
"""Create DGL graph for graph-sc.
def cell_gene_graph(data, threshold=0, dense_dim=100, gene_data={}, normalize_weights="log_per_cell", nb_edges=1,
node_features="scale", same_edge_values=False, edge_norm=True):
"""Create DGL cell-gene graph for graph-sc.

Parameters
----------
Expand Down Expand Up @@ -665,6 +665,7 @@ def make_graph(X, Y=None, threshold=0, dense_dim=100, gene_data={}, normalize_we
constructed dgl graph.

"""
X, Y = data.get_x_y()
num_genes = X.shape[1]

graph = dgl.DGLGraph()
Expand Down Expand Up @@ -743,7 +744,7 @@ def make_graph(X, Y=None, threshold=0, dense_dim=100, gene_data={}, normalize_we

graph.add_edges(graph.nodes(), graph.nodes(),
{'weight': torch.ones(graph.number_of_nodes(), dtype=torch.float).unsqueeze(1)})
return graph
data.data.uns["graph"] = graph


def external_data_connections(graph, gene_data, X, gene_idx, cell_idx):
Expand Down Expand Up @@ -814,7 +815,7 @@ def external_data_connections(graph, gene_data, X, gene_idx, cell_idx):
return graph


def get_adj(count, k=15, pca_dim=50, mode="connectivity"):
def get_adj(data, k=15, pca_dim=50, mode="connectivity"):
"""Conctruct adjacency matrix for scTAG.

Parameters
Expand All @@ -834,6 +835,7 @@ def get_adj(count, k=15, pca_dim=50, mode="connectivity"):
prediction of leiden.

"""
count = data.get_x()
if pca_dim:
countp = PCA(n_components=pca_dim).fit_transform(count)
else:
Expand All @@ -842,8 +844,8 @@ def get_adj(count, k=15, pca_dim=50, mode="connectivity"):
adj = A.toarray()
normalized_D = degree_power(adj, -0.5)
adj_n = normalized_D.dot(adj).dot(normalized_D)

return adj, adj_n
data.data.obsp["adj"] = adj
data.data.obsp["adj_n"] = adj_n


def degree_power(A, k):
Expand Down Expand Up @@ -1052,7 +1054,7 @@ def stAdjConstruct(st_scale, st_label, adj_data, k_filter=1):
############################
# scDSC #
############################
def construct_graph_sc(fname, features, label, method, topk):
def construct_graph_scdsc(fname, data, method, topk):
"""Graph construction function for scDSC.

Parameters
Expand All @@ -1073,6 +1075,7 @@ def construct_graph_sc(fname, features, label, method, topk):
None.

"""
features, label = data.get_train_data()
num = len(label)
if topk == None:
topk = 0
Expand Down
100 changes: 15 additions & 85 deletions dance/transforms/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,38 +2039,15 @@ def selectTopGenes(Loadings, dims, DimGenes, maxGenes):
return (selgene)


def filter_data(X, highly_genes=500):
"""Remove less variable genes.

Parameters
----------
X :
cell-gene data.
highly_genes : int optional
number of chosen genes.

Returns
-------
genes_idx :
index of chosen genes
cells_idx :
index of chosen cells

"""

X = np.ceil(X).astype(int)
adata = sc.AnnData(X, dtype=np.float32)

def filter_data(data, highly_genes=500):
adata = data.data.copy()
sc.pp.filter_genes(adata, min_counts=3)
sc.pp.filter_cells(adata, min_counts=1)
sc.pp.normalize_per_cell(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=4, flavor='cell_ranger', min_disp=0.5,
n_top_genes=highly_genes, subset=True)
genes_idx = np.array(adata.var_names.tolist()).astype(int)
cells_idx = np.array(adata.obs_names.tolist()).astype(int)

return genes_idx, cells_idx
data._data = data.data[adata.obs_names, adata.var_names]


def generate_random_pair(y, label_cell_indx, num, error_rate=0):
Expand Down Expand Up @@ -2121,9 +2098,8 @@ def check_ind(ind1, ind2, ind_list1, ind_list2):
return ml_ind1, ml_ind2, cl_ind1, cl_ind2, error_num


def geneSelection(data, threshold=0, atleast=10, yoffset=.02, xoffset=5, decay=1.5, n=None, plot=True, markers=None,
genes=None, figsize=(6, 3.5), markeroffsets=None, labelsize=10, alpha=1, verbose=1):
if sparse.issparse(data):
def geneSelection(data, threshold=0, atleast=10, yoffset=.02, xoffset=5, decay=1.5, n=None, verbose=1):
if sp.issparse(data):
zeroRate = 1 - np.squeeze(np.array((data > threshold).mean(axis=0)))
A = data.multiply(data > threshold)
A.data = np.log2(A.data)
Expand Down Expand Up @@ -2164,50 +2140,6 @@ def geneSelection(data, threshold=0, atleast=10, yoffset=.02, xoffset=5, decay=1
nonan = ~np.isnan(zeroRate)
selected = np.zeros_like(zeroRate).astype(bool)
selected[nonan] = zeroRate[nonan] > np.exp(-decay * (meanExpr[nonan] - xoffset)) + yoffset

if plot:
if figsize is not None:
plt.figure(figsize=figsize)
plt.ylim([0, 1])
if threshold > 0:
plt.xlim([np.log2(threshold), np.ceil(np.nanmax(meanExpr))])
else:
plt.xlim([0, np.ceil(np.nanmax(meanExpr))])
x = np.arange(plt.xlim()[0], plt.xlim()[1] + .1, .1)
y = np.exp(-decay * (x - xoffset)) + yoffset
if decay == 1:
plt.text(.4, 0.2, '{} genes selected\ny = exp(-x+{:.2f})+{:.2f}'.format(np.sum(selected), xoffset, yoffset),
color='k', fontsize=labelsize, transform=plt.gca().transAxes)
else:
plt.text(
.4, 0.2,
'{} genes selected\ny = exp(-{:.1f}*(x-{:.2f}))+{:.2f}'.format(np.sum(selected), decay, xoffset,
yoffset), color='k', fontsize=labelsize,
transform=plt.gca().transAxes)

plt.plot(x, y, color=sns.color_palette()[1], linewidth=2)
xy = np.concatenate((np.concatenate((x[:, None], y[:, None]), axis=1), np.array([[plt.xlim()[1], 1]])))
t = plt.matplotlib.patches.Polygon(xy, color=sns.color_palette()[1], alpha=.4)
plt.gca().add_patch(t)

plt.scatter(meanExpr, zeroRate, s=1, alpha=alpha, rasterized=True)
if threshold == 0:
plt.xlabel('Mean log2 nonzero expression')
plt.ylabel('Frequency of zero expression')
else:
plt.xlabel('Mean log2 nonzero expression')
plt.ylabel('Frequency of near-zero expression')
plt.tight_layout()

if markers is not None and genes is not None:
if markeroffsets is None:
markeroffsets = [(0, 0) for g in markers]
for num, g in enumerate(markers):
i = np.where(genes == g)[0]
plt.scatter(meanExpr[i], zeroRate[i], s=10, color='k')
dx, dy = markeroffsets[num]
plt.text(meanExpr[i] + dx + .1, zeroRate[i] + dy, g, color='k', fontsize=labelsize)

return selected


Expand All @@ -2229,29 +2161,27 @@ def load_graph(path, data):
return adj


def normalize_adata(adata, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True):
def normalize_adata(data, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True):
if filter_min_counts:
sc.pp.filter_genes(adata, min_counts=1)
sc.pp.filter_cells(adata, min_counts=1)
sc.pp.filter_genes(data.data, min_counts=1)
sc.pp.filter_cells(data.data, min_counts=1)

if size_factors or normalize_input or logtrans_input:
adata.raw = adata.copy()
data.data.raw = data.data.copy()
else:
adata.raw = adata
data.data.raw = data.data

if size_factors:
sc.pp.normalize_per_cell(adata)
adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts)
sc.pp.normalize_per_cell(data.data)
data.data.obs['size_factors'] = data.data.obs.n_counts / np.median(data.data.obs.n_counts)
else:
adata.obs['size_factors'] = 1.0
data.data.obs['size_factors'] = 1.0

if logtrans_input:
sc.pp.log1p(adata)
sc.pp.log1p(data.data)

if normalize_input:
sc.pp.scale(adata)

return adata
sc.pp.scale(data.data)


def row_normalize(mx):
Expand Down
38 changes: 16 additions & 22 deletions examples/single_modality/clustering/graphsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,36 @@
import numpy as np
import torch

from dance.data import Data
from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.graphsc import *
from dance.transforms.graph_construct import make_graph
from dance.transforms.graph_construct import cell_gene_graph
from dance.transforms.preprocess import filter_data
from dance.utils import set_seed


def pipeline(**args):
data = ClusteringDataset(args['data_dir'], args['dataset']).load_data()
X = data.X
Y = data.Y
adata, labels = ClusteringDataset(args['data_dir'], args['dataset']).load_data()
adata.obsm["labels"] = labels
adata = adata[:100, :]
RemyLau marked this conversation as resolved.
Show resolved Hide resolved
data = Data(adata, train_size=adata.n_obs)
data.set_config(label_channel="labels")
RemyLau marked this conversation as resolved.
Show resolved Hide resolved

filter_data(data, highly_genes=args['nb_genes'])
cell_gene_graph(data, dense_dim=args['in_feats'], node_features=args['node_features'],
normalize_weights=args['normalize_weights'], same_edge_values=args['same_edge_values'],
edge_norm=args['edge_norm'])
data.set_config(feature_channel="graph", feature_channel_type="uns")
graph, Y = data.get_train_data()
n_clusters = len(np.unique(Y))

genes_idx, cells_idx = filter_data(X, highly_genes=args['nb_genes'])
X = X[cells_idx][:, genes_idx]
Y = Y[cells_idx]

t0 = time.time()
graph = make_graph(X, Y, dense_dim=args['in_feats'], node_features=args['node_features'],
normalize_weights=args['normalize_weights'], same_edge_values=args['same_edge_values'],
edge_norm=args['edge_norm'])

labels = graph.ndata["label"]
train_ids = np.where(labels != -1)[0]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args['n_layers'])
dataloader = dgl.dataloading.NodeDataLoader(graph, train_ids, sampler, batch_size=args['batch_size'], shuffle=True,
drop_last=False, num_workers=args['num_workers'])

t1 = time.time()

for run in range(args['num_run']):
t_start = time.time()
torch.manual_seed(run)
torch.cuda.manual_seed_all(run)
np.random.seed(run)
random.seed(run)

set_seed(run)
model = GraphSC(Namespace(**args))
model.fit(args['epochs'], dataloader, n_clusters, args['learning_rate'], cluster=["KMeans", "Leiden"])
pred = model.predict(n_clusters, cluster=["KMeans", "Leiden"])
Expand Down
Loading