Skip to content

Commit

Permalink
format fixes and improve inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
RemyLau committed Feb 4, 2024
1 parent dc36dc7 commit 821ad7b
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 83 deletions.
164 changes: 103 additions & 61 deletions dance/transforms/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dance.transforms.base import BaseTransform
from dance.transforms.interface import AnnDataTransform
from dance.typing import Dict, GeneSummaryMode, List, Literal, Logger, LogLevel, Optional, Tuple, Union
from dance.utils import default


def get_count(count_or_ratio: Optional[Union[float, int]], total: int) -> Optional[int]:
Expand Down Expand Up @@ -895,7 +896,9 @@ class FilterGenesScanpyOrder(BaseTransform):
Parameters
----------
order
Order of (min_counts,min_cells,max_counts,max_cells),e.g["min_counts","min_cells","max_counts","max_cells"]
Order of (min_counts, min_cells, max_counts, max_cells). For example,
``["min_counts", "min_cells", "max_counts", "max_cells"]`` or ``["max_counts", "min_cells"]``.
If not set, will be set by default to ``["min_counts", "min_cells", "max_counts", "max_cells"]``.
min_counts
Minimum number of counts required for a gene to be kept.
min_cells
Expand All @@ -913,14 +916,24 @@ class FilterGenesScanpyOrder(BaseTransform):
"""

def __init__(self, order: "list[str]" = ["min_counts", "min_cells", "max_counts",
"max_cells"], min_counts: Optional[int] = None,
min_cells: Optional[Union[float, int]] = None, max_counts: Optional[int] = None,
max_cells: Optional[Union[float, int]] = None, split_name: Optional[str] = None,
channel: Optional[str] = None, channel_type: Optional[str] = "X", **kwargs):
def __init__(
self,
order: Optional[List[str]] = None,
min_counts: Optional[int] = None,
min_cells: Optional[Union[float, int]] = None,
max_counts: Optional[int] = None,
max_cells: Optional[Union[float, int]] = None,
split_name: Optional[str] = None,
channel: Optional[str] = None,
channel_type: Optional[str] = "X",
**kwargs,
):
super().__init__(**kwargs)
self.filter_genes_order = order
self.logger.info(f"choose filter_genes_order f{self.filter_genes_order}")
self.filter_genes_order = default(
order,
["min_counts", "min_cells", "max_counts", "max_cells"],
)
self.logger.info(f"Filter genes order: {self.filter_genes_order}")
geneParameterDict = {
"min_counts": min_counts,
"min_cells": min_cells,
Expand All @@ -932,9 +945,13 @@ def __init__(self, order: "list[str]" = ["min_counts", "min_cells", "max_counts"
self.geneScanpyOrderDict = {}
for key in geneParameterDict.keys():
if key in self.filter_genes_order:
self.geneScanpyOrderDict[key] = FilterGenesScanpy(**{key:
geneParameterDict[key]}, split_name=split_name,
channel=channel, channel_type=channel_type, **kwargs)
self.geneScanpyOrderDict[key] = FilterGenesScanpy(
**{key: geneParameterDict[key]},
split_name=split_name,
channel=channel,
channel_type=channel_type,
**kwargs,
)
else:
self.logger.warning(f"{key} not in order,It makes no sense to set {key}")

Expand All @@ -945,15 +962,19 @@ def __call__(self, data: Data):


@register_preprocessor("filter", "gene")
class HighlyVariableGenesRawCount(BaseTransform):
"""Layer If provided, use `data.data.layers[layer]` for expression values instead of
`data.data.X`.
class HighlyVariableGenesRawCount(AnnDataTransform):
"""Filter for highly variable genes using raw count matrix.
Parameters
----------
layer
If provided, then use `data.data.layers[layer]` for expression values instead of the
default ``data.data.X``.
n_top_genes
Number of highly-variable genes to keep. Mandatory if `flavor='seurat_v3'`.
Number of highly-variable genes to keep.
span
The fraction of the data (cells) used when estimating the variance in the loess
model fit if `flavor='seurat_v3'`.
model fit if `flavor="seurat_v3"`.
subset
Inplace subset to highly-variable genes if `True` otherwise merely indicate
highly variable genes.
Expand All @@ -964,32 +985,35 @@ class HighlyVariableGenesRawCount(BaseTransform):
This simple process avoids the selection of batch-specific genes and acts as a
lightweight batch correction method. For all flavors, genes are first sorted
by how many batches they are a HVG. For dispersion-based flavors ties are broken
by normalized dispersion. If `flavor = 'seurat_v3'`, ties are broken by the median
by normalized dispersion. If `flavor = "seurat_v3"`, ties are broken by the median
(across batches) rank based on within-batch normalized variance.
check_values
Check if counts in selected layer are integers. A Warning is returned if set to True.
Only used if `flavor='seurat_v3'`.
Only used if `flavor="seurat_v3"`.
See also
--------
This is a wrapper for
https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html
"""

def __init__(self, layer: Optional[str] = None, n_top_genes: Optional[int] = 1000, span: Optional[float] = 0.3,
subset: bool = True, inplace: bool = True, batch_key: Optional[str] = None, check_values: bool = True,
**kwargs):
super().__init__(**kwargs)
self.transform = AnnDataTransform(
sc.pp.highly_variable_genes, layer=layer, n_top_genes=n_top_genes, batch_key=batch_key,
check_values=check_values, span=span, subset=subset, inplace=inplace, flavor='seurat_v3'
) #see https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html

def __call__(self, data: Data):
self.transform(data)
super().__init__(sc.pp.highly_variable_genes, layer=layer, n_top_genes=n_top_genes, batch_key=batch_key,
check_values=check_values, span=span, subset=subset, inplace=inplace, flavor="seurat_v3")


@register_preprocessor("filter", "gene")
class HighlyVariableGenesLogarithmizedByTopGenes(BaseTransform):
"""Layer If provided, use `data.data.layers[layer]` for expression values instead of
`data.data.X`.
class HighlyVariableGenesLogarithmizedByTopGenes(AnnDataTransform):
"""Filter for highly variable genes based on top genes.
Parameters
----------
layer
If provided, then use data.data.layers[layer]` for expression values instead of the
default `data.data.X`.
n_top_genes
Number of highly-variable genes to keep.
n_bins
Expand All @@ -1011,28 +1035,32 @@ class HighlyVariableGenesLogarithmizedByTopGenes(BaseTransform):
This simple process avoids the selection of batch-specific genes and acts as a
lightweight batch correction method. For all flavors, genes are first sorted
by how many batches they are a HVG. For dispersion-based flavors ties are broken
by normalized dispersion. If `flavor = 'seurat_v3'`, ties are broken by the median
by normalized dispersion. If `flavor = "seurat_v3"`, ties are broken by the median
(across batches) rank based on within-batch normalized variance.
See also
--------
This is a wrapper for
https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html
"""

def __init__(self, layer: Optional[str] = None, n_top_genes: Optional[int] = 1000, n_bins: int = 20,
flavor: Literal['seurat', 'cell_ranger'] = 'seurat', subset: bool = True, inplace: bool = True,
flavor: Literal["seurat", "cell_ranger"] = "seurat", subset: bool = True, inplace: bool = True,
batch_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.transform = AnnDataTransform(sc.pp.highly_variable_genes, layer=layer, n_top_genes=n_top_genes,
n_bins=n_bins, flavor=flavor, subset=subset, inplace=inplace,
batch_key=batch_key)

def __call__(self, data: Data):
self.transform(data)
super().__init__(sc.pp.highly_variable_genes, layer=layer, n_top_genes=n_top_genes, n_bins=n_bins,
flavor=flavor, subset=subset, inplace=inplace, batch_key=batch_key)


@register_preprocessor("filter", "gene")
class HighlyVariableGenesLogarithmizedByMeanAndDisp(BaseTransform):
"""Layer If provided, use `data.data.layers[layer]` for expression values instead of
`data.data.X`.
class HighlyVariableGenesLogarithmizedByMeanAndDisp(AnnDataTransform):
"""Filter for highly variable genes based on mean and dispersion.
Parameters
----------
layer
If provided, then use data.data.layers[layer]` for expression values instead of the
default `data.data.X`.
min_mean
min_mean
max_mean
Expand Down Expand Up @@ -1060,22 +1088,23 @@ class HighlyVariableGenesLogarithmizedByMeanAndDisp(BaseTransform):
This simple process avoids the selection of batch-specific genes and acts as a
lightweight batch correction method. For all flavors, genes are first sorted
by how many batches they are a HVG. For dispersion-based flavors ties are broken
by normalized dispersion. If `flavor = 'seurat_v3'`, ties are broken by the median
by normalized dispersion. If `flavor = "seurat_v3"`, ties are broken by the median
(across batches) rank based on within-batch normalized variance.
See also
--------
This is a wrapper for
https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html
"""

def __init__(self, layer: Optional[str] = None, min_disp: Optional[float] = 0.5, max_disp: Optional[float] = np.inf,
min_mean: Optional[float] = 0.0125, max_mean: Optional[float] = 3, n_bins: int = 20,
flavor: Literal['seurat', 'cell_ranger'] = 'seurat', subset: bool = True, inplace: bool = True,
flavor: Literal["seurat", "cell_ranger"] = "seurat", subset: bool = True, inplace: bool = True,
batch_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.transform = AnnDataTransform(sc.pp.highly_variable_genes, layer=layer, min_disp=min_disp,
max_disp=max_disp, min_mean=min_mean, max_mean=max_mean, n_bins=n_bins,
flavor=flavor, subset=subset, inplace=inplace, batch_key=batch_key)

def __call__(self, data: Data):
self.transform(data)
super().__init__(sc.pp.highly_variable_genes, layer=layer, min_disp=min_disp, max_disp=max_disp,
min_mean=min_mean, max_mean=max_mean, n_bins=n_bins, flavor=flavor, subset=subset,
inplace=inplace, batch_key=batch_key)


@register_preprocessor("filter", "cell")
Expand All @@ -1087,7 +1116,9 @@ class FilterCellsScanpyOrder(BaseTransform):
Parameters
----------
order
Order of (min_counts,min_genes,max_counts,max_genes),e.g["min_counts","min_genes","max_counts","max_genes"]
Order of (min_counts, min_cells, max_counts, max_cells). For example,
``["min_counts", "min_genes", "max_counts", "max_genes"]`` or ``["max_counts", "min_genes"]``.
If not set, will be set by default to ``["min_counts", "min_genes", "max_counts", "max_genes"]``.
min_counts
Minimum number of counts required for a cell to be kept.
min_genes
Expand All @@ -1105,14 +1136,21 @@ class FilterCellsScanpyOrder(BaseTransform):
"""

def __init__(self, order: "list[str]" = ["min_counts", "min_genes", "max_counts",
"max_genes"], min_counts: Optional[int] = None,
min_genes: Optional[Union[float, int]] = None, max_counts: Optional[int] = None,
max_genes: Optional[Union[float, int]] = None, split_name: Optional[str] = None,
channel: Optional[str] = None, channel_type: Optional[str] = "X", **kwargs):
def __init__(
self,
order: Optional[List[str]] = None,
min_counts: Optional[int] = None,
min_genes: Optional[Union[float, int]] = None,
max_counts: Optional[int] = None,
max_genes: Optional[Union[float, int]] = None,
split_name: Optional[str] = None,
channel: Optional[str] = None,
channel_type: Optional[str] = "X",
**kwargs,
):
super().__init__(**kwargs)
self.filter_cells_order = order
self.logger.info(f"choose filter_cells_order f{self.filter_cells_order}")
self.filter_cells_order = default(order, ["min_counts", "min_genes", "max_counts", "max_genes"])
self.logger.info(f"Filter cells order: {self.filter_cells_order}")
cellParameterDict = {
"min_counts": min_counts,
"min_genes": min_genes,
Expand All @@ -1124,9 +1162,13 @@ def __init__(self, order: "list[str]" = ["min_counts", "min_genes", "max_counts"
self.cellScanpyOrderDict = {}
for key in cellParameterDict.keys():
if key in self.filter_cells_order:
self.cellScanpyOrderDict[key] = FilterCellsScanpy(**{key:
cellParameterDict[key]}, split_name=split_name,
channel=channel, channel_type=channel_type, **kwargs)
self.cellScanpyOrderDict[key] = FilterCellsScanpy(
**{key: cellParameterDict[key]},
split_name=split_name,
channel=channel,
channel_type=channel_type,
**kwargs,
)
else:
self.logger.warning(f"{key} not in order,It makes no sense to set {key}")

Expand Down
42 changes: 20 additions & 22 deletions dance/transforms/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def info(n, th, mu, y, w):


@register_preprocessor("normalize")
class Log1P(BaseTransform):
class Log1P(AnnDataTransform):
"""Logarithmize the data matrix.
Computes :math:`X = \\log(X + 1)`,
Expand All @@ -513,20 +513,21 @@ class Log1P(BaseTransform):
obsm
Entry of obsm to transform.
See also
--------
This is a wrapper for
https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.log1p.html
"""

def __init__(self, base: Optional[Number] = None, copy: bool = False, chunked: bool = None,
chunk_size: Optional[int] = None, layer: Optional[str] = None, obsm: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.transform = AnnDataTransform(sc.pp.log1p, base=base, chunked=chunked, chunk_size=chunk_size, layer=layer,
obsm=obsm, copy=copy, **kwargs)

def __call__(self, data: Data):
self.transform(data=data)
super().__init__(sc.pp.log1p, base=base, chunked=chunked, chunk_size=chunk_size, layer=layer, obsm=obsm,
copy=copy, **kwargs)


@register_preprocessor("normalize")
class NormalizeTotal(BaseTransform):
class NormalizeTotal(AnnDataTransform):
"""Normalize counts per cell.
Normalize each cell by total counts over all genes,
Expand All @@ -536,10 +537,7 @@ class NormalizeTotal(BaseTransform):
If max_fraction is less than 1.0, very highly expressed genes are excluded
from the computation of the normalization factor (size factor) for each
cell. This is meaningful as these can strongly influence the resulting
normalized values for all other genes [Weinreb17]_.
Similar functions are used, for example, by Seurat [Satija15]_, Cell Ranger
[Zheng17]_ or SPRING [Weinreb17]_.
normalized values for all other genes.
Params
------
Expand Down Expand Up @@ -567,20 +565,20 @@ class NormalizeTotal(BaseTransform):
copy
Whether to modify copied input object. Not compatible with inplace=False.
See also
--------
This is a wrapper for
https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.normalize_total.html
"""

def __init__(self, target_sum: Optional[float] = None, max_fraction: float = 0.05, key_added: Optional[str] = None,
layer: Optional[str] = None, layers: Union[Literal['all'], Iterable[str]] = None,
layer_norm: Optional[str] = None, inplace: bool = True, copy: bool = False, **kwargs):
super().__init__(sc.pp.normalize_total, target_sum=target_sum, key_added=key_added, layer=layer, layers=layers,
layer_norm=layer_norm, inplace=inplace, copy=copy, exclude_highly_expressed=True,
max_fraction=max_fraction)

super().__init__(**kwargs)
self.transform = AnnDataTransform(sc.pp.normalize_total, target_sum=target_sum, key_added=key_added,
layer=layer, layers=layers, layer_norm=layer_norm, inplace=inplace, copy=copy,
exclude_highly_expressed=True, max_fraction=max_fraction)
self.logger.info("max_fraction must be valid")
if max_fraction == 1.0:
self.logger.info(
"When max_fraction is equal to 1.0, it is equivalent to setting exclude_highly_expressed=False.")

def __call__(self, data: Data):
self.transform(data)
self.logger.info("max_fraction set to 1.0, this is equivalent to setting exclude_highly_expressed=False.")

0 comments on commit 821ad7b

Please sign in to comment.