From f1c95a56c3331fbd383957d040c7bdd5331d8d79 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 23 Nov 2022 22:24:48 +0000 Subject: [PATCH 1/4] serving tfrec with pyt backend example --- .../01-ETL-with-NVTabular.ipynb | 671 +++++--- .../02-session-based-XLNet-with-PyT.ipynb | 342 ++-- ...ssion-based-model-with-Torch-backend.ipynb | 1473 +++++++++++++++++ .../getting-started-session-based/schema.pb | 87 - 4 files changed, 2131 insertions(+), 442 deletions(-) create mode 100644 examples/getting-started-session-based/03-serving-session-based-model-with-Torch-backend.ipynb delete mode 100644 examples/getting-started-session-based/schema.pb diff --git a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb index 632fe0ba34..d071156144 100644 --- a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb @@ -104,7 +104,7 @@ "metadata": {}, "outputs": [], "source": [ - "NUM_ROWS = 10000000\n", + "NUM_ROWS = 100000\n", "long_tailed_item_distribution = np.clip(np.random.lognormal(3., 1., NUM_ROWS).astype(np.int32), 1, 50000)\n", "\n", "# generate random item interaction features \n", @@ -113,8 +113,8 @@ "\n", "# generate category mapping for each item-id\n", "df['category'] = pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)\n", - "df['timestamp/age_days'] = np.random.uniform(0, 1, NUM_ROWS).astype(np.float32)\n", - "df['timestamp/weekday/sin']= np.random.uniform(0, 1, NUM_ROWS).astype(np.float32)\n", + "df['age_days'] = np.random.uniform(0, 1, NUM_ROWS).astype(np.float32)\n", + "df['weekday_sin']= np.random.uniform(0, 1, NUM_ROWS).astype(np.float32)\n", "\n", "# generate day mapping for each session \n", "map_day = dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))\n", @@ -159,75 +159,68 @@ " session_id\n", " item_id\n", " category\n", - " timestamp/age_days\n", - " timestamp/weekday/sin\n", + " age_days\n", + " weekday_sin\n", " day\n", " \n", " \n", " \n", " \n", " 0\n", - " 88303\n", + " 88504\n", + " 26\n", " 4\n", + " 0.342991\n", + " 0.144433\n", " 1\n", - " 0.627299\n", - " 0.059239\n", - " 2\n", " \n", " \n", " 1\n", - " 79291\n", - " 15\n", - " 2\n", - " 0.693606\n", - " 0.128668\n", + " 85107\n", + " 13\n", " 2\n", + " 0.156982\n", + " 0.722122\n", + " 4\n", " \n", " \n", " 2\n", - " 75485\n", + " 89499\n", " 37\n", " 5\n", - " 0.062722\n", - " 0.111661\n", - " 7\n", + " 0.389054\n", + " 0.321258\n", + " 9\n", " \n", " \n", " 3\n", - " 85283\n", - " 42\n", - " 5\n", - " 0.744100\n", - " 0.480346\n", - " 1\n", + " 88602\n", + " 20\n", + " 3\n", + " 0.258130\n", + " 0.491159\n", + " 2\n", " \n", " \n", " 4\n", - " 84407\n", - " 5\n", + " 84113\n", + " 7\n", " 1\n", - " 0.622424\n", - " 0.989467\n", - " 3\n", + " 0.519515\n", + " 0.110561\n", + " 7\n", " \n", " \n", "\n", "" ], "text/plain": [ - " session_id item_id category timestamp/age_days timestamp/weekday/sin \\\n", - "0 88303 4 1 0.627299 0.059239 \n", - "1 79291 15 2 0.693606 0.128668 \n", - "2 75485 37 5 0.062722 0.111661 \n", - "3 85283 42 5 0.744100 0.480346 \n", - "4 84407 5 1 0.622424 0.989467 \n", - "\n", - " day \n", - "0 2 \n", - "1 2 \n", - "2 7 \n", - "3 1 \n", - "4 3 " + " session_id item_id category age_days weekday_sin day\n", + "0 88504 26 4 0.342991 0.144433 1\n", + "1 85107 13 2 0.156982 0.722122 4\n", + "2 89499 37 5 0.389054 0.321258 9\n", + "3 88602 20 3 0.258130 0.491159 2\n", + "4 84113 7 1 0.519515 0.110561 7" ] }, "execution_count": 5, @@ -283,11 +276,13 @@ } ], "source": [ + "SESSIONS_MAX_LENGTH =20\n", + "\n", "# Categorify categorical features\n", "categ_feats = ['session_id', 'item_id', 'category'] >> nvt.ops.Categorify(start_index=1)\n", "\n", "# Define Groupby Workflow\n", - "groupby_feats = categ_feats + ['day', 'timestamp/age_days', 'timestamp/weekday/sin']\n", + "groupby_feats = categ_feats + ['day', 'age_days', 'weekday_sin']\n", "\n", "# Group interaction features by session\n", "groupby_features = groupby_feats >> nvt.ops.Groupby(\n", @@ -296,25 +291,29 @@ " \"item_id\": [\"list\", \"count\"],\n", " \"category\": [\"list\"], \n", " \"day\": [\"first\"],\n", - " \"timestamp/age_days\": [\"list\"],\n", - " 'timestamp/weekday/sin': [\"list\"],\n", + " \"age_days\": [\"list\"],\n", + " 'weekday_sin': [\"list\"],\n", " },\n", " name_sep=\"-\")\n", "\n", "# Select and truncate the sequential features\n", - "sequence_features_truncated = (groupby_features['category-list']) >> nvt.ops.ListSlice(0,20) >> nvt.ops.Rename(postfix = '_trim')\n", + "sequence_features_truncated = (\n", + " groupby_features['category-list']\n", + " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) \n", + " >> nvt.ops.ValueCount()\n", + ")\n", "\n", "sequence_features_truncated_item = (\n", " groupby_features['item_id-list']\n", - " >> nvt.ops.ListSlice(0,20) \n", - " >> nvt.ops.Rename(postfix = '_trim')\n", + " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) \n", " >> TagAsItemID()\n", + " >> nvt.ops.ValueCount()\n", ") \n", "sequence_features_truncated_cont = (\n", - " groupby_features['timestamp/age_days-list', 'timestamp/weekday/sin-list'] \n", - " >> nvt.ops.ListSlice(0,20) \n", - " >> nvt.ops.Rename(postfix = '_trim')\n", + " groupby_features['age_days-list', 'weekday_sin-list'] \n", + " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) \n", " >> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])\n", + " >> nvt.ops.ValueCount()\n", ")\n", "\n", "# Filter out sessions with length 1 (not valid for next-item prediction training and evaluation)\n", @@ -328,8 +327,10 @@ " \n", "filtered_sessions = selected_features >> nvt.ops.Filter(f=lambda df: df[\"item_id-count\"] >= MINIMUM_SESSION_LENGTH)\n", "\n", + "seq_feats_list = filtered_sessions['item_id-list', 'category-list', 'age_days-list', 'weekday_sin-list'] >> nvt.ops.ValueCount()\n", "\n", - "workflow = nvt.Workflow(filtered_sessions)\n", + "\n", + "workflow = nvt.Workflow(filtered_sessions['session_id', 'day-first', 'item_id-count'] + seq_feats_list)\n", "dataset = nvt.Dataset(df, cpu=False)\n", "# Generate statistics for the features\n", "workflow.fit(dataset)\n", @@ -366,75 +367,75 @@ " \n", " \n", " \n", - " item_id-count\n", - " day-first\n", " session_id\n", - " item_id-list_trim\n", - " category-list_trim\n", - " timestamp/age_days-list_trim\n", - " timestamp/weekday/sin-list_trim\n", + " day-first\n", + " item_id-count\n", + " item_id-list\n", + " category-list\n", + " age_days-list\n", + " weekday_sin-list\n", " \n", " \n", " \n", " \n", " 0\n", - " 586\n", - " 5\n", " 2\n", - " [13, 8, 39, 14, 23, 2, 4, 83, 25, 34, 17, 4, 1...\n", - " [3, 3, 6, 2, 4, 2, 2, 11, 4, 5, 4, 2, 15, 18, ...\n", - " [0.3265768, 0.41545194, 0.52078074, 0.7723212,...\n", - " [0.5636411, 0.14788395, 0.6995017, 0.010999571...\n", + " 9\n", + " 78\n", + " [19, 15, 26, 6, 24, 44, 19, 29, 14, 42, 33, 11...\n", + " [4, 4, 5, 2, 5, 7, 4, 5, 2, 7, 2, 3, 2, 3, 2, ...\n", + " [0.0042280755, 0.40522072, 0.42538044, 0.97327...\n", + " [0.6292621, 0.1172376, 0.18633945, 0.8232658, ...\n", " \n", " \n", " 1\n", - " 586\n", - " 6\n", " 3\n", - " [20, 14, 8, 15, 10, 56, 73, 22, 18, 52, 27, 42...\n", - " [4, 2, 3, 3, 3, 8, 10, 4, 4, 7, 5, 6, 2, 9, 9,...\n", - " [0.13561419, 0.035071343, 0.58149755, 0.159483...\n", - " [0.3359673, 0.6002685, 0.84561634, 0.04078535,...\n", + " 2\n", + " 76\n", + " [26, 85, 9, 57, 17, 7, 9, 41, 37, 11, 10, 16, ...\n", + " [5, 13, 2, 9, 4, 3, 2, 7, 6, 3, 3, 4, 7, 5, 2,...\n", + " [0.07340249, 0.2910817, 0.010784109, 0.8495507...\n", + " [0.52355134, 0.83093345, 0.8837344, 0.38942775...\n", " \n", " \n", " 2\n", - " 584\n", - " 2\n", " 4\n", - " [24, 12, 178, 23, 23, 218, 92, 5, 55, 85, 10, ...\n", - " [4, 3, 22, 4, 4, 28, 12, 2, 8, 11, 3, 3, 4, 3,...\n", - " [0.39440218, 0.12561888, 0.27249986, 0.6201667...\n", - " [0.90376323, 0.75177085, 0.6668168, 0.0828298,...\n", + " 6\n", + " 76\n", + " [13, 13, 50, 64, 105, 16, 78, 17, 19, 34, 8, 1...\n", + " [3, 3, 8, 10, 16, 4, 12, 4, 4, 6, 3, 4, 3, 6, ...\n", + " [0.29271448, 0.59962034, 0.042938035, 0.730446...\n", + " [0.8610789, 0.058191676, 0.806903, 0.79222715,...\n", " \n", " \n", "\n", "" ], "text/plain": [ - " item_id-count day-first session_id \\\n", - "0 586 5 2 \n", - "1 586 6 3 \n", - "2 584 2 4 \n", + " session_id day-first item_id-count \\\n", + "0 2 9 78 \n", + "1 3 2 76 \n", + "2 4 6 76 \n", "\n", - " item_id-list_trim \\\n", - "0 [13, 8, 39, 14, 23, 2, 4, 83, 25, 34, 17, 4, 1... \n", - "1 [20, 14, 8, 15, 10, 56, 73, 22, 18, 52, 27, 42... \n", - "2 [24, 12, 178, 23, 23, 218, 92, 5, 55, 85, 10, ... \n", + " item_id-list \\\n", + "0 [19, 15, 26, 6, 24, 44, 19, 29, 14, 42, 33, 11... \n", + "1 [26, 85, 9, 57, 17, 7, 9, 41, 37, 11, 10, 16, ... \n", + "2 [13, 13, 50, 64, 105, 16, 78, 17, 19, 34, 8, 1... \n", "\n", - " category-list_trim \\\n", - "0 [3, 3, 6, 2, 4, 2, 2, 11, 4, 5, 4, 2, 15, 18, ... \n", - "1 [4, 2, 3, 3, 3, 8, 10, 4, 4, 7, 5, 6, 2, 9, 9,... \n", - "2 [4, 3, 22, 4, 4, 28, 12, 2, 8, 11, 3, 3, 4, 3,... \n", + " category-list \\\n", + "0 [4, 4, 5, 2, 5, 7, 4, 5, 2, 7, 2, 3, 2, 3, 2, ... \n", + "1 [5, 13, 2, 9, 4, 3, 2, 7, 6, 3, 3, 4, 7, 5, 2,... \n", + "2 [3, 3, 8, 10, 16, 4, 12, 4, 4, 6, 3, 4, 3, 6, ... \n", "\n", - " timestamp/age_days-list_trim \\\n", - "0 [0.3265768, 0.41545194, 0.52078074, 0.7723212,... \n", - "1 [0.13561419, 0.035071343, 0.58149755, 0.159483... \n", - "2 [0.39440218, 0.12561888, 0.27249986, 0.6201667... \n", + " age_days-list \\\n", + "0 [0.0042280755, 0.40522072, 0.42538044, 0.97327... \n", + "1 [0.07340249, 0.2910817, 0.010784109, 0.8495507... \n", + "2 [0.29271448, 0.59962034, 0.042938035, 0.730446... \n", "\n", - " timestamp/weekday/sin-list_trim \n", - "0 [0.5636411, 0.14788395, 0.6995017, 0.010999571... \n", - "1 [0.3359673, 0.6002685, 0.84561634, 0.04078535,... \n", - "2 [0.90376323, 0.75177085, 0.6668168, 0.0828298,... " + " weekday_sin-list \n", + "0 [0.6292621, 0.1172376, 0.18633945, 0.8232658, ... \n", + "1 [0.52355134, 0.83093345, 0.8837344, 0.38942775... \n", + "2 [0.8610789, 0.058191676, 0.806903, 0.79222715,... " ] }, "execution_count": 7, @@ -457,11 +458,205 @@ { "cell_type": "code", "execution_count": 8, - "id": "ff88e98f", + "id": "78e42cbf-edd6-44af-af23-c026edb578c4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nametagsdtypeis_listis_raggedproperties.num_bucketsproperties.freq_thresholdproperties.max_sizeproperties.start_indexproperties.cat_pathproperties.domain.minproperties.domain.maxproperties.domain.nameproperties.embedding_sizes.cardinalityproperties.embedding_sizes.dimensionproperties.value_count.minproperties.value_count.max
0session_id(Tags.CATEGORICAL)int64FalseFalseNaN0.00.01.0.//categories/unique.session_id.parquet0.020001.0session_id20002.0410.0NaNNaN
1day-first()int64FalseFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2item_id-count(Tags.CATEGORICAL)int32FalseFalseNaN0.00.01.0.//categories/unique.item_id.parquet0.0828.0item_id829.069.0NaNNaN
3item_id-list(Tags.ITEM_ID, Tags.ITEM, Tags.CATEGORICAL, Ta...int64TrueFalseNaN0.00.01.0.//categories/unique.item_id.parquet0.0828.0item_id829.069.020.020.0
4category-list(Tags.CATEGORICAL, Tags.LIST)int64TrueFalseNaN0.00.01.0.//categories/unique.category.parquet0.0172.0category173.029.020.020.0
5age_days-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN20.020.0
6weekday_sin-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN20.020.0
\n", + "
" + ], + "text/plain": [ + "[{'name': 'session_id', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.session_id.parquet', 'domain': {'min': 0, 'max': 20001, 'name': 'session_id'}, 'embedding_sizes': {'cardinality': 20002, 'dimension': 410}}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 828, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 829, 'dimension': 69}}, 'dtype': dtype('int32'), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 828, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 829, 'dimension': 69}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 172, 'name': 'category'}, 'embedding_sizes': {'cardinality': 173, 'dimension': 29}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "workflow.save('workflow_etl')" + "workflow.output_schema" ] }, { @@ -491,6 +686,16 @@ "workflow.fit_transform(dataset).to_parquet(os.path.join(INPUT_DATA_DIR, \"processed_nvt\"))" ] }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9f498dce-69eb-4f88-8ddd-8629558825df", + "metadata": {}, + "outputs": [], + "source": [ + "workflow.save('workflow_etl')" + ] + }, { "cell_type": "markdown", "id": "02a41961", @@ -509,7 +714,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "12d3e59b", "metadata": {}, "outputs": [], @@ -520,7 +725,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "6c67a92b", "metadata": {}, "outputs": [ @@ -528,7 +733,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Creating time-based splits: 100%|██████████| 9/9 [00:01<00:00, 5.79it/s]\n" + "Creating time-based splits: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 9.56it/s]\n" ] } ], @@ -551,7 +756,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "dd04ec82", "metadata": {}, "outputs": [], @@ -561,7 +766,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "8e5e6358", "metadata": {}, "outputs": [ @@ -586,59 +791,59 @@ " \n", " \n", " \n", - " item_id-count\n", " session_id\n", - " item_id-list_trim\n", - " category-list_trim\n", - " timestamp/age_days-list_trim\n", - " timestamp/weekday/sin-list_trim\n", + " item_id-count\n", + " item_id-list\n", + " category-list\n", + " age_days-list\n", + " weekday_sin-list\n", " \n", " \n", " \n", " \n", " 0\n", - " 578\n", - " 8\n", - " [17, 3, 130, 56, 35, 4, 12, 48, 21, 6, 13, 12,...\n", - " [4, 2, 17, 8, 5, 2, 3, 7, 2, 2, 3, 3, 3, 6, 2,...\n", - " [0.91336805, 0.7539333, 0.6858618, 0.91122335,...\n", - " [0.4654125, 0.8024907, 0.15122412, 0.9189323, ...\n", + " 15\n", + " 73\n", + " [7, 15, 7, 18, 49, 106, 6, 2, 9, 9, 10, 31, 2,...\n", + " [3, 4, 3, 4, 8, 16, 2, 2, 2, 2, 3, 6, 2, 3, 3,...\n", + " [0.83090544, 0.80508804, 0.77800703, 0.1136483...\n", + " [0.73860276, 0.71227425, 0.82740945, 0.7208882...\n", " \n", " \n", " 1\n", - " 570\n", - " 24\n", - " [102, 5, 28, 12, 7, 11, 34, 2, 21, 11, 20, 4, ...\n", - " [13, 2, 5, 3, 3, 3, 5, 2, 2, 3, 4, 2, 5, 13, 4...\n", - " [0.16131416, 0.7624795, 0.5117769, 0.06776055,...\n", - " [0.17060736, 0.23287642, 0.5058551, 0.28743693...\n", + " 33\n", + " 72\n", + " [21, 15, 11, 34, 13, 2, 35, 16, 8, 16, 2, 72, ...\n", + " [2, 4, 3, 6, 3, 2, 6, 4, 3, 4, 2, 11, 4, 5, 6,...\n", + " [0.091410436, 0.39586508, 0.5213629, 0.73786, ...\n", + " [0.81503123, 0.8240082, 0.13869828, 0.00287529...\n", " \n", " \n", " 2\n", - " 567\n", - " 32\n", - " [11, 24, 2, 32, 6, 26, 14, 4, 5, 17, 17, 10, 8...\n", - " [3, 4, 2, 5, 2, 4, 2, 2, 2, 4, 4, 3, 3, 4, 2, ...\n", - " [0.62564784, 0.997358, 0.8010653, 0.027112987,...\n", - " [0.6030678, 0.25616208, 0.9580145, 0.99706334,...\n", + " 50\n", + " 71\n", + " [6, 113, 59, 16, 12, 209, 3, 181, 40, 6, 13, 3...\n", + " [2, 17, 9, 4, 3, 29, 2, 26, 7, 2, 3, 6, 9, 27,...\n", + " [0.0334595, 0.7818315, 0.9511686, 0.48197943, ...\n", + " [0.5987661, 0.2299972, 0.57004195, 0.93016136,...\n", " \n", " \n", " 4\n", - " 565\n", - " 43\n", - " [18, 16, 22, 32, 13, 73, 5, 21, 11, 61, 78, 73...\n", - " [4, 3, 4, 5, 3, 10, 2, 2, 3, 8, 10, 10, 6, 4, ...\n", - " [0.7378176, 0.16968544, 0.7315238, 0.95425814,...\n", - " [0.34397638, 0.8673334, 0.55496854, 0.9816106,...\n", + " 54\n", + " 71\n", + " [121, 23, 21, 46, 18, 68, 39, 103, 53, 15, 3, ...\n", + " [18, 5, 2, 7, 4, 11, 7, 15, 9, 4, 2, 2, 4, 5, ...\n", + " [0.24764787, 0.7796014, 0.6935816, 0.37522456,...\n", + " [0.7501101, 0.8012419, 0.6419888, 0.18589461, ...\n", " \n", " \n", " 5\n", - " 563\n", - " 57\n", - " [65, 15, 85, 11, 7, 60, 46, 11, 9, 48, 2, 19, ...\n", - " [9, 3, 11, 3, 3, 8, 7, 3, 2, 7, 2, 4, 7, 3, 2,...\n", - " [0.84140944, 0.032678757, 0.5808043, 0.9555968...\n", - " [0.1885734, 0.68167686, 0.5295532, 0.85896724,...\n", + " 60\n", + " 70\n", + " [4, 25, 62, 11, 94, 14, 13, 28, 8, 38, 47, 6, ...\n", + " [2, 5, 10, 3, 14, 2, 3, 5, 3, 6, 8, 2, 17, 3, ...\n", + " [0.42777818, 0.93942195, 0.7336651, 0.6372542,...\n", + " [0.6628856, 0.19497228, 0.34514892, 0.7849939,...\n", " \n", " \n", " ...\n", @@ -650,125 +855,125 @@ " ...\n", " \n", " \n", - " 2152\n", - " 438\n", - " 19951\n", - " [8, 25, 12, 63, 33, 40, 22, 28, 77, 33, 13, 8,...\n", - " [3, 4, 3, 9, 2, 6, 4, 5, 10, 2, 3, 3, 2, 5, 2,...\n", - " [0.8278093, 0.15857665, 0.36844572, 0.19620946...\n", - " [0.1730732, 0.8465068, 0.15297464, 0.46283653,...\n", + " 2234\n", + " 19970\n", + " 31\n", + " [16, 21, 8, 16, 17, 118, 146, 32, 16, 89, 52, ...\n", + " [4, 2, 3, 4, 4, 17, 22, 6, 4, 14, 8, 2, 2, 17,...\n", + " [0.39309493, 0.80350375, 0.41615465, 0.7130491...\n", + " [0.4812473, 0.19307572, 0.31647742, 0.48890728...\n", " \n", " \n", - " 2153\n", - " 437\n", - " 19961\n", - " [23, 40, 77, 89, 33, 38, 19, 16, 13, 60, 11, 5...\n", - " [4, 6, 10, 12, 2, 6, 4, 3, 3, 8, 3, 2, 2, 8, 6...\n", - " [0.34355995, 0.16667433, 0.3479699, 0.6915854,...\n", - " [0.7232413, 0.54277384, 0.9209216, 0.15295857,...\n", + " 2235\n", + " 19979\n", + " 30\n", + " [17, 13, 227, 22, 13, 43, 20, 65, 8, 13, 3, 36...\n", + " [4, 3, 33, 4, 3, 7, 4, 10, 3, 3, 2, 6, 2, 2, 7...\n", + " [0.80101305, 0.87425673, 0.072974995, 0.940252...\n", + " [0.9547868, 0.22636005, 0.14399427, 0.25720155...\n", " \n", " \n", - " 2155\n", - " 436\n", - " 19970\n", - " [4, 37, 9, 87, 11, 18, 111, 64, 6, 7, 3, 22, 4...\n", - " [2, 6, 2, 12, 3, 4, 15, 9, 2, 3, 2, 4, 2, 4, 2...\n", - " [0.86520904, 0.2337783, 0.7927252, 0.60708684,...\n", - " [0.37351412, 0.9064367, 0.78618735, 0.14689812...\n", + " 2236\n", + " 19980\n", + " 30\n", + " [12, 26, 16, 50, 23, 113, 38, 4, 17, 5, 22, 15...\n", + " [3, 5, 4, 8, 5, 17, 6, 2, 4, 3, 4, 4, 3, 3, 3,...\n", + " [0.8601423, 0.94999486, 0.19948259, 0.00668230...\n", + " [0.53358024, 0.03624035, 0.36104643, 0.6872657...\n", " \n", " \n", - " 2156\n", - " 434\n", - " 19977\n", - " [4, 19, 58, 28, 11, 140, 3, 93, 32, 42, 7, 5, ...\n", - " [2, 4, 8, 5, 3, 18, 2, 12, 5, 6, 3, 2, 4, 3, 3...\n", - " [0.9443977, 0.50952923, 0.7083005, 0.6272203, ...\n", - " [0.01576173, 0.5609052, 0.9132833, 0.72388875,...\n", + " 2237\n", + " 19986\n", + " 29\n", + " [154, 18, 28, 216, 173, 148, 22, 72, 28, 21, 2...\n", + " [24, 4, 5, 32, 25, 22, 4, 11, 5, 2, 5, 2, 2, 2...\n", + " [0.5032084, 0.22921799, 0.5817064, 0.31098822,...\n", + " [0.93975395, 0.9232329, 0.292329, 0.034402855,...\n", " \n", " \n", - " 2157\n", - " 422\n", - " 19997\n", - " [37, 32, 29, 18, 9, 13, 53, 7, 19, 20, 36, 27,...\n", - " [6, 5, 5, 4, 2, 3, 8, 3, 4, 4, 6, 5, 4, 2, 3, ...\n", - " [0.61663723, 0.7966605, 0.67182475, 0.7436706,...\n", - " [0.07727963, 0.6455043, 0.5547871, 0.9132467, ...\n", + " 2239\n", + " 19998\n", + " 27\n", + " [7, 14, 5, 14, 46, 31, 11, 7, 3, 10, 54, 2, 56...\n", + " [3, 2, 3, 2, 7, 6, 3, 3, 2, 3, 9, 2, 9, 3, 3, ...\n", + " [0.2078503, 0.6980298, 0.29598948, 0.49848178,...\n", + " [0.28777036, 0.6233945, 0.48009974, 0.9278188,...\n", " \n", " \n", "\n", - "

1724 rows × 6 columns

\n", + "

1794 rows × 6 columns

\n", "" ], "text/plain": [ - " item_id-count session_id \\\n", - "0 578 8 \n", - "1 570 24 \n", - "2 567 32 \n", - "4 565 43 \n", - "5 563 57 \n", - "... ... ... \n", - "2152 438 19951 \n", - "2153 437 19961 \n", - "2155 436 19970 \n", - "2156 434 19977 \n", - "2157 422 19997 \n", + " session_id item_id-count \\\n", + "0 15 73 \n", + "1 33 72 \n", + "2 50 71 \n", + "4 54 71 \n", + "5 60 70 \n", + "... ... ... \n", + "2234 19970 31 \n", + "2235 19979 30 \n", + "2236 19980 30 \n", + "2237 19986 29 \n", + "2239 19998 27 \n", "\n", - " item_id-list_trim \\\n", - "0 [17, 3, 130, 56, 35, 4, 12, 48, 21, 6, 13, 12,... \n", - "1 [102, 5, 28, 12, 7, 11, 34, 2, 21, 11, 20, 4, ... \n", - "2 [11, 24, 2, 32, 6, 26, 14, 4, 5, 17, 17, 10, 8... \n", - "4 [18, 16, 22, 32, 13, 73, 5, 21, 11, 61, 78, 73... \n", - "5 [65, 15, 85, 11, 7, 60, 46, 11, 9, 48, 2, 19, ... \n", + " item_id-list \\\n", + "0 [7, 15, 7, 18, 49, 106, 6, 2, 9, 9, 10, 31, 2,... \n", + "1 [21, 15, 11, 34, 13, 2, 35, 16, 8, 16, 2, 72, ... \n", + "2 [6, 113, 59, 16, 12, 209, 3, 181, 40, 6, 13, 3... \n", + "4 [121, 23, 21, 46, 18, 68, 39, 103, 53, 15, 3, ... \n", + "5 [4, 25, 62, 11, 94, 14, 13, 28, 8, 38, 47, 6, ... \n", "... ... \n", - "2152 [8, 25, 12, 63, 33, 40, 22, 28, 77, 33, 13, 8,... \n", - "2153 [23, 40, 77, 89, 33, 38, 19, 16, 13, 60, 11, 5... \n", - "2155 [4, 37, 9, 87, 11, 18, 111, 64, 6, 7, 3, 22, 4... \n", - "2156 [4, 19, 58, 28, 11, 140, 3, 93, 32, 42, 7, 5, ... \n", - "2157 [37, 32, 29, 18, 9, 13, 53, 7, 19, 20, 36, 27,... \n", + "2234 [16, 21, 8, 16, 17, 118, 146, 32, 16, 89, 52, ... \n", + "2235 [17, 13, 227, 22, 13, 43, 20, 65, 8, 13, 3, 36... \n", + "2236 [12, 26, 16, 50, 23, 113, 38, 4, 17, 5, 22, 15... \n", + "2237 [154, 18, 28, 216, 173, 148, 22, 72, 28, 21, 2... \n", + "2239 [7, 14, 5, 14, 46, 31, 11, 7, 3, 10, 54, 2, 56... \n", "\n", - " category-list_trim \\\n", - "0 [4, 2, 17, 8, 5, 2, 3, 7, 2, 2, 3, 3, 3, 6, 2,... \n", - "1 [13, 2, 5, 3, 3, 3, 5, 2, 2, 3, 4, 2, 5, 13, 4... \n", - "2 [3, 4, 2, 5, 2, 4, 2, 2, 2, 4, 4, 3, 3, 4, 2, ... \n", - "4 [4, 3, 4, 5, 3, 10, 2, 2, 3, 8, 10, 10, 6, 4, ... \n", - "5 [9, 3, 11, 3, 3, 8, 7, 3, 2, 7, 2, 4, 7, 3, 2,... \n", + " category-list \\\n", + "0 [3, 4, 3, 4, 8, 16, 2, 2, 2, 2, 3, 6, 2, 3, 3,... \n", + "1 [2, 4, 3, 6, 3, 2, 6, 4, 3, 4, 2, 11, 4, 5, 6,... \n", + "2 [2, 17, 9, 4, 3, 29, 2, 26, 7, 2, 3, 6, 9, 27,... \n", + "4 [18, 5, 2, 7, 4, 11, 7, 15, 9, 4, 2, 2, 4, 5, ... \n", + "5 [2, 5, 10, 3, 14, 2, 3, 5, 3, 6, 8, 2, 17, 3, ... \n", "... ... \n", - "2152 [3, 4, 3, 9, 2, 6, 4, 5, 10, 2, 3, 3, 2, 5, 2,... \n", - "2153 [4, 6, 10, 12, 2, 6, 4, 3, 3, 8, 3, 2, 2, 8, 6... \n", - "2155 [2, 6, 2, 12, 3, 4, 15, 9, 2, 3, 2, 4, 2, 4, 2... \n", - "2156 [2, 4, 8, 5, 3, 18, 2, 12, 5, 6, 3, 2, 4, 3, 3... \n", - "2157 [6, 5, 5, 4, 2, 3, 8, 3, 4, 4, 6, 5, 4, 2, 3, ... \n", + "2234 [4, 2, 3, 4, 4, 17, 22, 6, 4, 14, 8, 2, 2, 17,... \n", + "2235 [4, 3, 33, 4, 3, 7, 4, 10, 3, 3, 2, 6, 2, 2, 7... \n", + "2236 [3, 5, 4, 8, 5, 17, 6, 2, 4, 3, 4, 4, 3, 3, 3,... \n", + "2237 [24, 4, 5, 32, 25, 22, 4, 11, 5, 2, 5, 2, 2, 2... \n", + "2239 [3, 2, 3, 2, 7, 6, 3, 3, 2, 3, 9, 2, 9, 3, 3, ... \n", "\n", - " timestamp/age_days-list_trim \\\n", - "0 [0.91336805, 0.7539333, 0.6858618, 0.91122335,... \n", - "1 [0.16131416, 0.7624795, 0.5117769, 0.06776055,... \n", - "2 [0.62564784, 0.997358, 0.8010653, 0.027112987,... \n", - "4 [0.7378176, 0.16968544, 0.7315238, 0.95425814,... \n", - "5 [0.84140944, 0.032678757, 0.5808043, 0.9555968... \n", + " age_days-list \\\n", + "0 [0.83090544, 0.80508804, 0.77800703, 0.1136483... \n", + "1 [0.091410436, 0.39586508, 0.5213629, 0.73786, ... \n", + "2 [0.0334595, 0.7818315, 0.9511686, 0.48197943, ... \n", + "4 [0.24764787, 0.7796014, 0.6935816, 0.37522456,... \n", + "5 [0.42777818, 0.93942195, 0.7336651, 0.6372542,... \n", "... ... \n", - "2152 [0.8278093, 0.15857665, 0.36844572, 0.19620946... \n", - "2153 [0.34355995, 0.16667433, 0.3479699, 0.6915854,... \n", - "2155 [0.86520904, 0.2337783, 0.7927252, 0.60708684,... \n", - "2156 [0.9443977, 0.50952923, 0.7083005, 0.6272203, ... \n", - "2157 [0.61663723, 0.7966605, 0.67182475, 0.7436706,... \n", + "2234 [0.39309493, 0.80350375, 0.41615465, 0.7130491... \n", + "2235 [0.80101305, 0.87425673, 0.072974995, 0.940252... \n", + "2236 [0.8601423, 0.94999486, 0.19948259, 0.00668230... \n", + "2237 [0.5032084, 0.22921799, 0.5817064, 0.31098822,... \n", + "2239 [0.2078503, 0.6980298, 0.29598948, 0.49848178,... \n", "\n", - " timestamp/weekday/sin-list_trim \n", - "0 [0.4654125, 0.8024907, 0.15122412, 0.9189323, ... \n", - "1 [0.17060736, 0.23287642, 0.5058551, 0.28743693... \n", - "2 [0.6030678, 0.25616208, 0.9580145, 0.99706334,... \n", - "4 [0.34397638, 0.8673334, 0.55496854, 0.9816106,... \n", - "5 [0.1885734, 0.68167686, 0.5295532, 0.85896724,... \n", + " weekday_sin-list \n", + "0 [0.73860276, 0.71227425, 0.82740945, 0.7208882... \n", + "1 [0.81503123, 0.8240082, 0.13869828, 0.00287529... \n", + "2 [0.5987661, 0.2299972, 0.57004195, 0.93016136,... \n", + "4 [0.7501101, 0.8012419, 0.6419888, 0.18589461, ... \n", + "5 [0.6628856, 0.19497228, 0.34514892, 0.7849939,... \n", "... ... \n", - "2152 [0.1730732, 0.8465068, 0.15297464, 0.46283653,... \n", - "2153 [0.7232413, 0.54277384, 0.9209216, 0.15295857,... \n", - "2155 [0.37351412, 0.9064367, 0.78618735, 0.14689812... \n", - "2156 [0.01576173, 0.5609052, 0.9132833, 0.72388875,... \n", - "2157 [0.07727963, 0.6455043, 0.5547871, 0.9132467, ... \n", + "2234 [0.4812473, 0.19307572, 0.31647742, 0.48890728... \n", + "2235 [0.9547868, 0.22636005, 0.14399427, 0.25720155... \n", + "2236 [0.53358024, 0.03624035, 0.36104643, 0.6872657... \n", + "2237 [0.93975395, 0.9232329, 0.292329, 0.034402855,... \n", + "2239 [0.28777036, 0.6233945, 0.48009974, 0.9278188,... \n", "\n", - "[1724 rows x 6 columns]" + "[1794 rows x 6 columns]" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb b/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb index 943bb47fcb..0f03820b46 100644 --- a/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb +++ b/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb @@ -106,10 +106,71 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "3ba89970", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!\n", + " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (NDCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (DCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (AvgPrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (PrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (RecallAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n" + ] + } + ], "source": [ "import os\n", "\n", @@ -170,18 +231,18 @@ "output_type": "stream", "text": [ "feature {\n", - " name: \"item_id-count\"\n", + " name: \"session_id\"\n", " type: INT\n", " int_domain {\n", - " name: \"item_id\"\n", - " max: 1322\n", + " name: \"session_id\"\n", + " max: 19870\n", " is_categorical: true\n", " }\n", " annotation {\n", " tag: \"categorical\"\n", " extra_metadata {\n", " type_url: \"type.googleapis.com/google.protobuf.Struct\"\n", - " value: \"\\n\\021\\n\\013num_buckets\\022\\002\\010\\000\\n\\033\\n\\016freq_threshold\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n\\025\\n\\010max_size\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n\\030\\n\\013start_index\\022\\t\\021\\000\\000\\000\\000\\000\\000\\360?\\n<\\n\\010cat_path\\0220\\032.workflow_etl/categories/unique.item_id.parquet\\nG\\n\\017embedding_sizes\\0224*2\\n\\030\\n\\013cardinality\\022\\t\\021\\000\\000\\000\\000\\000\\250\\224@\\n\\026\\n\\tdimension\\022\\t\\021\\000\\000\\000\\000\\000\\200V@\\n\\034\\n\\017dtype_item_size\\022\\t\\021\\000\\000\\000\\000\\000\\000@@\\n\\r\\n\\007is_list\\022\\002 \\000\\n\\017\\n\\tis_ragged\\022\\002 \\000\"\n", + " value: \"\\n\\021\\n\\013num_buckets\\022\\002\\010\\000\\n\\033\\n\\016freq_threshold\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n\\025\\n\\010max_size\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n\\030\\n\\013start_index\\022\\t\\021\\000\\000\\000\\000\\000\\000\\360?\\n5\\n\\010cat_path\\022)\\032\\'.//categories/unique.session_id.parquet\\nG\\n\\017embedding_sizes\\0224*2\\n\\030\\n\\013cardinality\\022\\t\\021\\000\\000\\000\\000\\300g\\323@\\n\\026\\n\\tdimension\\022\\t\\021\\000\\000\\000\\000\\000\\200y@\\n\\034\\n\\017dtype_item_size\\022\\t\\021\\000\\000\\000\\000\\000\\000P@\\n\\r\\n\\007is_list\\022\\002 \\000\\n\\017\\n\\tis_ragged\\022\\002 \\000\"\n", " }\n", " }\n", "}\n", @@ -204,10 +265,10 @@ "outputs": [], "source": [ "# You can select a subset of features for training\n", - "schema = schema.select_by_name(['item_id-list_trim', \n", - " 'category-list_trim', \n", - " 'timestamp/weekday/sin-list_trim',\n", - " 'timestamp/age_days-list_trim'])" + "schema = schema.select_by_name(['item_id-list', \n", + " 'category-list', \n", + " 'weekday_sin-list',\n", + " 'age_days-list'])" ] }, { @@ -285,10 +346,47 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "ed749ca8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (NDCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (DCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (RecallAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n" + ] + } + ], "source": [ "# Define XLNetConfig class and set default parameters for HF XLNet config \n", "transformer_config = tr.XLNetConfig.build(\n", @@ -450,8 +548,6 @@ "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.\n", - " warnings.warn(\n", "***** Running training *****\n", " Num examples = 1664\n", " Num Epochs = 5\n", @@ -491,7 +587,7 @@ " \n", " \n", " 50\n", - " 6.715900\n", + " 5.821000\n", " \n", " \n", "

" @@ -521,22 +617,14 @@ "finished\n" ] }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.\n", - " warnings.warn(\n" - ] - }, { "data": { "text/html": [ "\n", "

\n", " \n", - " \n", - " [6/6 00:09]\n", + " \n", + " [6/6 00:10]\n", "
\n", " " ], @@ -556,14 +644,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 5.670037746429443\n", - " eval_/next-item/ndcg_at_20 = 0.14301326870918274\n", - " eval_/next-item/ndcg_at_40 = 0.20679126679897308\n", - " eval_/next-item/recall_at_20 = 0.359375\n", - " eval_/next-item/recall_at_40 = 0.6666666865348816\n", - " eval_runtime = 0.128\n", - " eval_samples_per_second = 1500.117\n", - " eval_steps_per_second = 46.879\n", + " eval_/loss = 5.006950855255127\n", + " eval_/next-item/ndcg_at_20 = 0.17477868497371674\n", + " eval_/next-item/ndcg_at_40 = 0.24376630783081055\n", + " eval_/next-item/recall_at_20 = 0.421875\n", + " eval_/next-item/recall_at_40 = 0.7552083730697632\n", + " eval_runtime = 0.1473\n", + " eval_samples_per_second = 1303.363\n", + " eval_steps_per_second = 40.73\n", "['/workspace/data/sessions_by_day/2/train.parquet']\n", "********************\n", "Launch training for day 2 are:\n", @@ -603,7 +691,7 @@ " \n", " \n", " 50\n", - " 5.340300\n", + " 4.896300\n", " \n", " \n", "

" @@ -630,25 +718,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "finished\n", - "********************\n", - "Eval results for day 3 are:\t\n", - "\n", - "********************\n", - "\n", - " eval_/loss = 4.734586238861084\n", - " eval_/next-item/ndcg_at_20 = 0.1766236275434494\n", - " eval_/next-item/ndcg_at_40 = 0.2341984361410141\n", - " eval_/next-item/recall_at_20 = 0.5104166865348816\n", - " eval_/next-item/recall_at_40 = 0.7916666865348816\n", - " eval_runtime = 0.1261\n", - " eval_samples_per_second = 1522.894\n", - " eval_steps_per_second = 47.59\n", - "['/workspace/data/sessions_by_day/3/train.parquet']\n", - "********************\n", - "Launch training for day 3 are:\n", - "********************\n", - "\n" + "finished\n" ] }, { @@ -664,6 +734,30 @@ " Total optimization steps = 65\n" ] }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "********************\n", + "Eval results for day 3 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 4.632648944854736\n", + " eval_/next-item/ndcg_at_20 = 0.18322871625423431\n", + " eval_/next-item/ndcg_at_40 = 0.24265889823436737\n", + " eval_/next-item/recall_at_20 = 0.4427083432674408\n", + " eval_/next-item/recall_at_40 = 0.7291666865348816\n", + " eval_runtime = 0.2087\n", + " eval_samples_per_second = 919.825\n", + " eval_steps_per_second = 28.745\n", + "['/workspace/data/sessions_by_day/3/train.parquet']\n", + "********************\n", + "Launch training for day 3 are:\n", + "********************\n", + "\n" + ] + }, { "data": { "text/html": [ @@ -683,7 +777,7 @@ " \n", " \n", " 50\n", - " 4.742700\n", + " 4.577600\n", " \n", " \n", "

" @@ -716,14 +810,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.577661037445068\n", - " eval_/next-item/ndcg_at_20 = 0.187037393450737\n", - " eval_/next-item/ndcg_at_40 = 0.22624635696411133\n", - " eval_/next-item/recall_at_20 = 0.546875\n", - " eval_/next-item/recall_at_40 = 0.7395833730697632\n", - " eval_runtime = 0.1279\n", - " eval_samples_per_second = 1501.392\n", - " eval_steps_per_second = 46.919\n", + " eval_/loss = 4.358114719390869\n", + " eval_/next-item/ndcg_at_20 = 0.2185230255126953\n", + " eval_/next-item/ndcg_at_40 = 0.2660380005836487\n", + " eval_/next-item/recall_at_20 = 0.578125\n", + " eval_/next-item/recall_at_40 = 0.8072916865348816\n", + " eval_runtime = 0.1733\n", + " eval_samples_per_second = 1107.591\n", + " eval_steps_per_second = 34.612\n", "['/workspace/data/sessions_by_day/4/train.parquet']\n", "********************\n", "Launch training for day 4 are:\n", @@ -763,7 +857,7 @@ " \n", " \n", " 50\n", - " 4.573400\n", + " 4.497000\n", " \n", " \n", "

" @@ -796,19 +890,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.446651935577393\n", - " eval_/next-item/ndcg_at_20 = 0.19330403208732605\n", - " eval_/next-item/ndcg_at_40 = 0.2473829984664917\n", - " eval_/next-item/recall_at_20 = 0.5208333730697632\n", - " eval_/next-item/recall_at_40 = 0.7864583730697632\n", - " eval_runtime = 0.1305\n", - " eval_samples_per_second = 1471.204\n", - " eval_steps_per_second = 45.975\n", - "['/workspace/data/sessions_by_day/5/train.parquet']\n", - "********************\n", - "Launch training for day 5 are:\n", - "********************\n", - "\n" + " eval_/loss = 4.520808696746826\n", + " eval_/next-item/ndcg_at_20 = 0.2083323895931244\n", + " eval_/next-item/ndcg_at_40 = 0.2511211037635803\n", + " eval_/next-item/recall_at_20 = 0.53125\n", + " eval_/next-item/recall_at_40 = 0.7395833730697632\n", + " eval_runtime = 0.173\n", + " eval_samples_per_second = 1109.768\n", + " eval_steps_per_second = 34.68\n" ] }, { @@ -824,6 +913,17 @@ " Total optimization steps = 65\n" ] }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['/workspace/data/sessions_by_day/5/train.parquet']\n", + "********************\n", + "Launch training for day 5 are:\n", + "********************\n", + "\n" + ] + }, { "data": { "text/html": [ @@ -843,7 +943,7 @@ " \n", " \n", " 50\n", - " 4.544500\n", + " 4.491200\n", " \n", " \n", "

" @@ -876,19 +976,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.383015155792236\n", - " eval_/next-item/ndcg_at_20 = 0.20541004836559296\n", - " eval_/next-item/ndcg_at_40 = 0.26060688495635986\n", - " eval_/next-item/recall_at_20 = 0.5416666865348816\n", - " eval_/next-item/recall_at_40 = 0.8125\n", - " eval_runtime = 0.126\n", - " eval_samples_per_second = 1523.488\n", - " eval_steps_per_second = 47.609\n", - "['/workspace/data/sessions_by_day/6/train.parquet']\n", - "********************\n", - "Launch training for day 6 are:\n", - "********************\n", - "\n" + " eval_/loss = 4.366371154785156\n", + " eval_/next-item/ndcg_at_20 = 0.21815048158168793\n", + " eval_/next-item/ndcg_at_40 = 0.2546561360359192\n", + " eval_/next-item/recall_at_20 = 0.5885416865348816\n", + " eval_/next-item/recall_at_40 = 0.765625\n", + " eval_runtime = 0.1743\n", + " eval_samples_per_second = 1101.611\n", + " eval_steps_per_second = 34.425\n" ] }, { @@ -896,12 +991,23 @@ "output_type": "stream", "text": [ "***** Running training *****\n", - " Num examples = 1792\n", + " Num examples = 1664\n", " Num Epochs = 5\n", " Instantaneous batch size per device = 128\n", " Total train batch size (w. parallel, distributed & accumulation) = 128\n", " Gradient Accumulation steps = 1\n", - " Total optimization steps = 70\n" + " Total optimization steps = 65\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['/workspace/data/sessions_by_day/6/train.parquet']\n", + "********************\n", + "Launch training for day 6 are:\n", + "********************\n", + "\n" ] }, { @@ -910,8 +1016,8 @@ "\n", "

\n", " \n", - " \n", - " [70/70 00:01, Epoch 5/5]\n", + " \n", + " [65/65 00:01, Epoch 5/5]\n", "
\n", " \n", " \n", @@ -923,7 +1029,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
504.5118004.488000

" @@ -956,16 +1062,16 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.352958679199219\n", - " eval_/next-item/ndcg_at_20 = 0.19154103100299835\n", - " eval_/next-item/ndcg_at_40 = 0.24716795980930328\n", - " eval_/next-item/recall_at_20 = 0.5223214626312256\n", - " eval_/next-item/recall_at_40 = 0.7901785969734192\n", - " eval_runtime = 0.1382\n", - " eval_samples_per_second = 1620.526\n", - " eval_steps_per_second = 50.641\n", - "CPU times: user 40.2 s, sys: 503 ms, total: 40.7 s\n", - "Wall time: 11.6 s\n" + " eval_/loss = 4.4079389572143555\n", + " eval_/next-item/ndcg_at_20 = 0.2032829374074936\n", + " eval_/next-item/ndcg_at_40 = 0.24693170189857483\n", + " eval_/next-item/recall_at_20 = 0.5416666865348816\n", + " eval_/next-item/recall_at_40 = 0.75\n", + " eval_runtime = 0.1598\n", + " eval_samples_per_second = 1201.645\n", + " eval_steps_per_second = 37.551\n", + "CPU times: user 44 s, sys: 560 ms, total: 44.6 s\n", + "Wall time: 13.2 s\n" ] } ], @@ -1021,7 +1127,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Saving model checkpoint to ./tmp/checkpoint-71\n", + "Saving model checkpoint to ./tmp/checkpoint-66\n", "Trainer.model is not a `PreTrainedModel`, only saving its state dict.\n" ] } @@ -1072,26 +1178,18 @@ "id": "2c6867d8", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/cudf/core/frame.py:384: UserWarning: The deep parameter is ignored and is only included for pandas compatibility.\n", - " warnings.warn(\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - " eval_/loss = 4.352958679199219\n", - " eval_/next-item/ndcg_at_20 = 0.19154103100299835\n", - " eval_/next-item/ndcg_at_40 = 0.24716795980930328\n", - " eval_/next-item/recall_at_20 = 0.5223214626312256\n", - " eval_/next-item/recall_at_40 = 0.7901785969734192\n", - " eval_runtime = 0.1481\n", - " eval_samples_per_second = 1512.78\n", - " eval_steps_per_second = 47.274\n" + " eval_/loss = 4.4079389572143555\n", + " eval_/next-item/ndcg_at_20 = 0.2032829374074936\n", + " eval_/next-item/ndcg_at_40 = 0.24693170189857483\n", + " eval_/next-item/recall_at_20 = 0.5416666865348816\n", + " eval_/next-item/recall_at_40 = 0.75\n", + " eval_runtime = 0.1713\n", + " eval_samples_per_second = 1121.166\n", + " eval_steps_per_second = 35.036\n" ] } ], diff --git a/examples/getting-started-session-based/03-serving-session-based-model-with-Torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-with-Torch-backend.ipynb new file mode 100644 index 0000000000..77aa95adee --- /dev/null +++ b/examples/getting-started-session-based/03-serving-session-based-model-with-Torch-backend.ipynb @@ -0,0 +1,1473 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "97250792", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "# ==============================================================================" + ] + }, + { + "cell_type": "markdown", + "id": "0a2228da", + "metadata": {}, + "source": [ + "\n", + "\n", + "# Session-based Recommendation with XLNET" + ] + }, + { + "cell_type": "markdown", + "id": "599efc90", + "metadata": {}, + "source": [ + "### Imports required libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3ba89970", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!\n", + " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (NDCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (DCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (AvgPrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (PrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (RecallAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n" + ] + } + ], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", + "\n", + "import cudf\n", + "import glob\n", + "import torch \n", + "\n", + "from transformers4rec import torch as tr\n", + "from transformers4rec.torch.ranking_metric import NDCGAt, AvgPrecisionAt, RecallAt\n", + "from transformers4rec.torch.utils.examples_utils import wipe_memory\n", + "from merlin.io import Dataset" + ] + }, + { + "cell_type": "markdown", + "id": "aea2a0c5", + "metadata": {}, + "source": [ + "Transformers4Rec library relies on a schema object to automatically build all necessary layers to represent, normalize and aggregate input features. As you can see below, `schema.pb` is a protobuf file that contains metadata including statistics about features such as cardinality, min and max values and also tags features based on their characteristics and dtypes (e.g., categorical, continuous, list, integer)." + ] + }, + { + "cell_type": "markdown", + "id": "a510b6ef", + "metadata": {}, + "source": [ + "### Set the schema object" + ] + }, + { + "cell_type": "markdown", + "id": "30a0518a-eb01-4ac4-9c6d-36b328985765", + "metadata": {}, + "source": [ + "We create the schema object by reading the `schema.pbtxt` file generated by NVTabular pipeline in the previous, `01-ETL-with-NVTabular`, notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9d1299fa", + "metadata": {}, + "outputs": [], + "source": [ + "from merlin_standard_lib import Schema\n", + "# import merlin.io\n", + "# from merlin.models.utils import schema_utils\n", + "# from merlin.schema import Schema, Tags\n", + "# from merlin.schema.io.tensorflow_metadata import TensorflowMetadata\n", + "# from merlin.schema import Schema\n", + "SCHEMA_PATH = os.environ.get(\"INPUT_SCHEMA_PATH\", \"/workspace/data/processed_nvt/schema.pbtxt\")\n", + "schema = Schema().from_proto_text(SCHEMA_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "868f0317-d140-40d5-b4bd-29a27e12077b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'session_id', 'type': 'INT', 'int_domain': {'name': 'session_id', 'max': '19877', 'is_categorical': True}, 'annotation': {'tag': ['categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.session_id.parquet\", \"embedding_sizes\": {\"cardinality\": 19878.0, \"dimension\": 409.0}, \"dtype_item_size\": 64.0, \"is_list\": false, \"is_ragged\": false}']}}, {'name': 'day-first', 'type': 'INT', 'annotation': {'comment': ['{\"dtype_item_size\": 64.0, \"is_list\": false, \"is_ragged\": false}']}}, {'name': 'item_id-count', 'type': 'INT', 'int_domain': {'name': 'item_id', 'max': '506', 'is_categorical': True}, 'annotation': {'tag': ['categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.item_id.parquet\", \"embedding_sizes\": {\"cardinality\": 507.0, \"dimension\": 52.0}, \"dtype_item_size\": 32.0, \"is_list\": false, \"is_ragged\": false}']}}, {'name': 'item_id-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'INT', 'int_domain': {'name': 'item_id', 'max': '506', 'is_categorical': True}, 'annotation': {'tag': ['list', 'item_id', 'item', 'id', 'categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.item_id.parquet\", \"embedding_sizes\": {\"cardinality\": 507.0, \"dimension\": 52.0}, \"dtype_item_size\": 64.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'category-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'INT', 'int_domain': {'name': 'category', 'max': '137', 'is_categorical': True}, 'annotation': {'tag': ['list', 'categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.category.parquet\", \"embedding_sizes\": {\"cardinality\": 138.0, \"dimension\": 25.0}, \"dtype_item_size\": 64.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'age_days-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'FLOAT', 'annotation': {'tag': ['list', 'continuous'], 'comment': ['{\"dtype_item_size\": 32.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'weekday_sin-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'FLOAT', 'annotation': {'tag': ['list', 'continuous'], 'comment': ['{\"dtype_item_size\": 32.0, \"is_list\": true, \"is_ragged\": true}']}}]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "schema" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b4f426f0", + "metadata": {}, + "outputs": [], + "source": [ + "# # You can select a subset of features for training\n", + "\n", + "# You can select a subset of features for training\n", + "schema = schema.select_by_name(['item_id-list', \n", + " 'category-list',\n", + " 'weekday_sin-list',\n", + " 'age_days-list'\n", + " ])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "31bd0f44-ecfe-489a-88ac-032b5a512622", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'item_id-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'INT', 'int_domain': {'name': 'item_id', 'max': '506', 'is_categorical': True}, 'annotation': {'tag': ['list', 'item_id', 'item', 'id', 'categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.item_id.parquet\", \"embedding_sizes\": {\"cardinality\": 507.0, \"dimension\": 52.0}, \"dtype_item_size\": 64.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'category-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'INT', 'int_domain': {'name': 'category', 'max': '137', 'is_categorical': True}, 'annotation': {'tag': ['list', 'categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.category.parquet\", \"embedding_sizes\": {\"cardinality\": 138.0, \"dimension\": 25.0}, \"dtype_item_size\": 64.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'age_days-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'FLOAT', 'annotation': {'tag': ['list', 'continuous'], 'comment': ['{\"dtype_item_size\": 32.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'weekday_sin-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'FLOAT', 'annotation': {'tag': ['list', 'continuous'], 'comment': ['{\"dtype_item_size\": 32.0, \"is_list\": true, \"is_ragged\": true}']}}]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "schema" + ] + }, + { + "cell_type": "markdown", + "id": "06cacefa", + "metadata": {}, + "source": [ + "### Define the sequential input module" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b38d30d7", + "metadata": {}, + "outputs": [], + "source": [ + "inputs = tr.TabularSequenceFeatures.from_schema(\n", + " schema,\n", + " max_sequence_length=15,\n", + " continuous_projection=64,\n", + " d_output=100,\n", + " masking=\"causal\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ed749ca8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (NDCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (DCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (RecallAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n" + ] + } + ], + "source": [ + "# Define XLNetConfig class and set default parameters for HF XLNet config \n", + "transformer_config = tr.XLNetConfig.build(\n", + " d_model=64, n_head=4, n_layer=2, total_seq_length=20\n", + ")\n", + "# Define the model block including: inputs, masking, projection and transformer block.\n", + "body = tr.SequentialBlock(\n", + " inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config, masking=inputs.masking)\n", + ")\n", + "\n", + "# Defines the evaluation top-N metrics and the cut-offs\n", + "metrics = [NDCGAt(top_ks=[20, 40], labels_onehot=True), \n", + " RecallAt(top_ks=[20, 40], labels_onehot=True)]\n", + "\n", + "# Define a head related to next item prediction task \n", + "head = tr.Head(\n", + " body,\n", + " tr.NextItemPredictionTask(weight_tying=True, metrics=metrics),\n", + " inputs=inputs,\n", + ")\n", + "\n", + "# Get the end-to-end Model class \n", + "model = tr.Model(head)" + ] + }, + { + "cell_type": "markdown", + "id": "a57335ff", + "metadata": {}, + "source": [ + "Note that we can easily define an RNN-based model inside the `SequentialBlock` instead of a Transformer-based model. You can explore this [tutorial](https://github.com/NVIDIA-Merlin/Transformers4Rec/tree/main/examples/tutorial) for a GRU-based model example." + ] + }, + { + "cell_type": "markdown", + "id": "16d51e39", + "metadata": {}, + "source": [ + "### Train the model " + ] + }, + { + "cell_type": "markdown", + "id": "f26d7aec", + "metadata": {}, + "source": [ + "We use the NVTabular PyTorch Dataloader for optimized loading of multiple features from input parquet files. You can learn more about this data loader [here](https://nvidia-merlin.github.io/NVTabular/main/training/pytorch.html)." + ] + }, + { + "cell_type": "markdown", + "id": "02fd4c22", + "metadata": {}, + "source": [ + "### **Set Training arguments**" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "693974df", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers4rec.config.trainer import T4RecTrainingArguments\n", + "from transformers4rec.torch import Trainer\n", + "# Set hyperparameters for training \n", + "train_args = T4RecTrainingArguments(data_loader_engine='nvtabular', \n", + " dataloader_drop_last = True,\n", + " gradient_accumulation_steps = 1,\n", + " per_device_train_batch_size = 128, \n", + " per_device_eval_batch_size = 32,\n", + " output_dir = \"./tmp\", \n", + " learning_rate=0.0005,\n", + " lr_scheduler_type='cosine', \n", + " learning_rate_num_cosine_cycles_by_epoch=1.5,\n", + " num_train_epochs=5,\n", + " max_sequence_length=20, \n", + " report_to = [],\n", + " logging_steps=50,\n", + " no_cuda=False)" + ] + }, + { + "cell_type": "markdown", + "id": "445ece64", + "metadata": {}, + "source": [ + "Note that we add an argument `data_loader_engine='nvtabular'` to automatically load the features needed for training using the schema. The default value is nvtabular for optimized GPU-based data-loading. Optionally a PyarrowDataLoader (pyarrow) can also be used as a basic option, but it is slower and works only for small datasets, as the full data is loaded to CPU memory." + ] + }, + { + "cell_type": "markdown", + "id": "32554ea0", + "metadata": {}, + "source": [ + "## Daily Fine-Tuning: Training over a time window" + ] + }, + { + "cell_type": "markdown", + "id": "ef883061", + "metadata": {}, + "source": [ + "Here we do daily fine-tuning meaning that we use the first day to train and second day to evaluate, then we use the second day data to train the model by resuming from the first step, and evaluate on the third day, so on so forth." + ] + }, + { + "cell_type": "markdown", + "id": "9f452d09", + "metadata": {}, + "source": [ + "We have extended the HuggingFace transformers `Trainer` class (PyTorch only) to support evaluation of RecSys metrics. In this example, the evaluation of the session-based recommendation model is performed using traditional Top-N ranking metrics such as Normalized Discounted Cumulative Gain (NDCG@20) and Hit Rate (HR@20). NDCG accounts for rank of the relevant item in the recommendation list and is a more fine-grained metric than HR, which only verifies whether the relevant item is among the top-n items. HR@n is equivalent to Recall@n when there is only one relevant item in the recommendation list." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2283f788", + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate the T4Rec Trainer, which manages training and evaluation for the PyTorch API\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=train_args,\n", + " schema=schema,\n", + " compute_metrics=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d515127c", + "metadata": {}, + "source": [ + "- Define the output folder of the processed parquet files" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ae313150", + "metadata": {}, + "outputs": [], + "source": [ + "INPUT_DATA_DIR = os.environ.get(\"INPUT_DATA_DIR\", \"/workspace/data\")\n", + "OUTPUT_DIR = os.environ.get(\"OUTPUT_DIR\", f\"{INPUT_DATA_DIR}/sessions_by_day\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8ae51de0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "***** Running training *****\n", + " Num examples = 1664\n", + " Num Epochs = 5\n", + " Instantaneous batch size per device = 128\n", + " Total train batch size (w. parallel, distributed & accumulation) = 128\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 65\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['/workspace/data/sessions_by_day/1/train.parquet']\n", + "********************\n", + "Launch training for day 1 are:\n", + "********************\n", + "\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [65/65 00:01, Epoch 5/5]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
505.731000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "finished\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [6/6 00:02]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "********************\n", + "Eval results for day 2 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 4.892789840698242\n", + " eval_/next-item/ndcg_at_20 = 0.2019212245941162\n", + " eval_/next-item/ndcg_at_40 = 0.24986284971237183\n", + " eval_/next-item/recall_at_20 = 0.5104166865348816\n", + " eval_/next-item/recall_at_40 = 0.7447916865348816\n", + " eval_runtime = 0.1608\n", + " eval_samples_per_second = 1193.821\n", + " eval_steps_per_second = 37.307\n", + "['/workspace/data/sessions_by_day/2/train.parquet']\n", + "********************\n", + "Launch training for day 2 are:\n", + "********************\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "***** Running training *****\n", + " Num examples = 1664\n", + " Num Epochs = 5\n", + " Instantaneous batch size per device = 128\n", + " Total train batch size (w. parallel, distributed & accumulation) = 128\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 65\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [65/65 00:01, Epoch 5/5]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
504.795200

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "finished\n", + "********************\n", + "Eval results for day 3 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 4.611485481262207\n", + " eval_/next-item/ndcg_at_20 = 0.1681433618068695\n", + " eval_/next-item/ndcg_at_40 = 0.22220981121063232\n", + " eval_/next-item/recall_at_20 = 0.484375\n", + " eval_/next-item/recall_at_40 = 0.7447916865348816\n", + " eval_runtime = 0.1687\n", + " eval_samples_per_second = 1138.188\n", + " eval_steps_per_second = 35.568\n", + "CPU times: user 15.9 s, sys: 247 ms, total: 16.1 s\n", + "Wall time: 5.28 s\n" + ] + } + ], + "source": [ + "%%time\n", + "start_time_window_index = 1\n", + "final_time_window_index = 3\n", + "#Iterating over days of one week\n", + "for time_index in range(start_time_window_index, final_time_window_index):\n", + " # Set data \n", + " time_index_train = time_index\n", + " time_index_eval = time_index + 1\n", + " train_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_train}/train.parquet\"))\n", + " eval_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_eval}/valid.parquet\"))\n", + " print(train_paths)\n", + " \n", + " # Train on day related to time_index \n", + " print('*'*20)\n", + " print(\"Launch training for day %s are:\" %time_index)\n", + " print('*'*20 + '\\n')\n", + " trainer.train_dataset_or_path = train_paths\n", + " trainer.reset_lr_scheduler()\n", + " trainer.train()\n", + " trainer.state.global_step +=1\n", + " print('finished')\n", + " \n", + " # Evaluate on the following day\n", + " trainer.eval_dataset_or_path = eval_paths\n", + " train_metrics = trainer.evaluate(metric_key_prefix='eval')\n", + " print('*'*20)\n", + " print(\"Eval results for day %s are:\\t\" %time_index_eval)\n", + " print('\\n' + '*'*20 + '\\n')\n", + " for key in sorted(train_metrics.keys()):\n", + " print(\" %s = %s\" % (key, str(train_metrics[key]))) \n", + " wipe_memory()" + ] + }, + { + "cell_type": "markdown", + "id": "6a8d7bd8", + "metadata": {}, + "source": [ + "### Re-compute eval metrics of validation data" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "3a34be66", + "metadata": {}, + "outputs": [], + "source": [ + "eval_data_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_eval}/valid.parquet\"))\n", + "\n", + "# set new data from day 7\n", + "eval_metrics = trainer.evaluate(eval_dataset=eval_data_paths, metric_key_prefix='eval')\n", + "for key in sorted(eval_metrics.keys()):\n", + " print(\" %s = %s\" % (key, str(eval_metrics[key])))" + ] + }, + { + "cell_type": "markdown", + "id": "4a26a649", + "metadata": {}, + "source": [ + "That's it! \n", + "You have just trained your session-based recommendation model using Transformers4Rec." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e516a78d-2e1a-4124-ba46-f60b245d3329", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Model(\n", + " (heads): ModuleList(\n", + " (0): Head(\n", + " (body): SequentialBlock(\n", + " (0): TabularSequenceFeatures(\n", + " (to_merge): ModuleDict(\n", + " (continuous_module): SequentialBlock(\n", + " (0): ContinuousFeatures(\n", + " (filter_features): FilterFeatures()\n", + " (_aggregation): ConcatFeatures()\n", + " )\n", + " (1): SequentialBlock(\n", + " (0): DenseBlock(\n", + " (0): Linear(in_features=2, out_features=64, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (2): AsTabular()\n", + " )\n", + " (categorical_module): SequenceEmbeddingFeatures(\n", + " (filter_features): FilterFeatures()\n", + " (embedding_tables): ModuleDict(\n", + " (item_id-list): Embedding(507, 64, padding_idx=0)\n", + " (category-list): Embedding(138, 64, padding_idx=0)\n", + " )\n", + " )\n", + " )\n", + " (_aggregation): ConcatFeatures()\n", + " (projection_module): SequentialBlock(\n", + " (0): DenseBlock(\n", + " (0): Linear(in_features=192, out_features=100, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (_masking): CausalLanguageModeling()\n", + " )\n", + " (1): SequentialBlock(\n", + " (0): DenseBlock(\n", + " (0): Linear(in_features=100, out_features=64, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (2): TansformerBlock(\n", + " (transformer): XLNetModel(\n", + " (word_embedding): Embedding(1, 64)\n", + " (layer): ModuleList(\n", + " (0): XLNetLayer(\n", + " (rel_attn): XLNetRelativeAttention(\n", + " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (ff): XLNetFeedForward(\n", + " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", + " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", + " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (1): XLNetLayer(\n", + " (rel_attn): XLNetRelativeAttention(\n", + " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (ff): XLNetFeedForward(\n", + " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", + " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", + " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (masking): CausalLanguageModeling()\n", + " )\n", + " )\n", + " (prediction_task_dict): ModuleDict(\n", + " (next-item): NextItemPredictionTask(\n", + " (sequence_summary): SequenceSummary(\n", + " (summary): Identity()\n", + " (activation): Identity()\n", + " (first_dropout): Identity()\n", + " (last_dropout): Identity()\n", + " )\n", + " (metrics): ModuleList(\n", + " (0): NDCGAt()\n", + " (1): RecallAt()\n", + " )\n", + " (loss): NLLLoss()\n", + " (embeddings): SequenceEmbeddingFeatures(\n", + " (filter_features): FilterFeatures()\n", + " (embedding_tables): ModuleDict(\n", + " (item_id-list): Embedding(507, 64, padding_idx=0)\n", + " (category-list): Embedding(138, 64, padding_idx=0)\n", + " )\n", + " )\n", + " (item_embedding_table): Embedding(507, 64, padding_idx=0)\n", + " (masking): CausalLanguageModeling()\n", + " (pre): Block(\n", + " (module): NextItemPredictionTask(\n", + " (item_embedding_table): Embedding(507, 64, padding_idx=0)\n", + " (log_softmax): LogSoftmax(dim=-1)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e1e140ac-afd6-455d-9057-1bbd07116a9b", + "metadata": {}, + "outputs": [], + "source": [ + "model.hf_format = False" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ef88f601-c7c0-4244-84f3-ee257b579205", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nametagsdtypeis_listis_raggedproperties.int_domain.minproperties.int_domain.max
0age_days-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalse00
1weekday_sin-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalse00
2item_id-list(Tags.ID, Tags.ITEM, Tags.CATEGORICAL, Tags.IT...int64TrueFalse0506
3category-list(Tags.CATEGORICAL, Tags.LIST)int64TrueFalse0137
\n", + "
" + ], + "text/plain": [ + "[{'name': 'age_days-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'int_domain': {'min': 0, 'max': 506}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 137}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.input_schema" + ] + }, + { + "cell_type": "markdown", + "id": "409c17c2-c81d-4f41-9577-bd380ae10921", + "metadata": {}, + "source": [ + "Create a dict of tensors" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "6273c8e5-db62-4cc6-a4f3-945155a463d6", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset(train_paths[0])\n", + "trainer.train_dataset_or_path = dataset\n", + "loader = trainer.get_train_dataloader()\n", + "train_dict = next(iter(loader))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "43e789a3-2423-44eb-ae5d-0c154557424f", + "metadata": {}, + "outputs": [], + "source": [ + "traced_model = torch.jit.trace(model, train_dict, strict=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "8cdbc288-baf5-4beb-a3ac-5fcb7315125c", + "metadata": {}, + "outputs": [], + "source": [ + "assert isinstance(traced_model, torch.jit.TopLevelTracedModule)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "896f4d80-3703-42af-a48e-2b81e839006d", + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.allclose(\n", + " model(train_dict),\n", + " traced_model(train_dict),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "c6a814aa-4954-404b-bd2d-161ed8066f4e", + "metadata": {}, + "outputs": [], + "source": [ + "input_schema = model.input_schema\n", + "output_schema = model.output_schema" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "757cd0c5-f581-488b-a8de-b8d1188820d6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nametagsdtypeis_listis_raggedproperties.int_domain.minproperties.int_domain.max
0age_days-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalse00
1weekday_sin-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalse00
2item_id-list(Tags.ID, Tags.ITEM, Tags.CATEGORICAL, Tags.IT...int64TrueFalse0506
3category-list(Tags.CATEGORICAL, Tags.LIST)int64TrueFalse0137
\n", + "
" + ], + "text/plain": [ + "[{'name': 'age_days-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'int_domain': {'min': 0, 'max': 506}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 137}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_schema" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "4f96597c-1c05-4fb0-ad3e-c55c21599158", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from merlin.core.dispatch import make_df # noqa\n", + "from merlin.systems.dag import Ensemble # noqa\n", + "from merlin.systems.dag.ops.pytorch import PredictPyTorch # noqa\n", + "from merlin.systems.triton.utils import run_ensemble_on_tritonserver # noqa\n", + "\n", + "torch_op = input_schema.column_names >> PredictPyTorch(\n", + " traced_model, input_schema, output_schema\n", + ")\n", + "\n", + "ensemble = Ensemble(torch_op, input_schema)\n", + "ens_config, node_configs = ensemble.export(str('./models'))" + ] + }, + { + "cell_type": "markdown", + "id": "5faba154-d4b2-4424-a1b2-badd2227e66e", + "metadata": {}, + "source": [ + "Create a dataframe to send as a request. We need a dataset where the list columns are padded to the max sequence lenght that was set in the ETL pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "c1ce3a29-5578-41ca-a033-abc4507adfef", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(128, 4)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
age_days-listweekday_sin-listitem_id-listcategory-list
0[0.37403864, 0.42758772, 0.93743354, 0.0, 0.0,...[0.9351001, 0.91299504, 0.9785595, 0.0, 0.0, 0...[30, 24, 200, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...[7, 6, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
1[0.9327483, 0.36575532, 0.13967341, 0.45479113...[0.5046626, 0.17492707, 0.12539314, 0.6640924,...[190, 10, 7, 55, 27, 3, 6, 184, 0, 0, 0, 0, 0,...[36, 3, 2, 11, 6, 4, 2, 36, 0, 0, 0, 0, 0, 0, ...
2[0.57168996, 0.48532194, 0.89944935, 0.2171675...[0.93685514, 0.5638695, 0.76670134, 0.6797855,...[153, 8, 46, 58, 21, 19, 31, 15, 4, 104, 0, 0,...[28, 2, 10, 11, 5, 5, 7, 3, 2, 18, 0, 0, 0, 0,...
3[0.8520663, 0.6690395, 0.92268515, 0.99163777,...[0.58499664, 0.45736608, 0.88926136, 0.9139287...[19, 28, 23, 34, 18, 10, 0, 0, 0, 0, 0, 0, 0, ...[5, 6, 6, 7, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
4[0.67542243, 0.65952307, 0.7467189, 0.6136317,...[0.09077961, 0.7920753, 0.35881928, 0.8545563,...[17, 27, 70, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...[5, 6, 14, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
\n", + "
" + ], + "text/plain": [ + " age_days-list \\\n", + "0 [0.37403864, 0.42758772, 0.93743354, 0.0, 0.0,... \n", + "1 [0.9327483, 0.36575532, 0.13967341, 0.45479113... \n", + "2 [0.57168996, 0.48532194, 0.89944935, 0.2171675... \n", + "3 [0.8520663, 0.6690395, 0.92268515, 0.99163777,... \n", + "4 [0.67542243, 0.65952307, 0.7467189, 0.6136317,... \n", + "\n", + " weekday_sin-list \\\n", + "0 [0.9351001, 0.91299504, 0.9785595, 0.0, 0.0, 0... \n", + "1 [0.5046626, 0.17492707, 0.12539314, 0.6640924,... \n", + "2 [0.93685514, 0.5638695, 0.76670134, 0.6797855,... \n", + "3 [0.58499664, 0.45736608, 0.88926136, 0.9139287... \n", + "4 [0.09077961, 0.7920753, 0.35881928, 0.8545563,... \n", + "\n", + " item_id-list \\\n", + "0 [30, 24, 200, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "1 [190, 10, 7, 55, 27, 3, 6, 184, 0, 0, 0, 0, 0,... \n", + "2 [153, 8, 46, 58, 21, 19, 31, 15, 4, 104, 0, 0,... \n", + "3 [19, 28, 23, 34, 18, 10, 0, 0, 0, 0, 0, 0, 0, ... \n", + "4 [17, 27, 70, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "\n", + " category-list \n", + "0 [7, 6, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "1 [36, 3, 2, 11, 6, 4, 2, 36, 0, 0, 0, 0, 0, 0, ... \n", + "2 [28, 2, 10, 11, 5, 5, 7, 3, 2, 18, 0, 0, 0, 0,... \n", + "3 [5, 6, 6, 7, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "4 [5, 6, 14, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... " + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset = Dataset(eval_paths[0])\n", + "# trainer.test_dataset_or_path = dataset\n", + "loader = trainer.get_test_dataloader(dataset)\n", + "test_dict = next(iter(loader))\n", + "\n", + "df_cols = {}\n", + "for name, tensor in train_dict.items():\n", + " if name in input_schema.column_names:\n", + " dtype = input_schema[name].dtype\n", + "\n", + " df_cols[name] = tensor.cpu().numpy().astype(dtype)\n", + " if len(tensor.shape) > 1:\n", + " df_cols[name] = list(df_cols[name])\n", + "\n", + "df = make_df(df_cols)\n", + "print(df.shape)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "617348d4-3493-4b68-ba9a-da9543147628", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I1123 22:15:22.928458 2118 pinned_memory_manager.cc:240] Pinned memory pool is created at '0x7f428e000000' with size 268435456\n", + "I1123 22:15:22.928899 2118 cuda_memory_manager.cc:105] CUDA memory pool is created on device 0 with size 67108864\n", + "I1123 22:15:22.931602 2118 model_lifecycle.cc:459] loading: 0_predictpytorch:1\n", + "I1123 22:15:23.299444 2118 libtorch.cc:1983] TRITONBACKEND_Initialize: pytorch\n", + "I1123 22:15:23.299463 2118 libtorch.cc:1993] Triton TRITONBACKEND API version: 1.10\n", + "I1123 22:15:23.299469 2118 libtorch.cc:1999] 'pytorch' TRITONBACKEND API version: 1.10\n", + "I1123 22:15:23.299488 2118 libtorch.cc:2032] TRITONBACKEND_ModelInitialize: 0_predictpytorch (version 1)\n", + "W1123 22:15:23.300039 2118 libtorch.cc:284] skipping model configuration auto-complete for '0_predictpytorch': not supported for pytorch backend\n", + "I1123 22:15:23.300768 2118 libtorch.cc:313] Optimized execution is enabled for model instance '0_predictpytorch'\n", + "I1123 22:15:23.300780 2118 libtorch.cc:332] Cache Cleaning is disabled for model instance '0_predictpytorch'\n", + "I1123 22:15:23.300786 2118 libtorch.cc:349] Inference Mode is enabled for model instance '0_predictpytorch'\n", + "I1123 22:15:23.300790 2118 libtorch.cc:444] NvFuser is not specified for model instance '0_predictpytorch'\n", + "I1123 22:15:23.301026 2118 libtorch.cc:2076] TRITONBACKEND_ModelInstanceInitialize: 0_predictpytorch (GPU device 0)\n", + "I1123 22:15:24.229933 2118 model_lifecycle.cc:693] successfully loaded '0_predictpytorch' version 1\n", + "I1123 22:15:24.230204 2118 model_lifecycle.cc:459] loading: ensemble_model:1\n", + "I1123 22:15:24.230490 2118 model_lifecycle.cc:693] successfully loaded 'ensemble_model' version 1\n", + "I1123 22:15:24.230584 2118 server.cc:561] \n", + "+------------------+------+\n", + "| Repository Agent | Path |\n", + "+------------------+------+\n", + "+------------------+------+\n", + "\n", + "I1123 22:15:24.230668 2118 server.cc:588] \n", + "+---------+---------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "| Backend | Path | Config |\n", + "+---------+---------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "| pytorch | /opt/tritonserver/backends/pytorch/libtriton_pytorch.so | {\"cmdline\":{\"auto-complete-config\":\"true\",\"min-compute-capability\":\"6.000000\",\"backend-directory\":\"/opt/tritonserver/backends\",\"default-max-batch-size\":\"4\"}} |\n", + "+---------+---------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "\n", + "I1123 22:15:24.230751 2118 server.cc:631] \n", + "+------------------+---------+--------+\n", + "| Model | Version | Status |\n", + "+------------------+---------+--------+\n", + "| 0_predictpytorch | 1 | READY |\n", + "| ensemble_model | 1 | READY |\n", + "+------------------+---------+--------+\n", + "\n", + "I1123 22:15:24.282945 2118 metrics.cc:650] Collecting metrics for GPU 0: Quadro GV100\n", + "I1123 22:15:24.283260 2118 tritonserver.cc:2214] \n", + "+----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "| Option | Value |\n", + "+----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "| server_id | triton |\n", + "| server_version | 2.25.0 |\n", + "| server_extensions | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration system_shared_memory cuda_shared_memory binary_tensor_data statistics trace |\n", + "| model_repository_path[0] | ./models |\n", + "| model_control_mode | MODE_NONE |\n", + "| strict_model_config | 0 |\n", + "| rate_limit | OFF |\n", + "| pinned_memory_pool_byte_size | 268435456 |\n", + "| cuda_memory_pool_byte_size{0} | 67108864 |\n", + "| response_cache_byte_size | 0 |\n", + "| min_supported_compute_capability | 6.0 |\n", + "| strict_readiness | 1 |\n", + "| exit_timeout | 30 |\n", + "+----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", + "\n", + "I1123 22:15:24.285008 2118 grpc_server.cc:4610] Started GRPCInferenceService at localhost:8001\n", + "I1123 22:15:24.285227 2118 http_server.cc:3316] Started HTTPService at 0.0.0.0:8000\n", + "I1123 22:15:24.326845 2118 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Signal (2) received.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I1123 22:15:26.364164 2118 server.cc:262] Waiting for in-flight requests to complete.\n", + "I1123 22:15:26.364179 2118 server.cc:278] Timeout 30: Found 0 model versions that have in-flight inferences\n", + "I1123 22:15:26.364255 2118 server.cc:293] All models are stopped, unloading models\n", + "I1123 22:15:26.364263 2118 server.cc:300] Timeout 30: Found 2 live models and 0 in-flight non-inference requests\n", + "I1123 22:15:26.364287 2118 model_lifecycle.cc:578] successfully unloaded 'ensemble_model' version 1\n", + "I1123 22:15:26.364592 2118 libtorch.cc:2110] TRITONBACKEND_ModelInstanceFinalize: delete instance state\n", + "I1123 22:15:26.372137 2118 libtorch.cc:2055] TRITONBACKEND_ModelFinalize: delete model state\n", + "I1123 22:15:26.372333 2118 model_lifecycle.cc:578] successfully unloaded '0_predictpytorch' version 1\n", + "I1123 22:15:27.364444 2118 server.cc:300] Timeout 29: Found 0 live models and 0 in-flight non-inference requests\n" + ] + } + ], + "source": [ + "# ===========================================\n", + "# Send request to Triton and check response\n", + "# ===========================================\n", + "response = run_ensemble_on_tritonserver(\n", + " './models', input_schema, df[input_schema.column_names], output_schema.column_names, \"ensemble_model\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "430555ac-d427-48ac-b93a-da2ea41b86d0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'next-item': array([[-7.9947233, -8.747803 , -3.4068425, ..., -7.914096 , -8.571636 ,\n", + " -7.815837 ],\n", + " [-7.977386 , -8.727583 , -3.3388414, ..., -7.9091067, -8.508778 ,\n", + " -7.7708635],\n", + " [-8.000487 , -8.737406 , -3.4030848, ..., -7.921053 , -8.557445 ,\n", + " -7.798526 ],\n", + " ...,\n", + " [-7.998163 , -8.739789 , -3.3824148, ..., -7.9103565, -8.550226 ,\n", + " -7.8002963],\n", + " [-7.9968286, -8.753717 , -3.3801503, ..., -7.9066863, -8.55961 ,\n", + " -7.794828 ],\n", + " [-8.01243 , -8.753323 , -3.3656597, ..., -7.8982997, -8.546498 ,\n", + " -7.7921886]], dtype=float32)}" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response" + ] + }, + { + "cell_type": "markdown", + "id": "d0581195-d496-4536-a93b-3071ed9088ea", + "metadata": {}, + "source": [ + "We return a response for each request in the df. Each row in the `response['next-item']` array corresponds to the logit values per item in the catalog and for the OOV item." + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "26bbdcb4-1347-46bd-a3eb-1c140f8bacd6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 507)" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response['next-item'].shape" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "7b543a88d374ac88bf8df97911b380f671b13649694a5b49eb21e60fd27eb479" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/getting-started-session-based/schema.pb b/examples/getting-started-session-based/schema.pb deleted file mode 100644 index c749503aa2..0000000000 --- a/examples/getting-started-session-based/schema.pb +++ /dev/null @@ -1,87 +0,0 @@ -feature { - name: "session_id" - type: INT - int_domain { - name: "session_id" - min: 1 - max: 100001 - is_categorical: false - } - annotation { - tag: "groupby_col" - } -} -feature { - name: "category-list_trim" - value_count { - min: 2 - max: 20 - } - type: INT - int_domain { - name: "category-list_trim" - min: 1 - max: 400 - is_categorical: true - } - annotation { - tag: "list" - tag: "categorical" - tag: "item" - } -} -feature { - name: "item_id-list_trim" - value_count { - min: 2 - max: 20 - } - type: INT - int_domain { - name: "item_id/list" - min: 1 - max: 50005 - is_categorical: true - } - annotation { - tag: "item_id" - tag: "list" - tag: "categorical" - tag: "item" - } -} -feature { - name: "timestamp/age_days-list_trim" - value_count { - min: 2 - max: 20 - } - type: FLOAT - float_domain { - name: "timestamp/age_days-list_trim" - min: 0.0000003 - max: 0.9999999 - } - annotation { - tag: "continuous" - tag: "list" - } -} -feature { - name: "timestamp/weekday/sin-list_trim" - value_count { - min: 2 - max: 20 - } - type: FLOAT - float_domain { - name: "timestamp/weekday-sin_trim" - min: 0.0000003 - max: 0.9999999 - } - annotation { - tag: "continuous" - tag: "time" - tag: "list" - } -} \ No newline at end of file From e05307bdeea76b31ecaa3b1731ef9a2a1aab3add Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Dec 2022 23:32:17 +0000 Subject: [PATCH 2/4] update nbs add schema file back --- .../01-ETL-with-NVTabular.ipynb | 487 ++++---- .../02-session-based-XLNet-with-PyT.ipynb | 315 ++--- ...ng-session-based-model-torch-backend.ipynb | 1011 +++++++++++++++++ .../getting-started-session-based/schema.pb | 87 ++ 4 files changed, 1544 insertions(+), 356 deletions(-) create mode 100644 examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb create mode 100644 examples/getting-started-session-based/schema.pb diff --git a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb index d071156144..c71f62ebdb 100644 --- a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb @@ -132,7 +132,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "9617e30c", + "id": "d5fa9320-7935-4a4c-ad7d-8f9e2b90801c", "metadata": {}, "outputs": [ { @@ -167,48 +167,48 @@ " \n", " \n", " 0\n", - " 88504\n", - " 26\n", - " 4\n", - " 0.342991\n", - " 0.144433\n", + " 84344\n", + " 5\n", + " 2\n", + " 0.197794\n", + " 0.220711\n", " 1\n", " \n", " \n", " 1\n", - " 85107\n", - " 13\n", + " 79183\n", + " 26\n", + " 7\n", + " 0.659679\n", + " 0.554893\n", " 2\n", - " 0.156982\n", - " 0.722122\n", - " 4\n", " \n", " \n", " 2\n", - " 89499\n", - " 37\n", + " 76110\n", + " 7\n", + " 2\n", + " 0.545001\n", + " 0.476261\n", " 5\n", - " 0.389054\n", - " 0.321258\n", - " 9\n", " \n", " \n", " 3\n", - " 88602\n", - " 20\n", - " 3\n", - " 0.258130\n", - " 0.491159\n", + " 86269\n", + " 78\n", + " 21\n", + " 0.231765\n", + " 0.040279\n", " 2\n", " \n", " \n", " 4\n", - " 84113\n", - " 7\n", - " 1\n", - " 0.519515\n", - " 0.110561\n", - " 7\n", + " 73974\n", + " 90\n", + " 24\n", + " 0.321135\n", + " 0.082030\n", + " 5\n", " \n", " \n", "\n", @@ -216,11 +216,11 @@ ], "text/plain": [ " session_id item_id category age_days weekday_sin day\n", - "0 88504 26 4 0.342991 0.144433 1\n", - "1 85107 13 2 0.156982 0.722122 4\n", - "2 89499 37 5 0.389054 0.321258 9\n", - "3 88602 20 3 0.258130 0.491159 2\n", - "4 84113 7 1 0.519515 0.110561 7" + "0 84344 5 2 0.197794 0.220711 1\n", + "1 79183 26 7 0.659679 0.554893 2\n", + "2 76110 7 2 0.545001 0.476261 5\n", + "3 86269 78 21 0.231765 0.040279 2\n", + "4 73974 90 24 0.321135 0.082030 5" ] }, "execution_count": 5, @@ -245,7 +245,7 @@ "id": "139de226", "metadata": {}, "source": [ - "Deep Learning models require dense input features. Categorical features are sparse, and need to be represented by dense embeddings in the model. To allow for that, categorical features first need to be encoded as contiguous integers `(0, ..., |C|)`, where `|C|` is the feature cardinality (number of unique values), so that their embeddings can be efficiently stored in embedding layers. We will use NVTabular to preprocess the categorical features, so that all categorical columns are encoded as contiguous integers. Note that in the `Categorify` op, we set `start_index=1`; the reason for that is, we want the encoded null values to start from `1` instead of `0` because we reserve `0` for padding the sequence features." + "Deep Learning models require dense input features. Categorical features are sparse, and need to be represented by dense embeddings in the model. To allow for that, categorical features first need to be encoded as contiguous integers `(0, ..., |C|)`, where `|C|` is the feature cardinality (number of unique values), so that their embeddings can be efficiently stored in embedding layers. We will use NVTabular to preprocess the categorical features, so that all categorical columns are encoded as contiguous integers. Note that the `Categorify` op encodes OOVs or nulls to `0` automatically. In our synthetic dataset we do not have any nulls. On the other hand `0` is also used for padding the sequences in input block, thefore, you can set `start_index=1` arg in the Categorify op if you want the encoded null or OOV values to start from `1` instead of `0` because we reserve `0` for padding the sequence features." ] }, { @@ -279,7 +279,7 @@ "SESSIONS_MAX_LENGTH =20\n", "\n", "# Categorify categorical features\n", - "categ_feats = ['session_id', 'item_id', 'category'] >> nvt.ops.Categorify(start_index=1)\n", + "categ_feats = ['session_id', 'item_id', 'category'] >> nvt.ops.Categorify()\n", "\n", "# Define Groupby Workflow\n", "groupby_feats = categ_feats + ['day', 'age_days', 'weekday_sin']\n", @@ -331,6 +331,7 @@ "\n", "\n", "workflow = nvt.Workflow(filtered_sessions['session_id', 'day-first', 'item_id-count'] + seq_feats_list)\n", + "\n", "dataset = nvt.Dataset(df, cpu=False)\n", "# Generate statistics for the features\n", "workflow.fit(dataset)\n", @@ -379,33 +380,33 @@ " \n", " \n", " 0\n", - " 2\n", - " 9\n", - " 78\n", - " [19, 15, 26, 6, 24, 44, 19, 29, 14, 42, 33, 11...\n", - " [4, 4, 5, 2, 5, 7, 4, 5, 2, 7, 2, 3, 2, 3, 2, ...\n", - " [0.0042280755, 0.40522072, 0.42538044, 0.97327...\n", - " [0.6292621, 0.1172376, 0.18633945, 0.8232658, ...\n", + " 1\n", + " 1\n", + " 18\n", + " [4, 11, 17, 29, 18, 1, 3, 48, 11, 9, 16, 5, 19...\n", + " [1, 3, 6, 8, 5, 1, 1, 13, 3, 2, 6, 2, 4, 52, 1...\n", + " [0.14561038, 0.9393455, 0.012047833, 0.658193,...\n", + " [0.23078364, 0.99029666, 0.89728844, 0.9642181...\n", " \n", " \n", " 1\n", - " 3\n", " 2\n", - " 76\n", - " [26, 85, 9, 57, 17, 7, 9, 41, 37, 11, 10, 16, ...\n", - " [5, 13, 2, 9, 4, 3, 2, 7, 6, 3, 3, 4, 7, 5, 2,...\n", - " [0.07340249, 0.2910817, 0.010784109, 0.8495507...\n", - " [0.52355134, 0.83093345, 0.8837344, 0.38942775...\n", + " 4\n", + " 15\n", + " [97, 7, 44, 24, 31, 23, 41, 245, 11, 3, 28, 11...\n", + " [28, 2, 12, 7, 9, 7, 10, 61, 3, 1, 8, 3, 18, 2...\n", + " [0.54006344, 0.71162707, 0.2320292, 0.49496385...\n", + " [0.72449577, 0.35770282, 0.13853826, 0.0450636...\n", " \n", " \n", " 2\n", - " 4\n", - " 6\n", - " 76\n", - " [13, 13, 50, 64, 105, 16, 78, 17, 19, 34, 8, 1...\n", - " [3, 3, 8, 10, 16, 4, 12, 4, 4, 6, 3, 4, 3, 6, ...\n", - " [0.29271448, 0.59962034, 0.042938035, 0.730446...\n", - " [0.8610789, 0.058191676, 0.806903, 0.79222715,...\n", + " 3\n", + " 9\n", + " 15\n", + " [5, 27, 26, 111, 97, 50, 3, 4, 7, 31, 29, 23, ...\n", + " [2, 8, 7, 27, 28, 13, 1, 1, 2, 9, 8, 7, 2, 11, 1]\n", + " [0.14291424, 0.11157788, 0.7810709, 0.11342292...\n", + " [0.35786915, 0.467376, 0.34360662, 0.50400823,...\n", " \n", " \n", "\n", @@ -413,29 +414,29 @@ ], "text/plain": [ " session_id day-first item_id-count \\\n", - "0 2 9 78 \n", - "1 3 2 76 \n", - "2 4 6 76 \n", + "0 1 1 18 \n", + "1 2 4 15 \n", + "2 3 9 15 \n", "\n", " item_id-list \\\n", - "0 [19, 15, 26, 6, 24, 44, 19, 29, 14, 42, 33, 11... \n", - "1 [26, 85, 9, 57, 17, 7, 9, 41, 37, 11, 10, 16, ... \n", - "2 [13, 13, 50, 64, 105, 16, 78, 17, 19, 34, 8, 1... \n", + "0 [4, 11, 17, 29, 18, 1, 3, 48, 11, 9, 16, 5, 19... \n", + "1 [97, 7, 44, 24, 31, 23, 41, 245, 11, 3, 28, 11... \n", + "2 [5, 27, 26, 111, 97, 50, 3, 4, 7, 31, 29, 23, ... \n", "\n", " category-list \\\n", - "0 [4, 4, 5, 2, 5, 7, 4, 5, 2, 7, 2, 3, 2, 3, 2, ... \n", - "1 [5, 13, 2, 9, 4, 3, 2, 7, 6, 3, 3, 4, 7, 5, 2,... \n", - "2 [3, 3, 8, 10, 16, 4, 12, 4, 4, 6, 3, 4, 3, 6, ... \n", + "0 [1, 3, 6, 8, 5, 1, 1, 13, 3, 2, 6, 2, 4, 52, 1... \n", + "1 [28, 2, 12, 7, 9, 7, 10, 61, 3, 1, 8, 3, 18, 2... \n", + "2 [2, 8, 7, 27, 28, 13, 1, 1, 2, 9, 8, 7, 2, 11, 1] \n", "\n", " age_days-list \\\n", - "0 [0.0042280755, 0.40522072, 0.42538044, 0.97327... \n", - "1 [0.07340249, 0.2910817, 0.010784109, 0.8495507... \n", - "2 [0.29271448, 0.59962034, 0.042938035, 0.730446... \n", + "0 [0.14561038, 0.9393455, 0.012047833, 0.658193,... \n", + "1 [0.54006344, 0.71162707, 0.2320292, 0.49496385... \n", + "2 [0.14291424, 0.11157788, 0.7810709, 0.11342292... \n", "\n", " weekday_sin-list \n", - "0 [0.6292621, 0.1172376, 0.18633945, 0.8232658, ... \n", - "1 [0.52355134, 0.83093345, 0.8837344, 0.38942775... \n", - "2 [0.8610789, 0.058191676, 0.806903, 0.79222715,... " + "0 [0.23078364, 0.99029666, 0.89728844, 0.9642181... \n", + "1 [0.72449577, 0.35770282, 0.13853826, 0.0450636... \n", + "2 [0.35786915, 0.467376, 0.34360662, 0.50400823,... " ] }, "execution_count": 7, @@ -447,6 +448,34 @@ "sessions_gdf.head(3)" ] }, + { + "cell_type": "code", + "execution_count": 8, + "id": "cf6a0dae-4999-495c-b154-f91de0da9c33", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "session_id int64\n", + "day-first int64\n", + "item_id-count int32\n", + "item_id-list list\n", + "category-list list\n", + "age_days-list list\n", + "weekday_sin-list list\n", + "dtype: object" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sessions_gdf.dtypes" + ] + }, { "cell_type": "markdown", "id": "2458c28f", @@ -457,7 +486,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "78e42cbf-edd6-44af-af23-c026edb578c4", "metadata": {}, "outputs": [ @@ -512,13 +541,13 @@ " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", + " 0.0\n", " .//categories/unique.session_id.parquet\n", " 0.0\n", - " 20001.0\n", + " 19875.0\n", " session_id\n", - " 20002.0\n", - " 410.0\n", + " 19876.0\n", + " 408.0\n", " NaN\n", " NaN\n", " \n", @@ -552,63 +581,63 @@ " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", + " 0.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 828.0\n", + " 489.0\n", " item_id\n", - " 829.0\n", - " 69.0\n", + " 490.0\n", + " 51.0\n", " NaN\n", " NaN\n", " \n", " \n", " 3\n", " item_id-list\n", - " (Tags.ITEM_ID, Tags.ITEM, Tags.CATEGORICAL, Ta...\n", + " (Tags.ITEM, Tags.ITEM_ID, Tags.ID, Tags.LIST, ...\n", " int64\n", " True\n", - " False\n", + " True\n", " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", + " 0.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 828.0\n", + " 489.0\n", " item_id\n", - " 829.0\n", - " 69.0\n", - " 20.0\n", - " 20.0\n", + " 490.0\n", + " 51.0\n", + " 2.0\n", + " 18.0\n", " \n", " \n", " 4\n", " category-list\n", - " (Tags.CATEGORICAL, Tags.LIST)\n", + " (Tags.LIST, Tags.CATEGORICAL)\n", " int64\n", " True\n", - " False\n", + " True\n", " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", + " 0.0\n", " .//categories/unique.category.parquet\n", " 0.0\n", - " 172.0\n", + " 176.0\n", " category\n", - " 173.0\n", + " 177.0\n", " 29.0\n", - " 20.0\n", - " 20.0\n", + " 2.0\n", + " 18.0\n", " \n", " \n", " 5\n", " age_days-list\n", - " (Tags.CONTINUOUS, Tags.LIST)\n", + " (Tags.LIST, Tags.CONTINUOUS)\n", " float32\n", " True\n", - " False\n", + " True\n", " NaN\n", " NaN\n", " NaN\n", @@ -619,16 +648,16 @@ " NaN\n", " NaN\n", " NaN\n", - " 20.0\n", - " 20.0\n", + " 2.0\n", + " 18.0\n", " \n", " \n", " 6\n", " weekday_sin-list\n", - " (Tags.CONTINUOUS, Tags.LIST)\n", + " (Tags.LIST, Tags.CONTINUOUS)\n", " float32\n", " True\n", - " False\n", + " True\n", " NaN\n", " NaN\n", " NaN\n", @@ -639,18 +668,18 @@ " NaN\n", " NaN\n", " NaN\n", - " 20.0\n", - " 20.0\n", + " 2.0\n", + " 18.0\n", " \n", " \n", "\n", "" ], "text/plain": [ - "[{'name': 'session_id', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.session_id.parquet', 'domain': {'min': 0, 'max': 20001, 'name': 'session_id'}, 'embedding_sizes': {'cardinality': 20002, 'dimension': 410}}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 828, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 829, 'dimension': 69}}, 'dtype': dtype('int32'), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 828, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 829, 'dimension': 69}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 172, 'name': 'category'}, 'embedding_sizes': {'cardinality': 173, 'dimension': 29}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}]" + "[{'name': 'session_id', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.session_id.parquet', 'domain': {'min': 0, 'max': 19875, 'name': 'session_id'}, 'embedding_sizes': {'cardinality': 19876, 'dimension': 408}}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': dtype('int64'), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 489, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 490, 'dimension': 51}}, 'dtype': dtype('int32'), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 489, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 490, 'dimension': 51}, 'value_count': {'min': 2, 'max': 18}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 176, 'name': 'category'}, 'embedding_sizes': {'cardinality': 177, 'dimension': 29}, 'value_count': {'min': 2, 'max': 18}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': True}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 18}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': True}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 18}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': True}]" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -669,7 +698,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "e71dd078-6116-4ac2-ba6f-0207aaa8d417", "metadata": {}, "outputs": [ @@ -688,7 +717,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "9f498dce-69eb-4f88-8ddd-8629558825df", "metadata": {}, "outputs": [], @@ -714,7 +743,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "12d3e59b", "metadata": {}, "outputs": [], @@ -725,7 +754,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "6c67a92b", "metadata": {}, "outputs": [ @@ -733,7 +762,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Creating time-based splits: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 9.56it/s]\n" + "Creating time-based splits: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 22.48it/s]\n" ] } ], @@ -756,7 +785,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "dd04ec82", "metadata": {}, "outputs": [], @@ -766,7 +795,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "8e5e6358", "metadata": {}, "outputs": [ @@ -802,48 +831,48 @@ " \n", " \n", " 0\n", - " 15\n", - " 73\n", - " [7, 15, 7, 18, 49, 106, 6, 2, 9, 9, 10, 31, 2,...\n", - " [3, 4, 3, 4, 8, 16, 2, 2, 2, 2, 3, 6, 2, 3, 3,...\n", - " [0.83090544, 0.80508804, 0.77800703, 0.1136483...\n", - " [0.73860276, 0.71227425, 0.82740945, 0.7208882...\n", + " 1\n", + " 18\n", + " [4, 11, 17, 29, 18, 1, 3, 48, 11, 9, 16, 5, 19...\n", + " [1, 3, 6, 8, 5, 1, 1, 13, 3, 2, 6, 2, 4, 52, 1...\n", + " [0.14561038, 0.9393455, 0.012047833, 0.658193,...\n", + " [0.23078364, 0.99029666, 0.89728844, 0.9642181...\n", " \n", " \n", " 1\n", - " 33\n", - " 72\n", - " [21, 15, 11, 34, 13, 2, 35, 16, 8, 16, 2, 72, ...\n", - " [2, 4, 3, 6, 3, 2, 6, 4, 3, 4, 2, 11, 4, 5, 6,...\n", - " [0.091410436, 0.39586508, 0.5213629, 0.73786, ...\n", - " [0.81503123, 0.8240082, 0.13869828, 0.00287529...\n", + " 4\n", + " 15\n", + " [36, 49, 9, 95, 12, 26, 35, 185, 43, 14, 19, 2...\n", + " [11, 13, 2, 24, 3, 7, 9, 55, 12, 3, 4, 7, 1, 1...\n", + " [0.4289175, 0.41714236, 0.6593241, 0.7470034, ...\n", + " [0.5122762, 0.11083387, 0.26527187, 0.77329, 0...\n", " \n", " \n", " 2\n", - " 50\n", - " 71\n", - " [6, 113, 59, 16, 12, 209, 3, 181, 40, 6, 13, 3...\n", - " [2, 17, 9, 4, 3, 29, 2, 26, 7, 2, 3, 6, 9, 27,...\n", - " [0.0334595, 0.7818315, 0.9511686, 0.48197943, ...\n", - " [0.5987661, 0.2299972, 0.57004195, 0.93016136,...\n", + " 30\n", + " 13\n", + " [29, 12, 18, 20, 46, 77, 7, 7, 26, 21, 111, 2, 1]\n", + " [8, 3, 5, 5, 12, 20, 2, 2, 7, 5, 27, 1, 1]\n", + " [0.079172395, 0.26267487, 0.9678789, 0.601294,...\n", + " [0.1440753, 0.5550622, 0.18317387, 0.06565472,...\n", " \n", " \n", " 4\n", - " 54\n", - " 71\n", - " [121, 23, 21, 46, 18, 68, 39, 103, 53, 15, 3, ...\n", - " [18, 5, 2, 7, 4, 11, 7, 15, 9, 4, 2, 2, 4, 5, ...\n", - " [0.24764787, 0.7796014, 0.6935816, 0.37522456,...\n", - " [0.7501101, 0.8012419, 0.6419888, 0.18589461, ...\n", + " 63\n", + " 12\n", + " [5, 44, 10, 5, 26, 6, 193, 11, 13, 9, 10, 60]\n", + " [2, 12, 3, 2, 7, 2, 50, 3, 4, 2, 3, 15]\n", + " [0.7659222, 0.9388312, 0.28288805, 0.75763357,...\n", + " [0.99894804, 0.038836945, 0.85671306, 0.345418...\n", " \n", " \n", " 5\n", - " 60\n", - " 70\n", - " [4, 25, 62, 11, 94, 14, 13, 28, 8, 38, 47, 6, ...\n", - " [2, 5, 10, 3, 14, 2, 3, 5, 3, 6, 8, 2, 17, 3, ...\n", - " [0.42777818, 0.93942195, 0.7336651, 0.6372542,...\n", - " [0.6628856, 0.19497228, 0.34514892, 0.7849939,...\n", + " 75\n", + " 12\n", + " [48, 21, 5, 40, 4, 182, 36, 39, 54, 37, 8, 116]\n", + " [13, 5, 2, 10, 1, 48, 11, 10, 14, 11, 4, 31]\n", + " [0.4896862, 0.7550025, 0.92395943, 0.4152636, ...\n", + " [0.08632153, 0.82823294, 0.50390047, 0.4975271...\n", " \n", " \n", " ...\n", @@ -855,125 +884,125 @@ " ...\n", " \n", " \n", - " 2234\n", - " 19970\n", - " 31\n", - " [16, 21, 8, 16, 17, 118, 146, 32, 16, 89, 52, ...\n", - " [4, 2, 3, 4, 4, 17, 22, 6, 4, 14, 8, 2, 2, 17,...\n", - " [0.39309493, 0.80350375, 0.41615465, 0.7130491...\n", - " [0.4812473, 0.19307572, 0.31647742, 0.48890728...\n", + " 2111\n", + " 19151\n", + " 2\n", + " [24, 19]\n", + " [7, 4]\n", + " [0.3092607, 0.25387767]\n", + " [0.6523481, 0.059806556]\n", " \n", " \n", - " 2235\n", - " 19979\n", - " 30\n", - " [17, 13, 227, 22, 13, 43, 20, 65, 8, 13, 3, 36...\n", - " [4, 3, 33, 4, 3, 7, 4, 10, 3, 3, 2, 6, 2, 2, 7...\n", - " [0.80101305, 0.87425673, 0.072974995, 0.940252...\n", - " [0.9547868, 0.22636005, 0.14399427, 0.25720155...\n", + " 2112\n", + " 19173\n", + " 2\n", + " [60, 37]\n", + " [15, 11]\n", + " [0.82798934, 0.054636054]\n", + " [0.84105706, 0.52476853]\n", " \n", " \n", - " 2236\n", - " 19980\n", - " 30\n", - " [12, 26, 16, 50, 23, 113, 38, 4, 17, 5, 22, 15...\n", - " [3, 5, 4, 8, 5, 17, 6, 2, 4, 3, 4, 4, 3, 3, 3,...\n", - " [0.8601423, 0.94999486, 0.19948259, 0.00668230...\n", - " [0.53358024, 0.03624035, 0.36104643, 0.6872657...\n", + " 2113\n", + " 19188\n", + " 2\n", + " [10, 21]\n", + " [3, 5]\n", + " [0.92787683, 0.5812024]\n", + " [0.13824013, 0.74283314]\n", " \n", " \n", - " 2237\n", - " 19986\n", - " 29\n", - " [154, 18, 28, 216, 173, 148, 22, 72, 28, 21, 2...\n", - " [24, 4, 5, 32, 25, 22, 4, 11, 5, 2, 5, 2, 2, 2...\n", - " [0.5032084, 0.22921799, 0.5817064, 0.31098822,...\n", - " [0.93975395, 0.9232329, 0.292329, 0.034402855,...\n", + " 2114\n", + " 19194\n", + " 2\n", + " [4, 158]\n", + " [1, 41]\n", + " [0.22679287, 0.024510423]\n", + " [0.9538698, 0.4295912]\n", " \n", " \n", - " 2239\n", - " 19998\n", - " 27\n", - " [7, 14, 5, 14, 46, 31, 11, 7, 3, 10, 54, 2, 56...\n", - " [3, 2, 3, 2, 7, 6, 3, 3, 2, 3, 9, 2, 9, 3, 3, ...\n", - " [0.2078503, 0.6980298, 0.29598948, 0.49848178,...\n", - " [0.28777036, 0.6233945, 0.48009974, 0.9278188,...\n", + " 2115\n", + " 19204\n", + " 2\n", + " [37, 16]\n", + " [11, 6]\n", + " [0.5972207, 0.11343666]\n", + " [0.81323135, 0.46290976]\n", " \n", " \n", "\n", - "

1794 rows × 6 columns

\n", + "

1687 rows × 6 columns

\n", "" ], "text/plain": [ " session_id item_id-count \\\n", - "0 15 73 \n", - "1 33 72 \n", - "2 50 71 \n", - "4 54 71 \n", - "5 60 70 \n", + "0 1 18 \n", + "1 4 15 \n", + "2 30 13 \n", + "4 63 12 \n", + "5 75 12 \n", "... ... ... \n", - "2234 19970 31 \n", - "2235 19979 30 \n", - "2236 19980 30 \n", - "2237 19986 29 \n", - "2239 19998 27 \n", + "2111 19151 2 \n", + "2112 19173 2 \n", + "2113 19188 2 \n", + "2114 19194 2 \n", + "2115 19204 2 \n", "\n", " item_id-list \\\n", - "0 [7, 15, 7, 18, 49, 106, 6, 2, 9, 9, 10, 31, 2,... \n", - "1 [21, 15, 11, 34, 13, 2, 35, 16, 8, 16, 2, 72, ... \n", - "2 [6, 113, 59, 16, 12, 209, 3, 181, 40, 6, 13, 3... \n", - "4 [121, 23, 21, 46, 18, 68, 39, 103, 53, 15, 3, ... \n", - "5 [4, 25, 62, 11, 94, 14, 13, 28, 8, 38, 47, 6, ... \n", + "0 [4, 11, 17, 29, 18, 1, 3, 48, 11, 9, 16, 5, 19... \n", + "1 [36, 49, 9, 95, 12, 26, 35, 185, 43, 14, 19, 2... \n", + "2 [29, 12, 18, 20, 46, 77, 7, 7, 26, 21, 111, 2, 1] \n", + "4 [5, 44, 10, 5, 26, 6, 193, 11, 13, 9, 10, 60] \n", + "5 [48, 21, 5, 40, 4, 182, 36, 39, 54, 37, 8, 116] \n", "... ... \n", - "2234 [16, 21, 8, 16, 17, 118, 146, 32, 16, 89, 52, ... \n", - "2235 [17, 13, 227, 22, 13, 43, 20, 65, 8, 13, 3, 36... \n", - "2236 [12, 26, 16, 50, 23, 113, 38, 4, 17, 5, 22, 15... \n", - "2237 [154, 18, 28, 216, 173, 148, 22, 72, 28, 21, 2... \n", - "2239 [7, 14, 5, 14, 46, 31, 11, 7, 3, 10, 54, 2, 56... \n", + "2111 [24, 19] \n", + "2112 [60, 37] \n", + "2113 [10, 21] \n", + "2114 [4, 158] \n", + "2115 [37, 16] \n", "\n", " category-list \\\n", - "0 [3, 4, 3, 4, 8, 16, 2, 2, 2, 2, 3, 6, 2, 3, 3,... \n", - "1 [2, 4, 3, 6, 3, 2, 6, 4, 3, 4, 2, 11, 4, 5, 6,... \n", - "2 [2, 17, 9, 4, 3, 29, 2, 26, 7, 2, 3, 6, 9, 27,... \n", - "4 [18, 5, 2, 7, 4, 11, 7, 15, 9, 4, 2, 2, 4, 5, ... \n", - "5 [2, 5, 10, 3, 14, 2, 3, 5, 3, 6, 8, 2, 17, 3, ... \n", + "0 [1, 3, 6, 8, 5, 1, 1, 13, 3, 2, 6, 2, 4, 52, 1... \n", + "1 [11, 13, 2, 24, 3, 7, 9, 55, 12, 3, 4, 7, 1, 1... \n", + "2 [8, 3, 5, 5, 12, 20, 2, 2, 7, 5, 27, 1, 1] \n", + "4 [2, 12, 3, 2, 7, 2, 50, 3, 4, 2, 3, 15] \n", + "5 [13, 5, 2, 10, 1, 48, 11, 10, 14, 11, 4, 31] \n", "... ... \n", - "2234 [4, 2, 3, 4, 4, 17, 22, 6, 4, 14, 8, 2, 2, 17,... \n", - "2235 [4, 3, 33, 4, 3, 7, 4, 10, 3, 3, 2, 6, 2, 2, 7... \n", - "2236 [3, 5, 4, 8, 5, 17, 6, 2, 4, 3, 4, 4, 3, 3, 3,... \n", - "2237 [24, 4, 5, 32, 25, 22, 4, 11, 5, 2, 5, 2, 2, 2... \n", - "2239 [3, 2, 3, 2, 7, 6, 3, 3, 2, 3, 9, 2, 9, 3, 3, ... \n", + "2111 [7, 4] \n", + "2112 [15, 11] \n", + "2113 [3, 5] \n", + "2114 [1, 41] \n", + "2115 [11, 6] \n", "\n", " age_days-list \\\n", - "0 [0.83090544, 0.80508804, 0.77800703, 0.1136483... \n", - "1 [0.091410436, 0.39586508, 0.5213629, 0.73786, ... \n", - "2 [0.0334595, 0.7818315, 0.9511686, 0.48197943, ... \n", - "4 [0.24764787, 0.7796014, 0.6935816, 0.37522456,... \n", - "5 [0.42777818, 0.93942195, 0.7336651, 0.6372542,... \n", + "0 [0.14561038, 0.9393455, 0.012047833, 0.658193,... \n", + "1 [0.4289175, 0.41714236, 0.6593241, 0.7470034, ... \n", + "2 [0.079172395, 0.26267487, 0.9678789, 0.601294,... \n", + "4 [0.7659222, 0.9388312, 0.28288805, 0.75763357,... \n", + "5 [0.4896862, 0.7550025, 0.92395943, 0.4152636, ... \n", "... ... \n", - "2234 [0.39309493, 0.80350375, 0.41615465, 0.7130491... \n", - "2235 [0.80101305, 0.87425673, 0.072974995, 0.940252... \n", - "2236 [0.8601423, 0.94999486, 0.19948259, 0.00668230... \n", - "2237 [0.5032084, 0.22921799, 0.5817064, 0.31098822,... \n", - "2239 [0.2078503, 0.6980298, 0.29598948, 0.49848178,... \n", + "2111 [0.3092607, 0.25387767] \n", + "2112 [0.82798934, 0.054636054] \n", + "2113 [0.92787683, 0.5812024] \n", + "2114 [0.22679287, 0.024510423] \n", + "2115 [0.5972207, 0.11343666] \n", "\n", " weekday_sin-list \n", - "0 [0.73860276, 0.71227425, 0.82740945, 0.7208882... \n", - "1 [0.81503123, 0.8240082, 0.13869828, 0.00287529... \n", - "2 [0.5987661, 0.2299972, 0.57004195, 0.93016136,... \n", - "4 [0.7501101, 0.8012419, 0.6419888, 0.18589461, ... \n", - "5 [0.6628856, 0.19497228, 0.34514892, 0.7849939,... \n", + "0 [0.23078364, 0.99029666, 0.89728844, 0.9642181... \n", + "1 [0.5122762, 0.11083387, 0.26527187, 0.77329, 0... \n", + "2 [0.1440753, 0.5550622, 0.18317387, 0.06565472,... \n", + "4 [0.99894804, 0.038836945, 0.85671306, 0.345418... \n", + "5 [0.08632153, 0.82823294, 0.50390047, 0.4975271... \n", "... ... \n", - "2234 [0.4812473, 0.19307572, 0.31647742, 0.48890728... \n", - "2235 [0.9547868, 0.22636005, 0.14399427, 0.25720155... \n", - "2236 [0.53358024, 0.03624035, 0.36104643, 0.6872657... \n", - "2237 [0.93975395, 0.9232329, 0.292329, 0.034402855,... \n", - "2239 [0.28777036, 0.6233945, 0.48009974, 0.9278188,... \n", + "2111 [0.6523481, 0.059806556] \n", + "2112 [0.84105706, 0.52476853] \n", + "2113 [0.13824013, 0.74283314] \n", + "2114 [0.9538698, 0.4295912] \n", + "2115 [0.81323135, 0.46290976] \n", "\n", - "[1794 rows x 6 columns]" + "[1687 rows x 6 columns]" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -988,7 +1017,7 @@ "id": "a6461a96", "metadata": {}, "source": [ - "You have just created session-level features to train a session-based recommendation model using NVTabular. Now you can move to the the next notebook,`02-session-based-XLNet-with-PyT.ipynb` to train a session-based recommendation model using [XLNet](https://arxiv.org/abs/1906.08237), one of the state-of-the-art NLP model." + "You have just created session-level features to train a session-based recommendation model using NVTabular. Now you can move to the the next notebook,`02-session-based-XLNet-with-PyT.ipynb` to train a session-based recommendation model using [XLNet](https://arxiv.org/abs/1906.08237), one of the state-of-the-art NLP model. Please shut down this kernel to free the GPU memory before you start the next one." ] } ], diff --git a/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb b/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb index 0f03820b46..1b9faca2f2 100644 --- a/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb +++ b/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb @@ -235,14 +235,14 @@ " type: INT\n", " int_domain {\n", " name: \"session_id\"\n", - " max: 19870\n", + " max: 19875\n", " is_categorical: true\n", " }\n", " annotation {\n", " tag: \"categorical\"\n", " extra_metadata {\n", " type_url: \"type.googleapis.com/google.protobuf.Struct\"\n", - " value: \"\\n\\021\\n\\013num_buckets\\022\\002\\010\\000\\n\\033\\n\\016freq_threshold\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n\\025\\n\\010max_size\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n\\030\\n\\013start_index\\022\\t\\021\\000\\000\\000\\000\\000\\000\\360?\\n5\\n\\010cat_path\\022)\\032\\'.//categories/unique.session_id.parquet\\nG\\n\\017embedding_sizes\\0224*2\\n\\030\\n\\013cardinality\\022\\t\\021\\000\\000\\000\\000\\300g\\323@\\n\\026\\n\\tdimension\\022\\t\\021\\000\\000\\000\\000\\000\\200y@\\n\\034\\n\\017dtype_item_size\\022\\t\\021\\000\\000\\000\\000\\000\\000P@\\n\\r\\n\\007is_list\\022\\002 \\000\\n\\017\\n\\tis_ragged\\022\\002 \\000\"\n", + " value: \"\\n\\021\\n\\013num_buckets\\022\\002\\010\\000\\n\\033\\n\\016freq_threshold\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n\\025\\n\\010max_size\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n\\030\\n\\013start_index\\022\\t\\021\\000\\000\\000\\000\\000\\000\\000\\000\\n5\\n\\010cat_path\\022)\\032\\'.//categories/unique.session_id.parquet\\nG\\n\\017embedding_sizes\\0224*2\\n\\030\\n\\013cardinality\\022\\t\\021\\000\\000\\000\\000\\000i\\323@\\n\\026\\n\\tdimension\\022\\t\\021\\000\\000\\000\\000\\000\\200y@\\n\\034\\n\\017dtype_item_size\\022\\t\\021\\000\\000\\000\\000\\000\\000P@\\n\\r\\n\\007is_list\\022\\002 \\000\\n\\017\\n\\tis_ragged\\022\\002 \\000\"\n", " }\n", " }\n", "}\n", @@ -306,8 +306,8 @@ " schema,\n", " max_sequence_length=20,\n", " continuous_projection=64,\n", - " d_output=100,\n", " masking=\"mlm\",\n", + " d_output=100,\n", ")" ] }, @@ -587,7 +587,7 @@ " \n", " \n", " 50\n", - " 5.821000\n", + " 5.804900\n", " \n", " \n", "

" @@ -623,8 +623,8 @@ "\n", "

\n", " \n", - " \n", - " [6/6 00:10]\n", + " \n", + " [6/6 00:12]\n", "
\n", " " ], @@ -644,14 +644,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 5.006950855255127\n", - " eval_/next-item/ndcg_at_20 = 0.17477868497371674\n", - " eval_/next-item/ndcg_at_40 = 0.24376630783081055\n", - " eval_/next-item/recall_at_20 = 0.421875\n", - " eval_/next-item/recall_at_40 = 0.7552083730697632\n", - " eval_runtime = 0.1473\n", - " eval_samples_per_second = 1303.363\n", - " eval_steps_per_second = 40.73\n", + " eval_/loss = 4.9857892990112305\n", + " eval_/next-item/ndcg_at_20 = 0.17793045938014984\n", + " eval_/next-item/ndcg_at_40 = 0.2293066680431366\n", + " eval_/next-item/recall_at_20 = 0.4895833432674408\n", + " eval_/next-item/recall_at_40 = 0.7395833730697632\n", + " eval_runtime = 0.1585\n", + " eval_samples_per_second = 1211.104\n", + " eval_steps_per_second = 37.847\n", "['/workspace/data/sessions_by_day/2/train.parquet']\n", "********************\n", "Launch training for day 2 are:\n", @@ -691,7 +691,7 @@ " \n", " \n", " 50\n", - " 4.896300\n", + " 4.859800\n", " \n", " \n", "

" @@ -718,7 +718,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "finished\n" + "finished\n", + "********************\n", + "Eval results for day 3 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 4.595378398895264\n", + " eval_/next-item/ndcg_at_20 = 0.18038052320480347\n", + " eval_/next-item/ndcg_at_40 = 0.23606626689434052\n", + " eval_/next-item/recall_at_20 = 0.4791666865348816\n", + " eval_/next-item/recall_at_40 = 0.75\n", + " eval_runtime = 0.1612\n", + " eval_samples_per_second = 1190.843\n", + " eval_steps_per_second = 37.214\n" ] }, { @@ -738,19 +751,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "********************\n", - "Eval results for day 3 are:\t\n", - "\n", - "********************\n", - "\n", - " eval_/loss = 4.632648944854736\n", - " eval_/next-item/ndcg_at_20 = 0.18322871625423431\n", - " eval_/next-item/ndcg_at_40 = 0.24265889823436737\n", - " eval_/next-item/recall_at_20 = 0.4427083432674408\n", - " eval_/next-item/recall_at_40 = 0.7291666865348816\n", - " eval_runtime = 0.2087\n", - " eval_samples_per_second = 919.825\n", - " eval_steps_per_second = 28.745\n", "['/workspace/data/sessions_by_day/3/train.parquet']\n", "********************\n", "Launch training for day 3 are:\n", @@ -777,7 +777,7 @@ " \n", " \n", " 50\n", - " 4.577600\n", + " 4.569400\n", " \n", " \n", "

" @@ -810,14 +810,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.358114719390869\n", - " eval_/next-item/ndcg_at_20 = 0.2185230255126953\n", - " eval_/next-item/ndcg_at_40 = 0.2660380005836487\n", - " eval_/next-item/recall_at_20 = 0.578125\n", - " eval_/next-item/recall_at_40 = 0.8072916865348816\n", - " eval_runtime = 0.1733\n", - " eval_samples_per_second = 1107.591\n", - " eval_steps_per_second = 34.612\n", + " eval_/loss = 4.384099960327148\n", + " eval_/next-item/ndcg_at_20 = 0.2163243591785431\n", + " eval_/next-item/ndcg_at_40 = 0.26256000995635986\n", + " eval_/next-item/recall_at_20 = 0.5364583730697632\n", + " eval_/next-item/recall_at_40 = 0.7604166865348816\n", + " eval_runtime = 0.154\n", + " eval_samples_per_second = 1246.411\n", + " eval_steps_per_second = 38.95\n", "['/workspace/data/sessions_by_day/4/train.parquet']\n", "********************\n", "Launch training for day 4 are:\n", @@ -857,7 +857,7 @@ " \n", " \n", " 50\n", - " 4.497000\n", + " 4.507400\n", " \n", " \n", "

" @@ -890,14 +890,94 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.520808696746826\n", - " eval_/next-item/ndcg_at_20 = 0.2083323895931244\n", - " eval_/next-item/ndcg_at_40 = 0.2511211037635803\n", - " eval_/next-item/recall_at_20 = 0.53125\n", - " eval_/next-item/recall_at_40 = 0.7395833730697632\n", - " eval_runtime = 0.173\n", - " eval_samples_per_second = 1109.768\n", - " eval_steps_per_second = 34.68\n" + " eval_/loss = 4.448643207550049\n", + " eval_/next-item/ndcg_at_20 = 0.18624266982078552\n", + " eval_/next-item/ndcg_at_40 = 0.22514502704143524\n", + " eval_/next-item/recall_at_20 = 0.5364583730697632\n", + " eval_/next-item/recall_at_40 = 0.7239583730697632\n", + " eval_runtime = 0.1597\n", + " eval_samples_per_second = 1202.447\n", + " eval_steps_per_second = 37.576\n", + "['/workspace/data/sessions_by_day/5/train.parquet']\n", + "********************\n", + "Launch training for day 5 are:\n", + "********************\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "***** Running training *****\n", + " Num examples = 1536\n", + " Num Epochs = 5\n", + " Instantaneous batch size per device = 128\n", + " Total train batch size (w. parallel, distributed & accumulation) = 128\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 60\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [60/60 00:01, Epoch 5/5]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
504.484300

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "finished\n", + "********************\n", + "Eval results for day 6 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 4.46337890625\n", + " eval_/next-item/ndcg_at_20 = 0.18136751651763916\n", + " eval_/next-item/ndcg_at_40 = 0.23689508438110352\n", + " eval_/next-item/recall_at_20 = 0.4947916865348816\n", + " eval_/next-item/recall_at_40 = 0.765625\n", + " eval_runtime = 0.161\n", + " eval_samples_per_second = 1192.384\n", + " eval_steps_per_second = 37.262\n" ] }, { @@ -917,9 +997,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "['/workspace/data/sessions_by_day/5/train.parquet']\n", + "['/workspace/data/sessions_by_day/6/train.parquet']\n", "********************\n", - "Launch training for day 5 are:\n", + "Launch training for day 6 are:\n", "********************\n", "\n" ] @@ -943,7 +1023,7 @@ " \n", " \n", " 50\n", - " 4.491200\n", + " 4.458900\n", " \n", " \n", "

" @@ -972,18 +1052,18 @@ "text": [ "finished\n", "********************\n", - "Eval results for day 6 are:\t\n", + "Eval results for day 7 are:\t\n", "\n", "********************\n", "\n", - " eval_/loss = 4.366371154785156\n", - " eval_/next-item/ndcg_at_20 = 0.21815048158168793\n", - " eval_/next-item/ndcg_at_40 = 0.2546561360359192\n", - " eval_/next-item/recall_at_20 = 0.5885416865348816\n", - " eval_/next-item/recall_at_40 = 0.765625\n", - " eval_runtime = 0.1743\n", - " eval_samples_per_second = 1101.611\n", - " eval_steps_per_second = 34.425\n" + " eval_/loss = 4.451930046081543\n", + " eval_/next-item/ndcg_at_20 = 0.19292211532592773\n", + " eval_/next-item/ndcg_at_40 = 0.237293541431427\n", + " eval_/next-item/recall_at_20 = 0.5104166865348816\n", + " eval_/next-item/recall_at_40 = 0.7291666865348816\n", + " eval_runtime = 0.1492\n", + " eval_samples_per_second = 1286.892\n", + " eval_steps_per_second = 40.215\n" ] }, { @@ -1003,9 +1083,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "['/workspace/data/sessions_by_day/6/train.parquet']\n", + "['/workspace/data/sessions_by_day/7/train.parquet']\n", "********************\n", - "Launch training for day 6 are:\n", + "Launch training for day 7 are:\n", "********************\n", "\n" ] @@ -1029,7 +1109,7 @@ " \n", " \n", " 50\n", - " 4.488000\n", + " 4.430300\n", " \n", " \n", "

" @@ -1058,27 +1138,27 @@ "text": [ "finished\n", "********************\n", - "Eval results for day 7 are:\t\n", + "Eval results for day 8 are:\t\n", "\n", "********************\n", "\n", - " eval_/loss = 4.4079389572143555\n", - " eval_/next-item/ndcg_at_20 = 0.2032829374074936\n", - " eval_/next-item/ndcg_at_40 = 0.24693170189857483\n", - " eval_/next-item/recall_at_20 = 0.5416666865348816\n", - " eval_/next-item/recall_at_40 = 0.75\n", - " eval_runtime = 0.1598\n", - " eval_samples_per_second = 1201.645\n", - " eval_steps_per_second = 37.551\n", - "CPU times: user 44 s, sys: 560 ms, total: 44.6 s\n", - "Wall time: 13.2 s\n" + " eval_/loss = 4.416452407836914\n", + " eval_/next-item/ndcg_at_20 = 0.20186899602413177\n", + " eval_/next-item/ndcg_at_40 = 0.2410435527563095\n", + " eval_/next-item/recall_at_20 = 0.5520833730697632\n", + " eval_/next-item/recall_at_40 = 0.7447916865348816\n", + " eval_runtime = 0.1614\n", + " eval_samples_per_second = 1189.826\n", + " eval_steps_per_second = 37.182\n", + "CPU times: user 50.8 s, sys: 731 ms, total: 51.6 s\n", + "Wall time: 15 s\n" ] } ], "source": [ "%%time\n", "start_time_window_index = 1\n", - "final_time_window_index = 7\n", + "final_time_window_index = 8\n", "#Iterating over days of one week\n", "for time_index in range(start_time_window_index, final_time_window_index):\n", " # Set data \n", @@ -1111,93 +1191,75 @@ }, { "cell_type": "markdown", - "id": "a33e6662", + "id": "6a8d7bd8", "metadata": {}, "source": [ - "### Save the model" + "### Re-compute evaluation metrics of the validation data" ] }, { "cell_type": "code", "execution_count": 12, - "id": "396057bb", + "id": "3a34be66", + "metadata": {}, + "outputs": [], + "source": [ + "eval_data_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_eval}/valid.parquet\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2c6867d8", "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "Saving model checkpoint to ./tmp/checkpoint-66\n", - "Trainer.model is not a `PreTrainedModel`, only saving its state dict.\n" + " eval_/loss = 4.416452407836914\n", + " eval_/next-item/ndcg_at_20 = 0.20186899602413177\n", + " eval_/next-item/ndcg_at_40 = 0.2410435527563095\n", + " eval_/next-item/recall_at_20 = 0.5520833730697632\n", + " eval_/next-item/recall_at_40 = 0.7447916865348816\n", + " eval_runtime = 0.1455\n", + " eval_samples_per_second = 1319.597\n", + " eval_steps_per_second = 41.237\n" ] } ], "source": [ - "trainer._save_model_and_checkpoint(save_model_class=True)" + "# set new data from day 7\n", + "eval_metrics = trainer.evaluate(eval_dataset=eval_data_paths, metric_key_prefix='eval')\n", + "for key in sorted(eval_metrics.keys()):\n", + " print(\" %s = %s\" % (key, str(eval_metrics[key])))" ] }, { "cell_type": "markdown", - "id": "41bb7767", + "id": "78b520bd-f6e3-460e-b06d-1bf480cae3cf", "metadata": {}, "source": [ - "### Reload the model" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "33139496", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.load_model_trainer_states_from_checkpoint('./tmp/checkpoint-%s'%trainer.state.global_step)" + "### Save the model" ] }, { "cell_type": "markdown", - "id": "6a8d7bd8", + "id": "766cd46b-25e7-4d89-9675-152e02103f33", "metadata": {}, "source": [ - "### Re-compute evaluation metrics of the validation data" + "Let's save the model to be able to load it back at inference step. Using `model.save()`, we save the model as a pkl file in the given path." ] }, { "cell_type": "code", "execution_count": 14, - "id": "3a34be66", + "id": "d9aebd79-e6e5-4309-a692-917b7b0283d0", "metadata": {}, "outputs": [], "source": [ - "eval_data_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_eval}/valid.parquet\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "2c6867d8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " eval_/loss = 4.4079389572143555\n", - " eval_/next-item/ndcg_at_20 = 0.2032829374074936\n", - " eval_/next-item/ndcg_at_40 = 0.24693170189857483\n", - " eval_/next-item/recall_at_20 = 0.5416666865348816\n", - " eval_/next-item/recall_at_40 = 0.75\n", - " eval_runtime = 0.1713\n", - " eval_samples_per_second = 1121.166\n", - " eval_steps_per_second = 35.036\n" - ] - } - ], - "source": [ - "# set new data from day 7\n", - "eval_metrics = trainer.evaluate(eval_dataset=eval_data_paths, metric_key_prefix='eval')\n", - "for key in sorted(eval_metrics.keys()):\n", - " print(\" %s = %s\" % (key, str(eval_metrics[key])))" + "model_path= os.environ.get(\"OUTPUT_DIR\", f\"{INPUT_DATA_DIR}/saved_model\")\n", + "model.save(model_path)" ] }, { @@ -1205,8 +1267,7 @@ "id": "4a26a649", "metadata": {}, "source": [ - "That's it! \n", - "You have just trained your session-based recommendation model using Transformers4Rec." + "That's it! You have just trained your session-based recommendation model using Transformers4Rec. Now you can move on to the next notebook `03-serving-session-based-model-torch-backend`. Please shut down this kernel to free the GPU memory before you start the next one." ] }, { diff --git a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb new file mode 100644 index 0000000000..4dd7671b27 --- /dev/null +++ b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb @@ -0,0 +1,1011 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "97250792", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "# ==============================================================================" + ] + }, + { + "cell_type": "markdown", + "id": "0a2228da", + "metadata": {}, + "source": [ + "\n", + "\n", + "# Serving a Session-based Recommendation model with Torch Backend" + ] + }, + { + "cell_type": "markdown", + "id": "4127de28-a7ce-4ff7-8dcc-1575a70ca7c8", + "metadata": {}, + "source": [ + "This notebook is created using the latest stable `merlin-pytorch` container.\n", + "\n", + "At this point, when you reach out to this notebook, we expect that you have already executed the `01-ETL-with-NVTabular.ipynb` and `02-session-based-XLNet-with-PyT.ipynb` notebooks, and saved the NVT workflow and the trained model.\n", + "\n", + "In this notebook, you are going to learn how you can serve a trained Transformer-based PyTorch model on Triton Inference Server (TIS) with Torch backend using [Merlin systems](https://github.com/NVIDIA-Merlin/systems) library. \n", + "\n", + "NVIDIA [Triton Inference Server](https://github.com/triton-inference-server/server) (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.\n", + "\n", + "If you would like to learn how to serve a TF4Rec model with Python backend please visit this [example](https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/main/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb)." + ] + }, + { + "cell_type": "markdown", + "id": "599efc90", + "metadata": {}, + "source": [ + "### Import required libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3ba89970", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!\n", + " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (NDCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (DCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (AvgPrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (PrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (RecallAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n" + ] + } + ], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", + "\n", + "import cudf\n", + "import glob\n", + "import torch \n", + "\n", + "from transformers4rec import torch as tr\n", + "from merlin.io import Dataset\n", + "\n", + "from merlin.core.dispatch import make_df # noqa\n", + "from merlin.systems.dag import Ensemble # noqa\n", + "from merlin.systems.dag.ops.pytorch import PredictPyTorch # noqa\n", + "from merlin.systems.dag.ops.workflow import TransformWorkflow\n", + "from merlin.systems.triton.utils import run_ensemble_on_tritonserver # noqa" + ] + }, + { + "cell_type": "markdown", + "id": "ac2aebe1-66a9-4c71-9bbc-1874625bc4e8", + "metadata": { + "tags": [] + }, + "source": [ + "We define the paths" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5bc4a16b-32be-42b7-9b6f-e4f4cde3f345", + "metadata": {}, + "outputs": [], + "source": [ + "INPUT_DATA_DIR = os.environ.get(\"INPUT_DATA_DIR\", \"/workspace/data\")\n", + "OUTPUT_DIR = os.environ.get(\"OUTPUT_DIR\", f\"{INPUT_DATA_DIR}/sessions_by_day\")\n", + "model_path= os.environ.get(\"OUTPUT_DIR\", f\"{INPUT_DATA_DIR}/saved_model\")" + ] + }, + { + "cell_type": "markdown", + "id": "a510b6ef", + "metadata": {}, + "source": [ + "### Set the schema object" + ] + }, + { + "cell_type": "markdown", + "id": "30a0518a-eb01-4ac4-9c6d-36b328985765", + "metadata": {}, + "source": [ + "We create the schema object by reading the `schema.pbtxt` file generated by NVTabular pipeline in the previous, `01-ETL-with-NVTabular`, notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9d1299fa", + "metadata": {}, + "outputs": [], + "source": [ + "from merlin_standard_lib import Schema\n", + "SCHEMA_PATH = os.environ.get(\"INPUT_SCHEMA_PATH\", \"/workspace/data/processed_nvt/schema.pbtxt\")\n", + "schema = Schema().from_proto_text(SCHEMA_PATH)" + ] + }, + { + "cell_type": "markdown", + "id": "81764b19-4495-45c0-9cb5-6d937962d2bc", + "metadata": {}, + "source": [ + "We need to load the saved model to be able to serve it on TIS." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3bbc34a4-2277-4961-8061-59aadaa5116c", + "metadata": {}, + "outputs": [], + "source": [ + "import cloudpickle\n", + "loaded_model = cloudpickle.load(\n", + " open(os.path.join(model_path, \"t4rec_model_class.pkl\"), \"rb\")\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "c68ab3c7-a576-4fa0-a6f3-fe2bd11effc9", + "metadata": {}, + "source": [ + "Switch the model to eval mode. We call `model.eval()` before tracing to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this might yield inconsistent inference results." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e516a78d-2e1a-4124-ba46-f60b245d3329", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Model(\n", + " (heads): ModuleList(\n", + " (0): Head(\n", + " (body): SequentialBlock(\n", + " (0): TabularSequenceFeatures(\n", + " (to_merge): ModuleDict(\n", + " (continuous_module): SequentialBlock(\n", + " (0): ContinuousFeatures(\n", + " (filter_features): FilterFeatures()\n", + " (_aggregation): ConcatFeatures()\n", + " )\n", + " (1): SequentialBlock(\n", + " (0): DenseBlock(\n", + " (0): Linear(in_features=2, out_features=64, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (2): AsTabular()\n", + " )\n", + " (categorical_module): SequenceEmbeddingFeatures(\n", + " (filter_features): FilterFeatures()\n", + " (embedding_tables): ModuleDict(\n", + " (item_id-list): Embedding(490, 64, padding_idx=0)\n", + " (category-list): Embedding(177, 64, padding_idx=0)\n", + " )\n", + " )\n", + " )\n", + " (_aggregation): ConcatFeatures()\n", + " (projection_module): SequentialBlock(\n", + " (0): DenseBlock(\n", + " (0): Linear(in_features=192, out_features=100, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (_masking): MaskedLanguageModeling()\n", + " )\n", + " (1): SequentialBlock(\n", + " (0): DenseBlock(\n", + " (0): Linear(in_features=100, out_features=64, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (2): TansformerBlock(\n", + " (transformer): XLNetModel(\n", + " (word_embedding): Embedding(1, 64)\n", + " (layer): ModuleList(\n", + " (0): XLNetLayer(\n", + " (rel_attn): XLNetRelativeAttention(\n", + " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (ff): XLNetFeedForward(\n", + " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", + " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", + " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (1): XLNetLayer(\n", + " (rel_attn): XLNetRelativeAttention(\n", + " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (ff): XLNetFeedForward(\n", + " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", + " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", + " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.3, inplace=False)\n", + " )\n", + " (masking): MaskedLanguageModeling()\n", + " )\n", + " )\n", + " (prediction_task_dict): ModuleDict(\n", + " (next-item): NextItemPredictionTask(\n", + " (sequence_summary): SequenceSummary(\n", + " (summary): Identity()\n", + " (activation): Identity()\n", + " (first_dropout): Identity()\n", + " (last_dropout): Identity()\n", + " )\n", + " (metrics): ModuleList(\n", + " (0): NDCGAt()\n", + " (1): RecallAt()\n", + " )\n", + " (loss): NLLLoss()\n", + " (embeddings): SequenceEmbeddingFeatures(\n", + " (filter_features): FilterFeatures()\n", + " (embedding_tables): ModuleDict(\n", + " (item_id-list): Embedding(490, 64, padding_idx=0)\n", + " (category-list): Embedding(177, 64, padding_idx=0)\n", + " )\n", + " )\n", + " (item_embedding_table): Embedding(490, 64, padding_idx=0)\n", + " (masking): MaskedLanguageModeling()\n", + " (pre): Block(\n", + " (module): NextItemPredictionTask(\n", + " (item_embedding_table): Embedding(490, 64, padding_idx=0)\n", + " (log_softmax): LogSoftmax(dim=-1)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = loaded_model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "409c17c2-c81d-4f41-9577-bd380ae10921", + "metadata": {}, + "source": [ + "### Trace the model\n", + "\n", + "One common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can be run in Python as well as in a high performance environment like C++. TorchScript is actually the recommended model format for scaled inference and deployment. We serve the model with the PyTorch backend that is used to execute TorchScript models. All models created in PyTorch using the python API must be traced/scripted to produce a TorchScript model. For tracing the model, we use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) api that takes the model as a Python function or torch.nn.Module, and an example input that will be passed to the function while tracing." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6273c8e5-db62-4cc6-a4f3-945155a463d6", + "metadata": {}, + "outputs": [], + "source": [ + "train_paths = os.path.join(OUTPUT_DIR, f\"{1}/train.parquet\")\n", + "dataset = Dataset(train_paths)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ef5867e4-769a-4af9-aed7-38586a75f18f", + "metadata": {}, + "outputs": [], + "source": [ + "sparse_max = {'age_days-list': 20,\n", + " 'weekday_sin-list': 20,\n", + " 'item_id-list': 20,\n", + " 'category-list': 20}\n", + "\n", + "from transformers4rec.torch.utils.data_utils import MerlinDataLoader\n", + "\n", + "def generate_dataloader(schema, dataset, batch_size=128, seq_length=20):\n", + " loader = MerlinDataLoader.from_schema(\n", + " schema,\n", + " dataset,\n", + " batch_size=batch_size,\n", + " max_sequence_length=seq_length,\n", + " shuffle=False,\n", + " sparse_as_dense=True,\n", + " sparse_max=sparse_max\n", + " )\n", + " return loader" + ] + }, + { + "cell_type": "markdown", + "id": "9ddb4b5f-9b8c-487b-9fa1-cd4b6d3c090f", + "metadata": {}, + "source": [ + "Create a dict of tensors to feed it as example inputs in the `torch.jit.trace()`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1552deb1-1605-4e80-b70f-36f92d64afbd", + "metadata": {}, + "outputs": [], + "source": [ + "loader = generate_dataloader(schema, dataset)\n", + "train_dict = next(iter(loader))" + ] + }, + { + "cell_type": "markdown", + "id": "c8f18ec0-fabe-494e-a963-8cae5882b9d1", + "metadata": {}, + "source": [ + "Let's check out the `item_id-list` column in the `train_dict` dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "46e676f8-0d82-42ec-b27d-04aa2b7c6a6e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 4, 11, 17, ..., 30, 0, 0],\n", + " [36, 49, 9, ..., 0, 0, 0],\n", + " [29, 12, 18, ..., 0, 0, 0],\n", + " ...,\n", + " [21, 23, 40, ..., 0, 0, 0],\n", + " [13, 67, 32, ..., 0, 0, 0],\n", + " [93, 11, 41, ..., 0, 0, 0]], device='cuda:0')" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dict['item_id-list']" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "43e789a3-2423-44eb-ae5d-0c154557424f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "traced_model = torch.jit.trace(model, train_dict, strict=True)" + ] + }, + { + "cell_type": "markdown", + "id": "a404bdb6-6e19-4899-8172-214adef384a8", + "metadata": {}, + "source": [ + "Generate model input and output schemas to feed in the `PredictPyTorch` operator below." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c6a814aa-4954-404b-bd2d-161ed8066f4e", + "metadata": {}, + "outputs": [], + "source": [ + "input_schema = model.input_schema\n", + "output_schema = model.output_schema" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "757cd0c5-f581-488b-a8de-b8d1188820d6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nametagsdtypeis_listis_raggedproperties.int_domain.minproperties.int_domain.max
0age_days-list(Tags.LIST, Tags.CONTINUOUS)float32TrueFalse00
1weekday_sin-list(Tags.LIST, Tags.CONTINUOUS)float32TrueFalse00
2item_id-list(Tags.ITEM, Tags.CATEGORICAL, Tags.LIST, Tags....int64TrueFalse0489
3category-list(Tags.CATEGORICAL, Tags.LIST)int64TrueFalse0176
\n", + "
" + ], + "text/plain": [ + "[{'name': 'age_days-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'int_domain': {'min': 0, 'max': 489}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 176}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_schema" + ] + }, + { + "cell_type": "markdown", + "id": "6f8a8bfc-c6e1-44dd-ab65-cbdab2135e8e", + "metadata": {}, + "source": [ + "Let's create a folder that we can store the exported models and the config files." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6b2deb2b-e223-4b5d-b655-810e1aefa7e8", + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "ens_model_path = os.environ.get(\"ens_model_path\", f\"{INPUT_DATA_DIR}/models\")\n", + "# Make sure we have a clean stats space for Dask\n", + "if os.path.isdir(ens_model_path):\n", + " shutil.rmtree(ens_model_path)\n", + "os.mkdir(ens_model_path)" + ] + }, + { + "cell_type": "markdown", + "id": "e3449615-2120-402d-b5c3-1544ee3224dd", + "metadata": {}, + "source": [ + "We use `PredictPyTorch` operator that takes a pytorch model and packages it correctly for tritonserver to run on the PyTorch backend." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "4f96597c-1c05-4fb0-ad3e-c55c21599158", + "metadata": {}, + "outputs": [], + "source": [ + "torch_op = input_schema.column_names >> PredictPyTorch(\n", + " traced_model, input_schema, output_schema\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f36b0f90-5392-4d7a-9a48-57737cf63cd1", + "metadata": {}, + "source": [ + "The last step is to create the ensemble artifacts that Triton Inference Server can consume. To make these artifacts, we import the Ensemble class. The class is responsible for interpreting the graph and exporting the correct files for the server.\n", + "\n", + "When we create an `Ensemble` object we supply the graph and a schema representing the starting input of the graph. The inputs to the ensemble graph are the inputs to the first operator of out graph. After we created the Ensemble we export the graph, supplying an export path for the `ensemble.export` function. This returns an ensemble config which represents the entire inference pipeline and a list of node-specific configs." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "5b0d14bb-7765-45e8-8fd0-9d508dc3ec14", + "metadata": {}, + "outputs": [], + "source": [ + "ensemble = Ensemble(torch_op, input_schema)\n", + "ens_config, node_configs = ensemble.export(ens_model_path)" + ] + }, + { + "cell_type": "markdown", + "id": "a36169a5-f218-44b5-b034-7d299ce718ed", + "metadata": {}, + "source": [ + "## Starting Triton Server" + ] + }, + { + "cell_type": "markdown", + "id": "eb507766-ac9b-4c8c-8339-c1c951648428", + "metadata": {}, + "source": [ + "It is time to deploy all the models as an ensemble model to Triton Inference Serve TIS. After we export the ensemble, we are ready to start the TIS. You can start triton server by using the following command on your terminal:\n", + "\n", + "`tritonserver --model-repository=`\n", + "\n", + "For the `--model-repository` argument, specify the same path as the export_path that you specified previously in the `ensemble.export` method. This command will launch the server and load all the models to the server. Once all the models are loaded successfully, you should see READY status printed out in the terminal for each loaded model." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "46a86c8d-9ec1-4422-8f8c-4d49e83f6783", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "client created.\n" + ] + } + ], + "source": [ + "import tritonclient.http as client\n", + "\n", + "# Create a triton client\n", + "try:\n", + " triton_client = client.InferenceServerClient(url=\"localhost:8000\", verbose=True)\n", + " print(\"client created.\")\n", + "except Exception as e:\n", + " print(\"channel creation failed: \" + str(e))" + ] + }, + { + "cell_type": "markdown", + "id": "bab41dd0-0155-4d97-8bdf-2451177d46f1", + "metadata": {}, + "source": [ + "After we create the client and verified it is connected to the server instance, we can communicate with the server and ensure all the models are loaded correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "dda3f852-a019-4bf1-831b-f63b750a1192", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GET /v2/health/live, headers None\n", + "\n", + "POST /v2/repository/index, headers None\n", + "\n", + "\n", + "bytearray(b'[{\"name\":\"0_predictpytorch\",\"version\":\"1\",\"state\":\"READY\"},{\"name\":\"ensemble_model\",\"version\":\"1\",\"state\":\"READY\"}]')\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'name': '0_predictpytorch', 'version': '1', 'state': 'READY'},\n", + " {'name': 'ensemble_model', 'version': '1', 'state': 'READY'}]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ensure triton is in a good state\n", + "triton_client.is_server_live()\n", + "triton_client.get_model_repository_index()" + ] + }, + { + "cell_type": "markdown", + "id": "415146b1-9e9c-4c72-a9b9-697041ba27ef", + "metadata": {}, + "source": [ + "### Send request to Triton and get the response" + ] + }, + { + "cell_type": "markdown", + "id": "5f218f19-63d8-498f-aec8-b5dfa56ca3f3", + "metadata": {}, + "source": [ + "The last step of a machine learning (ML)/deep learning (DL) pipeline is to deploy the model to production, and get responses for a given query or a set of queries.\n", + "In this section, we generate a dataframe that we can serve as a request to TIS. Note that this is a transformed dataframe. We also need out dataset list columns to be padded to the max sequence length that was set in the ETL pipeline.\n", + "\n", + "We do not serve the raw dataframe because in the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the deployed DL model for a prediction. Therefore, we use a transformed dataset that is processed similarly as train set. " + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "0acd5649-31fe-4f3f-87a2-2607477638b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(32, 4)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
age_days-listweekday_sin-listitem_id-listcategory-list
0[0.47162586, 0.9286565, 0.8286713, 0.7888927, ...[0.8412576, 0.4806973, 0.26910275, 0.7267265, ...[82, 2, 69, 101, 2, 26, 25, 20, 32, 5, 1, 5, 0...[21, 1, 19, 27, 1, 7, 7, 5, 4, 2, 1, 2, 0, 0, ...
1[0.3916677, 0.92222315, 0.23826516, 0.7312075,...[0.11462939, 0.5831296, 0.8735699, 0.6819625, ...[10, 1, 2, 67, 15, 64, 193, 41, 16, 15, 0, 0, ...[3, 1, 1, 18, 6, 16, 50, 10, 6, 6, 0, 0, 0, 0,...
2[0.18798462, 0.9052097, 0.63714063, 0.5033676,...[0.43194288, 0.45363078, 0.81598556, 0.8821798...[31, 8, 6, 21, 11, 18, 48, 25, 21, 22, 0, 0, 0...[9, 4, 2, 5, 3, 5, 13, 7, 5, 5, 0, 0, 0, 0, 0,...
3[0.8022848, 0.74478865, 0.25033632, 0.5172018,...[0.4466279, 0.9708794, 0.77920324, 0.11763577,...[9, 7, 24, 3, 2, 17, 36, 58, 7, 6, 0, 0, 0, 0,...[2, 2, 7, 1, 1, 6, 11, 15, 2, 2, 0, 0, 0, 0, 0...
4[0.37031004, 0.09800986, 0.8980817, 0.9975142,...[0.29334334, 0.11195588, 0.7300642, 0.27448815...[12, 30, 11, 64, 33, 19, 26, 4, 51, 0, 0, 0, 0...[3, 8, 3, 16, 9, 4, 7, 1, 14, 0, 0, 0, 0, 0, 0...
\n", + "
" + ], + "text/plain": [ + " age_days-list \\\n", + "0 [0.47162586, 0.9286565, 0.8286713, 0.7888927, ... \n", + "1 [0.3916677, 0.92222315, 0.23826516, 0.7312075,... \n", + "2 [0.18798462, 0.9052097, 0.63714063, 0.5033676,... \n", + "3 [0.8022848, 0.74478865, 0.25033632, 0.5172018,... \n", + "4 [0.37031004, 0.09800986, 0.8980817, 0.9975142,... \n", + "\n", + " weekday_sin-list \\\n", + "0 [0.8412576, 0.4806973, 0.26910275, 0.7267265, ... \n", + "1 [0.11462939, 0.5831296, 0.8735699, 0.6819625, ... \n", + "2 [0.43194288, 0.45363078, 0.81598556, 0.8821798... \n", + "3 [0.4466279, 0.9708794, 0.77920324, 0.11763577,... \n", + "4 [0.29334334, 0.11195588, 0.7300642, 0.27448815... \n", + "\n", + " item_id-list \\\n", + "0 [82, 2, 69, 101, 2, 26, 25, 20, 32, 5, 1, 5, 0... \n", + "1 [10, 1, 2, 67, 15, 64, 193, 41, 16, 15, 0, 0, ... \n", + "2 [31, 8, 6, 21, 11, 18, 48, 25, 21, 22, 0, 0, 0... \n", + "3 [9, 7, 24, 3, 2, 17, 36, 58, 7, 6, 0, 0, 0, 0,... \n", + "4 [12, 30, 11, 64, 33, 19, 26, 4, 51, 0, 0, 0, 0... \n", + "\n", + " category-list \n", + "0 [21, 1, 19, 27, 1, 7, 7, 5, 4, 2, 1, 2, 0, 0, ... \n", + "1 [3, 1, 1, 18, 6, 16, 50, 10, 6, 6, 0, 0, 0, 0,... \n", + "2 [9, 4, 2, 5, 3, 5, 13, 7, 5, 5, 0, 0, 0, 0, 0,... \n", + "3 [2, 2, 7, 1, 1, 6, 11, 15, 2, 2, 0, 0, 0, 0, 0... \n", + "4 [3, 8, 3, 16, 9, 4, 7, 1, 14, 0, 0, 0, 0, 0, 0... " + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "eval_paths = os.path.join(OUTPUT_DIR, f\"{1}/valid.parquet\")\n", + "eval_dataset = Dataset(eval_paths, shuffle=False)\n", + "eval_loader = generate_dataloader(schema, eval_dataset, batch_size=32)\n", + "test_dict = next(iter(eval_loader))\n", + "\n", + "df_cols = {}\n", + "for name, tensor in test_dict.items():\n", + " if name in input_schema.column_names:\n", + " dtype = input_schema[name].dtype\n", + "\n", + " df_cols[name] = tensor.cpu().numpy().astype(dtype)\n", + " if len(tensor.shape) > 1:\n", + " df_cols[name] = list(df_cols[name])\n", + "\n", + "df = make_df(df_cols)\n", + "print(df.shape)\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "52832509-6444-4fd2-8834-a02eec429e14", + "metadata": {}, + "source": [ + "Once our models are successfully loaded to the TIS, we can now easily send a request to TIS and get a response for our query with send_triton_request utility function." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "42091c25-7676-414e-bb8c-8432aeb58297", + "metadata": {}, + "outputs": [], + "source": [ + "from merlin.systems.triton.utils import send_triton_request\n", + "response = send_triton_request(input_schema, df[input_schema.column_names], output_schema.column_names)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "430555ac-d427-48ac-b93a-da2ea41b86d0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'next-item': array([[ -9.547688 , -3.3062646, -3.5631118, ..., -9.250848 ,\n", + " -9.964631 , -10.170028 ],\n", + " [ -9.54715 , -3.306591 , -3.5619133, ..., -9.250428 ,\n", + " -9.9653635, -10.169994 ],\n", + " [ -9.547378 , -3.306465 , -3.562431 , ..., -9.25062 ,\n", + " -9.964928 , -10.170043 ],\n", + " ...,\n", + " [ -9.54731 , -3.306385 , -3.5624921, ..., -9.250595 ,\n", + " -9.9651 , -10.169995 ],\n", + " [ -9.54718 , -3.3064706, -3.5621893, ..., -9.25048 ,\n", + " -9.965321 , -10.169971 ],\n", + " [ -9.546991 , -3.3065925, -3.5617704, ..., -9.25033 ,\n", + " -9.965618 , -10.169941 ]], dtype=float32)}" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "0fa425a4-9c00-45ed-a4b1-fd75ca4bf819", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(32, 490)" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response['next-item'].shape" + ] + }, + { + "cell_type": "markdown", + "id": "d0581195-d496-4536-a93b-3071ed9088ea", + "metadata": {}, + "source": [ + "We return a response for each request in the df. Each row in the `response['next-item']` array corresponds to the logit values per item in the catalog, and one logit score correspondig to the null, OOV and padded items. The first score of each array in each row corresponds to the score for the padded item, OOV or null item. Note that we dont have OOV or null items in our syntheticall generated datasets." + ] + }, + { + "cell_type": "markdown", + "id": "5ab16c64-4371-4696-b2d6-3bff66e67fdb", + "metadata": {}, + "source": [ + "This is the end of this suit of examples. You successfully performed feature engineering with NVTabular trained transformer architecture based session-based recommendation models with Transformers4Rec deployed a trained model to Triton Inference Server, sent request and got responses from the server." + ] + } + ], + "metadata": { + "interpreter": { + "hash": "7b543a88d374ac88bf8df97911b380f671b13649694a5b49eb21e60fd27eb479" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/getting-started-session-based/schema.pb b/examples/getting-started-session-based/schema.pb new file mode 100644 index 0000000000..c9ffce3d68 --- /dev/null +++ b/examples/getting-started-session-based/schema.pb @@ -0,0 +1,87 @@ +feature { + name: "session_id" + type: INT + int_domain { + name: "session_id" + min: 1 + max: 100001 + is_categorical: false + } + annotation { + tag: "groupby_col" + } +} +feature { + name: "category-list" + value_count { + min: 2 + max: 20 + } + type: INT + int_domain { + name: "category-list" + min: 1 + max: 400 + is_categorical: true + } + annotation { + tag: "list" + tag: "categorical" + tag: "item" + } +} +feature { + name: "item_id-list" + value_count { + min: 2 + max: 20 + } + type: INT + int_domain { + name: "item_id-list" + min: 1 + max: 50005 + is_categorical: true + } + annotation { + tag: "item_id" + tag: "list" + tag: "categorical" + tag: "item" + } +} +feature { + name: "age_days-list" + value_count { + min: 2 + max: 20 + } + type: FLOAT + float_domain { + name: "age_days-list" + min: 0.0000003 + max: 0.9999999 + } + annotation { + tag: "continuous" + tag: "list" + } +} +feature { + name: "weekday_sin-list" + value_count { + min: 2 + max: 20 + } + type: FLOAT + float_domain { + name: "weekday_sin-list" + min: 0.0000003 + max: 0.9999999 + } + annotation { + tag: "continuous" + tag: "time" + tag: "list" + } +} \ No newline at end of file From 92bd2d94e8d4ae6bed48aef5d90f0630773ea2f6 Mon Sep 17 00:00:00 2001 From: rnyak Date: Thu, 1 Dec 2022 15:35:32 -0800 Subject: [PATCH 3/4] delete old nb --- ...ssion-based-model-with-Torch-backend.ipynb | 1473 ----------------- 1 file changed, 1473 deletions(-) delete mode 100644 examples/getting-started-session-based/03-serving-session-based-model-with-Torch-backend.ipynb diff --git a/examples/getting-started-session-based/03-serving-session-based-model-with-Torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-with-Torch-backend.ipynb deleted file mode 100644 index 77aa95adee..0000000000 --- a/examples/getting-started-session-based/03-serving-session-based-model-with-Torch-backend.ipynb +++ /dev/null @@ -1,1473 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "97250792", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "# ==============================================================================" - ] - }, - { - "cell_type": "markdown", - "id": "0a2228da", - "metadata": {}, - "source": [ - "\n", - "\n", - "# Session-based Recommendation with XLNET" - ] - }, - { - "cell_type": "markdown", - "id": "599efc90", - "metadata": {}, - "source": [ - "### Imports required libraries" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "3ba89970", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!\n", - " warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n", - "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", - " not been set for this class (NDCGAt). The property determines if `update` by\n", - " default needs access to the full metric state. If this is not the case, significant speedups can be\n", - " achieved and we recommend setting this to `False`.\n", - " We provide an checking function\n", - " `from torchmetrics.utilities import check_forward_full_state_property`\n", - " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", - " default for now) or if `full_state_update=False` can be used safely.\n", - " \n", - " warnings.warn(*args, **kwargs)\n", - "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", - " not been set for this class (DCGAt). The property determines if `update` by\n", - " default needs access to the full metric state. If this is not the case, significant speedups can be\n", - " achieved and we recommend setting this to `False`.\n", - " We provide an checking function\n", - " `from torchmetrics.utilities import check_forward_full_state_property`\n", - " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", - " default for now) or if `full_state_update=False` can be used safely.\n", - " \n", - " warnings.warn(*args, **kwargs)\n", - "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", - " not been set for this class (AvgPrecisionAt). The property determines if `update` by\n", - " default needs access to the full metric state. If this is not the case, significant speedups can be\n", - " achieved and we recommend setting this to `False`.\n", - " We provide an checking function\n", - " `from torchmetrics.utilities import check_forward_full_state_property`\n", - " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", - " default for now) or if `full_state_update=False` can be used safely.\n", - " \n", - " warnings.warn(*args, **kwargs)\n", - "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", - " not been set for this class (PrecisionAt). The property determines if `update` by\n", - " default needs access to the full metric state. If this is not the case, significant speedups can be\n", - " achieved and we recommend setting this to `False`.\n", - " We provide an checking function\n", - " `from torchmetrics.utilities import check_forward_full_state_property`\n", - " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", - " default for now) or if `full_state_update=False` can be used safely.\n", - " \n", - " warnings.warn(*args, **kwargs)\n", - "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", - " not been set for this class (RecallAt). The property determines if `update` by\n", - " default needs access to the full metric state. If this is not the case, significant speedups can be\n", - " achieved and we recommend setting this to `False`.\n", - " We provide an checking function\n", - " `from torchmetrics.utilities import check_forward_full_state_property`\n", - " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", - " default for now) or if `full_state_update=False` can be used safely.\n", - " \n", - " warnings.warn(*args, **kwargs)\n" - ] - } - ], - "source": [ - "import os\n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", - "\n", - "import cudf\n", - "import glob\n", - "import torch \n", - "\n", - "from transformers4rec import torch as tr\n", - "from transformers4rec.torch.ranking_metric import NDCGAt, AvgPrecisionAt, RecallAt\n", - "from transformers4rec.torch.utils.examples_utils import wipe_memory\n", - "from merlin.io import Dataset" - ] - }, - { - "cell_type": "markdown", - "id": "aea2a0c5", - "metadata": {}, - "source": [ - "Transformers4Rec library relies on a schema object to automatically build all necessary layers to represent, normalize and aggregate input features. As you can see below, `schema.pb` is a protobuf file that contains metadata including statistics about features such as cardinality, min and max values and also tags features based on their characteristics and dtypes (e.g., categorical, continuous, list, integer)." - ] - }, - { - "cell_type": "markdown", - "id": "a510b6ef", - "metadata": {}, - "source": [ - "### Set the schema object" - ] - }, - { - "cell_type": "markdown", - "id": "30a0518a-eb01-4ac4-9c6d-36b328985765", - "metadata": {}, - "source": [ - "We create the schema object by reading the `schema.pbtxt` file generated by NVTabular pipeline in the previous, `01-ETL-with-NVTabular`, notebook." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "9d1299fa", - "metadata": {}, - "outputs": [], - "source": [ - "from merlin_standard_lib import Schema\n", - "# import merlin.io\n", - "# from merlin.models.utils import schema_utils\n", - "# from merlin.schema import Schema, Tags\n", - "# from merlin.schema.io.tensorflow_metadata import TensorflowMetadata\n", - "# from merlin.schema import Schema\n", - "SCHEMA_PATH = os.environ.get(\"INPUT_SCHEMA_PATH\", \"/workspace/data/processed_nvt/schema.pbtxt\")\n", - "schema = Schema().from_proto_text(SCHEMA_PATH)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "868f0317-d140-40d5-b4bd-29a27e12077b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'name': 'session_id', 'type': 'INT', 'int_domain': {'name': 'session_id', 'max': '19877', 'is_categorical': True}, 'annotation': {'tag': ['categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.session_id.parquet\", \"embedding_sizes\": {\"cardinality\": 19878.0, \"dimension\": 409.0}, \"dtype_item_size\": 64.0, \"is_list\": false, \"is_ragged\": false}']}}, {'name': 'day-first', 'type': 'INT', 'annotation': {'comment': ['{\"dtype_item_size\": 64.0, \"is_list\": false, \"is_ragged\": false}']}}, {'name': 'item_id-count', 'type': 'INT', 'int_domain': {'name': 'item_id', 'max': '506', 'is_categorical': True}, 'annotation': {'tag': ['categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.item_id.parquet\", \"embedding_sizes\": {\"cardinality\": 507.0, \"dimension\": 52.0}, \"dtype_item_size\": 32.0, \"is_list\": false, \"is_ragged\": false}']}}, {'name': 'item_id-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'INT', 'int_domain': {'name': 'item_id', 'max': '506', 'is_categorical': True}, 'annotation': {'tag': ['list', 'item_id', 'item', 'id', 'categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.item_id.parquet\", \"embedding_sizes\": {\"cardinality\": 507.0, \"dimension\": 52.0}, \"dtype_item_size\": 64.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'category-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'INT', 'int_domain': {'name': 'category', 'max': '137', 'is_categorical': True}, 'annotation': {'tag': ['list', 'categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.category.parquet\", \"embedding_sizes\": {\"cardinality\": 138.0, \"dimension\": 25.0}, \"dtype_item_size\": 64.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'age_days-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'FLOAT', 'annotation': {'tag': ['list', 'continuous'], 'comment': ['{\"dtype_item_size\": 32.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'weekday_sin-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'FLOAT', 'annotation': {'tag': ['list', 'continuous'], 'comment': ['{\"dtype_item_size\": 32.0, \"is_list\": true, \"is_ragged\": true}']}}]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "schema" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "b4f426f0", - "metadata": {}, - "outputs": [], - "source": [ - "# # You can select a subset of features for training\n", - "\n", - "# You can select a subset of features for training\n", - "schema = schema.select_by_name(['item_id-list', \n", - " 'category-list',\n", - " 'weekday_sin-list',\n", - " 'age_days-list'\n", - " ])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "31bd0f44-ecfe-489a-88ac-032b5a512622", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'name': 'item_id-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'INT', 'int_domain': {'name': 'item_id', 'max': '506', 'is_categorical': True}, 'annotation': {'tag': ['list', 'item_id', 'item', 'id', 'categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.item_id.parquet\", \"embedding_sizes\": {\"cardinality\": 507.0, \"dimension\": 52.0}, \"dtype_item_size\": 64.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'category-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'INT', 'int_domain': {'name': 'category', 'max': '137', 'is_categorical': True}, 'annotation': {'tag': ['list', 'categorical'], 'comment': ['{\"num_buckets\": null, \"freq_threshold\": 0.0, \"max_size\": 0.0, \"start_index\": 1.0, \"cat_path\": \".//categories/unique.category.parquet\", \"embedding_sizes\": {\"cardinality\": 138.0, \"dimension\": 25.0}, \"dtype_item_size\": 64.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'age_days-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'FLOAT', 'annotation': {'tag': ['list', 'continuous'], 'comment': ['{\"dtype_item_size\": 32.0, \"is_list\": true, \"is_ragged\": true}']}}, {'name': 'weekday_sin-list', 'value_count': {'min': '2', 'max': '15'}, 'type': 'FLOAT', 'annotation': {'tag': ['list', 'continuous'], 'comment': ['{\"dtype_item_size\": 32.0, \"is_list\": true, \"is_ragged\": true}']}}]" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "schema" - ] - }, - { - "cell_type": "markdown", - "id": "06cacefa", - "metadata": {}, - "source": [ - "### Define the sequential input module" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "b38d30d7", - "metadata": {}, - "outputs": [], - "source": [ - "inputs = tr.TabularSequenceFeatures.from_schema(\n", - " schema,\n", - " max_sequence_length=15,\n", - " continuous_projection=64,\n", - " d_output=100,\n", - " masking=\"causal\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "ed749ca8", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", - " not been set for this class (NDCGAt). The property determines if `update` by\n", - " default needs access to the full metric state. If this is not the case, significant speedups can be\n", - " achieved and we recommend setting this to `False`.\n", - " We provide an checking function\n", - " `from torchmetrics.utilities import check_forward_full_state_property`\n", - " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", - " default for now) or if `full_state_update=False` can be used safely.\n", - " \n", - " warnings.warn(*args, **kwargs)\n", - "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", - " not been set for this class (DCGAt). The property determines if `update` by\n", - " default needs access to the full metric state. If this is not the case, significant speedups can be\n", - " achieved and we recommend setting this to `False`.\n", - " We provide an checking function\n", - " `from torchmetrics.utilities import check_forward_full_state_property`\n", - " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", - " default for now) or if `full_state_update=False` can be used safely.\n", - " \n", - " warnings.warn(*args, **kwargs)\n", - "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", - " not been set for this class (RecallAt). The property determines if `update` by\n", - " default needs access to the full metric state. If this is not the case, significant speedups can be\n", - " achieved and we recommend setting this to `False`.\n", - " We provide an checking function\n", - " `from torchmetrics.utilities import check_forward_full_state_property`\n", - " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", - " default for now) or if `full_state_update=False` can be used safely.\n", - " \n", - " warnings.warn(*args, **kwargs)\n" - ] - } - ], - "source": [ - "# Define XLNetConfig class and set default parameters for HF XLNet config \n", - "transformer_config = tr.XLNetConfig.build(\n", - " d_model=64, n_head=4, n_layer=2, total_seq_length=20\n", - ")\n", - "# Define the model block including: inputs, masking, projection and transformer block.\n", - "body = tr.SequentialBlock(\n", - " inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config, masking=inputs.masking)\n", - ")\n", - "\n", - "# Defines the evaluation top-N metrics and the cut-offs\n", - "metrics = [NDCGAt(top_ks=[20, 40], labels_onehot=True), \n", - " RecallAt(top_ks=[20, 40], labels_onehot=True)]\n", - "\n", - "# Define a head related to next item prediction task \n", - "head = tr.Head(\n", - " body,\n", - " tr.NextItemPredictionTask(weight_tying=True, metrics=metrics),\n", - " inputs=inputs,\n", - ")\n", - "\n", - "# Get the end-to-end Model class \n", - "model = tr.Model(head)" - ] - }, - { - "cell_type": "markdown", - "id": "a57335ff", - "metadata": {}, - "source": [ - "Note that we can easily define an RNN-based model inside the `SequentialBlock` instead of a Transformer-based model. You can explore this [tutorial](https://github.com/NVIDIA-Merlin/Transformers4Rec/tree/main/examples/tutorial) for a GRU-based model example." - ] - }, - { - "cell_type": "markdown", - "id": "16d51e39", - "metadata": {}, - "source": [ - "### Train the model " - ] - }, - { - "cell_type": "markdown", - "id": "f26d7aec", - "metadata": {}, - "source": [ - "We use the NVTabular PyTorch Dataloader for optimized loading of multiple features from input parquet files. You can learn more about this data loader [here](https://nvidia-merlin.github.io/NVTabular/main/training/pytorch.html)." - ] - }, - { - "cell_type": "markdown", - "id": "02fd4c22", - "metadata": {}, - "source": [ - "### **Set Training arguments**" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "693974df", - "metadata": {}, - "outputs": [], - "source": [ - "from transformers4rec.config.trainer import T4RecTrainingArguments\n", - "from transformers4rec.torch import Trainer\n", - "# Set hyperparameters for training \n", - "train_args = T4RecTrainingArguments(data_loader_engine='nvtabular', \n", - " dataloader_drop_last = True,\n", - " gradient_accumulation_steps = 1,\n", - " per_device_train_batch_size = 128, \n", - " per_device_eval_batch_size = 32,\n", - " output_dir = \"./tmp\", \n", - " learning_rate=0.0005,\n", - " lr_scheduler_type='cosine', \n", - " learning_rate_num_cosine_cycles_by_epoch=1.5,\n", - " num_train_epochs=5,\n", - " max_sequence_length=20, \n", - " report_to = [],\n", - " logging_steps=50,\n", - " no_cuda=False)" - ] - }, - { - "cell_type": "markdown", - "id": "445ece64", - "metadata": {}, - "source": [ - "Note that we add an argument `data_loader_engine='nvtabular'` to automatically load the features needed for training using the schema. The default value is nvtabular for optimized GPU-based data-loading. Optionally a PyarrowDataLoader (pyarrow) can also be used as a basic option, but it is slower and works only for small datasets, as the full data is loaded to CPU memory." - ] - }, - { - "cell_type": "markdown", - "id": "32554ea0", - "metadata": {}, - "source": [ - "## Daily Fine-Tuning: Training over a time window" - ] - }, - { - "cell_type": "markdown", - "id": "ef883061", - "metadata": {}, - "source": [ - "Here we do daily fine-tuning meaning that we use the first day to train and second day to evaluate, then we use the second day data to train the model by resuming from the first step, and evaluate on the third day, so on so forth." - ] - }, - { - "cell_type": "markdown", - "id": "9f452d09", - "metadata": {}, - "source": [ - "We have extended the HuggingFace transformers `Trainer` class (PyTorch only) to support evaluation of RecSys metrics. In this example, the evaluation of the session-based recommendation model is performed using traditional Top-N ranking metrics such as Normalized Discounted Cumulative Gain (NDCG@20) and Hit Rate (HR@20). NDCG accounts for rank of the relevant item in the recommendation list and is a more fine-grained metric than HR, which only verifies whether the relevant item is among the top-n items. HR@n is equivalent to Recall@n when there is only one relevant item in the recommendation list." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "2283f788", - "metadata": {}, - "outputs": [], - "source": [ - "# Instantiate the T4Rec Trainer, which manages training and evaluation for the PyTorch API\n", - "trainer = Trainer(\n", - " model=model,\n", - " args=train_args,\n", - " schema=schema,\n", - " compute_metrics=True,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "d515127c", - "metadata": {}, - "source": [ - "- Define the output folder of the processed parquet files" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "ae313150", - "metadata": {}, - "outputs": [], - "source": [ - "INPUT_DATA_DIR = os.environ.get(\"INPUT_DATA_DIR\", \"/workspace/data\")\n", - "OUTPUT_DIR = os.environ.get(\"OUTPUT_DIR\", f\"{INPUT_DATA_DIR}/sessions_by_day\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "8ae51de0", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "***** Running training *****\n", - " Num examples = 1664\n", - " Num Epochs = 5\n", - " Instantaneous batch size per device = 128\n", - " Total train batch size (w. parallel, distributed & accumulation) = 128\n", - " Gradient Accumulation steps = 1\n", - " Total optimization steps = 65\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['/workspace/data/sessions_by_day/1/train.parquet']\n", - "********************\n", - "Launch training for day 1 are:\n", - "********************\n", - "\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [65/65 00:01, Epoch 5/5]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
505.731000

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "\n", - "Training completed. Do not forget to share your model on huggingface.co/models =)\n", - "\n", - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "finished\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "

\n", - " \n", - " \n", - " [6/6 00:02]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "********************\n", - "Eval results for day 2 are:\t\n", - "\n", - "********************\n", - "\n", - " eval_/loss = 4.892789840698242\n", - " eval_/next-item/ndcg_at_20 = 0.2019212245941162\n", - " eval_/next-item/ndcg_at_40 = 0.24986284971237183\n", - " eval_/next-item/recall_at_20 = 0.5104166865348816\n", - " eval_/next-item/recall_at_40 = 0.7447916865348816\n", - " eval_runtime = 0.1608\n", - " eval_samples_per_second = 1193.821\n", - " eval_steps_per_second = 37.307\n", - "['/workspace/data/sessions_by_day/2/train.parquet']\n", - "********************\n", - "Launch training for day 2 are:\n", - "********************\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "***** Running training *****\n", - " Num examples = 1664\n", - " Num Epochs = 5\n", - " Instantaneous batch size per device = 128\n", - " Total train batch size (w. parallel, distributed & accumulation) = 128\n", - " Gradient Accumulation steps = 1\n", - " Total optimization steps = 65\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [65/65 00:01, Epoch 5/5]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
504.795200

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "\n", - "Training completed. Do not forget to share your model on huggingface.co/models =)\n", - "\n", - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "finished\n", - "********************\n", - "Eval results for day 3 are:\t\n", - "\n", - "********************\n", - "\n", - " eval_/loss = 4.611485481262207\n", - " eval_/next-item/ndcg_at_20 = 0.1681433618068695\n", - " eval_/next-item/ndcg_at_40 = 0.22220981121063232\n", - " eval_/next-item/recall_at_20 = 0.484375\n", - " eval_/next-item/recall_at_40 = 0.7447916865348816\n", - " eval_runtime = 0.1687\n", - " eval_samples_per_second = 1138.188\n", - " eval_steps_per_second = 35.568\n", - "CPU times: user 15.9 s, sys: 247 ms, total: 16.1 s\n", - "Wall time: 5.28 s\n" - ] - } - ], - "source": [ - "%%time\n", - "start_time_window_index = 1\n", - "final_time_window_index = 3\n", - "#Iterating over days of one week\n", - "for time_index in range(start_time_window_index, final_time_window_index):\n", - " # Set data \n", - " time_index_train = time_index\n", - " time_index_eval = time_index + 1\n", - " train_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_train}/train.parquet\"))\n", - " eval_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_eval}/valid.parquet\"))\n", - " print(train_paths)\n", - " \n", - " # Train on day related to time_index \n", - " print('*'*20)\n", - " print(\"Launch training for day %s are:\" %time_index)\n", - " print('*'*20 + '\\n')\n", - " trainer.train_dataset_or_path = train_paths\n", - " trainer.reset_lr_scheduler()\n", - " trainer.train()\n", - " trainer.state.global_step +=1\n", - " print('finished')\n", - " \n", - " # Evaluate on the following day\n", - " trainer.eval_dataset_or_path = eval_paths\n", - " train_metrics = trainer.evaluate(metric_key_prefix='eval')\n", - " print('*'*20)\n", - " print(\"Eval results for day %s are:\\t\" %time_index_eval)\n", - " print('\\n' + '*'*20 + '\\n')\n", - " for key in sorted(train_metrics.keys()):\n", - " print(\" %s = %s\" % (key, str(train_metrics[key]))) \n", - " wipe_memory()" - ] - }, - { - "cell_type": "markdown", - "id": "6a8d7bd8", - "metadata": {}, - "source": [ - "### Re-compute eval metrics of validation data" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "3a34be66", - "metadata": {}, - "outputs": [], - "source": [ - "eval_data_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_eval}/valid.parquet\"))\n", - "\n", - "# set new data from day 7\n", - "eval_metrics = trainer.evaluate(eval_dataset=eval_data_paths, metric_key_prefix='eval')\n", - "for key in sorted(eval_metrics.keys()):\n", - " print(\" %s = %s\" % (key, str(eval_metrics[key])))" - ] - }, - { - "cell_type": "markdown", - "id": "4a26a649", - "metadata": {}, - "source": [ - "That's it! \n", - "You have just trained your session-based recommendation model using Transformers4Rec." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "e516a78d-2e1a-4124-ba46-f60b245d3329", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Model(\n", - " (heads): ModuleList(\n", - " (0): Head(\n", - " (body): SequentialBlock(\n", - " (0): TabularSequenceFeatures(\n", - " (to_merge): ModuleDict(\n", - " (continuous_module): SequentialBlock(\n", - " (0): ContinuousFeatures(\n", - " (filter_features): FilterFeatures()\n", - " (_aggregation): ConcatFeatures()\n", - " )\n", - " (1): SequentialBlock(\n", - " (0): DenseBlock(\n", - " (0): Linear(in_features=2, out_features=64, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " )\n", - " )\n", - " (2): AsTabular()\n", - " )\n", - " (categorical_module): SequenceEmbeddingFeatures(\n", - " (filter_features): FilterFeatures()\n", - " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(507, 64, padding_idx=0)\n", - " (category-list): Embedding(138, 64, padding_idx=0)\n", - " )\n", - " )\n", - " )\n", - " (_aggregation): ConcatFeatures()\n", - " (projection_module): SequentialBlock(\n", - " (0): DenseBlock(\n", - " (0): Linear(in_features=192, out_features=100, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " )\n", - " )\n", - " (_masking): CausalLanguageModeling()\n", - " )\n", - " (1): SequentialBlock(\n", - " (0): DenseBlock(\n", - " (0): Linear(in_features=100, out_features=64, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " )\n", - " )\n", - " (2): TansformerBlock(\n", - " (transformer): XLNetModel(\n", - " (word_embedding): Embedding(1, 64)\n", - " (layer): ModuleList(\n", - " (0): XLNetLayer(\n", - " (rel_attn): XLNetRelativeAttention(\n", - " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.3, inplace=False)\n", - " )\n", - " (ff): XLNetFeedForward(\n", - " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", - " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", - " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", - " (dropout): Dropout(p=0.3, inplace=False)\n", - " )\n", - " (dropout): Dropout(p=0.3, inplace=False)\n", - " )\n", - " (1): XLNetLayer(\n", - " (rel_attn): XLNetRelativeAttention(\n", - " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", - " (dropout): Dropout(p=0.3, inplace=False)\n", - " )\n", - " (ff): XLNetFeedForward(\n", - " (layer_norm): LayerNorm((64,), eps=0.03, elementwise_affine=True)\n", - " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", - " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", - " (dropout): Dropout(p=0.3, inplace=False)\n", - " )\n", - " (dropout): Dropout(p=0.3, inplace=False)\n", - " )\n", - " )\n", - " (dropout): Dropout(p=0.3, inplace=False)\n", - " )\n", - " (masking): CausalLanguageModeling()\n", - " )\n", - " )\n", - " (prediction_task_dict): ModuleDict(\n", - " (next-item): NextItemPredictionTask(\n", - " (sequence_summary): SequenceSummary(\n", - " (summary): Identity()\n", - " (activation): Identity()\n", - " (first_dropout): Identity()\n", - " (last_dropout): Identity()\n", - " )\n", - " (metrics): ModuleList(\n", - " (0): NDCGAt()\n", - " (1): RecallAt()\n", - " )\n", - " (loss): NLLLoss()\n", - " (embeddings): SequenceEmbeddingFeatures(\n", - " (filter_features): FilterFeatures()\n", - " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(507, 64, padding_idx=0)\n", - " (category-list): Embedding(138, 64, padding_idx=0)\n", - " )\n", - " )\n", - " (item_embedding_table): Embedding(507, 64, padding_idx=0)\n", - " (masking): CausalLanguageModeling()\n", - " (pre): Block(\n", - " (module): NextItemPredictionTask(\n", - " (item_embedding_table): Embedding(507, 64, padding_idx=0)\n", - " (log_softmax): LogSoftmax(dim=-1)\n", - " )\n", - " )\n", - " )\n", - " )\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = model.cuda()\n", - "model.eval()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "e1e140ac-afd6-455d-9057-1bbd07116a9b", - "metadata": {}, - "outputs": [], - "source": [ - "model.hf_format = False" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "ef88f601-c7c0-4244-84f3-ee257b579205", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/html": [ - "

\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nametagsdtypeis_listis_raggedproperties.int_domain.minproperties.int_domain.max
0age_days-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalse00
1weekday_sin-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalse00
2item_id-list(Tags.ID, Tags.ITEM, Tags.CATEGORICAL, Tags.IT...int64TrueFalse0506
3category-list(Tags.CATEGORICAL, Tags.LIST)int64TrueFalse0137
\n", - "
" - ], - "text/plain": [ - "[{'name': 'age_days-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'int_domain': {'min': 0, 'max': 506}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 137}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}]" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.input_schema" - ] - }, - { - "cell_type": "markdown", - "id": "409c17c2-c81d-4f41-9577-bd380ae10921", - "metadata": {}, - "source": [ - "Create a dict of tensors" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "6273c8e5-db62-4cc6-a4f3-945155a463d6", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = Dataset(train_paths[0])\n", - "trainer.train_dataset_or_path = dataset\n", - "loader = trainer.get_train_dataloader()\n", - "train_dict = next(iter(loader))" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "43e789a3-2423-44eb-ae5d-0c154557424f", - "metadata": {}, - "outputs": [], - "source": [ - "traced_model = torch.jit.trace(model, train_dict, strict=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "8cdbc288-baf5-4beb-a3ac-5fcb7315125c", - "metadata": {}, - "outputs": [], - "source": [ - "assert isinstance(traced_model, torch.jit.TopLevelTracedModule)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "896f4d80-3703-42af-a48e-2b81e839006d", - "metadata": {}, - "outputs": [], - "source": [ - "assert torch.allclose(\n", - " model(train_dict),\n", - " traced_model(train_dict),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "c6a814aa-4954-404b-bd2d-161ed8066f4e", - "metadata": {}, - "outputs": [], - "source": [ - "input_schema = model.input_schema\n", - "output_schema = model.output_schema" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "757cd0c5-f581-488b-a8de-b8d1188820d6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nametagsdtypeis_listis_raggedproperties.int_domain.minproperties.int_domain.max
0age_days-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalse00
1weekday_sin-list(Tags.CONTINUOUS, Tags.LIST)float32TrueFalse00
2item_id-list(Tags.ID, Tags.ITEM, Tags.CATEGORICAL, Tags.IT...int64TrueFalse0506
3category-list(Tags.CATEGORICAL, Tags.LIST)int64TrueFalse0137
\n", - "
" - ], - "text/plain": [ - "[{'name': 'age_days-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': dtype('float32'), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'int_domain': {'min': 0, 'max': 506}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 137}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': False}]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "input_schema" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "4f96597c-1c05-4fb0-ad3e-c55c21599158", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:148: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from merlin.core.dispatch import make_df # noqa\n", - "from merlin.systems.dag import Ensemble # noqa\n", - "from merlin.systems.dag.ops.pytorch import PredictPyTorch # noqa\n", - "from merlin.systems.triton.utils import run_ensemble_on_tritonserver # noqa\n", - "\n", - "torch_op = input_schema.column_names >> PredictPyTorch(\n", - " traced_model, input_schema, output_schema\n", - ")\n", - "\n", - "ensemble = Ensemble(torch_op, input_schema)\n", - "ens_config, node_configs = ensemble.export(str('./models'))" - ] - }, - { - "cell_type": "markdown", - "id": "5faba154-d4b2-4424-a1b2-badd2227e66e", - "metadata": {}, - "source": [ - "Create a dataframe to send as a request. We need a dataset where the list columns are padded to the max sequence lenght that was set in the ETL pipeline." - ] - }, - { - "cell_type": "code", - "execution_count": 66, - "id": "c1ce3a29-5578-41ca-a033-abc4507adfef", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(128, 4)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
age_days-listweekday_sin-listitem_id-listcategory-list
0[0.37403864, 0.42758772, 0.93743354, 0.0, 0.0,...[0.9351001, 0.91299504, 0.9785595, 0.0, 0.0, 0...[30, 24, 200, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...[7, 6, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
1[0.9327483, 0.36575532, 0.13967341, 0.45479113...[0.5046626, 0.17492707, 0.12539314, 0.6640924,...[190, 10, 7, 55, 27, 3, 6, 184, 0, 0, 0, 0, 0,...[36, 3, 2, 11, 6, 4, 2, 36, 0, 0, 0, 0, 0, 0, ...
2[0.57168996, 0.48532194, 0.89944935, 0.2171675...[0.93685514, 0.5638695, 0.76670134, 0.6797855,...[153, 8, 46, 58, 21, 19, 31, 15, 4, 104, 0, 0,...[28, 2, 10, 11, 5, 5, 7, 3, 2, 18, 0, 0, 0, 0,...
3[0.8520663, 0.6690395, 0.92268515, 0.99163777,...[0.58499664, 0.45736608, 0.88926136, 0.9139287...[19, 28, 23, 34, 18, 10, 0, 0, 0, 0, 0, 0, 0, ...[5, 6, 6, 7, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
4[0.67542243, 0.65952307, 0.7467189, 0.6136317,...[0.09077961, 0.7920753, 0.35881928, 0.8545563,...[17, 27, 70, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...[5, 6, 14, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
\n", - "
" - ], - "text/plain": [ - " age_days-list \\\n", - "0 [0.37403864, 0.42758772, 0.93743354, 0.0, 0.0,... \n", - "1 [0.9327483, 0.36575532, 0.13967341, 0.45479113... \n", - "2 [0.57168996, 0.48532194, 0.89944935, 0.2171675... \n", - "3 [0.8520663, 0.6690395, 0.92268515, 0.99163777,... \n", - "4 [0.67542243, 0.65952307, 0.7467189, 0.6136317,... \n", - "\n", - " weekday_sin-list \\\n", - "0 [0.9351001, 0.91299504, 0.9785595, 0.0, 0.0, 0... \n", - "1 [0.5046626, 0.17492707, 0.12539314, 0.6640924,... \n", - "2 [0.93685514, 0.5638695, 0.76670134, 0.6797855,... \n", - "3 [0.58499664, 0.45736608, 0.88926136, 0.9139287... \n", - "4 [0.09077961, 0.7920753, 0.35881928, 0.8545563,... \n", - "\n", - " item_id-list \\\n", - "0 [30, 24, 200, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "1 [190, 10, 7, 55, 27, 3, 6, 184, 0, 0, 0, 0, 0,... \n", - "2 [153, 8, 46, 58, 21, 19, 31, 15, 4, 104, 0, 0,... \n", - "3 [19, 28, 23, 34, 18, 10, 0, 0, 0, 0, 0, 0, 0, ... \n", - "4 [17, 27, 70, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "\n", - " category-list \n", - "0 [7, 6, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "1 [36, 3, 2, 11, 6, 4, 2, 36, 0, 0, 0, 0, 0, 0, ... \n", - "2 [28, 2, 10, 11, 5, 5, 7, 3, 2, 18, 0, 0, 0, 0,... \n", - "3 [5, 6, 6, 7, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "4 [5, 6, 14, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... " - ] - }, - "execution_count": 66, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset = Dataset(eval_paths[0])\n", - "# trainer.test_dataset_or_path = dataset\n", - "loader = trainer.get_test_dataloader(dataset)\n", - "test_dict = next(iter(loader))\n", - "\n", - "df_cols = {}\n", - "for name, tensor in train_dict.items():\n", - " if name in input_schema.column_names:\n", - " dtype = input_schema[name].dtype\n", - "\n", - " df_cols[name] = tensor.cpu().numpy().astype(dtype)\n", - " if len(tensor.shape) > 1:\n", - " df_cols[name] = list(df_cols[name])\n", - "\n", - "df = make_df(df_cols)\n", - "print(df.shape)\n", - "df.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 68, - "id": "617348d4-3493-4b68-ba9a-da9543147628", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "I1123 22:15:22.928458 2118 pinned_memory_manager.cc:240] Pinned memory pool is created at '0x7f428e000000' with size 268435456\n", - "I1123 22:15:22.928899 2118 cuda_memory_manager.cc:105] CUDA memory pool is created on device 0 with size 67108864\n", - "I1123 22:15:22.931602 2118 model_lifecycle.cc:459] loading: 0_predictpytorch:1\n", - "I1123 22:15:23.299444 2118 libtorch.cc:1983] TRITONBACKEND_Initialize: pytorch\n", - "I1123 22:15:23.299463 2118 libtorch.cc:1993] Triton TRITONBACKEND API version: 1.10\n", - "I1123 22:15:23.299469 2118 libtorch.cc:1999] 'pytorch' TRITONBACKEND API version: 1.10\n", - "I1123 22:15:23.299488 2118 libtorch.cc:2032] TRITONBACKEND_ModelInitialize: 0_predictpytorch (version 1)\n", - "W1123 22:15:23.300039 2118 libtorch.cc:284] skipping model configuration auto-complete for '0_predictpytorch': not supported for pytorch backend\n", - "I1123 22:15:23.300768 2118 libtorch.cc:313] Optimized execution is enabled for model instance '0_predictpytorch'\n", - "I1123 22:15:23.300780 2118 libtorch.cc:332] Cache Cleaning is disabled for model instance '0_predictpytorch'\n", - "I1123 22:15:23.300786 2118 libtorch.cc:349] Inference Mode is enabled for model instance '0_predictpytorch'\n", - "I1123 22:15:23.300790 2118 libtorch.cc:444] NvFuser is not specified for model instance '0_predictpytorch'\n", - "I1123 22:15:23.301026 2118 libtorch.cc:2076] TRITONBACKEND_ModelInstanceInitialize: 0_predictpytorch (GPU device 0)\n", - "I1123 22:15:24.229933 2118 model_lifecycle.cc:693] successfully loaded '0_predictpytorch' version 1\n", - "I1123 22:15:24.230204 2118 model_lifecycle.cc:459] loading: ensemble_model:1\n", - "I1123 22:15:24.230490 2118 model_lifecycle.cc:693] successfully loaded 'ensemble_model' version 1\n", - "I1123 22:15:24.230584 2118 server.cc:561] \n", - "+------------------+------+\n", - "| Repository Agent | Path |\n", - "+------------------+------+\n", - "+------------------+------+\n", - "\n", - "I1123 22:15:24.230668 2118 server.cc:588] \n", - "+---------+---------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "| Backend | Path | Config |\n", - "+---------+---------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "| pytorch | /opt/tritonserver/backends/pytorch/libtriton_pytorch.so | {\"cmdline\":{\"auto-complete-config\":\"true\",\"min-compute-capability\":\"6.000000\",\"backend-directory\":\"/opt/tritonserver/backends\",\"default-max-batch-size\":\"4\"}} |\n", - "+---------+---------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "\n", - "I1123 22:15:24.230751 2118 server.cc:631] \n", - "+------------------+---------+--------+\n", - "| Model | Version | Status |\n", - "+------------------+---------+--------+\n", - "| 0_predictpytorch | 1 | READY |\n", - "| ensemble_model | 1 | READY |\n", - "+------------------+---------+--------+\n", - "\n", - "I1123 22:15:24.282945 2118 metrics.cc:650] Collecting metrics for GPU 0: Quadro GV100\n", - "I1123 22:15:24.283260 2118 tritonserver.cc:2214] \n", - "+----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "| Option | Value |\n", - "+----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "| server_id | triton |\n", - "| server_version | 2.25.0 |\n", - "| server_extensions | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration system_shared_memory cuda_shared_memory binary_tensor_data statistics trace |\n", - "| model_repository_path[0] | ./models |\n", - "| model_control_mode | MODE_NONE |\n", - "| strict_model_config | 0 |\n", - "| rate_limit | OFF |\n", - "| pinned_memory_pool_byte_size | 268435456 |\n", - "| cuda_memory_pool_byte_size{0} | 67108864 |\n", - "| response_cache_byte_size | 0 |\n", - "| min_supported_compute_capability | 6.0 |\n", - "| strict_readiness | 1 |\n", - "| exit_timeout | 30 |\n", - "+----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n", - "\n", - "I1123 22:15:24.285008 2118 grpc_server.cc:4610] Started GRPCInferenceService at localhost:8001\n", - "I1123 22:15:24.285227 2118 http_server.cc:3316] Started HTTPService at 0.0.0.0:8000\n", - "I1123 22:15:24.326845 2118 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Signal (2) received.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "I1123 22:15:26.364164 2118 server.cc:262] Waiting for in-flight requests to complete.\n", - "I1123 22:15:26.364179 2118 server.cc:278] Timeout 30: Found 0 model versions that have in-flight inferences\n", - "I1123 22:15:26.364255 2118 server.cc:293] All models are stopped, unloading models\n", - "I1123 22:15:26.364263 2118 server.cc:300] Timeout 30: Found 2 live models and 0 in-flight non-inference requests\n", - "I1123 22:15:26.364287 2118 model_lifecycle.cc:578] successfully unloaded 'ensemble_model' version 1\n", - "I1123 22:15:26.364592 2118 libtorch.cc:2110] TRITONBACKEND_ModelInstanceFinalize: delete instance state\n", - "I1123 22:15:26.372137 2118 libtorch.cc:2055] TRITONBACKEND_ModelFinalize: delete model state\n", - "I1123 22:15:26.372333 2118 model_lifecycle.cc:578] successfully unloaded '0_predictpytorch' version 1\n", - "I1123 22:15:27.364444 2118 server.cc:300] Timeout 29: Found 0 live models and 0 in-flight non-inference requests\n" - ] - } - ], - "source": [ - "# ===========================================\n", - "# Send request to Triton and check response\n", - "# ===========================================\n", - "response = run_ensemble_on_tritonserver(\n", - " './models', input_schema, df[input_schema.column_names], output_schema.column_names, \"ensemble_model\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 70, - "id": "430555ac-d427-48ac-b93a-da2ea41b86d0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'next-item': array([[-7.9947233, -8.747803 , -3.4068425, ..., -7.914096 , -8.571636 ,\n", - " -7.815837 ],\n", - " [-7.977386 , -8.727583 , -3.3388414, ..., -7.9091067, -8.508778 ,\n", - " -7.7708635],\n", - " [-8.000487 , -8.737406 , -3.4030848, ..., -7.921053 , -8.557445 ,\n", - " -7.798526 ],\n", - " ...,\n", - " [-7.998163 , -8.739789 , -3.3824148, ..., -7.9103565, -8.550226 ,\n", - " -7.8002963],\n", - " [-7.9968286, -8.753717 , -3.3801503, ..., -7.9066863, -8.55961 ,\n", - " -7.794828 ],\n", - " [-8.01243 , -8.753323 , -3.3656597, ..., -7.8982997, -8.546498 ,\n", - " -7.7921886]], dtype=float32)}" - ] - }, - "execution_count": 70, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response" - ] - }, - { - "cell_type": "markdown", - "id": "d0581195-d496-4536-a93b-3071ed9088ea", - "metadata": {}, - "source": [ - "We return a response for each request in the df. Each row in the `response['next-item']` array corresponds to the logit values per item in the catalog and for the OOV item." - ] - }, - { - "cell_type": "code", - "execution_count": 69, - "id": "26bbdcb4-1347-46bd-a3eb-1c140f8bacd6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(128, 507)" - ] - }, - "execution_count": 69, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response['next-item'].shape" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "7b543a88d374ac88bf8df97911b380f671b13649694a5b49eb21e60fd27eb479" - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 071037f51588133db9af963c16eee73aeeb0c268 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 2 Dec 2022 21:05:19 +0000 Subject: [PATCH 4/4] update logo link, fix text --- .../01-ETL-with-NVTabular.ipynb | 2 +- ...ing-session-based-model-torch-backend.ipynb | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb index c71f62ebdb..b9e120b78e 100644 --- a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb @@ -245,7 +245,7 @@ "id": "139de226", "metadata": {}, "source": [ - "Deep Learning models require dense input features. Categorical features are sparse, and need to be represented by dense embeddings in the model. To allow for that, categorical features first need to be encoded as contiguous integers `(0, ..., |C|)`, where `|C|` is the feature cardinality (number of unique values), so that their embeddings can be efficiently stored in embedding layers. We will use NVTabular to preprocess the categorical features, so that all categorical columns are encoded as contiguous integers. Note that the `Categorify` op encodes OOVs or nulls to `0` automatically. In our synthetic dataset we do not have any nulls. On the other hand `0` is also used for padding the sequences in input block, thefore, you can set `start_index=1` arg in the Categorify op if you want the encoded null or OOV values to start from `1` instead of `0` because we reserve `0` for padding the sequence features." + "Deep Learning models require dense input features. Categorical features are sparse, and need to be represented by dense embeddings in the model. To allow for that, categorical features first need to be encoded as contiguous integers `(0, ..., |C|)`, where `|C|` is the feature cardinality (number of unique values), so that their embeddings can be efficiently stored in embedding layers. We will use NVTabular to preprocess the categorical features, so that all categorical columns are encoded as contiguous integers. Note that the `Categorify` op encodes OOVs or nulls to `0` automatically. In our synthetic dataset we do not have any nulls. On the other hand `0` is also used for padding the sequences in input block, therefore, you can set `start_index=1` arg in the Categorify op if you want the encoded null or OOV values to start from `1` instead of `0` because we reserve `0` for padding the sequence features." ] }, { diff --git a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb index 4dd7671b27..9d89edd5a9 100644 --- a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb +++ b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb @@ -28,7 +28,7 @@ "id": "0a2228da", "metadata": {}, "source": [ - "\n", + "\n", "\n", "# Serving a Session-based Recommendation model with Torch Backend" ] @@ -38,15 +38,13 @@ "id": "4127de28-a7ce-4ff7-8dcc-1575a70ca7c8", "metadata": {}, "source": [ - "This notebook is created using the latest stable `merlin-pytorch` container.\n", + "This notebook is created using the latest stable [merlin-pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-pytorch/tags) container.\n", "\n", - "At this point, when you reach out to this notebook, we expect that you have already executed the `01-ETL-with-NVTabular.ipynb` and `02-session-based-XLNet-with-PyT.ipynb` notebooks, and saved the NVT workflow and the trained model.\n", + "At this point, when you reach out to this notebook, we expect that you have already executed the `01-ETL-with-NVTabular.ipynb` and `02-session-based-XLNet-with-PyT.ipynb` notebooks, and saved the NVT workflow and the trained session-based model.\n", "\n", - "In this notebook, you are going to learn how you can serve a trained Transformer-based PyTorch model on Triton Inference Server (TIS) with Torch backend using [Merlin systems](https://github.com/NVIDIA-Merlin/systems) library. \n", + "In this notebook, you are going to learn how you can serve a trained Transformer-based PyTorch model on NVIDIA [Triton Inference Server](https://github.com/triton-inference-server/server) (TIS) with Torch backend using [Merlin systems](https://github.com/NVIDIA-Merlin/systems) library. One common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can be run in Python as well as in a high performance environment like C++. [TorchScript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html) is actually the recommended model format for scaled inference and deployment. TIS [PyTorch (LibTorch) backend](https://github.com/triton-inference-server/pytorch_backend) is designed to run TorchScript models using the PyTorch C++ API.\n", "\n", - "NVIDIA [Triton Inference Server](https://github.com/triton-inference-server/server) (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.\n", - "\n", - "If you would like to learn how to serve a TF4Rec model with Python backend please visit this [example](https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/main/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb)." + "[Triton Inference Server](https://github.com/triton-inference-server/server) (TIS) simplifies the deployment of AI models at scale in production. TIS provides a cloud and edge inferencing solution optimized for both CPUs and GPUs. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch." ] }, { @@ -360,7 +358,7 @@ "source": [ "### Trace the model\n", "\n", - "One common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can be run in Python as well as in a high performance environment like C++. TorchScript is actually the recommended model format for scaled inference and deployment. We serve the model with the PyTorch backend that is used to execute TorchScript models. All models created in PyTorch using the python API must be traced/scripted to produce a TorchScript model. For tracing the model, we use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) api that takes the model as a Python function or torch.nn.Module, and an example input that will be passed to the function while tracing." + "We serve the model with the PyTorch backend that is used to execute TorchScript models. All models created in PyTorch using the python API must be traced/scripted to produce a TorchScript model. For tracing the model, we use [torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) api that takes the model as a Python function or torch.nn.Module, and an example input that will be passed to the function while tracing." ] }, { @@ -972,7 +970,7 @@ "id": "d0581195-d496-4536-a93b-3071ed9088ea", "metadata": {}, "source": [ - "We return a response for each request in the df. Each row in the `response['next-item']` array corresponds to the logit values per item in the catalog, and one logit score correspondig to the null, OOV and padded items. The first score of each array in each row corresponds to the score for the padded item, OOV or null item. Note that we dont have OOV or null items in our syntheticall generated datasets." + "We return a response for each request in the df. Each row in the `response['next-item']` array corresponds to the logit values per item in the catalog, and one logit score corresponding to the null, OOV and padded items. The first score of each array in each row corresponds to the score for the padded item, OOV or null item. Note that we dont have OOV or null items in our syntheticall generated datasets." ] }, { @@ -980,7 +978,7 @@ "id": "5ab16c64-4371-4696-b2d6-3bff66e67fdb", "metadata": {}, "source": [ - "This is the end of this suit of examples. You successfully performed feature engineering with NVTabular trained transformer architecture based session-based recommendation models with Transformers4Rec deployed a trained model to Triton Inference Server, sent request and got responses from the server." + "This is the end of this suit of examples. You successfully performed feature engineering with NVTabular trained transformer architecture based session-based recommendation models with Transformers4Rec deployed a trained model to Triton Inference Server with Torch backend, sent request and got responses from the server. If you would like to learn how to serve a TF4Rec model with Python backend please visit this [example](https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/main/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb)." ] } ],