16
16
17
17
package com .google .cloud .vertexai ;
18
18
19
- import static com .google .common .base .Preconditions .checkArgument ;
20
- import static com .google .common .base .Preconditions .checkNotNull ;
21
-
22
19
import com .google .api .core .InternalApi ;
23
20
import com .google .api .gax .core .CredentialsProvider ;
24
21
import com .google .api .gax .core .FixedCredentialsProvider ;
31
28
import com .google .cloud .vertexai .api .LlmUtilityServiceSettings ;
32
29
import com .google .cloud .vertexai .api .PredictionServiceClient ;
33
30
import com .google .cloud .vertexai .api .PredictionServiceSettings ;
34
- import com .google .common .base .Strings ;
35
31
import java .io .IOException ;
36
32
import java .util .List ;
37
- import java .util .concurrent .locks .ReentrantLock ;
38
33
import java .util .logging .Level ;
39
34
import java .util .logging .Logger ;
40
35
@@ -61,8 +56,9 @@ public class VertexAI implements AutoCloseable {
61
56
private Transport transport = Transport .GRPC ;
62
57
// The clients will be instantiated lazily
63
58
private PredictionServiceClient predictionServiceClient = null ;
59
+ private PredictionServiceClient predictionServiceRestClient = null ;
64
60
private LlmUtilityServiceClient llmUtilityClient = null ;
65
- private final ReentrantLock lock = new ReentrantLock () ;
61
+ private LlmUtilityServiceClient llmUtilityRestClient = null ;
66
62
67
63
/**
68
64
* Construct a VertexAI instance.
@@ -197,35 +193,32 @@ public Credentials getCredentials() throws IOException {
197
193
198
194
/** Sets the value for {@link #getTransport()}. */
199
195
public void setTransport (Transport transport ) {
200
- checkNotNull (transport , "Transport can't be null." );
201
- if (this .transport == transport ) {
202
- return ;
203
- }
204
-
205
196
this .transport = transport ;
206
- resetClients ();
207
197
}
208
198
209
199
/** Sets the value for {@link #getApiEndpoint()}. */
210
200
public void setApiEndpoint (String apiEndpoint ) {
211
- checkArgument (!Strings .isNullOrEmpty (apiEndpoint ), "Api endpoint can't be null or empty." );
212
- if (this .apiEndpoint == apiEndpoint ) {
213
- return ;
214
- }
215
201
this .apiEndpoint = apiEndpoint ;
216
- resetClients ();
217
- }
218
202
219
- private void resetClients () {
220
203
if (this .predictionServiceClient != null ) {
221
204
this .predictionServiceClient .close ();
222
205
this .predictionServiceClient = null ;
223
206
}
224
207
208
+ if (this .predictionServiceRestClient != null ) {
209
+ this .predictionServiceRestClient .close ();
210
+ this .predictionServiceRestClient = null ;
211
+ }
212
+
225
213
if (this .llmUtilityClient != null ) {
226
214
this .llmUtilityClient .close ();
227
215
this .llmUtilityClient = null ;
228
216
}
217
+
218
+ if (this .llmUtilityRestClient != null ) {
219
+ this .llmUtilityRestClient .close ();
220
+ this .llmUtilityRestClient = null ;
221
+ }
229
222
}
230
223
231
224
/**
@@ -237,47 +230,78 @@ private void resetClients() {
237
230
*/
238
231
@ InternalApi
239
232
public PredictionServiceClient getPredictionServiceClient () throws IOException {
240
- if (predictionServiceClient != null ) {
241
- return predictionServiceClient ;
233
+ if (this .transport == Transport .GRPC ) {
234
+ return getPredictionServiceGrpcClient ();
235
+ } else {
236
+ return getPredictionServiceRestClient ();
242
237
}
243
- lock .lock ();
244
- try {
245
- if (predictionServiceClient == null ) {
246
- PredictionServiceSettings settings = getPredictionServiceSettings ();
247
- // Disable the warning message logged in getApplicationDefault
248
- Logger defaultCredentialsProviderLogger =
249
- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
250
- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
251
- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
252
- predictionServiceClient = PredictionServiceClient .create (settings );
253
- defaultCredentialsProviderLogger .setLevel (previousLevel );
238
+ }
239
+
240
+ /**
241
+ * Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
242
+ * first prediction API call is made.
243
+ *
244
+ * @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
245
+ * method calls that map to the API methods.
246
+ */
247
+ private PredictionServiceClient getPredictionServiceGrpcClient () throws IOException {
248
+ if (predictionServiceClient == null ) {
249
+ PredictionServiceSettings .Builder settingsBuilder = PredictionServiceSettings .newBuilder ();
250
+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
251
+ if (this .credentialsProvider != null ) {
252
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
254
253
}
255
- return predictionServiceClient ;
256
- } finally {
257
- lock .unlock ();
254
+ HeaderProvider headerProvider =
255
+ FixedHeaderProvider .create (
256
+ "user-agent" ,
257
+ String .format (
258
+ "%s/%s" ,
259
+ Constants .USER_AGENT_HEADER ,
260
+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
261
+ settingsBuilder .setHeaderProvider (headerProvider );
262
+ // Disable the warning message logged in getApplicationDefault
263
+ Logger defaultCredentialsProviderLogger =
264
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
265
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
266
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
267
+ predictionServiceClient = PredictionServiceClient .create (settingsBuilder .build ());
268
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
258
269
}
270
+ return predictionServiceClient ;
259
271
}
260
272
261
- private PredictionServiceSettings getPredictionServiceSettings () throws IOException {
262
- PredictionServiceSettings .Builder builder ;
263
- if (transport == Transport .REST ) {
264
- builder = PredictionServiceSettings .newHttpJsonBuilder ();
265
- } else {
266
- builder = PredictionServiceSettings .newBuilder ();
267
- }
268
- builder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
269
- if (this .credentialsProvider != null ) {
270
- builder .setCredentialsProvider (this .credentialsProvider );
273
+ /**
274
+ * Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
275
+ * first prediction API call is made.
276
+ *
277
+ * @return {@link PredictionServiceClient} that send REST requests to the backing service through
278
+ * method calls that map to the API methods.
279
+ */
280
+ private PredictionServiceClient getPredictionServiceRestClient () throws IOException {
281
+ if (predictionServiceRestClient == null ) {
282
+ PredictionServiceSettings .Builder settingsBuilder =
283
+ PredictionServiceSettings .newHttpJsonBuilder ();
284
+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
285
+ if (this .credentialsProvider != null ) {
286
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
287
+ }
288
+ HeaderProvider headerProvider =
289
+ FixedHeaderProvider .create (
290
+ "user-agent" ,
291
+ String .format (
292
+ "%s/%s" ,
293
+ Constants .USER_AGENT_HEADER ,
294
+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
295
+ settingsBuilder .setHeaderProvider (headerProvider );
296
+ // Disable the warning message logged in getApplicationDefault
297
+ Logger defaultCredentialsProviderLogger =
298
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
299
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
300
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
301
+ predictionServiceRestClient = PredictionServiceClient .create (settingsBuilder .build ());
302
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
271
303
}
272
- HeaderProvider headerProvider =
273
- FixedHeaderProvider .create (
274
- "user-agent" ,
275
- String .format (
276
- "%s/%s" ,
277
- Constants .USER_AGENT_HEADER ,
278
- GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
279
- builder .setHeaderProvider (headerProvider );
280
- return builder .build ();
304
+ return predictionServiceRestClient ;
281
305
}
282
306
283
307
/**
@@ -289,47 +313,78 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
289
313
*/
290
314
@ InternalApi
291
315
public LlmUtilityServiceClient getLlmUtilityClient () throws IOException {
292
- if (llmUtilityClient != null ) {
293
- return llmUtilityClient ;
316
+ if (this .transport == Transport .GRPC ) {
317
+ return getLlmUtilityGrpcClient ();
318
+ } else {
319
+ return getLlmUtilityRestClient ();
294
320
}
295
- lock .lock ();
296
- try {
297
- if (llmUtilityClient == null ) {
298
- LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings ();
299
- // Disable the warning message logged in getApplicationDefault
300
- Logger defaultCredentialsProviderLogger =
301
- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
302
- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
303
- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
304
- llmUtilityClient = LlmUtilityServiceClient .create (settings );
305
- defaultCredentialsProviderLogger .setLevel (previousLevel );
321
+ }
322
+
323
+ /**
324
+ * Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
325
+ * first API call is made.
326
+ *
327
+ * @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
328
+ * method calls that map to the API methods.
329
+ */
330
+ private LlmUtilityServiceClient getLlmUtilityGrpcClient () throws IOException {
331
+ if (llmUtilityClient == null ) {
332
+ LlmUtilityServiceSettings .Builder settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
333
+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
334
+ if (this .credentialsProvider != null ) {
335
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
306
336
}
307
- return llmUtilityClient ;
308
- } finally {
309
- lock .unlock ();
337
+ HeaderProvider headerProvider =
338
+ FixedHeaderProvider .create (
339
+ "user-agent" ,
340
+ String .format (
341
+ "%s/%s" ,
342
+ Constants .USER_AGENT_HEADER ,
343
+ GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
344
+ settingsBuilder .setHeaderProvider (headerProvider );
345
+ // Disable the warning message logged in getApplicationDefault
346
+ Logger defaultCredentialsProviderLogger =
347
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
348
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
349
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
350
+ llmUtilityClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
351
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
310
352
}
353
+ return llmUtilityClient ;
311
354
}
312
355
313
- private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings () throws IOException {
314
- LlmUtilityServiceSettings .Builder settingsBuilder ;
315
- if (transport == Transport .REST ) {
316
- settingsBuilder = LlmUtilityServiceSettings .newHttpJsonBuilder ();
317
- } else {
318
- settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
319
- }
320
- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
321
- if (this .credentialsProvider != null ) {
322
- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
356
+ /**
357
+ * Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
358
+ * first API call is made.
359
+ *
360
+ * @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
361
+ * method calls that map to the API methods.
362
+ */
363
+ private LlmUtilityServiceClient getLlmUtilityRestClient () throws IOException {
364
+ if (llmUtilityRestClient == null ) {
365
+ LlmUtilityServiceSettings .Builder settingsBuilder =
366
+ LlmUtilityServiceSettings .newHttpJsonBuilder ();
367
+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
368
+ if (this .credentialsProvider != null ) {
369
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
370
+ }
371
+ HeaderProvider headerProvider =
372
+ FixedHeaderProvider .create (
373
+ "user-agent" ,
374
+ String .format (
375
+ "%s/%s" ,
376
+ Constants .USER_AGENT_HEADER ,
377
+ GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
378
+ settingsBuilder .setHeaderProvider (headerProvider );
379
+ // Disable the warning message logged in getApplicationDefault
380
+ Logger defaultCredentialsProviderLogger =
381
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
382
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
383
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
384
+ llmUtilityRestClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
385
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
323
386
}
324
- HeaderProvider headerProvider =
325
- FixedHeaderProvider .create (
326
- "user-agent" ,
327
- String .format (
328
- "%s/%s" ,
329
- Constants .USER_AGENT_HEADER ,
330
- GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
331
- settingsBuilder .setHeaderProvider (headerProvider );
332
- return settingsBuilder .build ();
387
+ return llmUtilityRestClient ;
333
388
}
334
389
335
390
/** Closes the VertexAI instance together with all its instantiated clients. */
@@ -338,8 +393,14 @@ public void close() {
338
393
if (predictionServiceClient != null ) {
339
394
predictionServiceClient .close ();
340
395
}
396
+ if (predictionServiceRestClient != null ) {
397
+ predictionServiceRestClient .close ();
398
+ }
341
399
if (llmUtilityClient != null ) {
342
400
llmUtilityClient .close ();
343
401
}
402
+ if (llmUtilityRestClient != null ) {
403
+ llmUtilityRestClient .close ();
404
+ }
344
405
}
345
406
}
0 commit comments