From 556adb5662bd9853be5815b5b1d53736e99e2c3f Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Wed, 13 Mar 2019 09:28:08 -0500 Subject: [PATCH] add missing axis argument to np.stack call in ustack --- unyt/array.py | 2 +- unyt/tests/test_unyt_array.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/unyt/array.py b/unyt/array.py index 863fed64..004a7f02 100644 --- a/unyt/array.py +++ b/unyt/array.py @@ -2141,7 +2141,7 @@ def ustack(arrs, axis=0): [[1 2 3] [2 3 4]] km """ - v = np.stack(arrs) + v = np.stack(arrs, axis=axis) v = _validate_numpy_wrapper_units(v, arrs) return v diff --git a/unyt/tests/test_unyt_array.py b/unyt/tests/test_unyt_array.py index 7a34e27f..545fb774 100644 --- a/unyt/tests/test_unyt_array.py +++ b/unyt/tests/test_unyt_array.py @@ -43,7 +43,9 @@ udot, uintersect1d, unorm, + ustack, uunion1d, + uvstack, loadtxt, savetxt, ) @@ -1863,9 +1865,12 @@ def test_numpy_wrappers(): a1 = unyt_array([1, 2, 3], "cm") a2 = unyt_array([2, 3, 4, 5, 6], "cm") a3 = unyt_array([[1, 2, 3], [4, 5, 6]], "cm") + a4 = unyt_array([7, 8, 9, 10, 11], "cm") catenate_answer = [1, 2, 3, 2, 3, 4, 5, 6] intersect_answer = [2, 3] union_answer = [1, 2, 3, 4, 5, 6] + vstack_answer = [[2, 3, 4, 5, 6], [7, 8, 9, 10, 11]] + vstack_answer_last_axis = [[2, 7], [3, 8], [4, 9], [5, 10], [6, 11]] cross_answer = [-2, 4, -2] norm_answer = np.sqrt(1 ** 2 + 2 ** 2 + 3 ** 2) arr_norm_answer = [norm_answer, np.sqrt(4 ** 2 + 5 ** 2 + 6 ** 2)] @@ -1897,6 +1902,15 @@ def test_numpy_wrappers(): uconcatenate((a1, a2.v)) with pytest.raises(RuntimeError): uconcatenate((a1.to("m"), a2)) + assert_array_equal(unyt_array(vstack_answer, "cm"), uvstack([a2, a4])) + assert_array_equal(vstack_answer, np.vstack([a2, a4])) + assert_array_equal(unyt_array(vstack_answer, "cm"), ustack([a2, a4])) + assert_array_equal(vstack_answer, np.stack([a2, a4])) + + assert_array_equal( + unyt_array(vstack_answer_last_axis, "cm"), ustack([a2, a4], axis=-1) + ) + assert_array_equal(vstack_answer_last_axis, np.stack([a2, a4], axis=-1)) def test_dimensionless_conversion():