Skip to content

Commit

Permalink
Merge pull request #73 from ngoldbaum/vstack-fix
Browse files Browse the repository at this point in the history
add missing axis argument to np.stack call in ustack
  • Loading branch information
Nathan Goldbaum authored Mar 13, 2019
2 parents e194826 + 556adb5 commit e9bbdf5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion unyt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,7 +2141,7 @@ def ustack(arrs, axis=0):
[[1 2 3]
[2 3 4]] km
"""
v = np.stack(arrs)
v = np.stack(arrs, axis=axis)
v = _validate_numpy_wrapper_units(v, arrs)
return v

Expand Down
14 changes: 14 additions & 0 deletions unyt/tests/test_unyt_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
udot,
uintersect1d,
unorm,
ustack,
uunion1d,
uvstack,
loadtxt,
savetxt,
)
Expand Down Expand Up @@ -1863,9 +1865,12 @@ def test_numpy_wrappers():
a1 = unyt_array([1, 2, 3], "cm")
a2 = unyt_array([2, 3, 4, 5, 6], "cm")
a3 = unyt_array([[1, 2, 3], [4, 5, 6]], "cm")
a4 = unyt_array([7, 8, 9, 10, 11], "cm")
catenate_answer = [1, 2, 3, 2, 3, 4, 5, 6]
intersect_answer = [2, 3]
union_answer = [1, 2, 3, 4, 5, 6]
vstack_answer = [[2, 3, 4, 5, 6], [7, 8, 9, 10, 11]]
vstack_answer_last_axis = [[2, 7], [3, 8], [4, 9], [5, 10], [6, 11]]
cross_answer = [-2, 4, -2]
norm_answer = np.sqrt(1 ** 2 + 2 ** 2 + 3 ** 2)
arr_norm_answer = [norm_answer, np.sqrt(4 ** 2 + 5 ** 2 + 6 ** 2)]
Expand Down Expand Up @@ -1897,6 +1902,15 @@ def test_numpy_wrappers():
uconcatenate((a1, a2.v))
with pytest.raises(RuntimeError):
uconcatenate((a1.to("m"), a2))
assert_array_equal(unyt_array(vstack_answer, "cm"), uvstack([a2, a4]))
assert_array_equal(vstack_answer, np.vstack([a2, a4]))
assert_array_equal(unyt_array(vstack_answer, "cm"), ustack([a2, a4]))
assert_array_equal(vstack_answer, np.stack([a2, a4]))

assert_array_equal(
unyt_array(vstack_answer_last_axis, "cm"), ustack([a2, a4], axis=-1)
)
assert_array_equal(vstack_answer_last_axis, np.stack([a2, a4], axis=-1))


def test_dimensionless_conversion():
Expand Down

0 comments on commit e9bbdf5

Please sign in to comment.