Content-Length: 373273 | pFad | https://github.com/googleapis/python-aiplatform/commit/d9ced106b57cb21f5dcde433f1779b6500aaf7b0

26 fix: LLM - Make tuning use the global staging bucket if specified · googleapis/python-aiplatform@d9ced10 · GitHub
Skip to content

Commit d9ced10

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - Make tuning use the global staging bucket if specified
PiperOrigin-RevId: 574646855
1 parent 19dd980 commit d9ced10

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

tests/unit/aiplatform/test_language_models.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,53 @@ def test_tune_text_generation_model_evaluation_with_only_tensorboard(
17231723
].runtime_config.parameter_values
17241724
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name
17251725

1726+
@pytest.mark.parametrize(
1727+
"job_spec",
1728+
[_TEST_PIPELINE_SPEC_JSON],
1729+
)
1730+
@pytest.mark.parametrize(
1731+
"mock_request_urlopen",
1732+
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
1733+
indirect=True,
1734+
)
1735+
def test_tune_text_generation_model_staging_bucket(
1736+
self,
1737+
mock_pipeline_service_create,
1738+
mock_pipeline_job_get,
1739+
mock_pipeline_bucket_exists,
1740+
job_spec,
1741+
mock_load_yaml_and_json,
1742+
mock_gcs_from_string,
1743+
mock_gcs_upload,
1744+
mock_request_urlopen,
1745+
mock_get_tuned_model,
1746+
):
1747+
"""Tests that tune_model respects staging_bucket."""
1748+
TEST_STAGING_BUCKET = "gs://test_staging_bucket/path/"
1749+
aiplatform.init(staging_bucket=TEST_STAGING_BUCKET)
1750+
1751+
with mock.patch.object(
1752+
target=model_garden_service_client.ModelGardenServiceClient,
1753+
attribute="get_publisher_model",
1754+
return_value=gca_publisher_model.PublisherModel(
1755+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1756+
),
1757+
):
1758+
model = language_models.TextGenerationModel.from_pretrained(
1759+
"text-bison@001"
1760+
)
1761+
1762+
model.tune_model(
1763+
training_data=_TEST_TEXT_BISON_TRAINING_DF,
1764+
tuning_job_location="europe-west4",
1765+
tuned_model_location="us-central1",
1766+
)
1767+
call_kwargs = mock_pipeline_service_create.call_args[1]
1768+
pipeline_arguments = call_kwargs[
1769+
"pipeline_job"
1770+
].runtime_config.parameter_values
1771+
assert pipeline_arguments["dataset_uri"].startswith(TEST_STAGING_BUCKET)
1772+
17261773
@pytest.mark.parametrize(
17271774
"job_spec",
17281775
[_TEST_PIPELINE_SPEC_JSON],

vertexai/language_models/_language_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2534,6 +2534,12 @@ def _cancel(self):
25342534

25352535

25362536
def _get_tuned_models_dir_uri(model_id: str) -> str:
2537+
if aiplatform_initializer.global_config.staging_bucket:
2538+
return (
2539+
aiplatform_initializer.global_config.staging_bucket
2540+
+ "/tuned_language_models/"
2541+
+ model_id
2542+
)
25372543
staging_gcs_bucket = (
25382544
gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist()
25392545
)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/d9ced106b57cb21f5dcde433f1779b6500aaf7b0

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy