Skip to content

Commit

Permalink
Feat: Ensure all features return built-in types
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed Sep 30, 2022
1 parent acbba0f commit 384cf12
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 4 deletions.
29 changes: 26 additions & 3 deletions neurom/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
"""
import operator
from enum import Enum
from functools import reduce
from functools import reduce, wraps

import numpy as np

from neurom.core import Population, Morphology, Neurite
from neurom.core.morphology import iter_neurites
Expand Down Expand Up @@ -174,15 +176,36 @@ def _register_feature(namespace: NameSpace, name, func, shape):
def feature(shape, namespace: NameSpace, name=None):
"""Feature decorator to automatically register the feature in the appropriate namespace.
This decorator also ensure that the results of the features are casted to built-in types.
Arguments:
shape(tuple): the expected shape of the feature values
namespace(string): a namespace, see :class:`NameSpace`
name(string): name of the feature, used to access the feature via `neurom.features.get()`.
"""

def inner(func):
_register_feature(namespace, name or func.__name__, func, shape)
return func

@wraps(func)
def scalar_wrapper(*args, **kwargs):
res = func(*args, **kwargs)
try:
return res.tolist()
except AttributeError:
return res

@wraps(func)
def matrix_wrapper(*args, **kwargs):
res = func(*args, **kwargs)
return np.array(res).tolist()

if shape == ():
decorated_func = scalar_wrapper
else:
decorated_func = matrix_wrapper

_register_feature(namespace, name or func.__name__, decorated_func, shape)
return decorated_func

return inner

Expand Down
43 changes: 43 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Configuration for the pytest test suite."""
from pathlib import Path

import pytest
import neurom as nm


@pytest.fixture
def DATA_PATH():
return Path(__file__).parent / "data"


@pytest.fixture
def SWC_PATH(DATA_PATH):
return DATA_PATH / "swc"


@pytest.fixture
def MORPHOLOGY(SWC_PATH):
return nm.load_morphology(SWC_PATH / "test_morph.swc")


@pytest.fixture
def NEURITE(MORPHOLOGY):
return MORPHOLOGY.neurites[0]


@pytest.fixture
def SECTION(NEURITE):
return NEURITE.sections[0]


@pytest.fixture
def NRN_FILES(DATA_PATH):
return [
DATA_PATH / "h5/v1" / f
for f in ("Neuron.h5", "Neuron_2_branch.h5", "bio_neuron-001.h5")
]


@pytest.fixture
def POP(NRN_FILES):
return nm.load_morphologies(NRN_FILES)
59 changes: 59 additions & 0 deletions tests/features/test_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from itertools import chain

import numpy as np
import pytest
from numpy import testing as npt

from neurom import features


def _check_nested_type(data):
"""Check that the given data contains only built-in types.
The data should either be an int or float, or a list or tuple of ints or floats.
"""
if isinstance(data, (list, tuple)):
for i in data:
_check_nested_type(i)
else:
assert isinstance(data, (int, float))


class TestFeatureTypes:
"""Test that all features return raw Python types."""

@pytest.mark.parametrize(
"feature_name",
[
pytest.param(name, id=f"Test type of {name} neurite feature")
for name in features._NEURITE_FEATURES
]
)
def test_neurite_feature_types(self, feature_name, NEURITE):
"""Test neurite features."""
res = features._NEURITE_FEATURES.get(feature_name)(NEURITE)
_check_nested_type(res)

@pytest.mark.parametrize(
"feature_name",
[
pytest.param(name, id=f"Test type of {name} morphology feature")
for name in features._MORPHOLOGY_FEATURES
]
)
def test_morphology_feature_types(self, feature_name, MORPHOLOGY):
"""Test morphology features."""
res = features._MORPHOLOGY_FEATURES.get(feature_name)(MORPHOLOGY)
_check_nested_type(res)

@pytest.mark.parametrize(
"feature_name",
[
pytest.param(name, id=f"Test type of {name} population feature")
for name in features._POPULATION_FEATURES
]
)
def test_population_feature_types(self, feature_name, POP):
"""Test population features."""
res = features._POPULATION_FEATURES.get(feature_name)(POP)
_check_nested_type(res)
2 changes: 1 addition & 1 deletion tests/features/test_neurite.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_section_end_distances():

def test_section_partition_pairs():
part_pairs = [neurite.partition_pairs(s) for s in SIMPLE.neurites]
assert part_pairs == [[(1.0, 1.0)], [(1.0, 1.0)]]
assert part_pairs == [[[1.0, 1.0]], [[1.0, 1.0]]]


def test_section_bif_radial_distances():
Expand Down

0 comments on commit 384cf12

Please sign in to comment.