Skip to content

Commit

Permalink
DANCE V1.0.0 benchmark production (#371)
Browse files Browse the repository at this point in the history
* update graphsc for dgl 1.1

* update logging frequency

* remove unused line

* hotfix: more genereal gene split

* hotfix: save copy to raw

* update examples for benchmark

* set random seeds for all scripts in CTA, SD, CTD

* set seed; multiple runs

* rename ScDeepSortDatatse to CellTypeAnnotationDataset

* update the eval pipeline of imp models

* hotfix: split by batch_size

* minor update: minor doc fix; allow none idx_to_label input

* update benchmarking examples

* fix: scn predicted unsure label

* Update multimodality modules and modality predict examples.

* Update joint modality score functions.

* upload dance data tutorial

* minor fix: use majority voting results when majority_voting option is on

* Update examples for multimodality.

* fix cta dataest class name

* hotfix: update scripts for benchmark

* hotfix: undo incorrect overwrits by 517681c

* fix: cell type annotation dataset handling

* fix: follow ACTNN staircase exp decay as described in paper

* update ACTINN default params

* fix: parse device to spagcn

* Update multimodality.

* format code: sort imports; format example cmd; rnd_seed -> seed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Wenzhuo Tang <tangwen2@msu.edu>
Co-authored-by: Hongzhi Wen <wenhongz@msu.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 29, 2023
1 parent a573eb1 commit f44b9d0
Show file tree
Hide file tree
Showing 60 changed files with 2,443 additions and 883 deletions.
4 changes: 2 additions & 2 deletions dance/datasets/multimodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def __init__(self, subtask, root="./data", preprocess=None, pkl_path=None, span=

def _raw_to_dance(self, raw_data):
train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label = self._maybe_preprocess(raw_data)
# Align matched cells
train_mod2 = train_mod2[train_label.to_df().values.argmax(1)]

mod1 = ad.concat((train_mod1, test_mod1))
mod2 = ad.concat((train_mod2, test_mod2))
Expand All @@ -272,8 +274,6 @@ def _raw_to_dance(self, raw_data):
mod2.obs_names = mod1.obs_names
train_size = train_mod1.shape[0]

# Align matched cells
train_mod2 = train_mod2[train_label.to_df().values.argmax(1)]
mod1.obsm["labels"] = np.concatenate([np.zeros(train_size), np.argmax(test_label.X.toarray(), 1)])

# Combine modalities into mudata
Expand Down
10 changes: 6 additions & 4 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _load_scdeepsort_metadata():
return bench_url_dict, available_data


@register_dataset("scdeepsort")
class ScDeepSortDataset(BaseDataset):
@register_dataset("CellTypeAnnotation")
class CellTypeAnnotationDataset(BaseDataset):
_DISPLAY_ATTRS = ("species", "tissue", "train_dataset", "test_dataset")
ALL_URL_DICT: Dict[str, str] = {
"train_human_cell_atlas": "https://www.dropbox.com/s/1itq1pokplbqxhx?dl=1",
Expand Down Expand Up @@ -84,7 +84,7 @@ def download_all(self):

def get_all_filenames(self, filetype: str = "csv", feat_suffix: str = "data", label_suffix: str = "celltype"):
filenames = []
for id in self.train_dataset:
for id in self.train_dataset + self.test_dataset:
filenames.append(f"{self.species}_{self.tissue}{id}_{feat_suffix}.{filetype}")
filenames.append(f"{self.species}_{self.tissue}{id}_{label_suffix}.{filetype}")
return filenames
Expand Down Expand Up @@ -123,8 +123,10 @@ def is_complete_all(self):
def is_complete(self):
"""Check if benchmarking data is complete."""
for name in self.BENCH_URL_DICT:
if any(i not in name for i in (self.species, self.tissue)):
continue
filename = name[name.find(self.species):]
file_i = osp.join(self.data_dir, *name.split("_")[:2], filename)
file_i = osp.join(self.data_dir, *(name.split("_"))[:2], filename)
if not osp.exists(file_i):
logger.info(file_i)
logger.info(f"file {filename} doesn't exist")
Expand Down
64 changes: 36 additions & 28 deletions dance/modules/multi_modality/joint_embedding/dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import time
import warnings
from collections import OrderedDict
from copy import deepcopy

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -618,6 +619,7 @@ def fit(self, train_loader, test_loader, total_loader, model_pre, args, criterio
test_like_max = test_loss.item()
reco_epoch_test = epoch
patience_epoch = 0
best_dict = deepcopy(self.state_dict())

if flag_break == 1:
print("containin NA")
Expand All @@ -636,6 +638,7 @@ def fit(self, train_loader, test_loader, total_loader, model_pre, args, criterio
break

duration = time.time() - start
self.load_state_dict(best_dict)

print('Finish training, total time is: ' + str(duration) + 's')
self.eval()
Expand Down Expand Up @@ -823,7 +826,7 @@ def fit(self, train_loader, test_loader, total_loader, first="RNA"):

used_cycle = used_cycle + 1

def score(self, dataloader):
def score(self, dataloader, metric='clustering'):
"""Score function to get score of prediction.
Parameters
Expand All @@ -844,45 +847,50 @@ def score(self, dataloader):
"""

self.model1.eval()
self.model2.eval()
if metric == 'clustering':
self.model1.eval()
self.model2.eval()

with torch.no_grad():
with torch.no_grad():

kmeans1 = KMeans(n_clusters=self.args.cluster1, n_init=5, random_state=200)
kmeans2 = KMeans(n_clusters=self.args.cluster2, n_init=5, random_state=200)
kmeans1 = KMeans(n_clusters=self.args.cluster1, n_init=5, random_state=200)
kmeans2 = KMeans(n_clusters=self.args.cluster2, n_init=5, random_state=200)

latent_code_rna = []
latent_code_atac = []
latent_code_rna = []
latent_code_atac = []

for batch_idx, (X1, _, size_factor1, X2, _, size_factor2) in enumerate(dataloader):
for batch_idx, (X1, _, size_factor1, X2, _, size_factor2) in enumerate(dataloader):

X1, size_factor1 = X1.to(self.args.device), size_factor1.to(self.args.device)
X2, size_factor2 = X2.to(self.args.device), size_factor2.to(self.args.device)
X1, size_factor1 = X1.to(self.args.device), size_factor1.to(self.args.device)
X2, size_factor2 = X2.to(self.args.device), size_factor2.to(self.args.device)

X1, size_factor1 = Variable(X1), Variable(size_factor1)
X2, size_factor2 = Variable(X2), Variable(size_factor2)
X1, size_factor1 = Variable(X1), Variable(size_factor1)
X2, size_factor2 = Variable(X2), Variable(size_factor2)

result1 = self.model1.inference(X1, size_factor1)
result2 = self.model2.inference(X2, size_factor2)
result1 = self.model1.inference(X1, size_factor1)
result2 = self.model2.inference(X2, size_factor2)

latent_code_rna.append(result1["latent_z1"].data.cpu().numpy())
latent_code_atac.append(result2["latent_z1"].data.cpu().numpy())
latent_code_rna.append(result1["latent_z1"].data.cpu().numpy())
latent_code_atac.append(result2["latent_z1"].data.cpu().numpy())

latent_code_rna = np.concatenate(latent_code_rna)
latent_code_atac = np.concatenate(latent_code_atac)
latent_code_rna = np.concatenate(latent_code_rna)
latent_code_atac = np.concatenate(latent_code_atac)

pred_z1 = kmeans1.fit_predict(latent_code_rna)
NMI_score1 = round(normalized_mutual_info_score(self.ground_truth1, pred_z1, average_method='max'), 3)
ARI_score1 = round(metrics.adjusted_rand_score(self.ground_truth1, pred_z1), 3)
pred_z1 = kmeans1.fit_predict(latent_code_rna)
NMI_score1 = round(normalized_mutual_info_score(self.ground_truth1, pred_z1, average_method='max'), 3)
ARI_score1 = round(metrics.adjusted_rand_score(self.ground_truth1, pred_z1), 3)

pred_z2 = kmeans1.fit_predict(latent_code_atac)
NMI_score2 = round(normalized_mutual_info_score(self.ground_truth1, pred_z2, average_method='max'), 3)
ARI_score2 = round(metrics.adjusted_rand_score(self.ground_truth1, pred_z2), 3)
pred_z2 = kmeans1.fit_predict(latent_code_atac)
NMI_score2 = round(normalized_mutual_info_score(self.ground_truth1, pred_z2, average_method='max'), 3)
ARI_score2 = round(metrics.adjusted_rand_score(self.ground_truth1, pred_z2), 3)

print('scRNA-ARI: ' + str(ARI_score1) + ' NMI: ' + str(NMI_score1) + ' scEpigenomics-ARI: ' +
str(ARI_score2) + ' NMI: ' + str(NMI_score2))
return NMI_score1, ARI_score1, NMI_score2, ARI_score2
print('scRNA-ARI: ' + str(ARI_score1) + ' NMI: ' + str(NMI_score1) + ' scEpigenomics-ARI: ' +
str(ARI_score2) + ' NMI: ' + str(NMI_score2))
return NMI_score1, ARI_score1, NMI_score2, ARI_score2
elif metric == 'openproblems':
raise NotImplementedError
else:
raise NotImplementedError

def _encodeBatch(self, total_loader):
"""Helper function to get latent representation, normalized representation and
Expand Down
27 changes: 21 additions & 6 deletions dance/modules/multi_modality/joint_embedding/jae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

import os
from copy import deepcopy

import numpy as np
import torch
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self, args, num_celL_types, num_batches, num_phases, num_features):
print(num_celL_types, num_batches, num_phases, num_features)
self.args = args

def fit(self, inputs, cell_type, batch_label, phase_score):
def fit(self, inputs, cell_type, batch_label, phase_score, max_epochs=60):
"""Fit function for training.
Parameters
Expand Down Expand Up @@ -85,7 +86,7 @@ def fit(self, inputs, cell_type, batch_label, phase_score):
optimizer = torch.optim.Adam([{'params': self.model.parameters()}], lr=1e-4)
vals = []

for epoch in range(60):
for epoch in range(max_epochs):
self.model.train()
total_loss = [0] * 5
print('epoch', epoch)
Expand Down Expand Up @@ -127,10 +128,15 @@ def fit(self, inputs, cell_type, batch_label, phase_score):
if min(vals) == vals[-1]:
if not os.path.exists('models'):
os.mkdir('models')
torch.save(self.model.state_dict(), f'models/model_joint_embedding_{self.args.rnd_seed}.pth')
best_dict = deepcopy(self.model.state_dict())


# torch.save(self.model.state_dict(), f'models/model_joint_embedding_{self.args.rnd_seed}.pth')

if min(vals) != min(vals[-10:]):
print('Early stopped.')
break
self.model.load_state_dict(best_dict)

def to(self, device):
"""Performs device conversion.
Expand Down Expand Up @@ -191,7 +197,7 @@ def predict(self, inputs, idx):
prediction = self.model.encoder(inputs[idx])
return prediction

def score(self, inputs, idx, cell_type, batch_label=None, phase_score=None, metric='loss'):
def score(self, inputs, idx, cell_type, batch_label=None, phase_score=None, adata_sol=None, metric='loss'):
"""Score function to get score of prediction.
Parameters
Expand All @@ -206,6 +212,8 @@ def score(self, inputs, idx, cell_type, batch_label=None, phase_score=None, metr
Cell cycle phase labels.
metric : str optional
The type of evaluation metric, by default to be 'loss'.
adata_sol : anndata.AnnData optional
The solution anndata containing cell stypes, phase scores and batches. Required by 'openproblems' evaluation.
Returns
-------
Expand Down Expand Up @@ -234,7 +242,7 @@ def score(self, inputs, idx, cell_type, batch_label=None, phase_score=None, metr
loss4 = mse(output[3], phase_score[idx]).item()

return loss1, loss2, loss3, loss4
else:
elif metric == 'clustering':
emb = self.predict(inputs, idx).cpu().numpy()

kmeans = KMeans(n_clusters=10, n_init=5, random_state=200)
Expand All @@ -248,7 +256,14 @@ def score(self, inputs, idx, cell_type, batch_label=None, phase_score=None, metr
ARI_score = round(adjusted_rand_score(true_labels, pred_labels), 3)

# print('ARI: ' + str(ARI_score) + ' NMI: ' + str(NMI_score))
return NMI_score, ARI_score
return {'dance_nmi': NMI_score, 'dance_ari': ARI_score}
elif metric == 'openproblems':
emb = self.predict(inputs, idx).cpu().numpy()
assert adata_sol, 'adata_sol is required by `openproblems` evaluation but not provided.'
adata_sol.obsm['X_emb'] = emb
return integration_openproblems_evaluate(adata_sol)
else:
raise NotImplementedError


class JAE(nn.Module):
Expand Down
20 changes: 16 additions & 4 deletions dance/modules/multi_modality/joint_embedding/scmogcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import os
from copy import deepcopy

import dgl.nn.pytorch as dglnn
import numpy as np
Expand All @@ -19,6 +20,7 @@

from dance import logger
from dance.utils import SimpleIndexDataset
from dance.utils.metrics import *


def propagation_layer_combination(X, idx, wt, from_logits=True):
Expand Down Expand Up @@ -207,12 +209,15 @@ def fit(self, g_mod1, g_mod2, train_size, cell_type, batch_label, phase_score):
os.mkdir('models')
torch.save(self.model.state_dict(), f'models/model_joint_embedding_{self.args.rnd_seed}.pth')
weight_record = wt.detach()
best_dict = deepcopy(self.model.state_dict())

if min(vals) != min(vals[-10:]):
print('Early stopped.')
break

self.wt = weight_record
self.fitted = True
self.model.load_state_dict(best_dict)

def to(self, device):
"""Performs device conversion.
Expand Down Expand Up @@ -280,9 +285,9 @@ def predict(self, idx):
with torch.no_grad():
X = propagation_layer_combination(inputs, idx, wt)

return self.model.encoder(X)
return self.model.encoder(X)

def score(self, idx, cell_type, phase_score=None, metric='loss'):
def score(self, idx, cell_type, phase_score=None, adata_sol=None, metric='loss'):
"""Score function to get score of prediction.
Parameters
Expand Down Expand Up @@ -324,7 +329,7 @@ def score(self, idx, cell_type, phase_score=None, metric='loss'):
loss4 = mse(output[3], phase_score[idx]).item()

return loss1, loss2, loss3, loss4
else:
elif metric == 'clustering':
emb = self.predict(idx).cpu().numpy()
kmeans = KMeans(n_clusters=10, n_init=5, random_state=200)

Expand All @@ -336,7 +341,14 @@ def score(self, idx, cell_type, phase_score=None, metric='loss'):
ARI_score = round(adjusted_rand_score(true_labels, pred_labels), 3)

# print('ARI: ' + str(ARI_score) + ' NMI: ' + str(NMI_score))
return NMI_score, ARI_score
return {'dance_nmi': NMI_score, 'dance_ari': ARI_score}
elif metric == 'openproblems':
emb = self.predict(idx).cpu().numpy()
assert adata_sol, 'adata_sol is required by `openproblems` evaluation but not provided.'
adata_sol.obsm['X_emb'] = emb
return integration_openproblems_evaluate(adata_sol)
else:
raise NotImplementedError


class ScMoGCN(nn.Module):
Expand Down
34 changes: 23 additions & 11 deletions dance/modules/multi_modality/joint_embedding/scmvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import time
import warnings
from copy import deepcopy

import numpy as np
import scipy.stats as stats
Expand All @@ -32,6 +33,7 @@
from tqdm import trange

from dance.utils.loss import GMM_loss
from dance.utils.metrics import integration_openproblems_evaluate

warnings.filterwarnings("ignore", category=DeprecationWarning)

Expand Down Expand Up @@ -717,10 +719,10 @@ def fit(self, args, train, valid, final_rate, scale_factor, device):
break

if test_like_max > test_loss.item():
best_dict = deepcopy(self.state_dict())
test_like_max = test_loss.item()
epoch_count = 0

save_checkpoint(self)
best_dict = deepcopy(self.state_dict())

print(
str(epoch) + " " + str(loss.item()) + " " + str(test_loss.item()) + " " +
Expand All @@ -747,8 +749,10 @@ def fit(self, args, train, valid, final_rate, scale_factor, device):
duration = time.time() - start
print('Finish training, total time: ' + str(duration) + 's' + " epoch: " + str(reco_epoch_test) + " status: " +
status)
self.load_state_dict(best_dict)


load_checkpoint('./saved_model/model_best.pth.tar', self, device)
# load_checkpoint('./saved_model/model_best.pth.tar', self, device)

def predict(self, X1, X2, out='Z', device='cpu'):
"""Predict function to get prediction.
Expand Down Expand Up @@ -795,7 +799,7 @@ def predict(self, X1, X2, out='Z', device='cpu'):
# output.append(self.get_gamma(z)[0].cpu().detach())
return None

def score(self, X1, X2, labels):
def score(self, X1, X2, labels, adata_sol=None, metric='clustering'):
"""Score function to get score of prediction.
Parameters
Expand All @@ -816,16 +820,24 @@ def score(self, X1, X2, labels):
"""

emb = self.predict(X1, X2).cpu().numpy()
kmeans = KMeans(n_clusters=10, n_init=5, random_state=200)
if metric == 'clustering':
emb = self.predict(X1, X2).cpu().numpy()
kmeans = KMeans(n_clusters=10, n_init=5, random_state=200)

true_labels = labels.numpy()
pred_labels = kmeans.fit_predict(emb)
true_labels = labels.numpy()
pred_labels = kmeans.fit_predict(emb)

NMI_score = round(normalized_mutual_info_score(true_labels, pred_labels, average_method='max'), 3)
ARI_score = round(adjusted_rand_score(true_labels, pred_labels), 3)
NMI_score = round(normalized_mutual_info_score(true_labels, pred_labels, average_method='max'), 3)
ARI_score = round(adjusted_rand_score(true_labels, pred_labels), 3)

return NMI_score, ARI_score
return {'dance_nmi': NMI_score, 'dance_ari': ARI_score}
elif metric == 'openproblems':
emb = self.predict(X1, X2).cpu().numpy()
assert adata_sol, 'adata_sol is required by `openproblems` evaluation but not provided.'
adata_sol.obsm['X_emb'] = emb
return integration_openproblems_evaluate(adata_sol)
else:
raise NotImplementedError


class ProductOfExperts(nn.Module):
Expand Down
Loading

0 comments on commit f44b9d0

Please sign in to comment.