Skip to content

Commit

Permalink
Add tests for missed branches.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed May 9, 2018
1 parent 0a6b906 commit c7a7ca3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sparse/coo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def __init__(self, coords, data=None, shape=None, has_duplicates=True,
else:
dtype = np.uint8
self.coords = self.coords.astype(dtype)
assert not self.shape or len(data) == self.coords.shape[1]
assert not self.shape or (len(data) == self.coords.shape[1] and
len(shape) == self.coords.shape[0])

if not sorted:
self._sort_indices()
Expand Down Expand Up @@ -397,7 +398,8 @@ def from_iter(cls, x, shape=None):
coords = np.empty((ndim, 0), dtype=np.uint8)
data = np.empty((0,))

return COO(coords, data, shape=(), sorted=True, has_duplicates=False)
return COO(coords, data, shape=() if shape is None else shape,
sorted=True, has_duplicates=False)

if not isinstance(x[0][0], Iterable):
coords = np.stack(x[1], axis=0)
Expand Down
10 changes: 10 additions & 0 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,3 +1524,13 @@ def test_invalid_shape_error():

with pytest.raises(ValueError):
COO(s, shape=(2, 3))


def test_invalid_iterable_error():
with pytest.raises(ValueError):
x = [(3, 4, 5)]
COO.from_iter(x)

with pytest.raises(ValueError):
x = [((2.3, 4.5), 3.2)]
COO.from_iter(x)

0 comments on commit c7a7ca3

Please sign in to comment.