Content-Length: 776003 | pFad | http://github.com/googleapis/python-aiplatform/commit/99313e0baacd61d7d00d6576a22b151c1d8e1a49

0D feat: Adds the temporal fusion transformer (TFT) forecasting job · googleapis/python-aiplatform@99313e0 · GitHub
Skip to content

Commit 99313e0

Browse files
Mlawrence95copybara-github
authored andcommitted
feat: Adds the temporal fusion transformer (TFT) forecasting job
COPYBARA_INTEGRATE_REVIEW=#1817 from mikelawrence-google:mikealawrence-add-tft-model-support dde8ac0 PiperOrigin-RevId: 494251134
1 parent 43468bd commit 99313e0

File tree

5 files changed

+36
-86
lines changed

5 files changed

+36
-86
lines changed

google/cloud/aiplatform/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
AutoMLTabularTrainingJob,
6969
AutoMLForecastingTrainingJob,
7070
SequenceToSequencePlusForecastingTrainingJob,
71+
TemporalFusionTransformerForecastingTrainingJob,
7172
AutoMLImageTrainingJob,
7273
AutoMLTextTrainingJob,
7374
AutoMLVideoTrainingJob,
@@ -162,6 +163,7 @@
162163
"TensorboardRun",
163164
"TensorboardTimeSeries",
164165
"TextDataset",
166+
"TemporalFusionTransformerForecastingTrainingJob",
165167
"TimeSeriesDataset",
166168
"VideoDataset",
167169
)

google/cloud/aiplatform/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class definition:
2424
automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml"
2525
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
2626
seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
27+
tft_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/temporal_fusion_transformer_time_series_forecasting_1.0.0.yaml"
2728
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
2829
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
2930
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"

google/cloud/aiplatform/training_jobs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5204,19 +5204,31 @@ class column_data_types:
52045204

52055205

52065206
class AutoMLForecastingTrainingJob(_ForecastingTrainingJob):
5207+
"""Class to train AutoML forecasting models."""
5208+
52075209
_model_type = "AutoML"
52085210
_training_task_definition = schema.training_job.definition.automl_forecasting
52095211
_supported_training_schemas = (schema.training_job.definition.automl_forecasting,)
52105212

52115213

52125214
class SequenceToSequencePlusForecastingTrainingJob(_ForecastingTrainingJob):
5215+
"""Class to train Sequence to Sequence (Seq2Seq) forecasting models."""
5216+
52135217
_model_type = "Seq2Seq"
52145218
_training_task_definition = schema.training_job.definition.seq2seq_plus_forecasting
52155219
_supported_training_schemas = (
52165220
schema.training_job.definition.seq2seq_plus_forecasting,
52175221
)
52185222

52195223

5224+
class TemporalFusionTransformerForecastingTrainingJob(_ForecastingTrainingJob):
5225+
"""Class to train Temporal Fusion Transformer (TFT) forecasting models."""
5226+
5227+
_model_type = "TFT"
5228+
_training_task_definition = schema.training_job.definition.tft_forecasting
5229+
_supported_training_schemas = (schema.training_job.definition.tft_forecasting,)
5230+
5231+
52205232
class AutoMLImageTrainingJob(_TrainingJob):
52215233
_supported_training_schemas = (
52225234
schema.training_job.definition.automl_image_classification,

tests/system/aiplatform/test_e2e_forecasting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):
4040
"training_job",
4141
[
4242
training_jobs.AutoMLForecastingTrainingJob,
43+
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
4344
pytest.param(
44-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
45-
marks=pytest.mark.skip(reason="Seq2Seq not yet released."),
45+
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
46+
marks=pytest.mark.skip(reason="TFT not yet released."),
4647
),
4748
],
4849
)

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

Lines changed: 18 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@
183183
_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split"
184184
_TEST_SPLIT_TIMESTAMP_COLUMN_NAME = "timestamp"
185185

186+
_FORECASTING_JOB_MODEL_TYPES = [
187+
training_jobs.AutoMLForecastingTrainingJob,
188+
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
189+
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
190+
]
191+
186192

187193
@pytest.fixture
188194
def mock_pipeline_service_create():
@@ -293,13 +299,7 @@ def teardown_method(self):
293299
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
294300
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
295301
@pytest.mark.parametrize("sync", [True, False])
296-
@pytest.mark.parametrize(
297-
"training_job",
298-
[
299-
training_jobs.AutoMLForecastingTrainingJob,
300-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
301-
],
302-
)
302+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
303303
def test_run_call_pipeline_service_create(
304304
self,
305305
mock_pipeline_service_create,
@@ -401,13 +401,7 @@ def test_run_call_pipeline_service_create(
401401
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
402402
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
403403
@pytest.mark.parametrize("sync", [True, False])
404-
@pytest.mark.parametrize(
405-
"training_job",
406-
[
407-
training_jobs.AutoMLForecastingTrainingJob,
408-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
409-
],
410-
)
404+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
411405
def test_run_call_pipeline_service_create_with_timeout(
412406
self,
413407
mock_pipeline_service_create,
@@ -496,13 +490,7 @@ def test_run_call_pipeline_service_create_with_timeout(
496490
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
497491
@pytest.mark.usefixtures("mock_pipeline_service_get")
498492
@pytest.mark.parametrize("sync", [True, False])
499-
@pytest.mark.parametrize(
500-
"training_job",
501-
[
502-
training_jobs.AutoMLForecastingTrainingJob,
503-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
504-
],
505-
)
493+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
506494
def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
507495
self,
508496
mock_pipeline_service_create,
@@ -584,13 +572,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
584572
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
585573
@pytest.mark.usefixtures("mock_pipeline_service_get")
586574
@pytest.mark.parametrize("sync", [True, False])
587-
@pytest.mark.parametrize(
588-
"training_job",
589-
[
590-
training_jobs.AutoMLForecastingTrainingJob,
591-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
592-
],
593-
)
575+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
594576
def test_run_call_pipeline_if_set_additional_experiments(
595577
self,
596578
mock_pipeline_service_create,
@@ -675,13 +657,7 @@ def test_run_call_pipeline_if_set_additional_experiments(
675657
"mock_model_service_get",
676658
)
677659
@pytest.mark.parametrize("sync", [True, False])
678-
@pytest.mark.parametrize(
679-
"training_job",
680-
[
681-
training_jobs.AutoMLForecastingTrainingJob,
682-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
683-
],
684-
)
660+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
685661
def test_run_called_twice_raises(
686662
self,
687663
mock_dataset_time_series,
@@ -762,13 +738,7 @@ def test_run_called_twice_raises(
762738
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
763739
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
764740
@pytest.mark.parametrize("sync", [True, False])
765-
@pytest.mark.parametrize(
766-
"training_job",
767-
[
768-
training_jobs.AutoMLForecastingTrainingJob,
769-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
770-
],
771-
)
741+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
772742
def test_run_raises_if_pipeline_fails(
773743
self,
774744
mock_pipeline_service_create_and_get_with_fail,
@@ -823,13 +793,7 @@ def test_run_raises_if_pipeline_fails(
823793
with pytest.raises(RuntimeError):
824794
job.get_model()
825795

826-
@pytest.mark.parametrize(
827-
"training_job",
828-
[
829-
training_jobs.AutoMLForecastingTrainingJob,
830-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
831-
],
832-
)
796+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
833797
def test_raises_before_run_is_called(
834798
self,
835799
mock_pipeline_service_create,
@@ -855,13 +819,7 @@ def test_raises_before_run_is_called(
855819
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
856820
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
857821
@pytest.mark.parametrize("sync", [True, False])
858-
@pytest.mark.parametrize(
859-
"training_job",
860-
[
861-
training_jobs.AutoMLForecastingTrainingJob,
862-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
863-
],
864-
)
822+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
865823
def test_splits_fraction(
866824
self,
867825
mock_pipeline_service_create,
@@ -960,13 +918,7 @@ def test_splits_fraction(
960918
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
961919
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
962920
@pytest.mark.parametrize("sync", [True, False])
963-
@pytest.mark.parametrize(
964-
"training_job",
965-
[
966-
training_jobs.AutoMLForecastingTrainingJob,
967-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
968-
],
969-
)
921+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
970922
def test_splits_timestamp(
971923
self,
972924
mock_pipeline_service_create,
@@ -1067,13 +1019,7 @@ def test_splits_timestamp(
10671019
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
10681020
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
10691021
@pytest.mark.parametrize("sync", [True, False])
1070-
@pytest.mark.parametrize(
1071-
"training_job",
1072-
[
1073-
training_jobs.AutoMLForecastingTrainingJob,
1074-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
1075-
],
1076-
)
1022+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
10771023
def test_splits_predefined(
10781024
self,
10791025
mock_pipeline_service_create,
@@ -1168,13 +1114,7 @@ def test_splits_predefined(
11681114
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
11691115
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
11701116
@pytest.mark.parametrize("sync", [True, False])
1171-
@pytest.mark.parametrize(
1172-
"training_job",
1173-
[
1174-
training_jobs.AutoMLForecastingTrainingJob,
1175-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
1176-
],
1177-
)
1117+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
11781118
def test_splits_default(
11791119
self,
11801120
mock_pipeline_service_create,
@@ -1264,13 +1204,7 @@ def test_splits_default(
12641204
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
12651205
@pytest.mark.usefixtures("mock_pipeline_service_get")
12661206
@pytest.mark.parametrize("sync", [True, False])
1267-
@pytest.mark.parametrize(
1268-
"training_job",
1269-
[
1270-
training_jobs.AutoMLForecastingTrainingJob,
1271-
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
1272-
],
1273-
)
1207+
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
12741208
def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference(
12751209
self,
12761210
mock_pipeline_service_create,

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/googleapis/python-aiplatform/commit/99313e0baacd61d7d00d6576a22b151c1d8e1a49

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy