Skip to content

Commit

Permalink
refactor: create preprocessing pipelines for Spotlight and SpatialDec…
Browse files Browse the repository at this point in the history
…on (#211)

* fix: add obs to split branch

* feat: implement CellTypeProfile transform

* feat: implement CastOutputType wrapper

* add Callable and Logger types

* update spotlight to use CellTypeProfile

* update spatialdecon to use CellTypeProfile

* transpose ct_profile

* return torch tensors

* create preprocessing pipelines for Spotlight and SpatialDecon
  • Loading branch information
RemyLau authored Feb 22, 2023
1 parent 5a86543 commit 3ba0bac
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 154 deletions.
6 changes: 3 additions & 3 deletions dance/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,11 @@ def get_feature(self, *, split_name: Optional[str] = None, return_type: FeatType

# Extract specific split
if split_name is not None:
if channel_type in ["X", "raw_X", "obsm", "obsp", "layers"]:
if channel_type in ["X", "raw_X", "obs", "obsm", "obsp", "layers"]:
idx = self.get_split_idx(split_name, error_on_miss=True)
feature = feature[idx][:, idx] if channel_type == "obsp" else feature[idx]
elif isinstance(channel_type, str) and channel_type.startswith("var"):
logger.warning(f"Indexing option for {channel_type} not implemented yet.")
else:
logger.warning(f"Indexing option for {channel_type!r} not implemented yet.")

# Convert to other data types if needed
if return_type == "torch":
Expand Down
4 changes: 2 additions & 2 deletions dance/datasets/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def _raw_to_dance(self, raw_data: Tuple[pd.DataFrame, ...]):
obs=pd.DataFrame(index=count_matrix.index.tolist()),
var=pd.DataFrame(index=count_matrix.columns.tolist()),
)
adata_inf.obsm["cell_type_portion"] = cell_type_portion
adata_inf.obsm["spatial"] = spatial
adata_inf.obsm["cell_type_portion"] = cell_type_portion.astype(np.float32)
adata_inf.obsm["spatial"] = spatial.astype(np.float32)
adata_ref = AnnData(
ref_count.values,
dtype=np.float32,
Expand Down
77 changes: 19 additions & 58 deletions dance/modules/spatial/cell_type_deconvo/spatialdecon.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
transcriptomic data." Nature Communications (2022)
"""
import numpy as np
import torch
import torch.nn as nn
from torch import optim

from dance.transforms import CellTopicProfile, Compose, SetConfig
from dance.typing import LogLevel
from dance.utils import get_device
from dance.utils.matrix import normalize

Expand Down Expand Up @@ -44,39 +45,6 @@ def forward(self, pred, true):
return loss


def cell_topic_profile(X, groups, ct_select, axis=0, method='median'):
"""Compute cell topic profile matrix.
Parameters
----------
X : torch 2-d tensor
Gene expression matrix (gene x cell).
groups : int
Cell-type labels of each sample cell in X.
ct_select:
Cell types to profile.
method : string optional
Method for reduction of cell-types for cell profile, default median.
Returns
-------
X_profile : torch 2-d tensor
Cell profile matrix from scRNA-seq reference (gene x cell-type).
"""
if method == "median":
X_profile = np.array([
np.median(X[[i for i in range(len(groups)) if groups[i] == ct_select[j]], :], axis=0)
for j in range(len(ct_select))
]).T
else:
X_profile = np.array([
np.mean(X[[i for i in range(len(groups)) if groups[i] == ct_select[j]], :], axis=0)
for j in range(len(ct_select))
]).T
return X_profile


class SpatialDecon:
"""SpatialDecon.
Expand All @@ -101,31 +69,24 @@ class SpatialDecon:
"""

def __init__(self, sc_count, sc_annot, ct_varname, ct_select, sc_profile=None, bias=False, init_bias=None,
device="auto"):
def __init__(self, ct_select, sc_profile=None, bias=False, init_bias=None, device="auto"):
super().__init__()

self.device = get_device(device)

# TODO: extract to preprocessing transformation and remove ct_select from input
# Subset sc samples on selected cell types (mutual between sc and mix cell data)
ct_select_ix = sc_annot[sc_annot[ct_varname].isin(ct_select)].index
self.sc_annot = sc_annot.loc[ct_select_ix]
self.sc_count = sc_count.loc[ct_select_ix]
cellTypes = self.sc_annot[ct_varname].values.tolist()

# Construct a cell profile matrix if not profided
if sc_profile is None:
self.ref_sc_profile = cell_topic_profile(self.sc_count.values, cellTypes, ct_select, method='median')
else:
self.ref_sc_profile = sc_profile

self.ct_select = ct_select
self.bias = bias
self.init_bias = init_bias
self.model = None

@staticmethod
def preprocessing_pipeline(ct_select, ct_profile_split: str = "ref", log_level: LogLevel = "INFO"):
return Compose(
CellTopicProfile(ct_select=ct_select, split_name="ref"),
SetConfig({"label_channel": "cell_type_portion"}),
log_level=log_level,
)

def _init_model(self, num_cells: int, bias: bool = True):
num_cell_types = self.ref_sc_profile.shape[1]
num_cell_types = len(self.ct_select)
model = nn.Linear(in_features=num_cell_types, out_features=num_cells, bias=self.bias)
if self.init_bias is not None:
model.bias = nn.Parameter(torch.Tensor(self.init_bias.values.T.copy()))
Expand All @@ -148,7 +109,7 @@ def predict(self):
proportion_preds = normalize(weights, mode="normalize", axis=1)
return proportion_preds

def fit(self, x, lr=1e-4, max_iter=500, print_res=False, print_period=100):
def fit(self, x, ct_profile, lr=1e-4, max_iter=500, print_res=False, print_period=100):
"""fit function for model training.
Parameters
Expand All @@ -165,17 +126,17 @@ def fit(self, x, lr=1e-4, max_iter=500, print_res=False, print_period=100):
Indicates number of iterations until training results print.
"""
ref_ct_profile = ct_profile.to(self.device)
mix_count = x.T.to(self.device)
self._init_model(x.shape[0])
ref_sc_profile = torch.FloatTensor(self.ref_sc_profile).to(self.device)
mix_count = torch.FloatTensor(x.T).to(self.device)

criterion = MSLELoss()
optimizer = optim.Adam(self.model.parameters(), lr=lr)

self.model.train()
for iteration in range(max_iter):
iteration += 1
mix_pred = self.model(ref_sc_profile)
mix_pred = self.model(ref_ct_profile)

loss = criterion(mix_pred, mix_count)
self.loss = loss
Expand All @@ -189,9 +150,9 @@ def fit(self, x, lr=1e-4, max_iter=500, print_res=False, print_period=100):
if iteration % print_period == 0:
print(f"Epoch: {iteration:02}/{max_iter} Loss: {loss.item():.5e}")

def fit_and_predict(self, x, lr=1e-4, max_iter=500, print_res=False, print_period=100):
def fit_and_predict(self, *args, **kwargs):
"""Fit parameters and return cell-type portion predictions."""
self.fit(x, lr=lr, max_iter=max_iter, print_res=print_res, print_period=print_period)
self.fit(*args, **kwargs)
pred = self.predict()
return pred

Expand Down
79 changes: 20 additions & 59 deletions dance/modules/spatial/cell_type_deconvo/spotlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,17 @@
spots with single-cell transcriptomes." Nucleic Acids Research (2021)
"""
import numpy as np
import torch
from torch import nn, optim
from torchnmf.nmf import NMF

from dance.transforms import SetConfig
from dance.transforms.pseudo_gen import get_ct_profile
from dance.typing import LogLevel
from dance.utils import get_device
from dance.utils.wrappers import CastOutputType


def cell_topic_profile(x, groups, ct_select, axis=0, method="median"):
"""Cell topic profile.
Parameters
----------
x : torch.Tensor
Gene expression matrix (gene x cells).
groups : int
Cell-type labels of each sample cell in x.
ct_select:
Cell-types to profile.
method : str
Method for reduction of cell-types for cell profile, default median.
Returns
-------
x_profile : torch 2-d tensor
cell profile matrix from scRNA-seq reference.
"""
if method == "median":
x_profile = np.array([
np.median(x[[i for i in range(len(groups)) if groups[i] == ct_select[j]], :], axis=0)
for j in range(len(ct_select))
]).T
else:
x_profile = np.array([
np.mean(x[[i for i in range(len(groups)) if groups[i] == ct_select[j]], :], axis=0)
for j in range(len(ct_select))
]).T
return torch.Tensor(x_profile)
get_ct_profile_tensor = CastOutputType(torch.FloatTensor)(get_ct_profile)


class NNLS(nn.Module):
Expand Down Expand Up @@ -144,46 +116,36 @@ class SPOTlight:
"""

def __init__(self, ref_count, ref_annot, ct_varname, ct_select, rank=2, sc_profile=None, bias=False, init_bias=None,
init="random", device="auto"):
def __init__(self, ct_select, rank=2, sc_profile=None, bias=False, init_bias=None, init="random", device="auto"):
super().__init__()
self.device = get_device(device)
self.bias = bias
self.ct_select = ct_select
self.rank = rank

# TODO: extract to preprocessing transformation and remove ct_select from input
# Subset sc samples on selected cell types (mutual between sc and mix cell data)
ct_select_ix = ref_annot[ref_annot[ct_varname].isin(ct_select)].index
self.ref_annot = ref_annot.loc[ct_select_ix]
self.ref_count = ref_count.loc[ct_select_ix]
self.ref_annot = self.ref_annot[ct_varname].values.tolist()

# Construct a cell profile matrix if not profided
if sc_profile is None:
self.ref_sc_profile = cell_topic_profile(self.ref_count.values, self.ref_annot, ct_select, method="median")
else:
self.ref_sc_profile = sc_profile
@staticmethod
def preprocessing_pipeline(log_level: LogLevel = "INFO"):
return SetConfig({"label_channel": "cell_type_portion"}, log_level=log_level)

def _init_model(self, dim_out):
def _init_model(self, dim_out, ref_count, ref_annot):
hid_dim = len(self.ct_select)
self.nmf_model = NMF(Vshape=self.ref_count.T.shape, rank=self.rank).to(self.device)
self.nmf_model = NMF(Vshape=ref_count.T.shape, rank=self.rank).to(self.device)
if self.rank == len(self.ct_select): # initialize basis as cell profile
self.nmf_model.H = nn.Parameter(torch.Tensor(self.ref_sc_profile))
self.nmf_model.H = nn.Parameter(get_ct_profile_tensor(ref_count, ref_annot, self.ct_select))

self.nnls_reg1 = NNLS(in_dim=self.rank, out_dim=dim_out, bias=self.bias, device=self.device)
self.nnls_reg2 = NNLS(in_dim=hid_dim, out_dim=dim_out, bias=self.bias, device=self.device)

self.model = nn.Sequential(self.nmf_model, self.nnls_reg1, self.nnls_reg2)

def forward(self):
def forward(self, ref_annot):
# Get NMF decompositions
W = self.nmf_model.H.clone()
H = self.nmf_model.W.clone().T

# Get cell-topic and mix-topic profiles
# Get cell-topic profiles H_profile: cell-type group medians of coef H (topic x cells)
H_profile = cell_topic_profile(H.cpu().numpy().T, self.ref_annot, self.ct_select, method="median")
H_profile = get_ct_profile_tensor(H.cpu().numpy().T, ref_annot, self.ct_select)
H_profile = H_profile.to(self.device)

# Get mix-topic profiles B: NNLS of basis W onto mix expression Y -- y ~ W*b
Expand All @@ -195,7 +157,7 @@ def forward(self):

return (W, H_profile, B, P)

def fit(self, x, lr=1e-3, max_iter=1000):
def fit(self, x, ref_count, ref_annot, lr=1e-3, max_iter=1000):
"""Fit function for model training.
Parameters
Expand All @@ -208,9 +170,9 @@ def fit(self, x, lr=1e-3, max_iter=1000):
Maximum iterations allowed for matrix factorization solver.
"""
self._init_model(x.shape[0])
x = torch.FloatTensor(x.T).to(self.device)
y = torch.FloatTensor(self.ref_count.values.T).to(self.device)
self._init_model(x.shape[0], ref_count, ref_annot)
x = x.T.to(self.device)
y = torch.FloatTensor(ref_count.T).to(self.device)

# Run NMF on scRNA X
self.nmf_model.fit(y, max_iter=max_iter)
Expand All @@ -221,7 +183,7 @@ def fit(self, x, lr=1e-3, max_iter=1000):

# Get cell-topic and mix-topic profiles
# Get cell-topic profiles H_profile: cell-type group medians of coef H (topic x cells)
self.H_profile = cell_topic_profile(self.H.cpu().numpy().T, self.ref_annot, self.ct_select, method="median")
self.H_profile = get_ct_profile_tensor(self.H.cpu().numpy().T, ref_annot, self.ct_select)
self.H_profile = self.H_profile.to(self.device)

# Get mix-topic profiles B: NNLS of basis W onto mix expression X ~ W*b
Expand All @@ -244,8 +206,7 @@ def predict(self):
Predicted cell-type proportions (cell x cell-type).
"""
W, H_profile, B, P = self.forward()
pred = P / torch.sum(P, axis=0, keepdims=True).clamp(min=1e-6)
pred = self.P / torch.sum(self.P, axis=0, keepdims=True).clamp(min=1e-6)
return pred.T

def fit_and_predict(self, *args, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions dance/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from dance.transforms.filter import FilterGenesCommon, FilterGenesMatch, FilterGenesPercentile
from dance.transforms.interface import AnnDataTransform
from dance.transforms.misc import Compose, SaveRaw, SetConfig
from dance.transforms.pseudo_gen import CellTopicProfile
from dance.transforms.scn_feature import SCNFeature
from dance.transforms.spatial_feature import MorphologyFeature, SMEFeature
from dance.transforms.stats import GeneStats

__all__ = [
"AnnDataTransform",
"CellPCA",
"CellTopicProfile",
"Compose",
"FilterGenesCommon",
"FilterGenesMatch",
Expand Down
Loading

0 comments on commit 3ba0bac

Please sign in to comment.