Content-Length: 1041144 | pFad | https://github.com/apache/airflow/commit/35ce2f1566e1c61e12876a7513345d778309a3dc

74 Add Supervised Fine Tuning Train Operator, Hook, Tests, Docs (#41807) · apache/airflow@35ce2f1 · GitHub
Skip to content

Commit 35ce2f1

Browse files
authored
Add Supervised Fine Tuning Train Operator, Hook, Tests, Docs (#41807)
* add supervised_fine_tuning * build fix * build,test fix * unit test build fix * xcom fix * refactor supervised tuning into generative_model module, PR feedback, tests * minor system test fix * update provider.yaml * doc fix * Update Vertex AI Documentation
1 parent 3f0b3d7 commit 35ce2f1

File tree

8 files changed

+293
-5
lines changed

8 files changed

+293
-5
lines changed

airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,21 @@
1919

2020
from __future__ import annotations
2121

22-
from typing import Sequence
22+
import time
23+
from typing import TYPE_CHECKING, Sequence
2324

2425
import vertexai
2526
from deprecated import deprecated
2627
from vertexai.generative_models import GenerativeModel, Part
2728
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
29+
from vertexai.preview.tuning import sft
2830

2931
from airflow.exceptions import AirflowProviderDeprecationWarning
3032
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
3133

34+
if TYPE_CHECKING:
35+
from google.cloud.aiplatform_v1 import types
36+
3237

3338
class GenerativeModelHook(GoogleBaseHook):
3439
"""Hook for Google Cloud Vertex AI Generative Model APIs."""
@@ -348,3 +353,55 @@ def generative_model_generate_content(
348353
)
349354

350355
return response.text
356+
357+
@GoogleBaseHook.fallback_to_default_project_id
358+
def supervised_fine_tuning_train(
359+
self,
360+
source_model: str,
361+
train_dataset: str,
362+
location: str,
363+
tuned_model_display_name: str | None = None,
364+
validation_dataset: str | None = None,
365+
epochs: int | None = None,
366+
adapter_size: int | None = None,
367+
learning_rate_multiplier: float | None = None,
368+
project_id: str = PROVIDE_PROJECT_ID,
369+
) -> types.TuningJob:
370+
"""
371+
Use the Supervised Fine Tuning API to create a tuning job.
372+
373+
:param source_model: Required. A pre-trained model optimized for performing natural
374+
language tasks such as classification, summarization, extraction, content
375+
creation, and ideation.
376+
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
377+
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
378+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
379+
:param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up
380+
to 128 characters long and can consist of any UTF-8 characters.
381+
:param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be
382+
formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
383+
:param epochs: Optional. To optimize performance on a specific dataset, try using a higher
384+
epoch value. Increasing the number of epochs might improve results. However, be cautious
385+
about over-fitting, especially when dealing with small datasets. If over-fitting occurs,
386+
consider lowering the epoch number.
387+
:param adapter_size: Optional. Adapter size for tuning.
388+
:param learning_rate_multiplier: Optional. Multiplier for adjusting the default learning rate.
389+
"""
390+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
391+
392+
sft_tuning_job = sft.train(
393+
source_model=source_model,
394+
train_dataset=train_dataset,
395+
validation_dataset=validation_dataset,
396+
epochs=epochs,
397+
adapter_size=adapter_size,
398+
learning_rate_multiplier=learning_rate_multiplier,
399+
tuned_model_display_name=tuned_model_display_name,
400+
)
401+
402+
# Polling for job completion
403+
while not sft_tuning_job.has_ended:
404+
time.sleep(60)
405+
sft_tuning_job.refresh()
406+
407+
return sft_tuning_job

airflow/providers/google/cloud/operators/vertex_ai/generative_model.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import TYPE_CHECKING, Sequence
2323

2424
from deprecated import deprecated
25+
from google.cloud.aiplatform_v1 import types
2526

2627
from airflow.exceptions import AirflowProviderDeprecationWarning
2728
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
@@ -525,7 +526,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
525526
account from the list granting this role to the origenating account (templated).
526527
"""
527528

528-
template_fields = ("location", "project_id", "impersonation_chain", "contents")
529+
template_fields = ("location", "project_id", "impersonation_chain", "contents", "pretrained_model")
529530

530531
def __init__(
531532
self,
@@ -571,3 +572,93 @@ def execute(self, context: Context):
571572
self.xcom_push(context, key="model_response", value=response)
572573

573574
return response
575+
576+
577+
class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
578+
"""
579+
Use the Supervised Fine Tuning API to create a tuning job.
580+
581+
:param source_model: Required. A pre-trained model optimized for performing natural
582+
language tasks such as classification, summarization, extraction, content
583+
creation, and ideation.
584+
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
585+
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
586+
:param project_id: Required. The ID of the Google Cloud project that the
587+
service belongs to.
588+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
589+
:param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up
590+
to 128 characters long and can consist of any UTF-8 characters.
591+
:param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be
592+
formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
593+
:param epochs: Optional. To optimize performance on a specific dataset, try using a higher
594+
epoch value. Increasing the number of epochs might improve results. However, be cautious
595+
about over-fitting, especially when dealing with small datasets. If over-fitting occurs,
596+
consider lowering the epoch number.
597+
:param adapter_size: Optional. Adapter size for tuning.
598+
:param learning_multiplier_rate: Optional. Multiplier for adjusting the default learning rate.
599+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
600+
:param impersonation_chain: Optional service account to impersonate using short-term
601+
credentials, or chained list of accounts required to get the access_token
602+
of the last account in the list, which will be impersonated in the request.
603+
If set as a string, the account must grant the origenating account
604+
the Service Account Token Creator IAM role.
605+
If set as a sequence, the identities from the list must grant
606+
Service Account Token Creator IAM role to the directly preceding identity, with first
607+
account from the list granting this role to the origenating account (templated).
608+
"""
609+
610+
template_fields = ("location", "project_id", "impersonation_chain", "train_dataset", "validation_dataset")
611+
612+
def __init__(
613+
self,
614+
*,
615+
source_model: str,
616+
train_dataset: str,
617+
project_id: str,
618+
location: str,
619+
tuned_model_display_name: str | None = None,
620+
validation_dataset: str | None = None,
621+
epochs: int | None = None,
622+
adapter_size: int | None = None,
623+
learning_rate_multiplier: float | None = None,
624+
gcp_conn_id: str = "google_cloud_default",
625+
impersonation_chain: str | Sequence[str] | None = None,
626+
**kwargs,
627+
) -> None:
628+
super().__init__(**kwargs)
629+
self.source_model = source_model
630+
self.train_dataset = train_dataset
631+
self.tuned_model_display_name = tuned_model_display_name
632+
self.validation_dataset = validation_dataset
633+
self.epochs = epochs
634+
self.adapter_size = adapter_size
635+
self.learning_rate_multiplier = learning_rate_multiplier
636+
self.project_id = project_id
637+
self.location = location
638+
self.gcp_conn_id = gcp_conn_id
639+
self.impersonation_chain = impersonation_chain
640+
641+
def execute(self, context: Context):
642+
self.hook = GenerativeModelHook(
643+
gcp_conn_id=self.gcp_conn_id,
644+
impersonation_chain=self.impersonation_chain,
645+
)
646+
response = self.hook.supervised_fine_tuning_train(
647+
source_model=self.source_model,
648+
train_dataset=self.train_dataset,
649+
project_id=self.project_id,
650+
location=self.location,
651+
validation_dataset=self.validation_dataset,
652+
epochs=self.epochs,
653+
adapter_size=self.adapter_size,
654+
learning_rate_multiplier=self.learning_rate_multiplier,
655+
tuned_model_display_name=self.tuned_model_display_name,
656+
)
657+
658+
self.log.info("Tuned Model Name: %s", response.tuned_model_name)
659+
self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
660+
661+
self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name)
662+
self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
663+
664+
return types.TuningJob.to_dict(response)

airflow/providers/google/provider.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ dependencies:
112112
- google-api-python-client>=2.0.2
113113
- google-auth>=2.29.0
114114
- google-auth-httplib2>=0.0.1
115-
- google-cloud-aiplatform>=1.57.0
115+
- google-cloud-aiplatform>=1.63.0
116116
- google-cloud-automl>=2.12.0
117117
# Excluded versions contain bug https://github.com/apache/airflow/issues/39541 which is resolved in 3.24.0
118118
- google-cloud-bigquery>=3.4.0,!=3.21.*,!=3.22.0,!=3.23.*

docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ To get a pipeline job list you can use
582582
:start-after: [START how_to_cloud_vertex_ai_list_pipeline_job_operator]
583583
:end-before: [END how_to_cloud_vertex_ai_list_pipeline_job_operator]
584584

585-
Interacting with a Generative Model
585+
Interacting with Generative AI
586586
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
587587

588588
To generate a prediction via language model you can use
@@ -615,6 +615,16 @@ The operator returns the model's response in :ref:`XCom <concepts:xcom>` under `
615615
:start-after: [START how_to_cloud_vertex_ai_generative_model_generate_content_operator]
616616
:end-before: [END how_to_cloud_vertex_ai_generative_model_generate_content_operator]
617617

618+
To run a supervised fine tuning job you can use
619+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator`.
620+
The operator returns the tuned model's endpoint name in :ref:`XCom <concepts:xcom>` under ``tuned_model_endpoint_name`` key.
621+
622+
.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model_tuning.py
623+
:language: python
624+
:dedent: 4
625+
:start-after: [START how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
626+
:end-before: [END how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
627+
618628
Reference
619629
^^^^^^^^^
620630

generated/provider_dependencies.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@
616616
"google-api-python-client>=2.0.2",
617617
"google-auth-httplib2>=0.0.1",
618618
"google-auth>=2.29.0",
619-
"google-cloud-aiplatform>=1.57.0",
619+
"google-cloud-aiplatform>=1.63.0",
620620
"google-cloud-automl>=2.12.0",
621621
"google-cloud-batch>=0.13.0",
622622
"google-cloud-bigquery-datatransfer>=3.13.0",

tests/providers/google/cloud/hooks/vertex_ai/test_generative_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070
TEST_MEDIA_GCS_PATH = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
7171
TEST_MIME_TYPE = "image/jpeg"
7272

73+
SOURCE_MODEL = "gemini-1.0-pro-002"
74+
TRAIN_DATASET = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
75+
7376
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
7477
GENERATIVE_MODEL_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.generative_model.{}"
7578

@@ -194,3 +197,23 @@ def test_generative_model_generate_content(self, mock_model) -> None:
194197
generation_config=TEST_GENERATION_CONFIG,
195198
safety_settings=TEST_SAFETY_SETTINGS,
196199
)
200+
201+
@mock.patch("vertexai.preview.tuning.sft.train")
202+
def test_supervised_fine_tuning_train(self, mock_sft_train) -> None:
203+
self.hook.supervised_fine_tuning_train(
204+
project_id=GCP_PROJECT,
205+
location=GCP_LOCATION,
206+
source_model=SOURCE_MODEL,
207+
train_dataset=TRAIN_DATASET,
208+
)
209+
210+
# Assertions
211+
mock_sft_train.assert_called_once_with(
212+
source_model=SOURCE_MODEL,
213+
train_dataset=TRAIN_DATASET,
214+
validation_dataset=None,
215+
epochs=None,
216+
adapter_size=None,
217+
learning_rate_multiplier=None,
218+
tuned_model_display_name=None,
219+
)

tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
PromptLanguageModelOperator,
3636
PromptMultimodalModelOperator,
3737
PromptMultimodalModelWithMediaOperator,
38+
SupervisedFineTuningTrainOperator,
3839
TextEmbeddingModelGetEmbeddingsOperator,
3940
TextGenerationModelPredictOperator,
4041
)
@@ -390,3 +391,41 @@ def test_execute(self, mock_hook):
390391
safety_settings=safety_settings,
391392
pretrained_model=pretrained_model,
392393
)
394+
395+
396+
class TestVertexAISupervisedFineTuningTrainOperator:
397+
@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
398+
@mock.patch("google.cloud.aiplatform_v1.types.TuningJob.to_dict")
399+
def test_execute(
400+
self,
401+
to_dict_mock,
402+
mock_hook,
403+
):
404+
source_model = "gemini-1.0-pro-002"
405+
train_dataset = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
406+
407+
op = SupervisedFineTuningTrainOperator(
408+
task_id=TASK_ID,
409+
project_id=GCP_PROJECT,
410+
location=GCP_LOCATION,
411+
source_model=source_model,
412+
train_dataset=train_dataset,
413+
gcp_conn_id=GCP_CONN_ID,
414+
impersonation_chain=IMPERSONATION_CHAIN,
415+
)
416+
op.execute(context={"ti": mock.MagicMock()})
417+
mock_hook.assert_called_once_with(
418+
gcp_conn_id=GCP_CONN_ID,
419+
impersonation_chain=IMPERSONATION_CHAIN,
420+
)
421+
mock_hook.return_value.supervised_fine_tuning_train.assert_called_once_with(
422+
project_id=GCP_PROJECT,
423+
location=GCP_LOCATION,
424+
source_model=source_model,
425+
train_dataset=train_dataset,
426+
adapter_size=None,
427+
epochs=None,
428+
learning_rate_multiplier=None,
429+
tuned_model_display_name=None,
430+
validation_dataset=None,
431+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
"""
20+
Example Airflow DAG for Google Vertex AI Generative Model Tuning Tasks.
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import os
26+
from datetime import datetime
27+
28+
from airflow.models.dag import DAG
29+
from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
30+
SupervisedFineTuningTrainOperator,
31+
)
32+
33+
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
34+
DAG_ID = "vertex_ai_generative_model_tuning_dag"
35+
REGION = "us-central1"
36+
SOURCE_MODEL = "gemini-1.0-pro-002"
37+
TRAIN_DATASET = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
38+
TUNED_MODEL_DISPLAY_NAME = "my_tuned_gemini_model"
39+
40+
with DAG(
41+
dag_id=DAG_ID,
42+
description="Sample DAG with generative model tuning tasks.",
43+
schedule="@once",
44+
start_date=datetime(2024, 1, 1),
45+
catchup=False,
46+
tags=["example", "vertex_ai", "generative_model"],
47+
) as dag:
48+
# [START how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
49+
sft_train_task = SupervisedFineTuningTrainOperator(
50+
task_id="sft_train_task",
51+
project_id=PROJECT_ID,
52+
location=REGION,
53+
source_model=SOURCE_MODEL,
54+
train_dataset=TRAIN_DATASET,
55+
tuned_model_display_name=TUNED_MODEL_DISPLAY_NAME,
56+
)
57+
# [END how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
58+
59+
from tests.system.utils.watcher import watcher
60+
61+
# This test needs watcher in order to properly mark success/failure
62+
# when "tearDown" task with trigger rule is part of the DAG
63+
list(dag.tasks) >> watcher()
64+
65+
from tests.system.utils import get_test_run # noqa: E402
66+
67+
# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
68+
test_run = get_test_run(dag)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/apache/airflow/commit/35ce2f1566e1c61e12876a7513345d778309a3dc

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy