diff --git a/examples/iterator/iterator.py b/examples/iterator/iterator.py new file mode 100644 index 000000000..aa87815b5 --- /dev/null +++ b/examples/iterator/iterator.py @@ -0,0 +1,99 @@ +from pymilvus.milvus_client.milvus_client import MilvusClient +from pymilvus import ( + FieldSchema, CollectionSchema, DataType, +) +import numpy as np + +collection_name = "test_milvus_client_iterator" +prepare_new_data = True +clean_exist = True + +USER_ID = "id" +AGE = "age" +DEPOSIT = "deposit" +PICTURE = "picture" +DIM = 8 +NUM_ENTITIES = 10000 +rng = np.random.default_rng(seed=19530) + + +def test_query_iterator(milvus_client: MilvusClient): + # test query iterator + expr = f"10 <= {AGE} <= 25" + output_fields = [USER_ID, AGE] + queryIt = milvus_client.query_iterator(collection_name, filter=expr, batch_size=50, output_fields=output_fields) + page_idx = 0 + while True: + res = queryIt.next() + if len(res) == 0: + print("query iteration finished, close") + queryIt.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + +def test_search_iterator(milvus_client: MilvusClient): + vector_to_search = rng.random((1, DIM), np.float32) + search_iterator = milvus_client.search_iterator(collection_name, data=vector_to_search, batch_size=100, anns_field=PICTURE) + + page_idx = 0 + while True: + res = search_iterator.next() + if len(res) == 0: + print("query iteration finished, close") + search_iterator.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + + +def main(): + milvus_client = MilvusClient("http://localhost:19530") + if milvus_client.has_collection(collection_name) and clean_exist: + milvus_client.drop_collection(collection_name) + print(f"dropped existed collection{collection_name}") + + if not milvus_client.has_collection(collection_name): + fields = [ + FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema(name=AGE, dtype=DataType.INT64), + FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE), + FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM) + ] + schema = CollectionSchema(fields) + milvus_client.create_collection(collection_name, dimension=DIM, schema=schema) + + if prepare_new_data: + entities = [] + for i in range(NUM_ENTITIES): + entity = { + USER_ID: i, + AGE: (i % 100), + DEPOSIT: float(i), + PICTURE: rng.random((1, DIM))[0] + } + entities.append(entity) + milvus_client.insert(collection_name, entities) + milvus_client.flush(collection_name) + print(f"Finish flush collections:{collection_name}") + + index_params = milvus_client.prepare_index_params() + + index_params.add_index( + field_name=PICTURE, + index_type='IVF_FLAT', + metric_type='L2', + params={"nlist": 1024} + ) + milvus_client.create_index(collection_name, index_params) + milvus_client.load_collection(collection_name) + test_query_iterator(milvus_client=milvus_client) + test_search_iterator(milvus_client=milvus_client) + + +if __name__ == '__main__': + main() diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 3ff7703dd..0d03ca272 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -379,3 +379,26 @@ def is_scipy_sparse(cls, data: Any): "csr_array", "spmatrix", ] + + +def is_sparse_vector_type(data_type: DataType) -> bool: + return data_type == data_type.SPARSE_FLOAT_VECTOR + + +dense_vector_type_set = {DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR} + + +def is_dense_vector_type(data_type: DataType) -> bool: + return data_type in dense_vector_type_set + + +def is_float_vector_type(data_type: DataType): + return is_sparse_vector_type(data_type) or is_dense_vector_type(data_type) + + +def is_binary_vector_type(data_type: DataType): + return data_type == DataType.BINARY_VECTOR + + +def is_vector_type(data_type: DataType): + return is_float_vector_type(data_type) or is_binary_vector_type(data_type) diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index 7e2b8f764..d7f35b046 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -13,8 +13,10 @@ OmitZeroDict, construct_cost_extra, ) +from pymilvus.client.utils import is_vector_type from pymilvus.exceptions import ( DataTypeNotMatchException, + ErrorCode, MilvusException, ParamError, PrimaryKeyException, @@ -22,6 +24,8 @@ from pymilvus.orm import utility from pymilvus.orm.collection import CollectionSchema from pymilvus.orm.connections import connections +from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED +from pymilvus.orm.iterator import QueryIterator, SearchIterator from pymilvus.orm.types import DataType from .index import IndexParams @@ -480,6 +484,126 @@ def query( return res + def query_iterator( + self, + collection_name: str, + batch_size: Optional[int] = 1000, + limit: Optional[int] = UNLIMITED, + filter: Optional[str] = "", + output_fields: Optional[List[str]] = None, + partition_names: Optional[List[str]] = None, + timeout: Optional[float] = None, + **kwargs, + ): + if filter is not None and not isinstance(filter, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) + + conn = self._get_connection() + # set up schema for iterator + try: + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) + except Exception as ex: + logger.error("Failed to describe collection: %s", collection_name) + raise ex from ex + + return QueryIterator( + connection=conn, + collection_name=collection_name, + batch_size=batch_size, + limit=limit, + expr=filter, + output_fields=output_fields, + partition_names=partition_names, + schema=schema_dict, + timeout=timeout, + **kwargs, + ) + + def search_iterator( + self, + collection_name: str, + data: Union[List[list], list], + batch_size: Optional[int] = 1000, + filter: Optional[str] = None, + limit: Optional[int] = UNLIMITED, + output_fields: Optional[List[str]] = None, + search_params: Optional[dict] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None, + round_decimal: int = -1, + **kwargs, + ): + if filter is not None and not isinstance(filter, str): + raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) + + conn = self._get_connection() + # set up schema for iterator + try: + schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs) + except Exception as ex: + logger.error("Failed to describe collection: %s", collection_name) + raise ex from ex + # if anns_field is not provided + # if only one vector field, use to search + # if multiple vector fields, raise exception and abort + if anns_field is None or anns_field == "": + vec_field = None + fields = schema_dict[FIELDS] + vec_field_count = 0 + for field in fields: + if is_vector_type(field[TYPE]): + vec_field_count += 1 + vec_field = field + if vec_field is None: + raise MilvusException( + code=ErrorCode.UNEXPECTED_ERROR, + message="there should be at least one vector field in milvus collection", + ) + if vec_field_count > 1: + raise MilvusException( + code=ErrorCode.UNEXPECTED_ERROR, + message="must specify anns_field when there are more than one vector field", + ) + anns_field = vec_field["name"] + if anns_field is None or anns_field == "": + raise MilvusException( + code=ErrorCode.UNEXPECTED_ERROR, + message=f"cannot get anns_field name for search iterator, got:{anns_field}", + ) + # set up metrics type for search_iterator which is mandatory + if search_params is None: + search_params = {} + if METRIC_TYPE not in search_params: + indexes = conn.list_indexes(collection_name) + for index in indexes: + if anns_field == index.index_name: + params = index.params + for param in params: + if param.key == METRIC_TYPE: + search_params[METRIC_TYPE] = param.value + if METRIC_TYPE not in search_params: + raise MilvusException( + ParamError, f"Cannot set up metrics type for anns_field:{anns_field}" + ) + + return SearchIterator( + connection=self._get_connection(), + collection_name=collection_name, + data=data, + ann_field=anns_field, + param=search_params, + batch_size=batch_size, + limit=limit, + expr=filter, + partition_names=partition_names, + output_fields=output_fields, + timeout=timeout, + round_decimal=round_decimal, + schema=schema_dict, + **kwargs, + ) + def get( self, collection_name: str, diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index b4980204d..b5a88ecb4 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -29,6 +29,7 @@ MILVUS_LIMIT = "limit" BATCH_SIZE = "batch_size" ID = "id" +TYPE = "type" METRIC_TYPE = "metric_type" PARAMS = "params" DISTANCE = "distance"