Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions client/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ def test_insert_rejects_invalid_vector(client):
)


def test_insert_batch_success(client, mock_connection):
mock_connection.call.return_value = Mock(
ids=[
Mock(id=Mock(value="p1")),
Mock(id=Mock(value="p2")),
]
)

point_ids = client.insert_batch(
points=[
(DenseVector([1, 2, 3]), Payload.text("hello")),
(DenseVector([4, 5, 6]), Payload.text("world")),
]
)

assert point_ids == ["p1", "p2"]


def test_insert_batch_rejects_invalid_vector(client):
with pytest.raises(TypeError):
client.insert_batch(
points=[
(DenseVector([1, 2, 3]), Payload.text("hello")),
([4, 5, 6], Payload.text("world")),
]
)


# Get

def test_get_point_success(client, mock_connection):
Expand Down Expand Up @@ -118,6 +146,43 @@ def test_search_invalid_vector(client):
)


def test_search_batch_success(client, mock_connection):
mock_connection.call.return_value = Mock(
results=[
Mock(
result_point_ids=[
Mock(id=Mock(value="p1")),
Mock(id=Mock(value="p2")),
]
),
Mock(
result_point_ids=[
Mock(id=Mock(value="p3")),
]
),
]
)

results = client.search_batch(
queries=[
(DenseVector([1, 2, 3]), Similarity.COSINE, 2),
(DenseVector([4, 5, 6]), Similarity.COSINE, 1),
]
)

assert results == [["p1", "p2"], ["p3"]]


def test_search_batch_rejects_invalid_vector(client):
with pytest.raises(TypeError):
client.search_batch(
queries=[
(DenseVector([1, 2, 3]), Similarity.COSINE, 2),
([4, 5, 6], Similarity.COSINE, 1),
]
)


# Close

def test_close_closes_connection(client, mock_connection):
Expand Down
71 changes: 60 additions & 11 deletions client/python/vortexdb/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Sequence

from vortexdb.connection import GRPCConnection
from vortexdb.config import VortexDBConfig
Expand Down Expand Up @@ -38,11 +38,7 @@ def insert(self, *, vector: DenseVector, payload: Payload) -> str:
Insert a vector with payload.
Returns: point_id (str)
"""
if not isinstance(vector, DenseVector):
raise TypeError(
"vector must be a DenseVector. "
"Use: DenseVector([1.0, 2.0, 3.0])"
)
self._validate_dense_vector(vector)

request = proto.build_insert_request(
vector=vector,
Expand All @@ -56,6 +52,29 @@ def insert(self, *, vector: DenseVector, payload: Payload) -> str:

return response.id.value

def insert_batch(
self,
*,
points: Sequence[tuple[DenseVector, Payload]],
) -> List[str]:
"""
Insert multiple vectors with payloads.
Returns: List of point IDs
"""
for vector, _ in points:
self._validate_dense_vector(vector)

request = proto.build_batch_insert_request(
points=list(points),
)

response = self._conn.call(
self._conn.stub.InsertVectorsBatch,
request,
)

return [pid.id.value for pid in response.ids]

def get(self, *, point_id: str) -> Point | None:
"""
Retrieve a point by ID.
Expand Down Expand Up @@ -95,11 +114,7 @@ def search(
Search for nearest neighbors.
Returns: List of point IDs
"""
if not isinstance(vector, DenseVector):
raise TypeError(
"vector must be a DenseVector. "
"Use: DenseVector([1.0, 2.0, 3.0])"
)
self._validate_dense_vector(vector)

request = proto.build_search_request(
vector=vector,
Expand All @@ -114,6 +129,40 @@ def search(

return [pid.id.value for pid in response.result_point_ids]

def search_batch(
self,
*,
queries: Sequence[tuple[DenseVector, Similarity, int]],
) -> List[List[str]]:
"""
Search nearest neighbors for multiple query vectors.
Returns: List of result point ID lists
"""
for vector, _, _ in queries:
self._validate_dense_vector(vector)

request = proto.build_batch_search_request(
queries=list(queries),
)

response = self._conn.call(
self._conn.stub.SearchPointsBatch,
request,
)

return [
[pid.id.value for pid in result.result_point_ids]
for result in response.results
]

@staticmethod
def _validate_dense_vector(vector: DenseVector) -> None:
if not isinstance(vector, DenseVector):
raise TypeError(
"vector must be a DenseVector. "
"Use: DenseVector([1.0, 2.0, 3.0])"
)

def close(self) -> None:
"""
Close the gRPC connection.
Expand Down
28 changes: 18 additions & 10 deletions client/python/vortexdb/grpc/vector_db_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

96 changes: 91 additions & 5 deletions client/python/vortexdb/grpc/vector_db_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import warnings

from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
from vortexdb.grpc import vector_db_pb2 as vector__db__pb2
from . import vector_db_pb2 as vector__db__pb2

GRPC_GENERATED_VERSION = '1.76.0'
GRPC_GENERATED_VERSION = '1.81.1'
GRPC_VERSION = grpc.__version__
_version_not_supported = False

Expand All @@ -26,7 +26,7 @@
)


class VectorDBStub(object):
class VectorDBStub:
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel):
Expand Down Expand Up @@ -55,9 +55,19 @@ def __init__(self, channel):
request_serializer=vector__db__pb2.SearchRequest.SerializeToString,
response_deserializer=vector__db__pb2.SearchResponse.FromString,
_registered_method=True)
self.InsertVectorsBatch = channel.unary_unary(
'/vectordb.VectorDB/InsertVectorsBatch',
request_serializer=vector__db__pb2.InsertVectorsBatchRequest.SerializeToString,
response_deserializer=vector__db__pb2.InsertVectorsBatchResponse.FromString,
_registered_method=True)
self.SearchPointsBatch = channel.unary_unary(
'/vectordb.VectorDB/SearchPointsBatch',
request_serializer=vector__db__pb2.SearchPointsBatchRequest.SerializeToString,
response_deserializer=vector__db__pb2.SearchPointsBatchResponse.FromString,
_registered_method=True)


class VectorDBServicer(object):
class VectorDBServicer:
"""Missing associated documentation comment in .proto file."""

def InsertVector(self, request, context):
Expand Down Expand Up @@ -88,6 +98,18 @@ def SearchPoints(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def InsertVectorsBatch(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SearchPointsBatch(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_VectorDBServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -111,6 +133,16 @@ def add_VectorDBServicer_to_server(servicer, server):
request_deserializer=vector__db__pb2.SearchRequest.FromString,
response_serializer=vector__db__pb2.SearchResponse.SerializeToString,
),
'InsertVectorsBatch': grpc.unary_unary_rpc_method_handler(
servicer.InsertVectorsBatch,
request_deserializer=vector__db__pb2.InsertVectorsBatchRequest.FromString,
response_serializer=vector__db__pb2.InsertVectorsBatchResponse.SerializeToString,
),
'SearchPointsBatch': grpc.unary_unary_rpc_method_handler(
servicer.SearchPointsBatch,
request_deserializer=vector__db__pb2.SearchPointsBatchRequest.FromString,
response_serializer=vector__db__pb2.SearchPointsBatchResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'vectordb.VectorDB', rpc_method_handlers)
Expand All @@ -119,7 +151,7 @@ def add_VectorDBServicer_to_server(servicer, server):


# This class is part of an EXPERIMENTAL API.
class VectorDB(object):
class VectorDB:
"""Missing associated documentation comment in .proto file."""

@staticmethod
Expand Down Expand Up @@ -229,3 +261,57 @@ def SearchPoints(request,
timeout,
metadata,
_registered_method=True)

@staticmethod
def InsertVectorsBatch(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/vectordb.VectorDB/InsertVectorsBatch',
vector__db__pb2.InsertVectorsBatchRequest.SerializeToString,
vector__db__pb2.InsertVectorsBatchResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def SearchPointsBatch(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/vectordb.VectorDB/SearchPointsBatch',
vector__db__pb2.SearchPointsBatchRequest.SerializeToString,
vector__db__pb2.SearchPointsBatchResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
Loading
Loading