Skip to content

Commit db34b85

Browse files
authored
fix: Honoring the model's supported_deployment_resources_types (#865)
Honoring the model's `supported_deployment_resources_types` See https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.Model.FIELDS.repeated.google.cloud.aiplatform.v1.Model.DeploymentResourcesType.google.cloud.aiplatform.v1.Model.supported_deployment_resources_types Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) Fixes #773 🦕
1 parent 8a0626d commit db34b85

File tree

4 files changed

+105
-8
lines changed

4 files changed

+105
-8
lines changed

google/cloud/aiplatform/models.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050

5151
from google.protobuf import field_mask_pb2, json_format
5252

53+
_DEFAULT_MACHINE_TYPE = "n1-standard-2"
54+
5355
_LOGGER = base.Logger(__name__)
5456

5557

@@ -798,7 +800,7 @@ def _deploy(
798800
self._deploy_call(
799801
self.api_client,
800802
self.resource_name,
801-
model.resource_name,
803+
model,
802804
self._gca_resource.traffic_split,
803805
deployed_model_display_name=deployed_model_display_name,
804806
traffic_percentage=traffic_percentage,
@@ -823,7 +825,7 @@ def _deploy_call(
823825
cls,
824826
api_client: endpoint_service_client.EndpointServiceClient,
825827
endpoint_resource_name: str,
826-
model_resource_name: str,
828+
model: "Model",
827829
endpoint_resource_traffic_split: Optional[proto.MapField] = None,
828830
deployed_model_display_name: Optional[str] = None,
829831
traffic_percentage: Optional[int] = 0,
@@ -845,8 +847,8 @@ def _deploy_call(
845847
Required. endpoint_service_client.EndpointServiceClient to make call.
846848
endpoint_resource_name (str):
847849
Required. Endpoint resource name to deploy model to.
848-
model_resource_name (str):
849-
Required. Model resource name of Model to deploy.
850+
model (aiplatform.Model):
851+
Required. Model to be deployed.
850852
endpoint_resource_traffic_split (proto.MapField):
851853
Optional. Endpoint current resource traffic split.
852854
deployed_model_display_name (str):
@@ -913,6 +915,7 @@ def _deploy_call(
913915
is not 0 or 100.
914916
ValueError: If only `explanation_metadata` or `explanation_parameters`
915917
is specified.
918+
ValueError: If model does not support deployment.
916919
"""
917920

918921
max_replica_count = max(min_replica_count, max_replica_count)
@@ -923,12 +926,40 @@ def _deploy_call(
923926
)
924927

925928
deployed_model = gca_endpoint_compat.DeployedModel(
926-
model=model_resource_name,
929+
model=model.resource_name,
927930
display_name=deployed_model_display_name,
928931
service_account=service_account,
929932
)
930933

931-
if machine_type:
934+
supports_automatic_resources = (
935+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
936+
in model.supported_deployment_resources_types
937+
)
938+
supports_dedicated_resources = (
939+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
940+
in model.supported_deployment_resources_types
941+
)
942+
provided_custom_machine_spec = (
943+
machine_type or accelerator_type or accelerator_count
944+
)
945+
946+
# If the model supports both automatic and dedicated deployment resources,
947+
# decide based on the presence of machine spec customizations
948+
use_dedicated_resources = supports_dedicated_resources and (
949+
not supports_automatic_resources or provided_custom_machine_spec
950+
)
951+
952+
if provided_custom_machine_spec and not use_dedicated_resources:
953+
_LOGGER.info(
954+
"Model does not support dedicated deployment resources. "
955+
"The machine_type, accelerator_type and accelerator_count parameters are ignored."
956+
)
957+
958+
if use_dedicated_resources and not machine_type:
959+
machine_type = _DEFAULT_MACHINE_TYPE
960+
_LOGGER.info(f"Using default machine_type: {machine_type}")
961+
962+
if use_dedicated_resources:
932963
machine_spec = gca_machine_resources_compat.MachineSpec(
933964
machine_type=machine_type
934965
)
@@ -944,11 +975,16 @@ def _deploy_call(
944975
max_replica_count=max_replica_count,
945976
)
946977

947-
else:
978+
elif supports_automatic_resources:
948979
deployed_model.automatic_resources = gca_machine_resources_compat.AutomaticResources(
949980
min_replica_count=min_replica_count,
950981
max_replica_count=max_replica_count,
951982
)
983+
else:
984+
raise ValueError(
985+
"Model does not support deployment. "
986+
"See https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.Model.FIELDS.repeated.google.cloud.aiplatform.v1.Model.DeploymentResourcesType.google.cloud.aiplatform.v1.Model.supported_deployment_resources_types"
987+
)
952988

953989
# Service will throw error if both metadata and parameters are not provided
954990
if explanation_metadata and explanation_parameters:
@@ -2115,7 +2151,7 @@ def _deploy(
21152151
Endpoint._deploy_call(
21162152
endpoint.api_client,
21172153
endpoint.resource_name,
2118-
self.resource_name,
2154+
self,
21192155
endpoint._gca_resource.traffic_split,
21202156
deployed_model_display_name=deployed_model_display_name,
21212157
traffic_percentage=traffic_percentage,

tests/unit/aiplatform/test_endpoints.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,9 @@ def test_create_with_labels(self, create_endpoint_mock, sync):
611611
def test_deploy(self, deploy_model_mock, sync):
612612
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
613613
test_model = models.Model(_TEST_ID)
614+
test_model._gca_resource.supported_deployment_resources_types.append(
615+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
616+
)
614617
test_endpoint.deploy(test_model, sync=sync)
615618

616619
if not sync:
@@ -636,6 +639,9 @@ def test_deploy(self, deploy_model_mock, sync):
636639
def test_deploy_with_display_name(self, deploy_model_mock, sync):
637640
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
638641
test_model = models.Model(_TEST_ID)
642+
test_model._gca_resource.supported_deployment_resources_types.append(
643+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
644+
)
639645
test_endpoint.deploy(
640646
model=test_model, deployed_model_display_name=_TEST_DISPLAY_NAME, sync=sync
641647
)
@@ -664,6 +670,9 @@ def test_deploy_raise_error_traffic_80(self, sync):
664670
with pytest.raises(ValueError):
665671
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
666672
test_model = models.Model(_TEST_ID)
673+
test_model._gca_resource.supported_deployment_resources_types.append(
674+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
675+
)
667676
test_endpoint.deploy(model=test_model, traffic_percentage=80, sync=sync)
668677

669678
if not sync:
@@ -675,6 +684,9 @@ def test_deploy_raise_error_traffic_120(self, sync):
675684
with pytest.raises(ValueError):
676685
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
677686
test_model = models.Model(_TEST_ID)
687+
test_model._gca_resource.supported_deployment_resources_types.append(
688+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
689+
)
678690
test_endpoint.deploy(model=test_model, traffic_percentage=120, sync=sync)
679691

680692
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
@@ -683,6 +695,9 @@ def test_deploy_raise_error_traffic_negative(self, sync):
683695
with pytest.raises(ValueError):
684696
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
685697
test_model = models.Model(_TEST_ID)
698+
test_model._gca_resource.supported_deployment_resources_types.append(
699+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
700+
)
686701
test_endpoint.deploy(model=test_model, traffic_percentage=-18, sync=sync)
687702

688703
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
@@ -691,6 +706,9 @@ def test_deploy_raise_error_min_replica(self, sync):
691706
with pytest.raises(ValueError):
692707
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
693708
test_model = models.Model(_TEST_ID)
709+
test_model._gca_resource.supported_deployment_resources_types.append(
710+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
711+
)
694712
test_endpoint.deploy(model=test_model, min_replica_count=-1, sync=sync)
695713

696714
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
@@ -699,6 +717,9 @@ def test_deploy_raise_error_max_replica(self, sync):
699717
with pytest.raises(ValueError):
700718
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
701719
test_model = models.Model(_TEST_ID)
720+
test_model._gca_resource.supported_deployment_resources_types.append(
721+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
722+
)
702723
test_endpoint.deploy(model=test_model, max_replica_count=-2, sync=sync)
703724

704725
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
@@ -707,6 +728,9 @@ def test_deploy_raise_error_traffic_split(self, sync):
707728
with pytest.raises(ValueError):
708729
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
709730
test_model = models.Model(_TEST_ID)
731+
test_model._gca_resource.supported_deployment_resources_types.append(
732+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
733+
)
710734
test_endpoint.deploy(model=test_model, traffic_split={"a": 99}, sync=sync)
711735

712736
@pytest.mark.usefixtures("get_model_mock")
@@ -723,6 +747,9 @@ def test_deploy_with_traffic_percent(self, deploy_model_mock, sync):
723747

724748
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
725749
test_model = models.Model(_TEST_ID)
750+
test_model._gca_resource.supported_deployment_resources_types.append(
751+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
752+
)
726753
test_endpoint.deploy(model=test_model, traffic_percentage=70, sync=sync)
727754
if not sync:
728755
test_endpoint.wait()
@@ -755,6 +782,9 @@ def test_deploy_with_traffic_split(self, deploy_model_mock, sync):
755782

756783
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
757784
test_model = models.Model(_TEST_ID)
785+
test_model._gca_resource.supported_deployment_resources_types.append(
786+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
787+
)
758788
test_endpoint.deploy(
759789
model=test_model, traffic_split={"model1": 30, "0": 70}, sync=sync
760790
)
@@ -781,6 +811,9 @@ def test_deploy_with_traffic_split(self, deploy_model_mock, sync):
781811
def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync):
782812
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
783813
test_model = models.Model(_TEST_ID)
814+
test_model._gca_resource.supported_deployment_resources_types.append(
815+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
816+
)
784817
test_endpoint.deploy(
785818
model=test_model,
786819
machine_type=_TEST_MACHINE_TYPE,
@@ -821,6 +854,9 @@ def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync):
821854
def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, sync):
822855
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
823856
test_model = models.Model(_TEST_ID)
857+
test_model._gca_resource.supported_deployment_resources_types.append(
858+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
859+
)
824860
test_endpoint.deploy(
825861
model=test_model,
826862
machine_type=_TEST_MACHINE_TYPE,
@@ -865,6 +901,9 @@ def test_deploy_with_explanations(self, deploy_model_with_explanations_mock, syn
865901
def test_deploy_with_min_replica_count(self, deploy_model_mock, sync):
866902
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
867903
test_model = models.Model(_TEST_ID)
904+
test_model._gca_resource.supported_deployment_resources_types.append(
905+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
906+
)
868907
test_endpoint.deploy(model=test_model, min_replica_count=2, sync=sync)
869908

870909
if not sync:
@@ -889,6 +928,9 @@ def test_deploy_with_min_replica_count(self, deploy_model_mock, sync):
889928
def test_deploy_with_max_replica_count(self, deploy_model_mock, sync):
890929
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
891930
test_model = models.Model(_TEST_ID)
931+
test_model._gca_resource.supported_deployment_resources_types.append(
932+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
933+
)
892934
test_endpoint.deploy(model=test_model, max_replica_count=2, sync=sync)
893935
if not sync:
894936
test_endpoint.wait()

tests/unit/aiplatform/test_models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,10 @@ def test_upload_uploads_and_gets_model_with_custom_location(
825825
def test_deploy(self, deploy_model_mock, sync):
826826

827827
test_model = models.Model(_TEST_ID)
828+
test_model._gca_resource.supported_deployment_resources_types.append(
829+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
830+
)
831+
828832
test_endpoint = models.Endpoint(_TEST_ID)
829833

830834
assert test_model.deploy(test_endpoint, sync=sync,) == test_endpoint
@@ -854,6 +858,9 @@ def test_deploy(self, deploy_model_mock, sync):
854858
def test_deploy_no_endpoint(self, deploy_model_mock, sync):
855859

856860
test_model = models.Model(_TEST_ID)
861+
test_model._gca_resource.supported_deployment_resources_types.append(
862+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
863+
)
857864
test_endpoint = test_model.deploy(sync=sync)
858865

859866
if not sync:
@@ -881,6 +888,9 @@ def test_deploy_no_endpoint(self, deploy_model_mock, sync):
881888
def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync):
882889

883890
test_model = models.Model(_TEST_ID)
891+
test_model._gca_resource.supported_deployment_resources_types.append(
892+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
893+
)
884894
test_endpoint = test_model.deploy(
885895
machine_type=_TEST_MACHINE_TYPE,
886896
accelerator_type=_TEST_ACCELERATOR_TYPE,
@@ -919,6 +929,9 @@ def test_deploy_no_endpoint_dedicated_resources(self, deploy_model_mock, sync):
919929
@pytest.mark.parametrize("sync", [True, False])
920930
def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync):
921931
test_model = models.Model(_TEST_ID)
932+
test_model._gca_resource.supported_deployment_resources_types.append(
933+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
934+
)
922935
test_endpoint = test_model.deploy(
923936
machine_type=_TEST_MACHINE_TYPE,
924937
accelerator_type=_TEST_ACCELERATOR_TYPE,
@@ -961,6 +974,9 @@ def test_deploy_no_endpoint_with_explanations(self, deploy_model_mock, sync):
961974
def test_deploy_raises_with_impartial_explanation_spec(self):
962975

963976
test_model = models.Model(_TEST_ID)
977+
test_model._gca_resource.supported_deployment_resources_types.append(
978+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
979+
)
964980

965981
with pytest.raises(ValueError) as e:
966982
test_model.deploy(

tests/unit/aiplatform/test_training_jobs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,9 @@ def mock_model_service_get():
764764
model_service_client.ModelServiceClient, "get_model"
765765
) as mock_get_model:
766766
mock_get_model.return_value = gca_model.Model(name=_TEST_MODEL_NAME)
767+
mock_get_model.return_value.supported_deployment_resources_types.append(
768+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
769+
)
767770
yield mock_get_model
768771

769772

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy