Skip to content

Commit

Permalink
Merge pull request #127 from PixelgenTechnologies/feature/exe-1525-re…
Browse files Browse the repository at this point in the history
…factor-analysis

Feature/exe 1525 refactor analysis
  • Loading branch information
johandahlberg authored Apr 24, 2024
2 parents 7c0b937 + c9c8d79 commit a79d130
Show file tree
Hide file tree
Showing 10 changed files with 660 additions and 87 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Renaming of component metrics in adata
* Use MPX graph compatible permutation strategy when calculating Moran's I.
* Marker filtering is now done after count transformation in polarization score calculation.
* Use common analysis engine to orchestrate running different "per component" analysis, like
polarization and colocalization analysis (yielding a roughly 3x speed-up over previous approach).

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ exclude = ["docs/conf.py"]

[tool.ruff.lint.per-file-ignores]
"pixelator/report/__init__.py" = ["E501"]
"**/tests/**" = ["D101", "D102", "D103", "D200", "D202", "D205", "D212" , "D400", "D401", "D403", "D404", "D415"]
"**/tests/**" = ["D101", "D102", "D103", "D105", "D107", "D200", "D202", "D205", "D212" , "D400", "D401", "D403", "D404", "D415"]
# Since click uses a different layout for the docs strings to generate the
# cli docs, we ignore these rules here.
"src/pixelator/cli/**" = ["D200", "D212", "D400", "D415"]
Expand Down
92 changes: 20 additions & 72 deletions src/pixelator/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from dataclasses import asdict, dataclass
from pathlib import Path

from pixelator.analysis.colocalization import colocalization_scores
from pixelator.analysis.analysis_engine import PerComponentAnalysis, run_analysis
from pixelator.analysis.colocalization.types import TransformationTypes
from pixelator.analysis.polarization import polarization_scores
from pixelator.analysis.polarization.types import PolarizationNormalizationTypes
from pixelator.pixeldataset import (
PixelDataset,
Expand Down Expand Up @@ -41,19 +40,10 @@ def analyse_pixels(
output: str,
output_prefix: str,
metrics_file: str,
compute_polarization: bool,
compute_colocalization: bool,
use_full_bipartite: bool,
polarization_normalization: PolarizationNormalizationTypes,
polarization_n_permutations: int,
polarization_min_marker_count: int,
colocalization_transformation: TransformationTypes,
colocalization_neighbourhood_size: int,
colocalization_n_permutations: int,
colocalization_min_region_count: int,
verbose: bool,
analysis_to_run: list[PerComponentAnalysis],
) -> None:
"""Calculate Moran's I statistics for a PixelDataset.
"""Run analysis functions on a PixelDataset.
This function takes a pxl file that has been generated
with `pixelator annotate`. The function then uses the `edge list` and
Expand All @@ -65,27 +55,10 @@ def analyse_pixels(
:param output: the path to the output file
:param output_prefix: the prefix to prepend to the output file
:param metrics_file: the path to a JSON file to write metrics
:param compute_polarization: compute polarization scores when True
:param compute_colocalization: compute colocalization scores when True
:param use_full_bipartite: use the bipartite graph instead of the
one-node-projection (UPIA)
:param polarization_normalization: the method to use to normalize the
antibody counts (raw, log1p, or clr)
:param polarization_n_permutations: Select number of permutations used to
calculate empirical p-values of the
polarization scores
:param polarization_min_marker_count: the minimum number of counts of a marker to calculate
the Moran's I statistic
:param colocalization_transformation: Select a transformation method to use
for the colocalization
:param colocalization_neighbourhood_size: Set the size of the neighbourhood to
consider when computing the colocalization
:param colocalization_n_permutations: Select number of permutations used to
calculate empirical p-values of the
colocalization scores
:param colocalization_min_region_count: The minimum size of the region (e.g. number
of counts in the neighbourhood) required
for it to be considered
:param analysis_to_run: a list of analysis functions (`PerComponentAnalysis` instances) to apply
to each component
:param verbose: run if verbose mode when true
:returns: None
:rtype: None
Expand All @@ -95,52 +68,27 @@ def analyse_pixels(

# load the PixelDataset
dataset = PixelDataset.from_file(input)
edgelist = dataset.edgelist

metrics = {} # type: ignore
names_of_analyses = {analysis.ANALYSIS_NAME for analysis in analysis_to_run}

compute_polarization = "yes" if "polarization" in names_of_analyses else "no"
compute_colocalization = "yes" if "colocalization" in names_of_analyses else "no"

metrics = dict()
metrics["polarization"] = "yes" if compute_polarization else "no"
metrics["colocalization"] = "yes" if compute_colocalization else "no"

# polarization scores
if compute_polarization:
# obtain polarization scores
scores = polarization_scores(
edgelist=edgelist,
use_full_bipartite=use_full_bipartite,
normalization=polarization_normalization,
n_permutations=polarization_n_permutations,
min_marker_count=polarization_min_marker_count,
)
dataset.polarization = scores

# colocalization scores
if compute_colocalization:
# obtain colocalization scores
scores = colocalization_scores(
edgelist=edgelist,
use_full_bipartite=use_full_bipartite,
transformation=colocalization_transformation,
neighbourhood_size=colocalization_neighbourhood_size,
n_permutations=colocalization_n_permutations,
min_region_count=colocalization_min_region_count,
)
dataset.colocalization = scores
dataset = run_analysis(
pxl_dataset=dataset,
analysis_to_run=analysis_to_run,
use_full_bipartite=use_full_bipartite,
)

dataset.metadata["analysis"] = {
"params": asdict(
AnalysisParameters(
compute_colocalization=compute_colocalization,
compute_polarization=compute_polarization,
use_full_bipartite=use_full_bipartite,
polarization_normalization=polarization_normalization,
polarization_n_permutations=polarization_n_permutations,
polarization_min_marker_count=polarization_min_marker_count,
colocalization_transformation=colocalization_transformation,
colocalization_neighbourhood_size=colocalization_neighbourhood_size,
colocalization_n_permutations=colocalization_n_permutations,
colocalization_min_region_count=colocalization_min_region_count,
)
)
"params": {
analysis.ANALYSIS_NAME: analysis.parameters()
for analysis in analysis_to_run
}
}
# save dataset
dataset.save(
Expand Down
184 changes: 184 additions & 0 deletions src/pixelator/analysis/analysis_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Analysis engine capable of running a list of analysis functions on each component in a pixeldataset.
Copyright © 2024 Pixelgen Technologies AB.
"""

import logging
from collections import defaultdict
from functools import partial
from queue import Queue
from typing import Callable, Iterable, Protocol

import pandas as pd

from pixelator.graph import Graph
from pixelator.pixeldataset import PixelDataset
from pixelator.utils import (
get_process_pool_executor,
)

logger = logging.getLogger(__name__)


class PerComponentAnalysis(Protocol):
"""Protocol for analysis functions that are run on each component in a PixelDataset."""

ANALYSIS_NAME: str

def run_on_component(self, component: Graph, component_id: str) -> pd.DataFrame:
"""Run the analysis on this component."""
...

def concatenate_data(self, data: Iterable[pd.DataFrame]) -> pd.DataFrame:
"""Concatenate the data. Override this if you need custom concatenation behavior."""
try:
scores = pd.concat(data, axis=0)
return scores
except ValueError as error:
logger.error(f"No data was found to compute {self.ANALYSIS_NAME}")
raise error

def post_process_data(self, data: pd.DataFrame) -> pd.DataFrame:
"""Post process the data (e.g. adjust p-values). Override this if your data needs post processing."""
return data

def add_to_pixel_dataset(
self, data: pd.DataFrame, pxl_dataset: PixelDataset
) -> PixelDataset:
"""Add the data in the right place in the pxl_dataset."""
...

def parameters(self) -> dict:
"""Return the parameters of the `PerComponentAnalysis`.
This is used e.g. to store the metadata the parameters of the analysis
in the run metadata.
"""
return {f"{self.ANALYSIS_NAME}": vars(self)}


class _AnalysisManager:
"""Analysis manager that can run a number of analysis across a stream of components.
The analysis manager is responsible for hooking up the analysis functions and making
them run on each component in the stream. The main workflow it uses is outlined in the
`execute` method.
"""

def __init__(
self,
analysis_to_run: Iterable[PerComponentAnalysis],
component_stream: Iterable[tuple[str, Graph]],
):
self.analysis_to_run = {
analysis.ANALYSIS_NAME: analysis for analysis in analysis_to_run
}
self.component_stream = component_stream

def _prepare_computation(
self,
) -> Iterable[tuple[str, Callable[[Graph, str], pd.DataFrame]]]:
for component_id, component_graph in self.component_stream:
for _analysis_name, analysis in self.analysis_to_run.items():
yield (
_analysis_name,
partial(
analysis.run_on_component,
component=component_graph,
component_id=component_id,
),
)

def _execute_computations_in_parallel(self, prepared_computations):
futures = Queue()
with get_process_pool_executor() as executor:
for analysis_name, func in prepared_computations:
logger.debug("Putting %s in the queue for analysis", analysis_name)
future = executor.submit(func)
futures.put((analysis_name, future))

while not futures.empty():
key, future = futures.get()
if future.done():
logger.debug("Future for %s is done", key)
yield (key, future.result())
else:
futures.put((key, future))

def _post_process(self, per_component_results):
concatenated_data = defaultdict(list)
for key, data in per_component_results:
concatenated_data[key].append(data)

for key, data_list in concatenated_data.items():
yield (
key,
self.analysis_to_run[key].post_process_data(
self.analysis_to_run[key].concatenate_data(data_list)
),
)

def _add_to_pixel_dataset(self, post_processed_data, pxl_dataset: PixelDataset):
for key, data in post_processed_data:
pxl_dataset = self.analysis_to_run[key].add_to_pixel_dataset(
data, pxl_dataset
)
return pxl_dataset

def execute(self, pixel_dataset) -> PixelDataset:
"""Execute the analysis on the provided pixel dataset."""
prepared_computations = self._prepare_computation()
per_component_results = self._execute_computations_in_parallel(
prepared_computations
)
post_processed_data = self._post_process(per_component_results)
pxl_dataset_with_results = self._add_to_pixel_dataset(
post_processed_data, pixel_dataset
)
return pxl_dataset_with_results


def edgelist_to_component_stream(
dataset: PixelDataset, use_full_bipartite: bool
) -> Iterable[tuple[str, Graph]]:
"""Convert the edgelist in the dataset to a stream component ids and their component graphs."""
for component_id, component_df in (
dataset.edgelist_lazy.collect()
.partition_by(by="component", as_dict=True)
.items()
):
yield (
component_id,
Graph.from_edgelist(
edgelist=component_df.lazy(),
add_marker_counts=True,
simplify=True,
use_full_bipartite=use_full_bipartite,
),
)


def run_analysis(
pxl_dataset: PixelDataset,
analysis_to_run: list[PerComponentAnalysis],
use_full_bipartite: bool = False,
) -> PixelDataset:
"""Run the provided list of `PerComponentAnalysis` on the components in the `pxl_dataset`.
:param pxl_dataset: The PixelDataset to run the analysis on.
:param analysis_to_run: A list of `PerComponentAnalysis` to run on the components in the `pxl_dataset`.
:param use_full_bipartite: Whether to use the full bipartite graph when creating the components.
:returns: A `PixelDataset` instance with the provided analysis added to it.
"""
if not analysis_to_run:
logger.warning("No analysis functions were provided")
return pxl_dataset

analysis_manager = _AnalysisManager(
analysis_to_run,
component_stream=edgelist_to_component_stream(
pxl_dataset, use_full_bipartite=use_full_bipartite
),
)
pxl_dataset = analysis_manager.execute(pxl_dataset)
return pxl_dataset
Loading

0 comments on commit a79d130

Please sign in to comment.