Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update card to use dance transformations #213

Merged
merged 17 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 55 additions & 185 deletions dance/modules/spatial/cell_type_deconvo/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
Nature Biotechnology (2022): 1-11.

"""
from itertools import chain

import numpy as np
import pandas as pd

from dance.utils.matrix import pairwise_distance
from dance import logger
from dance.transforms import (CellTopicProfile, Compose, FilterGenesCommon, FilterGenesMarker, FilterGenesMatch,
FilterGenesPercentile, SetConfig)
from dance.typing import LogLevel
from dance.utils.matrix import normalize, pairwise_distance


def obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha, sigma_e2=None):
Expand All @@ -23,6 +25,7 @@ def obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne,
temp = (V.T - b @ vecOne.T) @ L @ (V - vecOne @ b.T)
logV = -(nSample) * 0.5 * np.sum(np.log(Lambda)) - 0.5 * (np.sum(np.diag(temp) / Lambda))
logSigmaL2 = -(alpha + 1.0) * np.sum(np.log(Lambda)) - np.sum(beta / Lambda)
logger.debug(f"{logX=:5.2e}, {logV=:5.2e}, {logSigmaL2=:5.2e}")
return logX + logV + logSigmaL2


Expand Down Expand Up @@ -50,7 +53,6 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda):
updateV_den_k = np.zeros(k)
vecOne = np.ones((nSample, 1))
diag_UtU = np.zeros(k)
logicalLogL = False
alpha = 1.0
beta = nSample / 2.0
accu_L = 0.0
Expand All @@ -70,13 +72,12 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda):
colsum_W = colsum_W.reshape(nSample, 1)
accu_L = np.sum(L)

obj_old = obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha, sigma_e2)
V_old = V.copy()

# Iteration starts
iter_converge = 0
obj = obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha, sigma_e2)
for i in range(max_iter):
# logV = 0.0
obj_old = obj
V_old = V.copy()

Lambda = (np.diag(temp) / 2.0 + beta) / (nSample / 2.0 + alpha + 1.0)
if W is not None:
b = np.sum(V.T @ L, axis=1, keepdims=True) / accu_L
Expand All @@ -94,164 +95,49 @@ def CARDref(Xinput, U, W, phi, max_iter, epsilon, V, b, sigma_e2, Lambda):
VtV = V.T @ V
obj = obj_func(trac_xxt, UtXV, UtU, VtV, mGene, nSample, b, Lambda, beta, vecOne, V, L, alpha)

logicalLogL = (obj > obj_old) & ((abs(obj - obj_old) * 2.0 / abs(obj + obj_old)) < epsilon)
# TODO: setup logging and make this debug or info
# print(f"{i=:<4}, {obj=:.5e}, {logX=:5.2e}, {logV=:5.2e}, {logSigmaL2=:5.2e}")
if (np.isnan(obj) | (np.sqrt(np.sum((V - V_old) * (V - V_old)) / (nSample * k)) < epsilon) | logicalLogL):
if (i > 5): # // run at least 5 iterations
# print(f"Exiting at {i=}")
iter_converge = i
break
else:
obj_old = obj
V_old = V.copy()
logic1 = (obj > obj_old) & ((abs(obj - obj_old) * 2.0 / abs(obj + obj_old)) < epsilon)
logic2 = np.sqrt(np.sum((V - V_old) * (V - V_old)) / (nSample * k)) < epsilon
stop_logic = np.isnan(obj) or logic1 or logic2
logger.debug(f"{i=:<4}, {obj=:.5e}")
if stop_logic and i > 5:
logger.info(f"Exiting at {i=}")
break

pred = V / V.sum(axis=1, keepdims=True)

res = {"V": V, "sigma_e2": sigma_e2, "Lambda": Lambda, "b": b, "Obj": obj, "iter_converge": iter_converge}
return res
return pred, obj


class Card:
"""The CARD cell-type deconvolution model.

Parameters
----------
sc_count : pd.DataFrame
Reference single cell RNA-seq counts data.
sc_meta : pd.DataFrame
Reference cell-type label information.
ct_varname : str, optional
Name of the cell-types column.
ct_select : str, optional
Selected cell-types to be considered for deconvolution.
sample_varname : str, optional
Name of the samples column.
minCountGene : int
Minimum number of genes required.
minCountSpot : int
Minimum number of spots required.
basis
The basis parameter.
markers
Markers.
The cell-type profile basis.

"""

def __init__(self, sc_count, sc_meta, ct_varname=None, ct_select=None, sample_varname=None, minCountGene=100,
minCountSpot=5, basis=None, markers=None):
self.sc_count = sc_count
self.sc_meta = sc_meta
self.ct_varname = ct_varname
self.ct_select = ct_select
self.sample_varname = sample_varname
self.minCountGene = minCountGene
self.minCountSpot = minCountSpot
def __init__(self, basis: pd.DataFrame):
self.basis = basis
self.marker = markers
self.info_parameters = {}

self.createscRef() # create basis
all_genes = sc_count.columns.tolist()
gene_to_idx = {j: i for i, j in enumerate(all_genes)}
not_mt_genes = [i for i in all_genes if not i.lower().startswith("mt-")]
selected_genes = self.selectInfo(not_mt_genes)
selected_gene_idx = list(map(gene_to_idx.get, selected_genes))
self.gene_mask = np.zeros(len(all_genes), dtype=np.bool)
self.gene_mask[selected_gene_idx] = True

def createscRef(self):
"""CreatescRef - create reference basis matrix from reference scRNA-seq."""
countMat = self.sc_count.copy() # cell by gene matrix
sc_meta = self.sc_meta.copy()
ct_varname = self.ct_varname
sample_varname = self.sample_varname
if sample_varname is None:
sc_meta["sampleID"] = "Sample"
sample_varname = "sampleID"
sample_id = sc_meta[sample_varname].astype(str)
ct_sample_id = sc_meta[ct_varname] + "$*$" + sample_id
rowSums_countMat = countMat.sum(axis=1)
sc_meta["rowSums"] = rowSums_countMat
rowSums_countMat_Ct = sc_meta.groupby([ct_varname, sample_varname])["rowSums"].agg("sum").to_frame()
rowSums_countMat_Ct_Wide = rowSums_countMat_Ct.pivot_table(index=sample_varname, columns=ct_varname,
values="rowSums", aggfunc="sum")

# create count table by sampleID and cellType
tab = sc_meta.groupby([sample_varname, ct_varname]).size()
tbl = tab.unstack()

# match column and row names
rowSums_countMat_Ct_Wide = rowSums_countMat_Ct_Wide.reindex_like(tbl)
rowSums_countMat_Ct_Wide = rowSums_countMat_Ct_Wide.reindex(tbl.index)

# Compute total expression count by sample and cell type
S_JK = rowSums_countMat_Ct_Wide.div(tbl)
S_JK = S_JK.replace(0, np.nan)
S_JK = S_JK.replace([np.inf, -np.inf], np.nan)
S = S_JK.mean(axis=0).to_frame().unstack().droplevel(0)
S = S[sc_meta[ct_varname].unique()]
countMat["ct_sample_id"] = ct_sample_id
Theta_S_colMean = countMat.groupby(ct_sample_id).mean(numeric_only=True)
tbl_sample = countMat.groupby([ct_sample_id]).size()
tbl_sample = tbl_sample.reindex_like(Theta_S_colMean)
tbl_sample = tbl_sample.reindex(Theta_S_colMean.index)
Theta_S_colSums = countMat.groupby(ct_sample_id).sum(numeric_only=True)
Theta_S = Theta_S_colSums.copy()
Theta_S["sum"] = Theta_S_colSums.sum(axis=1)
Theta_S = Theta_S[list(Theta_S.columns)[:-1]].div(Theta_S["sum"], axis=0)
grp = []
for ind in Theta_S.index:
grp.append(ind.split("$*$")[0])
Theta_S["grp"] = grp
Theta = Theta_S.groupby(grp).mean(numeric_only=True)
Theta = Theta.reindex(sc_meta[ct_varname].unique())
S = S[Theta.index]
Theta["S"] = S.iloc[0]
basis = Theta[list(Theta.columns)[:-1]].mul(Theta["S"], axis=0)
self.basis = basis

def select_ct_marker(self, ict):
Basis = self.basis.copy()
rest = Basis[Basis.index != ict].to_numpy().mean(axis=0)
FC = np.log(Basis[Basis.index == ict].to_numpy().mean(axis=0) + 1e-6) - np.log(rest + 1e-6)
markers = list(Basis.columns[np.logical_and(FC > 1.25, Basis[Basis.index == ict].to_numpy().mean(axis=0) > 0)])
return markers

def selectInfo(self, common_gene):
"""Select Informative Genes used in the deconvolution.

Parameters
----------
common_gene : list
Common genes between scRNAseq count data and spatial resolved transcriptomics data.

Returns
-------
list
List of informative genes.
self.best_phi = None
self.best_obj = -np.inf

"""
ct_varname = self.ct_varname
Basis = self.basis.copy()
sc_count = self.sc_count.copy()
sc_meta = self.sc_meta.copy()
gene1_list = list()
for ict in Basis.index:
gene1_list.append(self.select_ct_marker(ict))
gene1 = set(chain(*gene1_list))
gene1 = list(gene1)
gene1 = [gene for gene in gene1 if gene in common_gene] # intersect with common_gene
counts = sc_count[gene1]
sd_within = pd.DataFrame(columns=counts.columns)
for ict in Basis.index:
series = (counts.loc[list(sc_meta[sc_meta[ct_varname] == ict].index)].var(axis=0).divide(counts.loc[list(
sc_meta[sc_meta[ct_varname] == ict].index)].mean(axis=0)))
series.name = ict
sd_within = pd.concat((sd_within, pd.DataFrame(series).T))

sd_within_colMean = sd_within.mean(axis=0).index.to_frame()
genes_to_select = sd_within.mean(axis=0) < sd_within.mean(axis=0).quantile(.99)
genes = list(sd_within_colMean[genes_to_select].index)
return genes
@staticmethod
def preprocessing_pipeline(log_level: LogLevel = "INFO"):
return Compose(
CellTopicProfile(ct_select="auto", ct_key="cellType", batch_key=None, split_name="ref", method="mean"),
FilterGenesMatch(prefixes=["mt-"], case_sensitive=False),
FilterGenesCommon(split_keys=["ref", "test"]),
FilterGenesMarker(ct_profile_channel="CellTopicProfile", threshold=1.25),
FilterGenesPercentile(min_val=1, max_val=99, mode="rv"),
SetConfig({
"feature_channel": [None, "spatial"],
"feature_channel_type": ["X", "obsm"],
"label_channel": "cell_type_portion",
}),
log_level=log_level,
)

def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free: bool = False):
"""Fit function for model training.
Expand All @@ -272,14 +158,8 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free:
Do not use spatial location info if set to True.

"""
ct_select = self.ct_select
Basis = self.basis.copy()
Basis = Basis.loc[ct_select]

gene_mask = self.gene_mask & (x.sum(0) > 0)
B = Basis.values[:, gene_mask].copy() # TODO: make it a numpy array
Xinput = x[:, gene_mask].copy()
Xinput_norm = Xinput / Xinput.sum(1, keepdims=True) # TODO: use the normalize util
basis = self.basis.values.copy()
x_norm = normalize(x, axis=1, mode="normalize")

# Spatial location
if location_free or (spatial == 0).all():
Expand All @@ -294,31 +174,23 @@ def fit(self, x, spatial, max_iter=100, epsilon=1e-4, sigma=0.1, location_free:

# Initialize the proportion matrix
rng = np.random.default_rng(20200107)
Vint1 = rng.dirichlet(np.repeat(10, B.shape[0], axis=0), Xinput_norm.shape[0])
Vint1 = rng.dirichlet(np.repeat(10, basis.shape[1], axis=0), x_norm.shape[0])
phi = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
# scale the Xinput_norm and B to speed up the convergence.
Xinput_norm = Xinput_norm * 0.1 / Xinput_norm.sum()
B = B * 0.1 / B.mean()

# Scale the Xinput_norm and B to speed up the convergence.
x_norm = x_norm * 0.1 / x_norm.sum()
b_mat = basis * 0.1 / basis.mean()

# Optimization
ResList = {}
Obj = np.array([])
for iphi in range(len(phi)):
res = CARDref(Xinput=Xinput_norm.T, U=B.T, W=kernel_mat, phi=phi[iphi], max_iter=max_iter, epsilon=epsilon,
V=Vint1, b=np.repeat(0, B.T.shape[1]).reshape(B.T.shape[1], 1), sigma_e2=0.1,
Lambda=np.repeat(10, len(ct_select)))
ResList[str(iphi)] = res
Obj = np.append(Obj, res["Obj"])
self.Obj_hist = Obj
Optimal_ix = np.where(Obj == Obj.max())[0][-1] # in case if there are two equal objective function values
OptimalPhi = phi[Optimal_ix]
OptimalRes = ResList[str(Optimal_ix)]
print("## Deconvolution Finish! ...\n")

self.info_parameters["phi"] = OptimalPhi
self.algorithm_matrix = {"B": B, "Xinput_norm": Xinput_norm, "Res": OptimalRes}

return self
res, obj = CARDref(Xinput=x_norm.T, U=b_mat, W=kernel_mat, phi=phi[iphi], max_iter=max_iter,
epsilon=epsilon, V=Vint1, b=np.repeat(0, b_mat.shape[1]).reshape(b_mat.shape[1], 1),
sigma_e2=0.1, Lambda=np.repeat(10, basis.shape[1]))
if obj > self.best_obj:
self.res = res
self.best_obj = obj
self.best_phi = phi
logger.info("Deconvolution finished")

def predict(self):
"""Prediction function.
Expand All @@ -329,9 +201,7 @@ def predict(self):
Predictions of cell-type proportions.

"""
optim_res = self.algorithm_matrix["Res"]
prop_pred = optim_res["V"] / optim_res["V"].sum(axis=1, keepdims=True)
return prop_pred
return self.res

def fit_and_predict(self, x, spatial, **kwargs):
self.fit(x, spatial, **kwargs)
Expand Down
10 changes: 6 additions & 4 deletions dance/modules/spatial/cell_type_deconvo/spotlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
spots with single-cell transcriptomes." Nucleic Acids Research (2021)

"""
from functools import partial

import torch
from torch import nn, optim
from torchnmf.nmf import NMF
Expand All @@ -18,7 +20,7 @@
from dance.utils import get_device
from dance.utils.wrappers import CastOutputType

get_ct_profile_tensor = CastOutputType(torch.FloatTensor)(get_ct_profile)
get_ct_profile_tensor = CastOutputType(torch.FloatTensor)(partial(get_ct_profile, method="median"))


class NNLS(nn.Module):
Expand Down Expand Up @@ -131,7 +133,7 @@ def _init_model(self, dim_out, ref_count, ref_annot):
hid_dim = len(self.ct_select)
self.nmf_model = NMF(Vshape=ref_count.T.shape, rank=self.rank).to(self.device)
if self.rank == len(self.ct_select): # initialize basis as cell profile
self.nmf_model.H = nn.Parameter(get_ct_profile_tensor(ref_count, ref_annot, self.ct_select))
self.nmf_model.H = nn.Parameter(get_ct_profile_tensor(ref_count, ref_annot, ct_select=self.ct_select))

self.nnls_reg1 = NNLS(in_dim=self.rank, out_dim=dim_out, bias=self.bias, device=self.device)
self.nnls_reg2 = NNLS(in_dim=hid_dim, out_dim=dim_out, bias=self.bias, device=self.device)
Expand All @@ -145,7 +147,7 @@ def forward(self, ref_annot):

# Get cell-topic and mix-topic profiles
# Get cell-topic profiles H_profile: cell-type group medians of coef H (topic x cells)
H_profile = get_ct_profile_tensor(H.cpu().numpy().T, ref_annot, self.ct_select)
H_profile = get_ct_profile_tensor(H.cpu().numpy().T, ref_annot, ct_select=self.ct_select)
H_profile = H_profile.to(self.device)

# Get mix-topic profiles B: NNLS of basis W onto mix expression Y -- y ~ W*b
Expand Down Expand Up @@ -183,7 +185,7 @@ def fit(self, x, ref_count, ref_annot, lr=1e-3, max_iter=1000):

# Get cell-topic and mix-topic profiles
# Get cell-topic profiles H_profile: cell-type group medians of coef H (topic x cells)
self.H_profile = get_ct_profile_tensor(self.H.cpu().numpy().T, ref_annot, self.ct_select)
self.H_profile = get_ct_profile_tensor(self.H.cpu().numpy().T, ref_annot, ct_select=self.ct_select)
self.H_profile = self.H_profile.to(self.device)

# Get mix-topic profiles B: NNLS of basis W onto mix expression X ~ W*b
Expand Down
3 changes: 2 additions & 1 deletion dance/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dance.transforms import graph
from dance.transforms.cell_feature import CellPCA, WeightedFeaturePCA
from dance.transforms.filter import FilterGenesCommon, FilterGenesMatch, FilterGenesPercentile
from dance.transforms.filter import FilterGenesCommon, FilterGenesMarker, FilterGenesMatch, FilterGenesPercentile
from dance.transforms.interface import AnnDataTransform
from dance.transforms.misc import Compose, RemoveSplit, SaveRaw, SetConfig
from dance.transforms.normalize import ScaleFeature
Expand All @@ -15,6 +15,7 @@
"CellTopicProfile",
"Compose",
"FilterGenesCommon",
"FilterGenesMarker",
"FilterGenesMatch",
"FilterGenesPercentile",
"GeneStats",
Expand Down
Loading