Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support binary copy in python SDK. #419

Merged
merged 12 commits into from
Mar 20, 2024
7 changes: 5 additions & 2 deletions .github/workflows/python_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ jobs:
run: |
pdm sync -G :all

- name: Install jq
uses: dcarbone/install-jq-action@v2

- name: Start Postgres
env:
GH_TOKEN: ${{ github.token }}
run: |
LATEST_STABLE_VERSION=$(gh release list --repo tensorchord/pgvecto.rs --exclude-drafts --exclude-pre-releases --limit 1 | awk '{print $3}')
docker run --name pgvecto-rs-demo -e POSTGRES_PASSWORD=mysecretpassword -p 5432:5432 -d tensorchord/pgvecto-rs:pg15-${LATEST_STABLE_VERSION}
NIGHTLY_VERSION=$(curl 'https://registry.hub.docker.com/v2/repositories/tensorchord/pgvecto-rs/tags/?ordering=last_updated' | jq '.results[].name' | grep pg15 | head -n 1 | sed 's/"//g')
docker run --name pgvecto-rs-demo -e POSTGRES_PASSWORD=mysecretpassword -p 5432:5432 -d tensorchord/pgvecto-rs:${NIGHTLY_VERSION}

- name: Run Tests
working-directory: bindings/python
Expand Down
57 changes: 57 additions & 0 deletions bindings/python/examples/psycopg_copy_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os

import numpy as np
import psycopg

from pgvecto_rs.psycopg import register_vector

URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
port=os.getenv("DB_PORT", "5432"),
host=os.getenv("DB_HOST", "localhost"),
username=os.getenv("DB_USER", "postgres"),
password=os.getenv("DB_PASS", "mysecretpassword"),
db_name=os.getenv("DB_NAME", "postgres"),
)

# Connect to the DB and init things
with psycopg.connect(URL) as conn:
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
register_vector(conn)
conn.execute("DROP TABLE IF EXISTS documents;")
conn.execute(
"CREATE TABLE documents (id SERIAL PRIMARY KEY, embedding vector(3) NOT NULL);",
)
conn.commit()
try:
embeddings = [
np.array([1, 2, 3]),
np.array([1.0, 2.0, 4.0]),
np.array([1, 3, 4]),
]

with conn.cursor() as cursor, cursor.copy(
"COPY documents (embedding) FROM STDIN (FORMAT BINARY)"
) as copy:
# write row by row
for e in embeddings:
copy.write_row([e])
copy.write_row([[1, 3, 5]])
conn.commit()

# Select the rows using binary format
cur = conn.execute(
"SELECT * FROM documents;",
binary=True,
)
for row in cur.fetchall():
print(row[0], ": ", row[1])

# output will be:
# 1 : [1.0, 2.0, 3.0]
# 2 : [1.0, 2.0, 4.0]
# 3 : [1.0, 3.0, 4.0]
# 4 : [1.0, 3.0, 5.0]
finally:
# Drop the table
conn.execute("DROP TABLE IF EXISTS documents;")
conn.commit()
52 changes: 52 additions & 0 deletions bindings/python/examples/psycopg_copy_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import numpy as np
import psycopg

from pgvecto_rs.psycopg import register_vector
from pgvecto_rs.types import SparseVector

URL = "postgresql://{username}:{password}@{host}:{port}/{db_name}".format(
port=os.getenv("DB_PORT", "5432"),
host=os.getenv("DB_HOST", "localhost"),
username=os.getenv("DB_USER", "postgres"),
password=os.getenv("DB_PASS", "mysecretpassword"),
db_name=os.getenv("DB_NAME", "postgres"),
)


# Connect to the DB and init things
with psycopg.connect(URL) as conn:
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors;")
register_vector(conn)
conn.execute("DROP TABLE IF EXISTS documents;")
conn.execute(
"CREATE TABLE documents (id SERIAL PRIMARY KEY, embedding svector NOT NULL);",
)
conn.commit()
try:
with conn.cursor() as cursor, cursor.copy(
"COPY documents (embedding) FROM STDIN (FORMAT BINARY)"
) as copy:
copy.write_row([SparseVector(3, [0, 2], [1.0, 3.0])])
copy.write_row([SparseVector(3, np.array([0, 1, 2]), [1.0, 2.0, 3.0])])
copy.write_row([SparseVector(3, np.array([1, 2]), np.array([2.0, 3.0]))])
conn.pgconn.flush()
conn.commit()

# Select the rows using binary format
cur = conn.execute(
"SELECT * FROM documents;",
binary=True,
)
for row in cur.fetchall():
print(row[0], ": ", row[1])

# output will be:
# 1 : [1.0, 0.0, 3.0]
# 2 : [1.0, 2.0, 3.0]
# 3 : [0.0, 2.0, 3.0]
finally:
# Drop the table
conn.execute("DROP TABLE IF EXISTS documents;")
conn.commit()
20 changes: 20 additions & 0 deletions bindings/python/src/pgvecto_rs/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import numpy as np


Expand All @@ -23,3 +25,21 @@ def __init__(self) -> None:
class VectorDimensionError(PGVectoRsError):
def __init__(self, dim: int) -> None:
super().__init__(f"vector dimension must be > 0, got {dim}")


class SparseVectorTypeError(PGVectoRsError):
def __init__(
self, field: str, expected_type: List[type], actual_type: type
) -> None:
super().__init__(
f"{field} in SparseVector must be of type { ' or '.join(map(lambda t: t.__name__, expected_type))}, got {actual_type.__name__}"
)


class SparseVectorElementTypeError(PGVectoRsError):
def __init__(
self, field: str, expected_type: List[type], actual_type: type
) -> None:
super().__init__(
f"elements of {field} in SparseVector must be of type { ' or '.join(map(lambda t: t.__name__, expected_type))}, got {actual_type.__name__}"
)
75 changes: 67 additions & 8 deletions bindings/python/src/pgvecto_rs/psycopg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,41 @@
from psycopg.pq import Format
from psycopg.types import TypeInfo

from pgvecto_rs.utils.serializer import from_db_str, to_db_str
from pgvecto_rs.types import SparseVector
from pgvecto_rs.utils.serializer import (
from_db_binary,
from_db_binary_sparse,
from_db_str,
to_db_binary,
to_db_binary_sparse,
to_db_str,
)

__all__ = ["register_vector"]


class VectorDumper(Dumper):
class VectorTextDumper(Dumper):
format = Format.TEXT

def dump(self, obj):
return to_db_str(obj).encode("utf8")


class VectorLoader(Loader):
class VectorBinaryDumper(Dumper):
format = Format.BINARY

def dump(self, obj):
return to_db_binary(obj)


class SparseVectorTextDumper(Dumper):
format = Format.BINARY

def dump(self, obj):
return to_db_binary_sparse(obj)


class VectorTextLoader(Loader):
format = Format.TEXT

def load(self, data):
Expand All @@ -25,25 +47,62 @@ def load(self, data):
return from_db_str(data.decode("utf8"))


class VectorBinaryLoader(Loader):
format = Format.BINARY

def load(self, data):
return from_db_binary(data)


class SparseVectorBinaryLoader(Loader):
format = Format.BINARY

def load(self, data):
return from_db_binary_sparse(data)


def register_vector(context: Connection):
info = TypeInfo.fetch(context, "vector")
register_vector_info(context, info)
info = TypeInfo.fetch(context, "svector")
register_svector_info(context, info)


async def register_vector_async(context: Connection):
info = await TypeInfo.fetch(context, "vector")
register_vector_info(context, info)
info = await TypeInfo.fetch(context, "svector")
register_svector_info(context, info)


def register_vector_info(context: Connection, info: TypeInfo):
if info is None:
raise ProgrammingError(info="vector type not found in the database")
info.register(context)

class VectorTextDumper(VectorDumper):
oid = info.oid
# Dumper for text and binary
vector_text_dumper = type("", (VectorTextDumper,), {"oid": info.oid})
vector_binary_dumper = type("", (VectorBinaryDumper,), {"oid": info.oid})

# Register the dumper and loader
adapters = context.adapters
adapters.register_dumper(list, vector_text_dumper)
adapters.register_dumper(ndarray, vector_text_dumper)
adapters.register_dumper(list, vector_binary_dumper)
adapters.register_dumper(ndarray, vector_binary_dumper)
adapters.register_loader(info.oid, VectorTextLoader)
adapters.register_loader(info.oid, VectorBinaryLoader)


def register_svector_info(context: Connection, info: TypeInfo):
if info is None:
raise ProgrammingError(info="svector type not found in the database")
info.register(context)

# Dumper for binary
svector_binary_dumper = type("", (SparseVectorTextDumper,), {"oid": info.oid})

# Register the dumper and loader
adapters = context.adapters
adapters.register_dumper(list, VectorTextDumper)
adapters.register_dumper(ndarray, VectorTextDumper)
adapters.register_loader(info.oid, VectorLoader)
adapters.register_dumper(SparseVector, svector_binary_dumper)
adapters.register_loader(info.oid, SparseVectorBinaryLoader)
14 changes: 14 additions & 0 deletions bindings/python/src/pgvecto_rs/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import collections

SparseVector = collections.namedtuple("SparseVector", "dims indices values")


def print_sparse_vector(sparse_vector):
dense = [0.0] * sparse_vector.dims
for i, v in zip(sparse_vector.indices, sparse_vector.values):
dense[i] = v
return str(dense)


# override __str__ method of SparseVector
SparseVector.__str__ = print_sparse_vector
66 changes: 66 additions & 0 deletions bindings/python/src/pgvecto_rs/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
BuiltinListTypeError,
NDArrayDimensionError,
NDArrayDtypeError,
SparseVectorElementTypeError,
SparseVectorTypeError,
)
from pgvecto_rs.types import SparseVector


def ignore_none(func):
Expand Down Expand Up @@ -52,3 +55,66 @@ def _func(value: list, *args, **kwargs):
return func(value, *args, **kwargs)

return _func


def check_indices(indices) -> np.ndarray:
# check indices is a list or ndarray
if not isinstance(indices, (np.ndarray, list)):
raise SparseVectorTypeError("indices", [list, np.ndarray], type(indices))
if isinstance(indices, np.ndarray):
if indices.ndim != 1:
raise NDArrayDimensionError(indices.ndim)
# check indices is integer
if not np.issubdtype(indices.dtype, np.integer):
raise SparseVectorElementTypeError(
"indices", [int, np.integer], indices.dtype
)
if isinstance(indices, list):
for x in indices:
# check indices is integer
if not isinstance(x, int):
raise SparseVectorElementTypeError(
"indices", [int, np.integer], type(x)
)
indices = np.array(indices, dtype=np.uint32)
return indices


def check_values(values) -> np.ndarray:
# check values is a list or ndarray
if not isinstance(values, (np.ndarray, list)):
raise SparseVectorTypeError("values", [list, np.ndarray], type(values))
if isinstance(values, np.ndarray):
if values.ndim != 1:
raise NDArrayDimensionError(values.ndim)
# check values is numeric
if not np.issubdtype(values.dtype, np.number):
raise SparseVectorElementTypeError(
"values", [int, float, np.number], values.dtype
)
if isinstance(values, list):
for x in values:
# check values is numeric
if not isinstance(x, (int, float)):
raise SparseVectorElementTypeError(
"values", [int, float, np.number], type(x)
)
values = np.array(values, dtype=np.float32)
return values


def validate_sparse_vector(func):
"""Validate sparse vector data type"""

@wraps(func)
def _func(vector: SparseVector, *args, **kwargs):
if isinstance(vector, SparseVector):
(dims, indices, values) = vector
if not isinstance(dims, int):
raise SparseVectorTypeError("dims", [int], type(dims))
indices = check_indices(indices)
values = check_values(values)
return func(SparseVector(dims, indices, values), *args, **kwargs)
return func(vector, *args, **kwargs)

return _func
Loading
Loading