Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NBLAST score matrix calculation #28

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions navis/core/neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2637,3 +2637,6 @@ def to_skeleton(self, scale_vec: float = 1) -> TreeNeuron:
tn._soma = self._soma

return tn

def __len__(self):
return len(self.points)
1 change: 1 addition & 0 deletions navis/nbl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@

from .nblast_funcs import nblast, nblast_allbyall, nblast_smart
from .synblast_funcs import synblast
from .smat import Lookup2d, smat_fcwb, parse_score_fn

__all__ = ['nblast', 'nblast_allbyall', 'nblast_smart', 'synblast']
87 changes: 49 additions & 38 deletions navis/nbl/nblast_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numbers
import os
from warnings import warn

import numpy as np
import pandas as pd
Expand All @@ -23,30 +24,47 @@
from typing import Union, Optional
from typing_extensions import Literal

from .smat import Lookup2d, parse_score_fn, SCORE_FN_DESCR
from .. import core, utils
from ..core import NeuronList, Dotprops, make_dotprops
from .. import config

__all__ = ['nblast', 'nblast_smart', 'nblast_allbyall', 'sim_to_dist']

fp = os.path.dirname(__file__)
smat_path = os.path.join(fp, 'score_mats')

logger = config.logger


class ScoringFunction:
"""Class representing scoring function."""
"""[DEPRECATED] Class representing scoring function.

Use a ``navis.nbl.smat.Lookup2d`` for score functions
like those used in Costa et al. 2016
(see the ``.from_dataframe()`` class method
to load it from pandas).
The actual score matrices from FCWB can be accessed with
``navis.nbl.smat.smat_fcwb()``.

Otherwise, use any compatible callable,
e.g. ``operator.mul`` (used by this class for ``smat=None``).
"""

clbarnes marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, smat):
if isinstance(smat, type(None)):
self._deprecation_warning("operator.mul")
self.scoring_function = self.pass_through
elif isinstance(smat, (pd.DataFrame, str)):
self.parse_matrix(smat)
self.scoring_function = self.score_lookup
else:
raise TypeError

def _deprecation_warning(self, alternative: str):
warn(
f"{type(self).__qualname__} is deprecated, use {alternative}.",
FutureWarning,
)

def __call__(self, dist, dot):
return self.scoring_function(dist, dot)

Expand All @@ -63,10 +81,15 @@ def score_lookup(self, dist, dot):
def parse_matrix(self, smat):
"""Parse matrix."""
if isinstance(smat, str):
self._deprecation_warning(
f"{Lookup2d.__qualname__}.from_dataframe(pandas.read_csv(smat, index_col=0))"
)
smat = pd.read_csv(smat, index_col=0)
else:
self._deprecation_warning(f"{Lookup2d.__qualname__}.from_dataframe(smat)")

if not isinstance(smat, pd.DataFrame):
raise TypeError(f'Excepted filepath or DataFrame, got "{type(smat)}"')
raise TypeError(f'Expected filepath or DataFrame, got "{type(smat)}"')

self.cells = smat.to_numpy()

Expand All @@ -92,27 +115,28 @@ def parse_interval(self, s):


class NBlaster:
"""Implements version 2 of the NBLAST algorithm.
f"""Implements version 2 of the NBLAST algorithm.

Please note that some properties are computed on initialization and
changing parameters (e.g. ``use_alpha``) at a later stage will mess things
up!

The highly flexible ``smat`` argument converts raw point match parameters
into a single score representing how good that match is.
Most simply, it is an NBLAST score function.
{SCORE_FN_DESCR}

Parameters
----------
use_alpha : bool
Whether or not to use alpha values for the scoring.
If True, the dotproduct of nearest neighbor vectors will
be scaled by ``sqrt(alpha1 * alpha2)``.
normalzed : bool
normalized : bool
If True, will normalize scores by the best possible score
(i.e. self-self) of the query neuron.
smat : str | pd.DataFrame
Score matrix. If 'auto' (default), will use scoring matrices
from FCWB. Same behaviour as in R's nat.nblast
implementation. If ``smat=None`` the scores will be
generated as the product of the distances and the dotproduct
of the vectors of nearest-neighbor pairs.
smat : Callable[[float, float], float] | str | os.PathLike | pd.DataFrame, default "smat"
See ``navis.nbl.smat.parse_score_fn``.
progress : bool
If True, will show a progress bar.

Expand All @@ -124,15 +148,7 @@ def __init__(self, use_alpha=False, normalized=True, smat='auto', progress=True)
self.normalized = normalized
self.progress = progress

if smat == 'auto':
if self.use_alpha:
smat = pd.read_csv(f'{smat_path}/smat_alpha_fcwb.csv',
index_col=0)
else:
smat = pd.read_csv(f'{smat_path}/smat_fcwb.csv',
index_col=0)

self.score_fn = ScoringFunction(smat)
self.score_fn = parse_score_fn(smat, use_alpha)

self.self_hits = []
self.dotprops = []
Expand Down Expand Up @@ -357,12 +373,8 @@ def nblast_smart(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotpro
that have lots of branches.
normalized : bool, optional
Whether to return normalized NBLAST scores.
smat : str | pd.DataFrame
Score matrix. If 'auto' (default), will use scoring matrices
from FCWB. Same behaviour as in R's nat.nblast
implementation. If ``smat=None`` the scores will be
generated as the product of the distances and the dotproduct
of the vectors of nearest-neighbor pairs.
smat : str | pd.DataFrame, default "auto"
Score matrix. See ``navis.nbl.smat.parse_score_fn``.
progress : bool
Whether to show progress bars.

Expand Down Expand Up @@ -412,6 +424,7 @@ def nblast_smart(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotpro
A synapse-based variant of NBLAST.

"""
smat = parse_score_fn(smat, use_alpha)
utils.eval_param(criterion, name='criterion',
allowed_values=("percentile", "score", "N"))

Expand Down Expand Up @@ -624,15 +637,12 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'],
individual processes.
use_alpha : bool, optional
Emphasizes neurons' straight parts (backbone) over parts
that have lots of branches.
that have lots of branches.
normalized : bool, optional
Whether to return normalized NBLAST scores.
smat : str | pd.DataFrame
Score matrix. If 'auto' (default), will use scoring matrices
from FCWB. Same behaviour as in R's nat.nblast
implementation. If ``smat=None`` the scores will be
generated as the product of the distances and the dotproduct
of the vectors of nearest-neighbor pairs.
Score matrix.
See ``navis.nbl.smat.parse_score_fn``.
progress : bool
Whether to show progress bars.

Expand Down Expand Up @@ -677,6 +687,8 @@ def nblast(query: Union['core.TreeNeuron', 'core.NeuronList', 'core.Dotprops'],
A synapse-based variant of NBLAST.

"""
smat = parse_score_fn(smat, use_alpha)

# Check if query or targets are in microns
# Note this test can return `None` if it can't be determined
if check_microns(query) is False:
Expand Down Expand Up @@ -784,11 +796,8 @@ def nblast_allbyall(x: NeuronList,
normalized : bool, optional
Whether to return normalized NBLAST scores.
smat : str | pd.DataFrame, optional
Score matrix. If 'auto' (default), will use scoring matrices
from FCWB. Same behaviour as in R's nat.nblast
implementation. If ``smat=None`` the scores will be
generated as the product of the distances and the dotproduct
of the vectors of nearest-neighbor pairs.
Score matrix.
See ``navis.nbl.smat.parse_score_fn``.
progress : bool
Whether to show progress bars.

Expand Down Expand Up @@ -828,6 +837,8 @@ def nblast_allbyall(x: NeuronList,
For generic query -> target nblasts.

"""
smat = parse_score_fn(smat, use_alpha)

# Check if query or targets are in microns
# Note this test can return `None` if it can't be determined
if check_microns(x) is False:
Expand Down
Loading