diff --git a/cmeutils/tests/test_visualize.py b/cmeutils/tests/test_visualize.py index f0374a9..650d8df 100644 --- a/cmeutils/tests/test_visualize.py +++ b/cmeutils/tests/test_visualize.py @@ -51,6 +51,10 @@ def test_set_bad_frame(self, p3ht_fresnel): with pytest.raises(ValueError): p3ht_fresnel.frame = 10 + def test_set_bad_view(self, p3ht_fresnel): + with pytest.raises(ValueError): + p3ht_fresnel.view_axis = 10 + def test_bad_color_dict(self, p3ht_fresnel): with pytest.raises(ValueError): p3ht_fresnel.color_dict = np.array([0.1, 0.1, 0.1]) diff --git a/cmeutils/visualize.py b/cmeutils/visualize.py index 73ec5a2..a128886 100644 --- a/cmeutils/visualize.py +++ b/cmeutils/visualize.py @@ -226,7 +226,7 @@ def view_axis(self): def view_axis(self, value): # TODO Assert is 1,3 array new_view_axis = np.asarray(value) - if len(new_view_axis) != 0: + if new_view_axis.shape != (3,): raise ValueError("View axis must be a 3x1 array") self._view_axis = np.asarray(new_view_axis)