Content-Length: 978683 | pFad | https://github.com/apache/airflow/commit/20df60de24e1dbeab2dcf5b989b69080d1b3ed34

21 Dataflow Operators - use project and location from job in on_kill met… · apache/airflow@20df60d · GitHub
Skip to content

Commit 20df60d

Browse files
author
Łukasz Wyszomirski
authored
Dataflow Operators - use project and location from job in on_kill method. (#18699)
Reason why we need this is because we can have situation where project_id is set to None but we define it in the dataflow_default_options. Job will start normally without error but in case when we decide to mark running task to different state we will get a error that the job does not exits.
1 parent 6103b26 commit 20df60d

File tree

4 files changed

+97
-45
lines changed

4 files changed

+97
-45
lines changed

airflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ def start_template_dataflow(
623623
project_id: str,
624624
append_job_name: bool = True,
625625
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
626+
on_new_job_callback: Optional[Callable[[dict], None]] = None,
626627
location: str = DEFAULT_DATAFLOW_LOCATION,
627628
environment: Optional[dict] = None,
628629
) -> dict:
@@ -648,8 +649,10 @@ def start_template_dataflow(
648649
If set to None or missing, the default project_id from the Google Cloud connection is used.
649650
:param append_job_name: True if unique suffix has to be appended to job name.
650651
:type append_job_name: bool
651-
:param on_new_job_id_callback: Callback called when the job ID is known.
652+
:param on_new_job_id_callback: (Deprecated) Callback called when the Job is known.
652653
:type on_new_job_id_callback: callable
654+
:param on_new_job_callback: Callback called when the Job is known.
655+
:type on_new_job_callback: callable
653656
:param location: Job location.
654657
:type location: str
655658
:type environment: Optional, Map of job runtime environment options.
@@ -713,15 +716,24 @@ def start_template_dataflow(
713716
)
714717
response = request.execute(num_retries=self.num_retries)
715718

716-
job_id = response["job"]["id"]
719+
job = response["job"]
720+
717721
if on_new_job_id_callback:
718-
on_new_job_id_callback(job_id)
722+
warnings.warn(
723+
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
724+
DeprecationWarning,
725+
stacklevel=3,
726+
)
727+
on_new_job_id_callback(job.get("id"))
728+
729+
if on_new_job_callback:
730+
on_new_job_callback(job)
719731

720732
jobs_controller = _DataflowJobsController(
721733
dataflow=self.get_conn(),
722734
project_number=project_id,
723735
name=name,
724-
job_id=job_id,
736+
job_id=job["id"],
725737
location=location,
726738
poll_sleep=self.poll_sleep,
727739
num_retries=self.num_retries,
@@ -739,6 +751,7 @@ def start_flex_template(
739751
location: str,
740752
project_id: str,
741753
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
754+
on_new_job_callback: Optional[Callable[[dict], None]] = None,
742755
):
743756
"""
744757
Starts flex templates with the Dataflow pipeline.
@@ -750,7 +763,8 @@ def start_flex_template(
750763
:param project_id: The ID of the GCP project that owns the job.
751764
If set to ``None`` or missing, the default project_id from the GCP connection is used.
752765
:type project_id: Optional[str]
753-
:param on_new_job_id_callback: A callback that is called when a Job ID is detected.
766+
:param on_new_job_id_callback: (Deprecated) A callback that is called when a Job ID is detected.
767+
:param on_new_job_callback: A callback that is called when a Job is detected.
754768
:return: the Job
755769
"""
756770
service = self.get_conn()
@@ -761,15 +775,23 @@ def start_flex_template(
761775
.launch(projectId=project_id, body=body, location=location)
762776
)
763777
response = request.execute(num_retries=self.num_retries)
764-
job_id = response["job"]["id"]
778+
job = response["job"]
765779

766780
if on_new_job_id_callback:
767-
on_new_job_id_callback(job_id)
781+
warnings.warn(
782+
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
783+
DeprecationWarning,
784+
stacklevel=3,
785+
)
786+
on_new_job_id_callback(job.get("id"))
787+
788+
if on_new_job_callback:
789+
on_new_job_callback(job)
768790

769791
jobs_controller = _DataflowJobsController(
770792
dataflow=self.get_conn(),
771793
project_number=project_id,
772-
job_id=job_id,
794+
job_id=job.get("id"),
773795
location=location,
774796
poll_sleep=self.poll_sleep,
775797
num_retries=self.num_retries,
@@ -973,6 +995,7 @@ def start_sql_job(
973995
project_id: str,
974996
location: str = DEFAULT_DATAFLOW_LOCATION,
975997
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
998+
on_new_job_callback: Optional[Callable[[dict], None]] = None,
976999
):
9771000
"""
9781001
Starts Dataflow SQL query.
@@ -991,8 +1014,10 @@ def start_sql_job(
9911014
:param project_id: The ID of the GCP project that owns the job.
9921015
If set to ``None`` or missing, the default project_id from the GCP connection is used.
9931016
:type project_id: Optional[str]
994-
:param on_new_job_id_callback: Callback called when the job ID is known.
1017+
:param on_new_job_id_callback: (Deprecated) Callback called when the job ID is known.
9951018
:type on_new_job_id_callback: callable
1019+
:param on_new_job_callback: Callback called when the job is known.
1020+
:type on_new_job_callback: callable
9961021
:return: the new job object
9971022
"""
9981023
cmd = [
@@ -1018,8 +1043,6 @@ def start_sql_job(
10181043
job_id = proc.stdout.decode().strip()
10191044

10201045
self.log.info("Created job ID: %s", job_id)
1021-
if on_new_job_id_callback:
1022-
on_new_job_id_callback(job_id)
10231046

10241047
jobs_controller = _DataflowJobsController(
10251048
dataflow=self.get_conn(),
@@ -1031,8 +1054,20 @@ def start_sql_job(
10311054
drain_pipeline=self.drain_pipeline,
10321055
wait_until_finished=self.wait_until_finished,
10331056
)
1034-
jobs_controller.wait_for_done()
1057+
job = jobs_controller.get_jobs(refresh=True)[0]
10351058

1059+
if on_new_job_id_callback:
1060+
warnings.warn(
1061+
"on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
1062+
DeprecationWarning,
1063+
stacklevel=3,
1064+
)
1065+
on_new_job_id_callback(job.get("id"))
1066+
1067+
if on_new_job_callback:
1068+
on_new_job_callback(job)
1069+
1070+
jobs_controller.wait_for_done()
10361071
return jobs_controller.get_jobs(refresh=True)[0]
10371072

10381073
@GoogleBaseHook.fallback_to_default_project_id

airflow/providers/google/cloud/operators/dataflow.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def __init__(
657657
self.gcp_conn_id = gcp_conn_id
658658
self.delegate_to = delegate_to
659659
self.poll_sleep = poll_sleep
660-
self.job_id = None
660+
self.job = None
661661
self.hook: Optional[DataflowHook] = None
662662
self.impersonation_chain = impersonation_chain
663663
self.environment = environment
@@ -674,8 +674,8 @@ def execute(self, context) -> dict:
674674
wait_until_finished=self.wait_until_finished,
675675
)
676676

677-
def set_current_job_id(job_id):
678-
self.job_id = job_id
677+
def set_current_job(current_job):
678+
self.job = current_job
679679

680680
options = self.dataflow_default_options
681681
options.update(self.options)
@@ -684,7 +684,7 @@ def set_current_job_id(job_id):
684684
variables=options,
685685
parameters=self.parameters,
686686
dataflow_template=self.template,
687-
on_new_job_id_callback=set_current_job_id,
687+
on_new_job_callback=set_current_job,
688688
project_id=self.project_id,
689689
location=self.location,
690690
environment=self.environment,
@@ -694,8 +694,12 @@ def set_current_job_id(job_id):
694694

695695
def on_kill(self) -> None:
696696
self.log.info("On kill.")
697-
if self.job_id:
698-
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
697+
if self.job:
698+
self.hook.cancel_job(
699+
job_id=self.job.get("id"),
700+
project_id=self.job.get("projectId"),
701+
location=self.job.get("location"),
702+
)
699703

700704

701705
class DataflowStartFlexTemplateOperator(BaseOperator):
@@ -787,7 +791,7 @@ def __init__(
787791
self.drain_pipeline = drain_pipeline
788792
self.cancel_timeout = cancel_timeout
789793
self.wait_until_finished = wait_until_finished
790-
self.job_id = None
794+
self.job = None
791795
self.hook: Optional[DataflowHook] = None
792796

793797
def execute(self, context):
@@ -799,22 +803,26 @@ def execute(self, context):
799803
wait_until_finished=self.wait_until_finished,
800804
)
801805

802-
def set_current_job_id(job_id):
803-
self.job_id = job_id
806+
def set_current_job(current_job):
807+
self.job = current_job
804808

805809
job = self.hook.start_flex_template(
806810
body=self.body,
807811
location=self.location,
808812
project_id=self.project_id,
809-
on_new_job_id_callback=set_current_job_id,
813+
on_new_job_callback=set_current_job,
810814
)
811815

812816
return job
813817

814818
def on_kill(self) -> None:
815819
self.log.info("On kill.")
816-
if self.job_id:
817-
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
820+
if self.job:
821+
self.hook.cancel_job(
822+
job_id=self.job.get("id"),
823+
project_id=self.job.get("projectId"),
824+
location=self.job.get("location"),
825+
)
818826

819827

820828
class DataflowStartSqlJobOperator(BaseOperator):
@@ -890,7 +898,7 @@ def __init__(
890898
self.gcp_conn_id = gcp_conn_id
891899
self.delegate_to = delegate_to
892900
self.drain_pipeline = drain_pipeline
893-
self.job_id = None
901+
self.job = None
894902
self.hook: Optional[DataflowHook] = None
895903

896904
def execute(self, context):
@@ -900,24 +908,28 @@ def execute(self, context):
900908
drain_pipeline=self.drain_pipeline,
901909
)
902910

903-
def set_current_job_id(job_id):
904-
self.job_id = job_id
911+
def set_current_job(current_job):
912+
self.job = current_job
905913

906914
job = self.hook.start_sql_job(
907915
job_name=self.job_name,
908916
query=self.query,
909917
options=self.options,
910918
location=self.location,
911919
project_id=self.project_id,
912-
on_new_job_id_callback=set_current_job_id,
920+
on_new_job_callback=set_current_job,
913921
)
914922

915923
return job
916924

917925
def on_kill(self) -> None:
918926
self.log.info("On kill.")
919-
if self.job_id:
920-
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
927+
if self.job:
928+
self.hook.cancel_job(
929+
job_id=self.job.get("id"),
930+
project_id=self.job.get("projectId"),
931+
location=self.job.get("location"),
932+
)
921933

922934

923935
class DataflowCreatePythonJobOperator(BaseOperator):

tests/providers/google/cloud/hooks/test_dataflow.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,19 +1016,21 @@ def test_start_template_dataflow_update_runtime_env(self, mock_conn, mock_datafl
10161016
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
10171017
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
10181018
def test_start_flex_template(self, mock_conn, mock_controller):
1019+
expected_job = {"id": TEST_JOB_ID}
1020+
10191021
mock_locations = mock_conn.return_value.projects.return_value.locations
10201022
launch_method = mock_locations.return_value.flexTemplates.return_value.launch
1021-
launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}}
1023+
launch_method.return_value.execute.return_value = {"job": expected_job}
10221024
mock_controller.return_value.get_jobs.return_value = [{"id": TEST_JOB_ID}]
10231025

1024-
on_new_job_id_callback = mock.MagicMock()
1026+
on_new_job_callback = mock.MagicMock()
10251027
result = self.dataflow_hook.start_flex_template(
10261028
body={"launchParameter": TEST_FLEX_PARAMETERS},
10271029
location=TEST_LOCATION,
10281030
project_id=TEST_PROJECT_ID,
1029-
on_new_job_id_callback=on_new_job_id_callback,
1031+
on_new_job_callback=on_new_job_callback,
10301032
)
1031-
on_new_job_id_callback.assert_called_once_with(TEST_JOB_ID)
1033+
on_new_job_callback.assert_called_once_with(expected_job)
10321034
launch_method.assert_called_once_with(
10331035
projectId='test-project-id',
10341036
body={'launchParameter': TEST_FLEX_PARAMETERS},
@@ -1080,14 +1082,15 @@ def test_start_sql_job_failed_to_run(
10801082
mock_run.return_value = mock.MagicMock(
10811083
stdout=f"{TEST_JOB_ID}\n".encode(), stderr=f"{TEST_JOB_ID}\n".encode(), returncode=0
10821084
)
1083-
on_new_job_id_callback = mock.MagicMock()
1085+
on_new_job_callback = mock.MagicMock()
1086+
10841087
result = self.dataflow_hook.start_sql_job(
10851088
job_name=TEST_SQL_JOB_NAME,
10861089
query=TEST_SQL_QUERY,
10871090
options=TEST_SQL_OPTIONS,
10881091
location=TEST_LOCATION,
10891092
project_id=TEST_PROJECT,
1090-
on_new_job_id_callback=on_new_job_id_callback,
1093+
on_new_job_callback=on_new_job_callback,
10911094
)
10921095
mock_run.assert_called_once_with(
10931096
[
@@ -1135,7 +1138,7 @@ def test_start_sql_job(self, mock_run, mock_provide_authorized_gcloud, mock_get_
11351138
options=TEST_SQL_OPTIONS,
11361139
location=TEST_LOCATION,
11371140
project_id=TEST_PROJECT,
1138-
on_new_job_id_callback=mock.MagicMock(),
1141+
on_new_job_callback=mock.MagicMock(),
11391142
)
11401143

11411144

tests/providers/google/cloud/operators/test_dataflow.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
bigquery.table.test-project.beam_samples.beam_table
9090
GROUP BY sales_region;
9191
"""
92-
TEST_SQL_JOB_ID = 'test-job-id'
92+
TEST_SQL_JOB = {'id': 'test-job-id'}
9393

9494

9595
class TestDataflowPythonOperator(unittest.TestCase):
@@ -410,7 +410,7 @@ def test_exec(self, dataflow_mock):
410410
variables=expected_options,
411411
parameters=PARAMETERS,
412412
dataflow_template=TEMPLATE,
413-
on_new_job_id_callback=mock.ANY,
413+
on_new_job_callback=mock.ANY,
414414
project_id=None,
415415
location=TEST_LOCATION,
416416
environment={'maxWorkers': 2},
@@ -432,7 +432,7 @@ def test_execute(self, mock_dataflow):
432432
body={"launchParameter": TEST_FLEX_PARAMETERS},
433433
location=TEST_LOCATION,
434434
project_id=TEST_PROJECT,
435-
on_new_job_id_callback=mock.ANY,
435+
on_new_job_callback=mock.ANY,
436436
)
437437

438438
def test_on_kill(self):
@@ -444,10 +444,10 @@ def test_on_kill(self):
444444
project_id=TEST_PROJECT,
445445
)
446446
start_flex_template.hook = mock.MagicMock()
447-
start_flex_template.job_id = JOB_ID
447+
start_flex_template.job = {"id": JOB_ID, "projectId": TEST_PROJECT, "location": TEST_LOCATION}
448448
start_flex_template.on_kill()
449449
start_flex_template.hook.cancel_job.assert_called_once_with(
450-
job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT
450+
job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT, location=TEST_LOCATION
451451
)
452452

453453

@@ -473,8 +473,10 @@ def test_execute(self, mock_hook):
473473
options=TEST_SQL_OPTIONS,
474474
location=TEST_LOCATION,
475475
project_id=None,
476-
on_new_job_id_callback=mock.ANY,
476+
on_new_job_callback=mock.ANY,
477477
)
478-
start_sql.job_id = TEST_SQL_JOB_ID
478+
start_sql.job = TEST_SQL_JOB
479479
start_sql.on_kill()
480-
mock_hook.return_value.cancel_job.assert_called_once_with(job_id='test-job-id', project_id=None)
480+
mock_hook.return_value.cancel_job.assert_called_once_with(
481+
job_id='test-job-id', project_id=None, location=None
482+
)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/apache/airflow/commit/20df60de24e1dbeab2dcf5b989b69080d1b3ed34

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy