Skip to content

Commit

Permalink
Change start_index's default value to 0 from 1.
Browse files Browse the repository at this point in the history
Change start_indexi's default value to 0 from 1. The motivation
for this is to have more Pythonic code, and for increased consistency
with the rest of the updated modules. This implements the suggestion
made by @benfred towards merging the implementation of
issue #1074's feature request.

This change includes modifying what Categorify()'s _encode()
method does, default arg value updates in some other methods,
docstring updating, and updating Categorify()'s tests that
test start_index's behavior.
  • Loading branch information
Adam Lesnikowski committed Sep 10, 2021
1 parent 51ad8a2 commit 0425246
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
21 changes: 12 additions & 9 deletions nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,15 @@ class Categorify(StatOperator):
value will be `max_size - num_buckets -1`. Setting the max_size param means that
freq_threshold should not be given. If the num_buckets parameter is set, it must be
smaller than the max_size value.
start_index: int, default 1
start_index: int, default 0
The start index where Categorify will begin to translate dataframe entries
into integer values. For instance, if our original translated dataframe entries appear
as [[1], [1, 4], [3, 2], [2]], then with a start_index of 16, Categorify will now be
[[16], [16, 19], [18, 17], [17]]. This option is useful to reserve an intial segment
of non-negative translated integers for out-of-vocabulary or other special values.
into integer values, including an initial out-of-vocabulary encoding value.
For instance, if our original translated dataframe entries appear
as [[1], [1, 4], [3, 2], [2]], with an out-of-vocabulary value of 0, then with a
start_index of 16, Categorify will reserve 16 as the out-of-vocabulary encoding value,
and our new translated dataframe entry will now be [[17], [17, 20], [19, 18], [18]].
This parameter is useful to reserve an intial segment of non-negative translated integers
for special user-defined values.
"""

def __init__(
Expand All @@ -197,7 +200,7 @@ def __init__(
num_buckets=None,
vocabs=None,
max_size=0,
start_index=1,
start_index=0,
):

# We need to handle three types of encoding here:
Expand Down Expand Up @@ -1074,7 +1077,7 @@ def _encode(
cat_names=None,
max_size=0,
dtype=None,
start_index=1,
start_index=0,
):
"""The _encode method is responsible for transforming a dataframe by taking the written
out vocabulary file and looking up values to translate inputs to numeric
Expand All @@ -1090,7 +1093,7 @@ def _encode(
na_sentinel : int
Sentinel for NA value. Defaults to -1.
freq_threshold : int
Cateogires with a count or frequency below this threshold will
Categories with a count or frequency below this threshold will
be ommitted from the encoding and corresponding data will be
mapped to the "Null" category. Defaults to 0.
search_sorted :
Expand Down Expand Up @@ -1205,7 +1208,7 @@ def _encode(
elif dtype:
labels = labels.astype(dtype, copy=False)

labels = labels + (start_index - 1)
labels = labels + start_index
return labels


Expand Down
15 changes: 9 additions & 6 deletions tests/unit/ops/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,17 @@ def test_categorify_lists_with_start_index(tmpdir, cpu, start_index):
else:
compare = df_out["Authors"].to_arrow().to_pylist()

if start_index == 1:
# Note that start_index is the start_index of the range of encoding, which
# includes both an initial value for the encoding for out-of-vocabulary items,
# as well as the values for the rest of the in-vocabulary items.
# In this group of tests below, there are no out-of-vocabulary items, so our start index
# value does not appear in the expected comparison object.
if start_index == 0:
assert compare == [[1], [1, 4], [3, 2], [2]]

if start_index == 2:
elif start_index == 1:
assert compare == [[2], [2, 5], [4, 3], [3]]

if start_index == 16:
assert compare == [[16], [16, 19], [18, 17], [17]]
elif start_index == 16:
assert compare == [[17], [17, 20], [19, 18], [18]]


@pytest.mark.parametrize("cat_names", [[["Author", "Engaging User"]], ["Author", "Engaging User"]])
Expand Down

0 comments on commit 0425246

Please sign in to comment.