49
49
_TEST_RESERVED_IP_RANGES = test_constants .TrainingJobConstants ._TEST_RESERVED_IP_RANGES
50
50
_TEST_KEY_NAME = test_constants .TrainingJobConstants ._TEST_DEFAULT_ENCRYPTION_KEY_NAME
51
51
_TEST_SERVICE_ACCOUNT = test_constants .ProjectConstants ._TEST_SERVICE_ACCOUNT
52
+ _TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = (
53
+ test_constants .ProjectConstants ._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE
54
+ )
55
+
52
56
53
57
_TEST_PERSISTENT_RESOURCE_PROTO = persistent_resource_v1 .PersistentResource (
54
58
name = _TEST_PERSISTENT_RESOURCE_ID ,
@@ -298,7 +302,7 @@ def test_create_persistent_resource_with_kms_key(
298
302
)
299
303
300
304
@pytest .mark .parametrize ("sync" , [True , False ])
301
- def test_create_persistent_resource_with_service_account (
305
+ def test_create_persistent_resource_enable_custom_sa_true_with_sa (
302
306
self ,
303
307
create_persistent_resource_mock ,
304
308
get_persistent_resource_mock ,
@@ -309,6 +313,7 @@ def test_create_persistent_resource_with_service_account(
309
313
resource_pools = [
310
314
test_constants .PersistentResourceConstants ._TEST_RESOURCE_POOL ,
311
315
],
316
+ enable_custom_service_account = _TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE ,
312
317
service_account = _TEST_SERVICE_ACCOUNT ,
313
318
sync = sync ,
314
319
)
@@ -321,7 +326,8 @@ def test_create_persistent_resource_with_service_account(
321
326
)
322
327
323
328
service_account_spec = persistent_resource_v1 .ServiceAccountSpec (
324
- enable_custom_service_account = True , service_account = _TEST_SERVICE_ACCOUNT
329
+ enable_custom_service_account = _TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE ,
330
+ service_account = _TEST_SERVICE_ACCOUNT ,
325
331
)
326
332
expected_persistent_resource_arg .resource_runtime_spec = (
327
333
persistent_resource_v1 .ResourceRuntimeSpec (
@@ -341,6 +347,164 @@ def test_create_persistent_resource_with_service_account(
341
347
name = _TEST_PERSISTENT_RESOURCE_ID
342
348
)
343
349
350
+ @pytest .mark .parametrize ("sync" , [True , False ])
351
+ def test_create_persistent_resource_enable_custom_sa_true_no_sa (
352
+ self ,
353
+ create_persistent_resource_mock ,
354
+ get_persistent_resource_mock ,
355
+ sync ,
356
+ ):
357
+ my_test_resource = persistent_resource .PersistentResource .create (
358
+ persistent_resource_id = _TEST_PERSISTENT_RESOURCE_ID ,
359
+ resource_pools = [
360
+ test_constants .PersistentResourceConstants ._TEST_RESOURCE_POOL ,
361
+ ],
362
+ enable_custom_service_account = True ,
363
+ sync = sync ,
364
+ )
365
+
366
+ if not sync :
367
+ my_test_resource .wait ()
368
+
369
+ expected_persistent_resource_arg = _get_persistent_resource_proto (
370
+ name = _TEST_PERSISTENT_RESOURCE_ID ,
371
+ )
372
+ service_account_spec = persistent_resource_v1 .ServiceAccountSpec (
373
+ enable_custom_service_account = True ,
374
+ service_account = None ,
375
+ )
376
+ expected_persistent_resource_arg .resource_runtime_spec = (
377
+ persistent_resource_v1 .ResourceRuntimeSpec (
378
+ service_account_spec = service_account_spec
379
+ )
380
+ )
381
+
382
+ create_persistent_resource_mock .assert_called_once_with (
383
+ parent = _TEST_PARENT ,
384
+ persistent_resource_id = _TEST_PERSISTENT_RESOURCE_ID ,
385
+ persistent_resource = expected_persistent_resource_arg ,
386
+ timeout = None ,
387
+ )
388
+ get_persistent_resource_mock .assert_called_once ()
389
+ _ , mock_kwargs = get_persistent_resource_mock .call_args
390
+ assert mock_kwargs ["name" ] == _get_resource_name (
391
+ name = _TEST_PERSISTENT_RESOURCE_ID
392
+ )
393
+
394
+ @pytest .mark .parametrize ("sync" , [True , False ])
395
+ def test_create_persistent_resource_enable_custom_sa_false_raises_error (
396
+ self ,
397
+ create_persistent_resource_mock ,
398
+ get_persistent_resource_mock ,
399
+ sync ,
400
+ ):
401
+ with pytest .raises (ValueError ) as excinfo :
402
+ my_test_resource = persistent_resource .PersistentResource .create (
403
+ persistent_resource_id = _TEST_PERSISTENT_RESOURCE_ID ,
404
+ resource_pools = [
405
+ test_constants .PersistentResourceConstants ._TEST_RESOURCE_POOL ,
406
+ ],
407
+ enable_custom_service_account = False ,
408
+ service_account = _TEST_SERVICE_ACCOUNT ,
409
+ sync = sync ,
410
+ )
411
+ if not sync :
412
+ my_test_resource .wait ()
413
+
414
+ assert str (excinfo .value ) == (
415
+ "The parameter `enable_custom_service_account` was set to False, "
416
+ "but a value was provided for `service_account`. These two "
417
+ "settings are incompatible. If you want to use a custom "
418
+ "service account, set `enable_custom_service_account` to True."
419
+ )
420
+
421
+ create_persistent_resource_mock .assert_not_called ()
422
+ get_persistent_resource_mock .assert_not_called ()
423
+
424
+ @pytest .mark .parametrize ("sync" , [True , False ])
425
+ def test_create_persistent_resource_enable_custom_sa_none_with_sa (
426
+ self ,
427
+ create_persistent_resource_mock ,
428
+ get_persistent_resource_mock ,
429
+ sync ,
430
+ ):
431
+ my_test_resource = persistent_resource .PersistentResource .create (
432
+ persistent_resource_id = _TEST_PERSISTENT_RESOURCE_ID ,
433
+ resource_pools = [
434
+ test_constants .PersistentResourceConstants ._TEST_RESOURCE_POOL ,
435
+ ],
436
+ enable_custom_service_account = None ,
437
+ service_account = _TEST_SERVICE_ACCOUNT ,
438
+ sync = sync ,
439
+ )
440
+
441
+ if not sync :
442
+ my_test_resource .wait ()
443
+
444
+ expected_persistent_resource_arg = _get_persistent_resource_proto (
445
+ name = _TEST_PERSISTENT_RESOURCE_ID ,
446
+ )
447
+ service_account_spec = persistent_resource_v1 .ServiceAccountSpec (
448
+ enable_custom_service_account = True ,
449
+ service_account = _TEST_SERVICE_ACCOUNT ,
450
+ )
451
+ expected_persistent_resource_arg .resource_runtime_spec = (
452
+ persistent_resource_v1 .ResourceRuntimeSpec (
453
+ service_account_spec = service_account_spec
454
+ )
455
+ )
456
+
457
+ create_persistent_resource_mock .assert_called_once_with (
458
+ parent = _TEST_PARENT ,
459
+ persistent_resource_id = _TEST_PERSISTENT_RESOURCE_ID ,
460
+ persistent_resource = expected_persistent_resource_arg ,
461
+ timeout = None ,
462
+ )
463
+ get_persistent_resource_mock .assert_called_once ()
464
+ _ , mock_kwargs = get_persistent_resource_mock .call_args
465
+ assert mock_kwargs ["name" ] == _get_resource_name (
466
+ name = _TEST_PERSISTENT_RESOURCE_ID
467
+ )
468
+
469
+ @pytest .mark .parametrize ("sync" , [True , False ])
470
+ def test_create_persistent_resource_enable_custom_sa_none_no_sa (
471
+ self ,
472
+ create_persistent_resource_mock ,
473
+ get_persistent_resource_mock ,
474
+ sync ,
475
+ ):
476
+ my_test_resource = persistent_resource .PersistentResource .create (
477
+ persistent_resource_id = _TEST_PERSISTENT_RESOURCE_ID ,
478
+ resource_pools = [
479
+ test_constants .PersistentResourceConstants ._TEST_RESOURCE_POOL ,
480
+ ],
481
+ enable_custom_service_account = None ,
482
+ sync = sync ,
483
+ )
484
+
485
+ if not sync :
486
+ my_test_resource .wait ()
487
+
488
+ expected_persistent_resource_arg = _get_persistent_resource_proto (
489
+ name = _TEST_PERSISTENT_RESOURCE_ID ,
490
+ )
491
+
492
+ # Assert that resource_runtime_spec is NOT set
493
+ call_args = create_persistent_resource_mock .call_args .kwargs
494
+ assert "resource_runtime_spec" not in call_args ["persistent_resource" ]
495
+
496
+ create_persistent_resource_mock .assert_called_once_with (
497
+ parent = _TEST_PARENT ,
498
+ persistent_resource_id = _TEST_PERSISTENT_RESOURCE_ID ,
499
+ persistent_resource = expected_persistent_resource_arg ,
500
+ timeout = None ,
501
+ )
502
+ get_persistent_resource_mock .assert_called_once ()
503
+ _ , mock_kwargs = get_persistent_resource_mock .call_args
504
+ assert mock_kwargs ["name" ] == _get_resource_name (
505
+ name = _TEST_PERSISTENT_RESOURCE_ID
506
+ )
507
+
344
508
def test_list_persistent_resources (self , list_persistent_resources_mock ):
345
509
resource_list = persistent_resource .PersistentResource .list ()
346
510
0 commit comments