Content-Length: 1042045 | pFad | https://github.com/googleapis/python-aiplatform/commit/bf79bdf643c60e72b15f414dba964e9da2eb2d7f

84 feat: Add `enable_custom_service_account` parameter (must be set to `… · googleapis/python-aiplatform@bf79bdf · GitHub
Skip to content

Commit bf79bdf

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add enable_custom_service_account parameter (must be set to True for successful Persistent Resource). The service_account parameter is retained for backward compatibility.
PiperOrigin-RevId: 755016847
1 parent 6de9de1 commit bf79bdf

File tree

6 files changed

+408
-7
lines changed

6 files changed

+408
-7
lines changed

google/cloud/aiplatform/persistent_resource.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def create(
171171
labels: Optional[Dict[str, str]] = None,
172172
network: Optional[str] = None,
173173
kms_key_name: Optional[str] = None,
174+
enable_custom_service_account: Optional[bool] = None,
174175
service_account: Optional[str] = None,
175176
reserved_ip_ranges: List[str] = None,
176177
sync: Optional[bool] = True, # pylint: disable=unused-argument
@@ -234,6 +235,16 @@ def create(
234235
PersistentResource. If set, this PersistentResource and all
235236
sub-resources of this PersistentResource will be secured by
236237
this key.
238+
enable_custom_service_account (bool):
239+
Optional. When set to True, allows the `service_account`
240+
parameter to specify a custom service account for workloads on this
241+
PersistentResource. Defaults to None (False behavior).
242+
243+
If True, the service account provided in the `service_account` parameter
244+
will be used for workloads (runtimes, jobs), provided the user has the
245+
``iam.serviceAccounts.actAs`` permission. If False, the
246+
`service_account` parameter is ignored, and the PersistentResource
247+
will use the default service account.
237248
service_account (str):
238249
Optional. Default service account that this
239250
PersistentResource's workloads run as. The workloads
@@ -295,7 +306,31 @@ def create(
295306
gca_encryption_spec_compat.EncryptionSpec(kms_key_name=kms_key_name)
296307
)
297308

298-
if service_account:
309+
# Raise ValueError if enable_custom_service_account is False but
310+
# service_account is provided
311+
if (
312+
enable_custom_service_account is False and service_account is not None
313+
): # pylint: disable=g-bool-id-comparison
314+
raise ValueError(
315+
"The parameter `enable_custom_service_account` was set to False, "
316+
"but a value was provided for `service_account`. These two "
317+
"settings are incompatible. If you want to use a custom "
318+
"service account, set `enable_custom_service_account` to True."
319+
)
320+
321+
elif enable_custom_service_account:
322+
service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec(
323+
enable_custom_service_account=True,
324+
# Set service_account if it is provided, otherwise set to None
325+
service_account=service_account if service_account else None,
326+
)
327+
gca_persistent_resource.resource_runtime_spec = (
328+
gca_persistent_resource_compat.ResourceRuntimeSpec(
329+
service_account_spec=service_account_spec
330+
)
331+
)
332+
elif service_account:
333+
# Handle the deprecated case where only service_account is provided
299334
service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec(
300335
enable_custom_service_account=True, service_account=service_account
301336
)

google/cloud/aiplatform/preview/persistent_resource.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def create(
177177
labels: Optional[Dict[str, str]] = None,
178178
network: Optional[str] = None,
179179
kms_key_name: Optional[str] = None,
180+
enable_custom_service_account: Optional[bool] = None,
180181
service_account: Optional[str] = None,
181182
reserved_ip_ranges: List[str] = None,
182183
sync: Optional[bool] = True, # pylint: disable=unused-argument
@@ -240,6 +241,16 @@ def create(
240241
PersistentResource. If set, this PersistentResource and all
241242
sub-resources of this PersistentResource will be secured by
242243
this key.
244+
enable_custom_service_account (bool):
245+
Optional. When set to True, allows the `service_account`
246+
parameter to specify a custom service account for workloads on this
247+
PersistentResource. Defaults to None (False behavior).
248+
249+
If True, the service account provided in the `service_account` parameter
250+
will be used for workloads (runtimes, jobs), provided the user has the
251+
``iam.serviceAccounts.actAs`` permission. If False, the
252+
`service_account` parameter is ignored, and the PersistentResource
253+
will use the default service account.
243254
service_account (str):
244255
Optional. Default service account that this
245256
PersistentResource's workloads run as. The workloads
@@ -301,7 +312,29 @@ def create(
301312
gca_encryption_spec_compat.EncryptionSpec(kms_key_name=kms_key_name)
302313
)
303314

304-
if service_account:
315+
if (
316+
enable_custom_service_account is False and service_account is not None
317+
): # pylint: disable=g-bool-id-comparison
318+
raise ValueError(
319+
"The parameter `enable_custom_service_account` was set to False, "
320+
"but a value was provided for `service_account`. These two "
321+
"settings are incompatible. If you want to use a custom "
322+
"service account, set `enable_custom_service_account` to True."
323+
)
324+
325+
elif enable_custom_service_account:
326+
service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec(
327+
enable_custom_service_account=True,
328+
# Set service_account if it is provided, otherwise set to None
329+
service_account=service_account if service_account else None,
330+
)
331+
gca_persistent_resource.resource_runtime_spec = (
332+
gca_persistent_resource_compat.ResourceRuntimeSpec(
333+
service_account_spec=service_account_spec
334+
)
335+
)
336+
elif service_account:
337+
# Handle the deprecated case where only service_account is provided
305338
service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec(
306339
enable_custom_service_account=True, service_account=service_account
307340
)

tests/system/aiplatform/test_persistent_resource.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@
2727
persistent_resource_v1 as gca_persistent_resource,
2828
)
2929
from tests.system.aiplatform import e2e_base
30+
from google.cloud.aiplatform.tests.unit.aiplatform import constants as test_constants
3031
import pytest
3132

3233

3334
_TEST_MACHINE_TYPE = "n1-standard-4"
3435
_TEST_INITIAL_REPLICA_COUNT = 2
36+
_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = (
37+
test_constants.ProjectConstants._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE
38+
)
3539

3640

3741
@pytest.mark.usefixtures("tear_down_resources")
@@ -59,7 +63,9 @@ def test_create_persistent_resource(self, shared_state):
5963
]
6064

6165
test_resource = persistent_resource.PersistentResource.create(
62-
persistent_resource_id=resource_id, resource_pools=resource_pools
66+
persistent_resource_id=resource_id,
67+
resource_pools=resource_pools,
68+
enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE,
6369
)
6470

6571
shared_state["resources"] = [test_resource]

tests/unit/aiplatform/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class ProjectConstants:
5959
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
6060
_TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com"
6161
_TEST_LABELS = {"my_key": "my_value"}
62+
_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = True
6263

6364

6465
@dataclasses.dataclass(frozen=True)

tests/unit/aiplatform/test_persistent_resource.py

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
_TEST_RESERVED_IP_RANGES = test_constants.TrainingJobConstants._TEST_RESERVED_IP_RANGES
5050
_TEST_KEY_NAME = test_constants.TrainingJobConstants._TEST_DEFAULT_ENCRYPTION_KEY_NAME
5151
_TEST_SERVICE_ACCOUNT = test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT
52+
_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = (
53+
test_constants.ProjectConstants._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE
54+
)
55+
5256

5357
_TEST_PERSISTENT_RESOURCE_PROTO = persistent_resource_v1.PersistentResource(
5458
name=_TEST_PERSISTENT_RESOURCE_ID,
@@ -298,7 +302,7 @@ def test_create_persistent_resource_with_kms_key(
298302
)
299303

300304
@pytest.mark.parametrize("sync", [True, False])
301-
def test_create_persistent_resource_with_service_account(
305+
def test_create_persistent_resource_enable_custom_sa_true_with_sa(
302306
self,
303307
create_persistent_resource_mock,
304308
get_persistent_resource_mock,
@@ -309,6 +313,7 @@ def test_create_persistent_resource_with_service_account(
309313
resource_pools=[
310314
test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
311315
],
316+
enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE,
312317
service_account=_TEST_SERVICE_ACCOUNT,
313318
sync=sync,
314319
)
@@ -321,7 +326,8 @@ def test_create_persistent_resource_with_service_account(
321326
)
322327

323328
service_account_spec = persistent_resource_v1.ServiceAccountSpec(
324-
enable_custom_service_account=True, service_account=_TEST_SERVICE_ACCOUNT
329+
enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE,
330+
service_account=_TEST_SERVICE_ACCOUNT,
325331
)
326332
expected_persistent_resource_arg.resource_runtime_spec = (
327333
persistent_resource_v1.ResourceRuntimeSpec(
@@ -341,6 +347,164 @@ def test_create_persistent_resource_with_service_account(
341347
name=_TEST_PERSISTENT_RESOURCE_ID
342348
)
343349

350+
@pytest.mark.parametrize("sync", [True, False])
351+
def test_create_persistent_resource_enable_custom_sa_true_no_sa(
352+
self,
353+
create_persistent_resource_mock,
354+
get_persistent_resource_mock,
355+
sync,
356+
):
357+
my_test_resource = persistent_resource.PersistentResource.create(
358+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
359+
resource_pools=[
360+
test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
361+
],
362+
enable_custom_service_account=True,
363+
sync=sync,
364+
)
365+
366+
if not sync:
367+
my_test_resource.wait()
368+
369+
expected_persistent_resource_arg = _get_persistent_resource_proto(
370+
name=_TEST_PERSISTENT_RESOURCE_ID,
371+
)
372+
service_account_spec = persistent_resource_v1.ServiceAccountSpec(
373+
enable_custom_service_account=True,
374+
service_account=None,
375+
)
376+
expected_persistent_resource_arg.resource_runtime_spec = (
377+
persistent_resource_v1.ResourceRuntimeSpec(
378+
service_account_spec=service_account_spec
379+
)
380+
)
381+
382+
create_persistent_resource_mock.assert_called_once_with(
383+
parent=_TEST_PARENT,
384+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
385+
persistent_resource=expected_persistent_resource_arg,
386+
timeout=None,
387+
)
388+
get_persistent_resource_mock.assert_called_once()
389+
_, mock_kwargs = get_persistent_resource_mock.call_args
390+
assert mock_kwargs["name"] == _get_resource_name(
391+
name=_TEST_PERSISTENT_RESOURCE_ID
392+
)
393+
394+
@pytest.mark.parametrize("sync", [True, False])
395+
def test_create_persistent_resource_enable_custom_sa_false_raises_error(
396+
self,
397+
create_persistent_resource_mock,
398+
get_persistent_resource_mock,
399+
sync,
400+
):
401+
with pytest.raises(ValueError) as excinfo:
402+
my_test_resource = persistent_resource.PersistentResource.create(
403+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
404+
resource_pools=[
405+
test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
406+
],
407+
enable_custom_service_account=False,
408+
service_account=_TEST_SERVICE_ACCOUNT,
409+
sync=sync,
410+
)
411+
if not sync:
412+
my_test_resource.wait()
413+
414+
assert str(excinfo.value) == (
415+
"The parameter `enable_custom_service_account` was set to False, "
416+
"but a value was provided for `service_account`. These two "
417+
"settings are incompatible. If you want to use a custom "
418+
"service account, set `enable_custom_service_account` to True."
419+
)
420+
421+
create_persistent_resource_mock.assert_not_called()
422+
get_persistent_resource_mock.assert_not_called()
423+
424+
@pytest.mark.parametrize("sync", [True, False])
425+
def test_create_persistent_resource_enable_custom_sa_none_with_sa(
426+
self,
427+
create_persistent_resource_mock,
428+
get_persistent_resource_mock,
429+
sync,
430+
):
431+
my_test_resource = persistent_resource.PersistentResource.create(
432+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
433+
resource_pools=[
434+
test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
435+
],
436+
enable_custom_service_account=None,
437+
service_account=_TEST_SERVICE_ACCOUNT,
438+
sync=sync,
439+
)
440+
441+
if not sync:
442+
my_test_resource.wait()
443+
444+
expected_persistent_resource_arg = _get_persistent_resource_proto(
445+
name=_TEST_PERSISTENT_RESOURCE_ID,
446+
)
447+
service_account_spec = persistent_resource_v1.ServiceAccountSpec(
448+
enable_custom_service_account=True,
449+
service_account=_TEST_SERVICE_ACCOUNT,
450+
)
451+
expected_persistent_resource_arg.resource_runtime_spec = (
452+
persistent_resource_v1.ResourceRuntimeSpec(
453+
service_account_spec=service_account_spec
454+
)
455+
)
456+
457+
create_persistent_resource_mock.assert_called_once_with(
458+
parent=_TEST_PARENT,
459+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
460+
persistent_resource=expected_persistent_resource_arg,
461+
timeout=None,
462+
)
463+
get_persistent_resource_mock.assert_called_once()
464+
_, mock_kwargs = get_persistent_resource_mock.call_args
465+
assert mock_kwargs["name"] == _get_resource_name(
466+
name=_TEST_PERSISTENT_RESOURCE_ID
467+
)
468+
469+
@pytest.mark.parametrize("sync", [True, False])
470+
def test_create_persistent_resource_enable_custom_sa_none_no_sa(
471+
self,
472+
create_persistent_resource_mock,
473+
get_persistent_resource_mock,
474+
sync,
475+
):
476+
my_test_resource = persistent_resource.PersistentResource.create(
477+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
478+
resource_pools=[
479+
test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
480+
],
481+
enable_custom_service_account=None,
482+
sync=sync,
483+
)
484+
485+
if not sync:
486+
my_test_resource.wait()
487+
488+
expected_persistent_resource_arg = _get_persistent_resource_proto(
489+
name=_TEST_PERSISTENT_RESOURCE_ID,
490+
)
491+
492+
# Assert that resource_runtime_spec is NOT set
493+
call_args = create_persistent_resource_mock.call_args.kwargs
494+
assert "resource_runtime_spec" not in call_args["persistent_resource"]
495+
496+
create_persistent_resource_mock.assert_called_once_with(
497+
parent=_TEST_PARENT,
498+
persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
499+
persistent_resource=expected_persistent_resource_arg,
500+
timeout=None,
501+
)
502+
get_persistent_resource_mock.assert_called_once()
503+
_, mock_kwargs = get_persistent_resource_mock.call_args
504+
assert mock_kwargs["name"] == _get_resource_name(
505+
name=_TEST_PERSISTENT_RESOURCE_ID
506+
)
507+
344508
def test_list_persistent_resources(self, list_persistent_resources_mock):
345509
resource_list = persistent_resource.PersistentResource.list()
346510

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/bf79bdf643c60e72b15f414dba964e9da2eb2d7f

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy