Skip to content

Commit

Permalink
clean up spatial deconv datasets; improve ct_select alignment (#120)
Browse files Browse the repository at this point in the history
* wip

* fix ct_select variable, improve log msg

* improve pseudo_spatial_process

- improve comments and add docstring
- use real labels if available instead of filling with copies from pseudo spots

* use aligned y

* remove unused functions

* improve format, use sc.pp

* specify dtype when constructing adata and use copy of a df instead of a view

* remove unused imports in dstg_graph

* adapt default ct_select in spatial dataset to card example script

* adapt default ct_select in spatial dataset to spotlight example script

* adapt default ct_select in spatial dataset to spatialdecon example script

* add toto notes

* remove unused imports and deprecated spatial dataset classes
  • Loading branch information
RemyLau authored Jan 9, 2023
1 parent 7957197 commit 46537af
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 331 deletions.
165 changes: 24 additions & 141 deletions dance/datasets/spatial.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import glob
import os
import os.path as osp
import warnings
from pprint import pformat

import anndata
import cv2
import pandas as pd
import rdata
import scanpy as sc

from dance import logger
from dance.data import download_file, download_unzip, unzip_file

IGNORED_FILES = ["readme.txt"]
Expand Down Expand Up @@ -38,29 +37,6 @@
"toy2": "https://www.dropbox.com/sh/eqkcm344p5d1akr/AAAPs0Z0S7yFC5ML8Kcd5eU9a?dl=1",
}

card_simulation_dataset = {
"sim_noise0_rep1":
"https://www.dropbox.com/s/aujusznwbq4xa99/sim.pseudo.MOB.n10.cellType6.Mixnoise0.repeat1.RData?dl=1",
"sim_noise0_rep2":
"https://www.dropbox.com/s/bjyagojp3opkyxd/sim.pseudo.MOB.n10.cellType6.Mixnoise0.repeat2.RData?dl=1",
"sim_noise0_rep3":
"https://www.dropbox.com/s/k9n07ujvqtroj72/sim.pseudo.MOB.n10.cellType6.Mixnoise0.repeat3.RData?dl=1",
"sim_noise0_rep4":
"https://www.dropbox.com/s/3m49dq387he776x/sim.pseudo.MOB.n10.cellType6.Mixnoise0.repeat4.RData?dl=1",
"sim_noise0_rep5":
"https://www.dropbox.com/s/4wjsl2ids1q16b2/sim.pseudo.MOB.n10.cellType6.Mixnoise0.repeat5.RData?dl=1",
"sim_noise3_rep1":
"https://www.dropbox.com/s/z6ehj48q0vcxf13/sim.pseudo.MOB.n10.cellType6.Mixnoise3.repeat1.RData?dl=1",
"sim_noise3_rep2":
"https://www.dropbox.com/s/2mzjt4pv1f5ucs2/sim.pseudo.MOB.n10.cellType6.Mixnoise3.repeat2.RData?dl=1",
"sim_noise3_rep3":
"https://www.dropbox.com/s/65ixbnu6x65o8ee/sim.pseudo.MOB.n10.cellType6.Mixnoise3.repeat3.RData?dl=1",
"sim_noise3_rep4":
"https://www.dropbox.com/s/hrmwoi14wta0ida/sim.pseudo.MOB.n10.cellType6.Mixnoise3.repeat4.RData?dl=1",
"sim_noise3_rep5":
"https://www.dropbox.com/s/0txpltfj2p3dz9v/sim.pseudo.MOB.n10.cellType6.Mixnoise3.repeat5.RData?dl=1",
}


class SpotDataset:

Expand Down Expand Up @@ -139,67 +115,6 @@ def load_data(self):
return image, adata, spatial, spatial_pixel, label


class CellTypeDeconvoDataset:

def __init__(self, data_id="toy1", data_dir="data/spatial", build_graph_fn="default"):
self.data_id = data_id
self.data_dir = data_dir + "/{}".format(data_id)
self.data_url = cellDeconvo_dataset[data_id]
self.load_data()
self.adj = None

def get_all_data(self):
# provide an interface to get all data at one time
print("All data includes {} cellDeconvo datasets: {}".format(len(cellDeconvo_dataset),
",".join(cellDeconvo_dataset.keys())))
res = {}
for each_dataset in cellDeconvo_dataset.keys():
res[each_dataset] = cellDeconvo_dataset(each_dataset)
return res

def download_data(self):
# judge whether a file exists or not
isdownload = download_file(self.data_url, self.data_dir + "/{}.zip".format(self.data_id))
if isdownload:
unzip_file(self.data_dir + "/{}.zip".format(self.data_id), self.data_dir)
return self

def is_complete(self):
check = [self.data_dir + "/mix_count.*", self.data_dir + "/ref_sc_count.*"]

for i in check:
if not glob.glob(i):
print("lack {}".format(i))
return False
return True

def load_data(self):
if self.is_complete():
pass
else:
self.download_data()

self.data = {}
files = os.listdir(self.data_dir + "/")
for f in files:
DataPath = self.data_dir + "/" + f
filename = f.split(".")[0]
ext = f.split(".")[1]
if ext == "csv":
data = pd.read_csv(DataPath, header=0, index_col=0)
self.data[filename] = data
elif ext == "h5ad":
data = sc.read_h5ad(DataPath)
self.data[filename] = data
self.data[filename + "_annot"] = data.obs
else:
print("unsupported file type. Please use csv or h5ad file types.")

print("load data successfully....")

return self


class CellTypeDeconvoDatasetLite:

def __init__(self, data_id="GSE174746", data_dir="data/spatial", build_graph_fn="default"):
Expand Down Expand Up @@ -228,67 +143,35 @@ def _load_data(self):
else:
warnings.warn(f"Unsupported file type {ext!r}. Use csv or h5ad file types.")

def load_data(self):
def load_data(self, subset_common_celltypes: bool = True):
"""Load raw data.
Parameters
----------
subset_common_celltypes
If set to True, then subset both the reference and the real data to contain only cell types that are
present in both reference and real.
"""
ref_count = self.data["ref_sc_count"]
ref_annot = self.data["ref_sc_annot"]
count_matrix = self.data["mix_count"]
cell_type_portion = self.data["true_p"]
if (spatial := self.data.get("spatial_location")) is None:
spatial = pd.DataFrame(0, index=count_matrix.index, columns=["x", "y"])
return ref_count, ref_annot, count_matrix, cell_type_portion, spatial

# Obtain cell type info and subset to common cell types between ref and real if needed
ref_celltypes = set(ref_annot["cellType"].unique().tolist())
real_celltypes = set(cell_type_portion.columns.tolist())
logger.info(f"Number of cell types: reference = {len(ref_celltypes)}, real = {len(real_celltypes)}")
if subset_common_celltypes:
common_celltypes = sorted(ref_celltypes & real_celltypes)
logger.info(f"Subsetting to common cell types (n={len(common_celltypes)}):\n{pformat(common_celltypes)}")

class CARDSimulationRDataset:
ref_sc_count_url: str = "https://www.dropbox.com/s/wchoppxcsulk8ev/split2_ref_sc_count.h5ad?dl=1"
ref_sc_annot_url: str = "https://www.dropbox.com/s/irpvco2ffisvxvk/split2_ref_sc_annot.csv?dl=1"
idx = ref_annot[ref_annot["cellType"].isin(common_celltypes)].index
ref_annot = ref_annot.loc[idx]
ref_count = ref_count.loc[idx]

def __init__(self, data_id="sim_noise0_rep1", data_dir="data/spatial/card_simulation", build_graph_fn="default"):
self.data_id = data_id
self.data_dir = osp.join(data_dir, data_id)
self.data_path = osp.join(data_dir, f"{data_id}.RData")
self.ref_sc_count_path = osp.join(data_dir, "ref_sc_count.h5ad")
self.ref_sc_annot_path = osp.join(data_dir, "ref_sc_annot.csv")
self.data_url = card_simulation_dataset[data_id]
self.load_data()
self.adj = None

def get_all_data(self):
# TODO: make classmethod, make data url dict as class attrs
dataset_info = "\n\t".join(list(card_simulation_dataset))
print(f"Total of {(len(card_simulation_dataset))} datasets:\n{dataset_info}")
return {i: CARDSimulationRDataset(i) for i in card_simulation_dataset}

def download_data(self):
download_file(self.ref_sc_count_url, self.ref_sc_count_path)
download_file(self.ref_sc_annot_url, self.ref_sc_annot_path)
download_file(self.data_url, self.data_path)

def is_complete(self):
check = [self.data_path, self.ref_sc_count_path, self.ref_sc_annot_path]
return all(map(osp.exists, check))
cell_type_portion = cell_type_portion[common_celltypes]

def load_data(self):
if not self.is_complete():
self.download_data()

raw = rdata.conversion.convert(rdata.parser.parse_file(self.data_path))["spatial.pseudo"]
ref_sc_count = anndata.read_h5ad(self.ref_sc_count_path).to_df().T
ref_sc_annot = pd.read_csv(self.ref_sc_annot_path, index_col=0)
spatial_count = raw["pseudo.data"].to_pandas().T
spatial_location = (spatial_count.reset_index()["dim_1"].str.split("x", expand=True).set_index(
spatial_count.index).rename({
0: "x",
1: "y"
}, axis=1).astype(float))
true_p = raw["true.p"].to_pandas()

# TODO: directly save as attrs instead of a dict?
self.data = {
"ref_sc_count": ref_sc_count,
"ref_sc_annot": ref_sc_annot,
"spatial_count": spatial_count,
"spatial_location": spatial_location,
"true_p": true_p,
}

return self
return ref_count, ref_annot, count_matrix, cell_type_portion, spatial
5 changes: 0 additions & 5 deletions dance/modules/spatial/cell_type_deconvo/dstg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@
import time

import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parameter import Parameter

from dance.transforms.graph.dstg_graph import compute_dstg_adj
from dance.transforms.preprocess import pseudo_spatial_process
from dance.utils.matrix import normalize


class GraphConvolution(nn.Module):
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
Expand Down
1 change: 1 addition & 0 deletions dance/modules/spatial/cell_type_deconvo/spatialdecon.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(self, sc_count, sc_annot, ct_varname, ct_select, sc_profile=None, b

self.device = device

# TODO: extract to preprocessing transformation and remove ct_select from input
# Subset sc samples on selected cell types (mutual between sc and mix cell data)
ct_select_ix = sc_annot[sc_annot[ct_varname].isin(ct_select)].index
self.sc_annot = sc_annot.loc[ct_select_ix]
Expand Down
1 change: 1 addition & 0 deletions dance/modules/spatial/cell_type_deconvo/spotlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(self, ref_count, ref_annot, ct_varname, ct_select, rank=2, sc_profi
self.ct_select = ct_select
self.rank = rank

# TODO: extract to preprocessing transformation and remove ct_select from input
# Subset sc samples on selected cell types (mutual between sc and mix cell data)
ct_select_ix = ref_annot[ref_annot[ct_varname].isin(ct_select)].index
self.ref_annot = ref_annot.loc[ct_select_ix]
Expand Down
Loading

0 comments on commit 46537af

Please sign in to comment.