Skip to content

Commit

Permalink
feat: support subtraction between WeightedSum views
Browse files Browse the repository at this point in the history
  • Loading branch information
henryiii committed Sep 13, 2021
1 parent 830c6da commit 4950bf1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/boost_histogram/_internal/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def __array_ufunc__(

# Addition of two views
if input_0.dtype == input_1.dtype:
if ufunc in {np.add}:
if ufunc in {np.add, np.subtract}:
ufunc(
input_0["value"],
input_1["value"],
out=result["value"],
**kwargs,
)
ufunc(
np.add(
input_0["variance"],
input_1["variance"],
out=result["variance"],
Expand Down
17 changes: 8 additions & 9 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,6 @@ def test_add_2d_w(flow):


def test_sub_2d(flow, count_storage):
if count_storage in {bh.storage.AtomicInt64, bh.storage.Weight}:
pytest.skip("Storage does not support subtraction")

h0 = bh.Histogram(
bh.axis.Integer(-1, 2, underflow=flow, overflow=flow),
Expand All @@ -423,15 +421,16 @@ def test_sub_2d(flow, count_storage):

m = h0.values(flow=True).copy()

h = h0.copy()
h -= h0
assert h.values(flow=True) == approx(m * 0)
if count_storage not in {bh.storage.AtomicInt64, bh.storage.Weight}:
h = h0.copy()
h -= h0
assert h.values(flow=True) == approx(m * 0)

h -= h0
assert h.values(flow=True) == approx(-m)
h -= h0
assert h.values(flow=True) == approx(-m)

h2 = h0 - (h0 + h0 + h0)
assert h2.values(flow=True) == approx(-2 * m)
h2 = h0 - (h0 + h0 + h0)
assert h2.values(flow=True) == approx(-2 * m)

h3 = h0 - h0.view(flow=True) * 4
assert h3.values(flow=True) == approx(-3 * m)
Expand Down
22 changes: 16 additions & 6 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,14 @@ def test_view_add(v):
assert_allclose(v2.value, [2, 5, 4, 3])
assert_allclose(v2.variance, [4, 7, 6, 5])

v += 2
assert_allclose(v.value, [2, 5, 4, 3])
assert_allclose(v.variance, [4, 7, 6, 5])
v2 = v.copy()
v2 += 2
assert_allclose(v2.value, [2, 5, 4, 3])
assert_allclose(v2.variance, [4, 7, 6, 5])

v2 = v + v
assert_allclose(v2.value, v.value * 2)
assert_allclose(v2.variance, v.variance * 2)


def test_view_sub(v):
Expand All @@ -83,9 +88,14 @@ def test_view_sub(v):
assert_allclose(v2.value, [1, -2, -1, 0])
assert_allclose(v2.variance, [1, 4, 3, 2])

v -= 2
assert_allclose(v.value, [-2, 1, 0, -1])
assert_allclose(v.variance, [4, 7, 6, 5])
v2 = v.copy()
v2 -= 2
assert_allclose(v2.value, [-2, 1, 0, -1])
assert_allclose(v2.variance, [4, 7, 6, 5])

v2 = v - v
assert_allclose(v2.value, [0, 0, 0, 0])
assert_allclose(v2.variance, v.variance * 2)


def test_view_unary(v):
Expand Down

0 comments on commit 4950bf1

Please sign in to comment.