Skip to content

Commit

Permalink
feat: Add support for importing _matplotlib_compat module
Browse files Browse the repository at this point in the history
This commit adds support for importing the `_matplotlib_compat` module in the `brainunit` package. The module is now imported using the wildcard import syntax, allowing easier access to its contents. This change improves the usability and flexibility of the `brainunit` package.
  • Loading branch information
Routhleck committed Oct 28, 2024
1 parent ea49008 commit e797109
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 35 deletions.
2 changes: 1 addition & 1 deletion brainunit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ._unit_constants import __all__ as _constants_all
from ._unit_shortcuts import *
from ._unit_shortcuts import __all__ as _std_units_all
import _matplotlib_compat
from ._matplotlib_compat import *

__all__ = ['math'] + _common_all + _std_units_all + _constants_all + _base_all + _celsius_all
del _common_all, _std_units_all, _constants_all, _base_all, _celsius_all, _matplotlib_compat
25 changes: 10 additions & 15 deletions brainunit/_matplotlib_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
from __future__ import annotations

import importlib.util
from contextlib import ContextDecorator

import numpy as np

from ._base import Unit, Quantity
from ._base import Quantity, fail_for_dimension_mismatch
from ._unit_common import radian

matplotlib_installed = importlib.util.find_spec('matplotlib') is not None
Expand All @@ -47,11 +46,6 @@ def rad_fn(


class MplQuantityConverter(units.ConversionInterface):
def __init__(self):
# Keep track of original converter in case the context manager is
# used in a nested way.
self._original_converter = {Quantity: units.registry.get(Quantity)}
units.registry[Quantity] = self

@staticmethod
def axisinfo(unit, axis):
Expand All @@ -68,9 +62,16 @@ def axisinfo(unit, axis):
@staticmethod
def convert(val, unit, axis):
if isinstance(val, Quantity):
# check dimension
fail_for_dimension_mismatch(val.unit, unit)
# check unit
if val.unit != unit:
# scale to target unit
return val.to(unit).mantissa
return val.mantissa
elif isinstance(val, list) and val and isinstance(val[0], Quantity):
return [v.mantissa for v in val]
fail_for_dimension_mismatch(val[0].unit, unit)
return [v.to(unit).mantissa if v.unit != unit else v.mantissa for v in val]
else:
return val

Expand All @@ -80,11 +81,5 @@ def default_units(x, axis):
return x.unit
return None

def __enter__(self):
return self

def __exit__(self, type, value, tb):
if self._original_converter[Quantity] is None:
del units.registry[Quantity]
else:
units.registry[Quantity] = self._original_converter[Quantity]
units.registry[Quantity] = MplQuantityConverter()
31 changes: 12 additions & 19 deletions brainunit/_matplotlib_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,29 @@

try:
import matplotlib.pyplot as plt
from matplotlib.units import ConversionError
except ImportError:
pytest.skip("matplotlib is not installed", allow_module_level=True)


def test_quantity_support():
with u.matplotlib_support_quantity():
plt.figure()
plt.plot([1, 2, 3] * u.meter)
plt.show()
plt.figure()
plt.plot([1, 2, 3] * u.meter)
plt.show()

plt.cla()
plt.plot([101, 125, 150] * u.cmeter)
plt.show()
plt.cla()
plt.plot([101, 125, 150] * u.cmeter)
plt.show()

plt.cla()
plt.plot([101, 125, 150] * u.ms, [101, 125, 150] * u.cmeter)
plt.plot([0.1, 0.15, 0.2] * u.second, [111, 135, 160] * u.cmeter)
plt.show()
plt.cla()
plt.plot([101, 125, 150] * u.ms, [101, 125, 150] * u.cmeter)
plt.plot([0.1, 0.15, 0.2] * u.second, [111, 135, 160] * u.cmeter)
plt.show()

with pytest.raises(ConversionError):
plt.cla()
plt.plot([101, 125, 150] * u.ms, [101, 125, 150] * u.cmeter)
plt.plot([0.1, 0.15, 0.2] * u.second, [111, 135, 160] * u.cmeter)
plt.plot([0.1, 0.15, 0.2] * u.second, [131, 155, 180] * u.mA)
plt.show()

with pytest.raises(TypeError):
plt.figure()
plt.plot([1, 2, 3] * u.meter)
plt.show()

plt.cla()
plt.plot([101, 125, 150] * u.cmeter, [1, 2, 3] * u.kgram)
plt.show()

0 comments on commit e797109

Please sign in to comment.