Skip to content

Commit

Permalink
Add tests for issue #1074 and update Categorify args
Browse files Browse the repository at this point in the history
Add tests for issue #1074 on adding a start_index argument to
the Categorify op. The motivation for this issue is to allow an
offset for encoding out-of-vocabulary and other special values.

In this commit, we add a test function to test various values of our
new start_index arg in tests/unit/ops/text_ops.py and add a
start_value arg to the class signature for Categorify(),
documentation for its intended function in Categorify()'s
docstring, a start_value arg to FitOptions(), and documentation
for this new argument in FitOptions()'s docstring.
  • Loading branch information
Adam Lesnikowski committed Sep 2, 2021
1 parent b74a1a0 commit 642e5f5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
10 changes: 10 additions & 0 deletions nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ class Categorify(StatOperator):
value will be `max_size - num_buckets -1`. Setting the max_size param means that
freq_threshold should not be given. If the num_buckets parameter is set, it must be
smaller than the max_size value.
start_index: int, default 0
The start index where Categorify will begin to translate dataframe entries
into integer values. For instance, if our original translated dataframe entries appear
as [[1], [1, 4], [3, 2], [2]], then with a start_index of 16, Categorify will now be
[[16], [16, 19], [18, 17], [17]]. This option is useful to reserve an intial segment
of non-negative translated integers for out-of-vocabulary or other special values.
"""

def __init__(
Expand All @@ -191,6 +197,7 @@ def __init__(
num_buckets=None,
vocabs=None,
max_size=0,
start_index=0,
):

# We need to handle three types of encoding here:
Expand Down Expand Up @@ -603,6 +610,8 @@ class FitOptions:
num_buckets:
If specified will also do hashing operation for values that would otherwise be mapped
to as unknown (by freq_limit or max_size parameters)
start_index: int
The index to start mapping our output categorical values to.
"""

col_groups: list
Expand All @@ -617,6 +626,7 @@ class FitOptions:
name_sep: str = "-"
max_size: Optional[Union[int, dict]] = None
num_buckets: Optional[Union[int, dict]] = None
start_index: int = 0

def __post_init__(self):
if not isinstance(self.col_groups, ColumnSelector):
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/ops/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,41 @@ def test_categorify_lists(tmpdir, freq_threshold, cpu, dtype, vocabs):
else:
assert compare == [[1], [1, 0], [0, 2], [2]]

@pytest.mark.parametrize("cpu", _CPU)
@pytest.mark.parametrize("dtype", [None, np.int32, np.int64])
@pytest.mark.parametrize("vocabs", [None, pd.DataFrame({"Authors": [f"User_{x}" for x in "ACBE"]})])
@pytest.mark.parametrize("start_index", [0, 2, 16])
def test_categorify_lists_with_start_index(tmpdir, cpu, dtype, vocabs, start_index):
df = dispatch._make_df(
{
"Authors": [["User_A"], ["User_A", "User_E"], ["User_B", "User_C"], ["User_C"]],
"Engaging User": ["User_B", "User_B", "User_A", "User_D"],
"Post": [1, 2, 3, 4],
}
)
cat_names = ["Authors", "Engaging User"]
label_name = ["Post"]

cat_features = cat_names >> ops.Categorify(
out_path=str(tmpdir), dtype=dtype, vocabs=vocabs
)

workflow = nvt.Workflow(cat_features + label_name)
df_out = workflow.fit_transform(nvt.Dataset(df, cpu=cpu)).to_ddf().compute()

if cpu:
compare = [list(row) for row in df_out["Authors"].tolist()]
else:
compare = df_out["Authors"].to_arrow().to_pylist()

if start_index == 0:
assert compare == [[1], [1, 4], [3, 2], [2]]

if start_index == 2:
assert compare == [[3], [3, 6], [5, 4], [4]]

if start_index == 16:
assert compare == [[16], [16, 19], [18, 17], [17]]

@pytest.mark.parametrize("cat_names", [[["Author", "Engaging User"]], ["Author", "Engaging User"]])
@pytest.mark.parametrize("kind", ["joint", "combo"])
Expand Down

0 comments on commit 642e5f5

Please sign in to comment.