Skip to content

Commit

Permalink
Fix factorizing some more.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 30, 2022
1 parent 5646179 commit 7bcf77b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 12 deletions.
21 changes: 10 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def factorize_(
factorized = []
found_groups = []
for groupvar, expect in zip(by, expected_groups):
flat = groupvar.ravel()
if isinstance(expect, pd.IntervalIndex):
# when binning we change expected groups to integers marking the interval
# this makes the reindexing logic simpler.
Expand All @@ -432,21 +433,19 @@ def factorize_(
if groupvar.dtype.kind == "M":
expect = np.concatenate([expect.left.to_numpy(), [expect.right[-1].to_numpy()]])
# code is -1 for values outside the bounds of all intervals
idx = pd.cut(groupvar.ravel(), bins=expect).codes.copy()
idx = pd.cut(flat, bins=expect).codes.copy()
else:
if expect is not None and reindex:
groups = expect
sorter = np.argsort(expect)
groups = expect[(sorter,)] if sort else expect
idx = np.searchsorted(expect, flat, sorter=sorter)
mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
if not sort:
sorter = np.argsort(expect)
else:
sorter = None
idx = np.searchsorted(expect, groupvar.ravel(), sorter=sorter)
mask = isnull(groupvar.ravel()) | (idx == len(expect))
# TODO: optimize?
# idx is the index in to the sorted array.
# if we didn't want sorting, unsort it back
idx[(idx == len(expect),)] = -1
idx = sorter[(idx,)]
idx[mask] = -1
if not sort:
idx = sorter[idx]
idx[mask] = -1
else:
idx, groups = pd.factorize(groupvar.ravel(), sort=sort)

Expand Down
2 changes: 1 addition & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def assert_equal(a, b):
np.testing.assert_allclose(a, b, equal_nan=True)


@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
@pytest.fixture(scope="module", params=["flox"])
def engine(request):
if request.param == "numba":
try:
Expand Down
59 changes: 59 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,62 @@ def test_factorize_values_outside_bins():
actual = vals[0]
expected = np.array([[-1, -1], [-1, 0], [6, 12], [18, 24], [-1, -1]])
assert_equal(expected, actual)


def test_multiple_groupers():
actual, *_ = groupby_reduce(
np.ones((5, 2)),
np.arange(10).reshape(5, 2),
np.arange(10).reshape(5, 2),
axis=(0, 1),
expected_groups=(
pd.IntervalIndex.from_breaks(np.arange(2, 8, 1)),
pd.IntervalIndex.from_breaks(np.arange(2, 8, 1)),
),
reindex=True,
func="count",
)
expected = np.eye(5, 5)
assert_equal(expected, actual)


def test_factorize_reindex_sorting_strings():
kwargs = dict(
by=(np.array(["El-Nino", "La-Nina", "boo", "Neutral"]),),
axis=-1,
expected_groups=(np.array(["El-Nino", "Neutral", "foo", "La-Nina"]),),
)

expected = factorize_(**kwargs, reindex=True, sort=True)[0]
assert_equal(expected, [0, 1, 4, 2])

expected = factorize_(**kwargs, reindex=True, sort=False)[0]
assert_equal(expected, [0, 3, 4, 1])

expected = factorize_(**kwargs, reindex=False, sort=False)[0]
assert_equal(expected, [0, 1, 2, 3])

expected = factorize_(**kwargs, reindex=False, sort=True)[0]
assert_equal(expected, [0, 1, 3, 2])


def test_factorize_reindex_sorting_ints():
kwargs = dict(
by=(np.array([-10, 1, 10, 2, 3, 5]),),
axis=-1,
expected_groups=(np.array([0, 1, 2, 3, 4, 5]),),
)

expected = factorize_(**kwargs, reindex=True, sort=True)[0]
assert_equal(expected, [6, 1, 6, 2, 3, 5])

expected = factorize_(**kwargs, reindex=True, sort=False)[0]
assert_equal(expected, [6, 1, 6, 2, 3, 5])

kwargs["expected_groups"] = (np.arange(5, -1, -1),)

expected = factorize_(**kwargs, reindex=True, sort=True)[0]
assert_equal(expected, [6, 1, 6, 2, 3, 5])

expected = factorize_(**kwargs, reindex=True, sort=False)[0]
assert_equal(expected, [6, 4, 6, 3, 2, 0])

0 comments on commit 7bcf77b

Please sign in to comment.