@@ -1723,6 +1723,53 @@ def test_tune_text_generation_model_evaluation_with_only_tensorboard(
1723
1723
].runtime_config .parameter_values
1724
1724
assert pipeline_arguments ["tensorboard_resource_id" ] == tensorboard_name
1725
1725
1726
+ @pytest .mark .parametrize (
1727
+ "job_spec" ,
1728
+ [_TEST_PIPELINE_SPEC_JSON ],
1729
+ )
1730
+ @pytest .mark .parametrize (
1731
+ "mock_request_urlopen" ,
1732
+ ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" ],
1733
+ indirect = True ,
1734
+ )
1735
+ def test_tune_text_generation_model_staging_bucket (
1736
+ self ,
1737
+ mock_pipeline_service_create ,
1738
+ mock_pipeline_job_get ,
1739
+ mock_pipeline_bucket_exists ,
1740
+ job_spec ,
1741
+ mock_load_yaml_and_json ,
1742
+ mock_gcs_from_string ,
1743
+ mock_gcs_upload ,
1744
+ mock_request_urlopen ,
1745
+ mock_get_tuned_model ,
1746
+ ):
1747
+ """Tests that tune_model respects staging_bucket."""
1748
+ TEST_STAGING_BUCKET = "gs://test_staging_bucket/path/"
1749
+ aiplatform .init (staging_bucket = TEST_STAGING_BUCKET )
1750
+
1751
+ with mock .patch .object (
1752
+ target = model_garden_service_client .ModelGardenServiceClient ,
1753
+ attribute = "get_publisher_model" ,
1754
+ return_value = gca_publisher_model .PublisherModel (
1755
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
1756
+ ),
1757
+ ):
1758
+ model = language_models .TextGenerationModel .from_pretrained (
1759
+ "text-bison@001"
1760
+ )
1761
+
1762
+ model .tune_model (
1763
+ training_data = _TEST_TEXT_BISON_TRAINING_DF ,
1764
+ tuning_job_location = "europe-west4" ,
1765
+ tuned_model_location = "us-central1" ,
1766
+ )
1767
+ call_kwargs = mock_pipeline_service_create .call_args [1 ]
1768
+ pipeline_arguments = call_kwargs [
1769
+ "pipeline_job"
1770
+ ].runtime_config .parameter_values
1771
+ assert pipeline_arguments ["dataset_uri" ].startswith (TEST_STAGING_BUCKET )
1772
+
1726
1773
@pytest .mark .parametrize (
1727
1774
"job_spec" ,
1728
1775
[_TEST_PIPELINE_SPEC_JSON ],
0 commit comments