Content-Length: 658234 | pFad | https://github.com/googleapis/python-aiplatform/commit/e45ef96de9f008a5c5556bf119a75403085d8dcb

1A feat: Ray on Vertex enables XGBoost register model with custom versio… · googleapis/python-aiplatform@e45ef96 · GitHub
Skip to content

Commit e45ef96

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Ray on Vertex enables XGBoost register model with custom version using pre-built container
PiperOrigin-RevId: 619575247
1 parent b587a8d commit e45ef96

File tree

4 files changed

+97
-9
lines changed

4 files changed

+97
-9
lines changed

google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/register.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import pickle
2323
import ray
24+
import ray.cloudpickle as cpickle
2425
import tempfile
2526
from typing import Optional, TYPE_CHECKING
2627

@@ -117,7 +118,9 @@ def _get_estimator_from(
117118
118119
Raises:
119120
ValueError: Invalid Argument.
121+
RuntimeError: Model not found.
120122
"""
123+
121124
ray_version = ray.__version__
122125
if ray_version == "2.4.0":
123126
if not isinstance(checkpoint, ray_sklearn.SklearnCheckpoint):
@@ -133,8 +136,25 @@ def _get_estimator_from(
133136
)
134137
return checkpoint.get_estimator()
135138

136-
# get_model() signature changed in future versions
137139
try:
138-
return checkpoint.get_estimator()
140+
return checkpoint.get_model()
139141
except AttributeError:
140-
raise RuntimeError("Unsupported Ray version.")
142+
model_file_name = ray.train.sklearn.SklearnCheckpoint.MODEL_FILENAME
143+
144+
model_path = os.path.join(checkpoint.path, model_file_name)
145+
146+
if os.path.exists(model_path):
147+
with open(model_path, mode="rb") as f:
148+
obj = pickle.load(f)
149+
else:
150+
try:
151+
# Download from GCS to temp and then load_model
152+
with tempfile.TemporaryDirectory() as temp_dir:
153+
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
154+
with open(f"{temp_dir}/{model_file_name}", mode="rb") as f:
155+
obj = cpickle.load(f)
156+
except Exception as e:
157+
raise RuntimeError(
158+
f"{model_file_name} not found in this checkpoint due to: {e}."
159+
)
160+
return obj

google/cloud/aiplatform/preview/vertex_ray/predict/torch/register.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
# limitations under the License.
1717

1818
import logging
19-
from typing import Optional
19+
import os
2020
import ray
21+
from ray.air._internal.torch_utils import load_torch_model
22+
import tempfile
23+
from google.cloud.aiplatform.utils import gcs_utils
24+
from typing import Optional
25+
2126

2227
try:
2328
from ray.train import torch as ray_torch
@@ -51,6 +56,8 @@ def get_pytorch_model_from(
5156
5257
Raises:
5358
ValueError: Invalid Argument.
59+
ModuleNotFoundError: PyTorch isn't installed.
60+
RuntimeError: Model not found.
5461
"""
5562
ray_version = ray.__version__
5663
if ray_version == "2.4.0":
@@ -67,8 +74,33 @@ def get_pytorch_model_from(
6774
)
6875
return checkpoint.get_model(model=model)
6976

70-
# get_model() signature changed in future versions
7177
try:
7278
return checkpoint.get_model()
7379
except AttributeError:
74-
raise RuntimeError("Unsupported Ray version.")
80+
model_file_name = ray.train.torch.TorchCheckpoint.MODEL_FILENAME
81+
82+
model_path = os.path.join(checkpoint.path, model_file_name)
83+
84+
try:
85+
import torch
86+
87+
except ModuleNotFoundError as mnfe:
88+
raise ModuleNotFoundError("PyTorch isn't installed.") from mnfe
89+
90+
if os.path.exists(model_path):
91+
model_or_state_dict = torch.load(model_path, map_location="cpu")
92+
else:
93+
try:
94+
# Download from GCS to temp and then load_model
95+
with tempfile.TemporaryDirectory() as temp_dir:
96+
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
97+
model_or_state_dict = torch.load(
98+
f"{temp_dir}/{model_file_name}", map_location="cpu"
99+
)
100+
except Exception as e:
101+
raise RuntimeError(
102+
f"{model_file_name} not found in this checkpoint due to: {e}."
103+
)
104+
105+
model = load_torch_model(saved_model=model_or_state_dict, model_definition=model)
106+
return model

google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def register_xgboost(
4848
checkpoint: "ray_xgboost.XGBoostCheckpoint",
4949
artifact_uri: Optional[str] = None,
5050
display_name: Optional[str] = None,
51+
xgboost_version: Optional[str] = None,
5152
**kwargs,
5253
) -> aiplatform.Model:
5354
"""Uploads a Ray XGBoost Checkpoint as XGBoost Model to Model Registry.
@@ -75,6 +76,9 @@ def register_xgboost(
7576
display_name (str):
7677
Optional. The display name of the Model. The name can be up to 128
7778
characters long and can be consist of any UTF-8 characters.
79+
xgboost_version (str): Optional. The version of the XGBoost serving container.
80+
Supported versions: ["0.82", "0.90", "1.1", "1.2", "1.3", "1.4", "1.6", "1.7", "2.0"].
81+
If the version is not specified, the latest version is used.
7882
**kwargs:
7983
Any kwargs will be passed to aiplatform.Model registration.
8084
@@ -96,14 +100,16 @@ def register_xgboost(
96100

97101
model_dir = os.path.join(artifact_uri, display_model_name)
98102
file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME)
103+
if xgboost_version is None:
104+
xgboost_version = constants._XGBOOST_VERSION
99105

100106
with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file:
101107
pickle.dump(model, temp_file)
102108
gcs_utils.upload_to_gcs(temp_file.name, file_path)
103109
return aiplatform.Model.upload_xgboost_model_file(
104110
model_file_path=temp_file.name,
105111
display_name=display_model_name,
106-
xgboost_version=constants._XGBOOST_VERSION,
112+
xgboost_version=xgboost_version,
107113
**kwargs,
108114
)
109115

@@ -121,6 +127,8 @@ def _get_xgboost_model_from(
121127
122128
Raises:
123129
ValueError: Invalid Argument.
130+
ModuleNotFoundError: XGBoost isn't installed.
131+
RuntimeError: Model not found.
124132
"""
125133
ray_version = ray.__version__
126134
if ray_version == "2.4.0":
@@ -137,8 +145,33 @@ def _get_xgboost_model_from(
137145
)
138146
return checkpoint.get_model()
139147

140-
# get_model() signature changed in future versions
141148
try:
149+
# This works for Ray v2.5
142150
return checkpoint.get_model()
143151
except AttributeError:
144-
raise RuntimeError("Unsupported Ray version.")
152+
# This works for Ray v2.9
153+
model_file_name = ray.train.xgboost.XGBoostCheckpoint.MODEL_FILENAME
154+
155+
model_path = os.path.join(checkpoint.path, model_file_name)
156+
157+
try:
158+
import xgboost
159+
160+
except ModuleNotFoundError as mnfe:
161+
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
162+
163+
booster = xgboost.Booster()
164+
if os.path.exists(model_path):
165+
booster.load_model(model_path)
166+
return booster
167+
168+
try:
169+
# Download from GCS to temp and then load_model
170+
with tempfile.TemporaryDirectory() as temp_dir:
171+
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
172+
booster.load_model(f"{temp_dir}/{model_file_name}")
173+
return booster
174+
except Exception as e:
175+
raise RuntimeError(
176+
f"{model_file_name} not found in this checkpoint due to: {e}."
177+
)

tests/unit/vertex_ray/test_ray_prediction.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def test_convert_checkpoint_to_sklearn_raise_exception(
288288
"ray.train.sklearn.SklearnCheckpoint .*"
289289
)
290290

291+
@tc.rovminversion
291292
def test_convert_checkpoint_to_sklearn_model_succeed(
292293
self, ray_sklearn_checkpoint
293294
) -> None:
@@ -302,6 +303,7 @@ def test_convert_checkpoint_to_sklearn_model_succeed(
302303
y_pred = estimator.predict([[10, 11]])
303304
assert y_pred[0] is not None
304305

306+
@tc.rovminversion
305307
def test_register_sklearn_succeed(
306308
self,
307309
ray_sklearn_checkpoint,
@@ -325,6 +327,7 @@ def test_register_sklearn_succeed(
325327
pickle_dump.assert_called_once()
326328
gcs_utils_upload_to_gcs.assert_called_once()
327329

330+
@tc.rovminversion
328331
def test_register_sklearn_initialized_succeed(
329332
self,
330333
ray_sklearn_checkpoint,

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/e45ef96de9f008a5c5556bf119a75403085d8dcb

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy