16
16
#
17
17
18
18
from dataclasses import dataclass , field
19
- from typing import Dict , List , Optional , Sequence , Tuple
19
+ from typing import Dict , List , Optional , Sequence , Tuple , Union
20
20
21
21
from google .auth import credentials as auth_credentials
22
22
from google .cloud .aiplatform import base
@@ -148,6 +148,37 @@ def __post_init__(self):
148
148
)
149
149
150
150
151
+ @dataclass
152
+ class HybridQuery :
153
+ """
154
+ Hyrbid query. Could be used for dense-only or sparse-only or hybrid queries.
155
+
156
+ dense_embedding (List[float]):
157
+ Optional. The dense part of the hybrid queries.
158
+ sparse_embedding_values (List[float]):
159
+ Optional. The sparse values of the sparse part of the queries.
160
+
161
+ sparse_embedding_dimensions (List[int]):
162
+ Optional. The corresponding dimensions of the sparse values.
163
+ For example, values [1,2,3] with dimensions [4,5,6] means value 1 is of the
164
+ 4th dimension, value 2 is of the 4th dimension, and value 3 is of the 6th
165
+ dimension.
166
+
167
+ rrf_ranking_alpha (float):
168
+ Optional. This should not be specified for dense-only or sparse-only queries.
169
+ A value between 0 and 1 for ranking algorithm RRF, representing
170
+ the ratio for sparse v.s. dense embeddings returned in the query result.
171
+ If the alpha is 0, only sparse embeddings are being returned, and no dense
172
+ embedding is being returned. When alhpa is 1, only dense embeddings are being
173
+ returned, and no sparse embedding is being returned.
174
+ """
175
+
176
+ dense_embedding : List [float ] = None
177
+ sparse_embedding_values : List [float ] = None
178
+ sparse_embedding_dimensions : List [int ] = None
179
+ rrf_ranking_alpha : float = None
180
+
181
+
151
182
@dataclass
152
183
class MatchNeighbor :
153
184
"""The id and distance of a nearest neighbor match for a given query embedding.
@@ -157,7 +188,7 @@ class MatchNeighbor:
157
188
Required. The id of the neighbor.
158
189
distance (float):
159
190
Required. The distance to the query embedding.
160
- feature_vector (List( float) ):
191
+ feature_vector (List[ float] ):
161
192
Optional. The feature vector of the matching datapoint.
162
193
crowding_tag (Optional[str]):
163
194
Optional. Crowding tag of the datapoint, the
@@ -167,6 +198,14 @@ class MatchNeighbor:
167
198
Optional. The restricts of the matching datapoint.
168
199
numeric_restricts:
169
200
Optional. The numeric restricts of the matching datapoint.
201
+ sparse_embedding_values (List[float]):
202
+ Optional. The sparse values of the sparse part of the matching
203
+ datapoint.
204
+ sparse_embedding_dimensions (List[int]):
205
+ Optional. The corresponding dimensions of the sparse values.
206
+ For example, values [1,2,3] with dimensions [4,5,6] means value 1 is
207
+ of the 4th dimension, value 2 is of the 4th dimension, and value 3 is
208
+ of the 6th dimension.
170
209
171
210
"""
172
211
@@ -176,6 +215,8 @@ class MatchNeighbor:
176
215
crowding_tag : Optional [str ] = None
177
216
restricts : Optional [List [Namespace ]] = None
178
217
numeric_restricts : Optional [List [NumericNamespace ]] = None
218
+ sparse_embedding_values : Optional [List [float ]] = None
219
+ sparse_embedding_dimensions : Optional [List [int ]] = None
179
220
180
221
def from_index_datapoint (
181
222
self , index_datapoint : gca_index_v1beta1 .IndexDatapoint
@@ -207,22 +248,31 @@ def from_index_datapoint(
207
248
]
208
249
if index_datapoint .numeric_restricts is not None :
209
250
self .numeric_restricts = []
210
- for restrict in index_datapoint .numeric_restricts :
211
- numeric_namespace = None
212
- restrict_value_type = restrict ._pb .WhichOneof ("Value" )
213
- if restrict_value_type == "value_int" :
214
- numeric_namespace = NumericNamespace (
215
- name = restrict .namespace , value_int = restrict .value_int
216
- )
217
- elif restrict_value_type == "value_float" :
218
- numeric_namespace = NumericNamespace (
219
- name = restrict .namespace , value_float = restrict .value_float
220
- )
221
- elif restrict_value_type == "value_double" :
222
- numeric_namespace = NumericNamespace (
223
- name = restrict .namespace , value_double = restrict .value_double
224
- )
225
- self .numeric_restricts .append (numeric_namespace )
251
+ for restrict in index_datapoint .numeric_restricts :
252
+ numeric_namespace = None
253
+ restrict_value_type = restrict ._pb .WhichOneof ("Value" )
254
+ if restrict_value_type == "value_int" :
255
+ numeric_namespace = NumericNamespace (
256
+ name = restrict .namespace , value_int = restrict .value_int
257
+ )
258
+ elif restrict_value_type == "value_float" :
259
+ numeric_namespace = NumericNamespace (
260
+ name = restrict .namespace , value_float = restrict .value_float
261
+ )
262
+ elif restrict_value_type == "value_double" :
263
+ numeric_namespace = NumericNamespace (
264
+ name = restrict .namespace , value_double = restrict .value_double
265
+ )
266
+ self .numeric_restricts .append (numeric_namespace )
267
+ # sparse embeddings
268
+ if (
269
+ index_datapoint .sparse_embedding is not None
270
+ and index_datapoint .sparse_embedding .values is not None
271
+ ):
272
+ self .sparse_embedding_values = index_datapoint .sparse_embedding .values
273
+ self .sparse_embedding_dimensions = (
274
+ index_datapoint .sparse_embedding .dimensions
275
+ )
226
276
return self
227
277
228
278
def from_embedding (self , embedding : match_service_pb2 .Embedding ) -> "MatchNeighbor" :
@@ -250,22 +300,22 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb
250
300
]
251
301
if embedding .numeric_restricts :
252
302
self .numeric_restricts = []
253
- for restrict in embedding .numeric_restricts :
254
- numeric_namespace = None
255
- restrict_value_type = restrict .WhichOneof ("Value" )
256
- if restrict_value_type == "value_int" :
257
- numeric_namespace = NumericNamespace (
258
- name = restrict .name , value_int = restrict .value_int
259
- )
260
- elif restrict_value_type == "value_float" :
261
- numeric_namespace = NumericNamespace (
262
- name = restrict .name , value_float = restrict .value_float
263
- )
264
- elif restrict_value_type == "value_double" :
265
- numeric_namespace = NumericNamespace (
266
- name = restrict .name , value_double = restrict .value_double
267
- )
268
- self .numeric_restricts .append (numeric_namespace )
303
+ for restrict in embedding .numeric_restricts :
304
+ numeric_namespace = None
305
+ restrict_value_type = restrict .WhichOneof ("Value" )
306
+ if restrict_value_type == "value_int" :
307
+ numeric_namespace = NumericNamespace (
308
+ name = restrict .name , value_int = restrict .value_int
309
+ )
310
+ elif restrict_value_type == "value_float" :
311
+ numeric_namespace = NumericNamespace (
312
+ name = restrict .name , value_float = restrict .value_float
313
+ )
314
+ elif restrict_value_type == "value_double" :
315
+ numeric_namespace = NumericNamespace (
316
+ name = restrict .name , value_double = restrict .value_double
317
+ )
318
+ self .numeric_restricts .append (numeric_namespace )
269
319
return self
270
320
271
321
@@ -1322,7 +1372,7 @@ def find_neighbors(
1322
1372
self ,
1323
1373
* ,
1324
1374
deployed_index_id : str ,
1325
- queries : Optional [List [List [float ]]] = None ,
1375
+ queries : Optional [Union [ List [List [float ]], List [ HybridQuery ]]] = None ,
1326
1376
num_neighbors : int = 10 ,
1327
1377
filter : Optional [List [Namespace ]] = None ,
1328
1378
per_crowding_attribute_neighbor_count : Optional [int ] = None ,
@@ -1346,8 +1396,15 @@ def find_neighbors(
1346
1396
Args:
1347
1397
deployed_index_id (str):
1348
1398
Required. The ID of the DeployedIndex to match the queries against.
1349
- queries (List[List[float]]):
1350
- Required. A list of queries. Each query is a list of floats, representing a single embedding.
1399
+ queries (Union[List[List[float]], List[HybridQuery]]):
1400
+ Optional. A list of queries.
1401
+
1402
+ For regular dense-only queries, each query is a list of floats,
1403
+ representing a single embedding.
1404
+
1405
+ For hybrid queries, each query is a hybrid query of type
1406
+ aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery.
1407
+
1351
1408
num_neighbors (int):
1352
1409
Required. The number of nearest neighbors to be retrieved from database for
1353
1410
each query.
@@ -1381,7 +1438,7 @@ def find_neighbors(
1381
1438
Note that returning full datapoint will significantly increase the
1382
1439
latency and cost of the query.
1383
1440
1384
- numeric_filter (list [NumericNamespace]):
1441
+ numeric_filter (List [NumericNamespace]):
1385
1442
Optional. A list of NumericNamespaces for filtering the matching
1386
1443
results. For example:
1387
1444
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
@@ -1437,30 +1494,54 @@ def find_neighbors(
1437
1494
numeric_restrict .value_double = numeric_namespace .value_double
1438
1495
numeric_restricts .append (numeric_restrict )
1439
1496
# Queries
1440
- query_by_id = False if queries else True
1441
- queries = queries if queries else embedding_ids
1442
- if queries :
1443
- for query in queries :
1444
- find_neighbors_query = gca_match_service_v1beta1 .FindNeighborsRequest .Query (
1445
- neighbor_count = num_neighbors ,
1446
- per_crowding_attribute_neighbor_count = per_crowding_attribute_neighbor_count ,
1447
- approximate_neighbor_count = approx_num_neighbors ,
1448
- fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1449
- )
1450
- datapoint = gca_index_v1beta1 .IndexDatapoint (
1451
- datapoint_id = query if query_by_id else None ,
1452
- feature_vector = None if query_by_id else query ,
1453
- )
1454
- datapoint .restricts .extend (restricts )
1455
- datapoint .numeric_restricts .extend (numeric_restricts )
1456
- find_neighbors_query .datapoint = datapoint
1457
- find_neighbors_request .queries .append (find_neighbors_query )
1497
+ query_by_id = False
1498
+ query_is_hybrid = False
1499
+ if embedding_ids :
1500
+ query_by_id = True
1501
+ query_iterators : list [str ] = embedding_ids
1502
+ elif queries :
1503
+ query_is_hybrid = isinstance (queries [0 ], HybridQuery )
1504
+ query_iterators = queries
1458
1505
else :
1459
1506
raise ValueError (
1460
1507
"To find neighbors using matching engine,"
1461
- "please specify `queries` or `embedding_ids`"
1508
+ "please specify `queries` or `embedding_ids` or `hybrid_queries` "
1462
1509
)
1463
1510
1511
+ for query in query_iterators :
1512
+ find_neighbors_query = gca_match_service_v1beta1 .FindNeighborsRequest .Query (
1513
+ neighbor_count = num_neighbors ,
1514
+ per_crowding_attribute_neighbor_count = per_crowding_attribute_neighbor_count ,
1515
+ approximate_neighbor_count = approx_num_neighbors ,
1516
+ fraction_leaf_nodes_to_search_override = fraction_leaf_nodes_to_search_override ,
1517
+ )
1518
+ if query_by_id :
1519
+ datapoint = gca_index_v1beta1 .IndexDatapoint (
1520
+ datapoint_id = query ,
1521
+ )
1522
+ elif query_is_hybrid :
1523
+ datapoint = gca_index_v1beta1 .IndexDatapoint (
1524
+ feature_vector = query .dense_embedding ,
1525
+ sparse_embedding = gca_index_v1beta1 .IndexDatapoint .SparseEmbedding (
1526
+ values = query .sparse_embedding_values ,
1527
+ dimensions = query .sparse_embedding_dimensions ,
1528
+ ),
1529
+ )
1530
+ if query .rrf_ranking_alpha :
1531
+ find_neighbors_query .rrf = (
1532
+ gca_match_service_v1beta1 .FindNeighborsRequest .Query .RRF (
1533
+ alpha = query .rrf_ranking_alpha ,
1534
+ )
1535
+ )
1536
+ else :
1537
+ datapoint = gca_index_v1beta1 .IndexDatapoint (
1538
+ feature_vector = query ,
1539
+ )
1540
+ datapoint .restricts .extend (restricts )
1541
+ datapoint .numeric_restricts .extend (numeric_restricts )
1542
+ find_neighbors_query .datapoint = datapoint
1543
+ find_neighbors_request .queries .append (find_neighbors_query )
1544
+
1464
1545
response = self ._public_match_client .find_neighbors (find_neighbors_request )
1465
1546
1466
1547
# Wrap the results in MatchNeighbor objects and return
@@ -1543,7 +1624,6 @@ def read_index_datapoints(
1543
1624
read_index_datapoints_request
1544
1625
)
1545
1626
1546
- # Wrap the results and return
1547
1627
return response .datapoints
1548
1628
1549
1629
def _batch_get_embeddings (
0 commit comments