From 3ac29410e2cee2dd13a25bfc568badff5a3638aa Mon Sep 17 00:00:00 2001 From: Remy Date: Wed, 22 Feb 2023 18:35:18 -0500 Subject: [PATCH 01/17] simplifies card cell type profile generation function --- .../modules/spatial/cell_type_deconvo/card.py | 59 +++++-------------- 1 file changed, 14 insertions(+), 45 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index cfd57862..71c998cc 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -160,54 +160,23 @@ def __init__(self, sc_count, sc_meta, ct_varname=None, ct_select=None, sample_va def createscRef(self): """CreatescRef - create reference basis matrix from reference scRNA-seq.""" - countMat = self.sc_count.copy() # cell by gene matrix + count_mat = self.sc_count.copy() sc_meta = self.sc_meta.copy() ct_varname = self.ct_varname - sample_varname = self.sample_varname - if sample_varname is None: + batch_varname = self.sample_varname + if batch_varname is None: sc_meta["sampleID"] = "Sample" - sample_varname = "sampleID" - sample_id = sc_meta[sample_varname].astype(str) - ct_sample_id = sc_meta[ct_varname] + "$*$" + sample_id - rowSums_countMat = countMat.sum(axis=1) - sc_meta["rowSums"] = rowSums_countMat - rowSums_countMat_Ct = sc_meta.groupby([ct_varname, sample_varname])["rowSums"].agg("sum").to_frame() - rowSums_countMat_Ct_Wide = rowSums_countMat_Ct.pivot_table(index=sample_varname, columns=ct_varname, - values="rowSums", aggfunc="sum") - - # create count table by sampleID and cellType - tab = sc_meta.groupby([sample_varname, ct_varname]).size() - tbl = tab.unstack() - - # match column and row names - rowSums_countMat_Ct_Wide = rowSums_countMat_Ct_Wide.reindex_like(tbl) - rowSums_countMat_Ct_Wide = rowSums_countMat_Ct_Wide.reindex(tbl.index) - - # Compute total expression count by sample and cell type - S_JK = rowSums_countMat_Ct_Wide.div(tbl) - S_JK = S_JK.replace(0, np.nan) - S_JK = S_JK.replace([np.inf, -np.inf], np.nan) - S = S_JK.mean(axis=0).to_frame().unstack().droplevel(0) - S = S[sc_meta[ct_varname].unique()] - countMat["ct_sample_id"] = ct_sample_id - Theta_S_colMean = countMat.groupby(ct_sample_id).mean(numeric_only=True) - tbl_sample = countMat.groupby([ct_sample_id]).size() - tbl_sample = tbl_sample.reindex_like(Theta_S_colMean) - tbl_sample = tbl_sample.reindex(Theta_S_colMean.index) - Theta_S_colSums = countMat.groupby(ct_sample_id).sum(numeric_only=True) - Theta_S = Theta_S_colSums.copy() - Theta_S["sum"] = Theta_S_colSums.sum(axis=1) - Theta_S = Theta_S[list(Theta_S.columns)[:-1]].div(Theta_S["sum"], axis=0) - grp = [] - for ind in Theta_S.index: - grp.append(ind.split("$*$")[0]) - Theta_S["grp"] = grp - Theta = Theta_S.groupby(grp).mean(numeric_only=True) - Theta = Theta.reindex(sc_meta[ct_varname].unique()) - S = S[Theta.index] - Theta["S"] = S.iloc[0] - basis = Theta[list(Theta.columns)[:-1]].mul(Theta["S"], axis=0) - self.basis = basis + batch_varname = "sampleID" + var_names = [ct_varname, batch_varname] + + count_mat_ct_batch = count_mat.join(sc_meta[var_names]).groupby(var_names).mean(numeric_only=True) + lib_size_ct_batch = count_mat_ct_batch.sum(1) + ct_batch_profile = count_mat_ct_batch.div(lib_size_ct_batch, axis=0) + ct_batch_profile["lib_size"] = lib_size_ct_batch + + ct_profile = ct_batch_profile.droplevel(batch_varname).reset_index().groupby(ct_varname).mean(numeric_only=True) + ct_profile = ct_profile.loc[:, ct_profile.columns != "lib_size"].mul(ct_profile["lib_size"], axis=0) + self.basis = ct_profile def select_ct_marker(self, ict): Basis = self.basis.copy() From 8f6a708f13fb7bda2e7779933f0c02ea7ebdfd5c Mon Sep 17 00:00:00 2001 From: Remy Date: Wed, 22 Feb 2023 22:26:47 -0500 Subject: [PATCH 02/17] simplifies card gene selection function --- .../modules/spatial/cell_type_deconvo/card.py | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index 71c998cc..31b24a4c 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -199,27 +199,20 @@ def selectInfo(self, common_gene): List of informative genes. """ - ct_varname = self.ct_varname - Basis = self.basis.copy() - sc_count = self.sc_count.copy() - sc_meta = self.sc_meta.copy() - gene1_list = list() - for ict in Basis.index: - gene1_list.append(self.select_ct_marker(ict)) - gene1 = set(chain(*gene1_list)) - gene1 = list(gene1) - gene1 = [gene for gene in gene1 if gene in common_gene] # intersect with common_gene - counts = sc_count[gene1] - sd_within = pd.DataFrame(columns=counts.columns) - for ict in Basis.index: - series = (counts.loc[list(sc_meta[sc_meta[ct_varname] == ict].index)].var(axis=0).divide(counts.loc[list( - sc_meta[sc_meta[ct_varname] == ict].index)].mean(axis=0))) - series.name = ict - sd_within = pd.concat((sd_within, pd.DataFrame(series).T)) - - sd_within_colMean = sd_within.mean(axis=0).index.to_frame() - genes_to_select = sd_within.mean(axis=0) < sd_within.mean(axis=0).quantile(.99) - genes = list(sd_within_colMean[genes_to_select].index) + # Select marker genes from common genes + gene1 = set() + for ict in self.ct_select: + gene1.update(self.select_ct_marker(ict)) + gene1 = sorted(gene1 & set(common_gene)) + + # Compute coefficient of variation for each gene within each cell type + counts = self.sc_count[gene1].join(self.sc_meta[self.ct_varname]) + cov_mean = (counts.groupby(self.ct_varname).var() / counts.groupby(self.ct_varname).mean()).mean() + + # Remove genes that have high cov + ind = cov_mean < cov_mean.quantile(.99) + genes = sorted(cov_mean[ind].index) + return genes def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: bool = False): From 7aa9bc7a16a5d8fb2bb23908898f50d0e9d72968 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 00:15:06 -0500 Subject: [PATCH 03/17] update get_ct_profile to cover card basis construction --- .../modules/spatial/cell_type_deconvo/card.py | 21 +++---- .../spatial/cell_type_deconvo/spotlight.py | 10 ++-- dance/transforms/pseudo_gen.py | 57 ++++++++++++++----- 3 files changed, 58 insertions(+), 30 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index 31b24a4c..98f6ee8b 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -8,11 +8,10 @@ Nature Biotechnology (2022): 1-11. """ -from itertools import chain - import numpy as np import pandas as pd +from dance.transforms.pseudo_gen import get_ct_profile from dance.utils.matrix import pairwise_distance @@ -160,23 +159,21 @@ def __init__(self, sc_count, sc_meta, ct_varname=None, ct_select=None, sample_va def createscRef(self): """CreatescRef - create reference basis matrix from reference scRNA-seq.""" - count_mat = self.sc_count.copy() sc_meta = self.sc_meta.copy() ct_varname = self.ct_varname batch_varname = self.sample_varname if batch_varname is None: sc_meta["sampleID"] = "Sample" batch_varname = "sampleID" - var_names = [ct_varname, batch_varname] - - count_mat_ct_batch = count_mat.join(sc_meta[var_names]).groupby(var_names).mean(numeric_only=True) - lib_size_ct_batch = count_mat_ct_batch.sum(1) - ct_batch_profile = count_mat_ct_batch.div(lib_size_ct_batch, axis=0) - ct_batch_profile["lib_size"] = lib_size_ct_batch - ct_profile = ct_batch_profile.droplevel(batch_varname).reset_index().groupby(ct_varname).mean(numeric_only=True) - ct_profile = ct_profile.loc[:, ct_profile.columns != "lib_size"].mul(ct_profile["lib_size"], axis=0) - self.basis = ct_profile + ct_profile = get_ct_profile( + self.sc_count.values, + sc_meta[ct_varname].values, + ct_select=self.ct_select, + batch_index=sc_meta[batch_varname].values, + ).T + basis = pd.DataFrame(ct_profile, index=self.ct_select, columns=self.sc_count.columns) + self.basis = basis def select_ct_marker(self, ict): Basis = self.basis.copy() diff --git a/dance/modules/spatial/cell_type_deconvo/spotlight.py b/dance/modules/spatial/cell_type_deconvo/spotlight.py index c864e85c..a5cccfd2 100644 --- a/dance/modules/spatial/cell_type_deconvo/spotlight.py +++ b/dance/modules/spatial/cell_type_deconvo/spotlight.py @@ -8,6 +8,8 @@ spots with single-cell transcriptomes." Nucleic Acids Research (2021) """ +from functools import partial + import torch from torch import nn, optim from torchnmf.nmf import NMF @@ -18,7 +20,7 @@ from dance.utils import get_device from dance.utils.wrappers import CastOutputType -get_ct_profile_tensor = CastOutputType(torch.FloatTensor)(get_ct_profile) +get_ct_profile_tensor = CastOutputType(torch.FloatTensor)(partial(get_ct_profile, method="median")) class NNLS(nn.Module): @@ -131,7 +133,7 @@ def _init_model(self, dim_out, ref_count, ref_annot): hid_dim = len(self.ct_select) self.nmf_model = NMF(Vshape=ref_count.T.shape, rank=self.rank).to(self.device) if self.rank == len(self.ct_select): # initialize basis as cell profile - self.nmf_model.H = nn.Parameter(get_ct_profile_tensor(ref_count, ref_annot, self.ct_select)) + self.nmf_model.H = nn.Parameter(get_ct_profile_tensor(ref_count, ref_annot, ct_select=self.ct_select)) self.nnls_reg1 = NNLS(in_dim=self.rank, out_dim=dim_out, bias=self.bias, device=self.device) self.nnls_reg2 = NNLS(in_dim=hid_dim, out_dim=dim_out, bias=self.bias, device=self.device) @@ -145,7 +147,7 @@ def forward(self, ref_annot): # Get cell-topic and mix-topic profiles # Get cell-topic profiles H_profile: cell-type group medians of coef H (topic x cells) - H_profile = get_ct_profile_tensor(H.cpu().numpy().T, ref_annot, self.ct_select) + H_profile = get_ct_profile_tensor(H.cpu().numpy().T, ref_annot, ct_select=self.ct_select) H_profile = H_profile.to(self.device) # Get mix-topic profiles B: NNLS of basis W onto mix expression Y -- y ~ W*b @@ -183,7 +185,7 @@ def fit(self, x, ref_count, ref_annot, lr=1e-3, max_iter=1000): # Get cell-topic and mix-topic profiles # Get cell-topic profiles H_profile: cell-type group medians of coef H (topic x cells) - self.H_profile = get_ct_profile_tensor(self.H.cpu().numpy().T, ref_annot, self.ct_select) + self.H_profile = get_ct_profile_tensor(self.H.cpu().numpy().T, ref_annot, ct_select=self.ct_select) self.H_profile = self.H_profile.to(self.device) # Get mix-topic profiles B: NNLS of basis W onto mix expression X ~ W*b diff --git a/dance/transforms/pseudo_gen.py b/dance/transforms/pseudo_gen.py index 6c28b260..cbea45d6 100644 --- a/dance/transforms/pseudo_gen.py +++ b/dance/transforms/pseudo_gen.py @@ -7,7 +7,7 @@ from dance import logger as native_logger from dance.data import Data from dance.transforms.base import BaseTransform -from dance.typing import Dict, List, Literal, Logger, Optional, Tuple, Union +from dance.typing import Callable, Dict, List, Literal, Logger, Optional, Tuple, Union class PseudoMixture(BaseTransform): @@ -141,33 +141,62 @@ def get_cell_types(ct_select: Union[Literal["auto"], List[str]], annot: np.ndarr return ct_select +def get_agg_func(name: str, *, default: Optional[str] = None) -> Callable[[np.ndarray], np.ndarray]: + if name == "default": + if default is None: + raise ValueError("Aggregation function name set to 'default' but default option not set") + name = default + + if name == "median": + agg_func = partial(np.median, axis=0) + elif name == "mean": + agg_func = partial(np.mean, axis=0) + else: + raise ValueError(f"Unknown aggregation method {name!r}. Available options are: 'median', 'mena'") + + return agg_func + + def get_ct_profile( x: np.ndarray, annot: np.ndarray, - /, + *, + batch_index: Optional[np.ndarray] = None, ct_select: Union[Literal["auto"], List[str]] = "auto", - method: Literal["median", "mean"] = "median", + method: Literal["median", "mean"] = "mean", logger: Optional[Logger] = None, ) -> np.ndarray: - """Return the cell-topic profile matrix (gene x cell-type).""" logger = logger or native_logger ct_select = get_cell_types(ct_select, annot) - - # Get aggregation function - if method == "median": - agg_func = partial(np.median, axis=0) - elif method == "mean": - agg_func = partial(np.mean, axis=0) - else: - raise ValueError(f"Unknown aggregation method {method!r}. Available options are: 'median', 'mena'") + agg_func = get_agg_func(method, default="mean") + if batch_index is None: + batch_index = np.zeros(x.shape[0], dtype=int) # Aggregate profile for each selected cell types logger.info(f"Generating cell-type profiles ({method!r} aggregation) for {ct_select}") - ct_profile = np.zeros((x.shape[1], len(ct_select)), dtype=np.float32) + ct_profile = np.zeros((x.shape[1], len(ct_select)), dtype=np.float32) # gene x cell + for i, ct in enumerate(ct_select): ct_index = np.where(annot == ct)[0] logger.info(f"Aggregating {ct!r} profiles over {ct_index.size:,} samples") - ct_profile[:, i] = agg_func(x[ct_index]) + + # Get features within a cell type + sub_batch_index = batch_index[ct_index] + batches = np.unique(sub_batch_index) + + # Aggregate cell type profile for each batch + sub_ct_profile = np.zeros((batches.size, x.shape[1]), dtype=np.float32) # cell x gene + sub_ct_mean_lib_sizes = np.zeros(batches.size, dtype=np.float32) + for j, batch_id in enumerate(batches): + idx = np.where(sub_batch_index == batch_id)[0] + sub_ct_profile[j] = agg_func(x[ct_index][idx]) + sub_ct_mean_lib_sizes[j] = sub_ct_profile[j].sum() + sub_ct_profile[j] /= sub_ct_mean_lib_sizes[j] + logger.info(f"Number of {ct!r} cells in batch {batch_id!r}: {idx.size:,}") + + # Aggregate cell type profile over batches + ct_profile[:, i] = agg_func(sub_ct_profile) * agg_func(sub_ct_mean_lib_sizes) + logger.info("Cell-type profile generated") return ct_profile From fa9e506efca3cd10680a724b6d0e99b6728fd64b Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 10:37:35 -0500 Subject: [PATCH 04/17] feat: implement FilterGenesMarker gene selection transformation --- .../modules/spatial/cell_type_deconvo/card.py | 17 ++--- dance/transforms/__init__.py | 3 +- dance/transforms/filter.py | 76 ++++++++++++++++++- 3 files changed, 82 insertions(+), 14 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index 98f6ee8b..533a56de 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd +from dance.transforms import FilterGenesMarker from dance.transforms.pseudo_gen import get_ct_profile from dance.utils.matrix import pairwise_distance @@ -175,13 +176,6 @@ def createscRef(self): basis = pd.DataFrame(ct_profile, index=self.ct_select, columns=self.sc_count.columns) self.basis = basis - def select_ct_marker(self, ict): - Basis = self.basis.copy() - rest = Basis[Basis.index != ict].to_numpy().mean(axis=0) - FC = np.log(Basis[Basis.index == ict].to_numpy().mean(axis=0) + 1e-6) - np.log(rest + 1e-6) - markers = list(Basis.columns[np.logical_and(FC > 1.25, Basis[Basis.index == ict].to_numpy().mean(axis=0) > 0)]) - return markers - def selectInfo(self, common_gene): """Select Informative Genes used in the deconvolution. @@ -197,13 +191,12 @@ def selectInfo(self, common_gene): """ # Select marker genes from common genes - gene1 = set() - for ict in self.ct_select: - gene1.update(self.select_ct_marker(ict)) - gene1 = sorted(gene1 & set(common_gene)) + markers, _ = FilterGenesMarker.get_marker_genes(self.basis.values, self.basis.index.tolist(), + self.basis.columns.tolist()) + selected_genes = sorted(set(markers) & set(common_gene)) # Compute coefficient of variation for each gene within each cell type - counts = self.sc_count[gene1].join(self.sc_meta[self.ct_varname]) + counts = self.sc_count[selected_genes].join(self.sc_meta[self.ct_varname]) cov_mean = (counts.groupby(self.ct_varname).var() / counts.groupby(self.ct_varname).mean()).mean() # Remove genes that have high cov diff --git a/dance/transforms/__init__.py b/dance/transforms/__init__.py index b99b5a61..9dc5b16c 100644 --- a/dance/transforms/__init__.py +++ b/dance/transforms/__init__.py @@ -1,6 +1,6 @@ from dance.transforms import graph from dance.transforms.cell_feature import CellPCA, WeightedFeaturePCA -from dance.transforms.filter import FilterGenesCommon, FilterGenesMatch, FilterGenesPercentile +from dance.transforms.filter import FilterGenesCommon, FilterGenesMarker, FilterGenesMatch, FilterGenesPercentile from dance.transforms.interface import AnnDataTransform from dance.transforms.misc import Compose, RemoveSplit, SaveRaw, SetConfig from dance.transforms.normalize import ScaleFeature @@ -15,6 +15,7 @@ "CellTopicProfile", "Compose", "FilterGenesCommon", + "FilterGenesMarker", "FilterGenesMatch", "FilterGenesPercentile", "GeneStats", diff --git a/dance/transforms/filter.py b/dance/transforms/filter.py index 90ea2cb7..027ae1d0 100644 --- a/dance/transforms/filter.py +++ b/dance/transforms/filter.py @@ -1,9 +1,11 @@ import numpy as np +import pandas as pd from anndata import AnnData +from dance import logger as default_logger from dance.exceptions import DevError from dance.transforms.base import BaseTransform -from dance.typing import Dict, List, Literal, Optional, Union +from dance.typing import Dict, List, Literal, Logger, Optional, Tuple, Union class FilterGenesCommon(BaseTransform): @@ -175,3 +177,75 @@ def __call__(self, data): self.logger.info(f"{mask.size - mask.sum()} genes removed ({percentile_lo=:.2e}, {percentile_hi=:.2e})") data._data = data.data[:, mask].copy() + + +class FilterGenesMarker(BaseTransform): + + _DISPLAY_ATTRS = ("ct_profile_channel", "subset", "threshold", "eps") + + def __init__( + self, + *, + ct_profile_channel: str = "CellTopicProfile", + subset: bool = True, + threshold: float = 1.25, + eps: float = 1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.ct_profile_channel = ct_profile_channel + self.subset = subset + self.threshold = threshold + self.eps = eps + + @staticmethod + def get_marker_genes( + ct_profile: np.ndarray, + cell_types: List[str], + genes: List[str], + *, + threshold: float = 1.25, + eps: float = 1e-6, + logger: Logger = default_logger, + ) -> Tuple[List[str], pd.DataFrame]: + if (num_cts := len(cell_types)) < 2: + raise ValueError(f"Need at least two cell types to find marker genes, got {num_cts}:\n{cell_types}") + + # Find marker genes for each cell type + marker_gene_ind_df = pd.DataFrame(False, index=genes, columns=cell_types) + for i, ct in enumerate(cell_types): + others = [j for j in range(num_cts) if j != i] + log_fc = np.log(ct_profile[i] + eps) - np.log(ct_profile[others].mean(0) + eps) + markers_idx = np.where(log_fc > threshold)[0] + + if markers_idx.size > 0: + marker_gene_ind_df.iloc[markers_idx, i] = True + markers = marker_gene_ind_df.iloc[markers_idx].index.tolist() + logger.info(f"Found {len(markers):,} marker genes for cell type {ct!r}") + logger.debug(f"{markers=}") + else: + logger.info(f"No marker genes found for cell type {ct!r}") + + # Combine all marker genes + is_marker = marker_gene_ind_df.max(1) + marker_genes = is_marker[is_marker].index.tolist() + logger.info(f"Total number of marker genes found: {len(marker_genes):,}") + logger.debug(f"{marker_genes=}") + + return marker_genes, marker_gene_ind_df + + def __call__(self, data): + ct_profile_df = data.get_feature(channel=self.ct_profile_channel, channel_type="varm", return_type="default") + ct_profile = ct_profile_df.values + cell_types = ct_profile_df.index.tolist() + genes = ct_profile_df.columns.tolist() + marker_genes, marker_gene_ind_df = self.get_marker_genes(ct_profile, cell_types, genes, eps=self.eps, + threshold=self.threshold, logger=self.logger) + + # Save marker gene info to data + data.data.varm[self.out] = marker_gene_ind_df + if self.label is not None: + data.data.var[self.label] = marker_gene_ind_df.max(1) + + if self.subset: # inplace subset the variables + data.data._inplace_subset_var(marker_genes) From 63e144805110468f4b66cc0c5e98291fade63482 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 10:44:54 -0500 Subject: [PATCH 05/17] doc: add docstring to FilterGenesMarker --- dance/transforms/filter.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/dance/transforms/filter.py b/dance/transforms/filter.py index 027ae1d0..3753e5e6 100644 --- a/dance/transforms/filter.py +++ b/dance/transforms/filter.py @@ -180,6 +180,24 @@ def __call__(self, data): class FilterGenesMarker(BaseTransform): + """Select marker genes based on log fold-change. + + Parameters + ---------- + ct_profile_channel + Name of the ``.varm`` channel that contains the cell-topic profile which will be used to compute the log + fold-changes for each cell-topic (e.g., cell type). + subset + If set to :obj:`True`, then inplace subset the variables to only contain the markers. + label + If set, e.g., to :obj:`'marker'`, then save the marker indicator to the :obj:`.obs` column named as + :obj:`marker`. + threshold + Threshold value of the log fol-change above which the gene will be considered as a marker. + eps + A small value that prevents taking log of zeros. + + """ _DISPLAY_ATTRS = ("ct_profile_channel", "subset", "threshold", "eps") @@ -188,6 +206,7 @@ def __init__( *, ct_profile_channel: str = "CellTopicProfile", subset: bool = True, + label: Optional[str] = None, threshold: float = 1.25, eps: float = 1e-6, **kwargs, @@ -195,6 +214,7 @@ def __init__( super().__init__(**kwargs) self.ct_profile_channel = ct_profile_channel self.subset = subset + self.label = label self.threshold = threshold self.eps = eps From fc034eb835fd86fa79f47667bd00c4db4e1b5e28 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 10:48:20 -0500 Subject: [PATCH 06/17] minor format edit --- dance/transforms/filter.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/dance/transforms/filter.py b/dance/transforms/filter.py index 3753e5e6..faeb9ced 100644 --- a/dance/transforms/filter.py +++ b/dance/transforms/filter.py @@ -87,8 +87,13 @@ class FilterGenesMatch(BaseTransform): _DISPLAY_ATTRS = ("prefixes", "suffixes") - def __init__(self, prefixes: Optional[List[str]] = None, suffixes: Optional[List[str]] = None, - case_sensitive: bool = False, **kwargs): + def __init__( + self, + prefixes: Optional[List[str]] = None, + suffixes: Optional[List[str]] = None, + case_sensitive: bool = False, + **kwargs, + ): super().__init__(**kwargs) self.prefixes = prefixes or [] @@ -144,8 +149,16 @@ class FilterGenesPercentile(BaseTransform): _DISPLAY_ATTRS = ("min_val", "max_val", "mode") _MODES = ["sum", "cv"] - def __init__(self, min_val: Optional[float] = 1, max_val: Optional[float] = 99, mode: Literal["sum", "cv"] = "sum", - *, channel: Optional[str] = None, channel_type: Optional[str] = None, **kwargs): + def __init__( + self, + min_val: Optional[float] = 1, + max_val: Optional[float] = 99, + mode: Literal["sum", "cv"] = "sum", + *, + channel: Optional[str] = None, + channel_type: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) if (channel is not None) and (channel_type != "layers"): From 925ae60a788431e312aab129492496aebe16e018 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 11:02:19 -0500 Subject: [PATCH 07/17] update: use _inplace_subset_var to subset variabels in FilterGenesPercentile --- dance/transforms/filter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dance/transforms/filter.py b/dance/transforms/filter.py index faeb9ced..9addb588 100644 --- a/dance/transforms/filter.py +++ b/dance/transforms/filter.py @@ -183,13 +183,15 @@ def __call__(self, data): else: raise DevError(f"{self.mode!r} not expected, please inform dev to fix this error.") + self.logger.info(f"Filtering genes based on {self.mode} expression percentiles in layer {self.channel!r}") percentile_lo = np.percentile(gene_summary, self.min_val) percentile_hi = np.percentile(gene_summary, self.max_val) mask = np.logical_and(gene_summary >= percentile_lo, gene_summary <= percentile_hi) - self.logger.info(f"Filtering genes based on {self.mode} expression percentiles in layer {self.channel!r}") - self.logger.info(f"{mask.size - mask.sum()} genes removed ({percentile_lo=:.2e}, {percentile_hi=:.2e})") + selected_genes = data.data.var_names[mask].tolist() + num_removed = mask.size - len(selected_genes) + self.logger.info(f"{num_removed:,} genes removed ({percentile_lo=:.2e}, {percentile_hi=:.2e})") - data._data = data.data[:, mask].copy() + data.data._inplace_subset_var(selected_genes) class FilterGenesMarker(BaseTransform): From 432e93bcc85996bd7218832dadf33e533e40e367 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 11:24:52 -0500 Subject: [PATCH 08/17] feat: add whitelist option to FilterGenePercentile --- dance/transforms/filter.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/dance/transforms/filter.py b/dance/transforms/filter.py index 9addb588..aca630b8 100644 --- a/dance/transforms/filter.py +++ b/dance/transforms/filter.py @@ -143,6 +143,10 @@ class FilterGenesPercentile(BaseTransform): channel_type Type of channels specified. Only allow ``None`` (the default setting) or ``layers`` (when ``channel`` is specified). + whitelist_indicators + A list of (or a single) :obj:`.var` columns that indicates the genes to be excluded from the filtering process. + Note that these genes will still be used in the summary stats computation, and thus will still contribute to the + threshold percentile. If not set, then no genes will be excluded from the filtering process. """ @@ -157,6 +161,7 @@ def __init__( *, channel: Optional[str] = None, channel_type: Optional[str] = None, + whitelist_indicators: Optional[Union[str, List[str]]] = None, **kwargs, ): super().__init__(**kwargs) @@ -172,10 +177,12 @@ def __init__( self.mode = mode self.channel = channel self.channel_type = channel_type + self.whitelist_indicators = whitelist_indicators def __call__(self, data): x = data.get_feature(return_type="default", channel=self.channel, channel_type=self.channel_type) + # Compute gene summary stats for filtering if self.mode == "sum": gene_summary = np.array(x.sum(0)).ravel() elif self.mode == "cv": @@ -183,14 +190,32 @@ def __call__(self, data): else: raise DevError(f"{self.mode!r} not expected, please inform dev to fix this error.") + # Get whitelist genes to be excluded from the filtering process + whitelist_gene_set = set() + if self.whitelist_indicators is not None: + columns = self.whitelist_indicators + columns = columns if isinstance(columns, str) else columns + indicators = data.data.var[columns] + # Genes that satisfy any one of the whitelist conditions will be selected as whitelist genes + whitelist_gene_set.update(indicators[indicators.max(1)].index.tolist()) + + # Select genes to be filtered self.logger.info(f"Filtering genes based on {self.mode} expression percentiles in layer {self.channel!r}") percentile_lo = np.percentile(gene_summary, self.min_val) percentile_hi = np.percentile(gene_summary, self.max_val) mask = np.logical_and(gene_summary >= percentile_lo, gene_summary <= percentile_hi) - selected_genes = data.data.var_names[mask].tolist() + selected_genes = sorted(data.data.var_names[mask]) + + # Exclude whitelisted genes + if len(whitelist_gene_set) > 0: + orig_num_selected = len(selected_genes) + selected_genes = sorted(set(selected_genes) - whitelist_gene_set) + num_excluded = orig_num_selected - len(selected_genes) + self.logger.info(f"{num_excluded:,} genes originally selected for filtering excluded due to whitelist") + + # Update data num_removed = mask.size - len(selected_genes) self.logger.info(f"{num_removed:,} genes removed ({percentile_lo=:.2e}, {percentile_hi=:.2e})") - data.data._inplace_subset_var(selected_genes) From 2eb6c1984664c9dc31fa946ef84a326b06c14936 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 14:43:15 -0500 Subject: [PATCH 09/17] feat: add rv option to FilterGenesPercentile --- dance/transforms/filter.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dance/transforms/filter.py b/dance/transforms/filter.py index aca630b8..bb455390 100644 --- a/dance/transforms/filter.py +++ b/dance/transforms/filter.py @@ -135,8 +135,8 @@ class FilterGenesPercentile(BaseTransform): max_val Maximum percentile of the summarized expression value above which the genes will be discarded. mode - Summarization mode. Available options are ``[sum|cv]``. ``sum`` calculates the sum of expression values, ``cv`` - uses the coefficient of variation (std / mean). + Summarization mode. Available options are ``[sum|cv|rv]``. ``sum`` calculates the sum of expression values, + ``cv`` uses the coefficient of variation (std / mean), and ``rv`` uses the relative variance (var / mean). channel Which channel, more specificailly, ``layers``, to use. Use the default ``.X`` if not set. If ``channel`` is specified, then need to specify ``channel_type`` to be ``layers`` as well. @@ -151,13 +151,13 @@ class FilterGenesPercentile(BaseTransform): """ _DISPLAY_ATTRS = ("min_val", "max_val", "mode") - _MODES = ["sum", "cv"] + _MODES = ["sum", "cv", "rv"] def __init__( self, min_val: Optional[float] = 1, max_val: Optional[float] = 99, - mode: Literal["sum", "cv"] = "sum", + mode: Literal["sum", "cv", "rv"] = "sum", *, channel: Optional[str] = None, channel_type: Optional[str] = None, @@ -187,6 +187,8 @@ def __call__(self, data): gene_summary = np.array(x.sum(0)).ravel() elif self.mode == "cv": gene_summary = np.nan_to_num(np.array(x.std(0) / x.mean(0)), posinf=0, neginf=0).ravel() + elif self.mode == "rv": + gene_summary = np.nan_to_num(np.array(x.var(0) / x.mean(0)), posinf=0, neginf=0).ravel() else: raise DevError(f"{self.mode!r} not expected, please inform dev to fix this error.") From d7a6146ffd75a4a45dfe5310ba6bf8d5c3faa82c Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 14:44:46 -0500 Subject: [PATCH 10/17] fix: set columns to list if only one presented; log fc dimension; genes/celltypes assignment --- dance/transforms/filter.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dance/transforms/filter.py b/dance/transforms/filter.py index bb455390..b80d021d 100644 --- a/dance/transforms/filter.py +++ b/dance/transforms/filter.py @@ -196,7 +196,7 @@ def __call__(self, data): whitelist_gene_set = set() if self.whitelist_indicators is not None: columns = self.whitelist_indicators - columns = columns if isinstance(columns, str) else columns + columns = [columns] if isinstance(columns, str) else columns indicators = data.data.var[columns] # Genes that satisfy any one of the whitelist conditions will be selected as whitelist genes whitelist_gene_set.update(indicators[indicators.max(1)].index.tolist()) @@ -211,9 +211,9 @@ def __call__(self, data): # Exclude whitelisted genes if len(whitelist_gene_set) > 0: orig_num_selected = len(selected_genes) - selected_genes = sorted(set(selected_genes) - whitelist_gene_set) - num_excluded = orig_num_selected - len(selected_genes) - self.logger.info(f"{num_excluded:,} genes originally selected for filtering excluded due to whitelist") + selected_genes = sorted(set(selected_genes) | whitelist_gene_set) + num_added = len(selected_genes) - orig_num_selected + self.logger.info(f"{num_added:,} genes originally unselected are being added due to whitelist") # Update data num_removed = mask.size - len(selected_genes) @@ -262,7 +262,7 @@ def __init__( @staticmethod def get_marker_genes( - ct_profile: np.ndarray, + ct_profile: np.ndarray, # gene x cell cell_types: List[str], genes: List[str], *, @@ -277,7 +277,7 @@ def get_marker_genes( marker_gene_ind_df = pd.DataFrame(False, index=genes, columns=cell_types) for i, ct in enumerate(cell_types): others = [j for j in range(num_cts) if j != i] - log_fc = np.log(ct_profile[i] + eps) - np.log(ct_profile[others].mean(0) + eps) + log_fc = np.log(ct_profile[:, i] + eps) - np.log(ct_profile[:, others].mean(1) + eps) markers_idx = np.where(log_fc > threshold)[0] if markers_idx.size > 0: @@ -299,8 +299,8 @@ def get_marker_genes( def __call__(self, data): ct_profile_df = data.get_feature(channel=self.ct_profile_channel, channel_type="varm", return_type="default") ct_profile = ct_profile_df.values - cell_types = ct_profile_df.index.tolist() - genes = ct_profile_df.columns.tolist() + cell_types = ct_profile_df.columns.tolist() + genes = ct_profile_df.index.tolist() marker_genes, marker_gene_ind_df = self.get_marker_genes(ct_profile, cell_types, genes, eps=self.eps, threshold=self.threshold, logger=self.logger) From 3ca1cc7dc59c2bf315a20386c52ac4670ad5ead0 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 14:45:09 -0500 Subject: [PATCH 11/17] addapt batch_key option to CellTopicProfile --- dance/transforms/pseudo_gen.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/dance/transforms/pseudo_gen.py b/dance/transforms/pseudo_gen.py index cbea45d6..d805e410 100644 --- a/dance/transforms/pseudo_gen.py +++ b/dance/transforms/pseudo_gen.py @@ -104,6 +104,7 @@ def __init__( *, ct_select: Union[Literal["auto"], List[str]] = "auto", ct_key: str = "cellType", + batch_key: Optional[str] = None, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: str = "X", @@ -114,7 +115,9 @@ def __init__( self.ct_select = ct_select self.ct_key = ct_key + self.batch_key = batch_key self.split_name = split_name + self.channel = channel self.channel_type = channel_type self.method = method @@ -124,9 +127,15 @@ def __call__(self, data): return_type="numpy") annot = data.get_feature(split_name=self.split_name, channel=self.ct_key, channel_type="obs", return_type="numpy") + if self.batch_key is None: + batch_index = None + else: + batch_index = data.get_feature(split_name=self.split_name, channel=self.batch_key, channel_type="obs", + return_type="numpy") ct_select = get_cell_types(self.ct_select, annot) - ct_profile = get_ct_profile(x, annot, ct_select, self.method, self.logger) + ct_profile = get_ct_profile(x, annot, batch_index=batch_index, ct_select=ct_select, method=self.method, + logger=self.logger) ct_profile_df = pd.DataFrame(ct_profile, index=data.data.var_names, columns=ct_select) data.data.varm[self.out] = ct_profile_df From 96cdea6818c30e765a42be8119d9d9c8bcd106e5 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 14:46:25 -0500 Subject: [PATCH 12/17] refactor: update card to use dance transformations --- .../modules/spatial/cell_type_deconvo/card.py | 114 ++---------------- examples/spatial/cell_type_deconvo/card.py | 20 ++- 2 files changed, 27 insertions(+), 107 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index 533a56de..a4fe8751 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -11,9 +11,7 @@ import numpy as np import pandas as pd -from dance.transforms import FilterGenesMarker -from dance.transforms.pseudo_gen import get_ct_profile -from dance.utils.matrix import pairwise_distance +from dance.utils.matrix import normalize, pairwise_distance def obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha, sigma_e2=None): @@ -115,96 +113,15 @@ class Card: Parameters ---------- - sc_count : pd.DataFrame - Reference single cell RNA-seq counts data. - sc_meta : pd.DataFrame - Reference cell-type label information. - ct_varname : str, optional - Name of the cell-types column. - ct_select : str, optional - Selected cell-types to be considered for deconvolution. - sample_varname : str, optional - Name of the samples column. - minCountGene : int - Minimum number of genes required. - minCountSpot : int - Minimum number of spots required. basis - The basis parameter. - markers - Markers. + The cell-type profile basis. """ - def __init__(self, sc_count, sc_meta, ct_varname=None, ct_select=None, sample_varname=None, minCountGene=100, - minCountSpot=5, basis=None, markers=None): - self.sc_count = sc_count - self.sc_meta = sc_meta - self.ct_varname = ct_varname - self.ct_select = ct_select - self.sample_varname = sample_varname - self.minCountGene = minCountGene - self.minCountSpot = minCountSpot + def __init__(self, basis: pd.DataFrame): self.basis = basis - self.marker = markers self.info_parameters = {} - self.createscRef() # create basis - all_genes = sc_count.columns.tolist() - gene_to_idx = {j: i for i, j in enumerate(all_genes)} - not_mt_genes = [i for i in all_genes if not i.lower().startswith("mt-")] - selected_genes = self.selectInfo(not_mt_genes) - selected_gene_idx = list(map(gene_to_idx.get, selected_genes)) - self.gene_mask = np.zeros(len(all_genes), dtype=np.bool) - self.gene_mask[selected_gene_idx] = True - - def createscRef(self): - """CreatescRef - create reference basis matrix from reference scRNA-seq.""" - sc_meta = self.sc_meta.copy() - ct_varname = self.ct_varname - batch_varname = self.sample_varname - if batch_varname is None: - sc_meta["sampleID"] = "Sample" - batch_varname = "sampleID" - - ct_profile = get_ct_profile( - self.sc_count.values, - sc_meta[ct_varname].values, - ct_select=self.ct_select, - batch_index=sc_meta[batch_varname].values, - ).T - basis = pd.DataFrame(ct_profile, index=self.ct_select, columns=self.sc_count.columns) - self.basis = basis - - def selectInfo(self, common_gene): - """Select Informative Genes used in the deconvolution. - - Parameters - ---------- - common_gene : list - Common genes between scRNAseq count data and spatial resolved transcriptomics data. - - Returns - ------- - list - List of informative genes. - - """ - # Select marker genes from common genes - markers, _ = FilterGenesMarker.get_marker_genes(self.basis.values, self.basis.index.tolist(), - self.basis.columns.tolist()) - selected_genes = sorted(set(markers) & set(common_gene)) - - # Compute coefficient of variation for each gene within each cell type - counts = self.sc_count[selected_genes].join(self.sc_meta[self.ct_varname]) - cov_mean = (counts.groupby(self.ct_varname).var() / counts.groupby(self.ct_varname).mean()).mean() - - # Remove genes that have high cov - ind = cov_mean < cov_mean.quantile(.99) - genes = sorted(cov_mean[ind].index) - - return genes - def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: bool = False): """Fit function for model training. @@ -224,14 +141,8 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: Do not use spatial location info if set to True. """ - ct_select = self.ct_select - Basis = self.basis.copy() - Basis = Basis.loc[ct_select] - - gene_mask = self.gene_mask & (x.sum(0) > 0) - B = Basis.values[:, gene_mask].copy() # TODO: make it a numpy array - Xinput = x[:, gene_mask].copy() - Xinput_norm = Xinput / Xinput.sum(1, keepdims=True) # TODO: use the normalize util + basis = self.basis.values.copy() + x_norm = normalize(x, axis=1, mode="normalize") # Spatial location if location_free or (spatial == 0).all(): @@ -246,19 +157,20 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: # Initialize the proportion matrix rng = np.random.default_rng(20200107) - Vint1 = rng.dirichlet(np.repeat(10, B.shape[0], axis=0), Xinput_norm.shape[0]) + Vint1 = rng.dirichlet(np.repeat(10, basis.shape[0], axis=0), x_norm.shape[0]) phi = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99] + # scale the Xinput_norm and B to speed up the convergence. - Xinput_norm = Xinput_norm * 0.1 / Xinput_norm.sum() - B = B * 0.1 / B.mean() + x_norm = x_norm * 0.1 / x_norm.sum() + b_mat = basis * 0.1 / basis.mean() # Optimization ResList = {} Obj = np.array([]) for iphi in range(len(phi)): - res = CARDref(Xinput=Xinput_norm.T, U=B.T, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, epsilon=epsilon, - V=Vint1, b=np.repeat(0, B.T.shape[1]).reshape(B.T.shape[1], 1), sigma_e2=0.1, - Lambda=np.repeat(10, len(ct_select))) + res = CARDref(Xinput=x_norm.T, U=b_mat.T, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, epsilon=epsilon, + V=Vint1, b=np.repeat(0, b_mat.T.shape[1]).reshape(b_mat.T.shape[1], 1), sigma_e2=0.1, + Lambda=np.repeat(10, basis.shape[0])) ResList[str(iphi)] = res Obj = np.append(Obj, res["Obj"]) self.Obj_hist = Obj @@ -268,7 +180,7 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: print("## Deconvolution Finish! ...\n") self.info_parameters["phi"] = OptimalPhi - self.algorithm_matrix = {"B": B, "Xinput_norm": Xinput_norm, "Res": OptimalRes} + self.algorithm_matrix = {"B": b_mat, "Xinput_norm": x_norm, "Res": OptimalRes} return self diff --git a/examples/spatial/cell_type_deconvo/card.py b/examples/spatial/cell_type_deconvo/card.py index 0d12a8b6..78b861de 100644 --- a/examples/spatial/cell_type_deconvo/card.py +++ b/examples/spatial/cell_type_deconvo/card.py @@ -3,6 +3,8 @@ from dance.datasets.spatial import CellTypeDeconvoDataset from dance.modules.spatial.cell_type_deconvo.card import Card +from dance.transforms import (CellTopicProfile, FilterGenesCommon, FilterGenesMarker, FilterGenesMatch, + FilterGenesPercentile) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--dataset", default="CARD_synthetic", choices=CellTypeDeconvoDataset.DATASETS) @@ -16,17 +18,23 @@ # Load dataset dataset = CellTypeDeconvoDataset(data_dir=args.datadir, data_id=args.dataset) data = dataset.load_data() +cell_types = data.data.obsm["cell_type_portion"].columns.tolist() + +CellTopicProfile(ct_select=cell_types, ct_key="cellType", batch_key=None, split_name="ref", method="mean", + log_level="INFO")(data) +FilterGenesMatch(prefixes=["mt-"], case_sensitive=False, log_level="INFO")(data) +FilterGenesCommon(split_keys=["ref", "test"], log_level="INFO")(data) +FilterGenesMarker(ct_profile_channel="CellTopicProfile", threshold=1.25, log_level="INFO")(data) +FilterGenesPercentile(min_val=1, max_val=99, mode="rv", log_level="INFO")(data) data.set_config(feature_channel=[None, "spatial"], feature_channel_type=["X", "obsm"], label_channel="cell_type_portion") (x_count, x_spatial), y = data.get_data(split_name="test", return_type="numpy") -cell_types = data.data.obsm["cell_type_portion"].columns.tolist() - -ref_adata = data.get_split_data("ref") -ref_count = ref_adata.to_df() -ref_annot = ref_adata.obs +# TODO: adapt card to use basis.T +# TODO: use "auto"/None option for ct_select +basis = data.get_feature(return_type="default", channel="CellTopicProfile", channel_type="varm").T -model = Card(ref_count, ref_annot, ct_varname="cellType", ct_select=cell_types) +model = Card(basis) pred = model.fit_and_predict(x_count, x_spatial, max_iter=args.max_iter, epsilon=args.epsilon, location_free=args.location_free) mse = model.score(pred, y) From 838da0beb749b4539a98fda0b53933629ce6ad13 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 14:50:09 -0500 Subject: [PATCH 13/17] update: pass basis instead of basis.T --- dance/modules/spatial/cell_type_deconvo/card.py | 8 ++++---- examples/spatial/cell_type_deconvo/card.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index a4fe8751..92370957 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -157,7 +157,7 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: # Initialize the proportion matrix rng = np.random.default_rng(20200107) - Vint1 = rng.dirichlet(np.repeat(10, basis.shape[0], axis=0), x_norm.shape[0]) + Vint1 = rng.dirichlet(np.repeat(10, basis.shape[1], axis=0), x_norm.shape[0]) phi = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99] # scale the Xinput_norm and B to speed up the convergence. @@ -168,9 +168,9 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: ResList = {} Obj = np.array([]) for iphi in range(len(phi)): - res = CARDref(Xinput=x_norm.T, U=b_mat.T, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, epsilon=epsilon, - V=Vint1, b=np.repeat(0, b_mat.T.shape[1]).reshape(b_mat.T.shape[1], 1), sigma_e2=0.1, - Lambda=np.repeat(10, basis.shape[0])) + res = CARDref(Xinput=x_norm.T, U=b_mat, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, epsilon=epsilon, + V=Vint1, b=np.repeat(0, b_mat.shape[1]).reshape(b_mat.shape[1], 1), sigma_e2=0.1, + Lambda=np.repeat(10, basis.shape[1])) ResList[str(iphi)] = res Obj = np.append(Obj, res["Obj"]) self.Obj_hist = Obj diff --git a/examples/spatial/cell_type_deconvo/card.py b/examples/spatial/cell_type_deconvo/card.py index 78b861de..6432322f 100644 --- a/examples/spatial/cell_type_deconvo/card.py +++ b/examples/spatial/cell_type_deconvo/card.py @@ -30,9 +30,8 @@ data.set_config(feature_channel=[None, "spatial"], feature_channel_type=["X", "obsm"], label_channel="cell_type_portion") (x_count, x_spatial), y = data.get_data(split_name="test", return_type="numpy") -# TODO: adapt card to use basis.T # TODO: use "auto"/None option for ct_select -basis = data.get_feature(return_type="default", channel="CellTopicProfile", channel_type="varm").T +basis = data.get_feature(return_type="default", channel="CellTopicProfile", channel_type="varm") model = Card(basis) pred = model.fit_and_predict(x_count, x_spatial, max_iter=args.max_iter, epsilon=args.epsilon, From 1d0811c6fa020006516ff1fbbe781ddfcea667e7 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 14:52:23 -0500 Subject: [PATCH 14/17] update: use auto ct_select --- examples/spatial/cell_type_deconvo/card.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/spatial/cell_type_deconvo/card.py b/examples/spatial/cell_type_deconvo/card.py index 6432322f..609f74c8 100644 --- a/examples/spatial/cell_type_deconvo/card.py +++ b/examples/spatial/cell_type_deconvo/card.py @@ -18,9 +18,8 @@ # Load dataset dataset = CellTypeDeconvoDataset(data_dir=args.datadir, data_id=args.dataset) data = dataset.load_data() -cell_types = data.data.obsm["cell_type_portion"].columns.tolist() -CellTopicProfile(ct_select=cell_types, ct_key="cellType", batch_key=None, split_name="ref", method="mean", +CellTopicProfile(ct_select="auto", ct_key="cellType", batch_key=None, split_name="ref", method="mean", log_level="INFO")(data) FilterGenesMatch(prefixes=["mt-"], case_sensitive=False, log_level="INFO")(data) FilterGenesCommon(split_keys=["ref", "test"], log_level="INFO")(data) @@ -30,7 +29,6 @@ data.set_config(feature_channel=[None, "spatial"], feature_channel_type=["X", "obsm"], label_channel="cell_type_portion") (x_count, x_spatial), y = data.get_data(split_name="test", return_type="numpy") -# TODO: use "auto"/None option for ct_select basis = data.get_feature(return_type="default", channel="CellTopicProfile", channel_type="varm") model = Card(basis) From 6cdd031dd0eec3b73f32449eeb38d18661b66263 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 15:05:09 -0500 Subject: [PATCH 15/17] use logger --- .../modules/spatial/cell_type_deconvo/card.py | 28 ++++++++++++++++--- examples/spatial/cell_type_deconvo/card.py | 12 ++------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index 92370957..4f8ea2fc 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -11,6 +11,10 @@ import numpy as np import pandas as pd +from dance import logger +from dance.transforms import (CellTopicProfile, Compose, FilterGenesCommon, FilterGenesMarker, FilterGenesMatch, + FilterGenesPercentile, SetConfig) +from dance.typing import LogLevel from dance.utils.matrix import normalize, pairwise_distance @@ -21,6 +25,7 @@ def obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, temp = (V.T - b @ vecOne.T) @ L @ (V - vecOne @ b.T) logV = -(nSample) * 0.5 * np.sum(np.log(Lambda)) - 0.5 * (np.sum(np.diag(temp) / Lambda)) logSigmaL2 = -(alpha + 1.0) * np.sum(np.log(Lambda)) - np.sum(beta / Lambda) + logger.debug(f"{logX=:5.2e}, {logV=:5.2e}, {logSigmaL2=:5.2e}") return logX + logV + logSigmaL2 @@ -93,11 +98,10 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda): obj = obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha) logicalLogL = (obj > obj_old) & ((abs(obj - obj_old) * 2.0 / abs(obj + obj_old)) < epsilon) - # TODO: setup logging and make this debug or info - # print(f"{i=:<4}, {obj=:.5e}, {logX=:5.2e}, {logV=:5.2e}, {logSigmaL2=:5.2e}") + logger.debug(f"{i=:<4}, {obj=:.5e}") if (np.isnan(obj) | (np.sqrt(np.sum((V - V_old) * (V - V_old)) / (nSample * k)) < epsilon) | logicalLogL): if (i > 5): # // run at least 5 iterations - # print(f"Exiting at {i=}") + logger.info(f"Exiting at {i=}") iter_converge = i break else: @@ -122,6 +126,22 @@ def __init__(self, basis: pd.DataFrame): self.basis = basis self.info_parameters = {} + @staticmethod + def preprocessing_pipeline(log_level: LogLevel = "INFO"): + return Compose( + CellTopicProfile(ct_select="auto", ct_key="cellType", batch_key=None, split_name="ref", method="mean"), + FilterGenesMatch(prefixes=["mt-"], case_sensitive=False), + FilterGenesCommon(split_keys=["ref", "test"]), + FilterGenesMarker(ct_profile_channel="CellTopicProfile", threshold=1.25), + FilterGenesPercentile(min_val=1, max_val=99, mode="rv"), + SetConfig({ + "feature_channel": [None, "spatial"], + "feature_channel_type": ["X", "obsm"], + "label_channel": "cell_type_portion", + }), + log_level=log_level, + ) + def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: bool = False): """Fit function for model training. @@ -177,7 +197,7 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: Optimal_ix = np.where(Obj == Obj.max())[0][-1] # in case if there are two equal objective function values OptimalPhi = phi[Optimal_ix] OptimalRes = ResList[str(Optimal_ix)] - print("## Deconvolution Finish! ...\n") + logger.info("Deconvolution finished") self.info_parameters["phi"] = OptimalPhi self.algorithm_matrix = {"B": b_mat, "Xinput_norm": x_norm, "Res": OptimalRes} diff --git a/examples/spatial/cell_type_deconvo/card.py b/examples/spatial/cell_type_deconvo/card.py index 609f74c8..7d35ea33 100644 --- a/examples/spatial/cell_type_deconvo/card.py +++ b/examples/spatial/cell_type_deconvo/card.py @@ -3,8 +3,6 @@ from dance.datasets.spatial import CellTypeDeconvoDataset from dance.modules.spatial.cell_type_deconvo.card import Card -from dance.transforms import (CellTopicProfile, FilterGenesCommon, FilterGenesMarker, FilterGenesMatch, - FilterGenesPercentile) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--dataset", default="CARD_synthetic", choices=CellTypeDeconvoDataset.DATASETS) @@ -19,15 +17,9 @@ dataset = CellTypeDeconvoDataset(data_dir=args.datadir, data_id=args.dataset) data = dataset.load_data() -CellTopicProfile(ct_select="auto", ct_key="cellType", batch_key=None, split_name="ref", method="mean", - log_level="INFO")(data) -FilterGenesMatch(prefixes=["mt-"], case_sensitive=False, log_level="INFO")(data) -FilterGenesCommon(split_keys=["ref", "test"], log_level="INFO")(data) -FilterGenesMarker(ct_profile_channel="CellTopicProfile", threshold=1.25, log_level="INFO")(data) -FilterGenesPercentile(min_val=1, max_val=99, mode="rv", log_level="INFO")(data) +preprocessing_pipeline = Card.preprocessing_pipeline() +preprocessing_pipeline(data) -data.set_config(feature_channel=[None, "spatial"], feature_channel_type=["X", "obsm"], - label_channel="cell_type_portion") (x_count, x_spatial), y = data.get_data(split_name="test", return_type="numpy") basis = data.get_feature(return_type="default", channel="CellTopicProfile", channel_type="varm") From 9c481a8ed025a6fb12d89c0e2279450dd66fdf4c Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 15:52:49 -0500 Subject: [PATCH 16/17] improve attrs --- .../modules/spatial/cell_type_deconvo/card.py | 38 +++++++------------ 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index 4f8ea2fc..98e2293c 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -77,7 +77,6 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda): V_old = V.copy() # Iteration starts - iter_converge = 0 for i in range(max_iter): # logV = 0.0 Lambda = (np.diag(temp) / 2.0 + beta) / (nSample / 2.0 + alpha + 1.0) @@ -102,14 +101,13 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda): if (np.isnan(obj) | (np.sqrt(np.sum((V - V_old) * (V - V_old)) / (nSample * k)) < epsilon) | logicalLogL): if (i > 5): # // run at least 5 iterations logger.info(f"Exiting at {i=}") - iter_converge = i break else: obj_old = obj V_old = V.copy() + pred = V / V.sum(axis=1, keepdims=True) - res = {"V": V, "sigma_e2": sigma_e2, "Lambda": Lambda, "b": b, "Obj": obj, "iter_converge": iter_converge} - return res + return pred, obj class Card: @@ -124,7 +122,8 @@ class Card: def __init__(self, basis: pd.DataFrame): self.basis = basis - self.info_parameters = {} + self.best_phi = None + self.best_obj = -np.inf @staticmethod def preprocessing_pipeline(log_level: LogLevel = "INFO"): @@ -180,30 +179,21 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: Vint1 = rng.dirichlet(np.repeat(10, basis.shape[1], axis=0), x_norm.shape[0]) phi = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99] - # scale the Xinput_norm and B to speed up the convergence. + # Scale the Xinput_norm and B to speed up the convergence. x_norm = x_norm * 0.1 / x_norm.sum() b_mat = basis * 0.1 / basis.mean() # Optimization - ResList = {} - Obj = np.array([]) for iphi in range(len(phi)): - res = CARDref(Xinput=x_norm.T, U=b_mat, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, epsilon=epsilon, - V=Vint1, b=np.repeat(0, b_mat.shape[1]).reshape(b_mat.shape[1], 1), sigma_e2=0.1, - Lambda=np.repeat(10, basis.shape[1])) - ResList[str(iphi)] = res - Obj = np.append(Obj, res["Obj"]) - self.Obj_hist = Obj - Optimal_ix = np.where(Obj == Obj.max())[0][-1] # in case if there are two equal objective function values - OptimalPhi = phi[Optimal_ix] - OptimalRes = ResList[str(Optimal_ix)] + res, obj = CARDref(Xinput=x_norm.T, U=b_mat, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, + epsilon=epsilon, V=Vint1, b=np.repeat(0, b_mat.shape[1]).reshape(b_mat.shape[1], 1), + sigma_e2=0.1, Lambda=np.repeat(10, basis.shape[1])) + if obj > self.best_obj: + self.res = res + self.best_obj = obj + self.best_phi = phi logger.info("Deconvolution finished") - self.info_parameters["phi"] = OptimalPhi - self.algorithm_matrix = {"B": b_mat, "Xinput_norm": x_norm, "Res": OptimalRes} - - return self - def predict(self): """Prediction function. @@ -213,9 +203,7 @@ def predict(self): Predictions of cell-type proportions. """ - optim_res = self.algorithm_matrix["Res"] - prop_pred = optim_res["V"] / optim_res["V"].sum(axis=1, keepdims=True) - return prop_pred + return self.res def fit_and_predict(self, x, spatial, **kwargs): self.fit(x, spatial, **kwargs) From d9cafdb455b066545fc6baca841fa813d699b6f3 Mon Sep 17 00:00:00 2001 From: Remy Date: Thu, 23 Feb 2023 15:58:11 -0500 Subject: [PATCH 17/17] improve logic --- .../modules/spatial/cell_type_deconvo/card.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/dance/modules/spatial/cell_type_deconvo/card.py b/dance/modules/spatial/cell_type_deconvo/card.py index 98e2293c..cc6129c7 100644 --- a/dance/modules/spatial/cell_type_deconvo/card.py +++ b/dance/modules/spatial/cell_type_deconvo/card.py @@ -53,7 +53,6 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda): updateV_den_k = np.zeros(k) vecOne = np.ones((nSample, 1)) diag_UtU = np.zeros(k) - logicalLogL = False alpha = 1.0 beta = nSample / 2.0 accu_L = 0.0 @@ -73,12 +72,12 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda): colsum_W = colsum_W.reshape(nSample, 1) accu_L = np.sum(L) - obj_old = obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha, sigma_e2) - V_old = V.copy() - # Iteration starts + obj = obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha, sigma_e2) for i in range(max_iter): - # logV = 0.0 + obj_old = obj + V_old = V.copy() + Lambda = (np.diag(temp) / 2.0 + beta) / (nSample / 2.0 + alpha + 1.0) if W is not None: b = np.sum(V.T @ L, axis=1, keepdims=True) / accu_L @@ -96,15 +95,14 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda): VtV = V.T @ V obj = obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha) - logicalLogL = (obj > obj_old) & ((abs(obj - obj_old) * 2.0 / abs(obj + obj_old)) < epsilon) + logic1 = (obj > obj_old) & ((abs(obj - obj_old) * 2.0 / abs(obj + obj_old)) < epsilon) + logic2 = np.sqrt(np.sum((V - V_old) * (V - V_old)) / (nSample * k)) < epsilon + stop_logic = np.isnan(obj) or logic1 or logic2 logger.debug(f"{i=:<4}, {obj=:.5e}") - if (np.isnan(obj) | (np.sqrt(np.sum((V - V_old) * (V - V_old)) / (nSample * k)) < epsilon) | logicalLogL): - if (i > 5): # // run at least 5 iterations - logger.info(f"Exiting at {i=}") - break - else: - obj_old = obj - V_old = V.copy() + if stop_logic and i > 5: + logger.info(f"Exiting at {i=}") + break + pred = V / V.sum(axis=1, keepdims=True) return pred, obj