Skip to content

Commit 4704add

Browse files
jaycee-licopybara-github
authored andcommitted
feat: [vertexai] add fluent API in ChatSession
PiperOrigin-RevId: 617901539
1 parent a8aa591 commit 4704add

File tree

5 files changed

+417
-370
lines changed

5 files changed

+417
-370
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java

Lines changed: 150 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616

1717
package com.google.cloud.vertexai;
1818

19-
import static com.google.common.base.Preconditions.checkArgument;
20-
import static com.google.common.base.Preconditions.checkNotNull;
21-
2219
import com.google.api.core.InternalApi;
2320
import com.google.api.gax.core.CredentialsProvider;
2421
import com.google.api.gax.core.FixedCredentialsProvider;
@@ -31,10 +28,8 @@
3128
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
3229
import com.google.cloud.vertexai.api.PredictionServiceClient;
3330
import com.google.cloud.vertexai.api.PredictionServiceSettings;
34-
import com.google.common.base.Strings;
3531
import java.io.IOException;
3632
import java.util.List;
37-
import java.util.concurrent.locks.ReentrantLock;
3833
import java.util.logging.Level;
3934
import java.util.logging.Logger;
4035

@@ -61,8 +56,9 @@ public class VertexAI implements AutoCloseable {
6156
private Transport transport = Transport.GRPC;
6257
// The clients will be instantiated lazily
6358
private PredictionServiceClient predictionServiceClient = null;
59+
private PredictionServiceClient predictionServiceRestClient = null;
6460
private LlmUtilityServiceClient llmUtilityClient = null;
65-
private final ReentrantLock lock = new ReentrantLock();
61+
private LlmUtilityServiceClient llmUtilityRestClient = null;
6662

6763
/**
6864
* Construct a VertexAI instance.
@@ -197,35 +193,32 @@ public Credentials getCredentials() throws IOException {
197193

198194
/** Sets the value for {@link #getTransport()}. */
199195
public void setTransport(Transport transport) {
200-
checkNotNull(transport, "Transport can't be null.");
201-
if (this.transport == transport) {
202-
return;
203-
}
204-
205196
this.transport = transport;
206-
resetClients();
207197
}
208198

209199
/** Sets the value for {@link #getApiEndpoint()}. */
210200
public void setApiEndpoint(String apiEndpoint) {
211-
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "Api endpoint can't be null or empty.");
212-
if (this.apiEndpoint == apiEndpoint) {
213-
return;
214-
}
215201
this.apiEndpoint = apiEndpoint;
216-
resetClients();
217-
}
218202

219-
private void resetClients() {
220203
if (this.predictionServiceClient != null) {
221204
this.predictionServiceClient.close();
222205
this.predictionServiceClient = null;
223206
}
224207

208+
if (this.predictionServiceRestClient != null) {
209+
this.predictionServiceRestClient.close();
210+
this.predictionServiceRestClient = null;
211+
}
212+
225213
if (this.llmUtilityClient != null) {
226214
this.llmUtilityClient.close();
227215
this.llmUtilityClient = null;
228216
}
217+
218+
if (this.llmUtilityRestClient != null) {
219+
this.llmUtilityRestClient.close();
220+
this.llmUtilityRestClient = null;
221+
}
229222
}
230223

231224
/**
@@ -237,47 +230,78 @@ private void resetClients() {
237230
*/
238231
@InternalApi
239232
public PredictionServiceClient getPredictionServiceClient() throws IOException {
240-
if (predictionServiceClient != null) {
241-
return predictionServiceClient;
233+
if (this.transport == Transport.GRPC) {
234+
return getPredictionServiceGrpcClient();
235+
} else {
236+
return getPredictionServiceRestClient();
242237
}
243-
lock.lock();
244-
try {
245-
if (predictionServiceClient == null) {
246-
PredictionServiceSettings settings = getPredictionServiceSettings();
247-
// Disable the warning message logged in getApplicationDefault
248-
Logger defaultCredentialsProviderLogger =
249-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
250-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
251-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
252-
predictionServiceClient = PredictionServiceClient.create(settings);
253-
defaultCredentialsProviderLogger.setLevel(previousLevel);
238+
}
239+
240+
/**
241+
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
242+
* first prediction API call is made.
243+
*
244+
* @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
245+
* method calls that map to the API methods.
246+
*/
247+
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
248+
if (predictionServiceClient == null) {
249+
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
250+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
251+
if (this.credentialsProvider != null) {
252+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
254253
}
255-
return predictionServiceClient;
256-
} finally {
257-
lock.unlock();
254+
HeaderProvider headerProvider =
255+
FixedHeaderProvider.create(
256+
"user-agent",
257+
String.format(
258+
"%s/%s",
259+
Constants.USER_AGENT_HEADER,
260+
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
261+
settingsBuilder.setHeaderProvider(headerProvider);
262+
// Disable the warning message logged in getApplicationDefault
263+
Logger defaultCredentialsProviderLogger =
264+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
265+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
266+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
267+
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
268+
defaultCredentialsProviderLogger.setLevel(previousLevel);
258269
}
270+
return predictionServiceClient;
259271
}
260272

261-
private PredictionServiceSettings getPredictionServiceSettings() throws IOException {
262-
PredictionServiceSettings.Builder builder;
263-
if (transport == Transport.REST) {
264-
builder = PredictionServiceSettings.newHttpJsonBuilder();
265-
} else {
266-
builder = PredictionServiceSettings.newBuilder();
267-
}
268-
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
269-
if (this.credentialsProvider != null) {
270-
builder.setCredentialsProvider(this.credentialsProvider);
273+
/**
274+
* Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
275+
* first prediction API call is made.
276+
*
277+
* @return {@link PredictionServiceClient} that send REST requests to the backing service through
278+
* method calls that map to the API methods.
279+
*/
280+
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
281+
if (predictionServiceRestClient == null) {
282+
PredictionServiceSettings.Builder settingsBuilder =
283+
PredictionServiceSettings.newHttpJsonBuilder();
284+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
285+
if (this.credentialsProvider != null) {
286+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
287+
}
288+
HeaderProvider headerProvider =
289+
FixedHeaderProvider.create(
290+
"user-agent",
291+
String.format(
292+
"%s/%s",
293+
Constants.USER_AGENT_HEADER,
294+
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
295+
settingsBuilder.setHeaderProvider(headerProvider);
296+
// Disable the warning message logged in getApplicationDefault
297+
Logger defaultCredentialsProviderLogger =
298+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
299+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
300+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
301+
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
302+
defaultCredentialsProviderLogger.setLevel(previousLevel);
271303
}
272-
HeaderProvider headerProvider =
273-
FixedHeaderProvider.create(
274-
"user-agent",
275-
String.format(
276-
"%s/%s",
277-
Constants.USER_AGENT_HEADER,
278-
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
279-
builder.setHeaderProvider(headerProvider);
280-
return builder.build();
304+
return predictionServiceRestClient;
281305
}
282306

283307
/**
@@ -289,47 +313,78 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
289313
*/
290314
@InternalApi
291315
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
292-
if (llmUtilityClient != null) {
293-
return llmUtilityClient;
316+
if (this.transport == Transport.GRPC) {
317+
return getLlmUtilityGrpcClient();
318+
} else {
319+
return getLlmUtilityRestClient();
294320
}
295-
lock.lock();
296-
try {
297-
if (llmUtilityClient == null) {
298-
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
299-
// Disable the warning message logged in getApplicationDefault
300-
Logger defaultCredentialsProviderLogger =
301-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
302-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
303-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
304-
llmUtilityClient = LlmUtilityServiceClient.create(settings);
305-
defaultCredentialsProviderLogger.setLevel(previousLevel);
321+
}
322+
323+
/**
324+
* Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
325+
* first API call is made.
326+
*
327+
* @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
328+
* method calls that map to the API methods.
329+
*/
330+
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
331+
if (llmUtilityClient == null) {
332+
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
333+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
334+
if (this.credentialsProvider != null) {
335+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
306336
}
307-
return llmUtilityClient;
308-
} finally {
309-
lock.unlock();
337+
HeaderProvider headerProvider =
338+
FixedHeaderProvider.create(
339+
"user-agent",
340+
String.format(
341+
"%s/%s",
342+
Constants.USER_AGENT_HEADER,
343+
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
344+
settingsBuilder.setHeaderProvider(headerProvider);
345+
// Disable the warning message logged in getApplicationDefault
346+
Logger defaultCredentialsProviderLogger =
347+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
348+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
349+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
350+
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
351+
defaultCredentialsProviderLogger.setLevel(previousLevel);
310352
}
353+
return llmUtilityClient;
311354
}
312355

313-
private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException {
314-
LlmUtilityServiceSettings.Builder settingsBuilder;
315-
if (transport == Transport.REST) {
316-
settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder();
317-
} else {
318-
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
319-
}
320-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
321-
if (this.credentialsProvider != null) {
322-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
356+
/**
357+
* Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
358+
* first API call is made.
359+
*
360+
* @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
361+
* method calls that map to the API methods.
362+
*/
363+
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
364+
if (llmUtilityRestClient == null) {
365+
LlmUtilityServiceSettings.Builder settingsBuilder =
366+
LlmUtilityServiceSettings.newHttpJsonBuilder();
367+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
368+
if (this.credentialsProvider != null) {
369+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
370+
}
371+
HeaderProvider headerProvider =
372+
FixedHeaderProvider.create(
373+
"user-agent",
374+
String.format(
375+
"%s/%s",
376+
Constants.USER_AGENT_HEADER,
377+
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
378+
settingsBuilder.setHeaderProvider(headerProvider);
379+
// Disable the warning message logged in getApplicationDefault
380+
Logger defaultCredentialsProviderLogger =
381+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
382+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
383+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
384+
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
385+
defaultCredentialsProviderLogger.setLevel(previousLevel);
323386
}
324-
HeaderProvider headerProvider =
325-
FixedHeaderProvider.create(
326-
"user-agent",
327-
String.format(
328-
"%s/%s",
329-
Constants.USER_AGENT_HEADER,
330-
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
331-
settingsBuilder.setHeaderProvider(headerProvider);
332-
return settingsBuilder.build();
387+
return llmUtilityRestClient;
333388
}
334389

335390
/** Closes the VertexAI instance together with all its instantiated clients. */
@@ -338,8 +393,14 @@ public void close() {
338393
if (predictionServiceClient != null) {
339394
predictionServiceClient.close();
340395
}
396+
if (predictionServiceRestClient != null) {
397+
predictionServiceRestClient.close();
398+
}
341399
if (llmUtilityClient != null) {
342400
llmUtilityClient.close();
343401
}
402+
if (llmUtilityRestClient != null) {
403+
llmUtilityRestClient.close();
404+
}
344405
}
345406
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy