Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: create preprocessing pipelines for Spotlight and SpatialDecon #211

Merged
merged 9 commits into from
Feb 22, 2023
6 changes: 3 additions & 3 deletions dance/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,11 @@ def get_feature(self, *, split_name: Optional[str] = None, return_type: FeatType

# Extract specific split
if split_name is not None:
if channel_type in ["X", "raw_X", "obsm", "obsp", "layers"]:
if channel_type in ["X", "raw_X", "obs", "obsm", "obsp", "layers"]:
idx = self.get_split_idx(split_name, error_on_miss=True)
feature = feature[idx][:, idx] if channel_type == "obsp" else feature[idx]
elif isinstance(channel_type, str) and channel_type.startswith("var"):
logger.warning(f"Indexing option for {channel_type} not implemented yet.")
else:
logger.warning(f"Indexing option for {channel_type!r} not implemented yet.")

# Convert to other data types if needed
if return_type == "torch":
Expand Down
4 changes: 2 additions & 2 deletions dance/datasets/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def _raw_to_dance(self, raw_data: Tuple[pd.DataFrame, ...]):
obs=pd.DataFrame(index=count_matrix.index.tolist()),
var=pd.DataFrame(index=count_matrix.columns.tolist()),
)
adata_inf.obsm["cell_type_portion"] = cell_type_portion
adata_inf.obsm["spatial"] = spatial
adata_inf.obsm["cell_type_portion"] = cell_type_portion.astype(np.float32)
adata_inf.obsm["spatial"] = spatial.astype(np.float32)
adata_ref = AnnData(
ref_count.values,
dtype=np.float32,
Expand Down
77 changes: 19 additions & 58 deletions dance/modules/spatial/cell_type_deconvo/spatialdecon.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
transcriptomic data." Nature Communications (2022)
"""
import numpy as np
import torch
import torch.nn as nn
from torch import optim

from dance.transforms import CellTopicProfile, Compose, SetConfig
from dance.typing import LogLevel
from dance.utils import get_device
from dance.utils.matrix import normalize

Expand Down Expand Up @@ -44,39 +45,6 @@ def forward(self, pred, true):
return loss


def cell_topic_profile(X, groups, ct_select, axis=0, method='median'):
"""Compute cell topic profile matrix.
Parameters
----------
X : torch 2-d tensor
Gene expression matrix (gene x cell).
groups : int
Cell-type labels of each sample cell in X.
ct_select:
Cell types to profile.
method : string optional
Method for reduction of cell-types for cell profile, default median.
Returns
-------
X_profile : torch 2-d tensor
Cell profile matrix from scRNA-seq reference (gene x cell-type).
"""
if method == "median":
X_profile = np.array([
np.median(X[[i for i in range(len(groups)) if groups[i] == ct_select[j]], :], axis=0)
for j in range(len(ct_select))
]).T
else:
X_profile = np.array([
np.mean(X[[i for i in range(len(groups)) if groups[i] == ct_select[j]], :], axis=0)
for j in range(len(ct_select))
]).T
return X_profile


class SpatialDecon:
"""SpatialDecon.
Expand All @@ -101,31 +69,24 @@ class SpatialDecon:
"""

def __init__(self, sc_count, sc_annot, ct_varname, ct_select, sc_profile=None, bias=False, init_bias=None,
device="auto"):
def __init__(self, ct_select, sc_profile=None, bias=False, init_bias=None, device="auto"):
super().__init__()

self.device = get_device(device)

# TODO: extract to preprocessing transformation and remove ct_select from input
# Subset sc samples on selected cell types (mutual between sc and mix cell data)
ct_select_ix = sc_annot[sc_annot[ct_varname].isin(ct_select)].index
self.sc_annot = sc_annot.loc[ct_select_ix]
self.sc_count = sc_count.loc[ct_select_ix]
cellTypes = self.sc_annot[ct_varname].values.tolist()

# Construct a cell profile matrix if not profided
if sc_profile is None:
self.ref_sc_profile = cell_topic_profile(self.sc_count.values, cellTypes, ct_select, method='median')
else:
self.ref_sc_profile = sc_profile

self.ct_select = ct_select
self.bias = bias
self.init_bias = init_bias
self.model = None

@staticmethod
def preprocessing_pipeline(ct_select, ct_profile_split: str = "ref", log_level: LogLevel = "INFO"):
return Compose(
CellTopicProfile(ct_select=ct_select, split_name="ref"),
SetConfig({"label_channel": "cell_type_portion"}),
log_level=log_level,
)

def _init_model(self, num_cells: int, bias: bool = True):
num_cell_types = self.ref_sc_profile.shape[1]
num_cell_types = len(self.ct_select)
model = nn.Linear(in_features=num_cell_types, out_features=num_cells, bias=self.bias)
if self.init_bias is not None:
model.bias = nn.Parameter(torch.Tensor(self.init_bias.values.T.copy()))
Expand All @@ -148,7 +109,7 @@ def predict(self):
proportion_preds = normalize(weights, mode="normalize", axis=1)
return proportion_preds

def fit(self, x, lr=1e-4, max_iter=500, print_res=False, print_period=100):
def fit(self, x, ct_profile, lr=1e-4, max_iter=500, print_res=False, print_period=100):
"""fit function for model training.
Parameters
Expand All @@ -165,17 +126,17 @@ def fit(self, x, lr=1e-4, max_iter=500, print_res=False, print_period=100):
Indicates number of iterations until training results print.
"""
ref_ct_profile = ct_profile.to(self.device)
mix_count = x.T.to(self.device)
self._init_model(x.shape[0])
ref_sc_profile = torch.FloatTensor(self.ref_sc_profile).to(self.device)
mix_count = torch.FloatTensor(x.T).to(self.device)

criterion = MSLELoss()
optimizer = optim.Adam(self.model.parameters(), lr=lr)

self.model.train()
for iteration in range(max_iter):
iteration += 1
mix_pred = self.model(ref_sc_profile)
mix_pred = self.model(ref_ct_profile)

loss = criterion(mix_pred, mix_count)
self.loss = loss
Expand All @@ -189,9 +150,9 @@ def fit(self, x, lr=1e-4, max_iter=500, print_res=False, print_period=100):
if iteration % print_period == 0:
print(f"Epoch: {iteration:02}/{max_iter} Loss: {loss.item():.5e}")

def fit_and_predict(self, x, lr=1e-4, max_iter=500, print_res=False, print_period=100):
def fit_and_predict(self, *args, **kwargs):
"""Fit parameters and return cell-type portion predictions."""
self.fit(x, lr=lr, max_iter=max_iter, print_res=print_res, print_period=print_period)
self.fit(*args, **kwargs)
pred = self.predict()
return pred

Expand Down
79 changes: 20 additions & 59 deletions dance/modules/spatial/cell_type_deconvo/spotlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,17 @@
spots with single-cell transcriptomes." Nucleic Acids Research (2021)
"""
import numpy as np
import torch
from torch import nn, optim
from torchnmf.nmf import NMF

from dance.transforms import SetConfig
from dance.transforms.pseudo_gen import get_ct_profile
from dance.typing import LogLevel
from dance.utils import get_device
from dance.utils.wrappers import CastOutputType


def cell_topic_profile(x, groups, ct_select, axis=0, method="median"):
"""Cell topic profile.
Parameters
----------
x : torch.Tensor
Gene expression matrix (gene x cells).
groups : int
Cell-type labels of each sample cell in x.
ct_select:
Cell-types to profile.
method : str
Method for reduction of cell-types for cell profile, default median.
Returns
-------
x_profile : torch 2-d tensor
cell profile matrix from scRNA-seq reference.
"""
if method == "median":
x_profile = np.array([
np.median(x[[i for i in range(len(groups)) if groups[i] == ct_select[j]], :], axis=0)
for j in range(len(ct_select))
]).T
else:
x_profile = np.array([
np.mean(x[[i for i in range(len(groups)) if groups[i] == ct_select[j]], :], axis=0)
for j in range(len(ct_select))
]).T
return torch.Tensor(x_profile)
get_ct_profile_tensor = CastOutputType(torch.FloatTensor)(get_ct_profile)


class NNLS(nn.Module):
Expand Down Expand Up @@ -144,46 +116,36 @@ class SPOTlight:
"""

def __init__(self, ref_count, ref_annot, ct_varname, ct_select, rank=2, sc_profile=None, bias=False, init_bias=None,
init="random", device="auto"):
def __init__(self, ct_select, rank=2, sc_profile=None, bias=False, init_bias=None, init="random", device="auto"):
super().__init__()
self.device = get_device(device)
self.bias = bias
self.ct_select = ct_select
self.rank = rank

# TODO: extract to preprocessing transformation and remove ct_select from input
# Subset sc samples on selected cell types (mutual between sc and mix cell data)
ct_select_ix = ref_annot[ref_annot[ct_varname].isin(ct_select)].index
self.ref_annot = ref_annot.loc[ct_select_ix]
self.ref_count = ref_count.loc[ct_select_ix]
self.ref_annot = self.ref_annot[ct_varname].values.tolist()

# Construct a cell profile matrix if not profided
if sc_profile is None:
self.ref_sc_profile = cell_topic_profile(self.ref_count.values, self.ref_annot, ct_select, method="median")
else:
self.ref_sc_profile = sc_profile
@staticmethod
def preprocessing_pipeline(log_level: LogLevel = "INFO"):
return SetConfig({"label_channel": "cell_type_portion"}, log_level=log_level)

def _init_model(self, dim_out):
def _init_model(self, dim_out, ref_count, ref_annot):
hid_dim = len(self.ct_select)
self.nmf_model = NMF(Vshape=self.ref_count.T.shape, rank=self.rank).to(self.device)
self.nmf_model = NMF(Vshape=ref_count.T.shape, rank=self.rank).to(self.device)
if self.rank == len(self.ct_select): # initialize basis as cell profile
self.nmf_model.H = nn.Parameter(torch.Tensor(self.ref_sc_profile))
self.nmf_model.H = nn.Parameter(get_ct_profile_tensor(ref_count, ref_annot, self.ct_select))

self.nnls_reg1 = NNLS(in_dim=self.rank, out_dim=dim_out, bias=self.bias, device=self.device)
self.nnls_reg2 = NNLS(in_dim=hid_dim, out_dim=dim_out, bias=self.bias, device=self.device)

self.model = nn.Sequential(self.nmf_model, self.nnls_reg1, self.nnls_reg2)

def forward(self):
def forward(self, ref_annot):
# Get NMF decompositions
W = self.nmf_model.H.clone()
H = self.nmf_model.W.clone().T

# Get cell-topic and mix-topic profiles
# Get cell-topic profiles H_profile: cell-type group medians of coef H (topic x cells)
H_profile = cell_topic_profile(H.cpu().numpy().T, self.ref_annot, self.ct_select, method="median")
H_profile = get_ct_profile_tensor(H.cpu().numpy().T, ref_annot, self.ct_select)
H_profile = H_profile.to(self.device)

# Get mix-topic profiles B: NNLS of basis W onto mix expression Y -- y ~ W*b
Expand All @@ -195,7 +157,7 @@ def forward(self):

return (W, H_profile, B, P)

def fit(self, x, lr=1e-3, max_iter=1000):
def fit(self, x, ref_count, ref_annot, lr=1e-3, max_iter=1000):
"""Fit function for model training.
Parameters
Expand All @@ -208,9 +170,9 @@ def fit(self, x, lr=1e-3, max_iter=1000):
Maximum iterations allowed for matrix factorization solver.
"""
self._init_model(x.shape[0])
x = torch.FloatTensor(x.T).to(self.device)
y = torch.FloatTensor(self.ref_count.values.T).to(self.device)
self._init_model(x.shape[0], ref_count, ref_annot)
x = x.T.to(self.device)
y = torch.FloatTensor(ref_count.T).to(self.device)

# Run NMF on scRNA X
self.nmf_model.fit(y, max_iter=max_iter)
Expand All @@ -221,7 +183,7 @@ def fit(self, x, lr=1e-3, max_iter=1000):

# Get cell-topic and mix-topic profiles
# Get cell-topic profiles H_profile: cell-type group medians of coef H (topic x cells)
self.H_profile = cell_topic_profile(self.H.cpu().numpy().T, self.ref_annot, self.ct_select, method="median")
self.H_profile = get_ct_profile_tensor(self.H.cpu().numpy().T, ref_annot, self.ct_select)
self.H_profile = self.H_profile.to(self.device)

# Get mix-topic profiles B: NNLS of basis W onto mix expression X ~ W*b
Expand All @@ -244,8 +206,7 @@ def predict(self):
Predicted cell-type proportions (cell x cell-type).
"""
W, H_profile, B, P = self.forward()
pred = P / torch.sum(P, axis=0, keepdims=True).clamp(min=1e-6)
pred = self.P / torch.sum(self.P, axis=0, keepdims=True).clamp(min=1e-6)
return pred.T

def fit_and_predict(self, *args, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions dance/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from dance.transforms.filter import FilterGenesCommon, FilterGenesMatch, FilterGenesPercentile
from dance.transforms.interface import AnnDataTransform
from dance.transforms.misc import Compose, SaveRaw, SetConfig
from dance.transforms.pseudo_gen import CellTopicProfile
from dance.transforms.scn_feature import SCNFeature
from dance.transforms.spatial_feature import MorphologyFeature, SMEFeature
from dance.transforms.stats import GeneStats

__all__ = [
"AnnDataTransform",
"CellPCA",
"CellTopicProfile",
"Compose",
"FilterGenesCommon",
"FilterGenesMatch",
Expand Down
Loading