Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: DeepImput & GraphSCI imputation modules (dataset + transform) #223

Merged
merged 26 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b12dd26
add feature feature graph class
WenzhuoTang Feb 22, 2023
a9aaa30
Add mask script, reformat graphsci
WenzhuoTang Feb 23, 2023
7ef6cdb
Refactor graphsci
WenzhuoTang Feb 24, 2023
b7d0ea8
Refactor deepimpute
WenzhuoTang Feb 25, 2023
5c30b2a
feat: implement FilterGenesScanpy and FilterCellsScanpy extensions
RemyLau Feb 26, 2023
56cb0af
add new transformation to init
RemyLau Feb 26, 2023
9fecb3f
fix: correctly handel cells/genes
RemyLau Feb 26, 2023
63cd48d
fix: filter via sc.pp.filter_xxxx
RemyLau Feb 26, 2023
72c6c90
format fixes
RemyLau Feb 27, 2023
2447143
fix: mask generation logic; support sparse features
RemyLau Feb 27, 2023
75dd5c9
refactor imputation
WenzhuoTang Feb 27, 2023
a09a562
update graphsci preprocessing pipeline
WenzhuoTang Feb 27, 2023
c37c1d0
Revert "update graphsci preprocessing pipeline"
WenzhuoTang Feb 27, 2023
dd7722b
update scale factor of graphsci
WenzhuoTang Feb 27, 2023
0211db1
fix ratio constraint
WenzhuoTang Feb 28, 2023
3b11521
fix typo
WenzhuoTang Feb 28, 2023
9229bdb
parameter tuning
WenzhuoTang Feb 28, 2023
9d22a45
update imputation results
WenzhuoTang Feb 28, 2023
7c41767
update readme
WenzhuoTang Feb 28, 2023
6ab35a1
update readme
WenzhuoTang Feb 28, 2023
f26332d
update readme
WenzhuoTang Feb 28, 2023
58584e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2023
1b3c160
use osp
RemyLau Feb 28, 2023
0bbf057
fix: revert changes to scdeepsort dataset
RemyLau Feb 28, 2023
f5bc31f
fix: path and error
RemyLau Feb 28, 2023
c19ca66
add random state, use split permutation instead of dataloader
RemyLau Feb 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,10 @@ pip install -e .
| NN | DeepImpute | DeepImpute: an accurate, fast, and scalable deep neural network method to impute single-cell RNA-seq data | 2019 | ✅ |
| NN + TF | Saver-X | Transfer learning in single-cell transcriptomics improves data denoising and pattern discovery | 2019 | P1 |

| Model | Evaluation Metric | Mouse Brain (current/reported) | Mouse Embryo (current/reported) |
| ---------- | ----------------- | ------------------------------ | ------------------------------- |
| DeepImpute | MSE | 0.12 / N/A | 0.12 / N/A |
| ScGNN | MSE | 0.47 / N/A | 1.10 / N/A |
| GraphSCI | MSE | 0.42 / N/A | 0.87 / N/A |

Note: the data split modality of DeepImpute is different from ScGNN and GraphSCI, so the results are not comparable.
| Model | Evaluation Metric | Mouse Brain (current/reported) | Mouse Embryo (current/reported) | PBMC (current/reported) |
| ---------- | ----------------- | ------------------------------ | ------------------------------- | ----------------------- |
| DeepImpute | RMSE | 0.87 / N/A | 1.20 / N/A | 2.30 / N/A |
| GraphSCI | RMSE | 1.55 / N/A | 1.81 / N/A | 3.68 / N/A |

#### 2)Cell Type Annotation

Expand Down
232 changes: 108 additions & 124 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import collections
import glob
import os
import os.path as osp
import pprint
import shutil
import sys
from dataclasses import dataclass
from glob import glob

import anndata as ad
import h5py
import numpy as np
import pandas as pd
import scanpy as sc
from scipy.sparse import csr_matrix

from dance import logger
from dance.data import Data
from dance.datasets.base import BaseDataset
from dance.registers import register_dataset
from dance.transforms.preprocess import load_imputation_data_internal
from dance.typing import Dict, List, Optional, Set, Tuple
from dance.utils.download import download_file, download_unzip
from dance.utils.preprocess import cell_label_to_df
Expand All @@ -25,7 +25,7 @@
@register_dataset("scdeepsort")
class ScDeepSortDataset(BaseDataset):

_DISPLAY_ATTRS = ("species", "tissue", "train_dataset", "test_dataset")
_DISPLAY_ATTRS = ("species", "tissue", "dataset", "test_dataset")
ALL_URL_DICT: Dict[str, str] = {
"train_human_cell_atlas": "https://www.dropbox.com/s/1itq1pokplbqxhx?dl=1",
"test_human_test_data": "https://www.dropbox.com/s/gpxjnnvwyblv3xb?dl=1",
Expand All @@ -52,12 +52,12 @@ class ScDeepSortDataset(BaseDataset):
"test_mouse_Kidney203_data.csv": "https://www.dropbox.com/s/kmos1ceubumgmpj?dl=1",
} # yapf: disable

def __init__(self, full_download=False, train_dataset=None, test_dataset=None, species=None, tissue=None,
def __init__(self, full_download=False, dataset=None, test_dataset=None, species=None, tissue=None,
train_dir="train", test_dir="test", map_path="map", data_dir="./"):
super().__init__(data_dir, full_download)

self.data_dir = data_dir
self.train_dataset = train_dataset
self.dataset = dataset
self.test_dataset = test_dataset
RemyLau marked this conversation as resolved.
Show resolved Hide resolved
self.species = species
self.tissue = tissue
Expand Down Expand Up @@ -138,15 +138,15 @@ def is_complete(self):
def _load_raw_data(self, ct_col: str = "Cell_type") -> Tuple[ad.AnnData, List[Set[str]], List[str], int]:
species = self.species
tissue = self.tissue
train_dataset_ids = self.train_dataset
dataset_ids = self.dataset
test_dataset_ids = self.test_dataset
data_dir = self.data_dir
train_dir = osp.join(data_dir, self.train_dir)
test_dir = osp.join(data_dir, self.test_dir)
map_path = osp.join(data_dir, self.map_path, self.species)

# Load raw data
train_feat_paths, train_label_paths = self._get_data_paths(train_dir, species, tissue, train_dataset_ids)
train_feat_paths, train_label_paths = self._get_data_paths(train_dir, species, tissue, dataset_ids)
test_feat_paths, test_label_paths = self._get_data_paths(test_dir, species, tissue, test_dataset_ids)
train_feat, test_feat = (self._load_dfs(paths, transpose=True) for paths in (train_feat_paths, test_feat_paths))
train_label, test_label = (self._load_dfs(paths) for paths in (train_label_paths, test_label_paths))
Expand Down Expand Up @@ -280,41 +280,41 @@ def _raw_to_dance(self, raw_data: Tuple[ad.AnnData, np.ndarray]):
return data


@dataclass
class ImputationDatasetParams:
data_dir = None
random_seed = None
min_counts = None
train_dataset = None
test_dataset = None
gpu = None
filetype = None
@register_dataset("imputation")
class ImputationDataset(BaseDataset):

URL = {
"pbmc_data": "https://www.dropbox.com/s/brj3orsjbhnhawa/5k.zip?dl=0",
"mouse_embryo_data": "https://www.dropbox.com/s/8ftx1bydoy7kn6p/GSE65525.zip?dl=0",
"mouse_brain_data": "https://www.dropbox.com/s/zzpotaayy2i29hk/neuron_10k.zip?dl=0",
"human_stemcell_data": "https://www.dropbox.com/s/g2qua2j3rqcngn6/GSE75748.zip?dl=0"
}

class ImputationDataset():
DATASET_TO_FILE = {
"pbmc_data": "5k_pbmc_protein_v3_filtered_feature_bc_matrix.h5",
"mouse_embryo_data": [
osp.join("GSE65525", i)
for i in [
"GSM1599494_ES_d0_main.csv",
"GSM1599497_ES_d2_LIFminus.csv",
"GSM1599498_ES_d4_LIFminus.csv",
"GSM1599499_ES_d7_LIFminus.csv",
]
],
"mouse_brain_data": "neuron_10k_v3_filtered_feature_bc_matrix.h5",
"human_stemcell_data": "GSE75748/GSE75748_sc_time_course_ec.csv.gz"
} # yapf: disable

def __init__(self, random_seed=10, gpu=-1, filetype=None, data_dir="data", train_dataset="human_stemcell",
test_dataset="pbmc", min_counts=1):
self.params = ImputationDatasetParams
self.params.data_dir = data_dir
self.params.random_seed = random_seed
self.params.min_counts = min_counts
self.params.train_dataset = train_dataset
self.params.test_dataset = test_dataset
self.params.gpu = gpu
self.params.filetype = filetype
def __init__(self, data_dir="data", dataset="human_stemcell", train_size=0.1):
super().__init__(data_dir, full_download=False)
self.data_dir = data_dir
self.dataset = dataset
self.train_size = train_size
RemyLau marked this conversation as resolved.
Show resolved Hide resolved

def download_all_data(self):
def download(self):

gene_class = ["pbmc_data", "mouse_brain_data", "mouse_embryo_data", "human_stemcell_data"]

url = {
"pbmc_data": "https://www.dropbox.com/s/brj3orsjbhnhawa/5k.zip?dl=0",
"mouse_embryo_data": "https://www.dropbox.com/s/8ftx1bydoy7kn6p/GSE65525.zip?dl=0",
"mouse_brain_data": "https://www.dropbox.com/s/zzpotaayy2i29hk/neuron_10k.zip?dl=0",
"human_stemcell_data": "https://www.dropbox.com/s/g2qua2j3rqcngn6/GSE75748.zip?dl=0"
}

file_name = {
"pbmc_data": "5k.zip?dl=0",
"mouse_embryo_data": "GSE65525.zip?dl=0",
Expand All @@ -329,63 +329,39 @@ def download_all_data(self):
"human_stemcell_data": "GSE75748"
}

dataset_to_file = {
"pbmc_data":
"5k_pbmc_protein_v3_filtered_feature_bc_matrix.h5",
"mouse_embryo_data":
list(
map(lambda x: "GSE65525/" + x, [
"GSM1599494_ES_d0_main.csv", "GSM1599497_ES_d2_LIFminus.csv", "GSM1599498_ES_d4_LIFminus.csv",
"GSM1599499_ES_d7_LIFminus.csv"
])),
"mouse_brain_data":
"neuron_10k_v3_filtered_feature_bc_matrix.h5",
"human_stemcell_data":
"GSE75748/GSE75748_sc_time_course_ec.csv.gz"
}
self.params.dataset_to_file = dataset_to_file
if sys.platform != 'win32':
if not osp.exists(self.params.data_dir):
os.system("mkdir " + self.params.data_dir)
if not osp.exists(self.params.data_dir + "/train"):
os.system("mkdir " + self.params.data_dir + "/train")
if not osp.exists(self.data_dir):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • TODO later: use the download_unzip utility instead of doing these system calls.

os.system("mkdir " + self.data_dir)
if not osp.exists(self.data_dir + "/train"):
os.system("mkdir " + self.data_dir + "/train")

for class_name in gene_class:
if not any(
list(
map(osp.exists,
glob.glob(self.params.data_dir + "/train/" + class_name + "/" +
dl_files[class_name])))):
os.system("mkdir " + self.params.data_dir + "/train/" + class_name)
os.system("wget " + url[class_name]) # assumes linux... mac needs to install
if not any(map(osp.exists, glob(osp.join(self.data_dir, "train", class_name, dl_files[class_name])))):
os.system("mkdir " + self.data_dir + "/train/" + class_name)
os.system("wget " + self.URL[class_name]) # assumes linux... mac needs to install
os.system("unzip " + file_name[class_name])
os.system("rm " + file_name[class_name])
os.system("mv " + dl_files[class_name] + " " + self.params.data_dir + "/train/" + class_name + "/")
os.system("cp -r " + self.params.data_dir + "/train/ " + self.params.data_dir + "/test")
os.system("mv " + dl_files[class_name] + " " + self.data_dir + "/train/" + class_name + "/")
os.system("cp -r " + self.data_dir + "/train/ " + self.data_dir + "/test")
if sys.platform == 'win32':
if not osp.exists(self.params.data_dir):
os.system("mkdir " + self.params.data_dir)
if not osp.exists(self.params.data_dir + "/train"):
os.mkdir(self.params.data_dir + "/train")
if not osp.exists(self.data_dir):
os.system("mkdir " + self.data_dir)
if not osp.exists(self.data_dir + "/train"):
os.mkdir(self.data_dir + "/train")
for class_name in gene_class:
if not any(
list(
map(osp.exists,
glob.glob(self.params.data_dir + "/train/" + class_name + "/" +
dl_files[class_name])))):
os.mkdir(self.params.data_dir + "/train/" + class_name)
os.system("curl " + url[class_name])
if not any(map(osp.exists, glob(osp.join(self.data_dir, "train", class_name, dl_files[class_name])))):
os.mkdir(self.data_dir + "/train/" + class_name)
os.system("curl " + self.URL[class_name])
os.system("tar -xf " + file_name[class_name])
os.system("del -R " + file_name[class_name])
os.system("move " + dl_files[class_name] + " " + self.params.data_dir + "/train/" + class_name +
"/")
os.system("copy /r " + self.params.data_dir + "/train/ " + self.params.data_dir + "/test")
os.system("move " + dl_files[class_name] + " " + self.data_dir + "/train/" + class_name + "/")
os.system("copy /r " + self.data_dir + "/train/ " + self.data_dir + "/test")

def is_complete(self):
# check whether data is complete or not
check = [
self.params.data_dir + "/train",
self.params.data_dir + "/test",
self.data_dir + "/train",
self.data_dir + "/test",
]

for i in check:
Expand All @@ -394,48 +370,56 @@ def is_complete(self):
return False
return True

def load_data(self, model_params, model='GraphSCI'):
# Load data from existing h5ad files, or download files and load data.
if self.is_complete():
pass
def _load_raw_data(self) -> ad.AnnData:
if self.dataset[-5:] != '_data':
dataset = self.dataset + '_data'
else:
dataset = self.dataset

if self.dataset == 'mouse_embryo' or self.dataset == 'mouse_embryo_data':
for i in range(len(self.DATASET_TO_FILE[dataset])):
fname = self.DATASET_TO_FILE[dataset][i]
data_path = f'{self.data_dir}/train/{dataset}/{fname}'
if i == 0:
counts = pd.read_csv(data_path, header=None, index_col=0)
time = pd.Series(np.zeros(counts.shape[1]))
else:
x = pd.read_csv(data_path, header=None, index_col=0)
time = pd.concat([time, pd.Series(np.zeros(x.shape[1])) + i])
counts = pd.concat([counts, x], axis=1)
time = pd.DataFrame(time)
time.columns = ['time']
counts = counts.T
counts.index = [i for i in range(counts.shape[0])]
adata = ad.AnnData(csr_matrix(counts.values))
adata.var_names = counts.columns.tolist()
adata.obs['time'] = time.to_numpy()
else:
self.download_all_data()
assert self.is_complete()

data_dict = load_imputation_data_internal(self.params, model_params, model=model)
self.params.num_cells = data_dict['num_cells']
self.params.num_genes = data_dict['num_genes']
self.params.train_data = data_dict['train_data']
self.params.test_data = data_dict['test_data']
self.params.adata = data_dict['adata']

if model == 'GraphSCI':
self.params.train_data_raw = data_dict['train_data_raw']
self.params.test_data_raw = data_dict['test_data_raw']
self.params.adj_train = data_dict['adj_train']
self.params.adj_test = data_dict['adj_test']
self.params.adj_train_false = data_dict['adj_train_false']
self.params.adj_norm_train = data_dict['adj_norm_train']
self.params.adj_norm_test = data_dict['adj_norm_test']
self.params.size_factors = data_dict['size_factors']
self.params.train_size_factors = data_dict['train_size_factors']
self.params.test_size_factors = data_dict['test_size_factors']
self.params.train_size_factors = data_dict['train_size_factors']
self.params.test_size_factors = data_dict['test_size_factors']
self.params.test_idx = data_dict['test_idx']
if model == 'DeepImpute':
self.params.X_train = self.params.train_data[0]
self.params.Y_train = self.params.train_data[1]
self.params.X_test = self.params.test_data[0]
self.params.Y_test = self.params.test_data[1]
self.params.inputgenes = data_dict['predictors']
self.params.targetgenes = data_dict['targets']
self.params.total_counts = data_dict['total_counts']
self.params.true_counts = data_dict['true_counts']
self.params.genes_to_impute = data_dict['genes_to_impute']
if model == 'scGNN':
self.params.genelist = data_dict['genelist']
self.params.celllist = data_dict['celllist']
self.params.test_idx = data_dict['test_idx']

return self
data_path = f'{self.data_dir}/train/{dataset}/{self.DATASET_TO_FILE[dataset]}'
RemyLau marked this conversation as resolved.
Show resolved Hide resolved
if not os.path.exists(data_path):
raise NotImplementedError

if self.DATASET_TO_FILE[dataset][-3:] == 'csv':
counts = pd.read_csv(data_path, index_col=0, header=None)
counts = counts.T
adata = ad.AnnData(csr_matrix(counts.values))
# adata.obs_names = ["%d"%i for i in range(adata.shape[0])]
adata.obs_names = counts.index.tolist()
adata.var_names = counts.columns.tolist()
if self.DATASET_TO_FILE[dataset][-2:] == 'gz':
counts = pd.read_csv(data_path, index_col=0, compression='gzip', header=0)
counts = counts.T
adata = ad.AnnData(csr_matrix(counts.values))
# adata.obs_names = ["%d" % i for i in range(adata.shape[0])]
adata.obs_names = counts.index.tolist()
adata.var_names = counts.columns.tolist()
elif self.DATASET_TO_FILE[dataset][-2:] == 'h5':
adata = sc.read_10x_h5(data_path)
adata.var_names_make_unique()

return adata

def _raw_to_dance(self, raw_data: ad.AnnData):
adata = raw_data
data = Data(adata, train_size=int(adata.n_obs * self.train_size))
return data
Loading