diff --git a/client/python/tests/test_client.py b/client/python/tests/test_client.py index a752320..1fb4c5b 100644 --- a/client/python/tests/test_client.py +++ b/client/python/tests/test_client.py @@ -54,6 +54,34 @@ def test_insert_rejects_invalid_vector(client): ) +def test_insert_batch_success(client, mock_connection): + mock_connection.call.return_value = Mock( + ids=[ + Mock(id=Mock(value="p1")), + Mock(id=Mock(value="p2")), + ] + ) + + point_ids = client.insert_batch( + points=[ + (DenseVector([1, 2, 3]), Payload.text("hello")), + (DenseVector([4, 5, 6]), Payload.text("world")), + ] + ) + + assert point_ids == ["p1", "p2"] + + +def test_insert_batch_rejects_invalid_vector(client): + with pytest.raises(TypeError): + client.insert_batch( + points=[ + (DenseVector([1, 2, 3]), Payload.text("hello")), + ([4, 5, 6], Payload.text("world")), + ] + ) + + # Get def test_get_point_success(client, mock_connection): @@ -118,6 +146,43 @@ def test_search_invalid_vector(client): ) +def test_search_batch_success(client, mock_connection): + mock_connection.call.return_value = Mock( + results=[ + Mock( + result_point_ids=[ + Mock(id=Mock(value="p1")), + Mock(id=Mock(value="p2")), + ] + ), + Mock( + result_point_ids=[ + Mock(id=Mock(value="p3")), + ] + ), + ] + ) + + results = client.search_batch( + queries=[ + (DenseVector([1, 2, 3]), Similarity.COSINE, 2), + (DenseVector([4, 5, 6]), Similarity.COSINE, 1), + ] + ) + + assert results == [["p1", "p2"], ["p3"]] + + +def test_search_batch_rejects_invalid_vector(client): + with pytest.raises(TypeError): + client.search_batch( + queries=[ + (DenseVector([1, 2, 3]), Similarity.COSINE, 2), + ([4, 5, 6], Similarity.COSINE, 1), + ] + ) + + # Close def test_close_closes_connection(client, mock_connection): diff --git a/client/python/vortexdb/client.py b/client/python/vortexdb/client.py index 38e0553..a73a455 100644 --- a/client/python/vortexdb/client.py +++ b/client/python/vortexdb/client.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Sequence from vortexdb.connection import GRPCConnection from vortexdb.config import VortexDBConfig @@ -38,11 +38,7 @@ def insert(self, *, vector: DenseVector, payload: Payload) -> str: Insert a vector with payload. Returns: point_id (str) """ - if not isinstance(vector, DenseVector): - raise TypeError( - "vector must be a DenseVector. " - "Use: DenseVector([1.0, 2.0, 3.0])" - ) + self._validate_dense_vector(vector) request = proto.build_insert_request( vector=vector, @@ -56,6 +52,29 @@ def insert(self, *, vector: DenseVector, payload: Payload) -> str: return response.id.value + def insert_batch( + self, + *, + points: Sequence[tuple[DenseVector, Payload]], + ) -> List[str]: + """ + Insert multiple vectors with payloads. + Returns: List of point IDs + """ + for vector, _ in points: + self._validate_dense_vector(vector) + + request = proto.build_batch_insert_request( + points=list(points), + ) + + response = self._conn.call( + self._conn.stub.InsertVectorsBatch, + request, + ) + + return [pid.id.value for pid in response.ids] + def get(self, *, point_id: str) -> Point | None: """ Retrieve a point by ID. @@ -95,11 +114,7 @@ def search( Search for nearest neighbors. Returns: List of point IDs """ - if not isinstance(vector, DenseVector): - raise TypeError( - "vector must be a DenseVector. " - "Use: DenseVector([1.0, 2.0, 3.0])" - ) + self._validate_dense_vector(vector) request = proto.build_search_request( vector=vector, @@ -114,6 +129,40 @@ def search( return [pid.id.value for pid in response.result_point_ids] + def search_batch( + self, + *, + queries: Sequence[tuple[DenseVector, Similarity, int]], + ) -> List[List[str]]: + """ + Search nearest neighbors for multiple query vectors. + Returns: List of result point ID lists + """ + for vector, _, _ in queries: + self._validate_dense_vector(vector) + + request = proto.build_batch_search_request( + queries=list(queries), + ) + + response = self._conn.call( + self._conn.stub.SearchPointsBatch, + request, + ) + + return [ + [pid.id.value for pid in result.result_point_ids] + for result in response.results + ] + + @staticmethod + def _validate_dense_vector(vector: DenseVector) -> None: + if not isinstance(vector, DenseVector): + raise TypeError( + "vector must be a DenseVector. " + "Use: DenseVector([1.0, 2.0, 3.0])" + ) + def close(self) -> None: """ Close the gRPC connection. diff --git a/client/python/vortexdb/grpc/vector_db_pb2.py b/client/python/vortexdb/grpc/vector_db_pb2.py index 2b8cbb8..91d3021 100644 --- a/client/python/vortexdb/grpc/vector_db_pb2.py +++ b/client/python/vortexdb/grpc/vector_db_pb2.py @@ -2,7 +2,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: vector-db.proto -# Protobuf Python Version: 6.31.1 +# Protobuf Python Version: 6.33.5 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -12,8 +12,8 @@ _runtime_version.ValidateProtobufRuntimeVersion( _runtime_version.Domain.PUBLIC, 6, - 31, - 1, + 33, + 5, '', 'vector-db.proto' ) @@ -25,17 +25,17 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fvector-db.proto\x12\x08vectordb\x1a\x1bgoogle/protobuf/empty.proto\"\x15\n\x04UUID\x12\r\n\x05value\x18\x01 \x01(\t\"`\n\x13InsertVectorRequest\x12%\n\x06vector\x18\x01 \x01(\x0b\x32\x15.vectordb.DenseVector\x12\"\n\x07payload\x18\x02 \x01(\x0b\x32\x11.vectordb.Payload\"u\n\rSearchRequest\x12+\n\x0cquery_vector\x18\x01 \x01(\x0b\x32\x15.vectordb.DenseVector\x12(\n\nsimilarity\x18\x02 \x01(\x0e\x32\x14.vectordb.Similarity\x12\r\n\x05limit\x18\x03 \x01(\x04\"=\n\x0eSearchResponse\x12+\n\x10result_point_ids\x18\x01 \x03(\x0b\x32\x11.vectordb.PointID\"\x1d\n\x0b\x44\x65nseVector\x12\x0e\n\x06values\x18\x01 \x03(\x02\"q\n\x05Point\x12\x1d\n\x02id\x18\x01 \x01(\x0b\x32\x11.vectordb.PointID\x12\"\n\x07payload\x18\x02 \x01(\x0b\x32\x11.vectordb.Payload\x12%\n\x06vector\x18\x03 \x01(\x0b\x32\x15.vectordb.DenseVector\"%\n\x07PointID\x12\x1a\n\x02id\x18\x01 \x01(\x0b\x32\x0e.vectordb.UUID\"G\n\x07Payload\x12+\n\x0c\x63ontent_type\x18\x01 \x01(\x0e\x32\x15.vectordb.ContentType\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t*C\n\nSimilarity\x12\r\n\tEuclidean\x10\x00\x12\r\n\tManhattan\x10\x01\x12\x0b\n\x07Hamming\x10\x02\x12\n\n\x06\x43osine\x10\x03*\"\n\x0b\x43ontentType\x12\t\n\x05Image\x10\x00\x12\x08\n\x04Text\x10\x01\x32\x81\x02\n\x08VectorDB\x12\x42\n\x0cInsertVector\x12\x1d.vectordb.InsertVectorRequest\x1a\x11.vectordb.PointID\"\x00\x12:\n\x0b\x44\x65letePoint\x12\x11.vectordb.PointID\x1a\x16.google.protobuf.Empty\"\x00\x12\x30\n\x08GetPoint\x12\x11.vectordb.PointID\x1a\x0f.vectordb.Point\"\x00\x12\x43\n\x0cSearchPoints\x12\x17.vectordb.SearchRequest\x1a\x18.vectordb.SearchResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fvector-db.proto\x12\x08vectordb\x1a\x1bgoogle/protobuf/empty.proto\"\x15\n\x04UUID\x12\r\n\x05value\x18\x01 \x01(\t\"`\n\x13InsertVectorRequest\x12%\n\x06vector\x18\x01 \x01(\x0b\x32\x15.vectordb.DenseVector\x12\"\n\x07payload\x18\x02 \x01(\x0b\x32\x11.vectordb.Payload\"u\n\rSearchRequest\x12+\n\x0cquery_vector\x18\x01 \x01(\x0b\x32\x15.vectordb.DenseVector\x12(\n\nsimilarity\x18\x02 \x01(\x0e\x32\x14.vectordb.Similarity\x12\r\n\x05limit\x18\x03 \x01(\x04\"=\n\x0eSearchResponse\x12+\n\x10result_point_ids\x18\x01 \x03(\x0b\x32\x11.vectordb.PointID\"\x1d\n\x0b\x44\x65nseVector\x12\x0e\n\x06values\x18\x01 \x03(\x02\"q\n\x05Point\x12\x1d\n\x02id\x18\x01 \x01(\x0b\x32\x11.vectordb.PointID\x12\"\n\x07payload\x18\x02 \x01(\x0b\x32\x11.vectordb.Payload\x12%\n\x06vector\x18\x03 \x01(\x0b\x32\x15.vectordb.DenseVector\"%\n\x07PointID\x12\x1a\n\x02id\x18\x01 \x01(\x0b\x32\x0e.vectordb.UUID\"G\n\x07Payload\x12+\n\x0c\x63ontent_type\x18\x01 \x01(\x0e\x32\x15.vectordb.ContentType\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\"K\n\x19InsertVectorsBatchRequest\x12.\n\x07vectors\x18\x01 \x03(\x0b\x32\x1d.vectordb.InsertVectorRequest\"<\n\x1aInsertVectorsBatchResponse\x12\x1e\n\x03ids\x18\x01 \x03(\x0b\x32\x11.vectordb.PointID\"D\n\x18SearchPointsBatchRequest\x12(\n\x07queries\x18\x01 \x03(\x0b\x32\x17.vectordb.SearchRequest\"F\n\x19SearchPointsBatchResponse\x12)\n\x07results\x18\x01 \x03(\x0b\x32\x18.vectordb.SearchResponse*C\n\nSimilarity\x12\r\n\tEuclidean\x10\x00\x12\r\n\tManhattan\x10\x01\x12\x0b\n\x07Hamming\x10\x02\x12\n\n\x06\x43osine\x10\x03*\"\n\x0b\x43ontentType\x12\t\n\x05Image\x10\x00\x12\x08\n\x04Text\x10\x01\x32\xc4\x03\n\x08VectorDB\x12\x42\n\x0cInsertVector\x12\x1d.vectordb.InsertVectorRequest\x1a\x11.vectordb.PointID\"\x00\x12:\n\x0b\x44\x65letePoint\x12\x11.vectordb.PointID\x1a\x16.google.protobuf.Empty\"\x00\x12\x30\n\x08GetPoint\x12\x11.vectordb.PointID\x1a\x0f.vectordb.Point\"\x00\x12\x43\n\x0cSearchPoints\x12\x17.vectordb.SearchRequest\x1a\x18.vectordb.SearchResponse\"\x00\x12\x61\n\x12InsertVectorsBatch\x12#.vectordb.InsertVectorsBatchRequest\x1a$.vectordb.InsertVectorsBatchResponse\"\x00\x12^\n\x11SearchPointsBatch\x12\".vectordb.SearchPointsBatchRequest\x1a#.vectordb.SearchPointsBatchResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'vector_db_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_SIMILARITY']._serialized_start=619 - _globals['_SIMILARITY']._serialized_end=686 - _globals['_CONTENTTYPE']._serialized_start=688 - _globals['_CONTENTTYPE']._serialized_end=722 + _globals['_SIMILARITY']._serialized_start=900 + _globals['_SIMILARITY']._serialized_end=967 + _globals['_CONTENTTYPE']._serialized_start=969 + _globals['_CONTENTTYPE']._serialized_end=1003 _globals['_UUID']._serialized_start=58 _globals['_UUID']._serialized_end=79 _globals['_INSERTVECTORREQUEST']._serialized_start=81 @@ -52,6 +52,14 @@ _globals['_POINTID']._serialized_end=544 _globals['_PAYLOAD']._serialized_start=546 _globals['_PAYLOAD']._serialized_end=617 - _globals['_VECTORDB']._serialized_start=725 - _globals['_VECTORDB']._serialized_end=982 + _globals['_INSERTVECTORSBATCHREQUEST']._serialized_start=619 + _globals['_INSERTVECTORSBATCHREQUEST']._serialized_end=694 + _globals['_INSERTVECTORSBATCHRESPONSE']._serialized_start=696 + _globals['_INSERTVECTORSBATCHRESPONSE']._serialized_end=756 + _globals['_SEARCHPOINTSBATCHREQUEST']._serialized_start=758 + _globals['_SEARCHPOINTSBATCHREQUEST']._serialized_end=826 + _globals['_SEARCHPOINTSBATCHRESPONSE']._serialized_start=828 + _globals['_SEARCHPOINTSBATCHRESPONSE']._serialized_end=898 + _globals['_VECTORDB']._serialized_start=1006 + _globals['_VECTORDB']._serialized_end=1458 # @@protoc_insertion_point(module_scope) diff --git a/client/python/vortexdb/grpc/vector_db_pb2_grpc.py b/client/python/vortexdb/grpc/vector_db_pb2_grpc.py index edc3c8f..1442f69 100644 --- a/client/python/vortexdb/grpc/vector_db_pb2_grpc.py +++ b/client/python/vortexdb/grpc/vector_db_pb2_grpc.py @@ -4,9 +4,9 @@ import warnings from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -from vortexdb.grpc import vector_db_pb2 as vector__db__pb2 +from . import vector_db_pb2 as vector__db__pb2 -GRPC_GENERATED_VERSION = '1.76.0' +GRPC_GENERATED_VERSION = '1.81.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -26,7 +26,7 @@ ) -class VectorDBStub(object): +class VectorDBStub: """Missing associated documentation comment in .proto file.""" def __init__(self, channel): @@ -55,9 +55,19 @@ def __init__(self, channel): request_serializer=vector__db__pb2.SearchRequest.SerializeToString, response_deserializer=vector__db__pb2.SearchResponse.FromString, _registered_method=True) + self.InsertVectorsBatch = channel.unary_unary( + '/vectordb.VectorDB/InsertVectorsBatch', + request_serializer=vector__db__pb2.InsertVectorsBatchRequest.SerializeToString, + response_deserializer=vector__db__pb2.InsertVectorsBatchResponse.FromString, + _registered_method=True) + self.SearchPointsBatch = channel.unary_unary( + '/vectordb.VectorDB/SearchPointsBatch', + request_serializer=vector__db__pb2.SearchPointsBatchRequest.SerializeToString, + response_deserializer=vector__db__pb2.SearchPointsBatchResponse.FromString, + _registered_method=True) -class VectorDBServicer(object): +class VectorDBServicer: """Missing associated documentation comment in .proto file.""" def InsertVector(self, request, context): @@ -88,6 +98,18 @@ def SearchPoints(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def InsertVectorsBatch(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SearchPointsBatch(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_VectorDBServicer_to_server(servicer, server): rpc_method_handlers = { @@ -111,6 +133,16 @@ def add_VectorDBServicer_to_server(servicer, server): request_deserializer=vector__db__pb2.SearchRequest.FromString, response_serializer=vector__db__pb2.SearchResponse.SerializeToString, ), + 'InsertVectorsBatch': grpc.unary_unary_rpc_method_handler( + servicer.InsertVectorsBatch, + request_deserializer=vector__db__pb2.InsertVectorsBatchRequest.FromString, + response_serializer=vector__db__pb2.InsertVectorsBatchResponse.SerializeToString, + ), + 'SearchPointsBatch': grpc.unary_unary_rpc_method_handler( + servicer.SearchPointsBatch, + request_deserializer=vector__db__pb2.SearchPointsBatchRequest.FromString, + response_serializer=vector__db__pb2.SearchPointsBatchResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'vectordb.VectorDB', rpc_method_handlers) @@ -119,7 +151,7 @@ def add_VectorDBServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. -class VectorDB(object): +class VectorDB: """Missing associated documentation comment in .proto file.""" @staticmethod @@ -229,3 +261,57 @@ def SearchPoints(request, timeout, metadata, _registered_method=True) + + @staticmethod + def InsertVectorsBatch(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/vectordb.VectorDB/InsertVectorsBatch', + vector__db__pb2.InsertVectorsBatchRequest.SerializeToString, + vector__db__pb2.InsertVectorsBatchResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SearchPointsBatch(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/vectordb.VectorDB/SearchPointsBatch', + vector__db__pb2.SearchPointsBatchRequest.SerializeToString, + vector__db__pb2.SearchPointsBatchResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/client/python/vortexdb/protoutils.py b/client/python/vortexdb/protoutils.py index 8562b18..48e6ffb 100644 --- a/client/python/vortexdb/protoutils.py +++ b/client/python/vortexdb/protoutils.py @@ -1,6 +1,7 @@ from vortexdb.grpc import vector_db_pb2 from vortexdb.models import DenseVector, Payload, Similarity + def build_insert_request( *, vector: DenseVector, @@ -11,11 +12,25 @@ def build_insert_request( payload=payload.to_proto(), ) + +def build_batch_insert_request( + *, + points: list[tuple[DenseVector, Payload]], +) -> vector_db_pb2.InsertVectorsBatchRequest: + return vector_db_pb2.InsertVectorsBatchRequest( + vectors=[ + build_insert_request(vector=vector, payload=payload) + for vector, payload in points + ] + ) + + def build_point_id_request(point_id: str) -> vector_db_pb2.PointID: return vector_db_pb2.PointID( id=vector_db_pb2.UUID(value=point_id) ) + def build_search_request( *, vector: DenseVector, @@ -27,3 +42,15 @@ def build_search_request( similarity=similarity.to_proto(), limit=limit, ) + + +def build_batch_search_request( + *, + queries: list[tuple[DenseVector, Similarity, int]], +) -> vector_db_pb2.SearchPointsBatchRequest: + return vector_db_pb2.SearchPointsBatchRequest( + queries=[ + build_search_request(vector=vector, similarity=similarity, limit=limit) + for vector, similarity, limit in queries + ] + ) diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 5a94828..0a8f548 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -1,7 +1,7 @@ use defs::{DbError, Dimension, IndexedVector, SearchQueryInput, Similarity, SnapshottableDb}; use defs::{DenseVector, Payload, Point, PointId, PointInput}; -use index::hnsw::HnswIndex; -use index::kd_tree::index::KDTree; +use index::hnsw::{HnswConfig, HnswIndex}; +use index::kd_tree::{KDTree, KDTreeConfig}; use std::path::{Path, PathBuf}; use tempfile::tempdir; // use std::sync::atomic::{AtomicU64, Ordering}; @@ -225,6 +225,8 @@ pub struct DbConfig { pub data_path: PathBuf, pub dimension: Dimension, pub similarity: Similarity, + pub hnsw_config: HnswConfig, + pub kd_tree_config: KDTreeConfig, } #[derive(Debug)] @@ -256,10 +258,14 @@ pub fn init_api(config: DbConfig) -> Result { // Initialize the vector index let index: Arc> = match config.index_type { IndexType::Flat => Arc::new(RwLock::new(FlatIndex::new())), - IndexType::KDTree => Arc::new(RwLock::new(KDTree::build_empty(config.dimension))), - IndexType::HNSW => Arc::new(RwLock::new(HnswIndex::new( + IndexType::KDTree => Arc::new(RwLock::new(KDTree::build_empty_with_config( + config.dimension, + config.kd_tree_config, + ))), + IndexType::HNSW => Arc::new(RwLock::new(HnswIndex::with_config( config.similarity, config.dimension, + config.hnsw_config, ))), }; @@ -297,6 +303,8 @@ mod tests { data_path: temp_dir.path().to_path_buf(), dimension: 3, similarity: Similarity::Cosine, + hnsw_config: HnswConfig::default(), + kd_tree_config: KDTreeConfig::default(), }; (init_api(config).unwrap(), temp_dir) } diff --git a/crates/grpc/src/tests.rs b/crates/grpc/src/tests.rs index ff04820..5b00870 100644 --- a/crates/grpc/src/tests.rs +++ b/crates/grpc/src/tests.rs @@ -7,7 +7,7 @@ use crate::service::{VectorDBService, run_server}; use crate::utils::ServerEndpoint; use api::DbConfig; use defs::Similarity; -use index::IndexType; +use index::{IndexType, hnsw::HnswConfig, kd_tree::KDTreeConfig}; use std::net::SocketAddr; use std::sync::Arc; use storage::StorageType; @@ -35,6 +35,8 @@ async fn start_test_server() -> Result<(SocketAddr, TempDir), Box Self { + let max_connections = 16; + Self { + max_connections, + max_connections_0: 2 * max_connections, + max_layer: 16, + ef_construction: 200, + ef: 100, + } + } +} + pub struct HnswIndex { // Construction/search parameters pub ef_construction: usize, @@ -26,11 +48,19 @@ pub struct HnswIndex { impl HnswIndex { pub fn new(similarity: Similarity, data_dimension: Dimension) -> Self { - let max_connections = 16; - let max_connections_0 = 32; // M0 = 2 * M (common default) - let max_layer = 16; - let ef_construction = 200; - let ef = 100; + Self::with_config(similarity, data_dimension, HnswConfig::default()) + } + + pub fn with_config( + similarity: Similarity, + data_dimension: Dimension, + config: HnswConfig, + ) -> Self { + let max_connections = config.max_connections.max(2); + let max_connections_0 = config.max_connections_0.max(max_connections); + let max_layer = config.max_layer.max(1); + let ef_construction = config.ef_construction.max(1); + let ef = config.ef.max(1); let level_generator = LevelGenerator::from_m(max_connections); let index = PointIndexation { @@ -83,7 +113,7 @@ impl VectorIndex for HnswIndex { let new_id: PointId = vector.id; - let mut query_vec = vector.vector.clone(); + let mut query_vec = vector.vector; self.normalize_if_cosine(&mut query_vec); self.cache.insert(new_id, query_vec.clone()); @@ -293,4 +323,23 @@ impl HnswIndex { } } } + + pub(super) fn distance(&self, a: &[f32], b: &[f32]) -> f32 { + debug_assert_eq!(a.len(), b.len()); + match self.similarity { + Similarity::Euclidean => a + .iter() + .zip(b.iter()) + .map(|(&x, &y)| { + let d = x - y; + d * d + }) + .sum(), + Similarity::Cosine => { + let dot = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum::(); + 1.0 - dot + } + Similarity::Manhattan | Similarity::Hamming => distance(a, b, self.similarity), + } + } } diff --git a/crates/index/src/hnsw/mod.rs b/crates/index/src/hnsw/mod.rs index 0cedc17..129985a 100644 --- a/crates/index/src/hnsw/mod.rs +++ b/crates/index/src/hnsw/mod.rs @@ -6,7 +6,7 @@ pub mod search; pub mod serialize; pub mod types; use defs::Magic; -pub use index::HnswIndex; +pub use index::{HnswConfig, HnswIndex}; pub const HNSW_MAGIC_BYTES: Magic = [0x02, 0x01, 0x03, 0x00]; diff --git a/crates/index/src/hnsw/search.rs b/crates/index/src/hnsw/search.rs index bdc6640..8be71db 100644 --- a/crates/index/src/hnsw/search.rs +++ b/crates/index/src/hnsw/search.rs @@ -5,7 +5,6 @@ use std::collections::HashSet; use defs::{OrdF32, PointId}; use crate::Result; -use crate::distance; use super::index::HnswIndex; @@ -23,7 +22,7 @@ impl HnswIndex { let mut current = ep; loop { let cur_vec = self.get_vec(current)?; - let mut best_score = distance(query, cur_vec, self.similarity); + let mut best_score = self.distance(query, cur_vec); let mut best_id = current; let empty: &[PointId] = &[]; @@ -46,7 +45,7 @@ impl HnswIndex { continue; } let n_vec = self.get_vec(n)?; - let score = distance(query, n_vec, self.similarity); + let score = self.distance(query, n_vec); if score < best_score { best_score = score; best_id = n; @@ -91,7 +90,7 @@ impl HnswIndex { .unwrap_or(ep), }; - let ep_score = distance(query, self.get_vec(seed)?, self.similarity); + let ep_score = self.distance(query, self.get_vec(seed)?); candidates.push((Reverse(OrdF32::new(ep_score)), seed)); w_heap.push((OrdF32::new(ep_score), seed)); visited.insert(seed); @@ -114,7 +113,7 @@ impl HnswIndex { .unwrap_or(empty); for &n in neighbors { - if visited.contains(&n) { + if !visited.insert(n) { continue; } // Skip deleted neighbors @@ -124,8 +123,7 @@ impl HnswIndex { continue; } - visited.insert(n); - let score = distance(query, self.get_vec(n)?, self.similarity); + let score = self.distance(query, self.get_vec(n)?); let score = OrdF32::new(score); candidates.push((Reverse(score), n)); if w_heap.len() < ef_construction { @@ -168,7 +166,7 @@ impl HnswIndex { let cand_vec = self.get_vec(cand_id)?; for &r_id in &result { let r_vec = self.get_vec(r_id)?; - let cand_to_r = distance(cand_vec, r_vec, self.similarity); + let cand_to_r = self.distance(cand_vec, r_vec); if cand_to_r < cand_dist_to_q { continue 'outer; } @@ -273,7 +271,7 @@ impl HnswIndex { let mut scored: Vec<(PointId, f32)> = Vec::with_capacity(merged.len()); for nid in merged { - let d = distance(center_vec, self.get_vec(nid)?, self.similarity); + let d = self.distance(center_vec, self.get_vec(nid)?); scored.push((nid, d)); } scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); diff --git a/crates/index/src/kd_tree/helpers.rs b/crates/index/src/kd_tree/helpers.rs index 8934ae7..e7ffe11 100644 --- a/crates/index/src/kd_tree/helpers.rs +++ b/crates/index/src/kd_tree/helpers.rs @@ -4,16 +4,13 @@ use super::types::KDTreeNode; use defs::IndexedVector; impl KDTree { - pub const BALANCE_THRESHOLD: f32 = 0.7; - pub const DELETE_REBUILD_RATIO: f32 = 0.25; - /// Checks if a node is unbalanced based on the balance threshold - pub fn is_unbalanced(node: &KDTreeNode) -> bool { + pub fn is_unbalanced(&self, node: &KDTreeNode) -> bool { let left_size = node.left.as_ref().map_or(0, |n| n.subtree_size); let right_size = node.right.as_ref().map_or(0, |n| n.subtree_size); let max_child = left_size.max(right_size); - max_child as f32 > Self::BALANCE_THRESHOLD * node.subtree_size as f32 + max_child as f32 > self.config.balance_threshold * node.subtree_size as f32 } /// Recursively collects non-deleted vectors from the tree @@ -37,7 +34,9 @@ impl KDTree { } /// Checks if the tree should be globally rebuilt based on deletion ratio - pub fn should_rebuild_global(total_nodes: usize, deleted_count: usize) -> bool { - total_nodes > 0 && (deleted_count as f32 / total_nodes as f32) > Self::DELETE_REBUILD_RATIO + pub fn should_rebuild_global(&self) -> bool { + self.total_nodes > 0 + && (self.deleted_count as f32 / self.total_nodes as f32) + > self.config.delete_rebuild_ratio } } diff --git a/crates/index/src/kd_tree/index.rs b/crates/index/src/kd_tree/index.rs index bf4ed7a..31b6c6e 100644 --- a/crates/index/src/kd_tree/index.rs +++ b/crates/index/src/kd_tree/index.rs @@ -9,6 +9,21 @@ use std::{ }; use uuid::Uuid; +#[derive(Debug, Clone, Copy)] +pub struct KDTreeConfig { + pub balance_threshold: f32, + pub delete_rebuild_ratio: f32, +} + +impl Default for KDTreeConfig { + fn default() -> Self { + Self { + balance_threshold: 0.7, + delete_rebuild_ratio: 0.25, + } + } +} + pub struct KDTree { pub dim: usize, pub root: Option>, @@ -17,22 +32,35 @@ pub struct KDTree { // Rebuild tracking pub total_nodes: usize, pub deleted_count: usize, + pub config: KDTreeConfig, } impl KDTree { // Build an empty index with no points pub fn build_empty(dim: usize) -> Self { + Self::build_empty_with_config(dim, KDTreeConfig::default()) + } + + pub fn build_empty_with_config(dim: usize, config: KDTreeConfig) -> Self { KDTree { dim, root: None, point_ids: HashSet::new(), total_nodes: 0, deleted_count: 0, + config: Self::sanitize_config(config), } } // Builds the vector index from provided vectors, there should atleast be single vector for dim calculation - pub fn build(mut vectors: Vec) -> Result { + pub fn build(vectors: Vec) -> Result { + Self::build_with_config(vectors, KDTreeConfig::default()) + } + + pub fn build_with_config( + mut vectors: Vec, + config: KDTreeConfig, + ) -> Result { if vectors.is_empty() { Err(IndexError::NotInitialized) } else { @@ -50,10 +78,18 @@ impl KDTree { point_ids, total_nodes: vectors.len(), deleted_count: 0, + config: Self::sanitize_config(config), }) } } + fn sanitize_config(config: KDTreeConfig) -> KDTreeConfig { + KDTreeConfig { + balance_threshold: config.balance_threshold.clamp(0.5, 1.0), + delete_rebuild_ratio: config.delete_rebuild_ratio.clamp(0.0, 1.0), + } + } + // Builds the tree recursively with given vectors and returns the pointer of the root node pub fn build_recursive( vectors: &mut [IndexedVector], @@ -246,7 +282,7 @@ impl KDTree { // Check root first (depth 0) if let Some(node) = current - && Self::is_unbalanced(node) + && self.is_unbalanced(node) { unbalanced_depth = Some(0); } @@ -267,7 +303,7 @@ impl KDTree { // Check the child node we just moved to (at depth idx + 1) if let Some(child) = current - && Self::is_unbalanced(child) + && self.is_unbalanced(child) { unbalanced_depth = Some(idx + 1); break; @@ -289,7 +325,7 @@ impl KDTree { self.point_ids.remove(point_id); } - if Self::should_rebuild_global(self.total_nodes, self.deleted_count) + if self.should_rebuild_global() && let Some(root) = self.root.take() { let mut vectors = Self::collect_active_vectors(*root); diff --git a/crates/index/src/kd_tree/mod.rs b/crates/index/src/kd_tree/mod.rs index 4a29b61..61233f4 100644 --- a/crates/index/src/kd_tree/mod.rs +++ b/crates/index/src/kd_tree/mod.rs @@ -9,3 +9,5 @@ pub mod types; mod tests; pub const KD_TREE_MAGIC_BYTES: Magic = [0x00, 0x01, 0x02, 0x00]; + +pub use index::{KDTree, KDTreeConfig}; diff --git a/crates/index/src/kd_tree/serialize.rs b/crates/index/src/kd_tree/serialize.rs index a651d10..397d629 100644 --- a/crates/index/src/kd_tree/serialize.rs +++ b/crates/index/src/kd_tree/serialize.rs @@ -101,6 +101,7 @@ impl KDTree { point_ids: non_deleted, total_nodes: metadata.total_nodes, deleted_count: metadata.deleted_count, + config: Default::default(), }) } } diff --git a/crates/server/src/config.rs b/crates/server/src/config.rs index e593afa..dfcd0f4 100644 --- a/crates/server/src/config.rs +++ b/crates/server/src/config.rs @@ -1,7 +1,7 @@ use api::DbConfig; use defs::Similarity; use dotenv::dotenv; -use index::IndexType; +use index::{IndexType, hnsw::HnswConfig, kd_tree::KDTreeConfig}; use snafu::prelude::*; use std::env; use std::fs; @@ -12,6 +12,12 @@ use tracing::{Level, event}; const DEFAULT_HTTP_PORT: &str = "3000"; const DEFAULT_GRPC_PORT: &str = "50051"; +const DEFAULT_HNSW_M: usize = 16; +const DEFAULT_HNSW_MAX_LAYER: usize = 16; +const DEFAULT_HNSW_EF_CONSTRUCTION: usize = 200; +const DEFAULT_HNSW_EF: usize = 100; +const DEFAULT_KD_TREE_BALANCE_THRESHOLD: f32 = 0.7; +const DEFAULT_KD_TREE_DELETE_REBUILD_RATIO: f32 = 0.25; #[derive(Debug)] pub struct ServerConfig { @@ -197,12 +203,33 @@ impl ServerConfig { } }; + let hnsw_m = load_usize_env("HNSW_M", DEFAULT_HNSW_M); + let hnsw_config = HnswConfig { + max_connections: hnsw_m, + max_connections_0: load_usize_env("HNSW_M0", 2 * hnsw_m), + max_layer: load_usize_env("HNSW_MAX_LAYER", DEFAULT_HNSW_MAX_LAYER), + ef_construction: load_usize_env("HNSW_EF_CONSTRUCTION", DEFAULT_HNSW_EF_CONSTRUCTION), + ef: load_usize_env("HNSW_EF", DEFAULT_HNSW_EF), + }; + let kd_tree_config = KDTreeConfig { + balance_threshold: load_f32_env( + "KD_TREE_BALANCE_THRESHOLD", + DEFAULT_KD_TREE_BALANCE_THRESHOLD, + ), + delete_rebuild_ratio: load_f32_env( + "KD_TREE_DELETE_REBUILD_RATIO", + DEFAULT_KD_TREE_DELETE_REBUILD_RATIO, + ), + }; + let db_config = DbConfig { storage_type, index_type, data_path, dimension, similarity, + hnsw_config, + kd_tree_config, }; Ok(ServerConfig { @@ -215,3 +242,35 @@ impl ServerConfig { }) } } + +fn load_usize_env(name: &str, default: usize) -> usize { + match env::var(name) { + Ok(value) => value.parse().unwrap_or_else(|_| { + event!( + Level::WARN, + "{}='{}' is invalid, defaulting to {}", + name, + value, + default + ); + default + }), + Err(_) => default, + } +} + +fn load_f32_env(name: &str, default: f32) -> f32 { + match env::var(name) { + Ok(value) => value.parse().unwrap_or_else(|_| { + event!( + Level::WARN, + "{}='{}' is invalid, defaulting to {}", + name, + value, + default + ); + default + }), + Err(_) => default, + } +} diff --git a/crates/storage/src/in_memory.rs b/crates/storage/src/in_memory.rs index 3ed52c6..1e8ca60 100644 --- a/crates/storage/src/in_memory.rs +++ b/crates/storage/src/in_memory.rs @@ -126,21 +126,25 @@ impl StorageEngine for MemoryStorage { source, } })?; + let point_snapshot: Vec = { + let points = self + .points + .read() + .map_err(|_| StorageError::InMemoryLock {})?; + points.values().cloned().collect() + }; + let mut writer = BufWriter::new(file); - let points = self - .points - .read() - .map_err(|_| StorageError::InMemoryLock {})?; writer .write_all(INMEMORY_CHECKPOINT_MAGIC) .and_then(|_| writer.write_all(&INMEMORY_CHECKPOINT_VERSION.to_le_bytes())) - .and_then(|_| writer.write_all(&(points.len() as u64).to_le_bytes())) + .and_then(|_| writer.write_all(&(point_snapshot.len() as u64).to_le_bytes())) .map_err(|source| StorageError::InMemoryCheckpointIo { msg: "Couldn't write in-memory checkpoint header".to_string(), source, })?; - for point in points.values() { + for point in &point_snapshot { serialize_into(&mut writer, point).map_err(|source| StorageError::Serialization { id: point.id, source, diff --git a/crates/tui/src/app/database.rs b/crates/tui/src/app/database.rs index ee68e4d..db67df7 100644 --- a/crates/tui/src/app/database.rs +++ b/crates/tui/src/app/database.rs @@ -1,6 +1,6 @@ use api::{DbConfig, VectorDb, init_api}; use defs::Similarity; -use index::IndexType; +use index::{IndexType, hnsw::HnswConfig, kd_tree::KDTreeConfig}; use std::io; use std::path::PathBuf; use std::sync::Arc; @@ -49,6 +49,8 @@ impl DatabaseManager { data_path: path.clone(), dimension: 512, similarity: Similarity::Cosine, + hnsw_config: HnswConfig::default(), + kd_tree_config: KDTreeConfig::default(), }; match init_api(cfg) { @@ -82,6 +84,8 @@ impl DatabaseManager { data_path: path.clone(), dimension: 512, similarity: Similarity::Cosine, + hnsw_config: HnswConfig::default(), + kd_tree_config: KDTreeConfig::default(), }; match init_api(cfg) {