183
183
_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split"
184
184
_TEST_SPLIT_TIMESTAMP_COLUMN_NAME = "timestamp"
185
185
186
+ _FORECASTING_JOB_MODEL_TYPES = [
187
+ training_jobs .AutoMLForecastingTrainingJob ,
188
+ training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
189
+ training_jobs .TemporalFusionTransformerForecastingTrainingJob ,
190
+ ]
191
+
186
192
187
193
@pytest .fixture
188
194
def mock_pipeline_service_create ():
@@ -293,13 +299,7 @@ def teardown_method(self):
293
299
@mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
294
300
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
295
301
@pytest .mark .parametrize ("sync" , [True , False ])
296
- @pytest .mark .parametrize (
297
- "training_job" ,
298
- [
299
- training_jobs .AutoMLForecastingTrainingJob ,
300
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
301
- ],
302
- )
302
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
303
303
def test_run_call_pipeline_service_create (
304
304
self ,
305
305
mock_pipeline_service_create ,
@@ -401,13 +401,7 @@ def test_run_call_pipeline_service_create(
401
401
@mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
402
402
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
403
403
@pytest .mark .parametrize ("sync" , [True , False ])
404
- @pytest .mark .parametrize (
405
- "training_job" ,
406
- [
407
- training_jobs .AutoMLForecastingTrainingJob ,
408
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
409
- ],
410
- )
404
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
411
405
def test_run_call_pipeline_service_create_with_timeout (
412
406
self ,
413
407
mock_pipeline_service_create ,
@@ -496,13 +490,7 @@ def test_run_call_pipeline_service_create_with_timeout(
496
490
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
497
491
@pytest .mark .usefixtures ("mock_pipeline_service_get" )
498
492
@pytest .mark .parametrize ("sync" , [True , False ])
499
- @pytest .mark .parametrize (
500
- "training_job" ,
501
- [
502
- training_jobs .AutoMLForecastingTrainingJob ,
503
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
504
- ],
505
- )
493
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
506
494
def test_run_call_pipeline_if_no_model_display_name_nor_model_labels (
507
495
self ,
508
496
mock_pipeline_service_create ,
@@ -584,13 +572,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
584
572
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
585
573
@pytest .mark .usefixtures ("mock_pipeline_service_get" )
586
574
@pytest .mark .parametrize ("sync" , [True , False ])
587
- @pytest .mark .parametrize (
588
- "training_job" ,
589
- [
590
- training_jobs .AutoMLForecastingTrainingJob ,
591
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
592
- ],
593
- )
575
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
594
576
def test_run_call_pipeline_if_set_additional_experiments (
595
577
self ,
596
578
mock_pipeline_service_create ,
@@ -675,13 +657,7 @@ def test_run_call_pipeline_if_set_additional_experiments(
675
657
"mock_model_service_get" ,
676
658
)
677
659
@pytest .mark .parametrize ("sync" , [True , False ])
678
- @pytest .mark .parametrize (
679
- "training_job" ,
680
- [
681
- training_jobs .AutoMLForecastingTrainingJob ,
682
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
683
- ],
684
- )
660
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
685
661
def test_run_called_twice_raises (
686
662
self ,
687
663
mock_dataset_time_series ,
@@ -762,13 +738,7 @@ def test_run_called_twice_raises(
762
738
@mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
763
739
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
764
740
@pytest .mark .parametrize ("sync" , [True , False ])
765
- @pytest .mark .parametrize (
766
- "training_job" ,
767
- [
768
- training_jobs .AutoMLForecastingTrainingJob ,
769
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
770
- ],
771
- )
741
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
772
742
def test_run_raises_if_pipeline_fails (
773
743
self ,
774
744
mock_pipeline_service_create_and_get_with_fail ,
@@ -823,13 +793,7 @@ def test_run_raises_if_pipeline_fails(
823
793
with pytest .raises (RuntimeError ):
824
794
job .get_model ()
825
795
826
- @pytest .mark .parametrize (
827
- "training_job" ,
828
- [
829
- training_jobs .AutoMLForecastingTrainingJob ,
830
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
831
- ],
832
- )
796
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
833
797
def test_raises_before_run_is_called (
834
798
self ,
835
799
mock_pipeline_service_create ,
@@ -855,13 +819,7 @@ def test_raises_before_run_is_called(
855
819
@mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
856
820
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
857
821
@pytest .mark .parametrize ("sync" , [True , False ])
858
- @pytest .mark .parametrize (
859
- "training_job" ,
860
- [
861
- training_jobs .AutoMLForecastingTrainingJob ,
862
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
863
- ],
864
- )
822
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
865
823
def test_splits_fraction (
866
824
self ,
867
825
mock_pipeline_service_create ,
@@ -960,13 +918,7 @@ def test_splits_fraction(
960
918
@mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
961
919
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
962
920
@pytest .mark .parametrize ("sync" , [True , False ])
963
- @pytest .mark .parametrize (
964
- "training_job" ,
965
- [
966
- training_jobs .AutoMLForecastingTrainingJob ,
967
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
968
- ],
969
- )
921
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
970
922
def test_splits_timestamp (
971
923
self ,
972
924
mock_pipeline_service_create ,
@@ -1067,13 +1019,7 @@ def test_splits_timestamp(
1067
1019
@mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
1068
1020
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
1069
1021
@pytest .mark .parametrize ("sync" , [True , False ])
1070
- @pytest .mark .parametrize (
1071
- "training_job" ,
1072
- [
1073
- training_jobs .AutoMLForecastingTrainingJob ,
1074
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
1075
- ],
1076
- )
1022
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
1077
1023
def test_splits_predefined (
1078
1024
self ,
1079
1025
mock_pipeline_service_create ,
@@ -1168,13 +1114,7 @@ def test_splits_predefined(
1168
1114
@mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
1169
1115
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
1170
1116
@pytest .mark .parametrize ("sync" , [True , False ])
1171
- @pytest .mark .parametrize (
1172
- "training_job" ,
1173
- [
1174
- training_jobs .AutoMLForecastingTrainingJob ,
1175
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
1176
- ],
1177
- )
1117
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
1178
1118
def test_splits_default (
1179
1119
self ,
1180
1120
mock_pipeline_service_create ,
@@ -1264,13 +1204,7 @@ def test_splits_default(
1264
1204
@mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
1265
1205
@pytest .mark .usefixtures ("mock_pipeline_service_get" )
1266
1206
@pytest .mark .parametrize ("sync" , [True , False ])
1267
- @pytest .mark .parametrize (
1268
- "training_job" ,
1269
- [
1270
- training_jobs .AutoMLForecastingTrainingJob ,
1271
- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
1272
- ],
1273
- )
1207
+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
1274
1208
def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference (
1275
1209
self ,
1276
1210
mock_pipeline_service_create ,
0 commit comments