Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor multi dataset object for matching modality tasks #296

Merged
merged 10 commits into from
Jun 26, 2023
2 changes: 1 addition & 1 deletion dance/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, data: Union[anndata.AnnData, mudata.MuData], train_size: Opti

# Store data and pass through some main properties over
self._data = data
for prop in self._DATA_CHANNELS + ["X"]:
for prop in self._DATA_CHANNELS + ["X", "mod"]:
assert not hasattr(self, prop)
setattr(self, prop, getattr(data, prop))

Expand Down
125 changes: 79 additions & 46 deletions dance/datasets/multimodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,21 @@ def mod_data_paths(self) -> List[str]:
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_mod2.h5ad"),
]
else:
elif self.TASK == "predict_modality":
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod2.h5ad"),
]
elif self.TASK == "match_modality":
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_sol.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_sol.h5ad"),
]
return paths

Expand Down Expand Up @@ -197,82 +206,106 @@ class ModalityMatchingDataset(MultiModalityDataset):
}
AVAILABLE_DATA = sorted(list(URL_DICT) + list(SUBTASK_NAME_MAP))

def __init__(self, subtask, root="./data"):
def __init__(self, subtask, root="./data", preprocess=None, pkl_path=None):
# TODO: factor our preprocess
self.preprocess = preprocess
self.pkl_path = pkl_path
super().__init__(subtask, root)
self.preprocessed = False

def load_sol(self):
assert (self.loaded)
self.train_sol = ad.read_h5ad(
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_sol.h5ad"))
self.test_sol = ad.read_h5ad(
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_sol.h5ad"))
self.modalities[1] = self.modalities[1][self.train_sol.to_df().values.argmax(1)]
return self
def _load_raw_data(self):
# TODO: merge to MultiModalityDataset?
modalities = []
for mod_path in self.mod_data_paths:
logger.info(f"Loading {mod_path}")
modalities.append(ad.read_h5ad(mod_path))
return modalities

def _raw_to_dance(self, raw_data):
train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label = self._maybe_preprocess(raw_data)

mod1 = ad.concat((train_mod1, test_mod1))
mod2 = ad.concat((train_mod2, test_mod2))
mod1.var_names_make_unique()
mod2.var_names_make_unique()
mod2.obs_names = mod1.obs_names
train_size = train_mod1.shape[0]

# Align matched cells
train_mod2 = train_mod2[train_label.to_df().values.argmax(1)]
mod1.obsm["labels"] = np.concatenate([np.zeros(train_size), np.argmax(test_label.X.toarray(), 1)])

# Combine modalities into mudata
mdata = md.MuData({"mod1": mod1, "mod2": mod2})
mdata.var_names_make_unique()

data = Data(mdata, train_size=train_size)

return data

def preprocess(self, kind="pca", pkl_path=None, selection_threshold=10000):
def _maybe_preprocess(self, raw_data, selection_threshold=10000):
train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label = raw_data
modalities = [train_mod1, train_mod2, test_mod1, test_mod2]

# TODO: support other two subtasks
assert self.subtask in ("openproblems_bmmc_cite_phase2_rna",
"openproblems_bmmc_multiome_phase2_rna"), "Currently not available."

if kind == "pca":
if pkl_path and (not osp.exists(pkl_path)):
if self.preprocess == "pca":
if self.pkl_path and osp.exists(self.pkl_path):
with open(self.pkl_path, "rb") as f:
preprocessed_features = pickle.load(f)

else:
if self.subtask == "openproblems_bmmc_cite_phase2_rna":
lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True)
m1_train = lsi_transformer_gex.fit_transform(self.modalities[0]).values
m1_test = lsi_transformer_gex.transform(self.modalities[2]).values
m2_train = self.modalities[1].X.toarray()
m2_test = self.modalities[3].X.toarray()
m1_train = lsi_transformer_gex.fit_transform(modalities[0]).values
m1_test = lsi_transformer_gex.transform(modalities[2]).values
m2_train = modalities[1].X.toarray()
m2_test = modalities[3].X.toarray()

elif self.subtask == "openproblems_bmmc_multiome_phase2_rna":
lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True)
m1_train = lsi_transformer_gex.fit_transform(self.modalities[0]).values
m1_test = lsi_transformer_gex.transform(self.modalities[2]).values
m1_train = lsi_transformer_gex.fit_transform(modalities[0]).values
m1_test = lsi_transformer_gex.transform(modalities[2]).values
lsi_transformer_atac = lsiTransformer(n_components=512, drop_first=True)
m2_train = lsi_transformer_atac.fit_transform(self.modalities[1]).values
m2_test = lsi_transformer_atac.transform(self.modalities[3]).values
m2_train = lsi_transformer_atac.fit_transform(modalities[1]).values
m2_test = lsi_transformer_atac.transform(modalities[3]).values

else:
raise ValueError(f"Unrecognized subtask name: {self.subtask}")

self.preprocessed_features = {
preprocessed_features = {
"mod1_train": m1_train,
"mod2_train": m2_train,
"mod1_test": m1_test,
"mod2_test": m2_test
}
self.modalities[0].obsm["X_pca"] = self.preprocessed_features["mod1_train"]
self.modalities[1].obsm["X_pca"] = self.preprocessed_features["mod2_train"]
self.modalities[2].obsm["X_pca"] = self.preprocessed_features["mod1_test"]
self.modalities[3].obsm["X_pca"] = self.preprocessed_features["mod2_test"]
pickle.dump(self.preprocessed_features, open(pkl_path, "wb"))

else:
self.preprocessed_features = pickle.load(open(pkl_path, "rb"))
self.modalities[0].obsm["X_pca"] = self.preprocessed_features["mod1_train"]
self.modalities[1].obsm["X_pca"] = self.preprocessed_features["mod2_train"]
self.modalities[2].obsm["X_pca"] = self.preprocessed_features["mod1_test"]
self.modalities[3].obsm["X_pca"] = self.preprocessed_features["mod2_test"]
elif kind == "feature_selection":
if self.pkl_path:
with open(self.pkl_path, "wb") as f:
pickle.dump(preprocessed_features, f)

modalities[0].obsm["X_pca"] = preprocessed_features["mod1_train"]
modalities[1].obsm["X_pca"] = preprocessed_features["mod2_train"]
modalities[2].obsm["X_pca"] = preprocessed_features["mod1_test"]
modalities[3].obsm["X_pca"] = preprocessed_features["mod2_test"]

elif self.preprocess == "feature_selection":
for i in range(2):
if self.modalities[i].shape[1] > selection_threshold:
sc.pp.highly_variable_genes(self.modalities[i], layer="counts", flavor="seurat_v3",
if modalities[i].shape[1] > selection_threshold:
sc.pp.highly_variable_genes(modalities[i], layer="counts", flavor="seurat_v3",
n_top_genes=selection_threshold)
self.modalities[i + 2].var["highly_variable"] = self.modalities[i].var["highly_variable"]
self.modalities[i] = self.modalities[i][:, self.modalities[i].var["highly_variable"]]
self.modalities[i + 2] = self.modalities[i + 2][:, self.modalities[i + 2].var["highly_variable"]]
modalities[i + 2].var["highly_variable"] = modalities[i].var["highly_variable"]
modalities[i] = modalities[i][:, modalities[i].var["highly_variable"]]
modalities[i + 2] = modalities[i + 2][:, modalities[i + 2].var["highly_variable"]]

else:
logger.info("Preprocessing method not supported.")
return self

logger.info("Preprocessing done.")
self.preprocessed = True
return self

def get_preprocessed_features(self):
assert self.preprocessed, "Transformed features do not exist."
return self.preprocessed_features
train_mod1, train_mod2, test_mod1, test_mod2 = modalities
return train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label


class JointEmbeddingNIPSDataset(MultiModalityDataset):
Expand Down
2 changes: 1 addition & 1 deletion dance/modules/multi_modality/match_modality/cmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,6 @@ def fit(self, train_mod1, train_mod2, aux_labels=None, checkpoint_directory='./c
while True:
print('Iteration: ', iterations)
for it, batch_idx in enumerate(train_loader):
self._update_learning_rate()
mod1, mod2 = train_mod1[batch_idx], train_mod2[batch_idx]
for _ in range(num_disc):
self._dis_update(mod1, mod2, hyperparameters)
Expand All @@ -580,6 +579,7 @@ def fit(self, train_mod1, train_mod2, aux_labels=None, checkpoint_directory='./c
aux_labels[batch_idx], variational=False)
else:
self._gen_update(mod1, mod2, mod1, mod2, hyperparameters, variational=False)
self._update_learning_rate()
print('Matching score:', self.score(train_mod1[val_idx], train_mod2[val_idx],
torch.arange(val_idx.shape[0])))

Expand Down
5 changes: 0 additions & 5 deletions dance/modules/multi_modality/match_modality/scmogcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ class ScMoGCNWrapper:
def __init__(self, args, layers, temp=1):
self.model = ScMoGCN(args, layers, temp).to(args.device)
self.args = args
self.fitted = False
wt1 = torch.tensor([0.] * (args.layers - 1)).to(args.device).requires_grad_(True)
wt2 = torch.tensor([0.] * (args.layers - 1)).to(args.device).requires_grad_(True)
self.wt = [wt1, wt2]
Expand Down Expand Up @@ -149,7 +148,6 @@ def load(self, path, map_location=None):
None.

"""
self.fitted = True
if map_location is not None:
self.model.load_state_dict(torch.load(path, map_location=map_location))
else:
Expand Down Expand Up @@ -268,7 +266,6 @@ def fit(self, g_mod1, g_mod2, labels1, labels2, train_size):
break

logger.info(f'Valid: {maxval}')
self.fitted = True

self.wt = weight_record
return self
Expand All @@ -294,8 +291,6 @@ def predict(self, idx, enhance=False, batch1=None, batch2=None):

"""
# inputs: [train_mod1, train_mod2], idx: valid_idx, labels: [sol, sol.T], wt: [wt0, wt1]
if not self.fitted:
raise RuntimeError('Model not fitted yet.')
self.model.eval()

with torch.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion dance/transforms/graph/cell_feature_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __call__(self, data):
g = dgl.bipartite_from_scipy(feat, utype='cell', etype='cell2feature', vtype='feature', eweight_name='weight')
g.nodes['cell'].data['id'] = torch.arange(feat.shape[0]).long()
g.nodes['feature'].data['id'] = torch.arange(feat.shape[1]).long()
g = AddReverse(copy_edata=True, sym_new_etype=True)(g)
g = dgl.AddReverse(copy_edata=True, sym_new_etype=True)(g)
if self.mod is None:
data.data.uns['g'] = g
else:
Expand Down
Loading