From acfb41fb6ed203572d39c9761e3cde6924b53749 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Apr 2023 16:07:00 +0000 Subject: [PATCH 1/2] serving NVT and TF4Rec model together --- .../01-ETL-with-NVTabular.ipynb | 237 ++++---- .../02-session-based-XLNet-with-PyT.ipynb | 220 ++++--- ...ng-session-based-model-torch-backend.ipynb | 549 +++++++++--------- .../test_getting_started_session_based.py | 31 +- 4 files changed, 527 insertions(+), 510 deletions(-) diff --git a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb index b43116656d..4185cb5b44 100644 --- a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb @@ -64,15 +64,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", - " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n" + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", + "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", "import glob\n", "\n", "import cudf\n", @@ -188,48 +189,48 @@ " \n", " \n", " 0\n", - " 84190\n", - " 16\n", - " 4\n", - " 0.223325\n", - " 0.121436\n", + " 75363\n", + " 18\n", " 4\n", + " 0.386272\n", + " 0.817845\n", + " 1\n", " \n", " \n", " 1\n", - " 74994\n", - " 18\n", - " 4\n", - " 0.055975\n", - " 0.285401\n", - " 4\n", + " 72281\n", + " 5\n", + " 1\n", + " 0.436883\n", + " 0.407265\n", + " 6\n", " \n", " \n", " 2\n", - " 80080\n", - " 29\n", - " 6\n", - " 0.301125\n", - " 0.274779\n", - " 3\n", + " 75383\n", + " 9\n", + " 2\n", + " 0.629215\n", + " 0.608197\n", + " 2\n", " \n", " \n", " 3\n", - " 87793\n", - " 4\n", + " 84734\n", + " 10\n", + " 2\n", + " 0.355827\n", + " 0.620883\n", " 1\n", - " 0.110288\n", - " 0.653482\n", - " 7\n", " \n", " \n", " 4\n", - " 78898\n", - " 18\n", - " 4\n", - " 0.586547\n", - " 0.859649\n", - " 6\n", + " 81038\n", + " 42\n", + " 8\n", + " 0.265468\n", + " 0.830717\n", + " 7\n", " \n", " \n", "\n", @@ -237,11 +238,11 @@ ], "text/plain": [ " session_id item_id category age_days weekday_sin day\n", - "0 84190 16 4 0.223325 0.121436 4\n", - "1 74994 18 4 0.055975 0.285401 4\n", - "2 80080 29 6 0.301125 0.274779 3\n", - "3 87793 4 1 0.110288 0.653482 7\n", - "4 78898 18 4 0.586547 0.859649 6" + "0 75363 18 4 0.386272 0.817845 1\n", + "1 72281 5 1 0.436883 0.407265 6\n", + "2 75383 9 2 0.629215 0.608197 2\n", + "3 84734 10 2 0.355827 0.620883 1\n", + "4 81038 42 8 0.265468 0.830717 7" ] }, "execution_count": 6, @@ -352,7 +353,7 @@ "dataset = nvt.Dataset(df)\n", "\n", "# Generate statistics for the features and export parquet files\n", - "# this step will generate the schema.pbtxt file\n", + "# this step will generate the schema file\n", "workflow.fit_transform(dataset).to_parquet(os.path.join(INPUT_DATA_DIR, \"processed_nvt\"))" ] }, @@ -366,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "78e42cbf-edd6-44af-af23-c026edb578c4", "metadata": {}, "outputs": [ @@ -454,7 +455,7 @@ " \n", " 2\n", " item_id-list\n", - " (Tags.ID, Tags.CATEGORICAL, Tags.ITEM, Tags.LI...\n", + " (Tags.ITEM_ID, Tags.LIST, Tags.ID, Tags.CATEGO...\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -464,9 +465,9 @@ " 0.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 502.0\n", + " 496.0\n", " item_id\n", - " 503.0\n", + " 497.0\n", " 52.0\n", " 20.0\n", " 20.0\n", @@ -474,7 +475,7 @@ " \n", " 3\n", " category-list\n", - " (Tags.CATEGORICAL, Tags.LIST)\n", + " (Tags.LIST, Tags.CATEGORICAL)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -484,17 +485,17 @@ " 0.0\n", " .//categories/unique.category.parquet\n", " 0.0\n", - " 184.0\n", + " 142.0\n", " category\n", - " 185.0\n", - " 30.0\n", + " 143.0\n", + " 26.0\n", " 20.0\n", " 20.0\n", " \n", " \n", " 4\n", " age_days-list\n", - " (Tags.CONTINUOUS, Tags.LIST)\n", + " (Tags.LIST, Tags.CONTINUOUS)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", @@ -514,7 +515,7 @@ " \n", " 5\n", " weekday_sin-list\n", - " (Tags.CONTINUOUS, Tags.LIST)\n", + " (Tags.LIST, Tags.CONTINUOUS)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", @@ -536,10 +537,10 @@ "" ], "text/plain": [ - "[{'name': 'session_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 502, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 503, 'dimension': 52}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 184, 'name': 'category'}, 'embedding_sizes': {'cardinality': 185, 'dimension': 30}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" + "[{'name': 'session_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 496, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 497, 'dimension': 52}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 142, 'name': 'category'}, 'embedding_sizes': {'cardinality': 143, 'dimension': 26}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -558,7 +559,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "9f498dce-69eb-4f88-8ddd-8629558825df", "metadata": {}, "outputs": [], @@ -584,7 +585,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "12d3e59b", "metadata": {}, "outputs": [], @@ -594,7 +595,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "id": "603fb27a-0c64-43eb-be79-42213944990b", "metadata": {}, "outputs": [], @@ -605,7 +606,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "id": "c537a248-059e-4db9-8b62-9681175f0193", "metadata": { "tags": [] @@ -616,24 +617,24 @@ "output_type": "stream", "text": [ " session_id day-first item_id-list \\\n", - "0 70000 5 [3, 10, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", - "1 70001 2 [21, 18, 61, 70, 41, 15, 47, 2, 54, 33, 16, 6,... \n", - "2 70002 4 [92, 190, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "0 70000 7 [19, 1, 86, 8, 28, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "1 70002 7 [9, 110, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "2 70003 6 [8, 229, 1, 21, 2, 17, 5, 7, 4, 0, 0, 0, 0, 0,... \n", "\n", " category-list \\\n", - "0 [1, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "1 [5, 5, 17, 18, 10, 6, 13, 1, 16, 4, 6, 2, 0, 0... \n", - "2 [25, 55, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "0 [4, 1, 17, 3, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "1 [2, 21, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2 [3, 48, 1, 5, 2, 4, 1, 1, 2, 0, 0, 0, 0, 0, 0,... \n", "\n", " age_days-list \\\n", - "0 [0.6733318, 0.36617512, 0.56633216, 0.0, 0.0, ... \n", - "1 [0.2618201, 0.47376513, 0.4943239, 0.47170892,... \n", - "2 [0.7606575, 0.68139946, 0.7614335, 0.0, 0.0, 0... \n", + "0 [0.566933, 0.11411724, 0.6606563, 0.4054278, 0... \n", + "1 [0.88592035, 0.049719226, 0.67839175, 0.0, 0.0... \n", + "2 [0.7854072, 0.84745556, 0.3776744, 0.3021382, ... \n", "\n", " weekday_sin-list \n", - "0 [0.007507112, 0.7772118, 0.89882964, 0.0, 0.0,... \n", - "1 [0.9721299, 0.9868918, 0.58013266, 0.54288113,... \n", - "2 [0.13897721, 0.026063459, 0.5617346, 0.0, 0.0,... \n" + "0 [0.8184131, 0.44557166, 0.48090392, 0.35810795... \n", + "1 [0.04808665, 0.33483326, 0.65433854, 0.0, 0.0,... \n", + "2 [0.7141748, 0.40484497, 0.37434393, 0.47856098... \n" ] } ], @@ -643,7 +644,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 13, "id": "6c67a92b", "metadata": {}, "outputs": [ @@ -651,7 +652,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Creating time-based splits: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 17.49it/s]\n" + "Creating time-based splits: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 18.53it/s]\n" ] } ], @@ -669,12 +670,12 @@ "id": "0b72337b", "metadata": {}, "source": [ - "## Checking the preprocessed outputs" + "## Check out the preprocessed outputs" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 14, "id": "dd04ec82", "metadata": {}, "outputs": [], @@ -684,7 +685,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "id": "8e5e6358", "metadata": {}, "outputs": [ @@ -719,43 +720,43 @@ " \n", " \n", " 0\n", - " 70004\n", - " [55, 5, 25, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...\n", - " [16, 1, 7, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", - " [0.092122145, 0.95085436, 0.43779024, 0.880414...\n", - " [0.84441566, 0.566204, 0.74057275, 0.42418396,...\n", + " 70020\n", + " [11, 16, 52, 4, 160, 11, 0, 0, 0, 0, 0, 0, 0, ...\n", + " [3, 4, 10, 2, 32, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0...\n", + " [0.30004495, 0.8680668, 0.051910266, 0.0361004...\n", + " [0.9554083, 0.5393047, 0.52827483, 0.018015692...\n", " \n", " \n", " 1\n", - " 70007\n", - " [30, 7, 10, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", - " [9, 4, 3, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", - " [0.87233526, 0.6098228, 0.80743825, 0.89595354...\n", - " [0.7175466, 0.12882064, 0.7239913, 0.53173864,...\n", + " 70022\n", + " [28, 1, 18, 76, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", + " [6, 1, 4, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [0.28977808, 0.89265364, 0.7273247, 0.9952868,...\n", + " [0.6657577, 0.51649666, 0.34749013, 0.6105231,...\n", " \n", " \n", " 2\n", " 70024\n", - " [4, 91, 12, 1, 18, 1, 1, 0, 0, 0, 0, 0, 0, 0, ...\n", - " [2, 24, 4, 1, 5, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,...\n", - " [0.028581467, 0.56572497, 0.43579134, 0.434222...\n", - " [0.84453183, 0.7441511, 0.34946096, 0.35230267...\n", + " [167, 113, 61, 22, 5, 88, 0, 0, 0, 0, 0, 0, 0,...\n", + " [33, 22, 12, 5, 1, 17, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [0.44703692, 0.9245853, 0.20643276, 0.02649925...\n", + " [0.813814, 0.06711837, 2.8271123e-05, 0.985612...\n", " \n", " \n", " 4\n", - " 70051\n", - " [1, 14, 201, 14, 93, 42, 0, 0, 0, 0, 0, 0, 0, ...\n", - " [1, 3, 54, 3, 28, 10, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", - " [0.5466232, 0.24969716, 0.17476603, 0.39066973...\n", - " [0.9378631, 0.5088918, 0.36846048, 0.71432567,...\n", + " 70032\n", + " [31, 5, 23, 14, 42, 7, 57, 5, 17, 0, 0, 0, 0, ...\n", + " [7, 1, 5, 3, 8, 1, 11, 1, 4, 0, 0, 0, 0, 0, 0,...\n", + " [0.8306253, 0.25969487, 0.4956863, 0.56623936,...\n", + " [0.17520918, 0.8598712, 0.83949673, 0.37930673...\n", " \n", " \n", " 5\n", - " 70059\n", - " [22, 10, 8, 14, 24, 15, 12, 133, 0, 0, 0, 0, 0...\n", - " [5, 3, 2, 3, 7, 6, 4, 35, 0, 0, 0, 0, 0, 0, 0,...\n", - " [0.6290009, 0.77528137, 0.68838376, 0.5259807,...\n", - " [0.10275719, 0.42096895, 0.2630285, 0.31103444...\n", + " 70060\n", + " [5, 49, 39, 62, 4, 18, 8, 40, 0, 0, 0, 0, 0, 0...\n", + " [1, 10, 8, 13, 2, 4, 3, 8, 0, 0, 0, 0, 0, 0, 0...\n", + " [0.8982996, 0.052276224, 0.7122792, 0.35780925...\n", + " [0.16947733, 0.32332653, 0.94814104, 0.1960088...\n", " \n", " \n", "\n", @@ -763,35 +764,35 @@ ], "text/plain": [ " session_id item_id-list \\\n", - "0 70004 [55, 5, 25, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", - "1 70007 [30, 7, 10, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "2 70024 [4, 91, 12, 1, 18, 1, 1, 0, 0, 0, 0, 0, 0, 0, ... \n", - "4 70051 [1, 14, 201, 14, 93, 42, 0, 0, 0, 0, 0, 0, 0, ... \n", - "5 70059 [22, 10, 8, 14, 24, 15, 12, 133, 0, 0, 0, 0, 0... \n", + "0 70020 [11, 16, 52, 4, 160, 11, 0, 0, 0, 0, 0, 0, 0, ... \n", + "1 70022 [28, 1, 18, 76, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "2 70024 [167, 113, 61, 22, 5, 88, 0, 0, 0, 0, 0, 0, 0,... \n", + "4 70032 [31, 5, 23, 14, 42, 7, 57, 5, 17, 0, 0, 0, 0, ... \n", + "5 70060 [5, 49, 39, 62, 4, 18, 8, 40, 0, 0, 0, 0, 0, 0... \n", "\n", " category-list \\\n", - "0 [16, 1, 7, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "1 [9, 4, 3, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "2 [2, 24, 4, 1, 5, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "4 [1, 3, 54, 3, 28, 10, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "5 [5, 3, 2, 3, 7, 6, 4, 35, 0, 0, 0, 0, 0, 0, 0,... \n", + "0 [3, 4, 10, 2, 32, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", + "1 [6, 1, 4, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2 [33, 22, 12, 5, 1, 17, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "4 [7, 1, 5, 3, 8, 1, 11, 1, 4, 0, 0, 0, 0, 0, 0,... \n", + "5 [1, 10, 8, 13, 2, 4, 3, 8, 0, 0, 0, 0, 0, 0, 0... \n", "\n", " age_days-list \\\n", - "0 [0.092122145, 0.95085436, 0.43779024, 0.880414... \n", - "1 [0.87233526, 0.6098228, 0.80743825, 0.89595354... \n", - "2 [0.028581467, 0.56572497, 0.43579134, 0.434222... \n", - "4 [0.5466232, 0.24969716, 0.17476603, 0.39066973... \n", - "5 [0.6290009, 0.77528137, 0.68838376, 0.5259807,... \n", + "0 [0.30004495, 0.8680668, 0.051910266, 0.0361004... \n", + "1 [0.28977808, 0.89265364, 0.7273247, 0.9952868,... \n", + "2 [0.44703692, 0.9245853, 0.20643276, 0.02649925... \n", + "4 [0.8306253, 0.25969487, 0.4956863, 0.56623936,... \n", + "5 [0.8982996, 0.052276224, 0.7122792, 0.35780925... \n", "\n", " weekday_sin-list \n", - "0 [0.84441566, 0.566204, 0.74057275, 0.42418396,... \n", - "1 [0.7175466, 0.12882064, 0.7239913, 0.53173864,... \n", - "2 [0.84453183, 0.7441511, 0.34946096, 0.35230267... \n", - "4 [0.9378631, 0.5088918, 0.36846048, 0.71432567,... \n", - "5 [0.10275719, 0.42096895, 0.2630285, 0.31103444... " + "0 [0.9554083, 0.5393047, 0.52827483, 0.018015692... \n", + "1 [0.6657577, 0.51649666, 0.34749013, 0.6105231,... \n", + "2 [0.813814, 0.06711837, 2.8271123e-05, 0.985612... \n", + "4 [0.17520918, 0.8598712, 0.83949673, 0.37930673... \n", + "5 [0.16947733, 0.32332653, 0.94814104, 0.1960088... " ] }, - "execution_count": 20, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -803,17 +804,17 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 16, "id": "a687f998-8905-42a4-bb92-d1f5244860b6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "594" + "430" ] }, - "execution_count": 21, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb b/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb index 3be757a32c..4a144e04f4 100644 --- a/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb +++ b/examples/getting-started-session-based/02-session-based-XLNet-with-PyT.ipynb @@ -271,7 +271,7 @@ " \n", " 0\n", " item_id-list\n", - " (Tags.ITEM_ID, Tags.ID, Tags.CATEGORICAL, Tags...\n", + " (Tags.ITEM_ID, Tags.ID, Tags.LIST, Tags.CATEGO...\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -280,10 +280,10 @@ " 0.0\n", " 0.0\n", " .//categories/unique.item_id.parquet\n", - " 481.0\n", - " 51.0\n", + " 497.0\n", + " 52.0\n", " 0.0\n", - " 480.0\n", + " 496.0\n", " item_id\n", " 20\n", " 20\n", @@ -291,7 +291,7 @@ " \n", " 1\n", " category-list\n", - " (Tags.LIST, Tags.CATEGORICAL)\n", + " (Tags.CATEGORICAL, Tags.LIST)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -300,10 +300,10 @@ " 0.0\n", " 0.0\n", " .//categories/unique.category.parquet\n", - " 145.0\n", + " 143.0\n", " 26.0\n", " 0.0\n", - " 144.0\n", + " 142.0\n", " category\n", " 20\n", " 20\n", @@ -311,7 +311,7 @@ " \n", " 2\n", " weekday_sin-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", @@ -331,7 +331,7 @@ " \n", " 3\n", " age_days-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", @@ -353,7 +353,7 @@ "" ], "text/plain": [ - "[{'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 481.0, 'dimension': 51.0}, 'domain': {'min': 0, 'max': 480, 'name': 'item_id'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0.0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 145.0, 'dimension': 26.0}, 'domain': {'min': 0, 'max': 144, 'name': 'category'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0.0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0.0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0.0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" + "[{'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 497.0, 'dimension': 52.0}, 'domain': {'min': 0, 'max': 496, 'name': 'item_id'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0.0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 143.0, 'dimension': 26.0}, 'domain': {'min': 0, 'max': 142, 'name': 'category'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0.0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0.0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0.0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" ] }, "execution_count": 6, @@ -657,7 +657,7 @@ "
\n", " \n", " \n", - " [65/65 00:01, Epoch 5/5]\n", + " [65/65 00:02, Epoch 5/5]\n", "
\n", " \n", " \n", @@ -669,7 +669,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
505.8123005.772000

" @@ -706,7 +706,7 @@ "

\n", " \n", " \n", - " [6/6 00:16]\n", + " [6/6 00:20]\n", "
\n", " " ], @@ -726,14 +726,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 5.045756816864014\n", - " eval_/next-item/ndcg_at_20 = 0.1720433533191681\n", - " eval_/next-item/ndcg_at_40 = 0.2304922342300415\n", - " eval_/next-item/recall_at_20 = 0.4479166865348816\n", - " eval_/next-item/recall_at_40 = 0.734375\n", - " eval_runtime = 0.1534\n", - " eval_samples_per_second = 1251.373\n", - " eval_steps_per_second = 39.105\n", + " eval_/loss = 5.125837326049805\n", + " eval_/next-item/ndcg_at_20 = 0.15405046939849854\n", + " eval_/next-item/ndcg_at_40 = 0.19994235038757324\n", + " eval_/next-item/recall_at_20 = 0.421875\n", + " eval_/next-item/recall_at_40 = 0.6458333730697632\n", + " eval_runtime = 0.1725\n", + " eval_samples_per_second = 1113.287\n", + " eval_steps_per_second = 34.79\n", "['/workspace/data/sessions_by_day/2/train.parquet']\n", "********************\n", "Launch training for day 2 are:\n", @@ -773,7 +773,7 @@ " \n", " \n", " 50\n", - " 4.857400\n", + " 4.879700\n", " \n", " \n", "

" @@ -800,20 +800,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "finished\n", - "********************\n", - "Eval results for day 3 are:\t\n", - "\n", - "********************\n", - "\n", - " eval_/loss = 4.47845983505249\n", - " eval_/next-item/ndcg_at_20 = 0.20010951161384583\n", - " eval_/next-item/ndcg_at_40 = 0.2491903007030487\n", - " eval_/next-item/recall_at_20 = 0.5625\n", - " eval_/next-item/recall_at_40 = 0.8020833730697632\n", - " eval_runtime = 0.159\n", - " eval_samples_per_second = 1207.847\n", - " eval_steps_per_second = 37.745\n" + "finished\n" ] }, { @@ -833,6 +820,19 @@ "name": "stdout", "output_type": "stream", "text": [ + "********************\n", + "Eval results for day 3 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 4.680967807769775\n", + " eval_/next-item/ndcg_at_20 = 0.17892448604106903\n", + " eval_/next-item/ndcg_at_40 = 0.22572219371795654\n", + " eval_/next-item/recall_at_20 = 0.484375\n", + " eval_/next-item/recall_at_40 = 0.7135416865348816\n", + " eval_runtime = 0.1745\n", + " eval_samples_per_second = 1100.148\n", + " eval_steps_per_second = 34.38\n", "['/workspace/data/sessions_by_day/3/train.parquet']\n", "********************\n", "Launch training for day 3 are:\n", @@ -847,7 +847,7 @@ "

\n", " \n", " \n", - " [65/65 00:02, Epoch 5/5]\n", + " [65/65 00:03, Epoch 5/5]\n", "
\n", " \n", " \n", @@ -859,7 +859,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
504.5826004.581400

" @@ -892,14 +892,19 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.428158283233643\n", - " eval_/next-item/ndcg_at_20 = 0.19852674007415771\n", - " eval_/next-item/ndcg_at_40 = 0.257255494594574\n", - " eval_/next-item/recall_at_20 = 0.4895833432674408\n", - " eval_/next-item/recall_at_40 = 0.7760416865348816\n", - " eval_runtime = 0.1462\n", - " eval_samples_per_second = 1313.302\n", - " eval_steps_per_second = 41.041\n" + " eval_/loss = 4.464942455291748\n", + " eval_/next-item/ndcg_at_20 = 0.2020609825849533\n", + " eval_/next-item/ndcg_at_40 = 0.24695619940757751\n", + " eval_/next-item/recall_at_20 = 0.5260416865348816\n", + " eval_/next-item/recall_at_40 = 0.7447916865348816\n", + " eval_runtime = 0.165\n", + " eval_samples_per_second = 1163.669\n", + " eval_steps_per_second = 36.365\n", + "['/workspace/data/sessions_by_day/4/train.parquet']\n", + "********************\n", + "Launch training for day 4 are:\n", + "********************\n", + "\n" ] }, { @@ -915,17 +920,6 @@ " Total optimization steps = 65\n" ] }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['/workspace/data/sessions_by_day/4/train.parquet']\n", - "********************\n", - "Launch training for day 4 are:\n", - "********************\n", - "\n" - ] - }, { "data": { "text/html": [ @@ -945,7 +939,7 @@ " \n", " \n", " 50\n", - " 4.496400\n", + " 4.519200\n", " \n", " \n", "

" @@ -972,20 +966,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "finished\n", - "********************\n", - "Eval results for day 5 are:\t\n", - "\n", - "********************\n", - "\n", - " eval_/loss = 4.330148696899414\n", - " eval_/next-item/ndcg_at_20 = 0.22238120436668396\n", - " eval_/next-item/ndcg_at_40 = 0.2671073079109192\n", - " eval_/next-item/recall_at_20 = 0.5885416865348816\n", - " eval_/next-item/recall_at_40 = 0.8072916865348816\n", - " eval_runtime = 0.1497\n", - " eval_samples_per_second = 1282.976\n", - " eval_steps_per_second = 40.093\n" + "finished\n" ] }, { @@ -1005,6 +986,19 @@ "name": "stdout", "output_type": "stream", "text": [ + "********************\n", + "Eval results for day 5 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 4.440032958984375\n", + " eval_/next-item/ndcg_at_20 = 0.20033857226371765\n", + " eval_/next-item/ndcg_at_40 = 0.24777457118034363\n", + " eval_/next-item/recall_at_20 = 0.5416666865348816\n", + " eval_/next-item/recall_at_40 = 0.7760416865348816\n", + " eval_runtime = 0.1673\n", + " eval_samples_per_second = 1147.626\n", + " eval_steps_per_second = 35.863\n", "['/workspace/data/sessions_by_day/5/train.parquet']\n", "********************\n", "Launch training for day 5 are:\n", @@ -1031,7 +1025,7 @@ " \n", " \n", " 50\n", - " 4.534100\n", + " 4.479100\n", " \n", " \n", "

" @@ -1058,20 +1052,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "finished\n", - "********************\n", - "Eval results for day 6 are:\t\n", - "\n", - "********************\n", - "\n", - " eval_/loss = 4.478173732757568\n", - " eval_/next-item/ndcg_at_20 = 0.16448986530303955\n", - " eval_/next-item/ndcg_at_40 = 0.22266244888305664\n", - " eval_/next-item/recall_at_20 = 0.4739583432674408\n", - " eval_/next-item/recall_at_40 = 0.7552083730697632\n", - " eval_runtime = 0.1538\n", - " eval_samples_per_second = 1248.547\n", - " eval_steps_per_second = 39.017\n" + "finished\n" ] }, { @@ -1091,6 +1072,19 @@ "name": "stdout", "output_type": "stream", "text": [ + "********************\n", + "Eval results for day 6 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 4.369968891143799\n", + " eval_/next-item/ndcg_at_20 = 0.18754208087921143\n", + " eval_/next-item/ndcg_at_40 = 0.23825709521770477\n", + " eval_/next-item/recall_at_20 = 0.5416666865348816\n", + " eval_/next-item/recall_at_40 = 0.7864583730697632\n", + " eval_runtime = 0.1616\n", + " eval_samples_per_second = 1188.118\n", + " eval_steps_per_second = 37.129\n", "['/workspace/data/sessions_by_day/6/train.parquet']\n", "********************\n", "Launch training for day 6 are:\n", @@ -1105,7 +1099,7 @@ "

\n", " \n", " \n", - " [65/65 00:01, Epoch 5/5]\n", + " [65/65 00:02, Epoch 5/5]\n", "
\n", " \n", " \n", @@ -1117,7 +1111,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
504.4833004.459600

" @@ -1150,14 +1144,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.236471652984619\n", - " eval_/next-item/ndcg_at_20 = 0.21075992286205292\n", - " eval_/next-item/ndcg_at_40 = 0.2642228603363037\n", - " eval_/next-item/recall_at_20 = 0.5572916865348816\n", - " eval_/next-item/recall_at_40 = 0.8125\n", - " eval_runtime = 0.1599\n", - " eval_samples_per_second = 1200.538\n", - " eval_steps_per_second = 37.517\n" + " eval_/loss = 4.5033793449401855\n", + " eval_/next-item/ndcg_at_20 = 0.1925697922706604\n", + " eval_/next-item/ndcg_at_40 = 0.23315435647964478\n", + " eval_/next-item/recall_at_20 = 0.53125\n", + " eval_/next-item/recall_at_40 = 0.7291666865348816\n", + " eval_runtime = 0.1677\n", + " eval_samples_per_second = 1144.864\n", + " eval_steps_per_second = 35.777\n" ] }, { @@ -1191,7 +1185,7 @@ "

\n", " \n", " \n", - " [65/65 00:01, Epoch 5/5]\n", + " [65/65 00:02, Epoch 5/5]\n", "
\n", " \n", " \n", @@ -1203,7 +1197,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
504.4571004.455300

" @@ -1236,14 +1230,14 @@ "\n", "********************\n", "\n", - " eval_/loss = 4.431709289550781\n", - " eval_/next-item/ndcg_at_20 = 0.19320915639400482\n", - " eval_/next-item/ndcg_at_40 = 0.24811463057994843\n", + " eval_/loss = 4.4695820808410645\n", + " eval_/next-item/ndcg_at_20 = 0.1891615390777588\n", + " eval_/next-item/ndcg_at_40 = 0.24406516551971436\n", " eval_/next-item/recall_at_20 = 0.5104166865348816\n", - " eval_/next-item/recall_at_40 = 0.78125\n", - " eval_runtime = 0.1487\n", - " eval_samples_per_second = 1291.383\n", - " eval_steps_per_second = 40.356\n" + " eval_/next-item/recall_at_40 = 0.7760416865348816\n", + " eval_runtime = 0.1806\n", + " eval_samples_per_second = 1063.043\n", + " eval_steps_per_second = 33.22\n" ] } ], @@ -1308,14 +1302,14 @@ "name": "stdout", "output_type": "stream", "text": [ - " eval_/loss = 4.431709289550781\n", - " eval_/next-item/ndcg_at_20 = 0.19320915639400482\n", - " eval_/next-item/ndcg_at_40 = 0.24811463057994843\n", + " eval_/loss = 4.4695820808410645\n", + " eval_/next-item/ndcg_at_20 = 0.1891615390777588\n", + " eval_/next-item/ndcg_at_40 = 0.24406516551971436\n", " eval_/next-item/recall_at_20 = 0.5104166865348816\n", - " eval_/next-item/recall_at_40 = 0.78125\n", - " eval_runtime = 0.1607\n", - " eval_samples_per_second = 1194.43\n", - " eval_steps_per_second = 37.326\n" + " eval_/next-item/recall_at_40 = 0.7760416865348816\n", + " eval_runtime = 0.1776\n", + " eval_samples_per_second = 1081.009\n", + " eval_steps_per_second = 33.782\n" ] } ], diff --git a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb index 8f95ef634f..42cc365de2 100644 --- a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb +++ b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb @@ -81,14 +81,17 @@ "\n", "import cudf\n", "import glob\n", + "import numpy as np\n", + "import pandas as pd\n", "import torch \n", "\n", "from transformers4rec import torch as tr\n", "from merlin.io import Dataset\n", "\n", - "from merlin.core.dispatch import make_df # noqa\n", - "from merlin.systems.dag import Ensemble # noqa\n", - "from merlin.systems.dag.ops.pytorch import PredictPyTorch # noqa" + "from merlin.core.dispatch import make_df \n", + "from merlin.systems.dag import Ensemble \n", + "from merlin.systems.dag.ops.pytorch import PredictPyTorch \n", + "from merlin.systems.dag.ops.workflow import TransformWorkflow " ] }, { @@ -212,8 +215,8 @@ " (categorical_module): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(481, 64, padding_idx=0)\n", - " (category-list): Embedding(145, 64, padding_idx=0)\n", + " (item_id-list): Embedding(497, 64, padding_idx=0)\n", + " (category-list): Embedding(143, 64, padding_idx=0)\n", " )\n", " )\n", " )\n", @@ -284,15 +287,15 @@ " (embeddings): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(481, 64, padding_idx=0)\n", - " (category-list): Embedding(145, 64, padding_idx=0)\n", + " (item_id-list): Embedding(497, 64, padding_idx=0)\n", + " (category-list): Embedding(143, 64, padding_idx=0)\n", " )\n", " )\n", - " (item_embedding_table): Embedding(481, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(497, 64, padding_idx=0)\n", " (masking): MaskedLanguageModeling()\n", " (pre): Block(\n", " (module): NextItemPredictionTask(\n", - " (item_embedding_table): Embedding(481, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(497, 64, padding_idx=0)\n", " (log_softmax): LogSoftmax(dim=-1)\n", " )\n", " )\n", @@ -397,13 +400,13 @@ { "data": { "text/plain": [ - "tensor([[ 7, 7, 41, ..., 0, 0, 0],\n", - " [31, 11, 88, ..., 0, 0, 0],\n", - " [17, 31, 6, ..., 0, 0, 0],\n", + "tensor([[ 11, 16, 52, ..., 0, 0, 0],\n", + " [ 28, 1, 18, ..., 0, 0, 0],\n", + " [167, 113, 61, ..., 0, 0, 0],\n", " ...,\n", - " [ 2, 35, 0, ..., 0, 0, 0],\n", - " [ 6, 12, 20, ..., 0, 0, 0],\n", - " [31, 10, 11, ..., 0, 0, 0]], device='cuda:0')" + " [ 3, 58, 61, ..., 0, 0, 0],\n", + " [ 1, 21, 0, ..., 0, 0, 0],\n", + " [ 4, 15, 41, ..., 0, 0, 0]], device='cuda:0')" ] }, "execution_count": 10, @@ -556,7 +559,7 @@ " \n", " 2\n", " item_id-list\n", - " (Tags.ID, Tags.ITEM_ID, Tags.ITEM, Tags.CATEGO...\n", + " (Tags.LIST, Tags.ID, Tags.CATEGORICAL, Tags.IT...\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -567,16 +570,16 @@ " 0.0\n", " 0.0\n", " .//categories/unique.item_id.parquet\n", - " 481.0\n", - " 51.0\n", + " 497.0\n", + " 52.0\n", " 0.0\n", - " 480.0\n", + " 496.0\n", " item_id\n", " \n", " \n", " 3\n", " category-list\n", - " (Tags.CATEGORICAL, Tags.LIST)\n", + " (Tags.LIST, Tags.CATEGORICAL)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -587,10 +590,10 @@ " 0.0\n", " 0.0\n", " .//categories/unique.category.parquet\n", - " 145.0\n", + " 143.0\n", " 26.0\n", " 0.0\n", - " 144.0\n", + " 142.0\n", " category\n", " \n", " \n", @@ -598,7 +601,7 @@ "" ], "text/plain": [ - "[{'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 481.0, 'dimension': 51.0}, 'domain': {'min': 0, 'max': 480, 'name': 'item_id'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 145.0, 'dimension': 26.0}, 'domain': {'min': 0, 'max': 144, 'name': 'category'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" + "[{'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 497.0, 'dimension': 52.0}, 'domain': {'min': 0, 'max': 496, 'name': 'item_id'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 143.0, 'dimension': 26.0}, 'domain': {'min': 0, 'max': 142, 'name': 'category'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" ] }, "execution_count": 14, @@ -635,49 +638,55 @@ }, { "cell_type": "markdown", - "id": "e3449615-2120-402d-b5c3-1544ee3224dd", + "id": "845c935c-8bfc-4c92-bba2-2a915edbfc68", "metadata": {}, "source": [ - "We use `PredictPyTorch` operator that takes a pytorch model and packages it correctly for tritonserver to run on the PyTorch backend." + "We want to serve NVT model and our trained session-based model together as an ensemble to the Triton Inference Server. That way we can send raw requests to Triton and return back item scores per session. For that we need to load our save workflow first." ] }, { "cell_type": "code", "execution_count": 16, - "id": "4f96597c-1c05-4fb0-ad3e-c55c21599158", + "id": "0360f457-a198-42c0-b860-3356f7956ab9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['item_id', 'category', 'day', 'age_days', 'weekday_sin', 'session_id']\n" + ] + } + ], "source": [ - "torch_op = input_schema.column_names >> PredictPyTorch(\n", - " traced_model, input_schema, output_schema\n", - ")" + "from nvtabular.workflow import Workflow\n", + "workflow = Workflow.load(os.path.join(INPUT_DATA_DIR, \"workflow_etl\"))\n", + "print(workflow.input_schema.column_names)" ] }, { "cell_type": "markdown", - "id": "f36b0f90-5392-4d7a-9a48-57737cf63cd1", + "id": "766bb0dd-5766-42a9-949f-45b71baef8b4", "metadata": {}, "source": [ - "The last step is to create the ensemble artifacts that Triton Inference Server can consume. To make these artifacts, we import the Ensemble class. The class is responsible for interpreting the graph and exporting the correct files for the server.\n", - "\n", - "When we create an `Ensemble` object we supply the graph and a schema representing the starting input of the graph. The inputs to the ensemble graph are the inputs to the first operator of out graph. After we created the Ensemble we export the graph, supplying an export path for the `ensemble.export` function. This returns an ensemble config which represents the entire inference pipeline and a list of node-specific configs." + "Our workflow has two columns that are not fed to model as input features, so we need to remove these two columns from workflow output schema." ] }, { "cell_type": "code", "execution_count": 17, - "id": "5b0d14bb-7765-45e8-8fd0-9d508dc3ec14", + "id": "acb30571-92a4-4abe-a808-23019db74d11", "metadata": {}, "outputs": [], "source": [ - "ensemble = Ensemble(torch_op, input_schema)\n", - "ens_config, node_configs = ensemble.export(ens_model_path)" + "del workflow.output_schema.column_schemas['session_id']\n", + "del workflow.output_schema.column_schemas['day-first']" ] }, { "cell_type": "code", "execution_count": 18, - "id": "a3ba86eb-ca25-4a0c-9daf-61c9911b29ab", + "id": "b5c315c7-b4bf-44b0-8807-4cb78efc5279", "metadata": { "tags": [] }, @@ -708,32 +717,68 @@ " dtype\n", " is_list\n", " is_ragged\n", - " properties.triton_scalar_shape\n", - " properties.value_count.min\n", - " properties.value_count.max\n", " properties.num_buckets\n", " properties.freq_threshold\n", " properties.max_size\n", " properties.start_index\n", " properties.cat_path\n", - " properties.embedding_sizes.cardinality\n", - " properties.embedding_sizes.dimension\n", " properties.domain.min\n", " properties.domain.max\n", " properties.domain.name\n", + " properties.embedding_sizes.cardinality\n", + " properties.embedding_sizes.dimension\n", + " properties.value_count.min\n", + " properties.value_count.max\n", " \n", " \n", " \n", " \n", " 0\n", - " weekday_sin-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", - " DType(name='float32', element_type=<ElementTyp...\n", + " item_id-list\n", + " (Tags.LIST, Tags.ID, Tags.CATEGORICAL, Tags.IT...\n", + " DType(name='int64', element_type=<ElementType....\n", + " True\n", + " False\n", + " NaN\n", + " 0.0\n", + " 0.0\n", + " 0.0\n", + " .//categories/unique.item_id.parquet\n", + " 0.0\n", + " 496.0\n", + " item_id\n", + " 497.0\n", + " 52.0\n", + " 20\n", + " 20\n", + " \n", + " \n", + " 1\n", + " category-list\n", + " (Tags.LIST, Tags.CATEGORICAL)\n", + " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", - " []\n", + " NaN\n", + " 0.0\n", + " 0.0\n", + " 0.0\n", + " .//categories/unique.category.parquet\n", + " 0.0\n", + " 142.0\n", + " category\n", + " 143.0\n", + " 26.0\n", " 20\n", " 20\n", + " \n", + " \n", + " 2\n", + " age_days-list\n", + " (Tags.LIST, Tags.CONTINUOUS)\n", + " DType(name='float32', element_type=<ElementTyp...\n", + " True\n", + " False\n", " NaN\n", " NaN\n", " NaN\n", @@ -744,17 +789,16 @@ " NaN\n", " NaN\n", " NaN\n", + " 20\n", + " 20\n", " \n", " \n", - " 1\n", - " age_days-list\n", + " 3\n", + " weekday_sin-list\n", " (Tags.LIST, Tags.CONTINUOUS)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", - " []\n", - " 20\n", - " 20\n", " NaN\n", " NaN\n", " NaN\n", @@ -765,58 +809,163 @@ " NaN\n", " NaN\n", " NaN\n", + " 20\n", + " 20\n", + " \n", + " \n", + "\n", + "" + ], + "text/plain": [ + "[{'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 496, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 497, 'dimension': 52}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 142, 'name': 'category'}, 'embedding_sizes': {'cardinality': 143, 'dimension': 26}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# we only keep the columns that are input to the model\n", + "workflow.output_schema" + ] + }, + { + "cell_type": "markdown", + "id": "e3449615-2120-402d-b5c3-1544ee3224dd", + "metadata": {}, + "source": [ + "For transforming the raw input features during inference, we use [TransformWorkflow](https://github.com/NVIDIA-Merlin/systems/blob/main/merlin/systems/dag/ops/workflow.py) operator that ensures the workflow is correctly saved and packaged with the required config so the server will know how to load it. We use [PredictPyTorch](https://github.com/NVIDIA-Merlin/systems/blob/main/merlin/systems/dag/ops/pytorch.py) operator that takes a pytorch model and packages it correctly for tritonserver to run on the PyTorch backend." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "4f96597c-1c05-4fb0-ad3e-c55c21599158", + "metadata": {}, + "outputs": [], + "source": [ + "torch_op = workflow.input_schema.column_names >> TransformWorkflow(workflow) >> PredictPyTorch(\n", + " traced_model, input_schema, output_schema\n", + ")\n", + "\n", + "ensemble = Ensemble(torch_op, workflow.input_schema)" + ] + }, + { + "cell_type": "markdown", + "id": "f36b0f90-5392-4d7a-9a48-57737cf63cd1", + "metadata": {}, + "source": [ + "The last step is to create the ensemble artifacts that Triton Inference Server can consume. To make these artifacts, we import the Ensemble class. The class is responsible for interpreting the graph and exporting the correct files for the server.\n", + "\n", + "When we create an `Ensemble` object we supply the graph and a schema representing the starting input of the graph. The inputs to the ensemble graph are the inputs to the first operator of out graph. After we created the Ensemble we export the graph, supplying an export path for the `ensemble.export` function. This returns an ensemble config which represents the entire inference pipeline and a list of node-specific configs." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "5b0d14bb-7765-45e8-8fd0-9d508dc3ec14", + "metadata": {}, + "outputs": [], + "source": [ + "ens_config, node_configs = ensemble.export(ens_model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "a3ba86eb-ca25-4a0c-9daf-61c9911b29ab", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", - " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", "
nametagsdtypeis_listis_ragged
0item_id()DType(name='int32', element_type=<ElementType....FalseFalse
1category()DType(name='int32', element_type=<ElementType....FalseFalse
2item_id-list(Tags.ID, Tags.ITEM_ID, Tags.ITEM, Tags.CATEGO...day()DType(name='int64', element_type=<ElementType....TrueFalse[]2020NaN0.00.00.0.//categories/unique.item_id.parquet481.051.00.0480.0item_idFalse
3category-list(Tags.CATEGORICAL, Tags.LIST)age_days()DType(name='float32', element_type=<ElementTyp...FalseFalse
4weekday_sin()DType(name='float32', element_type=<ElementTyp...FalseFalse
5session_id()DType(name='int64', element_type=<ElementType....TrueFalse[]2020NaN0.00.00.0.//categories/unique.category.parquet145.026.00.0144.0categoryFalse
\n", "
" ], "text/plain": [ - "[{'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'triton_scalar_shape': [], 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'triton_scalar_shape': [], 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 481.0, 'dimension': 51.0}, 'domain': {'min': 0, 'max': 480, 'name': 'item_id'}, 'triton_scalar_shape': [], 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 145.0, 'dimension': 26.0}, 'domain': {'min': 0, 'max': 144, 'name': 'category'}, 'triton_scalar_shape': [], 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" + "[{'name': 'item_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'category', 'tags': set(), 'properties': {}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'age_days', 'tags': set(), 'properties': {}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'weekday_sin', 'tags': set(), 'properties': {}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'session_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]" ] }, - "execution_count": 18, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -847,7 +996,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 22, "id": "46a86c8d-9ec1-4422-8f8c-4d49e83f6783", "metadata": { "tags": [] @@ -882,7 +1031,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 23, "id": "dda3f852-a019-4bf1-831b-f63b750a1192", "metadata": { "tags": [] @@ -896,18 +1045,19 @@ "\n", "POST /v2/repository/index, headers None\n", "\n", - "\n", - "bytearray(b'[{\"name\":\"0_predictpytorchtriton\",\"version\":\"1\",\"state\":\"READY\"},{\"name\":\"executor_model\",\"version\":\"1\",\"state\":\"READY\"}]')\n" + "\n", + "bytearray(b'[{\"name\":\"0_transformworkflowtriton\",\"version\":\"1\",\"state\":\"READY\"},{\"name\":\"1_predictpytorchtriton\",\"version\":\"1\",\"state\":\"READY\"},{\"name\":\"executor_model\",\"version\":\"1\",\"state\":\"READY\"}]')\n" ] }, { "data": { "text/plain": [ - "[{'name': '0_predictpytorchtriton', 'version': '1', 'state': 'READY'},\n", + "[{'name': '0_transformworkflowtriton', 'version': '1', 'state': 'READY'},\n", + " {'name': '1_predictpytorchtriton', 'version': '1', 'state': 'READY'},\n", " {'name': 'executor_model', 'version': '1', 'state': 'READY'}]" ] }, - "execution_count": 20, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -932,55 +1082,21 @@ "metadata": {}, "source": [ "The last step of a machine learning (ML)/deep learning (DL) pipeline is to deploy the model to production, and get responses for a given query or a set of queries.\n", - "In this section, we generate a dataframe that we can serve as a request to TIS. Note that this is a transformed dataframe. We also need out dataset list columns to be padded to the max sequence length that was set in the ETL pipeline.\n", - "\n", - "We do not serve the raw dataframe because in the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the deployed DL model for a prediction. Therefore, we use a transformed dataset that is processed similarly as train set. " + "In this section, we generate a dataframe that we can serve as a request to TIS. We do serve the raw dataframe and in the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the deployed DL model for a prediction." ] }, { - "cell_type": "code", - "execution_count": 21, - "id": "0acd5649-31fe-4f3f-87a2-2607477638b5", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "eval_batch_size = 32\n", - "eval_paths = os.path.join(OUTPUT_DIR, f\"{1}/valid.parquet\")\n", - "eval_dataset = Dataset(eval_paths, shuffle=False)\n", - "eval_loader = generate_dataloader(schema, eval_dataset, batch_size=eval_batch_size)\n", - "test_dict = next(iter(eval_loader))" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "1a4894de-939f-4c3b-8c76-6f4d6f91d787", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 7, 5, 5, 12, 7, 19, 62, 14, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", - " 0, 0], device='cuda:0')" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "cell_type": "markdown", + "id": "5e3abf91-67e8-4f04-842c-f128a3379569", + "metadata": {}, "source": [ - "test_dict[0]['item_id-list'][0]" + "Let's generate a dataframe with raw input values. We can send this dataframe to Triton as a request." ] }, { "cell_type": "code", - "execution_count": 23, - "id": "0306fc5a-5f54-4a58-b762-97b38908b290", + "execution_count": 24, + "id": "0acd5649-31fe-4f3f-87a2-2607477638b5", "metadata": { "tags": [] }, @@ -989,122 +1105,29 @@ "name": "stdout", "output_type": "stream", "text": [ - "(32, 4)\n" + " session_id item_id category age_days weekday_sin day\n", + "0 74443 23 10 0.100614 0.034311 4\n", + "1 88512 7 3 0.255196 0.414701 2\n" ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
age_days-listweekday_sin-listitem_id-listcategory-list
0[0.30438614, 0.809415, 0.4060944, 0.3965151, 0...[0.1224156, 0.4469866, 0.0006228724, 0.2732989...[7, 5, 5, 12, 7, 19, 62, 14, 4, 0, 0, 0, 0, 0,...[2, 1, 1, 3, 2, 6, 13, 4, 1, 0, 0, 0, 0, 0, 0,...
1[0.96371543, 0.23374352, 0.79393756, 0.0, 0.0,...[0.17755665, 0.07790468, 0.7028925, 0.0, 0.0, ...[23, 19, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[5, 6, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
2[0.7407414, 0.2928302, 0.1934419, 0.0, 0.0, 0....[0.12156885, 0.051108807, 0.806525, 0.0, 0.0, ...[12, 14, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...[3, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
3[0.8000404, 0.72552097, 0.8033363, 0.06934566,...[0.37198335, 0.5371064, 0.9486511, 0.7411687, ...[21, 58, 13, 5, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0,...[3, 12, 2, 1, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
4[0.10884159, 0.2825696, 0.60954845, 0.0, 0.0, ...[0.09428788, 0.88109225, 0.15780881, 0.0, 0.0,...[5, 10, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...[1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
\n", - "
" - ], - "text/plain": [ - " age_days-list \\\n", - "0 [0.30438614, 0.809415, 0.4060944, 0.3965151, 0... \n", - "1 [0.96371543, 0.23374352, 0.79393756, 0.0, 0.0,... \n", - "2 [0.7407414, 0.2928302, 0.1934419, 0.0, 0.0, 0.... \n", - "3 [0.8000404, 0.72552097, 0.8033363, 0.06934566,... \n", - "4 [0.10884159, 0.2825696, 0.60954845, 0.0, 0.0, ... \n", - "\n", - " weekday_sin-list \\\n", - "0 [0.1224156, 0.4469866, 0.0006228724, 0.2732989... \n", - "1 [0.17755665, 0.07790468, 0.7028925, 0.0, 0.0, ... \n", - "2 [0.12156885, 0.051108807, 0.806525, 0.0, 0.0, ... \n", - "3 [0.37198335, 0.5371064, 0.9486511, 0.7411687, ... \n", - "4 [0.09428788, 0.88109225, 0.15780881, 0.0, 0.0,... \n", - "\n", - " item_id-list \\\n", - "0 [7, 5, 5, 12, 7, 19, 62, 14, 4, 0, 0, 0, 0, 0,... \n", - "1 [23, 19, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "2 [12, 14, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "3 [21, 58, 13, 5, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "4 [5, 10, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", - "\n", - " category-list \n", - "0 [2, 1, 1, 3, 2, 6, 13, 4, 1, 0, 0, 0, 0, 0, 0,... \n", - "1 [5, 6, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "2 [3, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "3 [3, 12, 2, 1, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", - "4 [1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... " - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "df_cols = {}\n", - "for name, tensor in test_dict[0].items():\n", - " if name in input_schema.column_names:\n", - " df_cols[name] = tensor.cpu().numpy()\n", - " if len(tensor.shape) > 1:\n", - " df_cols[name] = list(df_cols[name])\n", - " \n", - "df = make_df(df_cols)\n", - "print(df.shape)\n", - "df.head()" + "NUM_ROWS =1000\n", + "long_tailed_item_distribution = np.clip(np.random.lognormal(3., 1., int(NUM_ROWS)).astype(np.int32), 1, 50000)\n", + "# generate random item interaction features \n", + "df = pd.DataFrame(np.random.randint(70000, 90000, int(NUM_ROWS)), columns=['session_id'])\n", + "df['item_id'] = long_tailed_item_distribution\n", + "\n", + "# generate category mapping for each item-id\n", + "df['category'] = pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32)\n", + "df['age_days'] = np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)\n", + "df['weekday_sin']= np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32)\n", + "\n", + "# generate day mapping for each session \n", + "map_day = dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique()))))\n", + "df['day'] = df.session_id.map(map_day)\n", + "\n", + "print(df.head(2))" ] }, { @@ -1117,7 +1140,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "id": "42091c25-7676-414e-bb8c-8432aeb58297", "metadata": { "tags": [] @@ -1127,31 +1150,31 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'next-item': array([[-10.279172 , -3.3285382, -3.423249 , ..., -9.597292 ,\n", - " -10.079497 , -9.894423 ],\n", - " [-10.279955 , -3.328599 , -3.4235704, ..., -9.596973 ,\n", - " -10.079582 , -9.894649 ],\n", - " [-10.28006 , -3.32861 , -3.4236221, ..., -9.596931 ,\n", - " -10.079575 , -9.894657 ],\n", + "{'next-item': array([[ -9.883799 , -3.4167202, -3.5539691, ..., -9.903243 ,\n", + " -10.080429 , -9.736532 ],\n", + " [ -9.883825 , -3.4169102, -3.5542827, ..., -9.902823 ,\n", + " -10.079994 , -9.736897 ],\n", + " [ -9.883803 , -3.4168637, -3.5542285, ..., -9.902873 ,\n", + " -10.080062 , -9.736796 ],\n", " ...,\n", - " [-10.279462 , -3.3285487, -3.423347 , ..., -9.597226 ,\n", - " -10.079483 , -9.894518 ],\n", - " [-10.28092 , -3.3286107, -3.4241762, ..., -9.596828 ,\n", - " -10.07973 , -9.894941 ],\n", - " [-10.280691 , -3.3285236, -3.424143 , ..., -9.596914 ,\n", - " -10.079795 , -9.894926 ]], dtype=float32)}\n" + " [ -9.883751 , -3.4165516, -3.553697 , ..., -9.903563 ,\n", + " -10.080756 , -9.736205 ],\n", + " [ -9.883795 , -3.416768 , -3.5540614, ..., -9.903113 ,\n", + " -10.080299 , -9.736627 ],\n", + " [ -9.8838 , -3.4167717, -3.5540478, ..., -9.90313 ,\n", + " -10.080301 , -9.736641 ]], dtype=float32)}\n" ] } ], "source": [ "from merlin.systems.triton.utils import send_triton_request\n", - "response = send_triton_request(input_schema, df[input_schema.column_names], output_schema.column_names)\n", + "response = send_triton_request(workflow.input_schema, df, output_schema.column_names)\n", "print(response)" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "id": "0fa425a4-9c00-45ed-a4b1-fd75ca4bf819", "metadata": { "tags": [] @@ -1160,10 +1183,10 @@ { "data": { "text/plain": [ - "(32, 481)" + "(33, 497)" ] }, - "execution_count": 25, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -1179,7 +1202,7 @@ "source": [ "We return a response for each request in the df. Each row in the `response['next-item']` array corresponds to the logit values per item in the catalog, and one logit score corresponding to the null, OOV and padded items. The first score of each array in each row corresponds to the score for the padded item, OOV or null item. Note that we dont have OOV or null items in our syntheticall generated datasets.\n", "\n", - "This is the end of this suit of examples. You successfully performed feature engineering with NVTabular trained transformer architecture based session-based recommendation models with Transformers4Rec deployed a trained model to Triton Inference Server with Torch backend, sent request and got responses from the server. If you would like to learn how to serve a TF4Rec model with Python backend please visit this [example](https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/main/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb)." + "This is the end of this suit of examples. You successfully performed feature engineering with NVTabular trained transformer architecture based session-based recommendation models with Transformers4Rec deployed the saved workflow and the trained model to Triton Inference Server with Torch backend, sent request and got responses from the server. If you would like to learn how to serve a TF4Rec model with Python backend please visit this [example](https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/main/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb)." ] } ], diff --git a/tests/integration/notebooks/test_getting_started_session_based.py b/tests/integration/notebooks/test_getting_started_session_based.py index 70c1cf65f2..e2e380b293 100644 --- a/tests/integration/notebooks/test_getting_started_session_based.py +++ b/tests/integration/notebooks/test_getting_started_session_based.py @@ -82,26 +82,25 @@ def test_func(): tb3.execute_cell(list(range(0, NUM_OF_CELLS - 12))) tb3.inject( """ - eval_batch_size = 4 - eval_paths = os.path.join('/tmp/data/sessions_by_day', f"{1}/valid.parquet") - eval_dataset = Dataset(eval_paths, shuffle=False) - eval_loader = generate_dataloader(schema, eval_dataset, batch_size=eval_batch_size) - test_dict = next(iter(eval_loader)) - df_cols = {} - for name, tensor in test_dict[0].items(): - if name in input_schema.column_names: - df_cols[name] = tensor.cpu().numpy() - if len(tensor.shape) > 1: - df_cols[name] = list(df_cols[name]) - df = make_df(df_cols) + NUM_ROWS =1000 + long_tailed_item_distribution = np.clip(np.random.lognormal(3., 1., int(NUM_ROWS)).astype(np.int32), 1, 50000) + df = pd.DataFrame(np.random.randint(70000, 90000, int(NUM_ROWS)), columns=['session_id']) + df['item_id'] = long_tailed_item_distribution + df['category'] = pd.cut(df['item_id'], bins=334, labels=np.arange(1, 335)).astype(np.int32) + df['age_days'] = np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32) + df['weekday_sin']= np.random.uniform(0, 1, int(NUM_ROWS)).astype(np.float32) + map_day = dict(zip(df.session_id.unique(), np.random.randint(1, 10, size=(df.session_id.nunique())))) + df['day'] = df.session_id.map(map_day) + from merlin.systems.triton.utils import run_ensemble_on_tritonserver response = run_ensemble_on_tritonserver( - "/tmp/data/models", ensemble.graph.input_schema, df[input_schema.column_names], output_schema.column_names, 'executor_model' + "/tmp/data/models", workflow.input_schema, df, output_schema.column_names, 'executor_model' ) - response_array = [x.tolist()[0] for x in response["next-item"]] + response_array = list(response['next-item'][1]) + cardinality = workflow.output_schema['item_id-list'].properties['embedding_sizes']['cardinality'] """ ) tb3.execute_cell(NUM_OF_CELLS - 3) - batch_size = tb3.ref("eval_batch_size") + item_cardinality = tb3.ref("cardinality") response_array = tb3.ref("response_array") - assert len(response_array) == batch_size + assert len(response_array) == item_cardinality From c11c9522b6f70778a912915255f5e9a466a3b7e4 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 11 Apr 2023 19:47:27 +0000 Subject: [PATCH 2/2] remove deleting columns from workflow script --- .../01-ETL-with-NVTabular.ipynb | 202 ++++++------- ...ng-session-based-model-torch-backend.ipynb | 267 ++++-------------- 2 files changed, 157 insertions(+), 312 deletions(-) diff --git a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb index 4185cb5b44..0fcc59b992 100644 --- a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb @@ -189,47 +189,47 @@ " \n", " \n", " 0\n", - " 75363\n", - " 18\n", - " 4\n", - " 0.386272\n", - " 0.817845\n", + " 77378\n", + " 45\n", + " 11\n", + " 0.278357\n", + " 0.584583\n", " 1\n", " \n", " \n", " 1\n", - " 72281\n", - " 5\n", + " 74811\n", + " 2\n", " 1\n", - " 0.436883\n", - " 0.407265\n", + " 0.646034\n", + " 0.403526\n", " 6\n", " \n", " \n", " 2\n", - " 75383\n", - " 9\n", - " 2\n", - " 0.629215\n", - " 0.608197\n", - " 2\n", + " 74106\n", + " 11\n", + " 3\n", + " 0.586394\n", + " 0.040917\n", + " 4\n", " \n", " \n", " 3\n", - " 84734\n", - " 10\n", - " 2\n", - " 0.355827\n", - " 0.620883\n", - " 1\n", + " 82547\n", + " 47\n", + " 12\n", + " 0.989547\n", + " 0.962060\n", + " 5\n", " \n", " \n", " 4\n", - " 81038\n", - " 42\n", - " 8\n", - " 0.265468\n", - " 0.830717\n", + " 72220\n", + " 17\n", + " 4\n", + " 0.604099\n", + " 0.370005\n", " 7\n", " \n", " \n", @@ -238,11 +238,11 @@ ], "text/plain": [ " session_id item_id category age_days weekday_sin day\n", - "0 75363 18 4 0.386272 0.817845 1\n", - "1 72281 5 1 0.436883 0.407265 6\n", - "2 75383 9 2 0.629215 0.608197 2\n", - "3 84734 10 2 0.355827 0.620883 1\n", - "4 81038 42 8 0.265468 0.830717 7" + "0 77378 45 11 0.278357 0.584583 1\n", + "1 74811 2 1 0.646034 0.403526 6\n", + "2 74106 11 3 0.586394 0.040917 4\n", + "3 82547 47 12 0.989547 0.962060 5\n", + "4 72220 17 4 0.604099 0.370005 7" ] }, "execution_count": 6, @@ -455,7 +455,7 @@ " \n", " 2\n", " item_id-list\n", - " (Tags.ITEM_ID, Tags.LIST, Tags.ID, Tags.CATEGO...\n", + " (Tags.ID, Tags.LIST, Tags.ITEM, Tags.ITEM_ID, ...\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -465,17 +465,17 @@ " 0.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 496.0\n", + " 488.0\n", " item_id\n", - " 497.0\n", - " 52.0\n", + " 489.0\n", + " 51.0\n", " 20.0\n", " 20.0\n", " \n", " \n", " 3\n", " category-list\n", - " (Tags.LIST, Tags.CATEGORICAL)\n", + " (Tags.CATEGORICAL, Tags.LIST)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -485,17 +485,17 @@ " 0.0\n", " .//categories/unique.category.parquet\n", " 0.0\n", - " 142.0\n", + " 164.0\n", " category\n", - " 143.0\n", - " 26.0\n", + " 165.0\n", + " 28.0\n", " 20.0\n", " 20.0\n", " \n", " \n", " 4\n", " age_days-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", @@ -515,7 +515,7 @@ " \n", " 5\n", " weekday_sin-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", @@ -537,7 +537,7 @@ "" ], "text/plain": [ - "[{'name': 'session_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 496, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 497, 'dimension': 52}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 142, 'name': 'category'}, 'embedding_sizes': {'cardinality': 143, 'dimension': 26}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" + "[{'name': 'session_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 488, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 489, 'dimension': 51}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 164, 'name': 'category'}, 'embedding_sizes': {'cardinality': 165, 'dimension': 28}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" ] }, "execution_count": 8, @@ -617,24 +617,24 @@ "output_type": "stream", "text": [ " session_id day-first item_id-list \\\n", - "0 70000 7 [19, 1, 86, 8, 28, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "1 70002 7 [9, 110, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "2 70003 6 [8, 229, 1, 21, 2, 17, 5, 7, 4, 0, 0, 0, 0, 0,... \n", + "0 70000 6 [16, 30, 2, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "1 70001 1 [5, 42, 24, 5, 144, 15, 85, 62, 0, 0, 0, 0, 0,... \n", + "2 70002 4 [77, 25, 322, 56, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", "\n", " category-list \\\n", - "0 [4, 1, 17, 3, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "1 [2, 21, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "2 [3, 48, 1, 5, 2, 4, 1, 1, 2, 0, 0, 0, 0, 0, 0,... \n", + "0 [4, 8, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "1 [3, 11, 6, 3, 37, 5, 22, 15, 0, 0, 0, 0, 0, 0,... \n", + "2 [19, 7, 104, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", "\n", " age_days-list \\\n", - "0 [0.566933, 0.11411724, 0.6606563, 0.4054278, 0... \n", - "1 [0.88592035, 0.049719226, 0.67839175, 0.0, 0.0... \n", - "2 [0.7854072, 0.84745556, 0.3776744, 0.3021382, ... \n", + "0 [0.59015936, 0.6703282, 0.39532185, 0.9662245,... \n", + "1 [0.57877135, 0.6347805, 0.088149734, 0.6561545... \n", + "2 [0.5376647, 0.32101643, 0.522189, 0.044942442,... \n", "\n", " weekday_sin-list \n", - "0 [0.8184131, 0.44557166, 0.48090392, 0.35810795... \n", - "1 [0.04808665, 0.33483326, 0.65433854, 0.0, 0.0,... \n", - "2 [0.7141748, 0.40484497, 0.37434393, 0.47856098... \n" + "0 [0.35127148, 0.49160567, 0.6861373, 0.67218935... \n", + "1 [0.13620874, 0.3709723, 0.4606402, 0.6132054, ... \n", + "2 [0.48387563, 0.036944088, 0.39473122, 0.618825... \n" ] } ], @@ -652,7 +652,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Creating time-based splits: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 18.53it/s]\n" + "Creating time-based splits: 100%|██████████| 9/9 [00:00<00:00, 10.90it/s]\n" ] } ], @@ -720,43 +720,43 @@ " \n", " \n", " 0\n", - " 70020\n", - " [11, 16, 52, 4, 160, 11, 0, 0, 0, 0, 0, 0, 0, ...\n", - " [3, 4, 10, 2, 32, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0...\n", - " [0.30004495, 0.8680668, 0.051910266, 0.0361004...\n", - " [0.9554083, 0.5393047, 0.52827483, 0.018015692...\n", + " 70001\n", + " [5, 42, 24, 5, 144, 15, 85, 62, 0, 0, 0, 0, 0,...\n", + " [3, 11, 6, 3, 37, 5, 22, 15, 0, 0, 0, 0, 0, 0,...\n", + " [0.57877135, 0.6347805, 0.088149734, 0.6561545...\n", + " [0.13620874, 0.3709723, 0.4606402, 0.6132054, ...\n", " \n", " \n", " 1\n", - " 70022\n", - " [28, 1, 18, 76, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", - " [6, 1, 4, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", - " [0.28977808, 0.89265364, 0.7273247, 0.9952868,...\n", - " [0.6657577, 0.51649666, 0.34749013, 0.6105231,...\n", + " 70006\n", + " [8, 4, 40, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...\n", + " [2, 1, 10, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [0.8325506, 0.90652025, 0.10265447, 0.82018495...\n", + " [0.7054766, 0.8457117, 0.62498975, 0.302113, 0...\n", " \n", " \n", " 2\n", - " 70024\n", - " [167, 113, 61, 22, 5, 88, 0, 0, 0, 0, 0, 0, 0,...\n", - " [33, 22, 12, 5, 1, 17, 0, 0, 0, 0, 0, 0, 0, 0,...\n", - " [0.44703692, 0.9245853, 0.20643276, 0.02649925...\n", - " [0.813814, 0.06711837, 2.8271123e-05, 0.985612...\n", + " 70009\n", + " [29, 16, 38, 1, 5, 22, 13, 29, 0, 0, 0, 0, 0, ...\n", + " [8, 4, 10, 1, 3, 6, 4, 8, 0, 0, 0, 0, 0, 0, 0,...\n", + " [0.6411028, 0.7770049, 0.6484449, 0.341628, 0....\n", + " [0.3938173, 0.46046025, 0.7886525, 0.54964787,...\n", " \n", " \n", " 4\n", - " 70032\n", - " [31, 5, 23, 14, 42, 7, 57, 5, 17, 0, 0, 0, 0, ...\n", - " [7, 1, 5, 3, 8, 1, 11, 1, 4, 0, 0, 0, 0, 0, 0,...\n", - " [0.8306253, 0.25969487, 0.4956863, 0.56623936,...\n", - " [0.17520918, 0.8598712, 0.83949673, 0.37930673...\n", + " 70017\n", + " [10, 1, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...\n", + " [2, 1, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", + " [0.8229121, 0.6468095, 0.7113366, 0.0, 0.0, 0....\n", + " [0.85833865, 0.061059132, 0.13624768, 0.0, 0.0...\n", " \n", " \n", " 5\n", - " 70060\n", - " [5, 49, 39, 62, 4, 18, 8, 40, 0, 0, 0, 0, 0, 0...\n", - " [1, 10, 8, 13, 2, 4, 3, 8, 0, 0, 0, 0, 0, 0, 0...\n", - " [0.8982996, 0.052276224, 0.7122792, 0.35780925...\n", - " [0.16947733, 0.32332653, 0.94814104, 0.1960088...\n", + " 70030\n", + " [69, 20, 17, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [17, 5, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [0.49626696, 0.36561915, 0.29867646, 0.1756632...\n", + " [0.321747, 0.8811348, 0.7009088, 0.27151287, 0...\n", " \n", " \n", "\n", @@ -764,32 +764,32 @@ ], "text/plain": [ " session_id item_id-list \\\n", - "0 70020 [11, 16, 52, 4, 160, 11, 0, 0, 0, 0, 0, 0, 0, ... \n", - "1 70022 [28, 1, 18, 76, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", - "2 70024 [167, 113, 61, 22, 5, 88, 0, 0, 0, 0, 0, 0, 0,... \n", - "4 70032 [31, 5, 23, 14, 42, 7, 57, 5, 17, 0, 0, 0, 0, ... \n", - "5 70060 [5, 49, 39, 62, 4, 18, 8, 40, 0, 0, 0, 0, 0, 0... \n", + "0 70001 [5, 42, 24, 5, 144, 15, 85, 62, 0, 0, 0, 0, 0,... \n", + "1 70006 [8, 4, 40, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", + "2 70009 [29, 16, 38, 1, 5, 22, 13, 29, 0, 0, 0, 0, 0, ... \n", + "4 70017 [10, 1, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", + "5 70030 [69, 20, 17, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", "\n", " category-list \\\n", - "0 [3, 4, 10, 2, 32, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", - "1 [6, 1, 4, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "2 [33, 22, 12, 5, 1, 17, 0, 0, 0, 0, 0, 0, 0, 0,... \n", - "4 [7, 1, 5, 3, 8, 1, 11, 1, 4, 0, 0, 0, 0, 0, 0,... \n", - "5 [1, 10, 8, 13, 2, 4, 3, 8, 0, 0, 0, 0, 0, 0, 0... \n", + "0 [3, 11, 6, 3, 37, 5, 22, 15, 0, 0, 0, 0, 0, 0,... \n", + "1 [2, 1, 10, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2 [8, 4, 10, 1, 3, 6, 4, 8, 0, 0, 0, 0, 0, 0, 0,... \n", + "4 [2, 1, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "5 [17, 5, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", "\n", " age_days-list \\\n", - "0 [0.30004495, 0.8680668, 0.051910266, 0.0361004... \n", - "1 [0.28977808, 0.89265364, 0.7273247, 0.9952868,... \n", - "2 [0.44703692, 0.9245853, 0.20643276, 0.02649925... \n", - "4 [0.8306253, 0.25969487, 0.4956863, 0.56623936,... \n", - "5 [0.8982996, 0.052276224, 0.7122792, 0.35780925... \n", + "0 [0.57877135, 0.6347805, 0.088149734, 0.6561545... \n", + "1 [0.8325506, 0.90652025, 0.10265447, 0.82018495... \n", + "2 [0.6411028, 0.7770049, 0.6484449, 0.341628, 0.... \n", + "4 [0.8229121, 0.6468095, 0.7113366, 0.0, 0.0, 0.... \n", + "5 [0.49626696, 0.36561915, 0.29867646, 0.1756632... \n", "\n", " weekday_sin-list \n", - "0 [0.9554083, 0.5393047, 0.52827483, 0.018015692... \n", - "1 [0.6657577, 0.51649666, 0.34749013, 0.6105231,... \n", - "2 [0.813814, 0.06711837, 2.8271123e-05, 0.985612... \n", - "4 [0.17520918, 0.8598712, 0.83949673, 0.37930673... \n", - "5 [0.16947733, 0.32332653, 0.94814104, 0.1960088... " + "0 [0.13620874, 0.3709723, 0.4606402, 0.6132054, ... \n", + "1 [0.7054766, 0.8457117, 0.62498975, 0.302113, 0... \n", + "2 [0.3938173, 0.46046025, 0.7886525, 0.54964787,... \n", + "4 [0.85833865, 0.061059132, 0.13624768, 0.0, 0.0... \n", + "5 [0.321747, 0.8811348, 0.7009088, 0.27151287, 0... " ] }, "execution_count": 15, @@ -811,7 +811,7 @@ { "data": { "text/plain": [ - "430" + "436" ] }, "execution_count": 16, diff --git a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb index 42cc365de2..f4ecadc0ee 100644 --- a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb +++ b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb @@ -215,8 +215,8 @@ " (categorical_module): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(497, 64, padding_idx=0)\n", - " (category-list): Embedding(143, 64, padding_idx=0)\n", + " (item_id-list): Embedding(489, 64, padding_idx=0)\n", + " (category-list): Embedding(165, 64, padding_idx=0)\n", " )\n", " )\n", " )\n", @@ -287,15 +287,15 @@ " (embeddings): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(497, 64, padding_idx=0)\n", - " (category-list): Embedding(143, 64, padding_idx=0)\n", + " (item_id-list): Embedding(489, 64, padding_idx=0)\n", + " (category-list): Embedding(165, 64, padding_idx=0)\n", " )\n", " )\n", - " (item_embedding_table): Embedding(497, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(489, 64, padding_idx=0)\n", " (masking): MaskedLanguageModeling()\n", " (pre): Block(\n", " (module): NextItemPredictionTask(\n", - " (item_embedding_table): Embedding(497, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(489, 64, padding_idx=0)\n", " (log_softmax): LogSoftmax(dim=-1)\n", " )\n", " )\n", @@ -400,13 +400,13 @@ { "data": { "text/plain": [ - "tensor([[ 11, 16, 52, ..., 0, 0, 0],\n", - " [ 28, 1, 18, ..., 0, 0, 0],\n", - " [167, 113, 61, ..., 0, 0, 0],\n", + "tensor([[ 5, 42, 24, ..., 0, 0, 0],\n", + " [ 8, 4, 40, ..., 0, 0, 0],\n", + " [ 29, 16, 38, ..., 0, 0, 0],\n", " ...,\n", - " [ 3, 58, 61, ..., 0, 0, 0],\n", - " [ 1, 21, 0, ..., 0, 0, 0],\n", - " [ 4, 15, 41, ..., 0, 0, 0]], device='cuda:0')" + " [ 5, 2, 12, ..., 0, 0, 0],\n", + " [ 29, 11, 13, ..., 0, 0, 0],\n", + " [161, 12, 32, ..., 0, 0, 0]], device='cuda:0')" ] }, "execution_count": 10, @@ -559,7 +559,7 @@ " \n", " 2\n", " item_id-list\n", - " (Tags.LIST, Tags.ID, Tags.CATEGORICAL, Tags.IT...\n", + " (Tags.ID, Tags.CATEGORICAL, Tags.ITEM, Tags.LI...\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", @@ -570,10 +570,10 @@ " 0.0\n", " 0.0\n", " .//categories/unique.item_id.parquet\n", - " 497.0\n", - " 52.0\n", + " 489.0\n", + " 51.0\n", " 0.0\n", - " 496.0\n", + " 488.0\n", " item_id\n", " \n", " \n", @@ -590,10 +590,10 @@ " 0.0\n", " 0.0\n", " .//categories/unique.category.parquet\n", - " 143.0\n", - " 26.0\n", + " 165.0\n", + " 28.0\n", " 0.0\n", - " 142.0\n", + " 164.0\n", " category\n", " \n", " \n", @@ -601,7 +601,7 @@ "" ], "text/plain": [ - "[{'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 497.0, 'dimension': 52.0}, 'domain': {'min': 0, 'max': 496, 'name': 'item_id'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 143.0, 'dimension': 26.0}, 'domain': {'min': 0, 'max': 142, 'name': 'category'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" + "[{'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 489.0, 'dimension': 51.0}, 'domain': {'min': 0, 'max': 488, 'name': 'item_id'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 165.0, 'dimension': 28.0}, 'domain': {'min': 0, 'max': 164, 'name': 'category'}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" ] }, "execution_count": 14, @@ -666,184 +666,29 @@ }, { "cell_type": "markdown", - "id": "766bb0dd-5766-42a9-949f-45b71baef8b4", + "id": "e3449615-2120-402d-b5c3-1544ee3224dd", "metadata": {}, "source": [ - "Our workflow has two columns that are not fed to model as input features, so we need to remove these two columns from workflow output schema." + "For transforming the raw input features during inference, we use [TransformWorkflow](https://github.com/NVIDIA-Merlin/systems/blob/main/merlin/systems/dag/ops/workflow.py) operator that ensures the workflow is correctly saved and packaged with the required config so the server will know how to load it. We use [PredictPyTorch](https://github.com/NVIDIA-Merlin/systems/blob/main/merlin/systems/dag/ops/pytorch.py) operator that takes a pytorch model and packages it correctly for tritonserver to run on the PyTorch backend." ] }, { "cell_type": "code", "execution_count": 17, - "id": "acb30571-92a4-4abe-a808-23019db74d11", + "id": "4f96597c-1c05-4fb0-ad3e-c55c21599158", "metadata": {}, - "outputs": [], - "source": [ - "del workflow.output_schema.column_schemas['session_id']\n", - "del workflow.output_schema.column_schemas['day-first']" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "b5c315c7-b4bf-44b0-8807-4cb78efc5279", - "metadata": { - "tags": [] - }, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nametagsdtypeis_listis_raggedproperties.num_bucketsproperties.freq_thresholdproperties.max_sizeproperties.start_indexproperties.cat_pathproperties.domain.minproperties.domain.maxproperties.domain.nameproperties.embedding_sizes.cardinalityproperties.embedding_sizes.dimensionproperties.value_count.minproperties.value_count.max
0item_id-list(Tags.LIST, Tags.ID, Tags.CATEGORICAL, Tags.IT...DType(name='int64', element_type=<ElementType....TrueFalseNaN0.00.00.0.//categories/unique.item_id.parquet0.0496.0item_id497.052.02020
1category-list(Tags.LIST, Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....TrueFalseNaN0.00.00.0.//categories/unique.category.parquet0.0142.0category143.026.02020
2age_days-list(Tags.LIST, Tags.CONTINUOUS)DType(name='float32', element_type=<ElementTyp...TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN2020
3weekday_sin-list(Tags.LIST, Tags.CONTINUOUS)DType(name='float32', element_type=<ElementTyp...TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN2020
\n", - "
" - ], - "text/plain": [ - "[{'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 496, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 497, 'dimension': 52}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 142, 'name': 'category'}, 'embedding_sizes': {'cardinality': 143, 'dimension': 26}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'session_id', which is not being used by any downstream operator in the ensemble graph.\n", + " warnings.warn(\n", + "/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'day-first', which is not being used by any downstream operator in the ensemble graph.\n", + " warnings.warn(\n" + ] } ], - "source": [ - "# we only keep the columns that are input to the model\n", - "workflow.output_schema" - ] - }, - { - "cell_type": "markdown", - "id": "e3449615-2120-402d-b5c3-1544ee3224dd", - "metadata": {}, - "source": [ - "For transforming the raw input features during inference, we use [TransformWorkflow](https://github.com/NVIDIA-Merlin/systems/blob/main/merlin/systems/dag/ops/workflow.py) operator that ensures the workflow is correctly saved and packaged with the required config so the server will know how to load it. We use [PredictPyTorch](https://github.com/NVIDIA-Merlin/systems/blob/main/merlin/systems/dag/ops/pytorch.py) operator that takes a pytorch model and packages it correctly for tritonserver to run on the PyTorch backend." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "4f96597c-1c05-4fb0-ad3e-c55c21599158", - "metadata": {}, - "outputs": [], "source": [ "torch_op = workflow.input_schema.column_names >> TransformWorkflow(workflow) >> PredictPyTorch(\n", " traced_model, input_schema, output_schema\n", @@ -864,7 +709,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 18, "id": "5b0d14bb-7765-45e8-8fd0-9d508dc3ec14", "metadata": {}, "outputs": [], @@ -874,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 19, "id": "a3ba86eb-ca25-4a0c-9daf-61c9911b29ab", "metadata": { "tags": [] @@ -965,7 +810,7 @@ "[{'name': 'item_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'category', 'tags': set(), 'properties': {}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'age_days', 'tags': set(), 'properties': {}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'weekday_sin', 'tags': set(), 'properties': {}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'session_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]" ] }, - "execution_count": 21, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -996,7 +841,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 20, "id": "46a86c8d-9ec1-4422-8f8c-4d49e83f6783", "metadata": { "tags": [] @@ -1031,7 +876,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 21, "id": "dda3f852-a019-4bf1-831b-f63b750a1192", "metadata": { "tags": [] @@ -1057,7 +902,7 @@ " {'name': 'executor_model', 'version': '1', 'state': 'READY'}]" ] }, - "execution_count": 23, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -1095,7 +940,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 22, "id": "0acd5649-31fe-4f3f-87a2-2607477638b5", "metadata": { "tags": [] @@ -1106,8 +951,8 @@ "output_type": "stream", "text": [ " session_id item_id category age_days weekday_sin day\n", - "0 74443 23 10 0.100614 0.034311 4\n", - "1 88512 7 3 0.255196 0.414701 2\n" + "0 86940 5 3 0.537970 0.002699 6\n", + "1 81596 17 9 0.782668 0.385270 4\n" ] } ], @@ -1140,7 +985,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 23, "id": "42091c25-7676-414e-bb8c-8432aeb58297", "metadata": { "tags": [] @@ -1150,19 +995,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'next-item': array([[ -9.883799 , -3.4167202, -3.5539691, ..., -9.903243 ,\n", - " -10.080429 , -9.736532 ],\n", - " [ -9.883825 , -3.4169102, -3.5542827, ..., -9.902823 ,\n", - " -10.079994 , -9.736897 ],\n", - " [ -9.883803 , -3.4168637, -3.5542285, ..., -9.902873 ,\n", - " -10.080062 , -9.736796 ],\n", + "{'next-item': array([[-10.3142395, -3.4098535, -3.4116247, ..., -9.233587 ,\n", + " -9.641741 , -10.570887 ],\n", + " [-10.314354 , -3.4099286, -3.4116242, ..., -9.233635 ,\n", + " -9.6416 , -10.570789 ],\n", + " [-10.315022 , -3.4105031, -3.4116468, ..., -9.233878 ,\n", + " -9.640664 , -10.570228 ],\n", " ...,\n", - " [ -9.883751 , -3.4165516, -3.553697 , ..., -9.903563 ,\n", - " -10.080756 , -9.736205 ],\n", - " [ -9.883795 , -3.416768 , -3.5540614, ..., -9.903113 ,\n", - " -10.080299 , -9.736627 ],\n", - " [ -9.8838 , -3.4167717, -3.5540478, ..., -9.90313 ,\n", - " -10.080301 , -9.736641 ]], dtype=float32)}\n" + " [-10.314558 , -3.4101386, -3.411638 , ..., -9.233701 ,\n", + " -9.641275 , -10.570612 ],\n", + " [-10.314491 , -3.4100583, -3.4116335, ..., -9.233676 ,\n", + " -9.641399 , -10.570669 ],\n", + " [-10.314852 , -3.4103851, -3.4116478, ..., -9.233808 ,\n", + " -9.640871 , -10.570368 ]], dtype=float32)}\n" ] } ], @@ -1174,7 +1019,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 24, "id": "0fa425a4-9c00-45ed-a4b1-fd75ca4bf819", "metadata": { "tags": [] @@ -1183,10 +1028,10 @@ { "data": { "text/plain": [ - "(33, 497)" + "(23, 489)" ] }, - "execution_count": 26, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" }