31
31
from googleapiclient .discovery import Resource , build
32
32
33
33
from airflow .exceptions import AirflowException , AirflowNotFoundException
34
+ from airflow .providers .google .cloud .utils .datafusion import DataFusionPipelineType
34
35
from airflow .providers .google .common .hooks .base_google import (
35
36
PROVIDE_PROJECT_ID ,
36
37
GoogleBaseAsyncHook ,
@@ -105,6 +106,7 @@ def wait_for_pipeline_state(
105
106
pipeline_name : str ,
106
107
pipeline_id : str ,
107
108
instance_url : str ,
109
+ pipeline_type : DataFusionPipelineType = DataFusionPipelineType .BATCH ,
108
110
namespace : str = "default" ,
109
111
success_states : list [str ] | None = None ,
110
112
failure_states : list [str ] | None = None ,
@@ -120,6 +122,7 @@ def wait_for_pipeline_state(
120
122
workflow = self .get_pipeline_workflow (
121
123
pipeline_name = pipeline_name ,
122
124
pipeline_id = pipeline_id ,
125
+ pipeline_type = pipeline_type ,
123
126
instance_url = instance_url ,
124
127
namespace = namespace ,
125
128
)
@@ -432,13 +435,14 @@ def get_pipeline_workflow(
432
435
pipeline_name : str ,
433
436
instance_url : str ,
434
437
pipeline_id : str ,
438
+ pipeline_type : DataFusionPipelineType = DataFusionPipelineType .BATCH ,
435
439
namespace : str = "default" ,
436
440
) -> Any :
437
441
url = os .path .join (
438
442
self ._base_url (instance_url , namespace ),
439
443
quote (pipeline_name ),
440
- "workflows " ,
441
- "DataPipelineWorkflow" ,
444
+ f" { self . cdap_program_type ( pipeline_type = pipeline_type ) } s " ,
445
+ self . cdap_program_id ( pipeline_type = pipeline_type ) ,
442
446
"runs" ,
443
447
quote (pipeline_id ),
444
448
)
@@ -453,13 +457,15 @@ def start_pipeline(
453
457
self ,
454
458
pipeline_name : str ,
455
459
instance_url : str ,
460
+ pipeline_type : DataFusionPipelineType = DataFusionPipelineType .BATCH ,
456
461
namespace : str = "default" ,
457
462
runtime_args : dict [str , Any ] | None = None ,
458
463
) -> str :
459
464
"""
460
465
Starts a Cloud Data Fusion pipeline. Works for both batch and stream pipelines.
461
466
462
467
:param pipeline_name: Your pipeline name.
468
+ :param pipeline_type: Optional pipeline type (BATCH by default).
463
469
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
464
470
:param runtime_args: Optional runtime JSON args to be passed to the pipeline
465
471
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
@@ -480,9 +486,9 @@ def start_pipeline(
480
486
body = [
481
487
{
482
488
"appId" : pipeline_name ,
483
- "programType" : "workflow" ,
484
- "programId" : "DataPipelineWorkflow" ,
485
489
"runtimeargs" : runtime_args ,
490
+ "programType" : self .cdap_program_type (pipeline_type = pipeline_type ),
491
+ "programId" : self .cdap_program_id (pipeline_type = pipeline_type ),
486
492
}
487
493
]
488
494
response = self ._cdap_request (url = url , method = "POST" , body = body )
@@ -514,6 +520,30 @@ def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str =
514
520
response , f"Stopping a pipeline failed with code { response .status } "
515
521
)
516
522
523
+ @staticmethod
524
+ def cdap_program_type (pipeline_type : DataFusionPipelineType ) -> str :
525
+ """Retrieves CDAP Program type depending on the pipeline type.
526
+
527
+ :param pipeline_type: Pipeline type.
528
+ """
529
+ program_types = {
530
+ DataFusionPipelineType .BATCH : "workflow" ,
531
+ DataFusionPipelineType .STREAM : "spark" ,
532
+ }
533
+ return program_types .get (pipeline_type , "" )
534
+
535
+ @staticmethod
536
+ def cdap_program_id (pipeline_type : DataFusionPipelineType ) -> str :
537
+ """Retrieves CDAP Program id depending on the pipeline type.
538
+
539
+ :param pipeline_type: Pipeline type.
540
+ """
541
+ program_ids = {
542
+ DataFusionPipelineType .BATCH : "DataPipelineWorkflow" ,
543
+ DataFusionPipelineType .STREAM : "DataStreamsSparkStreaming" ,
544
+ }
545
+ return program_ids .get (pipeline_type , "" )
546
+
517
547
518
548
class DataFusionAsyncHook (GoogleBaseAsyncHook ):
519
549
"""Class to get asynchronous hook for DataFusion."""
@@ -561,10 +591,13 @@ async def get_pipeline(
561
591
pipeline_name : str ,
562
592
pipeline_id : str ,
563
593
session ,
594
+ pipeline_type : DataFusionPipelineType = DataFusionPipelineType .BATCH ,
564
595
):
596
+ program_type = self .sync_hook_class .cdap_program_type (pipeline_type = pipeline_type )
597
+ program_id = self .sync_hook_class .cdap_program_id (pipeline_type = pipeline_type )
565
598
base_url_link = self ._base_url (instance_url , namespace )
566
599
url = urljoin (
567
- base_url_link , f"{ quote (pipeline_name )} /workflows/DataPipelineWorkflow /runs/{ quote (pipeline_id )} "
600
+ base_url_link , f"{ quote (pipeline_name )} /{ program_type } s/ { program_id } /runs/{ quote (pipeline_id )} "
568
601
)
569
602
return await self ._get_link (url = url , session = session )
570
603
@@ -573,6 +606,7 @@ async def get_pipeline_status(
573
606
pipeline_name : str ,
574
607
instance_url : str ,
575
608
pipeline_id : str ,
609
+ pipeline_type : DataFusionPipelineType = DataFusionPipelineType .BATCH ,
576
610
namespace : str = "default" ,
577
611
success_states : list [str ] | None = None ,
578
612
) -> str :
@@ -581,7 +615,8 @@ async def get_pipeline_status(
581
615
582
616
:param pipeline_name: Your pipeline name.
583
617
:param instance_url: Endpoint on which the REST APIs is accessible for the instance.
584
- :param pipeline_id: Unique pipeline ID associated with specific pipeline
618
+ :param pipeline_id: Unique pipeline ID associated with specific pipeline.
619
+ :param pipeline_type: Optional pipeline type (by default batch).
585
620
:param namespace: if your pipeline belongs to a Basic edition instance, the namespace ID
586
621
is always default. If your pipeline belongs to an Enterprise edition instance, you
587
622
can create a namespace.
@@ -596,6 +631,7 @@ async def get_pipeline_status(
596
631
namespace = namespace ,
597
632
pipeline_name = pipeline_name ,
598
633
pipeline_id = pipeline_id ,
634
+ pipeline_type = pipeline_type ,
599
635
session = session ,
600
636
)
601
637
pipeline = await pipeline .json (content_type = None )
0 commit comments