Skip to content

Commit

Permalink
unify cell-type annotation dataloader and example scripts format (#137)
Browse files Browse the repository at this point in the history
* remove unused params

* fix cell_label_to_df, add option to encode single label as string

* refactor svm data loading

* update default options and remove unused args

* implement percentile gene filter transform

* refactor actinn data loading

* fix cell-type name mapping, add info log

* add cv option to FilterGenesPercentile

* fix cell_label_to_df: skip write if no match

* turn ACTINN normalization off by default

* remove unused functions, imports, and params

* fix get_map_dict: obtain mapped values

* improve training log info in actinn

* improve url dict format

* add mapped test cell-type info to debug log

* add option to change global log level

* update actinn example settings and benchmark scores

Disable redundant normalization

* update celltypist and singlecellnet dataloader

* remove unused functions

* update scdeepsort dataloader

* format fixes and remove unused functions

* remove data_type, unified

* sort parsed args

* unify cell-type annotation example script format
  • Loading branch information
RemyLau authored Jan 20, 2023
1 parent e1d7b50 commit 7349425
Show file tree
Hide file tree
Showing 13 changed files with 404 additions and 1,152 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ Note: the data split modality of DeepImpute is different from ScGNN and GraphSCI

| Model | Evaluation Metric | Mouse Brain 2695 (current/reported) | Mouse Spleen 1759 (current/reported) | Mouse Kidney 203 (current/reported) |
| ------------- | ----------------- | ----------------------------------- | ------------------------------------ | ----------------------------------- |
| scDeepsort | ACC | 0.363/0.363 | 0.965 /0.965 | 0.901/0.911 |
| scDeepsort | ACC | 0.542/0.363 | 0.969/0.965 | 0.847/0.911 |
| Celltypist\* | ACC | 0.680/0.666 | 0.966/0.848 | 0.879/0.832 |
| singleCellNet | ACC | 0.693/0.803 | 0.975/0.975 | 0.795/0.842 |
| ACTINN | ACC | 0.860/0.778 | 0.516/0.236 | 0.829/0.798 |
| ACTINN | ACC | 0.727/0.778 | 0.657/0.236 | 0.762/0.798 |
| SVM | ACC | 0.683/0.683 | 0.056/0.049 | 0.704/0.695 |

Note: * Benchmark datasets were renormalied before running the original implementation of Celltypist to match its form requirements.
Expand Down
252 changes: 127 additions & 125 deletions dance/datasets/singlemodality.py

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions dance/modules/single_modality/cell_type_annotation/actinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def compute_loss(self, z, y):
for i in range(0, len(self.model), 2): # TODO: replace with weight_decay
loss += self.lambd * torch.sum(self.model[i].weight**2) / 2

if self.print_cost:
print(loss)

return loss

def random_batches(self, x, y, batch_size=32, seed=None):
Expand Down Expand Up @@ -165,14 +162,20 @@ def fit(self, x_train, y_train, seed=None):
epoch_seed = seed if seed is None else seed + epoch
batches = self.random_batches(x_train, y_train, self.batch_size, epoch_seed)

tot_cost = tot_size = 0
for batch_x, batch_y in batches:
batch_cost = self.compute_loss(self.forward(batch_x), batch_y)
tot_cost += batch_cost.item()
tot_size += 1

optimizer.zero_grad()
batch_cost.backward()
optimizer.step()
lr_scheduler.step()

if (epoch % 10 == 0) and self.print_cost:
print(f"Epoch: {epoch:>4d} Loss: {tot_cost / tot_size:6.4f}")

print("Parameters have been trained!")

@torch.no_grad()
Expand Down
2 changes: 2 additions & 0 deletions dance/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dance.transforms import graph
from dance.transforms.cell_feature import CellPCA, WeightedFeaturePCA
from dance.transforms.filter import FilterGenesPercentile
from dance.transforms.interface import AnnDataTransform
from dance.transforms.spatial_feature import MorphologyFeature, SMEFeature

__all__ = [
"AnnDataTransform",
"CellPCA",
"FilterGenesPercentile",
"MorphologyFeature",
"SMEFeature",
"WeightedFeaturePCA",
Expand Down
65 changes: 65 additions & 0 deletions dance/transforms/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

from dance.exceptions import DevError
from dance.transforms.base import BaseTransform
from dance.typing import Literal, Optional


class FilterGenesPercentile(BaseTransform):
"""Filter genes based on percentiles of the summarized gene expressions."""

_DISPLAY_ATTRS = ("min_val", "max_val", "mode")
_MODES = ["sum", "cv"]

def __init__(self, min_val: Optional[float] = 1, max_val: Optional[float] = 99, mode: Literal["sum", "cv"] = "sum",
*, channel: Optional[str] = None, channel_type: Optional[str] = None, **kwargs):
"""Initialize FilterGenesPercentile.
Parameters
----------
min_val
Minimum percentile of the summarized expression value below which the genes will be discarded.
max_val
Maximum percentile of the summarized expression value above which the genes will be discarded.
mode
Summarization mode. Available options are `[sum|cv]`. `sum` calculates the sum of expression values, `cv`
uses the coefficient of variation (std / mean).
channel
Which channel, more specificailly, `layers`, to use. Use the default `.X` if not set. If `channel` is
specified, then need to specify `channel_type` to be `layers` as well.
channel_type
Type of channels specified. Only allow `None` (the default setting) or `layers` (when `channel` is
specified).
"""
super().__init__(**kwargs)

if (channel is not None) and (channel_type != "layers"):
raise ValueError(f"Only X layers is available for filtering genes, specified {channel_type=!r}")

if mode not in self._MODES:
raise ValueError(f"Unknown summarization mode {mode!r}, available options are {self._MODES}")

self.min_val = min_val
self.max_val = max_val
self.mode = mode
self.channel = channel
self.channel_type = channel_type

def __call__(self, data):
x = data.get_feature(return_type="default", channel=self.channel, channel_type=self.channel_type)

if self.mode == "sum":
gene_summary = np.array(x.sum(0)).ravel()
elif self.mode == "cv":
gene_summary = np.nan_to_num(np.array(x.std(0) / x.mean(0)), posinf=0, neginf=0).ravel()
else:
raise DevError(f"{self.mode!r} not expected, please inform dev to fix this error.")

percentile_lo = np.percentile(gene_summary, self.min_val)
percentile_hi = np.percentile(gene_summary, self.max_val)
mask = np.logical_and(gene_summary >= percentile_lo, gene_summary <= percentile_hi)
self.logger.info(f"Filtering genes based on {self.mode} expression percentiles in layer {self.channel!r}")
self.logger.info(f"{mask.size - mask.sum()} genes removed ({percentile_lo=:.2e}, {percentile_hi=:.2e})")

data._data = data.data[:, mask].copy()
Loading

0 comments on commit 7349425

Please sign in to comment.