Skip to content

Commit 36a56b9

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support reserved_ip_ranges for VPC network in Ray on Vertex cluster
chore: Update ray prediction tests for forward compatibility PiperOrigin-RevId: 670628417
1 parent 4a528c6 commit 36a56b9

File tree

7 files changed

+45
-5
lines changed

7 files changed

+45
-5
lines changed

google/cloud/aiplatform/vertex_ray/cluster_init.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def create_ray_cluster(
6161
enable_metrics_collection: Optional[bool] = True,
6262
enable_logging: Optional[bool] = True,
6363
psc_interface_config: Optional[resources.PscIConfig] = None,
64+
reserved_ip_ranges: Optional[List[str]] = None,
6465
labels: Optional[Dict[str, str]] = None,
6566
) -> str:
6667
"""Create a ray cluster on the Vertex AI.
@@ -126,6 +127,11 @@ def create_ray_cluster(
126127
enable_metrics_collection: Enable Ray metrics collection for visualization.
127128
enable_logging: Enable exporting Ray logs to Cloud Logging.
128129
psc_interface_config: PSC-I config.
130+
reserved_ip_ranges: A list of names for the reserved IP ranges under
131+
the VPC network that can be used for this cluster. If set, we will
132+
deploy the cluster within the provided IP ranges. Otherwise, the
133+
cluster is deployed to any IP ranges under the provided VPC network.
134+
Example: ["vertex-ai-ip-range"].
129135
labels:
130136
The labels with user-defined metadata to organize Ray cluster.
131137
@@ -325,6 +331,7 @@ def create_ray_cluster(
325331
labels=labels,
326332
resource_runtime_spec=resource_runtime_spec,
327333
psc_interface_config=gapic_psc_interface_config,
334+
reserved_ip_ranges=reserved_ip_ranges,
328335
)
329336

330337
location = initializer.global_config.location

google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@
4343
import xgboost
4444

4545
except ModuleNotFoundError as mnfe:
46-
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
46+
if ray.__version__ == "2.9.3":
47+
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
48+
else:
49+
xgboost = None
4750

4851

4952
def register_xgboost(

google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def persistent_resource_to_cluster(
150150
cluster = Cluster(
151151
cluster_resource_name=persistent_resource.name,
152152
network=persistent_resource.network,
153+
reserved_ip_ranges=persistent_resource.reserved_ip_ranges,
153154
state=persistent_resource.state.name,
154155
labels=persistent_resource.labels,
155156
dashboard_address=dashboard_address,

google/cloud/aiplatform/vertex_ray/util/resources.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ class Cluster:
117117
managed in the Vertex API service. For Ray Job API, VPC network is
118118
not required because cluster connection can be accessed through
119119
dashboard address.
120+
reserved_ip_ranges: A list of names for the reserved IP ranges under
121+
the VPC network that can be used for this cluster. If set, we will
122+
deploy the cluster within the provided IP ranges. Otherwise, the
123+
cluster is deployed to any IP ranges under the provided VPC network.
124+
Example: ["vertex-ai-ip-range"].
120125
service_account: Service account to be used for running Ray programs on
121126
the cluster.
122127
state: Describes the cluster state (defined in PersistentResource.State).
@@ -140,6 +145,7 @@ class Cluster:
140145

141146
cluster_resource_name: str = None
142147
network: str = None
148+
reserved_ip_ranges: List[str] = None
143149
service_account: str = None
144150
state: PersistentResource.State = None
145151
python_version: str = None

tests/unit/vertex_ray/test_cluster_init.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def test_create_ray_cluster_2_pools_custom_images_success(
384384
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
385385
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE,
386386
network=tc.ProjectConstants.TEST_VPC_NETWORK,
387+
reserved_ip_ranges=["vertex-dedicated-range"],
387388
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
388389
)
389390

tests/unit/vertex_ray/test_constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,17 @@
5151
from google.cloud.aiplatform_v1beta1.types.service_networking import (
5252
PscInterfaceConfig,
5353
)
54+
import ray
5455
import pytest
5556

5657

5758
rovminversion = pytest.mark.skipif(
5859
sys.version_info > (3, 10), reason="Requires python3.10 or lower"
5960
)
61+
# TODO(b/363340317)
62+
xgbversion = pytest.mark.skipif(
63+
ray.__version__ != "2.9.3", reason="Requires xgboost 1.7 or higher"
64+
)
6065

6166

6267
@dataclasses.dataclass(frozen=True)
@@ -347,6 +352,7 @@ class ClusterConstants:
347352
),
348353
psc_interface_config=None,
349354
network=ProjectConstants.TEST_VPC_NETWORK,
355+
reserved_ip_ranges=["vertex-dedicated-range"],
350356
)
351357
# Responses
352358
TEST_RESOURCE_POOL_2.replica_count = 1
@@ -366,6 +372,7 @@ class ClusterConstants:
366372
network_attachment=TEST_PSC_NETWORK_ATTACHMENT
367373
),
368374
network=None,
375+
reserved_ip_ranges=None,
369376
resource_runtime=ResourceRuntime(
370377
access_uris={
371378
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
@@ -386,6 +393,7 @@ class ClusterConstants:
386393
),
387394
),
388395
network=ProjectConstants.TEST_VPC_NETWORK,
396+
reserved_ip_ranges=["vertex-dedicated-range"],
389397
resource_runtime=ResourceRuntime(
390398
access_uris={
391399
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
@@ -399,6 +407,7 @@ class ClusterConstants:
399407
python_version="3.10",
400408
ray_version="2.9",
401409
network=ProjectConstants.TEST_VPC_NETWORK,
410+
reserved_ip_ranges=None,
402411
service_account=None,
403412
state="RUNNING",
404413
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
@@ -412,6 +421,7 @@ class ClusterConstants:
412421
python_version="3.10",
413422
ray_version="2.9",
414423
network="",
424+
reserved_ip_ranges="",
415425
service_account=None,
416426
state="RUNNING",
417427
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS,
@@ -424,6 +434,7 @@ class ClusterConstants:
424434
TEST_CLUSTER_CUSTOM_IMAGE = Cluster(
425435
cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS,
426436
network=ProjectConstants.TEST_VPC_NETWORK,
437+
reserved_ip_ranges=["vertex-dedicated-range"],
427438
service_account=None,
428439
state="RUNNING",
429440
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
@@ -438,6 +449,7 @@ class ClusterConstants:
438449
python_version="3.10",
439450
ray_version="2.9",
440451
network="",
452+
reserved_ip_ranges="",
441453
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
442454
state="RUNNING",
443455
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,

tests/unit/vertex_ray/test_ray_prediction.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import numpy as np
4242
import pytest
4343
import ray
44-
from ray.train import xgboost as ray_xgboost
4544
import tensorflow as tf
4645
import torch
4746
import xgboost
@@ -90,9 +89,14 @@ def ray_sklearn_checkpoint():
9089

9190
@pytest.fixture()
9291
def ray_xgboost_checkpoint():
93-
model = test_prediction_utils.get_xgboost_model()
94-
checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster())
95-
return checkpoint
92+
if ray.__version__ == "2.9.3":
93+
from ray.train import xgboost as ray_xgboost
94+
95+
model = test_prediction_utils.get_xgboost_model()
96+
checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster())
97+
return checkpoint
98+
else:
99+
return None
96100

97101

98102
@pytest.fixture()
@@ -374,6 +378,7 @@ def test_register_sklearnartifact_uri_not_gcs_uri_raise_error(
374378
assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*")
375379

376380
# XGBoost Tests
381+
@tc.xgbversion
377382
@tc.rovminversion
378383
def test_convert_checkpoint_to_xgboost_raise_exception(
379384
self, ray_checkpoint_from_dict
@@ -392,6 +397,7 @@ def test_convert_checkpoint_to_xgboost_raise_exception(
392397
"ray.train.xgboost.XGBoostCheckpoint .*"
393398
)
394399

400+
@tc.xgbversion
395401
def test_convert_checkpoint_to_xgboost_model_succeed(
396402
self, ray_xgboost_checkpoint
397403
) -> None:
@@ -406,6 +412,7 @@ def test_convert_checkpoint_to_xgboost_model_succeed(
406412
y_pred = model.predict(xgboost.DMatrix(np.array([[1, 2]])))
407413
assert y_pred[0] is not None
408414

415+
@tc.xgbversion
409416
def test_register_xgboost_succeed(
410417
self,
411418
ray_xgboost_checkpoint,
@@ -429,6 +436,7 @@ def test_register_xgboost_succeed(
429436
pickle_dump.assert_called_once()
430437
gcs_utils_upload_to_gcs.assert_called_once()
431438

439+
@tc.xgbversion
432440
def test_register_xgboost_initialized_succeed(
433441
self,
434442
ray_xgboost_checkpoint,
@@ -455,6 +463,7 @@ def test_register_xgboost_initialized_succeed(
455463
pickle_dump.assert_called_once()
456464
gcs_utils_upload_to_gcs.assert_called_once()
457465

466+
@tc.xgbversion
458467
def test_register_xgboostartifact_uri_is_none_raise_error(
459468
self, ray_xgboost_checkpoint
460469
) -> None:
@@ -467,6 +476,7 @@ def test_register_xgboostartifact_uri_is_none_raise_error(
467476
)
468477
assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*")
469478

479+
@tc.xgbversion
470480
def test_register_xgboostartifact_uri_not_gcs_uri_raise_error(
471481
self, ray_xgboost_checkpoint
472482
) -> None:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy