Skip to content

Commit

Permalink
Specify output dtype for Normalize op in ETL example to match model e…
Browse files Browse the repository at this point in the history
…xpectations (#523)

* Specify output dtype for Normalize op in ETL to match model

* Convert timestamps to np.float32 in Getting Started ETL Notebook

* Convert normalized cols to np.float32
  • Loading branch information
oliverholworthy committed Nov 11, 2022
1 parent 3b17f6c commit 11be6b1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
" \n",
"recency_features = session_ts >> ItemRecency() \n",
"# Apply standardization to this continuous feature\n",
"recency_features_norm = recency_features >> nvt.ops.LogOp() >> nvt.ops.Normalize() >> nvt.ops.Rename(name='product_recency_days_log_norm')\n",
"recency_features_norm = recency_features >> nvt.ops.LogOp() >> nvt.ops.Normalize(out_dtype=np.float32) >> nvt.ops.Rename(name='product_recency_days_log_norm')\n",
"\n",
"time_features = (\n",
" session_time +\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@
"\n",
"# generate category mapping for each item-id\n",
"df['category'] = pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)\n",
"df['timestamp/age_days'] = np.random.uniform(0, 1, NUM_ROWS)\n",
"df['timestamp/weekday/sin']= np.random.uniform(0, 1, NUM_ROWS)\n",
"df['timestamp/age_days'] = np.random.uniform(0, 1, NUM_ROWS).astype(np.float32)\n",
"df['timestamp/weekday/sin']= np.random.uniform(0, 1, NUM_ROWS).astype(np.float32)\n",
"\n",
"# generate day mapping for each session \n",
"map_day = dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/tutorial/02-ETL-with-NVTabular.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@
"outputs": [],
"source": [
"recency_features = ['event_time_ts'] >> ItemRecency() \n",
"recency_features_norm = recency_features >> nvt.ops.LogOp() >> nvt.ops.Normalize() >> nvt.ops.Rename(name='product_recency_days_log_norm')"
"recency_features_norm = recency_features >> nvt.ops.LogOp() >> nvt.ops.Normalize(out_dtype=np.float32) >> nvt.ops.Rename(name='product_recency_days_log_norm')"
]
},
{
Expand Down Expand Up @@ -536,7 +536,7 @@
"outputs": [],
"source": [
"# Smoothing price long-tailed distribution and applying standardization\n",
"price_log = ['price'] >> nvt.ops.LogOp() >> nvt.ops.Normalize() >> nvt.ops.Rename(name='price_log_norm')"
"price_log = ['price'] >> nvt.ops.LogOp() >> nvt.ops.Normalize(out_dtype=np.float32) >> nvt.ops.Rename(name='price_log_norm')"
]
},
{
Expand Down

0 comments on commit 11be6b1

Please sign in to comment.