diff --git a/.gitignore b/.gitignore index da2d03ca..d35c0bea 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,5 @@ ENV/ #*.swc #*.json MANIFEST + +tmp/ diff --git a/docs/source/api.rst b/docs/source/api.rst index 4161dfbe..d82b4cc6 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -364,6 +364,15 @@ Functions to compare morphology. navis.synblast navis.persistence_distances +Utilities for creating your own score matrices for NBLAST can be found in + +.. autosummary:: + :toctree: generated/ + + navis.nbl.smat.Lookup2d + navis.nbl.smat.Digitizer + navis.nbl.smat.LookupDistDotBuilder + Polarity metrics ---------------- .. autosummary:: diff --git a/docs/source/gallery.rst b/docs/source/gallery.rst index 62a61f4b..4da347ad 100644 --- a/docs/source/gallery.rst +++ b/docs/source/gallery.rst @@ -72,6 +72,7 @@ function. There you will also find more examples. tutorials/nblast tutorials/nblast_flycircuit tutorials/nblast_hemibrain + tutorials/smat .. raw:: html diff --git a/docs/source/tutorials/smat.ipynb b/docs/source/tutorials/smat.ipynb new file mode 100644 index 00000000..45672bb2 --- /dev/null +++ b/docs/source/tutorials/smat.ipynb @@ -0,0 +1,491 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ".. _smat_intro:\n", + "\n", + "NBLAST score matrix generation\n", + "******************************\n", + "\n", + "NBLAST calculates the similarity between neurons' morphology.\n", + "For more information on NBLAST in ``navis``, see the other tutorials.\n", + "\n", + "The core of the algorithm is the function which converts point matches\n", + "(a distance and a dot product) into a score\n", + "expressing how likely they are to have come from the same point cloud.\n", + "This is generally a 2D lookup table, referred to as the score matrix,\n", + "generated by \"training\" it on neurons known to be matching or non-matching.\n", + "\n", + "``navis`` provides (and uses by default) the score matrix used in the original publication (Costa et al., 2016).\n", + "This works quite well in many cases.\n", + "However, how appropriate it is for your data depends on a number of factors:\n", + "\n", + "* How big your neurons are (commonly addressed by scaling the distance axis of the built-in score matrix)\n", + "* How you have pre-processed your neurons (pruning dendrites, resampling etc.)\n", + "* Your actual task (matching left-right pairs, finding lineages etc.)\n", + "* How distinct you expect your matches and non-matches to be (e.g. how large a body volume you're drawing neurons from)\n", + "\n", + "Utilities in ``navis.nbl.smat`` allow you to train your own score matrix." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import navis\n", + "from navis.nbl.smat import Lookup2d, LookupDistDotBuilder, Digitizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "``Lookup2d`` is the lookup table, and can be given to any nblast-related class or function.\n", + "Each axis represented by a ``Digitizer``, which converts continuous values into discrete indices.\n", + "These indices are used to look up score values in an array.\n", + "\n", + "The ``LookupDistDotBuilder`` is a class which generates ``Lookup2d`` instances from training data.\n", + "\n", + "First, we need some training data.\n", + "Let's augment our set of example neurons by randomly mutating them:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "# use a replicable random number generator\n", + "RNG = np.random.default_rng(2021)\n", + "\n", + "\n", + "def augment_neuron(\n", + " nrn: navis.TreeNeuron, scale_sigma=0.1, translation_sigma=50, jitter_sigma=10\n", + "):\n", + " nrn = nrn.copy(deepcopy=True)\n", + " nrn.name += \"_aug\"\n", + " dims = list(\"xyz\")\n", + " coords = nrn.nodes[dims].to_numpy()\n", + "\n", + " # translate whole neuron\n", + " coords += RNG.normal(scale=translation_sigma, size=coords.shape[-1])\n", + " # jitter individual coordinates\n", + " coords += RNG.normal(scale=jitter_sigma, size=coords.shape)\n", + " # rescale\n", + " mean = np.mean(coords, axis=0)\n", + " coords -= mean\n", + " coords *= RNG.normal(loc=1.0, scale=scale_sigma)\n", + " coords += mean\n", + "\n", + " nrn.nodes[dims] = coords\n", + " return nrn\n", + "\n", + "\n", + "original = list(navis.example_neurons())\n", + "jittered = [augment_neuron(n) for n in original]\n", + "\n", + "dotprops = [navis.make_dotprops(n, k=5, resample=False) for n in original + jittered]\n", + "matching_pairs = [[idx, idx + len(original)] for idx in range(len(original))]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The score matrix builder needs some neurons as a list of ``navis.Dotprops`` objects,\n", + "and then to know which neurons should match with each other as indices into that list.\n", + "It's assumed that matches are relatively rare among the total set of possible pairings,\n", + "so non-matching pairs are drawn randomly (although a non-matching list can be given explicitly).\n", + "\n", + "Then it needs to know where to draw the boundaries between bins in the output lookup table.\n", + "These can be given explicitly as a list of 2 ``Digitizer`` s,\n", + "or can be inferred from the data: bins will be drawn to evenly partition the matching neuron scores.\n", + "\n", + "The resulting ``Lookup2d`` can be imported/exported as a ``pandas.DataFrame`` for ease of viewing and storing." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
[0.0,0.14588644700379452)[0.14588644700379452,0.29912252023057806)[0.29912252023057806,0.4802508950780032)[0.4802508950780032,0.7351365037506351)[0.7351365037506351,0.9988406500803395)
[2.0140850483155233,57.8493661406481)0.4866610.3790420.3872600.4456480.855775
[57.8493661406481,81.31283353333698)0.4047660.3437020.4064450.4178020.799595
[81.31283353333698,104.08576202392578)0.2579710.3118270.2176260.2487431.013926
[104.08576202392578,128.14104591262304)0.2557280.0896630.1715990.1159971.296510
[128.14104591262304,155.36119651794434)-0.136171-0.107249-0.1257510.2528831.987175
[155.36119651794434,202.6728515625)-0.575078-0.448307-0.475147-0.2210161.407061
[202.6728515625,395.9569088293945)-1.025938-0.948679-0.863801-0.6205120.054148
[395.9569088293945,4709.61474609375)-0.615558-0.737251-0.679764-0.454779-0.197384
\n", + "
" + ], + "text/plain": [ + " [0.0,0.14588644700379452) \\\n", + "[2.0140850483155233,57.8493661406481) 0.486661 \n", + "[57.8493661406481,81.31283353333698) 0.404766 \n", + "[81.31283353333698,104.08576202392578) 0.257971 \n", + "[104.08576202392578,128.14104591262304) 0.255728 \n", + "[128.14104591262304,155.36119651794434) -0.136171 \n", + "[155.36119651794434,202.6728515625) -0.575078 \n", + "[202.6728515625,395.9569088293945) -1.025938 \n", + "[395.9569088293945,4709.61474609375) -0.615558 \n", + "\n", + " [0.14588644700379452,0.29912252023057806) \\\n", + "[2.0140850483155233,57.8493661406481) 0.379042 \n", + "[57.8493661406481,81.31283353333698) 0.343702 \n", + "[81.31283353333698,104.08576202392578) 0.311827 \n", + "[104.08576202392578,128.14104591262304) 0.089663 \n", + "[128.14104591262304,155.36119651794434) -0.107249 \n", + "[155.36119651794434,202.6728515625) -0.448307 \n", + "[202.6728515625,395.9569088293945) -0.948679 \n", + "[395.9569088293945,4709.61474609375) -0.737251 \n", + "\n", + " [0.29912252023057806,0.4802508950780032) \\\n", + "[2.0140850483155233,57.8493661406481) 0.387260 \n", + "[57.8493661406481,81.31283353333698) 0.406445 \n", + "[81.31283353333698,104.08576202392578) 0.217626 \n", + "[104.08576202392578,128.14104591262304) 0.171599 \n", + "[128.14104591262304,155.36119651794434) -0.125751 \n", + "[155.36119651794434,202.6728515625) -0.475147 \n", + "[202.6728515625,395.9569088293945) -0.863801 \n", + "[395.9569088293945,4709.61474609375) -0.679764 \n", + "\n", + " [0.4802508950780032,0.7351365037506351) \\\n", + "[2.0140850483155233,57.8493661406481) 0.445648 \n", + "[57.8493661406481,81.31283353333698) 0.417802 \n", + "[81.31283353333698,104.08576202392578) 0.248743 \n", + "[104.08576202392578,128.14104591262304) 0.115997 \n", + "[128.14104591262304,155.36119651794434) 0.252883 \n", + "[155.36119651794434,202.6728515625) -0.221016 \n", + "[202.6728515625,395.9569088293945) -0.620512 \n", + "[395.9569088293945,4709.61474609375) -0.454779 \n", + "\n", + " [0.7351365037506351,0.9988406500803395) \n", + "[2.0140850483155233,57.8493661406481) 0.855775 \n", + "[57.8493661406481,81.31283353333698) 0.799595 \n", + "[81.31283353333698,104.08576202392578) 1.013926 \n", + "[104.08576202392578,128.14104591262304) 1.296510 \n", + "[128.14104591262304,155.36119651794434) 1.987175 \n", + "[155.36119651794434,202.6728515625) 1.407061 \n", + "[202.6728515625,395.9569088293945) 0.054148 \n", + "[395.9569088293945,4709.61474609375) -0.197384 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "builder = LookupDistDotBuilder(\n", + " dotprops, matching_pairs, use_alpha=True, seed=2021\n", + ").with_bin_counts([8, 5])\n", + "smat = builder.build()\n", + "as_table = smat.to_dataframe()\n", + "as_table" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we can have this score matrix, we can use it for a problem which can be solved by NBLAST:\n", + "we've mixed up a bag of neurons which look very similar to some of our examples,\n", + "and need to know which they match with." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "NBLAST is optimized for data in microns and it looks like your queries are not in microns.\n", + "NBLAST is optimized for data in microns and it looks like your targets are not in microns.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b55a73a0832e495596df888b4e7d879d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Preparing: 0%| | 0/1 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
754538881_aug722817260_aug1734350908_aug1734350788_aug754534424_aug
1734350788-0.271504-0.361381-0.2720180.159048-0.325920
1734350908-0.400432-0.4913760.858478-0.437819-0.216206
722817260-0.1574330.127931-0.365195-0.407436-0.393940
754534424-0.237769-0.413532-0.193794-0.4011480.857898
754538881-0.021153-0.213349-0.200214-0.303390-0.115017
\n", + "" + ], + "text/plain": [ + " 754538881_aug 722817260_aug 1734350908_aug 1734350788_aug \\\n", + "1734350788 -0.271504 -0.361381 -0.272018 0.159048 \n", + "1734350908 -0.400432 -0.491376 0.858478 -0.437819 \n", + "722817260 -0.157433 0.127931 -0.365195 -0.407436 \n", + "754534424 -0.237769 -0.413532 -0.193794 -0.401148 \n", + "754538881 -0.021153 -0.213349 -0.200214 -0.303390 \n", + "\n", + " 754534424_aug \n", + "1734350788 -0.325920 \n", + "1734350908 -0.216206 \n", + "722817260 -0.393940 \n", + "754534424 0.857898 \n", + "754538881 -0.115017 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "original_dps = dotprops[:len(original)]\n", + "new_dps = [\n", + " navis.make_dotprops(\n", + " augment_neuron(n),\n", + " k=5,\n", + " resample=False\n", + " )\n", + " for n in original\n", + "]\n", + "RNG.shuffle(new_dps)\n", + "\n", + "result = navis.nblast(\n", + " original_dps, new_dps,\n", + " use_alpha=True, scores=\"mean\", normalized=True,\n", + " smat=smat,\n", + " n_cores=1,\n", + ")\n", + "result.index = [dp.name for dp in original_dps]\n", + "result.columns = [dp.name for dp in new_dps]\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that while there aren't any particularly good scores\n", + "(this is a very small amount of training data, and one would normally preprocess the neurons),\n", + "in each case the original's best match is its augmented partner." + ] + } + ], + "metadata": { + "interpreter": { + "hash": "97618690babd5eee2c893391179393f7dcc498027dfffda6cb7b8d10b95474fd" + }, + "kernelspec": { + "display_name": "Python 3.9.4 64-bit ('navis': pyenv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/navis/core/dotprop.py b/navis/core/dotprop.py index c83b04f9..aaa8175c 100644 --- a/navis/core/dotprop.py +++ b/navis/core/dotprop.py @@ -332,7 +332,9 @@ def dist_dots(self, other: 'Dotprops', alpha: bool = False, distance_upper_bound: Optional[float] = None, - **kwargs) -> Tuple[np.ndarray, np.ndarray]: + **kwargs) -> Union[ + Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray] + ]: """Query this Dotprops against another. This function is mainly for ``navis.nblast``. @@ -622,3 +624,6 @@ def to_skeleton(self, tn._soma = self._soma return tn + + def __len__(self): + return len(self.points) diff --git a/navis/nbl/base.py b/navis/nbl/base.py index c3ad4203..9e51344d 100644 --- a/navis/nbl/base.py +++ b/navis/nbl/base.py @@ -17,6 +17,7 @@ import pandas as pd from abc import ABC, abstractmethod +from typing import Union, List from .. import utils, config @@ -24,6 +25,9 @@ FLOAT_DTYPES = {16: np.float16, 32: np.float32, 64: np.float64, None: None} +NestedIndices = Union[int, List['NestedIndices']] + + class Blaster(ABC): """Base class for blasting.""" @@ -37,7 +41,7 @@ def __init__(self, dtype=np.float64, progress=True): self.ids = [] @abstractmethod - def append(self, neurons): + def append(self, neurons) -> NestedIndices: """Append neurons.""" pass @@ -157,3 +161,6 @@ def all_by_all(self, scores='forward'): res.loc[:, :] = np.dstack((res, res.T)).max(axis=2) return res + + def __len__(self): + return len(self.neurons) diff --git a/navis/nbl/nblast_funcs.py b/navis/nbl/nblast_funcs.py index 32dae6fa..9bf1ffe1 100644 --- a/navis/nbl/nblast_funcs.py +++ b/navis/nbl/nblast_funcs.py @@ -16,7 +16,7 @@ import numbers import os -import uuid +import operator import numpy as np import pandas as pd @@ -25,9 +25,11 @@ from typing import Union, Optional, List from typing_extensions import Literal +from navis.nbl.smat import Lookup2d, smat_fcwb + from .. import utils, config from ..core import NeuronList, Dotprops, make_dotprops -from .base import Blaster +from .base import Blaster, NestedIndices __all__ = ['nblast', 'nblast_smart', 'nblast_allbyall', 'sim_to_dist'] @@ -37,75 +39,6 @@ logger = config.logger -class ScoringFunction: - """Class representing scoring function.""" - - def __init__(self, smat): - if isinstance(smat, type(None)): - self.scoring_function = self.pass_through - elif isinstance(smat, (pd.DataFrame, str)): - self.parse_matrix(smat) - self.scoring_function = self.score_lookup - else: - raise TypeError - - def __call__(self, dist, dot): - return self.scoring_function(dist, dot) - - @property - def max_dist(self): - """Max distance considered by the scoring matrix. - - Returns ``None`` if pass through. - """ - if self.scoring_function == self.pass_through: - return None - # The last bin is always `np.inf`, so we need the second last bin - return self.dist_bins[-2] - - def pass_through(self, dist, dot): - """Pass-through scores if no scoring matrix.""" - return dist * dot - - def score_lookup(self, dist, dot): - return self.cells[ - np.digitize(dist, self.dist_bins), - np.digitize(dot, self.dot_bins), - ] - - def parse_matrix(self, smat): - """Parse matrix.""" - if isinstance(smat, str): - smat = pd.read_csv(smat, index_col=0) - - if not isinstance(smat, pd.DataFrame): - raise TypeError(f'Excepted filepath or DataFrame, got "{type(smat)}"') - - self.cells = smat.to_numpy() - - self.dist_thresholds = [self.parse_interval(s) for s in smat.index] - # Make sure right bin is open - self.dist_thresholds[-1] = np.inf - self.dist_bins = np.array(self.dist_thresholds, float) - - self.dot_thresholds = [self.parse_interval(s) for s in smat.columns] - # Make sure right bin is open - self.dot_thresholds[-1] = np.inf - self.dot_bins = np.array(self.dot_thresholds, float) - - def parse_interval(self, s): - """Strip brackets and parse right interval. - - Example - ------- - >>> parse_intervals("(0,0.1]") # doctest: +SKIP - 0.1 - """ - return float(s.strip("([])").split(",")[-1]) - - -NeuronId = Union[int, str, uuid.UUID] - class NBlaster(Blaster): """Implements version 2 of the NBLAST algorithm. @@ -122,10 +55,16 @@ class NBlaster(Blaster): normalized : bool If True, will normalize scores by the best possible score (i.e. self-self) of the query neuron. - smat : str | pd.DataFrame - Score matrix. If 'auto' (default), will use scoring matrices + smat : navis.nbl.smat.Lookup2d | pd.DataFrame | str + How to convert the point match pairs into an NBLAST score, + usually by a lookup table. + If 'auto' (default), will use scoring matrices from FCWB. Same behaviour as in R's nat.nblast - implementation. If ``smat=None`` the scores will be + implementation. + Dataframes will be used to build a ``Lookup2d``. + If ``limit_dist`` is not given, + will attempt to infer from the first axis of the lookup table. + If ``smat=None`` the scores will be generated as the product of the distances and the dotproduct of the vectors of nearest-neighbor pairs. limit_dist : float | "auto" | None @@ -148,27 +87,30 @@ def __init__(self, use_alpha=False, normalized=True, smat='auto', self.approx_nn = approx_nn self.desc = "NBlasting" - if smat == 'auto': - if self.use_alpha: - smat = pd.read_csv(f'{smat_path}/smat_alpha_fcwb.csv', - index_col=0) - else: - smat = pd.read_csv(f'{smat_path}/smat_fcwb.csv', - index_col=0) - - self.score_fn = ScoringFunction(smat) - - if limit_dist == 'auto': - self.distance_upper_bound = self.score_fn.max_dist + if smat is None: + self.score_fn = operator.mul + elif smat == 'auto': + self.score_fn = smat_fcwb(self.use_alpha) + elif isinstance(smat, pd.DataFrame): + self.score_fn = Lookup2d.from_dataframe(smat) + else: + self.score_fn = smat + + if limit_dist == "auto": + try: + self.distance_upper_bound = self.score_fn.axes[0]._max + except AttributeError: + logger.warning("Could not infer distance upper bound from scoring function") + self.distance_upper_bound = None else: self.distance_upper_bound = limit_dist - def append(self, dotprops) -> Union[List[NeuronId], NeuronId]: + def append(self, dotprops) -> NestedIndices: """Append dotprops. - Returns the ID of the appended dotprops. + Returns the numerical index appended dotprops. If dotprops is a (possibly nested) sequence of dotprops, - return a (possibly nested) list of IDs. + return a (possibly nested) list of indices. """ if isinstance(dotprops, Dotprops): return self._append_dotprops(dotprops) @@ -178,12 +120,13 @@ def append(self, dotprops) -> Union[List[NeuronId], NeuronId]: except TypeError: # i.e. not iterable raise ValueError(f"Expected Dotprops or iterable thereof; got {type(dotprops)}") - def _append_dotprops(self, dotprops: Dotprops) -> NeuronId: + def _append_dotprops(self, dotprops: Dotprops) -> int: + next_id = len(self) self.neurons.append(dotprops) self.ids.append(dotprops.id) # Calculate score for self hit self.self_hits.append(self.calc_self_hit(dotprops)) - return dotprops.id + return next_id def calc_self_hit(self, dotprops): """Non-normalized value for self hit.""" @@ -195,7 +138,7 @@ def calc_self_hit(self, dotprops): dots = np.repeat(1, len(dotprops.points)) * np.sqrt(alpha) return self.score_fn(dists, dots).sum() - def single_query_target(self, q_idx, t_idx, scores='forward'): + def single_query_target(self, q_idx: int, t_idx: int, scores='forward'): """Query single target against single target.""" # Take a short-cut if this is a self-self comparison if q_idx == t_idx: diff --git a/navis/nbl/smat.py b/navis/nbl/smat.py new file mode 100644 index 00000000..8aa9ee8c --- /dev/null +++ b/navis/nbl/smat.py @@ -0,0 +1,875 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from itertools import permutations +import sys +import os +from collections import Counter +from concurrent.futures import ProcessPoolExecutor +from typing import ( + Generic, + Hashable, + Iterator, + Mapping, + Optional, + Sequence, + Callable, + List, + Iterable, + Any, + Tuple, + TypeVar, + Union, +) +import logging +from pathlib import Path +from functools import lru_cache +from copy import deepcopy +import operator +import math +from collections import defaultdict + +import numpy as np +import pandas as pd + +from ..core.neurons import Dotprops + +logger = logging.getLogger(__name__) + +DEFAULT_SEED = 1991 + +epsilon = sys.float_info.epsilon +cpu_count = max(1, (os.cpu_count() or 2) - 1) + +fp = Path(__file__).resolve().parent +smat_path = fp / "score_mats" + + +def chunksize(it_len, cpu_count, min_chunk=50): + return max(min_chunk, int(it_len / (cpu_count * 4))) + + +def yield_not_same(pairs: Iterable[Tuple[Any, Any]]) -> Iterator[Tuple[Any, Any]]: + for a, b in pairs: + if a != b: + yield a, b + + +def concat_results(results: Iterable[List[np.ndarray]]) -> List[np.ndarray]: + intermediate = defaultdict(list) + for result_lst in results: + for idx, array in enumerate(result_lst): + intermediate[idx].append(array) + + return [np.concatenate(arrs) for arrs in intermediate.values()] + + +DotpropKey = Hashable + + +class LookupNdBuilder: + def __init__( + self, + dotprops: Union[List[Dotprops], Mapping[DotpropKey, Dotprops]], + matching_lists: List[List[DotpropKey]], + match_fn: Callable[[Dotprops, Dotprops], List[np.ndarray]], + nonmatching_list: Optional[List[DotpropKey]] = None, + seed: int = DEFAULT_SEED, + ) -> None: + f"""Class for building an N-dimensional score lookup for NBLAST. + + Once instantiated, the axes of the lookup table must be defined. + Call ``.with_digitizers()`` to manually define them, + or ``.with_bin_counts()`` to learn them from the matched-pair data. + + Then call ``.build()`` to build the lookup table. + + Parameters + ---------- + + dotprops : dict or list of Dotprops + An indexable, consistently-ordered sequence of all neurons + which will be used as the training set, as Dotprops objects. + matching_sets : list of lists of index into dotprops + Lists of neurons, as indices into ``dotprops``, which should be considered matches. + match_fn : Callable[[Dotprops, Dotprops], List[np.ndarray[float]]] + Function taking 2 arguments, + both instances of ``navis.core.neurons.Dotprops``, + and returning a list of 1D ``numpy.ndarray``s of floats. + The length of the list must be the same as the length of ``boundaries``. + The length of the ``array``s must be the same + as the number of points in the first argument. + This function returns values describing the quality of + point matches from a query to a target neuron. + nonmatching : list of index into dotprops, optional + List of neurons, as indices into ``dotprops``, + which should not be considered matches. + If not given, all ``dotprops`` will be used + (on the assumption that matches are a small subset of possible pairs). + seed : int, optional + Non-matching pairs are drawn at random using this seed, + by default {DEFAULT_SEED} + """ + self.dotprops = dotprops + self.matching_lists = matching_lists + self._nonmatching_list = nonmatching_list + self.match_fn = match_fn + + self.digitizers: Optional[List[Digitizer]] = None + self.bin_counts: Optional[List[int]] = None + + self.seed = seed + self._ndim: Optional[int] = None + + @property + def ndim(self) -> int: + if self._ndim is None: + idx1, idx2 = self._dotprop_keys()[:2] + self._ndim = len(self._query(idx1, idx2)) + self._query.cache_clear() + return self._ndim + + def with_digitizers(self, digitizers: List[Digitizer]): + """Specify the axes of the output lookup table directly. + + Parameters + ---------- + digitizers : List[Digitizer] + + Returns + ------- + self + For chaining convenience. + """ + if len(digitizers) != self.ndim: + raise ValueError( + f"Match function returns {self.ndim} values " + f"but provided {len(digitizers)} digitizers" + ) + + self.digitizers = digitizers + self.bin_counts = None + return self + + def with_bin_counts(self, bin_counts: List[int]): + """Specify the number of bins on each axis of the output lookup table. + + The bin boundaries will be determined by evenly partitioning the data + from the matched pairs into quantiles, in each dimension. + + Parameters + ---------- + bin_counts : List[int] + + Returns + ------- + self + For chaining convenience. + """ + if len(bin_counts) != self.ndim: + raise ValueError( + f"Match function returns {self.ndim} values " + f"but provided {len(bin_counts)} bin counts" + ) + + self.bin_counts = bin_counts + self.digitizers = None + return self + + def _dotprop_keys(self) -> Sequence[DotpropKey]: + """Get all indices into dotprops instance member""" + try: + return self.dotprops.keys() + except AttributeError: + return range(len(self.dotprops)) + + @property + def nonmatching(self) -> List[DotpropKey]: + """Indices of nonmatching set of neurons""" + if self._nonmatching_list is None: + return list(self._dotprop_keys()) + return self._nonmatching_list + + def _yield_matching_pairs(self) -> Iterator[Tuple[DotpropKey, DotpropKey]]: + """Yield all index pairs within all matching pairs""" + for ms in self.matching_lists: + yield from yield_not_same(permutations(ms, 2)) + + def _yield_nonmatching_pairs(self) -> Iterator[Tuple[DotpropKey, DotpropKey]]: + """Yield all index pairs within nonmatching list""" + # todo: this could be much better, use meshgrid or shuffle index arrays + return yield_not_same(permutations(self.nonmatching, 2)) + + def _empty_counts(self) -> np.ndarray: + """Create an empty array in which to store counts; shape determined by digitizer sizes.""" + shape = [len(b) for b in self.digitizers] + return np.zeros(shape, int) + + @lru_cache(None) + def _query(self, q_idx, t_idx) -> List[np.ndarray]: + """Get the results of applying the match function to dotprops specified by indices""" + return self.match_fn(self.dotprops[q_idx], self.dotprops[t_idx]) + + def _query_many(self, idx_pairs, threads=None) -> Iterator[List[np.ndarray]]: + """Yield results from querying many pairs of dotprop indices""" + if threads is None or threads == 0 and cpu_count == 1: + for q_idx, t_idx in idx_pairs: + yield self._query(q_idx, t_idx) + return + + threads = threads or cpu_count + idx_pairs = np.asarray(idx_pairs) + chunks = chunksize(len(idx_pairs), threads) + + with ProcessPoolExecutor(threads) as exe: + yield from exe.map( + self._query, idx_pairs[:, 0], idx_pairs[:, 1], chunksize=chunks + ) + + def _query_to_idxs(self, q_idx, t_idx, counts=None): + """Produce a digitized counts array from a given query-target pair""" + return self._count_results(self._query(q_idx, t_idx), counts) + + def _count_results(self, results: List[np.ndarray], counts=None): + """Convert raw match function ouput into a digitized counts array. + + Requires digitizers. + """ + idxs = [dig(r) for dig, r in zip(self.digitizers, results)] + + if counts is None: + counts = self._empty_counts() + + for idx in zip(*idxs): + counts[idx] += 1 + + return counts + + def _counts_array(self, idx_pairs, threads=None, cache=False): + """Convert index pairs into a digitized counts array. + + Requires digitizers. + """ + counts = self._empty_counts() + if threads is None or threads == 0 and cpu_count == 1: + for q_idx, t_idx in idx_pairs: + counts = self._query_to_idxs(q_idx, t_idx, counts) + return counts + + threads = threads or cpu_count + idx_pairs = np.asarray(idx_pairs, dtype=int) + chunks = chunksize(len(idx_pairs), threads) + + # because digitizing is not necessarily free, + # keep this parallelisation separate to that in _query_many + with ProcessPoolExecutor(threads) as exe: + for these_counts in exe.map( + self._query_to_idxs, + idx_pairs[:, 0], + idx_pairs[:, 1], + chunksize=chunks, + ): + counts += these_counts + if not cache: + self._query.cache_clear + + return counts + + def _pick_nonmatching_pairs(self, n_matching_qual_vals): + """Using the seeded RNG, pick which nonmatching pairs to use.""" + # pre-calculating which pairs we're going to use, + # rather than drawing them as we need them, + # means that we can parallelise the later step more effectively. + # Slowdowns here are practically meaningless + # because of how long distdot calculation will take + all_nonmatching_pairs = list(self._yield_nonmatching_pairs()) + nonmatching_pairs = [] + n_nonmatching_qual_vals = 0 + rng = np.random.default_rng(self.seed) + while n_nonmatching_qual_vals < n_matching_qual_vals: + idx = rng.integers(0, len(all_nonmatching_pairs)) + nonmatching_pair = all_nonmatching_pairs.pop(idx) + nonmatching_pairs.append(nonmatching_pair) + n_nonmatching_qual_vals += len(self.dotprops[nonmatching_pair[0]]) + + return nonmatching_pairs + + def _get_pairs(self): + matching_pairs = list(set(self._yield_matching_pairs())) + # need to know the eventual distdot count + # so we know how many non-matching pairs to draw + q_idx_count = Counter(p[0] for p in matching_pairs) + n_matching_qual_vals = sum( + len(self.dotprops[q_idx]) * n_reps for q_idx, n_reps in q_idx_count.items() + ) + + nonmatching_pairs = self._pick_nonmatching_pairs(n_matching_qual_vals) + return matching_pairs, nonmatching_pairs + + def _build(self, threads, cache=False) -> Tuple[List[Digitizer], np.ndarray]: + if self.digitizers is None and self.bin_counts is None: + raise ValueError( + "Builder needs either digitizers or bin_counts; see with_* methods" + ) + + matching_pairs, nonmatching_pairs = self._get_pairs() + + if self.digitizers: + match_counts = self._counts_array(matching_pairs, threads, cache) + else: + match_results = concat_results(self._query_many(matching_pairs, threads)) + self.digitizers = [ + Digitizer.from_data(data, nbins) + for data, nbins in zip(match_results, self.bin_counts) + ] + if not cache: + self._query.cache_clear() + match_counts = self._count_results(match_results) + + nonmatch_counts = self._counts_array(nonmatching_pairs, threads) + + # account for there being different total numbers of match/nonmatch dist dots + matching_factor = nonmatch_counts.sum() / match_counts.sum() + if np.any(match_counts + nonmatch_counts == 0): + logger.warning("Some lookup cells have no data in them") + + cells = np.log2( + (match_counts * matching_factor + epsilon) / (nonmatch_counts + epsilon) + ) + + return self.digitizers, cells + + def build(self, threads=None, cache=False) -> LookupNd: + """Build the score matrix. + + All non-identical neuron pairs within all matching sets are selected, + and distdots calculated for those pairs. + Then, the minimum number of non-matching pairs are randomly drawn + so that at least as many distdots can be calculated for non-matching + pairs. + + In each bin of the score matrix, the log2 odds ratio of a distdot + in that bin belonging to a match vs. non-match is calculated. + + Parameters + ---------- + threads : int, optional + If None, act in serial. + If 0, use cpu_count - 1. + Otherwise, use the given value. + + Returns + ------- + LookupNd + """ + dig, cells = self._build(threads, cache) + return LookupNd(dig, cells) + + +def dist_dot(q: Dotprops, t: Dotprops): + return list(q.dist_dots(t)) + + +def dist_dot_alpha(q: Dotprops, t: Dotprops): + dist, dot, alpha = q.dist_dots(t, alpha=True) + return [dist, dot * np.sqrt(alpha)] + + +class LookupDistDotBuilder(LookupNdBuilder): + def __init__( + self, + dotprops: Union[List[Dotprops], Mapping[DotpropKey, Dotprops]], + matching_lists: List[List[DotpropKey]], + nonmatching_list: Optional[List[DotpropKey]] = None, + use_alpha: bool = False, + seed: int = DEFAULT_SEED, + ): + f"""Class for building a 2-dimensional score lookup for NBLAST. + + The scores are + + 1. The distances between best-matching points + 2. The dot products of direction vectors around those points, + optionally scaled by the colinearity ``alpha``. + + Parameters + ---------- + dotprops : dict or list of Dotprops + An indexable sequence of all neurons which will be used as the training set, + as Dotprops objects. + matching_lists : list of lists of indices into dotprops + List of neurons, as indices into ``dotprops``, which should be considered matches. + nonmatching_list : list of indices into dotprops, optional + List of neurons, as indices into ``dotprops``, + which should not be considered matches. + If not given, all ``dotprops`` will be used + (on the assumption that matches are a small subset of possible pairs). + use_alpha : bool, optional + If true, multiply the dot product by the geometric mean + of the matched points' alpha values + (i.e. ``sqrt(alpha1 * alpha2)``). + seed : int, optional + Non-matching pairs are drawn at random using this seed, + by default {DEFAULT_SEED} + """ + match_fn = dist_dot_alpha if use_alpha else dist_dot + super().__init__( + dotprops, + matching_lists, + match_fn, + nonmatching_list, + seed, + ) + self._ndim = 2 + + def build(self, threads=None, cache=False) -> Lookup2d: + (dig0, dig1), cells = self._build(threads, cache) + return Lookup2d(dig0, dig1, cells) + + +def is_monotonically_increasing(lst): + for prev_idx, item in enumerate(lst[1:]): + if item <= lst[prev_idx]: + return False + return True + + +def parse_boundary(item: str): + explicit_interval = item[0] + item[-1] + if explicit_interval == "[)": + right = False + elif explicit_interval == "(]": + right = True + else: + raise ValueError( + f"Enclosing characters '{explicit_interval}' do not match a half-open interval" + ) + return tuple(float(i) for i in item[1:-1].split(",")), right + + +T = TypeVar("T") + + +class LookupAxis(ABC, Generic[T]): + """Class converting some data into a linear index.""" + + @abstractmethod + def __len__(self) -> int: + """Number of bins represented by this instance.""" + pass + + @abstractmethod + def __call__(self, value: Union[T, Sequence[T]]) -> Union[int, Sequence[int]]: + """Convert some data into a linear index. + + Parameters + ---------- + value : Union[T, Sequence[T]] + Value to convert into an index + + Returns + ------- + Union[int, Sequence[int]] + If a scalar was given, return a scalar; otherwise, a numpy array of ints. + """ + pass + + +class SimpleLookup(LookupAxis[Hashable]): + def __init__(self, items: List[Hashable]): + """Look up in a list of items and return their index. + + Parameters + ---------- + items : List[Hashable] + The item's position in the list is the index which will be returned. + + Raises + ------ + ValueError + items are non-unique. + """ + self.items = {item: idx for idx, item in enumerate(items)} + if len(self.items) != len(items): + raise ValueError("Items are not unique") + + def __len__(self) -> int: + return len(self.items) + + def __call__(self, value: Union[Hashable, Sequence[Hashable]]) -> Union[int, Sequence[int]]: + if np.isscalar(value): + return self.items[value] + else: + return np.array([self.items[v] for v in value], int) + + +class Digitizer(LookupAxis[float]): + def __init__( + self, + boundaries: Sequence[float], + clip: Tuple[bool, bool] = (True, True), + right=False, + ): + """Class converting continuous values into discrete indices. + + Parameters + ---------- + boundaries : Sequence[float] + N boundaries specifying N-1 bins. + Must be monotonically increasing. + clip : Tuple[bool, bool], optional + Whether to set the bottom and top boundaries to -infinity and infinity respectively, + effectively clipping incoming values: by default (True, True). + False means "add a new bin for out-of-range values". + right : bool, optional + Whether bins should include their right (rather than left) boundary, + by default False + """ + self.right = right + + boundaries = list(boundaries) + self._min = -math.inf + if clip[0]: + self._min = boundaries[0] + boundaries[0] = -math.inf + elif boundaries[0] != -math.inf: + self._min = -math.inf + boundaries.insert(0, -math.inf) + + self._max = math.inf + if clip[1]: + self._max = boundaries[-1] + boundaries[-1] = math.inf + elif boundaries[-1] != math.inf: + boundaries.append(math.inf) + + if not is_monotonically_increasing(boundaries): + raise ValueError("Boundaries are not monotonically increasing") + + self.boundaries = np.asarray(boundaries) + + def __len__(self): + return len(self.boundaries) - 1 + + def __call__(self, value: float): + # searchsorted is marginally faster than digitize as it skips monotonicity checks + return ( + np.searchsorted( + self.boundaries, value, side="left" if self.right else "right" + ) + - 1 + ) + + def to_strings(self) -> List[str]: + if self.right: + lb = "(" + rb = "]" + else: + lb = "[" + rb = ")" + + b = self.boundaries.copy() + b[0] = self._min + b[-1] = self._max + return [ + f"{lb}{lower},{upper}{rb}" + for lower, upper in zip(b[:-1], b[1:]) + ] + + @classmethod + def from_strings(cls, interval_strs: Sequence[str]): + """Set digitizer boundaries based on a sequence of interval expressions. + + e.g. ``["(0, 1]", "(1, 5]", "(5, 10]"]`` + + The lowermost and uppermost boundaries are converted to -infinity and infinity respectively. + + Parameters + ---------- + bound_strs : Sequence[str] + Strings representing intervals, which must abut and have open/closed boundaries + specified by brackets. + + Returns + ------- + Digitizer + """ + bounds: List[float] = [] + last_upper = None + last_right = None + for item in interval_strs: + (lower, upper), right = parse_boundary(item) + bounds.append(float(lower)) + + if last_right is not None: + if right != last_right: + raise ValueError("Inconsistent half-open interval") + else: + last_right = right + + if last_upper is not None: + if lower != last_upper: + raise ValueError("Half-open intervals do not abut") + + last_upper = upper + + bounds.append(float(last_upper)) + return cls(bounds, right=last_right) + + @classmethod + def from_linear(cls, lower: float, upper: float, nbins: int, right=False): + """Choose digitizer boundaries spaced linearly between two values. + + Input values will be clipped to fit within the given interval. + + Parameters + ---------- + lower : float + Lowest value + upper : float + Highest value + nbins : int + Number of bins + right : bool, optional + Whether bins should include their right (rather than left) boundary, + by default False + + Returns + ------- + Digitizer + """ + arr = np.linspace(lower, upper, nbins + 1, endpoint=True) + return cls(arr, right=right) + + @classmethod + def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right=False): + """Choose digitizer boundaries in a geometric sequence. + + Additional bins will be added above and below the given values. + + Parameters + ---------- + lowest_upper : float + Upper bound of the lowest bin. + The lower bound of the lowest bin is often 0, which cannot be represented in a nontrivial geometric sequence. + highest_lower : float + Lower bound of the highest bin. + nbins : int + Number of bins + right : bool, optional + Whether bins should include their right (rather than left) boundary, + by default False + + Returns + ------- + Digitizer + """ + arr = np.geomspace(lowest_upper, highest_lower, nbins - 1, True) + return cls(arr, clip=(False, False), right=right) + + @classmethod + def from_data(cls, data: Sequence[float], nbins: int, right=False): + """Choose digitizer boundaries to evenly partition the given values. + + Parameters + ---------- + data : Sequence[float] + Data which should be evenly partitioned by the resulting digitizer. + nbins : int + Number of bins + right : bool, optional + Whether bins should include their right (rather than left) boundary, + by default False + + Returns + ------- + Digitizer + """ + arr = np.quantile(data, np.linspace(0, 1, nbins + 1, True)) + return cls(arr, right=right) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Digitizer): + return NotImplemented + return self.right == other.right and np.allclose( + self.boundaries, other.boundaries + ) + + +class LookupNd: + def __init__(self, axes: List[LookupAxis], cells: np.ndarray): + if [len(b) for b in axes] != list(cells.shape): + raise ValueError("boundaries and cells have inconsistent bin counts") + self.axes = axes + self.cells = cells + + def __call__(self, *args): + if len(args) != len(self.axes): + raise TypeError( + f"Lookup takes {len(self.axes)} arguments but {len(args)} were given" + ) + + idxs = tuple(d(arg) for d, arg in zip(self.axes, args)) + out = self.cells[idxs] + return out + + +class Lookup2d(LookupNd): + """Convenience class inheriting from LookupNd for the common 2D float case. + Provides IO with pandas DataFrames. + """ + + def __init__(self, axis0: Digitizer, axis1: Digitizer, cells: np.ndarray): + """2D lookup table for convert NBLAST matches to scores. + + Commonly read from a ``pandas.DataFrame`` + or trained on data using a ``LookupDistDotBuilder``. + + Parameters + ---------- + digitizer0 : Digitizer + How to convert continuous values into an index for the first axis. + digitizer1 : Digitizer + How to convert continuous values into an index for the second axis. + cells : np.ndarray + Values to look up in the table. + """ + super().__init__([axis0, axis1], cells) + + def to_dataframe(self) -> pd.DataFrame: + """Convert the lookup table into a ``pandas.DataFrame``. + + From there, it can be shared, saved, and so on. + + The index and column labels describe the intervals represented by that axis. + + Returns + ------- + pd.DataFrame + """ + return pd.DataFrame( + self.cells, + self.axes[0].to_strings(), + self.axes[1].to_strings(), + ) + + @classmethod + def from_dataframe(cls, df: pd.DataFrame): + f"""Parse score matrix from a dataframe with string index and column labels. + + Expects the index and column labels to specify an interval + like ``f"[{{lower}},{{upper}})"``. + Will replace the lowermost and uppermost bound with -inf and inf + if they are not already. + """ + return cls( + Digitizer.from_strings(df.index), + Digitizer.from_strings(df.columns), + df.to_numpy(), + ) + + +@lru_cache(maxsize=None) +def _smat_fcwb(alpha=False): + # cached private function defers construction + # until needed (speeding startup), + # but avoids repeated reads (speeding later uses) + fname = ("smat_fcwb.csv", "smat_alpha_fcwb.csv")[alpha] + fpath = smat_path / fname + + return Lookup2d.from_dataframe(pd.read_csv(fpath, index_col=0)) + + +def smat_fcwb(alpha=False): + # deepcopied so that mutations do not propagate to cache + return deepcopy(_smat_fcwb(alpha)) + + +def check_score_fn(fn: Callable, nargs=2, scalar=True, array=True): + """Checks functionally that the callable can be used as a score function. + + Parameters + ---------- + nargs : optional int, default 2 + How many positional arguments the score function should have. + scalar : optional bool, default True + Check that the function can be used on ``nargs`` scalars. + array : optional bool, default True + Check that the function can be used on ``nargs`` 1D ``numpy.ndarray``s. + + Raises + ------ + ValueError + If the score function is not appropriate. + """ + if scalar: + scalars = [0.5] * nargs + if not isinstance(fn(*scalars), float): + raise ValueError("smat does not take 2 floats and return a float") + + if array: + test_arr = np.array([0.5] * 3) + arrs = [test_arr] * nargs + try: + out = fn(*arrs) + except Exception as e: + raise ValueError(f"Failed to use smat with numpy arrays: {e}") + + if out.shape != test_arr.shape: + raise ValueError( + f"smat produced inconsistent shape: input {test_arr.shape}; output {out.shape}" + ) + + +SCORE_FN_DESCR = """ +NBLAST score functions take 2 floats or N-length numpy arrays of floats +(for matched dotprop points/tangents, distance and dot product; +the latter possibly scaled by the geometric mean of the alpha colinearity values) +and returns a float or N-length numpy array of floats. +""".strip().replace( + "\n", " " +) + + +def parse_score_fn(smat, alpha=False): + f"""Interpret ``smat`` as a score function. + Primarily for backwards compatibility. + {SCORE_FN_DESCR} + Parameters + ---------- + smat : None | "auto" | str | os.PathLike | pandas.DataFrame | Callable[[float, float], float] + If ``None``, use ``operator.mul``. + If ``"auto"``, use ``navis.nbl.smat.smat_fcwb(alpha)``. + If a dataframe, use ``navis.nbl.smat.Lookup2d.from_dataframe(smat)``. + If another string or path-like, load from CSV in a dataframe and uses as above. + Also checks the signature of the callable. + Raises an error, probably a ValueError, if it can't be interpreted. + alpha : optional bool, default False + If ``smat`` is ``"auto"``, choose whether to use the FCWB matrices + with or without alpha. + Returns + ------- + Callable + Raises + ------ + ValueError + If score function cannot be interpreted. + """ + if smat is None: + smat = operator.mul + elif smat == "auto": + smat = smat_fcwb(alpha) + + if isinstance(smat, (str, os.PathLike)): + smat = pd.read_csv(smat, index_col=0) + + if isinstance(smat, pd.DataFrame): + smat = Lookup2d.from_dataframe(smat) + + if not callable(smat): + raise ValueError( + "smat should be a callable, a path, a pandas.DataFrame, or 'auto'" + ) + + check_score_fn(smat) + + return smat diff --git a/navis/nbl/synblast_funcs.py b/navis/nbl/synblast_funcs.py index c0b7812d..433039ea 100644 --- a/navis/nbl/synblast_funcs.py +++ b/navis/nbl/synblast_funcs.py @@ -15,6 +15,7 @@ """Module contains functions implementing SyNBLAST.""" import os +import operator import numpy as np import pandas as pd @@ -27,10 +28,11 @@ from .. import config, utils from ..core import NeuronList, BaseNeuron -from .base import Blaster +from .base import Blaster, NestedIndices +from .smat import Lookup2d -from .nblast_funcs import (check_microns, find_optimal_partition, ScoringFunction, - nblast_preflight) +from .nblast_funcs import (check_microns, find_optimal_partition, + nblast_preflight, smat_fcwb) __all__ = ['synblast'] @@ -59,10 +61,14 @@ class SynBlaster(Blaster): by_type : bool If True will only compare synapses with the same value in the "type" column. - smat : str | pd.DataFrame - Score matrix. If 'auto' (default), will use scoring matrices + smat : navis.nbl.smat.Lookup2d | pd.DataFrame | str + How to convert the point match pairs into an NBLAST score, + usually by a lookup table. + If 'auto' (default), will use scoring matrices from FCWB. Same behaviour as in R's nat.nblast - implementation. If ``smat=None`` the scores will be + implementation. + Dataframes will be used to build a ``Lookup2d``. + If ``smat=None`` the scores will be generated as the product of the distances and the dotproduct of the vectors of nearest-neighbor pairs. progress : bool @@ -77,15 +83,19 @@ def __init__(self, normalized=True, by_type=True, self.normalized = normalized self.by_type = by_type - if smat == 'auto': - smat = pd.read_csv(f'{smat_path}/smat_fcwb.csv', - index_col=0) + if smat is None: + self.score_fn = operator.mul + elif smat == 'auto': + self.score_fn = smat_fcwb() + elif isinstance(smat, pd.DataFrame): + self.score_fn = Lookup2d.from_dataframe(smat) + else: + self.score_fn = smat - self.score_fn = ScoringFunction(smat) self.ids = [] - def append(self, neuron, id=None): - """Append neurons/connector tables, returning ids of added objects""" + def append(self, neuron, id=None) -> NestedIndices: + """Append neurons/connector tables, returning numerical indices of added objects""" if isinstance(neuron, pd.DataFrame): return self._append_connectors(neuron, id) @@ -102,10 +112,11 @@ def append(self, neuron, id=None): f"{type(neuron)}" ) - def _append_connectors(self, connectors: pd.DataFrame, id): + def _append_connectors(self, connectors: pd.DataFrame, id) -> int: if id is None: raise ValueError("Explicit non-None id required for appending connectors") + next_idx = len(self) self.ids.append(id) self.neurons.append({}) if not self.by_type: @@ -123,13 +134,13 @@ def _append_connectors(self, connectors: pd.DataFrame, id): # Calculate score for self hit self.self_hits.append(self.calc_self_hit(connectors)) - return id + return next_idx def calc_self_hit(self, cn): """Non-normalized value for self hit.""" return cn.shape[0] * self.score_fn(0, 1) - def single_query_target(self, q_idx, t_idx, scores='forward'): + def single_query_target(self, q_idx: int, t_idx: int, scores='forward'): """Query single target against single target.""" # Take a short-cut if this is a self-self comparison if q_idx == t_idx: diff --git a/tests/conftest.py b/tests/conftest.py index a35954d7..ec872834 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ from pathlib import Path import os +from typing import List import pandas as pd import pytest @@ -9,7 +10,7 @@ import navis -@pytest.fixture +@pytest.fixture(scope="session") def data_dir(): return Path(__file__).resolve().parent.parent / "navis" / "data" @@ -17,8 +18,8 @@ def data_dir(): @pytest.fixture( params=["Path", "pathstr", "swcstr", "textbuffer", "rawbuffer", "DataFrame"] ) -def swc_source(request, data_dir: Path): - swc_path: Path = data_dir / "swc" / "722817260.swc" +def swc_source(request, swc_paths: List[Path]): + swc_path: Path = swc_paths[0] if request.param == "Path": yield swc_path elif request.param == "pathstr": @@ -42,9 +43,9 @@ def swc_source(request, data_dir: Path): @pytest.fixture( params=["dirstr", "dirpath", "list", "listwithdir"], ) -def swc_source_multi(request, data_dir: Path): - dpath = data_dir / "swc" - fpath = dpath / "722817260.swc" +def swc_source_multi(request, swc_paths: List[Path]): + fpath = swc_paths[0] + dpath = fpath.parent if request.param == "dirstr": yield str(dpath) elif request.param == "dirpath": @@ -74,3 +75,43 @@ def voxel_nrrd_path(tmp_path): nrrd.write(os.fspath(path), data, header) return path + + +def data_paths(dpath, glob="*"): + return sorted(dpath.glob(glob)) + + +@pytest.fixture(scope="session") +def swc_paths(data_dir: Path): + return data_paths(data_dir / "swc", "*.swc") + + +@pytest.fixture(scope="session") +def gml_paths(data_dir: Path): + return data_paths(data_dir / "gml", "*.gml") + + +@pytest.fixture(scope="session") +def obj_paths(data_dir: Path): + return data_paths(data_dir / "obj", "*.obj") + + +@pytest.fixture(scope="session") +def synapses_paths(data_dir: Path): + return data_paths(data_dir / "synapses", "*.csv") + + +@pytest.fixture(scope="session") +def volumes_paths(data_dir: Path): + return data_paths(data_dir / "volumes", "*.obj") + + +@pytest.fixture +def treeneuron_dfs(swc_paths, synapses_paths): + swc_reader = navis.io.swc_io.SwcReader() + out = [] + for swc_path, syn_path in zip(swc_paths, synapses_paths): + neuron = swc_reader.read_file_path(swc_path) + neuron.connectors = pd.read_csv(syn_path) + out.append(neuron) + return out diff --git a/tests/test_nbl/__init__.py b/tests/test_nbl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_nbl/test_smat.py b/tests/test_nbl/test_smat.py new file mode 100644 index 00000000..04d2d4f6 --- /dev/null +++ b/tests/test_nbl/test_smat.py @@ -0,0 +1,141 @@ +from navis import Dotprops +import pytest + +import numpy as np + +from navis.nbl.smat import ( + Digitizer, LookupNd, Lookup2d, LookupDistDotBuilder +) + + +SMALLEST_DIM_SIZE = 3 +SEED = 1991 + + +def lookup_args(ndim): + f""" + Create arguments for an ND lookup table. + The first dimension is of size {SMALLEST_DIM_SIZE}, + and subsequent dimensions are 1 longer than the previous. + The data in the cells are the sequence from 0 + to the size of the array. + The boundaries are 0 to the length of the dimension, + with the left and rightmost values replaced with -inf and inf respectively. + + Examples + -------- + >>> lookup_args(2) + ( + [ + array([-inf, 1, 2, inf]), + array([-inf, 1, 2, 3, inf]), + ], + array([ + [ 0, 1, 2, 3 ], + [ 4, 5, 6, 7 ], + [ 8, 9, 10, 11 ], + ]), + ) + """ + shape = tuple(range(SMALLEST_DIM_SIZE, SMALLEST_DIM_SIZE + ndim)) + cells = np.arange(np.product(shape)).reshape(shape) + digitizers = [Digitizer(np.arange(s + 1, dtype=float)) for s in shape] + return digitizers, cells + + +def fmt_array(arg): + if np.isscalar(arg): + return str(arg) + else: + return "[" + ",".join(fmt_array(v) for v in arg) + "]" + + +@pytest.mark.parametrize( + ["ndim"], [[1], [2], [3], [4], [5]], ids=lambda x: f"{x}D" +) +@pytest.mark.parametrize( + ["arg"], + [ + (-1000,), + (0,), + (1,), + (1.5,), + (2,), + (1000,), + ([-1000, 0, 1, 1.5, 2, 1000],), + ], + ids=fmt_array, +) +def test_lookupNd(ndim, arg): + lookup = LookupNd(*lookup_args(ndim)) + + args = [arg for _ in range(ndim)] + expected_arr_idx = np.floor([ + np.clip(arg, 0, dim + SMALLEST_DIM_SIZE - 1) for dim in range(ndim) + ]).astype(int) + expected_val = np.ravel_multi_index( + tuple(expected_arr_idx), lookup.cells.shape + ) + + response = lookup(*args) + assert np.all(response == expected_val) + + +def test_lookup2d_roundtrip(): + digs, cells = lookup_args(2) + lookup = Lookup2d(*digs, cells=cells) + df = lookup.to_dataframe() + lookup2 = Lookup2d.from_dataframe(df) + assert np.allclose(lookup.cells, lookup2.cells) + for b1, b2 in zip(lookup.axes, lookup2.axes): + assert b1 == b2 + + +def prepare_lookupdistdotbuilder(neurons, alpha=False, k=5): + k = 5 + dotprops = [Dotprops(n.nodes[["x", "y", "z"]], k) for n in neurons] + n_orig = len(dotprops) + + # make jittered copies of these neurons + rng = np.random.default_rng(SEED) + jitter_sigma = 50 + matching_lists = [] + for idx, dp in enumerate(dotprops[:]): + dotprops.append( + Dotprops( + dp.points + rng.normal(0, jitter_sigma, dp.points.shape), k + ) + ) + # assign each neuron its jittered self as a match + matching_lists.append([idx, idx + n_orig]) + + # original neurons should all not match each other + nonmatching = list(range(n_orig)) + + # max distance between any 2 points in the data + # for calculating dist boundaries + max_dist = np.linalg.norm( + np.ptp( + np.concatenate([dp.points for dp in dotprops], axis=0), axis=0, + ) + ) + + return LookupDistDotBuilder( + dotprops, + matching_lists, + nonmatching, + alpha, + seed=SEED + 1, + ).with_digitizers([ + Digitizer.from_geom(10, max_dist, 5), + Digitizer.from_linear(0, 1, 5), + ]) + + +@pytest.mark.parametrize(["alpha"], [(True,), (False,)]) +@pytest.mark.parametrize(["threads"], [(0,), (2,), (None,)]) +def test_lookupdistdotbuilder_builds(treeneuron_dfs, threads, alpha): + builder = prepare_lookupdistdotbuilder(treeneuron_dfs, alpha) + lookup = builder.build(threads) + # `pytest -rP` to see output + print(lookup.to_dataframe())