Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: support milvus-client iterator #2461

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions examples/iterator/iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from pymilvus.milvus_client.milvus_client import MilvusClient
from pymilvus import (
FieldSchema, CollectionSchema, DataType,
)
import numpy as np

collection_name = "test_milvus_client_iterator"
prepare_new_data = True
clean_exist = True

USER_ID = "id"
AGE = "age"
DEPOSIT = "deposit"
PICTURE = "picture"
DIM = 8
NUM_ENTITIES = 10000
rng = np.random.default_rng(seed=19530)


def test_query_iterator(milvus_client: MilvusClient):
# test query iterator
expr = f"10 <= {AGE} <= 25"
output_fields = [USER_ID, AGE]
queryIt = milvus_client.query_iterator(collection_name, filter=expr, batch_size=50, output_fields=output_fields)
page_idx = 0
while True:
res = queryIt.next()
if len(res) == 0:
print("query iteration finished, close")
queryIt.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")

def test_search_iterator(milvus_client: MilvusClient):
vector_to_search = rng.random((1, DIM), np.float32)
search_iterator = milvus_client.search_iterator(collection_name, data=vector_to_search, batch_size=100, anns_field=PICTURE)

page_idx = 0
while True:
res = search_iterator.next()
if len(res) == 0:
print("query iteration finished, close")
search_iterator.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")


def main():
milvus_client = MilvusClient("http://localhost:19530")
if milvus_client.has_collection(collection_name) and clean_exist:
milvus_client.drop_collection(collection_name)
print(f"dropped existed collection{collection_name}")

if not milvus_client.has_collection(collection_name):
fields = [
FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False),
FieldSchema(name=AGE, dtype=DataType.INT64),
FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE),
FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM)
]
schema = CollectionSchema(fields)
milvus_client.create_collection(collection_name, dimension=DIM, schema=schema)

if prepare_new_data:
entities = []
for i in range(NUM_ENTITIES):
entity = {
USER_ID: i,
AGE: (i % 100),
DEPOSIT: float(i),
PICTURE: rng.random((1, DIM))[0]
}
entities.append(entity)
milvus_client.insert(collection_name, entities)
milvus_client.flush(collection_name)
print(f"Finish flush collections:{collection_name}")

index_params = milvus_client.prepare_index_params()

index_params.add_index(
field_name=PICTURE,
index_type='IVF_FLAT',
metric_type='L2',
params={"nlist": 1024}
)
milvus_client.create_index(collection_name, index_params)
milvus_client.load_collection(collection_name)
test_query_iterator(milvus_client=milvus_client)
test_search_iterator(milvus_client=milvus_client)


if __name__ == '__main__':
main()
23 changes: 23 additions & 0 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,26 @@ def is_scipy_sparse(cls, data: Any):
"csr_array",
"spmatrix",
]


def is_sparse_vector_type(data_type: DataType) -> bool:
return data_type == data_type.SPARSE_FLOAT_VECTOR


dense_vector_type_set = {DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR}


def is_dense_vector_type(data_type: DataType) -> bool:
return data_type in dense_vector_type_set


def is_float_vector_type(data_type: DataType):
return is_sparse_vector_type(data_type) or is_dense_vector_type(data_type)


def is_binary_vector_type(data_type: DataType):
return data_type == DataType.BINARY_VECTOR


def is_vector_type(data_type: DataType):
return is_float_vector_type(data_type) or is_binary_vector_type(data_type)
124 changes: 124 additions & 0 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
OmitZeroDict,
construct_cost_extra,
)
from pymilvus.client.utils import is_vector_type
from pymilvus.exceptions import (
DataTypeNotMatchException,
ErrorCode,
MilvusException,
ParamError,
PrimaryKeyException,
)
from pymilvus.orm import utility
from pymilvus.orm.collection import CollectionSchema
from pymilvus.orm.connections import connections
from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED
from pymilvus.orm.iterator import QueryIterator, SearchIterator
from pymilvus.orm.types import DataType

from .index import IndexParams
Expand Down Expand Up @@ -480,6 +484,126 @@ def query(

return res

def query_iterator(
self,
collection_name: str,
batch_size: Optional[int] = 1000,
limit: Optional[int] = UNLIMITED,
filter: Optional[str] = "",
output_fields: Optional[List[str]] = None,
partition_names: Optional[List[str]] = None,
timeout: Optional[float] = None,
**kwargs,
):
if filter is not None and not isinstance(filter, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))

conn = self._get_connection()
# set up schema for iterator
try:
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
except Exception as ex:
logger.error("Failed to describe collection: %s", collection_name)
raise ex from ex

return QueryIterator(
connection=conn,
collection_name=collection_name,
batch_size=batch_size,
limit=limit,
expr=filter,
output_fields=output_fields,
partition_names=partition_names,
schema=schema_dict,
timeout=timeout,
**kwargs,
)

def search_iterator(
self,
collection_name: str,
data: Union[List[list], list],
batch_size: Optional[int] = 1000,
filter: Optional[str] = None,
limit: Optional[int] = UNLIMITED,
output_fields: Optional[List[str]] = None,
search_params: Optional[dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
round_decimal: int = -1,
**kwargs,
):
if filter is not None and not isinstance(filter, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))

conn = self._get_connection()
# set up schema for iterator
try:
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
except Exception as ex:
logger.error("Failed to describe collection: %s", collection_name)
raise ex from ex
# if anns_field is not provided
# if only one vector field, use to search
# if multiple vector fields, raise exception and abort
if anns_field is None or anns_field == "":
vec_field = None
fields = schema_dict[FIELDS]
vec_field_count = 0
for field in fields:
if is_vector_type(field[TYPE]):
vec_field_count += 1
vec_field = field
if vec_field is None:
raise MilvusException(
code=ErrorCode.UNEXPECTED_ERROR,
message="there should be at least one vector field in milvus collection",
)
if vec_field_count > 1:
raise MilvusException(
code=ErrorCode.UNEXPECTED_ERROR,
message="must specify anns_field when there are more than one vector field",
)
anns_field = vec_field["name"]
if anns_field is None or anns_field == "":
raise MilvusException(
code=ErrorCode.UNEXPECTED_ERROR,
message=f"cannot get anns_field name for search iterator, got:{anns_field}",
)
# set up metrics type for search_iterator which is mandatory
if search_params is None:
search_params = {}
if METRIC_TYPE not in search_params:
indexes = conn.list_indexes(collection_name)
for index in indexes:
if anns_field == index.index_name:
params = index.params
for param in params:
if param.key == METRIC_TYPE:
search_params[METRIC_TYPE] = param.value
if METRIC_TYPE not in search_params:
raise MilvusException(
ParamError, f"Cannot set up metrics type for anns_field:{anns_field}"
)

return SearchIterator(
connection=self._get_connection(),
collection_name=collection_name,
data=data,
ann_field=anns_field,
param=search_params,
batch_size=batch_size,
limit=limit,
expr=filter,
partition_names=partition_names,
output_fields=output_fields,
timeout=timeout,
round_decimal=round_decimal,
schema=schema_dict,
**kwargs,
)

def get(
self,
collection_name: str,
Expand Down
1 change: 1 addition & 0 deletions pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
MILVUS_LIMIT = "limit"
BATCH_SIZE = "batch_size"
ID = "id"
TYPE = "type"
METRIC_TYPE = "metric_type"
PARAMS = "params"
DISTANCE = "distance"
Expand Down
Loading