Skip to content

Commit e0c6227

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support custom service account for Ray cluster creation and Ray Client connection
PiperOrigin-RevId: 631998839
1 parent cc8bc96 commit e0c6227

File tree

7 files changed

+248
-19
lines changed

7 files changed

+248
-19
lines changed

google/cloud/aiplatform/preview/vertex_ray/client_builder.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,21 @@ def __init__(self, address: Optional[str]) -> None:
9898
public_address = self.response.resource_runtime.access_uris.get(
9999
"RAY_CLIENT_ENDPOINT"
100100
)
101+
service_account = (
102+
self.response.resource_runtime_spec.service_account_spec.service_account
103+
)
104+
101105
if public_address is None:
102106
address = private_address
107+
if service_account:
108+
raise ValueError(
109+
"[Ray on Vertex AI]: Ray Cluster ",
110+
address,
111+
" failed to start Head node properly because custom service"
112+
" account isn't supported in peered VPC network. Use public"
113+
" endpoint instead (createa a cluster withought specifying"
114+
" VPC network).",
115+
)
103116
else:
104117
address = public_address
105118

@@ -110,17 +123,7 @@ def __init__(self, address: Optional[str]) -> None:
110123
persistent_resource_id,
111124
" Head node is not reachable. Please ensure that a valid VPC network has been specified.",
112125
)
113-
# Handling service_account
114-
service_account = (
115-
self.response.resource_runtime_spec.service_account_spec.service_account
116-
)
117126

118-
if service_account:
119-
raise ValueError(
120-
"[Ray on Vertex AI]: Ray Cluster ",
121-
address,
122-
" failed to start Head node properly because custom service account isn't supported.",
123-
)
124127
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
125128
cluster = _gapic_utils.persistent_resource_to_cluster(
126129
persistent_resource=self.response

google/cloud/aiplatform/preview/vertex_ray/cluster_init.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
RayMetricSpec,
3333
ResourcePool,
3434
ResourceRuntimeSpec,
35+
ServiceAccountSpec,
3536
)
3637

3738
from google.cloud.aiplatform.preview.vertex_ray.util import (
@@ -48,6 +49,7 @@ def create_ray_cluster(
4849
python_version: Optional[str] = "3.10",
4950
ray_version: Optional[str] = "2.9",
5051
network: Optional[str] = None,
52+
service_account: Optional[str] = None,
5153
cluster_name: Optional[str] = None,
5254
worker_node_types: Optional[List[resources.Resources]] = None,
5355
custom_images: Optional[resources.NodeImages] = None,
@@ -78,7 +80,9 @@ def create_ray_cluster(
7880
7981
cluster_resource_name = vertex_ray.create_ray_cluster(
8082
head_node_type=head_node_type,
81-
network="projects/my-project-number/global/networks/my-vpc-name",
83+
network="projects/my-project-number/global/networks/my-vpc-name", # Optional
84+
service_account="my-service-account@my-project-number.iam.gserviceaccount.com", # Optional
85+
cluster_name="my-cluster-name", # Optional
8286
worker_node_types=worker_node_types,
8387
ray_version="2.9",
8488
)
@@ -100,6 +104,8 @@ def create_ray_cluster(
100104
Vertex API service. For Ray Job API, VPC network is not required
101105
because Ray Cluster connection can be accessed through dashboard
102106
address.
107+
service_account: Service account to be used for running Ray programs on
108+
the cluster.
103109
cluster_name: This value may be up to 63 characters, and valid
104110
characters are `[a-z0-9_-]`. The first character cannot be a number
105111
or hyphen.
@@ -254,7 +260,17 @@ def create_ray_cluster(
254260
ray_spec = RaySpec(
255261
resource_pool_images=resource_pool_images, ray_metric_spec=ray_metric_spec
256262
)
257-
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
263+
if service_account:
264+
service_account_spec = ServiceAccountSpec(
265+
enable_custom_service_account=True,
266+
service_account=service_account,
267+
)
268+
resource_runtime_spec = ResourceRuntimeSpec(
269+
ray_spec=ray_spec,
270+
service_account_spec=service_account_spec,
271+
)
272+
else:
273+
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
258274
persistent_resource = PersistentResource(
259275
resource_pools=resource_pools,
260276
network=network,

google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ def persistent_resource_to_cluster(
166166
head_image_uri = (
167167
persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[head_id]
168168
)
169-
169+
if persistent_resource.resource_runtime_spec.service_account_spec.service_account:
170+
cluster.service_account = (
171+
persistent_resource.resource_runtime_spec.service_account_spec.service_account
172+
)
170173
if not head_image_uri:
171174
head_image_uri = persistent_resource.resource_runtime_spec.ray_spec.image_uri
172175

google/cloud/aiplatform/preview/vertex_ray/util/resources.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Resources:
4141
us-docker.pkg.dev/my-project/ray-gpu.2-9.py310-tf:latest).
4242
"""
4343

44-
machine_type: Optional[str] = "n1-standard-8"
44+
machine_type: Optional[str] = "n1-standard-16"
4545
node_count: Optional[int] = 1
4646
accelerator_type: Optional[str] = None
4747
accelerator_count: Optional[int] = 0
@@ -81,6 +81,8 @@ class Cluster:
8181
managed in the Vertex API service. For Ray Job API, VPC network is
8282
not required because cluster connection can be accessed through
8383
dashboard address.
84+
service_account: Service account to be used for running Ray programs on
85+
the cluster.
8486
state: Describes the cluster state (defined in PersistentResource.State).
8587
python_version: Python version for the ray cluster (e.g. "3.10").
8688
ray_version: Ray version for the ray cluster (e.g. "2.4").
@@ -102,6 +104,7 @@ class Cluster:
102104

103105
cluster_resource_name: str = None
104106
network: str = None
107+
service_account: str = None
105108
state: PersistentResource.State = None
106109
python_version: str = None
107110
ray_version: str = None

tests/unit/vertex_ray/test_cluster_init.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,34 @@ def get_persistent_resource_1_pool_custom_image_mock():
9292
yield get_persistent_resource_1_pool_custom_image_mock
9393

9494

95+
@pytest.fixture
96+
def create_persistent_resource_1_pool_byosa_mock():
97+
with mock.patch.object(
98+
PersistentResourceServiceClient,
99+
"create_persistent_resource",
100+
) as create_persistent_resource_1_pool_byosa_mock:
101+
create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation)
102+
create_persistent_resource_lro_mock.result.return_value = (
103+
tc.ClusterConstants.TEST_RESPONSE_RUNNING_1_POOL_BYOSA
104+
)
105+
create_persistent_resource_1_pool_byosa_mock.return_value = (
106+
create_persistent_resource_lro_mock
107+
)
108+
yield create_persistent_resource_1_pool_byosa_mock
109+
110+
111+
@pytest.fixture
112+
def get_persistent_resource_1_pool_byosa_mock():
113+
with mock.patch.object(
114+
PersistentResourceServiceClient,
115+
"get_persistent_resource",
116+
) as get_persistent_resource_1_pool_byosa_mock:
117+
get_persistent_resource_1_pool_byosa_mock.return_value = (
118+
tc.ClusterConstants.TEST_RESPONSE_RUNNING_1_POOL_BYOSA
119+
)
120+
yield get_persistent_resource_1_pool_byosa_mock
121+
122+
95123
@pytest.fixture
96124
def create_persistent_resource_2_pools_mock():
97125
with mock.patch.object(
@@ -426,6 +454,30 @@ def test_create_ray_cluster_initialized_success(
426454
]
427455
)
428456

457+
@pytest.mark.usefixtures("get_persistent_resource_1_pool_byosa_mock")
458+
def test_create_ray_cluster_byosa_success(
459+
self, create_persistent_resource_1_pool_byosa_mock
460+
):
461+
"""If head and worker nodes are duplicate, merge to head pool."""
462+
cluster_name = vertex_ray.create_ray_cluster(
463+
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_1_POOL,
464+
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_1_POOL,
465+
service_account=tc.ProjectConstants.TEST_SERVICE_ACCOUNT,
466+
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
467+
)
468+
469+
assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
470+
471+
request = persistent_resource_service.CreatePersistentResourceRequest(
472+
parent=tc.ProjectConstants.TEST_PARENT,
473+
persistent_resource=tc.ClusterConstants.TEST_REQUEST_RUNNING_1_POOL_BYOSA,
474+
persistent_resource_id=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
475+
)
476+
477+
create_persistent_resource_1_pool_byosa_mock.assert_called_with(
478+
request,
479+
)
480+
429481
def test_create_ray_cluster_head_multinode_error(self):
430482
with pytest.raises(ValueError) as e:
431483
vertex_ray.create_ray_cluster(
@@ -508,6 +560,16 @@ def test_get_ray_cluster_with_custom_image_success(
508560
get_persistent_resource_2_pools_custom_image_mock.assert_called_once()
509561
cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE)
510562

563+
def test_get_ray_cluster_byosa_success(
564+
self, get_persistent_resource_1_pool_byosa_mock
565+
):
566+
cluster = vertex_ray.get_ray_cluster(
567+
cluster_resource_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS
568+
)
569+
570+
get_persistent_resource_1_pool_byosa_mock.assert_called_once()
571+
cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER_BYOSA)
572+
511573
@pytest.mark.usefixtures("get_persistent_resource_exception_mock")
512574
def test_get_ray_cluster_error(self):
513575
with pytest.raises(ValueError) as e:

tests/unit/vertex_ray/test_constants.py

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import dataclasses
19+
import sys
1920

2021
from google.cloud.aiplatform.preview.vertex_ray.util.resources import Cluster
2122
from google.cloud.aiplatform.preview.vertex_ray.util.resources import (
@@ -28,10 +29,10 @@
2829
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
2930
PersistentResource,
3031
)
31-
from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec
3232
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
3333
RayMetricSpec,
3434
)
35+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec
3536
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
3637
ResourcePool,
3738
)
@@ -41,9 +42,11 @@
4142
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
4243
ResourceRuntimeSpec,
4344
)
44-
45+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
46+
ServiceAccountSpec,
47+
)
4548
import pytest
46-
import sys
49+
4750

4851
rovminversion = pytest.mark.skipif(
4952
sys.version_info > (3, 10), reason="Requires python3.10 or lower"
@@ -67,6 +70,7 @@ class ProjectConstants:
6770
TEST_MODEL_ID = (
6871
f"projects/{TEST_GCP_PROJECT_NUMBER}/locations/{TEST_GCP_REGION}/models/456"
6972
)
73+
TEST_SERVICE_ACCOUNT = "service-account@project.iam.gserviceaccount.com"
7074

7175

7276
@dataclasses.dataclass(frozen=True)
@@ -79,6 +83,9 @@ class ClusterConstants:
7983
TEST_VERTEX_RAY_DASHBOARD_ADDRESS = (
8084
"48b400ad90b8dd3c-dot-us-central1.aiplatform-training.googleusercontent.com"
8185
)
86+
TEST_VERTEX_RAY_CLIENT_ENDPOINT = (
87+
"88888.us-central1-1234567.staging-ray.vertexai.goog:443"
88+
)
8289
TEST_VERTEX_RAY_PR_ID = "user-persistent-resource-1234567890"
8390
TEST_VERTEX_RAY_PR_ADDRESS = (
8491
f"{ProjectConstants.TEST_PARENT}/persistentResources/" + TEST_VERTEX_RAY_PR_ID
@@ -106,7 +113,7 @@ class ClusterConstants:
106113
TEST_RESOURCE_POOL_0 = ResourcePool(
107114
id="head-node",
108115
machine_spec=MachineSpec(
109-
machine_type="n1-standard-8",
116+
machine_type="n1-standard-16",
110117
accelerator_type="NVIDIA_TESLA_P100",
111118
accelerator_count=1,
112119
),
@@ -147,6 +154,20 @@ class ClusterConstants:
147154
),
148155
network=ProjectConstants.TEST_VPC_NETWORK,
149156
)
157+
TEST_REQUEST_RUNNING_1_POOL_BYOSA = PersistentResource(
158+
resource_pools=[TEST_RESOURCE_POOL_0],
159+
resource_runtime_spec=ResourceRuntimeSpec(
160+
ray_spec=RaySpec(
161+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
162+
ray_metric_spec=RayMetricSpec(disabled=False),
163+
),
164+
service_account_spec=ServiceAccountSpec(
165+
enable_custom_service_account=True,
166+
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
167+
),
168+
),
169+
network=None,
170+
)
150171
# Get response has generated name, and URIs
151172
TEST_RESPONSE_RUNNING_1_POOL = PersistentResource(
152173
name=TEST_VERTEX_RAY_PR_ADDRESS,
@@ -185,6 +206,50 @@ class ClusterConstants:
185206
),
186207
state="RUNNING",
187208
)
209+
TEST_RESPONSE_RUNNING_1_POOL_BYOSA = PersistentResource(
210+
name=TEST_VERTEX_RAY_PR_ADDRESS,
211+
resource_pools=[TEST_RESOURCE_POOL_0],
212+
resource_runtime_spec=ResourceRuntimeSpec(
213+
ray_spec=RaySpec(
214+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
215+
ray_metric_spec=RayMetricSpec(disabled=False),
216+
),
217+
service_account_spec=ServiceAccountSpec(
218+
enable_custom_service_account=True,
219+
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
220+
),
221+
),
222+
network=None,
223+
resource_runtime=ResourceRuntime(
224+
access_uris={
225+
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
226+
"RAY_CLIENT_ENDPOINT": TEST_VERTEX_RAY_CLIENT_ENDPOINT,
227+
}
228+
),
229+
state="RUNNING",
230+
)
231+
TEST_RESPONSE_1_POOL_BYOSA_PRIVATE = PersistentResource(
232+
name=TEST_VERTEX_RAY_PR_ADDRESS,
233+
resource_pools=[TEST_RESOURCE_POOL_0],
234+
resource_runtime_spec=ResourceRuntimeSpec(
235+
ray_spec=RaySpec(
236+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
237+
ray_metric_spec=RayMetricSpec(disabled=False),
238+
),
239+
service_account_spec=ServiceAccountSpec(
240+
enable_custom_service_account=True,
241+
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
242+
),
243+
),
244+
network=ProjectConstants.TEST_VPC_NETWORK,
245+
resource_runtime=ResourceRuntime(
246+
access_uris={
247+
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
248+
"RAY_CLIENT_ENDPOINT": TEST_VERTEX_RAY_CLIENT_ENDPOINT,
249+
}
250+
),
251+
state="RUNNING",
252+
)
188253
# 2_POOL: worker_node_types and head_node_type have different MachineSpecs
189254
TEST_HEAD_NODE_TYPE_2_POOLS = Resources()
190255
TEST_WORKER_NODE_TYPES_2_POOLS = [
@@ -208,7 +273,7 @@ class ClusterConstants:
208273
TEST_RESOURCE_POOL_1 = ResourcePool(
209274
id="head-node",
210275
machine_spec=MachineSpec(
211-
machine_type="n1-standard-8",
276+
machine_type="n1-standard-16",
212277
),
213278
disk_spec=DiskSpec(
214279
boot_disk_type="pd-ssd",
@@ -302,6 +367,7 @@ class ClusterConstants:
302367
python_version="3.10",
303368
ray_version="2.9",
304369
network=ProjectConstants.TEST_VPC_NETWORK,
370+
service_account=None,
305371
state="RUNNING",
306372
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
307373
worker_node_types=TEST_WORKER_NODE_TYPES_1_POOL,
@@ -312,6 +378,7 @@ class ClusterConstants:
312378
python_version="3.10",
313379
ray_version="2.9",
314380
network=ProjectConstants.TEST_VPC_NETWORK,
381+
service_account=None,
315382
state="RUNNING",
316383
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS,
317384
worker_node_types=TEST_WORKER_NODE_TYPES_2_POOLS,
@@ -320,11 +387,23 @@ class ClusterConstants:
320387
TEST_CLUSTER_CUSTOM_IMAGE = Cluster(
321388
cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS,
322389
network=ProjectConstants.TEST_VPC_NETWORK,
390+
service_account=None,
323391
state="RUNNING",
324392
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
325393
worker_node_types=TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE,
326394
dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
327395
)
396+
TEST_CLUSTER_BYOSA = Cluster(
397+
cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS,
398+
python_version="3.10",
399+
ray_version="2.9",
400+
network="",
401+
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
402+
state="RUNNING",
403+
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
404+
worker_node_types=TEST_WORKER_NODE_TYPES_1_POOL,
405+
dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
406+
)
328407
TEST_BEARER_TOKEN = "test-bearer-token"
329408
TEST_HEADERS = {
330409
"Content-Type": "application/json",

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy