Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unify cell-type annotation dataloader and example scripts format #137

Merged
merged 24 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c34ed79
remove unused params
RemyLau Jan 18, 2023
2bc8bcd
fix cell_label_to_df, add option to encode single label as string
RemyLau Jan 18, 2023
f06faad
refactor svm data loading
RemyLau Jan 18, 2023
43d1211
update default options and remove unused args
RemyLau Jan 18, 2023
17939c0
implement percentile gene filter transform
RemyLau Jan 18, 2023
42537e4
refactor actinn data loading
RemyLau Jan 18, 2023
3c691a6
fix cell-type name mapping, add info log
RemyLau Jan 19, 2023
3ce2625
add cv option to FilterGenesPercentile
RemyLau Jan 19, 2023
d51eecd
fix cell_label_to_df: skip write if no match
RemyLau Jan 19, 2023
2844dd1
turn ACTINN normalization off by default
RemyLau Jan 19, 2023
2df3a9a
remove unused functions, imports, and params
RemyLau Jan 19, 2023
efa6fdb
fix get_map_dict: obtain mapped values
RemyLau Jan 19, 2023
22918c2
improve training log info in actinn
RemyLau Jan 19, 2023
f968e3d
improve url dict format
RemyLau Jan 19, 2023
a751fcd
add mapped test cell-type info to debug log
RemyLau Jan 19, 2023
689cee9
add option to change global log level
RemyLau Jan 19, 2023
8d01c59
update actinn example settings and benchmark scores
RemyLau Jan 19, 2023
1494de6
update celltypist and singlecellnet dataloader
RemyLau Jan 19, 2023
b4bf4e3
remove unused functions
RemyLau Jan 19, 2023
f041629
update scdeepsort dataloader
RemyLau Jan 19, 2023
95ed505
format fixes and remove unused functions
RemyLau Jan 19, 2023
508f927
remove data_type, unified
RemyLau Jan 19, 2023
dce607e
sort parsed args
RemyLau Jan 19, 2023
5fa2b6b
unify cell-type annotation example script format
RemyLau Jan 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ Note: the data split modality of DeepImpute is different from ScGNN and GraphSCI

| Model | Evaluation Metric | Mouse Brain 2695 (current/reported) | Mouse Spleen 1759 (current/reported) | Mouse Kidney 203 (current/reported) |
| ------------- | ----------------- | ----------------------------------- | ------------------------------------ | ----------------------------------- |
| scDeepsort | ACC | 0.363/0.363 | 0.965 /0.965 | 0.901/0.911 |
| scDeepsort | ACC | 0.542/0.363 | 0.969/0.965 | 0.847/0.911 |
| Celltypist\* | ACC | 0.680/0.666 | 0.966/0.848 | 0.879/0.832 |
| singleCellNet | ACC | 0.693/0.803 | 0.975/0.975 | 0.795/0.842 |
| ACTINN | ACC | 0.860/0.778 | 0.516/0.236 | 0.829/0.798 |
| ACTINN | ACC | 0.727/0.778 | 0.657/0.236 | 0.762/0.798 |
| SVM | ACC | 0.683/0.683 | 0.056/0.049 | 0.704/0.695 |

Note: * Benchmark datasets were renormalied before running the original implementation of Celltypist to match its form requirements.
Expand Down
252 changes: 127 additions & 125 deletions dance/datasets/singlemodality.py

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions dance/modules/single_modality/cell_type_annotation/actinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def compute_loss(self, z, y):
for i in range(0, len(self.model), 2): # TODO: replace with weight_decay
loss += self.lambd * torch.sum(self.model[i].weight**2) / 2

if self.print_cost:
print(loss)

return loss

def random_batches(self, x, y, batch_size=32, seed=None):
Expand Down Expand Up @@ -165,14 +162,20 @@ def fit(self, x_train, y_train, seed=None):
epoch_seed = seed if seed is None else seed + epoch
batches = self.random_batches(x_train, y_train, self.batch_size, epoch_seed)

tot_cost = tot_size = 0
for batch_x, batch_y in batches:
batch_cost = self.compute_loss(self.forward(batch_x), batch_y)
tot_cost += batch_cost.item()
tot_size += 1

optimizer.zero_grad()
batch_cost.backward()
optimizer.step()
lr_scheduler.step()

if (epoch % 10 == 0) and self.print_cost:
print(f"Epoch: {epoch:>4d} Loss: {tot_cost / tot_size:6.4f}")

print("Parameters have been trained!")

@torch.no_grad()
Expand Down
2 changes: 2 additions & 0 deletions dance/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dance.transforms import graph
from dance.transforms.cell_feature import CellPCA, WeightedFeaturePCA
from dance.transforms.filter import FilterGenesPercentile
from dance.transforms.interface import AnnDataTransform
from dance.transforms.spatial_feature import MorphologyFeature, SMEFeature

__all__ = [
"AnnDataTransform",
"CellPCA",
"FilterGenesPercentile",
"MorphologyFeature",
"SMEFeature",
"WeightedFeaturePCA",
Expand Down
65 changes: 65 additions & 0 deletions dance/transforms/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

from dance.exceptions import DevError
from dance.transforms.base import BaseTransform
from dance.typing import Literal, Optional


class FilterGenesPercentile(BaseTransform):
"""Filter genes based on percentiles of the summarized gene expressions."""

_DISPLAY_ATTRS = ("min_val", "max_val", "mode")
_MODES = ["sum", "cv"]

def __init__(self, min_val: Optional[float] = 1, max_val: Optional[float] = 99, mode: Literal["sum", "cv"] = "sum",
*, channel: Optional[str] = None, channel_type: Optional[str] = None, **kwargs):
"""Initialize FilterGenesPercentile.

Parameters
----------
min_val
Minimum percentile of the summarized expression value below which the genes will be discarded.
max_val
Maximum percentile of the summarized expression value above which the genes will be discarded.
mode
Summarization mode. Available options are `[sum|cv]`. `sum` calculates the sum of expression values, `cv`
uses the coefficient of variation (std / mean).
channel
Which channel, more specificailly, `layers`, to use. Use the default `.X` if not set. If `channel` is
specified, then need to specify `channel_type` to be `layers` as well.
channel_type
Type of channels specified. Only allow `None` (the default setting) or `layers` (when `channel` is
specified).

"""
super().__init__(**kwargs)

if (channel is not None) and (channel_type != "layers"):
raise ValueError(f"Only X layers is available for filtering genes, specified {channel_type=!r}")

if mode not in self._MODES:
raise ValueError(f"Unknown summarization mode {mode!r}, available options are {self._MODES}")

self.min_val = min_val
self.max_val = max_val
self.mode = mode
self.channel = channel
self.channel_type = channel_type

def __call__(self, data):
x = data.get_feature(return_type="default", channel=self.channel, channel_type=self.channel_type)

if self.mode == "sum":
gene_summary = np.array(x.sum(0)).ravel()
elif self.mode == "cv":
gene_summary = np.nan_to_num(np.array(x.std(0) / x.mean(0)), posinf=0, neginf=0).ravel()
else:
raise DevError(f"{self.mode!r} not expected, please inform dev to fix this error.")

percentile_lo = np.percentile(gene_summary, self.min_val)
percentile_hi = np.percentile(gene_summary, self.max_val)
mask = np.logical_and(gene_summary >= percentile_lo, gene_summary <= percentile_hi)
self.logger.info(f"Filtering genes based on {self.mode} expression percentiles in layer {self.channel!r}")
self.logger.info(f"{mask.size - mask.sum()} genes removed ({percentile_lo=:.2e}, {percentile_hi=:.2e})")

data._data = data.data[:, mask].copy()
Loading