Content-Length: 628413 | pFad | https://github.com/googleapis/python-aiplatform/commit/a0d4ff20ceb1c48806d1711fdb2691dc34f9f1db

6E feat: GenAI - Added the model Distillation feature (private preview) · googleapis/python-aiplatform@a0d4ff2 · GitHub
Skip to content

Commit a0d4ff2

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Added the model Distillation feature (private preview)
``` from google.cloud.aiplatform.private_preview import distillation job = distillation.train( student_model="gemma-1.1-2b-it", teacher_model="gemini-1.5-flash-001", training_dataset="gs://some-bucket/some_dataset.jsonl", # Optional: validation_dataset="gs://some-bucket/some_dataset.jsonl", epoch_count=5, learning_rate_multiplier=1.0, ) ``` PiperOrigin-RevId: 666992707
1 parent d59a052 commit a0d4ff2

File tree

3 files changed

+148
-6
lines changed

3 files changed

+148
-6
lines changed

tests/unit/vertexai/test_tuning.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
sft as preview_supervised_tuning,
3939
)
4040
from vertexai.tuning import sft as supervised_tuning
41+
from vertexai.tuning import _distillation
42+
from google.cloud import storage
4143

4244
import pytest
4345

@@ -79,11 +81,12 @@ def create_tuning_job(
7981
def _progress_tuning_job(self, name: str):
8082
tuning_job: gca_tuning_job.TuningJob = self._tuning_jobs[name]
8183
current_time = datetime.datetime.now(datetime.timezone.utc)
84+
training_dataset_uri = (
85+
tuning_job.supervised_tuning_spec.training_dataset_uri
86+
or tuning_job.distillation_spec.training_dataset_uri
87+
)
8288
if tuning_job.state == job_state.JobState.JOB_STATE_PENDING:
83-
if (
84-
"invalid_dataset"
85-
in tuning_job.supervised_tuning_spec.training_dataset_uri
86-
):
89+
if "invalid_dataset" in training_dataset_uri:
8790
tuning_job.state = job_state.JobState.JOB_STATE_FAILED
8891
tuning_job.error = status_pb2.Status(
8992
code=400, message="Invalid dataset."
@@ -162,6 +165,7 @@ def setup_method(self):
162165
vertexai.init(
163166
project=_TEST_PROJECT,
164167
location=_TEST_LOCATION,
168+
staging_bucket="gs://test-bucket",
165169
)
166170

167171
def teardown_method(self):
@@ -233,3 +237,48 @@ def test_genai_tuning_service_encryption_spec(
233237
train_dataset="gs://some-bucket/some_dataset.jsonl",
234238
)
235239
assert sft_tuning_job.encryption_spec.kms_key_name == "test-key"
240+
241+
@mock.patch.object(
242+
target=tuning.TuningJob,
243+
attribute="client_class",
244+
new=MockTuningJobClientWithOverride,
245+
)
246+
@mock.patch.object(
247+
target=storage.Bucket,
248+
attribute="exists",
249+
new=lambda _: True,
250+
)
251+
def test_genai_tuning_service_distillation_distill_model(self):
252+
distillation_train = _distillation.distill_model
253+
254+
tuning_job = distillation_train(
255+
student_model="gemma",
256+
teacher_model="gemini-1.0-pro-001",
257+
training_dataset="gs://some-bucket/some_dataset.jsonl",
258+
# Optional:
259+
validation_dataset="gs://some-bucket/some_dataset.jsonl",
260+
epoch_count=300,
261+
learning_rate_multiplier=1.0,
262+
)
263+
assert tuning_job.state == job_state.JobState.JOB_STATE_PENDING
264+
assert not tuning_job.has_ended
265+
assert not tuning_job.has_succeeded
266+
267+
# Refreshing the job
268+
tuning_job.refresh()
269+
assert tuning_job.state == job_state.JobState.JOB_STATE_PENDING
270+
assert not tuning_job.has_ended
271+
assert not tuning_job.has_succeeded
272+
273+
# Refreshing the job
274+
tuning_job.refresh()
275+
assert tuning_job.state == job_state.JobState.JOB_STATE_RUNNING
276+
assert not tuning_job.has_ended
277+
assert not tuning_job.has_succeeded
278+
279+
# Refreshing the job
280+
tuning_job.refresh()
281+
assert tuning_job.state == job_state.JobState.JOB_STATE_SUCCEEDED
282+
assert tuning_job.has_ended
283+
assert tuning_job.has_succeeded
284+
assert tuning_job.tuned_model_name

vertexai/tuning/_distillation.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access
16+
"""Classes for model tuning based on distillation."""
17+
18+
from typing import Optional
19+
20+
from google.cloud.aiplatform.utils import gcs_utils
21+
from google.cloud.aiplatform_v1beta1.types import tuning_job as gca_tuning_job_types
22+
23+
from vertexai import generative_models
24+
from vertexai.tuning import _tuning
25+
26+
27+
def distill_model(
28+
*,
29+
student_model: str,
30+
teacher_model: str,
31+
training_dataset: str,
32+
validation_dataset: Optional[str] = None,
33+
epoch_count: Optional[int] = None,
34+
learning_rate_multiplier: Optional[float] = None,
35+
tuned_model_display_name: Optional[str] = None,
36+
) -> "DistillationJob":
37+
"""Tunes a model using distillation.
38+
39+
Args:
40+
student_model:
41+
Student model name for distillation, e.g., "gemma-1.1-2b-it".
42+
teacher_model:
43+
Teacher model name for distillation, e.g., "gemini-1.5-flash-001".
44+
training_dataset: Cloud Storage path to file containing training dataset for distillation.
45+
The dataset should be in JSONL format.
46+
validation_dataset: Cloud Storage path to file containing validation dataset for distillation.
47+
The dataset should be in JSONL format.
48+
epoch_count: Number of training epoches for this tuning job.
49+
learning_rate_multiplier: Learning rate multiplier for tuning.
50+
tuned_model_display_name: The display name of the
51+
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
52+
be up to 128 characters long and can consist of any UTF-8 characters.
53+
54+
Returns:
55+
A `TuningJob` object.
56+
"""
57+
58+
if isinstance(student_model, generative_models.GenerativeModel):
59+
student_model = student_model._prediction_resource_name
60+
61+
student_model = student_model.rpartition("/")[-1]
62+
teacher_model = teacher_model.rpartition("/")[-1]
63+
64+
pipeline_root = (
65+
gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist()
66+
)
67+
68+
distillation_spec = gca_tuning_job_types.DistillationSpec(
69+
student_model=student_model,
70+
base_teacher_model=teacher_model,
71+
training_dataset_uri=training_dataset,
72+
validation_dataset_uri=validation_dataset,
73+
hyper_parameters=gca_tuning_job_types.DistillationHyperParameters(
74+
epoch_count=epoch_count,
75+
learning_rate_multiplier=learning_rate_multiplier,
76+
),
77+
pipeline_root_directory=pipeline_root,
78+
)
79+
80+
return DistillationJob._create( # pylint: disable=protected-access
81+
base_model=None,
82+
tuning_spec=distillation_spec,
83+
tuned_model_display_name=tuned_model_display_name,
84+
)
85+
86+
87+
class DistillationJob(_tuning.TuningJob):
88+
pass

vertexai/tuning/_tuning.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ def _create(
128128
cls,
129129
*,
130130
base_model: str,
131-
tuning_spec: Union[gca_tuning_job_types.SupervisedTuningSpec],
131+
tuning_spec: Union[
132+
gca_tuning_job_types.SupervisedTuningSpec,
133+
gca_tuning_job_types.DistillationSpec,
134+
],
132135
tuned_model_display_name: Optional[str] = None,
133136
description: Optional[str] = None,
134137
labels: Optional[Dict[str, str]] = None,
@@ -145,7 +148,7 @@ def _create(
145148
146149
This field is a member of `oneof`_ ``source_model``.
147150
tuning_spec: Tuning Spec for Fine Tuning.
148-
Supported types: SupervisedTuningSpec.
151+
Supported types: SupervisedTuningSpec, DistillationSpec.
149152
tuned_model_display_name: The display name of the
150153
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
151154
be up to 128 characters long and can consist of any UTF-8
@@ -192,6 +195,8 @@ def _create(
192195

193196
if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec):
194197
gca_tuning_job.supervised_tuning_spec = tuning_spec
198+
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
199+
gca_tuning_job.distillation_spec = tuning_spec
195200
else:
196201
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")
197202

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/a0d4ff20ceb1c48806d1711fdb2691dc34f9f1db

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy