Content-Length: 779675 | pFad | https://github.com/googleapis/python-aiplatform/commit/9d3561738d577129cb222417bf208166825d8043

50 feat: Add hybrid search for public find_neighbors() call. · googleapis/python-aiplatform@9d35617 · GitHub
Skip to content

Commit 9d35617

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Add hybrid search for public find_neighbors() call.
PiperOrigin-RevId: 640750317
1 parent c118557 commit 9d35617

File tree

2 files changed

+225
-59
lines changed

2 files changed

+225
-59
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

Lines changed: 138 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
from dataclasses import dataclass, field
19-
from typing import Dict, List, Optional, Sequence, Tuple
19+
from typing import Dict, List, Optional, Sequence, Tuple, Union
2020

2121
from google.auth import credentials as auth_credentials
2222
from google.cloud.aiplatform import base
@@ -148,6 +148,37 @@ def __post_init__(self):
148148
)
149149

150150

151+
@dataclass
152+
class HybridQuery:
153+
"""
154+
Hyrbid query. Could be used for dense-only or sparse-only or hybrid queries.
155+
156+
dense_embedding (List[float]):
157+
Optional. The dense part of the hybrid queries.
158+
sparse_embedding_values (List[float]):
159+
Optional. The sparse values of the sparse part of the queries.
160+
161+
sparse_embedding_dimensions (List[int]):
162+
Optional. The corresponding dimensions of the sparse values.
163+
For example, values [1,2,3] with dimensions [4,5,6] means value 1 is of the
164+
4th dimension, value 2 is of the 4th dimension, and value 3 is of the 6th
165+
dimension.
166+
167+
rrf_ranking_alpha (float):
168+
Optional. This should not be specified for dense-only or sparse-only queries.
169+
A value between 0 and 1 for ranking algorithm RRF, representing
170+
the ratio for sparse v.s. dense embeddings returned in the query result.
171+
If the alpha is 0, only sparse embeddings are being returned, and no dense
172+
embedding is being returned. When alhpa is 1, only dense embeddings are being
173+
returned, and no sparse embedding is being returned.
174+
"""
175+
176+
dense_embedding: List[float] = None
177+
sparse_embedding_values: List[float] = None
178+
sparse_embedding_dimensions: List[int] = None
179+
rrf_ranking_alpha: float = None
180+
181+
151182
@dataclass
152183
class MatchNeighbor:
153184
"""The id and distance of a nearest neighbor match for a given query embedding.
@@ -157,7 +188,7 @@ class MatchNeighbor:
157188
Required. The id of the neighbor.
158189
distance (float):
159190
Required. The distance to the query embedding.
160-
feature_vector (List(float)):
191+
feature_vector (List[float]):
161192
Optional. The feature vector of the matching datapoint.
162193
crowding_tag (Optional[str]):
163194
Optional. Crowding tag of the datapoint, the
@@ -167,6 +198,14 @@ class MatchNeighbor:
167198
Optional. The restricts of the matching datapoint.
168199
numeric_restricts:
169200
Optional. The numeric restricts of the matching datapoint.
201+
sparse_embedding_values (List[float]):
202+
Optional. The sparse values of the sparse part of the matching
203+
datapoint.
204+
sparse_embedding_dimensions (List[int]):
205+
Optional. The corresponding dimensions of the sparse values.
206+
For example, values [1,2,3] with dimensions [4,5,6] means value 1 is
207+
of the 4th dimension, value 2 is of the 4th dimension, and value 3 is
208+
of the 6th dimension.
170209
171210
"""
172211

@@ -176,6 +215,8 @@ class MatchNeighbor:
176215
crowding_tag: Optional[str] = None
177216
restricts: Optional[List[Namespace]] = None
178217
numeric_restricts: Optional[List[NumericNamespace]] = None
218+
sparse_embedding_values: Optional[List[float]] = None
219+
sparse_embedding_dimensions: Optional[List[int]] = None
179220

180221
def from_index_datapoint(
181222
self, index_datapoint: gca_index_v1beta1.IndexDatapoint
@@ -207,22 +248,31 @@ def from_index_datapoint(
207248
]
208249
if index_datapoint.numeric_restricts is not None:
209250
self.numeric_restricts = []
210-
for restrict in index_datapoint.numeric_restricts:
211-
numeric_namespace = None
212-
restrict_value_type = restrict._pb.WhichOneof("Value")
213-
if restrict_value_type == "value_int":
214-
numeric_namespace = NumericNamespace(
215-
name=restrict.namespace, value_int=restrict.value_int
216-
)
217-
elif restrict_value_type == "value_float":
218-
numeric_namespace = NumericNamespace(
219-
name=restrict.namespace, value_float=restrict.value_float
220-
)
221-
elif restrict_value_type == "value_double":
222-
numeric_namespace = NumericNamespace(
223-
name=restrict.namespace, value_double=restrict.value_double
224-
)
225-
self.numeric_restricts.append(numeric_namespace)
251+
for restrict in index_datapoint.numeric_restricts:
252+
numeric_namespace = None
253+
restrict_value_type = restrict._pb.WhichOneof("Value")
254+
if restrict_value_type == "value_int":
255+
numeric_namespace = NumericNamespace(
256+
name=restrict.namespace, value_int=restrict.value_int
257+
)
258+
elif restrict_value_type == "value_float":
259+
numeric_namespace = NumericNamespace(
260+
name=restrict.namespace, value_float=restrict.value_float
261+
)
262+
elif restrict_value_type == "value_double":
263+
numeric_namespace = NumericNamespace(
264+
name=restrict.namespace, value_double=restrict.value_double
265+
)
266+
self.numeric_restricts.append(numeric_namespace)
267+
# sparse embeddings
268+
if (
269+
index_datapoint.sparse_embedding is not None
270+
and index_datapoint.sparse_embedding.values is not None
271+
):
272+
self.sparse_embedding_values = index_datapoint.sparse_embedding.values
273+
self.sparse_embedding_dimensions = (
274+
index_datapoint.sparse_embedding.dimensions
275+
)
226276
return self
227277

228278
def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor":
@@ -250,22 +300,22 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb
250300
]
251301
if embedding.numeric_restricts:
252302
self.numeric_restricts = []
253-
for restrict in embedding.numeric_restricts:
254-
numeric_namespace = None
255-
restrict_value_type = restrict.WhichOneof("Value")
256-
if restrict_value_type == "value_int":
257-
numeric_namespace = NumericNamespace(
258-
name=restrict.name, value_int=restrict.value_int
259-
)
260-
elif restrict_value_type == "value_float":
261-
numeric_namespace = NumericNamespace(
262-
name=restrict.name, value_float=restrict.value_float
263-
)
264-
elif restrict_value_type == "value_double":
265-
numeric_namespace = NumericNamespace(
266-
name=restrict.name, value_double=restrict.value_double
267-
)
268-
self.numeric_restricts.append(numeric_namespace)
303+
for restrict in embedding.numeric_restricts:
304+
numeric_namespace = None
305+
restrict_value_type = restrict.WhichOneof("Value")
306+
if restrict_value_type == "value_int":
307+
numeric_namespace = NumericNamespace(
308+
name=restrict.name, value_int=restrict.value_int
309+
)
310+
elif restrict_value_type == "value_float":
311+
numeric_namespace = NumericNamespace(
312+
name=restrict.name, value_float=restrict.value_float
313+
)
314+
elif restrict_value_type == "value_double":
315+
numeric_namespace = NumericNamespace(
316+
name=restrict.name, value_double=restrict.value_double
317+
)
318+
self.numeric_restricts.append(numeric_namespace)
269319
return self
270320

271321

@@ -1322,7 +1372,7 @@ def find_neighbors(
13221372
self,
13231373
*,
13241374
deployed_index_id: str,
1325-
queries: Optional[List[List[float]]] = None,
1375+
queries: Optional[Union[List[List[float]], List[HybridQuery]]] = None,
13261376
num_neighbors: int = 10,
13271377
filter: Optional[List[Namespace]] = None,
13281378
per_crowding_attribute_neighbor_count: Optional[int] = None,
@@ -1346,8 +1396,15 @@ def find_neighbors(
13461396
Args:
13471397
deployed_index_id (str):
13481398
Required. The ID of the DeployedIndex to match the queries against.
1349-
queries (List[List[float]]):
1350-
Required. A list of queries. Each query is a list of floats, representing a single embedding.
1399+
queries (Union[List[List[float]], List[HybridQuery]]):
1400+
Optional. A list of queries.
1401+
1402+
For regular dense-only queries, each query is a list of floats,
1403+
representing a single embedding.
1404+
1405+
For hybrid queries, each query is a hybrid query of type
1406+
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery.
1407+
13511408
num_neighbors (int):
13521409
Required. The number of nearest neighbors to be retrieved from database for
13531410
each query.
@@ -1381,7 +1438,7 @@ def find_neighbors(
13811438
Note that returning full datapoint will significantly increase the
13821439
latency and cost of the query.
13831440
1384-
numeric_filter (list[NumericNamespace]):
1441+
numeric_filter (List[NumericNamespace]):
13851442
Optional. A list of NumericNamespaces for filtering the matching
13861443
results. For example:
13871444
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
@@ -1437,30 +1494,54 @@ def find_neighbors(
14371494
numeric_restrict.value_double = numeric_namespace.value_double
14381495
numeric_restricts.append(numeric_restrict)
14391496
# Queries
1440-
query_by_id = False if queries else True
1441-
queries = queries if queries else embedding_ids
1442-
if queries:
1443-
for query in queries:
1444-
find_neighbors_query = gca_match_service_v1beta1.FindNeighborsRequest.Query(
1445-
neighbor_count=num_neighbors,
1446-
per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count,
1447-
approximate_neighbor_count=approx_num_neighbors,
1448-
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1449-
)
1450-
datapoint = gca_index_v1beta1.IndexDatapoint(
1451-
datapoint_id=query if query_by_id else None,
1452-
feature_vector=None if query_by_id else query,
1453-
)
1454-
datapoint.restricts.extend(restricts)
1455-
datapoint.numeric_restricts.extend(numeric_restricts)
1456-
find_neighbors_query.datapoint = datapoint
1457-
find_neighbors_request.queries.append(find_neighbors_query)
1497+
query_by_id = False
1498+
query_is_hybrid = False
1499+
if embedding_ids:
1500+
query_by_id = True
1501+
query_iterators: list[str] = embedding_ids
1502+
elif queries:
1503+
query_is_hybrid = isinstance(queries[0], HybridQuery)
1504+
query_iterators = queries
14581505
else:
14591506
raise ValueError(
14601507
"To find neighbors using matching engine,"
1461-
"please specify `queries` or `embedding_ids`"
1508+
"please specify `queries` or `embedding_ids` or `hybrid_queries`"
14621509
)
14631510

1511+
for query in query_iterators:
1512+
find_neighbors_query = gca_match_service_v1beta1.FindNeighborsRequest.Query(
1513+
neighbor_count=num_neighbors,
1514+
per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count,
1515+
approximate_neighbor_count=approx_num_neighbors,
1516+
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
1517+
)
1518+
if query_by_id:
1519+
datapoint = gca_index_v1beta1.IndexDatapoint(
1520+
datapoint_id=query,
1521+
)
1522+
elif query_is_hybrid:
1523+
datapoint = gca_index_v1beta1.IndexDatapoint(
1524+
feature_vector=query.dense_embedding,
1525+
sparse_embedding=gca_index_v1beta1.IndexDatapoint.SparseEmbedding(
1526+
values=query.sparse_embedding_values,
1527+
dimensions=query.sparse_embedding_dimensions,
1528+
),
1529+
)
1530+
if query.rrf_ranking_alpha:
1531+
find_neighbors_query.rrf = (
1532+
gca_match_service_v1beta1.FindNeighborsRequest.Query.RRF(
1533+
alpha=query.rrf_ranking_alpha,
1534+
)
1535+
)
1536+
else:
1537+
datapoint = gca_index_v1beta1.IndexDatapoint(
1538+
feature_vector=query,
1539+
)
1540+
datapoint.restricts.extend(restricts)
1541+
datapoint.numeric_restricts.extend(numeric_restricts)
1542+
find_neighbors_query.datapoint = datapoint
1543+
find_neighbors_request.queries.append(find_neighbors_query)
1544+
14641545
response = self._public_match_client.find_neighbors(find_neighbors_request)
14651546

14661547
# Wrap the results in MatchNeighbor objects and return
@@ -1543,7 +1624,6 @@ def read_index_datapoints(
15431624
read_index_datapoints_request
15441625
)
15451626

1546-
# Wrap the results and return
15471627
return response.datapoints
15481628

15491629
def _batch_get_embeddings(

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/9d3561738d577129cb222417bf208166825d8043

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy