From 4950bf1f52562a9f0fe337f613724637af5da578 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Sun, 12 Sep 2021 23:44:57 -0400 Subject: [PATCH] feat: support subtraction between WeightedSum views --- src/boost_histogram/_internal/view.py | 4 ++-- tests/test_histogram.py | 17 ++++++++--------- tests/test_views.py | 22 ++++++++++++++++------ 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/boost_histogram/_internal/view.py b/src/boost_histogram/_internal/view.py index 788704bf..b06ab245 100644 --- a/src/boost_histogram/_internal/view.py +++ b/src/boost_histogram/_internal/view.py @@ -130,14 +130,14 @@ def __array_ufunc__( # Addition of two views if input_0.dtype == input_1.dtype: - if ufunc in {np.add}: + if ufunc in {np.add, np.subtract}: ufunc( input_0["value"], input_1["value"], out=result["value"], **kwargs, ) - ufunc( + np.add( input_0["variance"], input_1["variance"], out=result["variance"], diff --git a/tests/test_histogram.py b/tests/test_histogram.py index 108324f7..7478cd01 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -404,8 +404,6 @@ def test_add_2d_w(flow): def test_sub_2d(flow, count_storage): - if count_storage in {bh.storage.AtomicInt64, bh.storage.Weight}: - pytest.skip("Storage does not support subtraction") h0 = bh.Histogram( bh.axis.Integer(-1, 2, underflow=flow, overflow=flow), @@ -423,15 +421,16 @@ def test_sub_2d(flow, count_storage): m = h0.values(flow=True).copy() - h = h0.copy() - h -= h0 - assert h.values(flow=True) == approx(m * 0) + if count_storage not in {bh.storage.AtomicInt64, bh.storage.Weight}: + h = h0.copy() + h -= h0 + assert h.values(flow=True) == approx(m * 0) - h -= h0 - assert h.values(flow=True) == approx(-m) + h -= h0 + assert h.values(flow=True) == approx(-m) - h2 = h0 - (h0 + h0 + h0) - assert h2.values(flow=True) == approx(-2 * m) + h2 = h0 - (h0 + h0 + h0) + assert h2.values(flow=True) == approx(-2 * m) h3 = h0 - h0.view(flow=True) * 4 assert h3.values(flow=True) == approx(-3 * m) diff --git a/tests/test_views.py b/tests/test_views.py index c2f5c746..95c6cb95 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -65,9 +65,14 @@ def test_view_add(v): assert_allclose(v2.value, [2, 5, 4, 3]) assert_allclose(v2.variance, [4, 7, 6, 5]) - v += 2 - assert_allclose(v.value, [2, 5, 4, 3]) - assert_allclose(v.variance, [4, 7, 6, 5]) + v2 = v.copy() + v2 += 2 + assert_allclose(v2.value, [2, 5, 4, 3]) + assert_allclose(v2.variance, [4, 7, 6, 5]) + + v2 = v + v + assert_allclose(v2.value, v.value * 2) + assert_allclose(v2.variance, v.variance * 2) def test_view_sub(v): @@ -83,9 +88,14 @@ def test_view_sub(v): assert_allclose(v2.value, [1, -2, -1, 0]) assert_allclose(v2.variance, [1, 4, 3, 2]) - v -= 2 - assert_allclose(v.value, [-2, 1, 0, -1]) - assert_allclose(v.variance, [4, 7, 6, 5]) + v2 = v.copy() + v2 -= 2 + assert_allclose(v2.value, [-2, 1, 0, -1]) + assert_allclose(v2.variance, [4, 7, 6, 5]) + + v2 = v - v + assert_allclose(v2.value, [0, 0, 0, 0]) + assert_allclose(v2.variance, v.variance * 2) def test_view_unary(v):