Skip to content

Commit 0801812

Browse files
feat: [vertexai] support ToolConfig in GenerativeModel (#10950)
PiperOrigin-RevId: 642059737 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent c79bfb5 commit 0801812

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.google.cloud.vertexai.api.GenerationConfig;
3131
import com.google.cloud.vertexai.api.SafetySetting;
3232
import com.google.cloud.vertexai.api.Tool;
33+
import com.google.cloud.vertexai.api.ToolConfig;
3334
import com.google.common.base.Strings;
3435
import com.google.common.collect.ImmutableList;
3536
import com.google.errorprone.annotations.CanIgnoreReturnValue;
@@ -46,6 +47,7 @@ public final class GenerativeModel {
4647
private final GenerationConfig generationConfig;
4748
private final ImmutableList<SafetySetting> safetySettings;
4849
private final ImmutableList<Tool> tools;
50+
private final Optional<ToolConfig> toolConfig;
4951
private final Optional<Content> systemInstruction;
5052

5153
/**
@@ -65,6 +67,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
6567
ImmutableList.of(),
6668
ImmutableList.of(),
6769
Optional.empty(),
70+
Optional.empty(),
6871
vertexAi);
6972
}
7073

@@ -79,6 +82,10 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
7982
* that will be used by default for generating response
8083
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
8184
* the model as auxiliary tools to generate content.
85+
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} instance that will be used
86+
* to specify the tool configuration.
87+
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} instance that will be
88+
* used by default for generating response.
8289
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
8390
* for the generative model
8491
*/
@@ -87,6 +94,7 @@ private GenerativeModel(
8794
GenerationConfig generationConfig,
8895
ImmutableList<SafetySetting> safetySettings,
8996
ImmutableList<Tool> tools,
97+
Optional<ToolConfig> toolConfig,
9098
Optional<Content> systemInstruction,
9199
VertexAI vertexAi) {
92100
checkArgument(
@@ -98,6 +106,8 @@ private GenerativeModel(
98106
checkNotNull(generationConfig, "GenerationConfig can't be null.");
99107
checkNotNull(safetySettings, "ImmutableList<SafetySettings> can't be null.");
100108
checkNotNull(tools, "ImmutableList<Tool> can't be null.");
109+
checkNotNull(toolConfig, "Optional<ToolConfig> can't be null.");
110+
checkNotNull(systemInstruction, "Optional<Content> can't be null.");
101111

102112
this.resourceName = getResourceName(modelName, vertexAi);
103113
// reconcileModelName should be called after getResourceName.
@@ -106,6 +116,7 @@ private GenerativeModel(
106116
this.generationConfig = generationConfig;
107117
this.safetySettings = safetySettings;
108118
this.tools = tools;
119+
this.toolConfig = toolConfig;
109120
// We remove the role in the system instruction content because it's officially documented
110121
// to be used without role specified:
111122
// https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-system-instruction
@@ -128,6 +139,7 @@ public static class Builder {
128139
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
129140
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
130141
private ImmutableList<Tool> tools = ImmutableList.of();
142+
private Optional<ToolConfig> toolConfig = Optional.empty();
131143
private Optional<Content> systemInstruction = Optional.empty();
132144

133145
public GenerativeModel build() {
@@ -136,7 +148,13 @@ public GenerativeModel build() {
136148
"modelName is required. Please call setModelName() before building.");
137149
checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building.");
138150
return new GenerativeModel(
139-
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
151+
modelName,
152+
generationConfig,
153+
safetySettings,
154+
tools,
155+
toolConfig,
156+
systemInstruction,
157+
vertexAi);
140158
}
141159

142160
/**
@@ -204,6 +222,19 @@ public Builder setTools(List<Tool> tools) {
204222
return this;
205223
}
206224

225+
/**
226+
* Sets a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used by default to
227+
* interact with the generative model.
228+
*/
229+
@CanIgnoreReturnValue
230+
public Builder setToolConfig(ToolConfig toolConfig) {
231+
checkNotNull(
232+
toolConfig,
233+
"toolConfig can't be null. Use Optional.empty() if no tool config is intended.");
234+
this.toolConfig = Optional.of(toolConfig);
235+
return this;
236+
}
237+
207238
/**
208239
* Sets a system instruction that will be used by default to interact with the generative model.
209240
*/
@@ -228,7 +259,13 @@ public Builder setSystemInstruction(Content systemInstruction) {
228259
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
229260
checkNotNull(generationConfig, "GenerationConfig can't be null.");
230261
return new GenerativeModel(
231-
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
262+
modelName,
263+
generationConfig,
264+
safetySettings,
265+
tools,
266+
toolConfig,
267+
systemInstruction,
268+
vertexAi);
232269
}
233270

234271
/**
@@ -247,6 +284,7 @@ public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
247284
generationConfig,
248285
ImmutableList.copyOf(safetySettings),
249286
tools,
287+
toolConfig,
250288
systemInstruction,
251289
vertexAi);
252290
}
@@ -265,6 +303,28 @@ public GenerativeModel withTools(List<Tool> tools) {
265303
generationConfig,
266304
safetySettings,
267305
ImmutableList.copyOf(tools),
306+
toolConfig,
307+
systemInstruction,
308+
vertexAi);
309+
}
310+
311+
/**
312+
* Creates a copy of the current model with updated tool config.
313+
*
314+
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
315+
* new model.
316+
* @return a new {@link GenerativeModel} instance with the specified tool config.
317+
*/
318+
public GenerativeModel withToolConfig(ToolConfig toolConfig) {
319+
checkNotNull(
320+
toolConfig,
321+
"toolConfig can't be null. Use Optional.empty() if no tool config is intended.");
322+
return new GenerativeModel(
323+
modelName,
324+
generationConfig,
325+
safetySettings,
326+
tools,
327+
Optional.of(toolConfig),
268328
systemInstruction,
269329
vertexAi);
270330
}
@@ -286,6 +346,7 @@ public GenerativeModel withSystemInstruction(Content systemInstruction) {
286346
generationConfig,
287347
safetySettings,
288348
tools,
349+
toolConfig,
289350
Optional.of(systemInstruction),
290351
vertexAi);
291352
}
@@ -537,6 +598,10 @@ private GenerateContentRequest buildGenerateContentRequest(List<Content> content
537598
.addAllSafetySettings(safetySettings)
538599
.addAllTools(tools);
539600

601+
if (toolConfig.isPresent()) {
602+
requestBuilder.setToolConfig(toolConfig.get());
603+
}
604+
540605
if (systemInstruction.isPresent()) {
541606
requestBuilder.setSystemInstruction(systemInstruction.get());
542607
}
@@ -568,6 +633,13 @@ public ImmutableList<Tool> getTools() {
568633
return tools;
569634
}
570635

636+
/**
637+
* Returns the optional {@link com.google.cloud.vertexai.api.ToolConfig} of this generative model.
638+
*/
639+
public Optional<ToolConfig> getToolConfig() {
640+
return toolConfig;
641+
}
642+
571643
/** Returns the optional system instruction of this generative model. */
572644
public Optional<Content> getSystemInstruction() {
573645
return systemInstruction;

java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import com.google.cloud.vertexai.api.Content;
3232
import com.google.cloud.vertexai.api.CountTokensRequest;
3333
import com.google.cloud.vertexai.api.CountTokensResponse;
34+
import com.google.cloud.vertexai.api.FunctionCallingConfig;
3435
import com.google.cloud.vertexai.api.FunctionDeclaration;
3536
import com.google.cloud.vertexai.api.GenerateContentRequest;
3637
import com.google.cloud.vertexai.api.GenerateContentResponse;
@@ -44,6 +45,7 @@
4445
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
4546
import com.google.cloud.vertexai.api.Schema;
4647
import com.google.cloud.vertexai.api.Tool;
48+
import com.google.cloud.vertexai.api.ToolConfig;
4749
import com.google.cloud.vertexai.api.Type;
4850
import com.google.cloud.vertexai.api.VertexAISearch;
4951
import java.util.ArrayList;
@@ -96,6 +98,13 @@ public final class GenerativeModelTest {
9698
.build())
9799
.addRequired("location")))
98100
.build();
101+
private static final ToolConfig DEFAULT_TOOL_CONFIG =
102+
ToolConfig.newBuilder()
103+
.setFunctionCallingConfig(
104+
FunctionCallingConfig.newBuilder()
105+
.setMode(FunctionCallingConfig.Mode.ANY)
106+
.addAllowedFunctionNames("getCurrentWeather"))
107+
.build();
99108
private static final Content DEFAULT_SYSTEM_INSTRUCTION =
100109
ContentMaker.fromString(
101110
"You're a helpful assistant that starts all its answers with: \"COOL\"");
@@ -404,6 +413,25 @@ public void generateContent_withDefaultTools_requestHasCorrectToolsAndText() thr
404413
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
405414
}
406415

416+
@Test
417+
public void generateContent_withDefaultToolConfig_requestHasCorrectToolConfigAndText()
418+
throws Exception {
419+
model =
420+
new GenerativeModel.Builder()
421+
.setModelName(MODEL_NAME)
422+
.setVertexAi(vertexAi)
423+
.setToolConfig(DEFAULT_TOOL_CONFIG)
424+
.build();
425+
426+
GenerateContentResponse unused = model.generateContent(TEXT);
427+
428+
ArgumentCaptor<GenerateContentRequest> request =
429+
ArgumentCaptor.forClass(GenerateContentRequest.class);
430+
verify(mockUnaryCallable).call(request.capture());
431+
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
432+
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
433+
}
434+
407435
@Test
408436
public void
409437
generateContent_withDefaultSystemInstruction_requestHasCorrectSystemInstructionAndText()
@@ -433,6 +461,7 @@ public void generateContent_withAllConfigsInFluentApi_requestHasCorrectFields()
433461
.withGenerationConfig(GENERATION_CONFIG)
434462
.withSafetySettings(safetySettings)
435463
.withTools(tools)
464+
.withToolConfig(DEFAULT_TOOL_CONFIG)
436465
.withSystemInstruction(DEFAULT_SYSTEM_INSTRUCTION)
437466
.generateContent(TEXT);
438467

@@ -444,6 +473,7 @@ public void generateContent_withAllConfigsInFluentApi_requestHasCorrectFields()
444473
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
445474
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
446475
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
476+
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
447477
assertThat(request.getValue().getSystemInstruction()).isEqualTo(expectedSystemInstruction);
448478
}
449479

@@ -546,6 +576,24 @@ public void generateContentStream_withDefaultTools_requestHasCorrectTools() thro
546576
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
547577
}
548578

579+
@Test
580+
public void generateContentStream_withDefaultToolConfig_requestHasCorrectToolConfig()
581+
throws Exception {
582+
model =
583+
new GenerativeModel.Builder()
584+
.setModelName(MODEL_NAME)
585+
.setVertexAi(vertexAi)
586+
.setToolConfig(DEFAULT_TOOL_CONFIG)
587+
.build();
588+
589+
ResponseStream unused = model.generateContentStream(TEXT);
590+
591+
ArgumentCaptor<GenerateContentRequest> request =
592+
ArgumentCaptor.forClass(GenerateContentRequest.class);
593+
verify(mockServerStreamCallable).call(request.capture());
594+
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
595+
}
596+
549597
@Test
550598
public void
551599
generateContentStream_withDefaultSystemInstruction_requestHasCorrectSystemInstruction()
@@ -576,6 +624,7 @@ public void generateContentStream_withAllConfigsInFluentApi_requestHasCorrectFie
576624
.withGenerationConfig(GENERATION_CONFIG)
577625
.withSafetySettings(safetySettings)
578626
.withTools(tools)
627+
.withToolConfig(DEFAULT_TOOL_CONFIG)
579628
.withSystemInstruction(DEFAULT_SYSTEM_INSTRUCTION)
580629
.generateContentStream(TEXT);
581630

@@ -587,6 +636,7 @@ public void generateContentStream_withAllConfigsInFluentApi_requestHasCorrectFie
587636
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
588637
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
589638
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
639+
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
590640
assertThat(request.getValue().getSystemInstruction()).isEqualTo(expectedSystemInstruction);
591641
}
592642

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy