From 57d11a290365b1232a48812cb6ee7bda9ffccb7b Mon Sep 17 00:00:00 2001 From: David Gardner <96306125+dagardner-nv@users.noreply.github.com> Date: Wed, 1 May 2024 12:54:45 -0700 Subject: [PATCH] Truncate strings exceeding max_length when inserting to Milvus (#1665) * Adds new helper methods to `morpheus.io.utils`, `cudf_string_cols_exceed_max_bytes` and `truncate_string_cols_by_bytes` * When `truncate_long_strings=True` `MilvusVectorDBResourceService` will truncate all `VARCHAR` fields according to the schema's `max_length` * Add `truncate_long_strings=True` in config for `vdb_upload` pipeline * Set C++ mode to default for example LLM pipelines * Remove issues 1650 & 1651 from `known_issues.md` Closes #1650 Closes #1651 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md). - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - David Gardner (https://github.com/dagardner-nv) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: https://github.com/nv-morpheus/Morpheus/pull/1665 --- docs/source/extra_info/known_issues.md | 2 - examples/llm/cli.py | 2 +- examples/llm/vdb_upload/pipeline.py | 19 ++- examples/llm/vdb_upload/vdb_utils.py | 3 +- morpheus/io/utils.py | 96 ++++++++++++ .../service/vdb/milvus_vector_db_service.py | 73 ++++++--- morpheus/stages/inference/inference_stage.py | 8 +- morpheus/utils/type_aliases.py | 1 + tests/conftest.py | 6 + tests/io/test_io_utils.py | 134 +++++++++++++++++ tests/test_milvus_vector_db_service.py | 138 ++++++++++++++++++ .../milvus_string_collection_conf.json | 3 + 12 files changed, 457 insertions(+), 28 deletions(-) create mode 100755 tests/io/test_io_utils.py create mode 100644 tests/tests_data/service/milvus_string_collection_conf.json diff --git a/docs/source/extra_info/known_issues.md b/docs/source/extra_info/known_issues.md index 014fac3471..9eeb53508e 100644 --- a/docs/source/extra_info/known_issues.md +++ b/docs/source/extra_info/known_issues.md @@ -19,7 +19,5 @@ limitations under the License. - TrainAEStage fails with a Segmentation fault ([#1641](https://github.com/nv-morpheus/Morpheus/pull/1641)) - vdb_upload example pipeline triggers an internal error in Triton ([#1649](https://github.com/nv-morpheus/Morpheus/pull/1649)) -- vdb_upload example pipeline error on inserting large strings ([#1650](https://github.com/nv-morpheus/Morpheus/pull/1650)) -- vdb_upload example pipeline only works with C++ mode disabled ([#1651](https://github.com/nv-morpheus/Morpheus/pull/1651)) Refer to [open issues in the Morpheus project](https://github.com/nv-morpheus/Morpheus/issues) diff --git a/examples/llm/cli.py b/examples/llm/cli.py index 1ea9198dc1..c8aea20320 100644 --- a/examples/llm/cli.py +++ b/examples/llm/cli.py @@ -32,7 +32,7 @@ callback=parse_log_level, help="Specify the logging level to use.") @click.option('--use_cpp', - default=False, + default=True, type=bool, help=("Whether or not to use C++ node and message types or to prefer python. " "Only use as a last resort if bugs are encountered")) diff --git a/examples/llm/vdb_upload/pipeline.py b/examples/llm/vdb_upload/pipeline.py index 494446d16c..5d5fbee8e4 100644 --- a/examples/llm/vdb_upload/pipeline.py +++ b/examples/llm/vdb_upload/pipeline.py @@ -19,7 +19,9 @@ from vdb_upload.helper import process_vdb_sources from morpheus.config import Config +from morpheus.messages import ControlMessage from morpheus.pipeline.pipeline import Pipeline +from morpheus.pipeline.stage_decorator import stage from morpheus.stages.general.monitor_stage import MonitorStage from morpheus.stages.general.trigger_stage import TriggerStage from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage @@ -78,6 +80,20 @@ def pipeline(pipeline_config: Config, monitor_2 = pipe.add_stage( MonitorStage(pipeline_config, description="Inference rate", unit="events", delayed_start=True)) + @stage + def embedding_tensor_to_df(message: ControlMessage, *, embedding_tensor_name='probs') -> ControlMessage: + """ + Copies the probs tensor to the 'embedding' field of the dataframe. + """ + msg_meta = message.payload() + with msg_meta.mutable_dataframe() as df: + embedding_tensor = message.tensors().get_tensor(embedding_tensor_name) + df['embedding'] = embedding_tensor.tolist() + + return message + + embedding_tensor_to_df_stage = pipe.add_stage(embedding_tensor_to_df(pipeline_config)) + vector_db = pipe.add_stage(WriteToVectorDBStage(pipeline_config, **vdb_config)) monitor_3 = pipe.add_stage( @@ -96,7 +112,8 @@ def pipeline(pipeline_config: Config, pipe.add_edge(nlp_stage, monitor_1) pipe.add_edge(monitor_1, embedding_stage) pipe.add_edge(embedding_stage, monitor_2) - pipe.add_edge(monitor_2, vector_db) + pipe.add_edge(monitor_2, embedding_tensor_to_df_stage) + pipe.add_edge(embedding_tensor_to_df_stage, vector_db) pipe.add_edge(vector_db, monitor_3) start_time = time.time() diff --git a/examples/llm/vdb_upload/vdb_utils.py b/examples/llm/vdb_upload/vdb_utils.py index d3aed615d7..7740acbc7c 100644 --- a/examples/llm/vdb_upload/vdb_utils.py +++ b/examples/llm/vdb_upload/vdb_utils.py @@ -315,14 +315,15 @@ def build_cli_configs(source_type, cli_vdb_conf = { # Vector db upload has some significant transaction overhead, batch size here should be as large as possible 'batch_size': 16384, - 'resource_name': vector_db_resource_name, 'embedding_size': embedding_size, 'recreate': True, + 'resource_name': vector_db_resource_name, 'resource_schemas': { vector_db_resource_name: build_defualt_milvus_config(embedding_size) if (vector_db_service == 'milvus') else None, }, 'service': vector_db_service, + 'truncate_long_strings': True, 'uri': vector_db_uri, } diff --git a/morpheus/io/utils.py b/morpheus/io/utils.py index 7c4cfce260..d8b286a8e8 100644 --- a/morpheus/io/utils.py +++ b/morpheus/io/utils.py @@ -14,7 +14,16 @@ # limitations under the License. """IO utilities.""" +import logging + +import pandas as pd + +import cudf + from morpheus.utils.type_aliases import DataFrameType +from morpheus.utils.type_aliases import SeriesType + +logger = logging.getLogger(__name__) def filter_null_data(x: DataFrameType): @@ -31,3 +40,90 @@ def filter_null_data(x: DataFrameType): return x return x[~x['data'].isna()] + + +def cudf_string_cols_exceed_max_bytes(df: cudf.DataFrame, column_max_bytes: dict[str, int]) -> bool: + """ + Checks a cudf DataFrame for string columns that exceed a maximum number of bytes and thus need to be truncated by + calling `truncate_string_cols_by_bytes`. + + This method utilizes a cudf method `Series.str.byte_count()` method that pandas lacks, which can avoid a costly + call to truncate_string_cols_by_bytes. + + Parameters + ---------- + df : DataFrameType + The dataframe to check. + column_max_bytes: dict[str, int] + A mapping of string column names to the maximum number of bytes for each column. + + Returns + ------- + bool + True if truncation is needed, False otherwise. + """ + if not isinstance(df, cudf.DataFrame): + raise ValueError("Expected cudf DataFrame") + + for (col, max_bytes) in column_max_bytes.items(): + series: cudf.Series = df[col] + + assert series.dtype == 'object' + + if series.str.byte_count().max() > max_bytes: + return True + + return False + + +def truncate_string_cols_by_bytes(df: DataFrameType, + column_max_bytes: dict[str, int], + warn_on_truncate: bool = True) -> bool: + """ + Truncates all string columns in a dataframe to a maximum number of bytes. This operation is performed in-place on + the dataframe. + + Parameters + ---------- + df : DataFrameType + The dataframe to truncate. + column_max_bytes: dict[str, int] + A mapping of string column names to the maximum number of bytes for each column. + warn_on_truncate: bool, default True + Whether to log a warning when truncating a column. + + Returns + ------- + bool + True if truncation was performed, False otherwise. + """ + + performed_truncation = False + is_cudf = isinstance(df, cudf.DataFrame) + + for (col, max_bytes) in column_max_bytes.items(): + series: SeriesType = df[col] + + if is_cudf: + series: pd.Series = series.to_pandas() + + assert series.dtype == 'object', f"Expected string column '{col}'" + + encoded_series = series.str.encode(encoding='utf-8', errors='strict') + if encoded_series.str.len().max() > max_bytes: + performed_truncation = True + if warn_on_truncate: + logger.warning("Truncating column '%s' to %d bytes", col, max_bytes) + + truncated_series = encoded_series.str.slice(0, max_bytes) + + # There is a possibility that slicing by max_len will slice a multi-byte character in half setting + # errors='ignore' will cause the resulting string to be truncated after the last full character + decoded_series = truncated_series.str.decode(encoding='utf-8', errors='ignore') + + if is_cudf: + df[col] = cudf.Series.from_pandas(decoded_series) + else: + df[col] = decoded_series + + return performed_truncation diff --git a/morpheus/service/vdb/milvus_vector_db_service.py b/morpheus/service/vdb/milvus_vector_db_service.py index 37cd82d1ba..09c68f15cd 100644 --- a/morpheus/service/vdb/milvus_vector_db_service.py +++ b/morpheus/service/vdb/milvus_vector_db_service.py @@ -20,18 +20,24 @@ import typing from functools import wraps -import pandas as pd - import cudf +from morpheus.io.utils import cudf_string_cols_exceed_max_bytes +from morpheus.io.utils import truncate_string_cols_by_bytes from morpheus.service.vdb.vector_db_service import VectorDBResourceService from morpheus.service.vdb.vector_db_service import VectorDBService +from morpheus.utils.type_aliases import DataFrameType logger = logging.getLogger(__name__) IMPORT_EXCEPTION = None IMPORT_ERROR_MESSAGE = "MilvusVectorDBResourceService requires the milvus and pymilvus packages to be installed." +# Milvus has a max string length in bytes of 65,535. Multi-byte characters like "ñ" will have a string length of 1, the +# byte length encoded as UTF-8 will be 2 +# https://milvus.io/docs/limitations.md#Length-of-a-string +MAX_STRING_LENGTH_BYTES = 65_535 + try: import pymilvus from pymilvus.orm.mutation import MutationResult @@ -222,9 +228,11 @@ class MilvusVectorDBResourceService(VectorDBResourceService): Name of the resource. client : MilvusClient An instance of the MilvusClient for interaction with the Milvus Vector Database. + truncate_long_strings : bool, optional + When true, truncate strings values that are longer than the max length of the field """ - def __init__(self, name: str, client: "MilvusClient") -> None: + def __init__(self, name: str, client: "MilvusClient", truncate_long_strings: bool = False) -> None: if IMPORT_EXCEPTION is not None: raise ImportError(IMPORT_ERROR_MESSAGE) from IMPORT_EXCEPTION @@ -239,13 +247,24 @@ def __init__(self, name: str, client: "MilvusClient") -> None: self._vector_field = None self._fillna_fields_dict = {} + # Mapping of field name to max length for string fields + self._fields_max_length: dict[str, int] = {} + for field in self._fields: if field.dtype == pymilvus.DataType.FLOAT_VECTOR: self._vector_field = field.name else: + # Intentionally excluding pymilvus.DataType.STRING, in our current version it isn't supported, and in + # some database systems string types don't have a max length. + if field.dtype == pymilvus.DataType.VARCHAR: + max_length = field.params.get('max_length') + if max_length is not None: + self._fields_max_length[field.name] = max_length if not field.auto_id: self._fillna_fields_dict[field.name] = field.dtype + self._truncate_long_strings = truncate_long_strings + self._collection.load() def _set_up_collection(self): @@ -275,13 +294,13 @@ def insert(self, data: list[list] | list[dict], **kwargs: dict[str, typing.Any]) return self._insert_result_to_dict(result=result) - def insert_dataframe(self, df: typing.Union[cudf.DataFrame, pd.DataFrame], **kwargs: dict[str, typing.Any]) -> dict: + def insert_dataframe(self, df: DataFrameType, **kwargs: dict[str, typing.Any]) -> dict: """ Insert a dataframe entires into the vector database. Parameters ---------- - df : typing.Union[cudf.DataFrame, pd.DataFrame] + df : DataFrameType Dataframe to be inserted into the collection. **kwargs : dict[str, typing.Any] Extra keyword arguments specific to the vector database implementation. @@ -291,10 +310,6 @@ def insert_dataframe(self, df: typing.Union[cudf.DataFrame, pd.DataFrame], **kwa dict Returns response content as a dictionary. """ - - if isinstance(df, cudf.DataFrame): - df = df.to_pandas() - # Ensure that there are no None values in the DataFrame entries. for field_name, dtype in self._fillna_fields_dict.items(): if dtype in (pymilvus.DataType.VARCHAR, pymilvus.DataType.STRING): @@ -311,11 +326,24 @@ def insert_dataframe(self, df: typing.Union[cudf.DataFrame, pd.DataFrame], **kwa else: logger.info("Skipped checking 'None' in the field: %s, with datatype: %s", field_name, dtype) + needs_truncate = self._truncate_long_strings + if needs_truncate and isinstance(df, cudf.DataFrame): + # Cudf specific optimization, we can avoid a costly call to truncate_string_cols_by_bytes if all of the + # string columns are already below the max length + needs_truncate = cudf_string_cols_exceed_max_bytes(df, self._fields_max_length) + # From the schema, this is the list of columns we need, excluding any auto_id columns column_names = [field.name for field in self._fields if not field.auto_id] + collection_df = df[column_names] + if isinstance(collection_df, cudf.DataFrame): + collection_df = collection_df.to_pandas() + + if needs_truncate: + truncate_string_cols_by_bytes(collection_df, self._fields_max_length, warn_on_truncate=True) + # Note: dataframe columns has to be in the order of collection schema fields.s - result = self._collection.insert(data=df[column_names], **kwargs) + result = self._collection.insert(data=collection_df, **kwargs) self._collection.flush() return self._insert_result_to_dict(result=result) @@ -575,6 +603,8 @@ class MilvusVectorDBService(VectorDBService): The port number for connecting to the Milvus server. alias : str, optional Alias for the Milvus connection, by default "default". + truncate_long_strings : bool, optional + When true, truncate strings values that are longer than the max length of the field **kwargs : dict Additional keyword arguments specific to the Milvus connection configuration. """ @@ -589,13 +619,17 @@ def __init__(self, password: str = "", db_name: str = "", token: str = "", + truncate_long_strings: bool = False, **kwargs: dict[str, typing.Any]): + self._truncate_long_strings = truncate_long_strings self._client = MilvusClient(uri=uri, user=user, password=password, db_name=db_name, token=token, **kwargs) def load_resource(self, name: str, **kwargs: dict[str, typing.Any]) -> MilvusVectorDBResourceService: - - return MilvusVectorDBResourceService(name=name, client=self._client, **kwargs) + return MilvusVectorDBResourceService(name=name, + client=self._client, + truncate_long_strings=self._truncate_long_strings, + **kwargs) def has_store_object(self, name: str) -> bool: """ @@ -688,7 +722,7 @@ def create(self, name: str, overwrite: bool = False, **kwargs: dict[str, typing. for part in partition_conf["partitions"]: self._client.create_partition(collection_name=name, partition_name=part["name"], timeout=timeout) - def _build_schema_conf(self, df: typing.Union[cudf.DataFrame, pd.DataFrame]) -> list[dict]: + def _build_schema_conf(self, df: DataFrameType) -> list[dict]: fields = [] # Always add a primary key @@ -708,7 +742,7 @@ def _build_schema_conf(self, df: typing.Union[cudf.DataFrame, pd.DataFrame]) -> } if (field_dict["dtype"] == pymilvus.DataType.VARCHAR): - field_dict["max_length"] = 65_535 + field_dict["max_length"] = MAX_STRING_LENGTH_BYTES if (field_dict["dtype"] == pymilvus.DataType.FLOAT_VECTOR or field_dict["dtype"] == pymilvus.DataType.BINARY_VECTOR): @@ -726,7 +760,7 @@ def _build_schema_conf(self, df: typing.Union[cudf.DataFrame, pd.DataFrame]) -> def create_from_dataframe(self, name: str, - df: typing.Union[cudf.DataFrame, pd.DataFrame], + df: DataFrameType, overwrite: bool = False, **kwargs: dict[str, typing.Any]) -> None: """ @@ -736,7 +770,7 @@ def create_from_dataframe(self, ---------- name : str Name of the collection. - df : Union[cudf.DataFrame, pd.DataFrame] + df : DataFrameType The dataframe to create the collection from. overwrite : bool, optional Whether to overwrite the collection if it already exists. Default is False. @@ -797,10 +831,7 @@ def insert(self, name: str, data: list[list] | list[dict], **kwargs: dict[str, return resource.insert(data, **kwargs) @with_collection_lock - def insert_dataframe(self, - name: str, - df: typing.Union[cudf.DataFrame, pd.DataFrame], - **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: + def insert_dataframe(self, name: str, df: DataFrameType, **kwargs: dict[str, typing.Any]) -> dict[str, typing.Any]: """ Converts dataframe to rows and insert to a collection in the Milvus vector database. @@ -808,7 +839,7 @@ def insert_dataframe(self, ---------- name : str Name of the collection to be inserted. - df : typing.Union[cudf.DataFrame, pd.DataFrame] + df : DataFrameType Dataframe to be inserted in the collection. **kwargs : dict[str, typing.Any] Additional keyword arguments containing collection configuration. diff --git a/morpheus/stages/inference/inference_stage.py b/morpheus/stages/inference/inference_stage.py index 8b1fa75d3a..ab12afe4d3 100644 --- a/morpheus/stages/inference/inference_stage.py +++ b/morpheus/stages/inference/inference_stage.py @@ -286,8 +286,12 @@ def set_output_fut(resp: TensorMemory, inner_batch, batch_future: mrc.Future): if (_df is not None and not _df.empty): _message_meta = CppMessageMeta(df=_df) _message.payload(_message_meta) - _message.tensors().set_tensor("probs", output_message.get_probs_tensor()) - print(_df) + + response_tensors = output_message.tensors + cm_tensors = _message.tensors() + for (name, tensor) in response_tensors.items(): + cm_tensors.set_tensor(name, tensor) + output_message = _message return output_message diff --git a/morpheus/utils/type_aliases.py b/morpheus/utils/type_aliases.py index f944c3f9cb..cd394664e6 100644 --- a/morpheus/utils/type_aliases.py +++ b/morpheus/utils/type_aliases.py @@ -20,3 +20,4 @@ import cudf DataFrameType = typing.Union[pd.DataFrame, cudf.DataFrame] +SeriesType = typing.Union[pd.Series, cudf.Series] diff --git a/tests/conftest.py b/tests/conftest.py index 1f8f0ef425..30cc8f869d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1035,6 +1035,12 @@ def simple_collection_config_fixture(): yield load_json_file(filename="service/milvus_simple_collection_conf.json") +@pytest.fixture(scope="session", name="string_collection_config") +def string_collection_config_fixture(): + from _utils import load_json_file + yield load_json_file(filename="service/milvus_string_collection_conf.json") + + @pytest.fixture(name="nemollm", scope='session') def nemollm_fixture(fail_missing: bool): """ diff --git a/tests/io/test_io_utils.py b/tests/io/test_io_utils.py new file mode 100755 index 0000000000..1ad46b75cb --- /dev/null +++ b/tests/io/test_io_utils.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import pytest + +import cudf + +from _utils.dataset_manager import DatasetManager +from morpheus.io import utils as io_utils +from morpheus.utils.type_aliases import DataFrameType + +MULTI_BYTE_STRINGS = ["ñäμɛ", "Moρφέας", "taç"] + + +def _mk_df(df_class: Callable[..., DataFrameType], data: dict[str, list[str]]) -> DataFrameType: + """ + Create a dataframe with a 'data' column containing the given data, and some other columns with different data types + """ + num_rows = len(data[list(data.keys())[0]]) + + float_col = [] + int_col = [] + short_str_col = [] + for i in range(num_rows): + float_col.append(i) + int_col.append(i) + short_str_col.append(f"{i}"[0:3]) + + df_data = data.copy() + df_data.update({"float_col": float_col, "int_col": int_col, "short_str_col": short_str_col}) + + return df_class(df_data) + + +@pytest.mark.parametrize( + "data, max_bytes, expected", + [({ + "data": MULTI_BYTE_STRINGS[:] + }, { + "data": 8 + }, True), ({ + "data": MULTI_BYTE_STRINGS[:], "ignored_col": ["a" * 20, "b" * 20, "c" * 20] + }, { + "data": 12 + }, False), ({ + "data": MULTI_BYTE_STRINGS[:] + }, { + "data": 20 + }, False), ({ + "data": ["." * 20] + }, { + "data": 19 + }, True), ({ + "data": ["." * 20] + }, { + "data": 20 + }, False), ({ + "data": ["." * 20] + }, { + "data": 21 + }, False)]) +def test_cudf_needs_truncate(data: list[str], max_bytes: int, expected: bool): + df = _mk_df(cudf.DataFrame, data) + assert io_utils.cudf_string_cols_exceed_max_bytes(df, max_bytes) is expected + + +@pytest.mark.parametrize("warn_on_truncate", [True, False]) +@pytest.mark.parametrize( + "data, max_bytes, expected_data", + [({ + "multibyte_strings": MULTI_BYTE_STRINGS[:], "ascii_strings": ["a" * 20, "b" * 21, "c" * 19] + }, { + "multibyte_strings": 4, "ascii_strings": 20 + }, { + "multibyte_strings": ["ñä", "Moρ", "taç"], "ascii_strings": ["a" * 20, "b" * 20, "c" * 19] + }), + ({ + "data": MULTI_BYTE_STRINGS[:], "ignored_col": ["a" * 20, "b" * 20, "c" * 20] + }, { + "data": 5 + }, { + "data": ["ñä", "Moρ", "taç"], "ignored_col": ["a" * 20, "b" * 20, "c" * 20] + }), ({ + "data": MULTI_BYTE_STRINGS[:] + }, { + "data": 8 + }, { + "data": ["ñäμɛ", "Moρφέ", "taç"] + }), ({ + "data": MULTI_BYTE_STRINGS[:] + }, { + "data": 9 + }, { + "data": ["ñäμɛ", "Moρφέ", "taç"] + }), ({ + "data": MULTI_BYTE_STRINGS[:] + }, { + "data": 12 + }, { + "data": MULTI_BYTE_STRINGS[:] + })]) +def test_truncate_string_cols_by_bytes(dataset: DatasetManager, + data: dict[str, list[str]], + max_bytes: int, + expected_data: dict[str, list[str]], + warn_on_truncate: bool): + df = _mk_df(dataset.df_class, data) + + expect_truncation = (data != expected_data) + expected_df_class = dataset.df_class + + expected_df = _mk_df(expected_df_class, expected_data) + + performed_truncation = io_utils.truncate_string_cols_by_bytes(df, max_bytes, warn_on_truncate=warn_on_truncate) + + assert performed_truncation is expect_truncation + assert isinstance(df, expected_df_class) + + dataset.assert_df_equal(df, expected_df) diff --git a/tests/test_milvus_vector_db_service.py b/tests/test_milvus_vector_db_service.py index 723e7e7f8e..3d0548176d 100644 --- a/tests/test_milvus_vector_db_service.py +++ b/tests/test_milvus_vector_db_service.py @@ -16,14 +16,18 @@ import json import random +import string import numpy as np import pymilvus import pytest from pymilvus import DataType +from pymilvus import MilvusException import cudf +from _utils.dataset_manager import DatasetManager +from morpheus.service.vdb.milvus_vector_db_service import MAX_STRING_LENGTH_BYTES from morpheus.service.vdb.milvus_vector_db_service import FieldSchemaEncoder from morpheus.service.vdb.milvus_vector_db_service import MilvusVectorDBService @@ -71,6 +75,45 @@ def sample_field_fixture(): return pymilvus.FieldSchema(name="test_field", dtype=pymilvus.DataType.INT64) +def _mk_long_string(source_chars: str) -> str: + """ + Yields a string longer than MAX_STRING_LENGTH_BYTES from source chars + """ + source_chars_byte_len = len(source_chars.encode("utf-8")) + source_data = list(source_chars) + + byte_len = 0 + long_str_data = [] + while byte_len <= MAX_STRING_LENGTH_BYTES: + long_str_data.extend(source_data) + byte_len += source_chars_byte_len + + return "".join(long_str_data) + + +@pytest.fixture(scope="module", name="long_ascii_string") +def long_ascii_string_fixture(): + """ + Yields a string longer than MAX_STRING_LENGTH_BYTES containing only ascii (single-byte) characters + """ + return _mk_long_string(string.ascii_letters) + + +@pytest.fixture(scope="module", name="long_multibyte_string") +def long_multibyte_string_fixture(): + """ + Yields a string longer than MAX_STRING_LENGTH_BYTES containing a mix of single and multi-byte characters + """ + return _mk_long_string("Moρφέας") + + +def _truncate_string_by_bytes(s: str, max_bytes: int) -> str: + """ + Truncates a string to the given number of bytes + """ + return s.encode("utf-8")[:max_bytes].decode("utf-8", errors="ignore") + + @pytest.mark.milvus def test_create_and_drop_collection(idx_part_collection_config: dict, milvus_service: MilvusVectorDBService): collection_name = "test_collection" @@ -467,3 +510,98 @@ def test_fse_from_dict(): result = FieldSchemaEncoder.from_dict(data) assert result.name == "test_field" assert result.dtype == pymilvus.DataType.INT64 + + +@pytest.mark.milvus +@pytest.mark.slow +@pytest.mark.parametrize("use_multi_byte_strings", [True, False], ids=["multi_byte", "ascii"]) +@pytest.mark.parametrize("truncate_long_strings", [True, False], ids=["truncate", "no_truncate"]) +@pytest.mark.parametrize("exceed_max_str_len", [True, False], ids=["exceed_max_len", "within_max_len"]) +def test_insert_dataframe(milvus_server_uri: str, + string_collection_config: dict, + dataset: DatasetManager, + use_multi_byte_strings: bool, + truncate_long_strings: bool, + exceed_max_str_len: bool, + long_ascii_string: str, + long_multibyte_string: str): + num_rows = 10 + collection_name = "test_insert_dataframe" + + milvus_service = MilvusVectorDBService(uri=milvus_server_uri, truncate_long_strings=truncate_long_strings) + + # Make sure to drop any existing collection from previous runs. + milvus_service.drop(collection_name) + + # Create a collection. + milvus_service.create(collection_name, **string_collection_config) + + short_str_col_len = -1 + long_str_col_len = -1 + for field_conf in string_collection_config["schema_conf"]["schema_fields"]: + if field_conf["name"] == "short_str_col": + short_str_col_len = field_conf["params"]["max_length"] + + elif field_conf["name"] == "long_str_col": + long_str_col_len = field_conf["params"]["max_length"] + + assert short_str_col_len > 0, "short_str_col length is not set" + assert long_str_col_len == MAX_STRING_LENGTH_BYTES, "long_str_col length is not set to MAX_STRING_LENGTH_BYTES" + + # Construct the dataframe. + ids = [] + embedding_data = [] + long_str_col = [] + short_str_col = [] + + if use_multi_byte_strings: + long_str = long_multibyte_string + else: + long_str = long_ascii_string + + short_str = long_str[:7] + if not exceed_max_str_len: + short_str = _truncate_string_by_bytes(short_str, short_str_col_len) + long_str = _truncate_string_by_bytes(long_str, MAX_STRING_LENGTH_BYTES) + + for i in range(num_rows): + ids.append(i) + embedding_data.append([i / 10.0] * 3) + + long_str_col.append(long_str) + short_str_col.append(short_str) + + df = dataset.df_class({ + "id": ids, "embedding": embedding_data, "long_str_col": long_str_col, "short_str_col": short_str_col + }) + + expected_long_str = [] + for long_str in long_str_col: + if truncate_long_strings: + expected_long_str.append( + long_str.encode("utf-8")[:MAX_STRING_LENGTH_BYTES].decode("utf-8", errors="ignore")) + else: + expected_long_str.append(long_str) + + expected_df = dataset.df_class({ + "id": ids, "embedding": embedding_data, "long_str_col": expected_long_str, "short_str_col": short_str_col + }) + + if (exceed_max_str_len and (not truncate_long_strings)): + with pytest.raises(MilvusException, match="string exceeds max length"): + milvus_service.insert_dataframe(collection_name, df) + + return # Skip the rest of the test if the string column exceeds the maximum length. + + milvus_service.insert_dataframe(collection_name, df) + + # Retrieve inserted data by primary keys. + retrieved_data = milvus_service.retrieve_by_keys(collection_name, ids) + assert len(retrieved_data) == num_rows + + # Clean up the collection. + milvus_service.drop(collection_name) + + result_df = dataset.df_class(retrieved_data) + + dataset.compare_df(result_df, expected_df) diff --git a/tests/tests_data/service/milvus_string_collection_conf.json b/tests/tests_data/service/milvus_string_collection_conf.json new file mode 100644 index 0000000000..a75970a361 --- /dev/null +++ b/tests/tests_data/service/milvus_string_collection_conf.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:adbc34ae22c1037c8308b5521a01597a81d0ea117cc691e72566b463c0be6e9a +size 1083