diff --git a/brainunit/__init__.py b/brainunit/__init__.py index e218c83..188b603 100644 --- a/brainunit/__init__.py +++ b/brainunit/__init__.py @@ -30,7 +30,7 @@ from ._unit_constants import __all__ as _constants_all from ._unit_shortcuts import * from ._unit_shortcuts import __all__ as _std_units_all -import _matplotlib_compat +from ._matplotlib_compat import * __all__ = ['math'] + _common_all + _std_units_all + _constants_all + _base_all + _celsius_all del _common_all, _std_units_all, _constants_all, _base_all, _celsius_all, _matplotlib_compat diff --git a/brainunit/_matplotlib_compat.py b/brainunit/_matplotlib_compat.py index e529006..c955692 100644 --- a/brainunit/_matplotlib_compat.py +++ b/brainunit/_matplotlib_compat.py @@ -16,11 +16,10 @@ from __future__ import annotations import importlib.util -from contextlib import ContextDecorator import numpy as np -from ._base import Unit, Quantity +from ._base import Quantity, fail_for_dimension_mismatch from ._unit_common import radian matplotlib_installed = importlib.util.find_spec('matplotlib') is not None @@ -47,11 +46,6 @@ def rad_fn( class MplQuantityConverter(units.ConversionInterface): - def __init__(self): - # Keep track of original converter in case the context manager is - # used in a nested way. - self._original_converter = {Quantity: units.registry.get(Quantity)} - units.registry[Quantity] = self @staticmethod def axisinfo(unit, axis): @@ -68,9 +62,16 @@ def axisinfo(unit, axis): @staticmethod def convert(val, unit, axis): if isinstance(val, Quantity): + # check dimension + fail_for_dimension_mismatch(val.unit, unit) + # check unit + if val.unit != unit: + # scale to target unit + return val.to(unit).mantissa return val.mantissa elif isinstance(val, list) and val and isinstance(val[0], Quantity): - return [v.mantissa for v in val] + fail_for_dimension_mismatch(val[0].unit, unit) + return [v.to(unit).mantissa if v.unit != unit else v.mantissa for v in val] else: return val @@ -80,11 +81,5 @@ def default_units(x, axis): return x.unit return None - def __enter__(self): - return self - def __exit__(self, type, value, tb): - if self._original_converter[Quantity] is None: - del units.registry[Quantity] - else: - units.registry[Quantity] = self._original_converter[Quantity] + units.registry[Quantity] = MplQuantityConverter() diff --git a/brainunit/_matplotlib_compat_test.py b/brainunit/_matplotlib_compat_test.py index 24de767..4d95fbb 100644 --- a/brainunit/_matplotlib_compat_test.py +++ b/brainunit/_matplotlib_compat_test.py @@ -4,36 +4,29 @@ try: import matplotlib.pyplot as plt + from matplotlib.units import ConversionError except ImportError: pytest.skip("matplotlib is not installed", allow_module_level=True) def test_quantity_support(): - with u.matplotlib_support_quantity(): - plt.figure() - plt.plot([1, 2, 3] * u.meter) - plt.show() + plt.figure() + plt.plot([1, 2, 3] * u.meter) + plt.show() - plt.cla() - plt.plot([101, 125, 150] * u.cmeter) - plt.show() + plt.cla() + plt.plot([101, 125, 150] * u.cmeter) + plt.show() - plt.cla() - plt.plot([101, 125, 150] * u.ms, [101, 125, 150] * u.cmeter) - plt.plot([0.1, 0.15, 0.2] * u.second, [111, 135, 160] * u.cmeter) - plt.show() + plt.cla() + plt.plot([101, 125, 150] * u.ms, [101, 125, 150] * u.cmeter) + plt.plot([0.1, 0.15, 0.2] * u.second, [111, 135, 160] * u.cmeter) + plt.show() + with pytest.raises(ConversionError): plt.cla() plt.plot([101, 125, 150] * u.ms, [101, 125, 150] * u.cmeter) plt.plot([0.1, 0.15, 0.2] * u.second, [111, 135, 160] * u.cmeter) plt.plot([0.1, 0.15, 0.2] * u.second, [131, 155, 180] * u.mA) plt.show() - with pytest.raises(TypeError): - plt.figure() - plt.plot([1, 2, 3] * u.meter) - plt.show() - - plt.cla() - plt.plot([101, 125, 150] * u.cmeter, [1, 2, 3] * u.kgram) - plt.show()