Skip to content

Commit ce65eab

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Add support for hybrid queries for private endpoint in Matching Engine Index Endpoint.
PiperOrigin-RevId: 644459987
1 parent 536f1d5 commit ce65eab

File tree

2 files changed

+108
-15
lines changed

2 files changed

+108
-15
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ class MatchNeighbor:
187187
id (str):
188188
Required. The id of the neighbor.
189189
distance (float):
190-
Required. The distance to the query embedding.
190+
Optional. The distance between the neighbor and the dense embedding query.
191+
sparse_distance (float):
192+
Optional. The distance between the neighbor and the query sparse_embedding.
191193
feature_vector (List[float]):
192194
Optional. The feature vector of the matching datapoint.
193195
crowding_tag (Optional[str]):
@@ -210,7 +212,8 @@ class MatchNeighbor:
210212
"""
211213

212214
id: str
213-
distance: float
215+
distance: Optional[float] = None
216+
sparse_distance: Optional[float] = None
214217
feature_vector: Optional[List[float]] = None
215218
crowding_tag: Optional[str] = None
216219
restricts: Optional[List[Namespace]] = None
@@ -316,6 +319,9 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb
316319
name=restrict.name, value_double=restrict.value_double
317320
)
318321
self.numeric_restricts.append(numeric_namespace)
322+
if embedding.sparse_embedding:
323+
self.sparse_embedding_values = embedding.sparse_embedding.float_val
324+
self.sparse_embedding_dimensions = embedding.sparse_embedding.dimension
319325
return self
320326

321327

@@ -1548,7 +1554,11 @@ def find_neighbors(
15481554
return [
15491555
[
15501556
MatchNeighbor(
1551-
id=neighbor.datapoint.datapoint_id, distance=neighbor.distance
1557+
id=neighbor.datapoint.datapoint_id,
1558+
distance=neighbor.distance,
1559+
sparse_distance=neighbor.sparse_distance
1560+
if neighbor.sparse_distance
1561+
else None,
15521562
).from_index_datapoint(index_datapoint=neighbor.datapoint)
15531563
for neighbor in embedding_neighbors.neighbors
15541564
]
@@ -1662,7 +1672,7 @@ def _batch_get_embeddings(
16621672
def match(
16631673
self,
16641674
deployed_index_id: str,
1665-
queries: List[List[float]] = None,
1675+
queries: Union[List[List[float]], List[HybridQuery]] = None,
16661676
num_neighbors: int = 1,
16671677
filter: Optional[List[Namespace]] = None,
16681678
per_crowding_attribute_num_neighbors: Optional[int] = None,
@@ -1677,8 +1687,14 @@ def match(
16771687
Args:
16781688
deployed_index_id (str):
16791689
Required. The ID of the DeployedIndex to match the queries against.
1680-
queries (List[List[float]]):
1681-
Optional. A list of queries. Each query is a list of floats, representing a single embedding.
1690+
queries (Union[List[List[float]], List[HybridQuery]]):
1691+
Optional. A list of queries.
1692+
1693+
For regular dense-only queries, each query is a list of floats,
1694+
representing a single embedding.
1695+
1696+
For hybrid queries, each query is a hybrid query of type
1697+
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery.
16821698
num_neighbors (int):
16831699
Required. The number of nearest neighbors to be retrieved from database for
16841700
each query.
@@ -1759,16 +1775,28 @@ def match(
17591775

17601776
requests = []
17611777
if queries:
1778+
query_is_hybrid = isinstance(queries[0], HybridQuery)
17621779
for query in queries:
17631780
request = match_service_pb2.MatchRequest(
17641781
deployed_index_id=deployed_index_id,
1765-
float_val=query,
1782+
float_val=query.dense_embedding if query_is_hybrid else query,
17661783
num_neighbors=num_neighbors,
17671784
restricts=restricts,
17681785
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
17691786
approx_num_neighbors=approx_num_neighbors,
17701787
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
17711788
numeric_restricts=numeric_restricts,
1789+
sparse_embedding=match_service_pb2.SparseEmbedding(
1790+
float_val=query.sparse_embedding_values,
1791+
dimension=query.sparse_embedding_dimensions,
1792+
)
1793+
if query_is_hybrid
1794+
else None,
1795+
rrf=match_service_pb2.MatchRequest.RRF(
1796+
alpha=query.rrf_ranking_alpha,
1797+
)
1798+
if query_is_hybrid and query.rrf_ranking_alpha
1799+
else None,
17721800
)
17731801
requests.append(request)
17741802
else:
@@ -1789,7 +1817,11 @@ def match(
17891817
match_neighbors_id_map = {}
17901818
for neighbor in resp.neighbor:
17911819
match_neighbors_id_map[neighbor.id] = MatchNeighbor(
1792-
id=neighbor.id, distance=neighbor.distance
1820+
id=neighbor.id,
1821+
distance=neighbor.distance,
1822+
sparse_distance=neighbor.sparse_distance
1823+
if neighbor.sparse_distance
1824+
else None,
17931825
)
17941826
for embedding in resp.embeddings:
17951827
if embedding.id in match_neighbors_id_map:

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,7 +1137,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
11371137
index_endpoint_match_queries_mock.assert_called_with(batch_request)
11381138

11391139
@pytest.mark.usefixtures("get_index_endpoint_mock")
1140-
def test_private_service_access_service_access_index_endpoint_match_queries(
1140+
def test_private_service_access_hybrid_search_match_queries(
11411141
self, index_endpoint_match_queries_mock
11421142
):
11431143
aiplatform.init(project=_TEST_PROJECT)
@@ -1146,7 +1146,72 @@ def test_private_service_access_service_access_index_endpoint_match_queries(
11461146
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
11471147
)
11481148

1149-
response = my_index_endpoint.match(
1149+
my_index_endpoint.match(
1150+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1151+
num_neighbors=_TEST_NUM_NEIGHBOURS,
1152+
filter=_TEST_FILTER,
1153+
queries=_TEST_HYBRID_QUERIES,
1154+
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1155+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
1156+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1157+
low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE,
1158+
numeric_filter=_TEST_NUMERIC_FILTER,
1159+
)
1160+
1161+
batch_request = match_service_pb2.BatchMatchRequest(
1162+
requests=[
1163+
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
1164+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1165+
low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE,
1166+
requests=[
1167+
match_service_pb2.MatchRequest(
1168+
num_neighbors=_TEST_NUM_NEIGHBOURS,
1169+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1170+
float_val=_TEST_HYBRID_QUERIES[i].dense_embedding,
1171+
restricts=[
1172+
match_service_pb2.Namespace(
1173+
name="class",
1174+
allow_tokens=["token_1"],
1175+
deny_tokens=["token_2"],
1176+
)
1177+
],
1178+
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1179+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
1180+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
1181+
numeric_restricts=_TEST_NUMERIC_NAMESPACE,
1182+
sparse_embedding=match_service_pb2.SparseEmbedding(
1183+
float_val=_TEST_HYBRID_QUERIES[
1184+
i
1185+
].sparse_embedding_values,
1186+
dimension=_TEST_HYBRID_QUERIES[
1187+
i
1188+
].sparse_embedding_dimensions,
1189+
),
1190+
rrf=match_service_pb2.MatchRequest.RRF(
1191+
alpha=_TEST_HYBRID_QUERIES[i].rrf_ranking_alpha,
1192+
)
1193+
if _TEST_HYBRID_QUERIES[i].rrf_ranking_alpha
1194+
else None,
1195+
)
1196+
for i in range(len(_TEST_HYBRID_QUERIES))
1197+
],
1198+
)
1199+
]
1200+
)
1201+
1202+
index_endpoint_match_queries_mock.assert_called_with(batch_request)
1203+
1204+
@pytest.mark.usefixtures("get_index_endpoint_mock")
1205+
def test_private_service_access_index_endpoint_match_queries(
1206+
self, index_endpoint_match_queries_mock
1207+
):
1208+
aiplatform.init(project=_TEST_PROJECT)
1209+
1210+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1211+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1212+
)
1213+
1214+
my_index_endpoint.match(
11501215
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
11511216
num_neighbors=_TEST_NUM_NEIGHBOURS,
11521217
filter=_TEST_FILTER,
@@ -1188,8 +1253,6 @@ def test_private_service_access_service_access_index_endpoint_match_queries(
11881253

11891254
index_endpoint_match_queries_mock.assert_called_with(batch_request)
11901255

1191-
assert response == _TEST_PRIVATE_MATCH_NEIGHBOR_RESPONSE
1192-
11931256
@pytest.mark.usefixtures("get_index_endpoint_mock")
11941257
def test_index_private_service_access_endpoint_find_neighbor_queries(
11951258
self, index_endpoint_match_queries_mock
@@ -1200,7 +1263,7 @@ def test_index_private_service_access_endpoint_find_neighbor_queries(
12001263
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
12011264
)
12021265

1203-
response = my_private_index_endpoint.find_neighbors(
1266+
my_private_index_endpoint.find_neighbors(
12041267
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
12051268
queries=_TEST_QUERIES,
12061269
num_neighbors=_TEST_NUM_NEIGHBOURS,
@@ -1240,8 +1303,6 @@ def test_index_private_service_access_endpoint_find_neighbor_queries(
12401303
)
12411304
index_endpoint_match_queries_mock.assert_called_with(batch_match_request)
12421305

1243-
assert response == _TEST_PRIVATE_MATCH_NEIGHBOR_RESPONSE
1244-
12451306
@pytest.mark.usefixtures("get_index_endpoint_mock")
12461307
def test_index_private_service_connect_endpoint_match_queries(
12471308
self, index_endpoint_match_queries_mock

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy