diff --git a/.gitignore b/.gitignore
index da2d03ca..d35c0bea 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,3 +130,5 @@ ENV/
#*.swc
#*.json
MANIFEST
+
+tmp/
diff --git a/docs/source/api.rst b/docs/source/api.rst
index 4161dfbe..d82b4cc6 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -364,6 +364,15 @@ Functions to compare morphology.
navis.synblast
navis.persistence_distances
+Utilities for creating your own score matrices for NBLAST can be found in
+
+.. autosummary::
+ :toctree: generated/
+
+ navis.nbl.smat.Lookup2d
+ navis.nbl.smat.Digitizer
+ navis.nbl.smat.LookupDistDotBuilder
+
Polarity metrics
----------------
.. autosummary::
diff --git a/docs/source/gallery.rst b/docs/source/gallery.rst
index 62a61f4b..4da347ad 100644
--- a/docs/source/gallery.rst
+++ b/docs/source/gallery.rst
@@ -72,6 +72,7 @@ function. There you will also find more examples.
tutorials/nblast
tutorials/nblast_flycircuit
tutorials/nblast_hemibrain
+ tutorials/smat
.. raw:: html
diff --git a/docs/source/tutorials/smat.ipynb b/docs/source/tutorials/smat.ipynb
new file mode 100644
index 00000000..45672bb2
--- /dev/null
+++ b/docs/source/tutorials/smat.ipynb
@@ -0,0 +1,491 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ".. _smat_intro:\n",
+ "\n",
+ "NBLAST score matrix generation\n",
+ "******************************\n",
+ "\n",
+ "NBLAST calculates the similarity between neurons' morphology.\n",
+ "For more information on NBLAST in ``navis``, see the other tutorials.\n",
+ "\n",
+ "The core of the algorithm is the function which converts point matches\n",
+ "(a distance and a dot product) into a score\n",
+ "expressing how likely they are to have come from the same point cloud.\n",
+ "This is generally a 2D lookup table, referred to as the score matrix,\n",
+ "generated by \"training\" it on neurons known to be matching or non-matching.\n",
+ "\n",
+ "``navis`` provides (and uses by default) the score matrix used in the original publication (Costa et al., 2016).\n",
+ "This works quite well in many cases.\n",
+ "However, how appropriate it is for your data depends on a number of factors:\n",
+ "\n",
+ "* How big your neurons are (commonly addressed by scaling the distance axis of the built-in score matrix)\n",
+ "* How you have pre-processed your neurons (pruning dendrites, resampling etc.)\n",
+ "* Your actual task (matching left-right pairs, finding lineages etc.)\n",
+ "* How distinct you expect your matches and non-matches to be (e.g. how large a body volume you're drawing neurons from)\n",
+ "\n",
+ "Utilities in ``navis.nbl.smat`` allow you to train your own score matrix."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import navis\n",
+ "from navis.nbl.smat import Lookup2d, LookupDistDotBuilder, Digitizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "``Lookup2d`` is the lookup table, and can be given to any nblast-related class or function.\n",
+ "Each axis represented by a ``Digitizer``, which converts continuous values into discrete indices.\n",
+ "These indices are used to look up score values in an array.\n",
+ "\n",
+ "The ``LookupDistDotBuilder`` is a class which generates ``Lookup2d`` instances from training data.\n",
+ "\n",
+ "First, we need some training data.\n",
+ "Let's augment our set of example neurons by randomly mutating them:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "# use a replicable random number generator\n",
+ "RNG = np.random.default_rng(2021)\n",
+ "\n",
+ "\n",
+ "def augment_neuron(\n",
+ " nrn: navis.TreeNeuron, scale_sigma=0.1, translation_sigma=50, jitter_sigma=10\n",
+ "):\n",
+ " nrn = nrn.copy(deepcopy=True)\n",
+ " nrn.name += \"_aug\"\n",
+ " dims = list(\"xyz\")\n",
+ " coords = nrn.nodes[dims].to_numpy()\n",
+ "\n",
+ " # translate whole neuron\n",
+ " coords += RNG.normal(scale=translation_sigma, size=coords.shape[-1])\n",
+ " # jitter individual coordinates\n",
+ " coords += RNG.normal(scale=jitter_sigma, size=coords.shape)\n",
+ " # rescale\n",
+ " mean = np.mean(coords, axis=0)\n",
+ " coords -= mean\n",
+ " coords *= RNG.normal(loc=1.0, scale=scale_sigma)\n",
+ " coords += mean\n",
+ "\n",
+ " nrn.nodes[dims] = coords\n",
+ " return nrn\n",
+ "\n",
+ "\n",
+ "original = list(navis.example_neurons())\n",
+ "jittered = [augment_neuron(n) for n in original]\n",
+ "\n",
+ "dotprops = [navis.make_dotprops(n, k=5, resample=False) for n in original + jittered]\n",
+ "matching_pairs = [[idx, idx + len(original)] for idx in range(len(original))]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The score matrix builder needs some neurons as a list of ``navis.Dotprops`` objects,\n",
+ "and then to know which neurons should match with each other as indices into that list.\n",
+ "It's assumed that matches are relatively rare among the total set of possible pairings,\n",
+ "so non-matching pairs are drawn randomly (although a non-matching list can be given explicitly).\n",
+ "\n",
+ "Then it needs to know where to draw the boundaries between bins in the output lookup table.\n",
+ "These can be given explicitly as a list of 2 ``Digitizer`` s,\n",
+ "or can be inferred from the data: bins will be drawn to evenly partition the matching neuron scores.\n",
+ "\n",
+ "The resulting ``Lookup2d`` can be imported/exported as a ``pandas.DataFrame`` for ease of viewing and storing."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " [0.0,0.14588644700379452) | \n",
+ " [0.14588644700379452,0.29912252023057806) | \n",
+ " [0.29912252023057806,0.4802508950780032) | \n",
+ " [0.4802508950780032,0.7351365037506351) | \n",
+ " [0.7351365037506351,0.9988406500803395) | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " [2.0140850483155233,57.8493661406481) | \n",
+ " 0.486661 | \n",
+ " 0.379042 | \n",
+ " 0.387260 | \n",
+ " 0.445648 | \n",
+ " 0.855775 | \n",
+ "
\n",
+ " \n",
+ " [57.8493661406481,81.31283353333698) | \n",
+ " 0.404766 | \n",
+ " 0.343702 | \n",
+ " 0.406445 | \n",
+ " 0.417802 | \n",
+ " 0.799595 | \n",
+ "
\n",
+ " \n",
+ " [81.31283353333698,104.08576202392578) | \n",
+ " 0.257971 | \n",
+ " 0.311827 | \n",
+ " 0.217626 | \n",
+ " 0.248743 | \n",
+ " 1.013926 | \n",
+ "
\n",
+ " \n",
+ " [104.08576202392578,128.14104591262304) | \n",
+ " 0.255728 | \n",
+ " 0.089663 | \n",
+ " 0.171599 | \n",
+ " 0.115997 | \n",
+ " 1.296510 | \n",
+ "
\n",
+ " \n",
+ " [128.14104591262304,155.36119651794434) | \n",
+ " -0.136171 | \n",
+ " -0.107249 | \n",
+ " -0.125751 | \n",
+ " 0.252883 | \n",
+ " 1.987175 | \n",
+ "
\n",
+ " \n",
+ " [155.36119651794434,202.6728515625) | \n",
+ " -0.575078 | \n",
+ " -0.448307 | \n",
+ " -0.475147 | \n",
+ " -0.221016 | \n",
+ " 1.407061 | \n",
+ "
\n",
+ " \n",
+ " [202.6728515625,395.9569088293945) | \n",
+ " -1.025938 | \n",
+ " -0.948679 | \n",
+ " -0.863801 | \n",
+ " -0.620512 | \n",
+ " 0.054148 | \n",
+ "
\n",
+ " \n",
+ " [395.9569088293945,4709.61474609375) | \n",
+ " -0.615558 | \n",
+ " -0.737251 | \n",
+ " -0.679764 | \n",
+ " -0.454779 | \n",
+ " -0.197384 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " [0.0,0.14588644700379452) \\\n",
+ "[2.0140850483155233,57.8493661406481) 0.486661 \n",
+ "[57.8493661406481,81.31283353333698) 0.404766 \n",
+ "[81.31283353333698,104.08576202392578) 0.257971 \n",
+ "[104.08576202392578,128.14104591262304) 0.255728 \n",
+ "[128.14104591262304,155.36119651794434) -0.136171 \n",
+ "[155.36119651794434,202.6728515625) -0.575078 \n",
+ "[202.6728515625,395.9569088293945) -1.025938 \n",
+ "[395.9569088293945,4709.61474609375) -0.615558 \n",
+ "\n",
+ " [0.14588644700379452,0.29912252023057806) \\\n",
+ "[2.0140850483155233,57.8493661406481) 0.379042 \n",
+ "[57.8493661406481,81.31283353333698) 0.343702 \n",
+ "[81.31283353333698,104.08576202392578) 0.311827 \n",
+ "[104.08576202392578,128.14104591262304) 0.089663 \n",
+ "[128.14104591262304,155.36119651794434) -0.107249 \n",
+ "[155.36119651794434,202.6728515625) -0.448307 \n",
+ "[202.6728515625,395.9569088293945) -0.948679 \n",
+ "[395.9569088293945,4709.61474609375) -0.737251 \n",
+ "\n",
+ " [0.29912252023057806,0.4802508950780032) \\\n",
+ "[2.0140850483155233,57.8493661406481) 0.387260 \n",
+ "[57.8493661406481,81.31283353333698) 0.406445 \n",
+ "[81.31283353333698,104.08576202392578) 0.217626 \n",
+ "[104.08576202392578,128.14104591262304) 0.171599 \n",
+ "[128.14104591262304,155.36119651794434) -0.125751 \n",
+ "[155.36119651794434,202.6728515625) -0.475147 \n",
+ "[202.6728515625,395.9569088293945) -0.863801 \n",
+ "[395.9569088293945,4709.61474609375) -0.679764 \n",
+ "\n",
+ " [0.4802508950780032,0.7351365037506351) \\\n",
+ "[2.0140850483155233,57.8493661406481) 0.445648 \n",
+ "[57.8493661406481,81.31283353333698) 0.417802 \n",
+ "[81.31283353333698,104.08576202392578) 0.248743 \n",
+ "[104.08576202392578,128.14104591262304) 0.115997 \n",
+ "[128.14104591262304,155.36119651794434) 0.252883 \n",
+ "[155.36119651794434,202.6728515625) -0.221016 \n",
+ "[202.6728515625,395.9569088293945) -0.620512 \n",
+ "[395.9569088293945,4709.61474609375) -0.454779 \n",
+ "\n",
+ " [0.7351365037506351,0.9988406500803395) \n",
+ "[2.0140850483155233,57.8493661406481) 0.855775 \n",
+ "[57.8493661406481,81.31283353333698) 0.799595 \n",
+ "[81.31283353333698,104.08576202392578) 1.013926 \n",
+ "[104.08576202392578,128.14104591262304) 1.296510 \n",
+ "[128.14104591262304,155.36119651794434) 1.987175 \n",
+ "[155.36119651794434,202.6728515625) 1.407061 \n",
+ "[202.6728515625,395.9569088293945) 0.054148 \n",
+ "[395.9569088293945,4709.61474609375) -0.197384 "
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "builder = LookupDistDotBuilder(\n",
+ " dotprops, matching_pairs, use_alpha=True, seed=2021\n",
+ ").with_bin_counts([8, 5])\n",
+ "smat = builder.build()\n",
+ "as_table = smat.to_dataframe()\n",
+ "as_table"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now that we can have this score matrix, we can use it for a problem which can be solved by NBLAST:\n",
+ "we've mixed up a bag of neurons which look very similar to some of our examples,\n",
+ "and need to know which they match with."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "NBLAST is optimized for data in microns and it looks like your queries are not in microns.\n",
+ "NBLAST is optimized for data in microns and it looks like your targets are not in microns.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b55a73a0832e495596df888b4e7d879d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Preparing: 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " "
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3a5c8d6bb2d54fd99d395e3ed492d37a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "NBlasting: 0%| | 0/5 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " 754538881_aug | \n",
+ " 722817260_aug | \n",
+ " 1734350908_aug | \n",
+ " 1734350788_aug | \n",
+ " 754534424_aug | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1734350788 | \n",
+ " -0.271504 | \n",
+ " -0.361381 | \n",
+ " -0.272018 | \n",
+ " 0.159048 | \n",
+ " -0.325920 | \n",
+ "
\n",
+ " \n",
+ " 1734350908 | \n",
+ " -0.400432 | \n",
+ " -0.491376 | \n",
+ " 0.858478 | \n",
+ " -0.437819 | \n",
+ " -0.216206 | \n",
+ "
\n",
+ " \n",
+ " 722817260 | \n",
+ " -0.157433 | \n",
+ " 0.127931 | \n",
+ " -0.365195 | \n",
+ " -0.407436 | \n",
+ " -0.393940 | \n",
+ "
\n",
+ " \n",
+ " 754534424 | \n",
+ " -0.237769 | \n",
+ " -0.413532 | \n",
+ " -0.193794 | \n",
+ " -0.401148 | \n",
+ " 0.857898 | \n",
+ "
\n",
+ " \n",
+ " 754538881 | \n",
+ " -0.021153 | \n",
+ " -0.213349 | \n",
+ " -0.200214 | \n",
+ " -0.303390 | \n",
+ " -0.115017 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " 754538881_aug 722817260_aug 1734350908_aug 1734350788_aug \\\n",
+ "1734350788 -0.271504 -0.361381 -0.272018 0.159048 \n",
+ "1734350908 -0.400432 -0.491376 0.858478 -0.437819 \n",
+ "722817260 -0.157433 0.127931 -0.365195 -0.407436 \n",
+ "754534424 -0.237769 -0.413532 -0.193794 -0.401148 \n",
+ "754538881 -0.021153 -0.213349 -0.200214 -0.303390 \n",
+ "\n",
+ " 754534424_aug \n",
+ "1734350788 -0.325920 \n",
+ "1734350908 -0.216206 \n",
+ "722817260 -0.393940 \n",
+ "754534424 0.857898 \n",
+ "754538881 -0.115017 "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "original_dps = dotprops[:len(original)]\n",
+ "new_dps = [\n",
+ " navis.make_dotprops(\n",
+ " augment_neuron(n),\n",
+ " k=5,\n",
+ " resample=False\n",
+ " )\n",
+ " for n in original\n",
+ "]\n",
+ "RNG.shuffle(new_dps)\n",
+ "\n",
+ "result = navis.nblast(\n",
+ " original_dps, new_dps,\n",
+ " use_alpha=True, scores=\"mean\", normalized=True,\n",
+ " smat=smat,\n",
+ " n_cores=1,\n",
+ ")\n",
+ "result.index = [dp.name for dp in original_dps]\n",
+ "result.columns = [dp.name for dp in new_dps]\n",
+ "result"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can see that while there aren't any particularly good scores\n",
+ "(this is a very small amount of training data, and one would normally preprocess the neurons),\n",
+ "in each case the original's best match is its augmented partner."
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "97618690babd5eee2c893391179393f7dcc498027dfffda6cb7b8d10b95474fd"
+ },
+ "kernelspec": {
+ "display_name": "Python 3.9.4 64-bit ('navis': pyenv)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.4"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/navis/core/dotprop.py b/navis/core/dotprop.py
index c83b04f9..aaa8175c 100644
--- a/navis/core/dotprop.py
+++ b/navis/core/dotprop.py
@@ -332,7 +332,9 @@ def dist_dots(self,
other: 'Dotprops',
alpha: bool = False,
distance_upper_bound: Optional[float] = None,
- **kwargs) -> Tuple[np.ndarray, np.ndarray]:
+ **kwargs) -> Union[
+ Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]
+ ]:
"""Query this Dotprops against another.
This function is mainly for ``navis.nblast``.
@@ -622,3 +624,6 @@ def to_skeleton(self,
tn._soma = self._soma
return tn
+
+ def __len__(self):
+ return len(self.points)
diff --git a/navis/nbl/base.py b/navis/nbl/base.py
index c3ad4203..9e51344d 100644
--- a/navis/nbl/base.py
+++ b/navis/nbl/base.py
@@ -17,6 +17,7 @@
import pandas as pd
from abc import ABC, abstractmethod
+from typing import Union, List
from .. import utils, config
@@ -24,6 +25,9 @@
FLOAT_DTYPES = {16: np.float16, 32: np.float32, 64: np.float64, None: None}
+NestedIndices = Union[int, List['NestedIndices']]
+
+
class Blaster(ABC):
"""Base class for blasting."""
@@ -37,7 +41,7 @@ def __init__(self, dtype=np.float64, progress=True):
self.ids = []
@abstractmethod
- def append(self, neurons):
+ def append(self, neurons) -> NestedIndices:
"""Append neurons."""
pass
@@ -157,3 +161,6 @@ def all_by_all(self, scores='forward'):
res.loc[:, :] = np.dstack((res, res.T)).max(axis=2)
return res
+
+ def __len__(self):
+ return len(self.neurons)
diff --git a/navis/nbl/nblast_funcs.py b/navis/nbl/nblast_funcs.py
index 32dae6fa..9bf1ffe1 100644
--- a/navis/nbl/nblast_funcs.py
+++ b/navis/nbl/nblast_funcs.py
@@ -16,7 +16,7 @@
import numbers
import os
-import uuid
+import operator
import numpy as np
import pandas as pd
@@ -25,9 +25,11 @@
from typing import Union, Optional, List
from typing_extensions import Literal
+from navis.nbl.smat import Lookup2d, smat_fcwb
+
from .. import utils, config
from ..core import NeuronList, Dotprops, make_dotprops
-from .base import Blaster
+from .base import Blaster, NestedIndices
__all__ = ['nblast', 'nblast_smart', 'nblast_allbyall', 'sim_to_dist']
@@ -37,75 +39,6 @@
logger = config.logger
-class ScoringFunction:
- """Class representing scoring function."""
-
- def __init__(self, smat):
- if isinstance(smat, type(None)):
- self.scoring_function = self.pass_through
- elif isinstance(smat, (pd.DataFrame, str)):
- self.parse_matrix(smat)
- self.scoring_function = self.score_lookup
- else:
- raise TypeError
-
- def __call__(self, dist, dot):
- return self.scoring_function(dist, dot)
-
- @property
- def max_dist(self):
- """Max distance considered by the scoring matrix.
-
- Returns ``None`` if pass through.
- """
- if self.scoring_function == self.pass_through:
- return None
- # The last bin is always `np.inf`, so we need the second last bin
- return self.dist_bins[-2]
-
- def pass_through(self, dist, dot):
- """Pass-through scores if no scoring matrix."""
- return dist * dot
-
- def score_lookup(self, dist, dot):
- return self.cells[
- np.digitize(dist, self.dist_bins),
- np.digitize(dot, self.dot_bins),
- ]
-
- def parse_matrix(self, smat):
- """Parse matrix."""
- if isinstance(smat, str):
- smat = pd.read_csv(smat, index_col=0)
-
- if not isinstance(smat, pd.DataFrame):
- raise TypeError(f'Excepted filepath or DataFrame, got "{type(smat)}"')
-
- self.cells = smat.to_numpy()
-
- self.dist_thresholds = [self.parse_interval(s) for s in smat.index]
- # Make sure right bin is open
- self.dist_thresholds[-1] = np.inf
- self.dist_bins = np.array(self.dist_thresholds, float)
-
- self.dot_thresholds = [self.parse_interval(s) for s in smat.columns]
- # Make sure right bin is open
- self.dot_thresholds[-1] = np.inf
- self.dot_bins = np.array(self.dot_thresholds, float)
-
- def parse_interval(self, s):
- """Strip brackets and parse right interval.
-
- Example
- -------
- >>> parse_intervals("(0,0.1]") # doctest: +SKIP
- 0.1
- """
- return float(s.strip("([])").split(",")[-1])
-
-
-NeuronId = Union[int, str, uuid.UUID]
-
class NBlaster(Blaster):
"""Implements version 2 of the NBLAST algorithm.
@@ -122,10 +55,16 @@ class NBlaster(Blaster):
normalized : bool
If True, will normalize scores by the best possible score
(i.e. self-self) of the query neuron.
- smat : str | pd.DataFrame
- Score matrix. If 'auto' (default), will use scoring matrices
+ smat : navis.nbl.smat.Lookup2d | pd.DataFrame | str
+ How to convert the point match pairs into an NBLAST score,
+ usually by a lookup table.
+ If 'auto' (default), will use scoring matrices
from FCWB. Same behaviour as in R's nat.nblast
- implementation. If ``smat=None`` the scores will be
+ implementation.
+ Dataframes will be used to build a ``Lookup2d``.
+ If ``limit_dist`` is not given,
+ will attempt to infer from the first axis of the lookup table.
+ If ``smat=None`` the scores will be
generated as the product of the distances and the dotproduct
of the vectors of nearest-neighbor pairs.
limit_dist : float | "auto" | None
@@ -148,27 +87,30 @@ def __init__(self, use_alpha=False, normalized=True, smat='auto',
self.approx_nn = approx_nn
self.desc = "NBlasting"
- if smat == 'auto':
- if self.use_alpha:
- smat = pd.read_csv(f'{smat_path}/smat_alpha_fcwb.csv',
- index_col=0)
- else:
- smat = pd.read_csv(f'{smat_path}/smat_fcwb.csv',
- index_col=0)
-
- self.score_fn = ScoringFunction(smat)
-
- if limit_dist == 'auto':
- self.distance_upper_bound = self.score_fn.max_dist
+ if smat is None:
+ self.score_fn = operator.mul
+ elif smat == 'auto':
+ self.score_fn = smat_fcwb(self.use_alpha)
+ elif isinstance(smat, pd.DataFrame):
+ self.score_fn = Lookup2d.from_dataframe(smat)
+ else:
+ self.score_fn = smat
+
+ if limit_dist == "auto":
+ try:
+ self.distance_upper_bound = self.score_fn.axes[0]._max
+ except AttributeError:
+ logger.warning("Could not infer distance upper bound from scoring function")
+ self.distance_upper_bound = None
else:
self.distance_upper_bound = limit_dist
- def append(self, dotprops) -> Union[List[NeuronId], NeuronId]:
+ def append(self, dotprops) -> NestedIndices:
"""Append dotprops.
- Returns the ID of the appended dotprops.
+ Returns the numerical index appended dotprops.
If dotprops is a (possibly nested) sequence of dotprops,
- return a (possibly nested) list of IDs.
+ return a (possibly nested) list of indices.
"""
if isinstance(dotprops, Dotprops):
return self._append_dotprops(dotprops)
@@ -178,12 +120,13 @@ def append(self, dotprops) -> Union[List[NeuronId], NeuronId]:
except TypeError: # i.e. not iterable
raise ValueError(f"Expected Dotprops or iterable thereof; got {type(dotprops)}")
- def _append_dotprops(self, dotprops: Dotprops) -> NeuronId:
+ def _append_dotprops(self, dotprops: Dotprops) -> int:
+ next_id = len(self)
self.neurons.append(dotprops)
self.ids.append(dotprops.id)
# Calculate score for self hit
self.self_hits.append(self.calc_self_hit(dotprops))
- return dotprops.id
+ return next_id
def calc_self_hit(self, dotprops):
"""Non-normalized value for self hit."""
@@ -195,7 +138,7 @@ def calc_self_hit(self, dotprops):
dots = np.repeat(1, len(dotprops.points)) * np.sqrt(alpha)
return self.score_fn(dists, dots).sum()
- def single_query_target(self, q_idx, t_idx, scores='forward'):
+ def single_query_target(self, q_idx: int, t_idx: int, scores='forward'):
"""Query single target against single target."""
# Take a short-cut if this is a self-self comparison
if q_idx == t_idx:
diff --git a/navis/nbl/smat.py b/navis/nbl/smat.py
new file mode 100644
index 00000000..8aa9ee8c
--- /dev/null
+++ b/navis/nbl/smat.py
@@ -0,0 +1,875 @@
+from __future__ import annotations
+from abc import ABC, abstractmethod
+from itertools import permutations
+import sys
+import os
+from collections import Counter
+from concurrent.futures import ProcessPoolExecutor
+from typing import (
+ Generic,
+ Hashable,
+ Iterator,
+ Mapping,
+ Optional,
+ Sequence,
+ Callable,
+ List,
+ Iterable,
+ Any,
+ Tuple,
+ TypeVar,
+ Union,
+)
+import logging
+from pathlib import Path
+from functools import lru_cache
+from copy import deepcopy
+import operator
+import math
+from collections import defaultdict
+
+import numpy as np
+import pandas as pd
+
+from ..core.neurons import Dotprops
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_SEED = 1991
+
+epsilon = sys.float_info.epsilon
+cpu_count = max(1, (os.cpu_count() or 2) - 1)
+
+fp = Path(__file__).resolve().parent
+smat_path = fp / "score_mats"
+
+
+def chunksize(it_len, cpu_count, min_chunk=50):
+ return max(min_chunk, int(it_len / (cpu_count * 4)))
+
+
+def yield_not_same(pairs: Iterable[Tuple[Any, Any]]) -> Iterator[Tuple[Any, Any]]:
+ for a, b in pairs:
+ if a != b:
+ yield a, b
+
+
+def concat_results(results: Iterable[List[np.ndarray]]) -> List[np.ndarray]:
+ intermediate = defaultdict(list)
+ for result_lst in results:
+ for idx, array in enumerate(result_lst):
+ intermediate[idx].append(array)
+
+ return [np.concatenate(arrs) for arrs in intermediate.values()]
+
+
+DotpropKey = Hashable
+
+
+class LookupNdBuilder:
+ def __init__(
+ self,
+ dotprops: Union[List[Dotprops], Mapping[DotpropKey, Dotprops]],
+ matching_lists: List[List[DotpropKey]],
+ match_fn: Callable[[Dotprops, Dotprops], List[np.ndarray]],
+ nonmatching_list: Optional[List[DotpropKey]] = None,
+ seed: int = DEFAULT_SEED,
+ ) -> None:
+ f"""Class for building an N-dimensional score lookup for NBLAST.
+
+ Once instantiated, the axes of the lookup table must be defined.
+ Call ``.with_digitizers()`` to manually define them,
+ or ``.with_bin_counts()`` to learn them from the matched-pair data.
+
+ Then call ``.build()`` to build the lookup table.
+
+ Parameters
+ ----------
+
+ dotprops : dict or list of Dotprops
+ An indexable, consistently-ordered sequence of all neurons
+ which will be used as the training set, as Dotprops objects.
+ matching_sets : list of lists of index into dotprops
+ Lists of neurons, as indices into ``dotprops``, which should be considered matches.
+ match_fn : Callable[[Dotprops, Dotprops], List[np.ndarray[float]]]
+ Function taking 2 arguments,
+ both instances of ``navis.core.neurons.Dotprops``,
+ and returning a list of 1D ``numpy.ndarray``s of floats.
+ The length of the list must be the same as the length of ``boundaries``.
+ The length of the ``array``s must be the same
+ as the number of points in the first argument.
+ This function returns values describing the quality of
+ point matches from a query to a target neuron.
+ nonmatching : list of index into dotprops, optional
+ List of neurons, as indices into ``dotprops``,
+ which should not be considered matches.
+ If not given, all ``dotprops`` will be used
+ (on the assumption that matches are a small subset of possible pairs).
+ seed : int, optional
+ Non-matching pairs are drawn at random using this seed,
+ by default {DEFAULT_SEED}
+ """
+ self.dotprops = dotprops
+ self.matching_lists = matching_lists
+ self._nonmatching_list = nonmatching_list
+ self.match_fn = match_fn
+
+ self.digitizers: Optional[List[Digitizer]] = None
+ self.bin_counts: Optional[List[int]] = None
+
+ self.seed = seed
+ self._ndim: Optional[int] = None
+
+ @property
+ def ndim(self) -> int:
+ if self._ndim is None:
+ idx1, idx2 = self._dotprop_keys()[:2]
+ self._ndim = len(self._query(idx1, idx2))
+ self._query.cache_clear()
+ return self._ndim
+
+ def with_digitizers(self, digitizers: List[Digitizer]):
+ """Specify the axes of the output lookup table directly.
+
+ Parameters
+ ----------
+ digitizers : List[Digitizer]
+
+ Returns
+ -------
+ self
+ For chaining convenience.
+ """
+ if len(digitizers) != self.ndim:
+ raise ValueError(
+ f"Match function returns {self.ndim} values "
+ f"but provided {len(digitizers)} digitizers"
+ )
+
+ self.digitizers = digitizers
+ self.bin_counts = None
+ return self
+
+ def with_bin_counts(self, bin_counts: List[int]):
+ """Specify the number of bins on each axis of the output lookup table.
+
+ The bin boundaries will be determined by evenly partitioning the data
+ from the matched pairs into quantiles, in each dimension.
+
+ Parameters
+ ----------
+ bin_counts : List[int]
+
+ Returns
+ -------
+ self
+ For chaining convenience.
+ """
+ if len(bin_counts) != self.ndim:
+ raise ValueError(
+ f"Match function returns {self.ndim} values "
+ f"but provided {len(bin_counts)} bin counts"
+ )
+
+ self.bin_counts = bin_counts
+ self.digitizers = None
+ return self
+
+ def _dotprop_keys(self) -> Sequence[DotpropKey]:
+ """Get all indices into dotprops instance member"""
+ try:
+ return self.dotprops.keys()
+ except AttributeError:
+ return range(len(self.dotprops))
+
+ @property
+ def nonmatching(self) -> List[DotpropKey]:
+ """Indices of nonmatching set of neurons"""
+ if self._nonmatching_list is None:
+ return list(self._dotprop_keys())
+ return self._nonmatching_list
+
+ def _yield_matching_pairs(self) -> Iterator[Tuple[DotpropKey, DotpropKey]]:
+ """Yield all index pairs within all matching pairs"""
+ for ms in self.matching_lists:
+ yield from yield_not_same(permutations(ms, 2))
+
+ def _yield_nonmatching_pairs(self) -> Iterator[Tuple[DotpropKey, DotpropKey]]:
+ """Yield all index pairs within nonmatching list"""
+ # todo: this could be much better, use meshgrid or shuffle index arrays
+ return yield_not_same(permutations(self.nonmatching, 2))
+
+ def _empty_counts(self) -> np.ndarray:
+ """Create an empty array in which to store counts; shape determined by digitizer sizes."""
+ shape = [len(b) for b in self.digitizers]
+ return np.zeros(shape, int)
+
+ @lru_cache(None)
+ def _query(self, q_idx, t_idx) -> List[np.ndarray]:
+ """Get the results of applying the match function to dotprops specified by indices"""
+ return self.match_fn(self.dotprops[q_idx], self.dotprops[t_idx])
+
+ def _query_many(self, idx_pairs, threads=None) -> Iterator[List[np.ndarray]]:
+ """Yield results from querying many pairs of dotprop indices"""
+ if threads is None or threads == 0 and cpu_count == 1:
+ for q_idx, t_idx in idx_pairs:
+ yield self._query(q_idx, t_idx)
+ return
+
+ threads = threads or cpu_count
+ idx_pairs = np.asarray(idx_pairs)
+ chunks = chunksize(len(idx_pairs), threads)
+
+ with ProcessPoolExecutor(threads) as exe:
+ yield from exe.map(
+ self._query, idx_pairs[:, 0], idx_pairs[:, 1], chunksize=chunks
+ )
+
+ def _query_to_idxs(self, q_idx, t_idx, counts=None):
+ """Produce a digitized counts array from a given query-target pair"""
+ return self._count_results(self._query(q_idx, t_idx), counts)
+
+ def _count_results(self, results: List[np.ndarray], counts=None):
+ """Convert raw match function ouput into a digitized counts array.
+
+ Requires digitizers.
+ """
+ idxs = [dig(r) for dig, r in zip(self.digitizers, results)]
+
+ if counts is None:
+ counts = self._empty_counts()
+
+ for idx in zip(*idxs):
+ counts[idx] += 1
+
+ return counts
+
+ def _counts_array(self, idx_pairs, threads=None, cache=False):
+ """Convert index pairs into a digitized counts array.
+
+ Requires digitizers.
+ """
+ counts = self._empty_counts()
+ if threads is None or threads == 0 and cpu_count == 1:
+ for q_idx, t_idx in idx_pairs:
+ counts = self._query_to_idxs(q_idx, t_idx, counts)
+ return counts
+
+ threads = threads or cpu_count
+ idx_pairs = np.asarray(idx_pairs, dtype=int)
+ chunks = chunksize(len(idx_pairs), threads)
+
+ # because digitizing is not necessarily free,
+ # keep this parallelisation separate to that in _query_many
+ with ProcessPoolExecutor(threads) as exe:
+ for these_counts in exe.map(
+ self._query_to_idxs,
+ idx_pairs[:, 0],
+ idx_pairs[:, 1],
+ chunksize=chunks,
+ ):
+ counts += these_counts
+ if not cache:
+ self._query.cache_clear
+
+ return counts
+
+ def _pick_nonmatching_pairs(self, n_matching_qual_vals):
+ """Using the seeded RNG, pick which nonmatching pairs to use."""
+ # pre-calculating which pairs we're going to use,
+ # rather than drawing them as we need them,
+ # means that we can parallelise the later step more effectively.
+ # Slowdowns here are practically meaningless
+ # because of how long distdot calculation will take
+ all_nonmatching_pairs = list(self._yield_nonmatching_pairs())
+ nonmatching_pairs = []
+ n_nonmatching_qual_vals = 0
+ rng = np.random.default_rng(self.seed)
+ while n_nonmatching_qual_vals < n_matching_qual_vals:
+ idx = rng.integers(0, len(all_nonmatching_pairs))
+ nonmatching_pair = all_nonmatching_pairs.pop(idx)
+ nonmatching_pairs.append(nonmatching_pair)
+ n_nonmatching_qual_vals += len(self.dotprops[nonmatching_pair[0]])
+
+ return nonmatching_pairs
+
+ def _get_pairs(self):
+ matching_pairs = list(set(self._yield_matching_pairs()))
+ # need to know the eventual distdot count
+ # so we know how many non-matching pairs to draw
+ q_idx_count = Counter(p[0] for p in matching_pairs)
+ n_matching_qual_vals = sum(
+ len(self.dotprops[q_idx]) * n_reps for q_idx, n_reps in q_idx_count.items()
+ )
+
+ nonmatching_pairs = self._pick_nonmatching_pairs(n_matching_qual_vals)
+ return matching_pairs, nonmatching_pairs
+
+ def _build(self, threads, cache=False) -> Tuple[List[Digitizer], np.ndarray]:
+ if self.digitizers is None and self.bin_counts is None:
+ raise ValueError(
+ "Builder needs either digitizers or bin_counts; see with_* methods"
+ )
+
+ matching_pairs, nonmatching_pairs = self._get_pairs()
+
+ if self.digitizers:
+ match_counts = self._counts_array(matching_pairs, threads, cache)
+ else:
+ match_results = concat_results(self._query_many(matching_pairs, threads))
+ self.digitizers = [
+ Digitizer.from_data(data, nbins)
+ for data, nbins in zip(match_results, self.bin_counts)
+ ]
+ if not cache:
+ self._query.cache_clear()
+ match_counts = self._count_results(match_results)
+
+ nonmatch_counts = self._counts_array(nonmatching_pairs, threads)
+
+ # account for there being different total numbers of match/nonmatch dist dots
+ matching_factor = nonmatch_counts.sum() / match_counts.sum()
+ if np.any(match_counts + nonmatch_counts == 0):
+ logger.warning("Some lookup cells have no data in them")
+
+ cells = np.log2(
+ (match_counts * matching_factor + epsilon) / (nonmatch_counts + epsilon)
+ )
+
+ return self.digitizers, cells
+
+ def build(self, threads=None, cache=False) -> LookupNd:
+ """Build the score matrix.
+
+ All non-identical neuron pairs within all matching sets are selected,
+ and distdots calculated for those pairs.
+ Then, the minimum number of non-matching pairs are randomly drawn
+ so that at least as many distdots can be calculated for non-matching
+ pairs.
+
+ In each bin of the score matrix, the log2 odds ratio of a distdot
+ in that bin belonging to a match vs. non-match is calculated.
+
+ Parameters
+ ----------
+ threads : int, optional
+ If None, act in serial.
+ If 0, use cpu_count - 1.
+ Otherwise, use the given value.
+
+ Returns
+ -------
+ LookupNd
+ """
+ dig, cells = self._build(threads, cache)
+ return LookupNd(dig, cells)
+
+
+def dist_dot(q: Dotprops, t: Dotprops):
+ return list(q.dist_dots(t))
+
+
+def dist_dot_alpha(q: Dotprops, t: Dotprops):
+ dist, dot, alpha = q.dist_dots(t, alpha=True)
+ return [dist, dot * np.sqrt(alpha)]
+
+
+class LookupDistDotBuilder(LookupNdBuilder):
+ def __init__(
+ self,
+ dotprops: Union[List[Dotprops], Mapping[DotpropKey, Dotprops]],
+ matching_lists: List[List[DotpropKey]],
+ nonmatching_list: Optional[List[DotpropKey]] = None,
+ use_alpha: bool = False,
+ seed: int = DEFAULT_SEED,
+ ):
+ f"""Class for building a 2-dimensional score lookup for NBLAST.
+
+ The scores are
+
+ 1. The distances between best-matching points
+ 2. The dot products of direction vectors around those points,
+ optionally scaled by the colinearity ``alpha``.
+
+ Parameters
+ ----------
+ dotprops : dict or list of Dotprops
+ An indexable sequence of all neurons which will be used as the training set,
+ as Dotprops objects.
+ matching_lists : list of lists of indices into dotprops
+ List of neurons, as indices into ``dotprops``, which should be considered matches.
+ nonmatching_list : list of indices into dotprops, optional
+ List of neurons, as indices into ``dotprops``,
+ which should not be considered matches.
+ If not given, all ``dotprops`` will be used
+ (on the assumption that matches are a small subset of possible pairs).
+ use_alpha : bool, optional
+ If true, multiply the dot product by the geometric mean
+ of the matched points' alpha values
+ (i.e. ``sqrt(alpha1 * alpha2)``).
+ seed : int, optional
+ Non-matching pairs are drawn at random using this seed,
+ by default {DEFAULT_SEED}
+ """
+ match_fn = dist_dot_alpha if use_alpha else dist_dot
+ super().__init__(
+ dotprops,
+ matching_lists,
+ match_fn,
+ nonmatching_list,
+ seed,
+ )
+ self._ndim = 2
+
+ def build(self, threads=None, cache=False) -> Lookup2d:
+ (dig0, dig1), cells = self._build(threads, cache)
+ return Lookup2d(dig0, dig1, cells)
+
+
+def is_monotonically_increasing(lst):
+ for prev_idx, item in enumerate(lst[1:]):
+ if item <= lst[prev_idx]:
+ return False
+ return True
+
+
+def parse_boundary(item: str):
+ explicit_interval = item[0] + item[-1]
+ if explicit_interval == "[)":
+ right = False
+ elif explicit_interval == "(]":
+ right = True
+ else:
+ raise ValueError(
+ f"Enclosing characters '{explicit_interval}' do not match a half-open interval"
+ )
+ return tuple(float(i) for i in item[1:-1].split(",")), right
+
+
+T = TypeVar("T")
+
+
+class LookupAxis(ABC, Generic[T]):
+ """Class converting some data into a linear index."""
+
+ @abstractmethod
+ def __len__(self) -> int:
+ """Number of bins represented by this instance."""
+ pass
+
+ @abstractmethod
+ def __call__(self, value: Union[T, Sequence[T]]) -> Union[int, Sequence[int]]:
+ """Convert some data into a linear index.
+
+ Parameters
+ ----------
+ value : Union[T, Sequence[T]]
+ Value to convert into an index
+
+ Returns
+ -------
+ Union[int, Sequence[int]]
+ If a scalar was given, return a scalar; otherwise, a numpy array of ints.
+ """
+ pass
+
+
+class SimpleLookup(LookupAxis[Hashable]):
+ def __init__(self, items: List[Hashable]):
+ """Look up in a list of items and return their index.
+
+ Parameters
+ ----------
+ items : List[Hashable]
+ The item's position in the list is the index which will be returned.
+
+ Raises
+ ------
+ ValueError
+ items are non-unique.
+ """
+ self.items = {item: idx for idx, item in enumerate(items)}
+ if len(self.items) != len(items):
+ raise ValueError("Items are not unique")
+
+ def __len__(self) -> int:
+ return len(self.items)
+
+ def __call__(self, value: Union[Hashable, Sequence[Hashable]]) -> Union[int, Sequence[int]]:
+ if np.isscalar(value):
+ return self.items[value]
+ else:
+ return np.array([self.items[v] for v in value], int)
+
+
+class Digitizer(LookupAxis[float]):
+ def __init__(
+ self,
+ boundaries: Sequence[float],
+ clip: Tuple[bool, bool] = (True, True),
+ right=False,
+ ):
+ """Class converting continuous values into discrete indices.
+
+ Parameters
+ ----------
+ boundaries : Sequence[float]
+ N boundaries specifying N-1 bins.
+ Must be monotonically increasing.
+ clip : Tuple[bool, bool], optional
+ Whether to set the bottom and top boundaries to -infinity and infinity respectively,
+ effectively clipping incoming values: by default (True, True).
+ False means "add a new bin for out-of-range values".
+ right : bool, optional
+ Whether bins should include their right (rather than left) boundary,
+ by default False
+ """
+ self.right = right
+
+ boundaries = list(boundaries)
+ self._min = -math.inf
+ if clip[0]:
+ self._min = boundaries[0]
+ boundaries[0] = -math.inf
+ elif boundaries[0] != -math.inf:
+ self._min = -math.inf
+ boundaries.insert(0, -math.inf)
+
+ self._max = math.inf
+ if clip[1]:
+ self._max = boundaries[-1]
+ boundaries[-1] = math.inf
+ elif boundaries[-1] != math.inf:
+ boundaries.append(math.inf)
+
+ if not is_monotonically_increasing(boundaries):
+ raise ValueError("Boundaries are not monotonically increasing")
+
+ self.boundaries = np.asarray(boundaries)
+
+ def __len__(self):
+ return len(self.boundaries) - 1
+
+ def __call__(self, value: float):
+ # searchsorted is marginally faster than digitize as it skips monotonicity checks
+ return (
+ np.searchsorted(
+ self.boundaries, value, side="left" if self.right else "right"
+ )
+ - 1
+ )
+
+ def to_strings(self) -> List[str]:
+ if self.right:
+ lb = "("
+ rb = "]"
+ else:
+ lb = "["
+ rb = ")"
+
+ b = self.boundaries.copy()
+ b[0] = self._min
+ b[-1] = self._max
+ return [
+ f"{lb}{lower},{upper}{rb}"
+ for lower, upper in zip(b[:-1], b[1:])
+ ]
+
+ @classmethod
+ def from_strings(cls, interval_strs: Sequence[str]):
+ """Set digitizer boundaries based on a sequence of interval expressions.
+
+ e.g. ``["(0, 1]", "(1, 5]", "(5, 10]"]``
+
+ The lowermost and uppermost boundaries are converted to -infinity and infinity respectively.
+
+ Parameters
+ ----------
+ bound_strs : Sequence[str]
+ Strings representing intervals, which must abut and have open/closed boundaries
+ specified by brackets.
+
+ Returns
+ -------
+ Digitizer
+ """
+ bounds: List[float] = []
+ last_upper = None
+ last_right = None
+ for item in interval_strs:
+ (lower, upper), right = parse_boundary(item)
+ bounds.append(float(lower))
+
+ if last_right is not None:
+ if right != last_right:
+ raise ValueError("Inconsistent half-open interval")
+ else:
+ last_right = right
+
+ if last_upper is not None:
+ if lower != last_upper:
+ raise ValueError("Half-open intervals do not abut")
+
+ last_upper = upper
+
+ bounds.append(float(last_upper))
+ return cls(bounds, right=last_right)
+
+ @classmethod
+ def from_linear(cls, lower: float, upper: float, nbins: int, right=False):
+ """Choose digitizer boundaries spaced linearly between two values.
+
+ Input values will be clipped to fit within the given interval.
+
+ Parameters
+ ----------
+ lower : float
+ Lowest value
+ upper : float
+ Highest value
+ nbins : int
+ Number of bins
+ right : bool, optional
+ Whether bins should include their right (rather than left) boundary,
+ by default False
+
+ Returns
+ -------
+ Digitizer
+ """
+ arr = np.linspace(lower, upper, nbins + 1, endpoint=True)
+ return cls(arr, right=right)
+
+ @classmethod
+ def from_geom(cls, lowest_upper: float, highest_lower: float, nbins: int, right=False):
+ """Choose digitizer boundaries in a geometric sequence.
+
+ Additional bins will be added above and below the given values.
+
+ Parameters
+ ----------
+ lowest_upper : float
+ Upper bound of the lowest bin.
+ The lower bound of the lowest bin is often 0, which cannot be represented in a nontrivial geometric sequence.
+ highest_lower : float
+ Lower bound of the highest bin.
+ nbins : int
+ Number of bins
+ right : bool, optional
+ Whether bins should include their right (rather than left) boundary,
+ by default False
+
+ Returns
+ -------
+ Digitizer
+ """
+ arr = np.geomspace(lowest_upper, highest_lower, nbins - 1, True)
+ return cls(arr, clip=(False, False), right=right)
+
+ @classmethod
+ def from_data(cls, data: Sequence[float], nbins: int, right=False):
+ """Choose digitizer boundaries to evenly partition the given values.
+
+ Parameters
+ ----------
+ data : Sequence[float]
+ Data which should be evenly partitioned by the resulting digitizer.
+ nbins : int
+ Number of bins
+ right : bool, optional
+ Whether bins should include their right (rather than left) boundary,
+ by default False
+
+ Returns
+ -------
+ Digitizer
+ """
+ arr = np.quantile(data, np.linspace(0, 1, nbins + 1, True))
+ return cls(arr, right=right)
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Digitizer):
+ return NotImplemented
+ return self.right == other.right and np.allclose(
+ self.boundaries, other.boundaries
+ )
+
+
+class LookupNd:
+ def __init__(self, axes: List[LookupAxis], cells: np.ndarray):
+ if [len(b) for b in axes] != list(cells.shape):
+ raise ValueError("boundaries and cells have inconsistent bin counts")
+ self.axes = axes
+ self.cells = cells
+
+ def __call__(self, *args):
+ if len(args) != len(self.axes):
+ raise TypeError(
+ f"Lookup takes {len(self.axes)} arguments but {len(args)} were given"
+ )
+
+ idxs = tuple(d(arg) for d, arg in zip(self.axes, args))
+ out = self.cells[idxs]
+ return out
+
+
+class Lookup2d(LookupNd):
+ """Convenience class inheriting from LookupNd for the common 2D float case.
+ Provides IO with pandas DataFrames.
+ """
+
+ def __init__(self, axis0: Digitizer, axis1: Digitizer, cells: np.ndarray):
+ """2D lookup table for convert NBLAST matches to scores.
+
+ Commonly read from a ``pandas.DataFrame``
+ or trained on data using a ``LookupDistDotBuilder``.
+
+ Parameters
+ ----------
+ digitizer0 : Digitizer
+ How to convert continuous values into an index for the first axis.
+ digitizer1 : Digitizer
+ How to convert continuous values into an index for the second axis.
+ cells : np.ndarray
+ Values to look up in the table.
+ """
+ super().__init__([axis0, axis1], cells)
+
+ def to_dataframe(self) -> pd.DataFrame:
+ """Convert the lookup table into a ``pandas.DataFrame``.
+
+ From there, it can be shared, saved, and so on.
+
+ The index and column labels describe the intervals represented by that axis.
+
+ Returns
+ -------
+ pd.DataFrame
+ """
+ return pd.DataFrame(
+ self.cells,
+ self.axes[0].to_strings(),
+ self.axes[1].to_strings(),
+ )
+
+ @classmethod
+ def from_dataframe(cls, df: pd.DataFrame):
+ f"""Parse score matrix from a dataframe with string index and column labels.
+
+ Expects the index and column labels to specify an interval
+ like ``f"[{{lower}},{{upper}})"``.
+ Will replace the lowermost and uppermost bound with -inf and inf
+ if they are not already.
+ """
+ return cls(
+ Digitizer.from_strings(df.index),
+ Digitizer.from_strings(df.columns),
+ df.to_numpy(),
+ )
+
+
+@lru_cache(maxsize=None)
+def _smat_fcwb(alpha=False):
+ # cached private function defers construction
+ # until needed (speeding startup),
+ # but avoids repeated reads (speeding later uses)
+ fname = ("smat_fcwb.csv", "smat_alpha_fcwb.csv")[alpha]
+ fpath = smat_path / fname
+
+ return Lookup2d.from_dataframe(pd.read_csv(fpath, index_col=0))
+
+
+def smat_fcwb(alpha=False):
+ # deepcopied so that mutations do not propagate to cache
+ return deepcopy(_smat_fcwb(alpha))
+
+
+def check_score_fn(fn: Callable, nargs=2, scalar=True, array=True):
+ """Checks functionally that the callable can be used as a score function.
+
+ Parameters
+ ----------
+ nargs : optional int, default 2
+ How many positional arguments the score function should have.
+ scalar : optional bool, default True
+ Check that the function can be used on ``nargs`` scalars.
+ array : optional bool, default True
+ Check that the function can be used on ``nargs`` 1D ``numpy.ndarray``s.
+
+ Raises
+ ------
+ ValueError
+ If the score function is not appropriate.
+ """
+ if scalar:
+ scalars = [0.5] * nargs
+ if not isinstance(fn(*scalars), float):
+ raise ValueError("smat does not take 2 floats and return a float")
+
+ if array:
+ test_arr = np.array([0.5] * 3)
+ arrs = [test_arr] * nargs
+ try:
+ out = fn(*arrs)
+ except Exception as e:
+ raise ValueError(f"Failed to use smat with numpy arrays: {e}")
+
+ if out.shape != test_arr.shape:
+ raise ValueError(
+ f"smat produced inconsistent shape: input {test_arr.shape}; output {out.shape}"
+ )
+
+
+SCORE_FN_DESCR = """
+NBLAST score functions take 2 floats or N-length numpy arrays of floats
+(for matched dotprop points/tangents, distance and dot product;
+the latter possibly scaled by the geometric mean of the alpha colinearity values)
+and returns a float or N-length numpy array of floats.
+""".strip().replace(
+ "\n", " "
+)
+
+
+def parse_score_fn(smat, alpha=False):
+ f"""Interpret ``smat`` as a score function.
+ Primarily for backwards compatibility.
+ {SCORE_FN_DESCR}
+ Parameters
+ ----------
+ smat : None | "auto" | str | os.PathLike | pandas.DataFrame | Callable[[float, float], float]
+ If ``None``, use ``operator.mul``.
+ If ``"auto"``, use ``navis.nbl.smat.smat_fcwb(alpha)``.
+ If a dataframe, use ``navis.nbl.smat.Lookup2d.from_dataframe(smat)``.
+ If another string or path-like, load from CSV in a dataframe and uses as above.
+ Also checks the signature of the callable.
+ Raises an error, probably a ValueError, if it can't be interpreted.
+ alpha : optional bool, default False
+ If ``smat`` is ``"auto"``, choose whether to use the FCWB matrices
+ with or without alpha.
+ Returns
+ -------
+ Callable
+ Raises
+ ------
+ ValueError
+ If score function cannot be interpreted.
+ """
+ if smat is None:
+ smat = operator.mul
+ elif smat == "auto":
+ smat = smat_fcwb(alpha)
+
+ if isinstance(smat, (str, os.PathLike)):
+ smat = pd.read_csv(smat, index_col=0)
+
+ if isinstance(smat, pd.DataFrame):
+ smat = Lookup2d.from_dataframe(smat)
+
+ if not callable(smat):
+ raise ValueError(
+ "smat should be a callable, a path, a pandas.DataFrame, or 'auto'"
+ )
+
+ check_score_fn(smat)
+
+ return smat
diff --git a/navis/nbl/synblast_funcs.py b/navis/nbl/synblast_funcs.py
index c0b7812d..433039ea 100644
--- a/navis/nbl/synblast_funcs.py
+++ b/navis/nbl/synblast_funcs.py
@@ -15,6 +15,7 @@
"""Module contains functions implementing SyNBLAST."""
import os
+import operator
import numpy as np
import pandas as pd
@@ -27,10 +28,11 @@
from .. import config, utils
from ..core import NeuronList, BaseNeuron
-from .base import Blaster
+from .base import Blaster, NestedIndices
+from .smat import Lookup2d
-from .nblast_funcs import (check_microns, find_optimal_partition, ScoringFunction,
- nblast_preflight)
+from .nblast_funcs import (check_microns, find_optimal_partition,
+ nblast_preflight, smat_fcwb)
__all__ = ['synblast']
@@ -59,10 +61,14 @@ class SynBlaster(Blaster):
by_type : bool
If True will only compare synapses with the same value in
the "type" column.
- smat : str | pd.DataFrame
- Score matrix. If 'auto' (default), will use scoring matrices
+ smat : navis.nbl.smat.Lookup2d | pd.DataFrame | str
+ How to convert the point match pairs into an NBLAST score,
+ usually by a lookup table.
+ If 'auto' (default), will use scoring matrices
from FCWB. Same behaviour as in R's nat.nblast
- implementation. If ``smat=None`` the scores will be
+ implementation.
+ Dataframes will be used to build a ``Lookup2d``.
+ If ``smat=None`` the scores will be
generated as the product of the distances and the dotproduct
of the vectors of nearest-neighbor pairs.
progress : bool
@@ -77,15 +83,19 @@ def __init__(self, normalized=True, by_type=True,
self.normalized = normalized
self.by_type = by_type
- if smat == 'auto':
- smat = pd.read_csv(f'{smat_path}/smat_fcwb.csv',
- index_col=0)
+ if smat is None:
+ self.score_fn = operator.mul
+ elif smat == 'auto':
+ self.score_fn = smat_fcwb()
+ elif isinstance(smat, pd.DataFrame):
+ self.score_fn = Lookup2d.from_dataframe(smat)
+ else:
+ self.score_fn = smat
- self.score_fn = ScoringFunction(smat)
self.ids = []
- def append(self, neuron, id=None):
- """Append neurons/connector tables, returning ids of added objects"""
+ def append(self, neuron, id=None) -> NestedIndices:
+ """Append neurons/connector tables, returning numerical indices of added objects"""
if isinstance(neuron, pd.DataFrame):
return self._append_connectors(neuron, id)
@@ -102,10 +112,11 @@ def append(self, neuron, id=None):
f"{type(neuron)}"
)
- def _append_connectors(self, connectors: pd.DataFrame, id):
+ def _append_connectors(self, connectors: pd.DataFrame, id) -> int:
if id is None:
raise ValueError("Explicit non-None id required for appending connectors")
+ next_idx = len(self)
self.ids.append(id)
self.neurons.append({})
if not self.by_type:
@@ -123,13 +134,13 @@ def _append_connectors(self, connectors: pd.DataFrame, id):
# Calculate score for self hit
self.self_hits.append(self.calc_self_hit(connectors))
- return id
+ return next_idx
def calc_self_hit(self, cn):
"""Non-normalized value for self hit."""
return cn.shape[0] * self.score_fn(0, 1)
- def single_query_target(self, q_idx, t_idx, scores='forward'):
+ def single_query_target(self, q_idx: int, t_idx: int, scores='forward'):
"""Query single target against single target."""
# Take a short-cut if this is a self-self comparison
if q_idx == t_idx:
diff --git a/tests/conftest.py b/tests/conftest.py
index a35954d7..ec872834 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,5 +1,6 @@
from pathlib import Path
import os
+from typing import List
import pandas as pd
import pytest
@@ -9,7 +10,7 @@
import navis
-@pytest.fixture
+@pytest.fixture(scope="session")
def data_dir():
return Path(__file__).resolve().parent.parent / "navis" / "data"
@@ -17,8 +18,8 @@ def data_dir():
@pytest.fixture(
params=["Path", "pathstr", "swcstr", "textbuffer", "rawbuffer", "DataFrame"]
)
-def swc_source(request, data_dir: Path):
- swc_path: Path = data_dir / "swc" / "722817260.swc"
+def swc_source(request, swc_paths: List[Path]):
+ swc_path: Path = swc_paths[0]
if request.param == "Path":
yield swc_path
elif request.param == "pathstr":
@@ -42,9 +43,9 @@ def swc_source(request, data_dir: Path):
@pytest.fixture(
params=["dirstr", "dirpath", "list", "listwithdir"],
)
-def swc_source_multi(request, data_dir: Path):
- dpath = data_dir / "swc"
- fpath = dpath / "722817260.swc"
+def swc_source_multi(request, swc_paths: List[Path]):
+ fpath = swc_paths[0]
+ dpath = fpath.parent
if request.param == "dirstr":
yield str(dpath)
elif request.param == "dirpath":
@@ -74,3 +75,43 @@ def voxel_nrrd_path(tmp_path):
nrrd.write(os.fspath(path), data, header)
return path
+
+
+def data_paths(dpath, glob="*"):
+ return sorted(dpath.glob(glob))
+
+
+@pytest.fixture(scope="session")
+def swc_paths(data_dir: Path):
+ return data_paths(data_dir / "swc", "*.swc")
+
+
+@pytest.fixture(scope="session")
+def gml_paths(data_dir: Path):
+ return data_paths(data_dir / "gml", "*.gml")
+
+
+@pytest.fixture(scope="session")
+def obj_paths(data_dir: Path):
+ return data_paths(data_dir / "obj", "*.obj")
+
+
+@pytest.fixture(scope="session")
+def synapses_paths(data_dir: Path):
+ return data_paths(data_dir / "synapses", "*.csv")
+
+
+@pytest.fixture(scope="session")
+def volumes_paths(data_dir: Path):
+ return data_paths(data_dir / "volumes", "*.obj")
+
+
+@pytest.fixture
+def treeneuron_dfs(swc_paths, synapses_paths):
+ swc_reader = navis.io.swc_io.SwcReader()
+ out = []
+ for swc_path, syn_path in zip(swc_paths, synapses_paths):
+ neuron = swc_reader.read_file_path(swc_path)
+ neuron.connectors = pd.read_csv(syn_path)
+ out.append(neuron)
+ return out
diff --git a/tests/test_nbl/__init__.py b/tests/test_nbl/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_nbl/test_smat.py b/tests/test_nbl/test_smat.py
new file mode 100644
index 00000000..04d2d4f6
--- /dev/null
+++ b/tests/test_nbl/test_smat.py
@@ -0,0 +1,141 @@
+from navis import Dotprops
+import pytest
+
+import numpy as np
+
+from navis.nbl.smat import (
+ Digitizer, LookupNd, Lookup2d, LookupDistDotBuilder
+)
+
+
+SMALLEST_DIM_SIZE = 3
+SEED = 1991
+
+
+def lookup_args(ndim):
+ f"""
+ Create arguments for an ND lookup table.
+ The first dimension is of size {SMALLEST_DIM_SIZE},
+ and subsequent dimensions are 1 longer than the previous.
+ The data in the cells are the sequence from 0
+ to the size of the array.
+ The boundaries are 0 to the length of the dimension,
+ with the left and rightmost values replaced with -inf and inf respectively.
+
+ Examples
+ --------
+ >>> lookup_args(2)
+ (
+ [
+ array([-inf, 1, 2, inf]),
+ array([-inf, 1, 2, 3, inf]),
+ ],
+ array([
+ [ 0, 1, 2, 3 ],
+ [ 4, 5, 6, 7 ],
+ [ 8, 9, 10, 11 ],
+ ]),
+ )
+ """
+ shape = tuple(range(SMALLEST_DIM_SIZE, SMALLEST_DIM_SIZE + ndim))
+ cells = np.arange(np.product(shape)).reshape(shape)
+ digitizers = [Digitizer(np.arange(s + 1, dtype=float)) for s in shape]
+ return digitizers, cells
+
+
+def fmt_array(arg):
+ if np.isscalar(arg):
+ return str(arg)
+ else:
+ return "[" + ",".join(fmt_array(v) for v in arg) + "]"
+
+
+@pytest.mark.parametrize(
+ ["ndim"], [[1], [2], [3], [4], [5]], ids=lambda x: f"{x}D"
+)
+@pytest.mark.parametrize(
+ ["arg"],
+ [
+ (-1000,),
+ (0,),
+ (1,),
+ (1.5,),
+ (2,),
+ (1000,),
+ ([-1000, 0, 1, 1.5, 2, 1000],),
+ ],
+ ids=fmt_array,
+)
+def test_lookupNd(ndim, arg):
+ lookup = LookupNd(*lookup_args(ndim))
+
+ args = [arg for _ in range(ndim)]
+ expected_arr_idx = np.floor([
+ np.clip(arg, 0, dim + SMALLEST_DIM_SIZE - 1) for dim in range(ndim)
+ ]).astype(int)
+ expected_val = np.ravel_multi_index(
+ tuple(expected_arr_idx), lookup.cells.shape
+ )
+
+ response = lookup(*args)
+ assert np.all(response == expected_val)
+
+
+def test_lookup2d_roundtrip():
+ digs, cells = lookup_args(2)
+ lookup = Lookup2d(*digs, cells=cells)
+ df = lookup.to_dataframe()
+ lookup2 = Lookup2d.from_dataframe(df)
+ assert np.allclose(lookup.cells, lookup2.cells)
+ for b1, b2 in zip(lookup.axes, lookup2.axes):
+ assert b1 == b2
+
+
+def prepare_lookupdistdotbuilder(neurons, alpha=False, k=5):
+ k = 5
+ dotprops = [Dotprops(n.nodes[["x", "y", "z"]], k) for n in neurons]
+ n_orig = len(dotprops)
+
+ # make jittered copies of these neurons
+ rng = np.random.default_rng(SEED)
+ jitter_sigma = 50
+ matching_lists = []
+ for idx, dp in enumerate(dotprops[:]):
+ dotprops.append(
+ Dotprops(
+ dp.points + rng.normal(0, jitter_sigma, dp.points.shape), k
+ )
+ )
+ # assign each neuron its jittered self as a match
+ matching_lists.append([idx, idx + n_orig])
+
+ # original neurons should all not match each other
+ nonmatching = list(range(n_orig))
+
+ # max distance between any 2 points in the data
+ # for calculating dist boundaries
+ max_dist = np.linalg.norm(
+ np.ptp(
+ np.concatenate([dp.points for dp in dotprops], axis=0), axis=0,
+ )
+ )
+
+ return LookupDistDotBuilder(
+ dotprops,
+ matching_lists,
+ nonmatching,
+ alpha,
+ seed=SEED + 1,
+ ).with_digitizers([
+ Digitizer.from_geom(10, max_dist, 5),
+ Digitizer.from_linear(0, 1, 5),
+ ])
+
+
+@pytest.mark.parametrize(["alpha"], [(True,), (False,)])
+@pytest.mark.parametrize(["threads"], [(0,), (2,), (None,)])
+def test_lookupdistdotbuilder_builds(treeneuron_dfs, threads, alpha):
+ builder = prepare_lookupdistdotbuilder(treeneuron_dfs, alpha)
+ lookup = builder.build(threads)
+ # `pytest -rP` to see output
+ print(lookup.to_dataframe())