Content-Length: 847118 | pFad | http://github.com/apache/airflow/commit/a8e451981572fa09a96660992e68e046c4baa75f

FE Fix Vertex AI Custom Job training issue (#25367) · apache/airflow@a8e4519 · GitHub
Skip to content

Commit a8e4519

Browse files
authored
Fix Vertex AI Custom Job training issue (#25367)
1 parent 4dc1778 commit a8e4519

File tree

3 files changed

+58
-28
lines changed

3 files changed

+58
-28
lines changed

airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ def extract_model_id(obj: Dict) -> str:
246246
"""Returns unique id of the Model."""
247247
return obj["name"].rpartition("/")[-1]
248248

249+
@staticmethod
250+
def extract_training_id(resource_name: str) -> str:
251+
"""Returns unique id of the Training pipeline."""
252+
return resource_name.rpartition("/")[-1]
253+
249254
def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None):
250255
"""Waits for long-lasting operation to complete."""
251256
try:
@@ -299,7 +304,7 @@ def _run_job(
299304
timestamp_split_column_name: Optional[str] = None,
300305
tensorboard: Optional[str] = None,
301306
sync=True,
302-
) -> models.Model:
307+
) -> Tuple[Optional[models.Model], str]:
303308
"""Run Job for training pipeline"""
304309
model = job.run(
305310
dataset=dataset,
@@ -329,11 +334,17 @@ def _run_job(
329334
tensorboard=tensorboard,
330335
sync=sync,
331336
)
337+
training_id = self.extract_training_id(job.resource_name)
332338
if model:
333339
model.wait()
334-
return model
335340
else:
336-
raise AirflowException("Training did not produce a Managed Model returning None.")
341+
self.log.warning(
342+
"Training did not produce a Managed Model returning None. Training Pipeline is not "
343+
"configured to upload a Model. Create the Training Pipeline with "
344+
"model_serving_container_image_uri and model_display_name passed in. "
345+
"Ensure that your training script saves to model to os.environ['AIP_MODEL_DIR']."
346+
)
347+
return model, training_id
337348

338349
@GoogleBaseHook.fallback_to_default_project_id
339350
def cancel_pipeline_job(
@@ -618,7 +629,7 @@ def create_custom_container_training_job(
618629
timestamp_split_column_name: Optional[str] = None,
619630
tensorboard: Optional[str] = None,
620631
sync=True,
621-
) -> models.Model:
632+
) -> Tuple[Optional[models.Model], str]:
622633
"""
623634
Create Custom Container Training Job
624635
@@ -890,7 +901,7 @@ def create_custom_container_training_job(
890901
if not self._job:
891902
raise AirflowException("CustomJob was not created")
892903

893-
model = self._run_job(
904+
model, training_id = self._run_job(
894905
job=self._job,
895906
dataset=dataset,
896907
annotation_schema_uri=annotation_schema_uri,
@@ -920,7 +931,7 @@ def create_custom_container_training_job(
920931
sync=sync,
921932
)
922933

923-
return model
934+
return model, training_id
924935

925936
@GoogleBaseHook.fallback_to_default_project_id
926937
def create_custom_python_package_training_job(
@@ -980,7 +991,7 @@ def create_custom_python_package_training_job(
980991
timestamp_split_column_name: Optional[str] = None,
981992
tensorboard: Optional[str] = None,
982993
sync=True,
983-
) -> models.Model:
994+
) -> Tuple[Optional[models.Model], str]:
984995
"""
985996
Create Custom Python Package Training Job
986997
@@ -1252,7 +1263,7 @@ def create_custom_python_package_training_job(
12521263
if not self._job:
12531264
raise AirflowException("CustomJob was not created")
12541265

1255-
model = self._run_job(
1266+
model, training_id = self._run_job(
12561267
job=self._job,
12571268
dataset=dataset,
12581269
annotation_schema_uri=annotation_schema_uri,
@@ -1282,7 +1293,7 @@ def create_custom_python_package_training_job(
12821293
sync=sync,
12831294
)
12841295

1285-
return model
1296+
return model, training_id
12861297

12871298
@GoogleBaseHook.fallback_to_default_project_id
12881299
def create_custom_training_job(
@@ -1342,7 +1353,7 @@ def create_custom_training_job(
13421353
timestamp_split_column_name: Optional[str] = None,
13431354
tensorboard: Optional[str] = None,
13441355
sync=True,
1345-
) -> models.Model:
1356+
) -> Tuple[Optional[models.Model], str]:
13461357
"""
13471358
Create Custom Training Job
13481359
@@ -1614,7 +1625,7 @@ def create_custom_training_job(
16141625
if not self._job:
16151626
raise AirflowException("CustomJob was not created")
16161627

1617-
model = self._run_job(
1628+
model, training_id = self._run_job(
16181629
job=self._job,
16191630
dataset=dataset,
16201631
annotation_schema_uri=annotation_schema_uri,
@@ -1644,7 +1655,7 @@ def create_custom_training_job(
16441655
sync=sync,
16451656
)
16461657

1647-
return model
1658+
return model, training_id
16481659

16491660
@GoogleBaseHook.fallback_to_default_project_id
16501661
def delete_pipeline_job(

airflow/providers/google/cloud/operators/vertex_ai/custom_job.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929

3030
from airflow.models import BaseOperator
3131
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
32-
from airflow.providers.google.cloud.links.vertex_ai import VertexAIModelLink, VertexAITrainingPipelinesLink
32+
from airflow.providers.google.cloud.links.vertex_ai import (
33+
VertexAIModelLink,
34+
VertexAITrainingLink,
35+
VertexAITrainingPipelinesLink,
36+
)
3337

3438
if TYPE_CHECKING:
3539
from airflow.utils.context import Context
@@ -411,7 +415,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
411415
'command',
412416
'impersonation_chain',
413417
]
414-
operator_extra_links = (VertexAIModelLink(),)
418+
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
415419

416420
def __init__(
417421
self,
@@ -428,7 +432,7 @@ def execute(self, context: "Context"):
428432
delegate_to=self.delegate_to,
429433
impersonation_chain=self.impersonation_chain,
430434
)
431-
model = self.hook.create_custom_container_training_job(
435+
model, training_id = self.hook.create_custom_container_training_job(
432436
project_id=self.project_id,
433437
region=self.region,
434438
display_name=self.display_name,
@@ -478,9 +482,13 @@ def execute(self, context: "Context"):
478482
sync=True,
479483
)
480484

481-
result = Model.to_dict(model)
482-
model_id = self.hook.extract_model_id(result)
483-
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
485+
if model:
486+
result = Model.to_dict(model)
487+
model_id = self.hook.extract_model_id(result)
488+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
489+
else:
490+
result = model # type: ignore
491+
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
484492
return result
485493

486494
def on_kill(self) -> None:
@@ -755,7 +763,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
755763
'region',
756764
'impersonation_chain',
757765
]
758-
operator_extra_links = (VertexAIModelLink(),)
766+
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
759767

760768
def __init__(
761769
self,
@@ -774,7 +782,7 @@ def execute(self, context: "Context"):
774782
delegate_to=self.delegate_to,
775783
impersonation_chain=self.impersonation_chain,
776784
)
777-
model = self.hook.create_custom_python_package_training_job(
785+
model, training_id = self.hook.create_custom_python_package_training_job(
778786
project_id=self.project_id,
779787
region=self.region,
780788
display_name=self.display_name,
@@ -825,9 +833,13 @@ def execute(self, context: "Context"):
825833
sync=True,
826834
)
827835

828-
result = Model.to_dict(model)
829-
model_id = self.hook.extract_model_id(result)
830-
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
836+
if model:
837+
result = Model.to_dict(model)
838+
model_id = self.hook.extract_model_id(result)
839+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
840+
else:
841+
result = model # type: ignore
842+
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
831843
return result
832844

833845
def on_kill(self) -> None:
@@ -1104,7 +1116,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
11041116
'requirements',
11051117
'impersonation_chain',
11061118
]
1107-
operator_extra_links = (VertexAIModelLink(),)
1119+
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())
11081120

11091121
def __init__(
11101122
self,
@@ -1123,7 +1135,7 @@ def execute(self, context: "Context"):
11231135
delegate_to=self.delegate_to,
11241136
impersonation_chain=self.impersonation_chain,
11251137
)
1126-
model = self.hook.create_custom_training_job(
1138+
model, training_id = self.hook.create_custom_training_job(
11271139
project_id=self.project_id,
11281140
region=self.region,
11291141
display_name=self.display_name,
@@ -1174,9 +1186,13 @@ def execute(self, context: "Context"):
11741186
sync=True,
11751187
)
11761188

1177-
result = Model.to_dict(model)
1178-
model_id = self.hook.extract_model_id(result)
1179-
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1189+
if model:
1190+
result = Model.to_dict(model)
1191+
model_id = self.hook.extract_model_id(result)
1192+
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
1193+
else:
1194+
result = model # type: ignore
1195+
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
11801196
return result
11811197

11821198
def on_kill(self) -> None:

tests/providers/google/cloud/operators/test_vertex_ai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
class TestVertexAICreateCustomContainerTrainingJobOperator:
171171
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
172172
def test_execute(self, mock_hook):
173+
mock_hook.return_value.create_custom_container_training_job.return_value = (None, 'training_id')
173174
op = CreateCustomContainerTrainingJobOperator(
174175
task_id=TASK_ID,
175176
gcp_conn_id=GCP_CONN_ID,
@@ -250,6 +251,7 @@ def test_execute(self, mock_hook):
250251
class TestVertexAICreateCustomPythonPackageTrainingJobOperator:
251252
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
252253
def test_execute(self, mock_hook):
254+
mock_hook.return_value.create_custom_python_package_training_job.return_value = (None, 'training_id')
253255
op = CreateCustomPythonPackageTrainingJobOperator(
254256
task_id=TASK_ID,
255257
gcp_conn_id=GCP_CONN_ID,
@@ -332,6 +334,7 @@ def test_execute(self, mock_hook):
332334
class TestVertexAICreateCustomTrainingJobOperator:
333335
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
334336
def test_execute(self, mock_hook):
337+
mock_hook.return_value.create_custom_training_job.return_value = (None, 'training_id')
335338
op = CreateCustomTrainingJobOperator(
336339
task_id=TASK_ID,
337340
gcp_conn_id=GCP_CONN_ID,

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/apache/airflow/commit/a8e451981572fa09a96660992e68e046c4baa75f

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy