Content-Length: 791812 | pFad | https://github.com/googleapis/python-aiplatform/commit/c52e3e4ea63e43346b439c3eaf6b264c83bf1c25

8A feat: Add compatibility for RagRetrievalConfig in rag_store and rag_r… · googleapis/python-aiplatform@c52e3e4 · GitHub
Skip to content

Commit c52e3e4

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add compatibility for RagRetrievalConfig in rag_store and rag_retrieval
feat: Add deprecation warnings for use of similarity_top_k, vector_search_alpha, and vector_distance_threshold in retrieval_query, use RagRetrievalConfig instead. PiperOrigin-RevId: 700462404
1 parent 34ed530 commit c52e3e4

File tree

7 files changed

+432
-48
lines changed

7 files changed

+432
-48
lines changed

tests/unit/vertex_rag/test_rag_constants_preview.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020

2121
from vertexai.preview.rag import (
2222
EmbeddingModelConfig,
23+
Filter,
24+
HybridSearch,
2325
Pinecone,
2426
RagCorpus,
2527
RagFile,
2628
RagResource,
29+
RagRetrievalConfig,
2730
SharePointSource,
2831
SharePointSources,
2932
SlackChannelsSource,
@@ -529,3 +532,12 @@
529532
rag_corpus="213lkj-1/23jkl/",
530533
rag_file_ids=[TEST_RAG_FILE_ID],
531534
)
535+
TEST_RAG_RETRIEVAL_CONFIG = RagRetrievalConfig(
536+
top_k=2,
537+
filter=Filter(vector_distance_threshold=0.5),
538+
)
539+
TEST_RAG_RETRIEVAL_CONFIG_ALPHA = RagRetrievalConfig(
540+
top_k=2,
541+
filter=Filter(vector_distance_threshold=0.5),
542+
hybrid_search=HybridSearch(alpha=0.5),
543+
)

tests/unit/vertex_rag/test_rag_retrieval_preview.py

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,58 @@ def teardown_method(self):
7373

7474
@pytest.mark.usefixtures("retrieve_contexts_mock")
7575
def test_retrieval_query_rag_resources_success(self):
76+
with pytest.warns(DeprecationWarning):
77+
response = rag.retrieval_query(
78+
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
79+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
80+
similarity_top_k=2,
81+
vector_distance_threshold=0.5,
82+
vector_search_alpha=0.5,
83+
)
84+
retrieve_contexts_eq(
85+
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
86+
)
87+
88+
@pytest.mark.usefixtures("retrieve_contexts_mock")
89+
def test_retrieval_query_rag_resources_config_success(self):
90+
response = rag.retrieval_query(
91+
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
92+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
93+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
94+
)
95+
retrieve_contexts_eq(
96+
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
97+
)
98+
99+
@pytest.mark.usefixtures("retrieve_contexts_mock")
100+
def test_retrieval_query_rag_resources_default_config_success(self):
76101
response = rag.retrieval_query(
77102
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
78103
text=test_rag_constants_preview.TEST_QUERY_TEXT,
79-
similarity_top_k=2,
80-
vector_distance_threshold=0.5,
81-
vector_search_alpha=0.5,
82104
)
83105
retrieve_contexts_eq(
84106
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
85107
)
86108

87109
@pytest.mark.usefixtures("retrieve_contexts_mock")
88110
def test_retrieval_query_rag_corpora_success(self):
111+
with pytest.warns(DeprecationWarning):
112+
response = rag.retrieval_query(
113+
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
114+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
115+
similarity_top_k=2,
116+
vector_distance_threshold=0.5,
117+
)
118+
retrieve_contexts_eq(
119+
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
120+
)
121+
122+
@pytest.mark.usefixtures("retrieve_contexts_mock")
123+
def test_retrieval_query_rag_corpora_config_success(self):
89124
response = rag.retrieval_query(
90125
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
91126
text=test_rag_constants_preview.TEST_QUERY_TEXT,
92-
similarity_top_k=2,
93-
vector_distance_threshold=0.5,
127+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
94128
)
95129
retrieve_contexts_eq(
96130
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
@@ -107,6 +141,16 @@ def test_retrieval_query_failure(self):
107141
)
108142
e.match("Failed in retrieving contexts due to")
109143

144+
@pytest.mark.usefixtures("rag_client_mock_exception")
145+
def test_retrieval_query_config_failure(self):
146+
with pytest.raises(RuntimeError) as e:
147+
rag.retrieval_query(
148+
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
149+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
150+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
151+
)
152+
e.match("Failed in retrieving contexts due to")
153+
110154
def test_retrieval_query_invalid_name(self):
111155
with pytest.raises(ValueError) as e:
112156
rag.retrieval_query(
@@ -119,6 +163,17 @@ def test_retrieval_query_invalid_name(self):
119163
)
120164
e.match("Invalid RagCorpus name")
121165

166+
def test_retrieval_query_invalid_name_config(self):
167+
with pytest.raises(ValueError) as e:
168+
rag.retrieval_query(
169+
rag_resources=[
170+
test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME
171+
],
172+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
173+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
174+
)
175+
e.match("Invalid RagCorpus name")
176+
122177
def test_retrieval_query_multiple_rag_corpora(self):
123178
with pytest.raises(ValueError) as e:
124179
rag.retrieval_query(
@@ -132,6 +187,18 @@ def test_retrieval_query_multiple_rag_corpora(self):
132187
)
133188
e.match("Currently only support 1 RagCorpus")
134189

190+
def test_retrieval_query_multiple_rag_corpora_config(self):
191+
with pytest.raises(ValueError) as e:
192+
rag.retrieval_query(
193+
rag_corpora=[
194+
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
195+
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
196+
],
197+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
198+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
199+
)
200+
e.match("Currently only support 1 RagCorpus")
201+
135202
def test_retrieval_query_multiple_rag_resources(self):
136203
with pytest.raises(ValueError) as e:
137204
rag.retrieval_query(
@@ -144,3 +211,15 @@ def test_retrieval_query_multiple_rag_resources(self):
144211
vector_distance_threshold=0.5,
145212
)
146213
e.match("Currently only support 1 RagResource")
214+
215+
def test_retrieval_query_multiple_rag_resources_config(self):
216+
with pytest.raises(ValueError) as e:
217+
rag.retrieval_query(
218+
rag_resources=[
219+
test_rag_constants_preview.TEST_RAG_RESOURCE,
220+
test_rag_constants_preview.TEST_RAG_RESOURCE,
221+
],
222+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
223+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
224+
)
225+
e.match("Currently only support 1 RagResource")

tests/unit/vertex_rag/test_rag_store_preview.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,32 @@
2121

2222

2323
@pytest.mark.usefixtures("google_auth_mock")
24-
class TestRagStoreValidations:
24+
class TestRagStore:
25+
def test_retrieval_tool_success(self):
26+
with pytest.warns(DeprecationWarning):
27+
Tool.from_retrieval(
28+
retrieval=rag.Retrieval(
29+
source=rag.VertexRagStore(
30+
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
31+
similarity_top_k=3,
32+
vector_distance_threshold=0.4,
33+
),
34+
)
35+
)
36+
37+
def test_retrieval_tool_config_success(self):
38+
with pytest.warns(DeprecationWarning):
39+
Tool.from_retrieval(
40+
retrieval=rag.Retrieval(
41+
source=rag.VertexRagStore(
42+
rag_corpora=[
43+
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
44+
],
45+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
46+
),
47+
)
48+
)
49+
2550
def test_retrieval_tool_invalid_name(self):
2651
with pytest.raises(ValueError) as e:
2752
Tool.from_retrieval(
@@ -37,6 +62,20 @@ def test_retrieval_tool_invalid_name(self):
3762
)
3863
e.match("Invalid RagCorpus name")
3964

65+
def test_retrieval_tool_invalid_name_config(self):
66+
with pytest.raises(ValueError) as e:
67+
Tool.from_retrieval(
68+
retrieval=rag.Retrieval(
69+
source=rag.VertexRagStore(
70+
rag_resources=[
71+
test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME
72+
],
73+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
74+
),
75+
)
76+
)
77+
e.match("Invalid RagCorpus name")
78+
4079
def test_retrieval_tool_multiple_rag_corpora(self):
4180
with pytest.raises(ValueError) as e:
4281
Tool.from_retrieval(
@@ -53,6 +92,21 @@ def test_retrieval_tool_multiple_rag_corpora(self):
5392
)
5493
e.match("Currently only support 1 RagCorpus")
5594

95+
def test_retrieval_tool_multiple_rag_corpora_config(self):
96+
with pytest.raises(ValueError) as e:
97+
Tool.from_retrieval(
98+
retrieval=rag.Retrieval(
99+
source=rag.VertexRagStore(
100+
rag_corpora=[
101+
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
102+
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
103+
],
104+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
105+
),
106+
)
107+
)
108+
e.match("Currently only support 1 RagCorpus")
109+
56110
def test_retrieval_tool_multiple_rag_resources(self):
57111
with pytest.raises(ValueError) as e:
58112
Tool.from_retrieval(
@@ -68,3 +122,18 @@ def test_retrieval_tool_multiple_rag_resources(self):
68122
)
69123
)
70124
e.match("Currently only support 1 RagResource")
125+
126+
def test_retrieval_tool_multiple_rag_resources_config(self):
127+
with pytest.raises(ValueError) as e:
128+
Tool.from_retrieval(
129+
retrieval=rag.Retrieval(
130+
source=rag.VertexRagStore(
131+
rag_resources=[
132+
test_rag_constants_preview.TEST_RAG_RESOURCE,
133+
test_rag_constants_preview.TEST_RAG_RESOURCE,
134+
],
135+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
136+
),
137+
)
138+
)
139+
e.match("Currently only support 1 RagResource")

vertexai/preview/rag/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,16 @@
3939
)
4040
from vertexai.preview.rag.utils.resources import (
4141
EmbeddingModelConfig,
42+
Filter,
43+
HybridSearch,
4244
JiraQuery,
4345
JiraSource,
4446
Pinecone,
4547
RagCorpus,
4648
RagFile,
4749
RagManagedDb,
4850
RagResource,
51+
RagRetrievalConfig,
4952
SharePointSource,
5053
SharePointSources,
5154
SlackChannel,
@@ -58,13 +61,16 @@
5861

5962
__all__ = (
6063
"EmbeddingModelConfig",
64+
"Filter",
65+
"HybridSearch",
6166
"JiraQuery",
6267
"JiraSource",
6368
"Pinecone",
6469
"RagCorpus",
6570
"RagFile",
6671
"RagManagedDb",
6772
"RagResource",
73+
"RagRetrievalConfig",
6874
"Retrieval",
6975
"SharePointSource",
7076
"SharePointSources",

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/c52e3e4ea63e43346b439c3eaf6b264c83bf1c25

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy