Content-Length: 892176 | pFad | https://www.github.com/googleapis/python-aiplatform/commit/c95d1cebec0a3e2bf6a25a76700d46a42e65376c

B32 feat: Added explain tabular samples (#348) · googleapis/python-aiplatform@c95d1ce · GitHub
Skip to content

Commit c95d1ce

Browse files
authored
feat: Added explain tabular samples (#348)
* Added tabular explanation sample * Cleaned up mocks * Ran linter * Fixed mock and added explanation printing * Added more verbose explanations * Fixed endpoint fixture * Fixed linting issues
1 parent 9245d30 commit c95d1ce

8 files changed

+267
-9
lines changed

samples/model-builder/conftest.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,9 @@ def mock_batch_predict_model(mock_model):
237237

238238

239239
@pytest.fixture
240-
def mock_upload_model():
241-
with patch.object(aiplatform.models.Model, "upload") as mock:
240+
def mock_upload_model(mock_model):
241+
with patch.object(aiplatform.Model, "upload") as mock:
242+
mock.return_value = mock_model
242243
yield mock
243244

244245

@@ -277,7 +278,7 @@ def mock_endpoint():
277278

278279
@pytest.fixture
279280
def mock_create_endpoint():
280-
with patch.object(aiplatform.Endpoint, "create") as mock:
281+
with patch.object(aiplatform.models.Endpoint, "create") as mock:
281282
yield mock
282283

283284

@@ -286,3 +287,10 @@ def mock_get_endpoint(mock_endpoint):
286287
with patch.object(aiplatform, "Endpoint") as mock_get_endpoint:
287288
mock_get_endpoint.return_value = mock_endpoint
288289
yield mock_get_endpoint
290+
291+
292+
@pytest.fixture
293+
def mock_endpoint_explain(mock_endpoint):
294+
with patch.object(mock_endpoint, "explain") as mock_endpoint_explain:
295+
mock_get_endpoint.return_value = mock_endpoint
296+
yield mock_endpoint_explain

samples/model-builder/create_and_import_dataset_video_sample_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from google.cloud.aiplatform import schema
1717

1818
import create_and_import_dataset_video_sample
19-
2019
import test_constants as constants
2120

2221

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_explain_tabular_sample]
21+
def explain_tabular_sample(
22+
project: str, location: str, endpoint_id: str, instance_dict: Dict
23+
):
24+
25+
aiplatform.init(project=project, location=location)
26+
27+
endpoint = aiplatform.Endpoint(endpoint_id)
28+
29+
response = endpoint.explain(instances=[instance_dict], parameters={})
30+
31+
for explanation in response.explanations:
32+
print(" explanation")
33+
# Feature attributions.
34+
attributions = explanation.attributions
35+
for attribution in attributions:
36+
print(" attribution")
37+
print(" baseline_output_value:", attribution.baseline_output_value)
38+
print(" instance_output_value:", attribution.instance_output_value)
39+
print(" output_display_name:", attribution.output_display_name)
40+
print(" approximation_error:", attribution.approximation_error)
41+
print(" output_name:", attribution.output_name)
42+
output_index = attribution.output_index
43+
for output_index in output_index:
44+
print(" output_index:", output_index)
45+
46+
for prediction in response.predictions:
47+
print(prediction)
48+
49+
50+
# [END aiplatform_sdk_explain_tabular_sample]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import explain_tabular_sample
17+
import test_constants as constants
18+
19+
20+
def test_explain_tabular_sample(
21+
mock_sdk_init, mock_endpoint, mock_get_endpoint, mock_endpoint_explain
22+
):
23+
24+
explain_tabular_sample.explain_tabular_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
endpoint_id=constants.ENDPOINT_NAME,
28+
instance_dict=constants.PREDICTION_TABULAR_INSTANCE,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT, location=constants.LOCATION
33+
)
34+
35+
mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,)
36+
37+
mock_endpoint_explain.assert_called_once_with(
38+
instances=[constants.PREDICTION_TABULAR_INSTANCE], parameters={}
39+
)

samples/model-builder/import_data_video_classification_sample_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import pytest
1818

1919
import import_data_video_classification_sample
20-
2120
import test_constants as constants
2221

2322

samples/model-builder/test_constants.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,31 @@
131131
inputs={
132132
"features": {
133133
"input_tensor_name": "dense_input",
134-
"encoding": "BAG_OF_FEATURES",
134+
# Input is tabular data
135135
"modality": "numeric",
136-
"index_feature_mapping": ["abc", "def", "ghj"],
136+
# Assign feature names to the inputs for explanation
137+
"encoding": "BAG_OF_FEATURES",
138+
"index_feature_mapping": [
139+
"crim",
140+
"zn",
141+
"indus",
142+
"chas",
143+
"nox",
144+
"rm",
145+
"age",
146+
"dis",
147+
"rad",
148+
"tax",
149+
"ptratio",
150+
"b",
151+
"lstat",
152+
],
137153
}
138154
},
139-
outputs={"medv": {"output_tensor_name": "dense_2"}},
155+
outputs={"prediction": {"output_tensor_name": "dense_2"}},
140156
)
141157
EXPLANATION_PARAMETERS = aiplatform.explain.ExplanationParameters(
142-
{"sampled_shapley_attribution": {"path_count": 10}}
158+
{"xrai_attribution": {"step_count": 1}}
143159
)
144160

145161
# Endpoint constants
@@ -148,4 +164,16 @@
148164
TRAFFIC_SPLIT = {"a": 99, "b": 1}
149165
MIN_REPLICA_COUNT = 1
150166
MAX_REPLICA_COUNT = 1
167+
ACCELERATOR_TYPE = "NVIDIA_TESLA_P100"
168+
ACCELERATOR_COUNT = 2
151169
ENDPOINT_DEPLOY_METADATA = ()
170+
PREDICTION_TABULAR_INSTANCE = {
171+
"longitude": "-124.35",
172+
"latitude": "40.54",
173+
"housing_median_age": "52.0",
174+
"total_rooms": "1820.0",
175+
"total_bedrooms": "300.0",
176+
"population": "806",
177+
"households": "270.0",
178+
"median_income": "3.014700",
179+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, Optional, Sequence
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_upload_model_explain_tabular_managed_container_sample]
21+
def upload_model_explain_tabular_managed_container_sample(
22+
project,
23+
location,
24+
model_display_name: str,
25+
serving_container_image_uri: str,
26+
artifact_uri: Optional[str] = None,
27+
serving_container_predict_route: Optional[str] = None,
28+
serving_container_health_route: Optional[str] = None,
29+
description: Optional[str] = None,
30+
serving_container_command: Optional[Sequence[str]] = None,
31+
serving_container_args: Optional[Sequence[str]] = None,
32+
serving_container_environment_variables: Optional[Dict[str, str]] = None,
33+
serving_container_ports: Optional[Sequence[int]] = None,
34+
instance_schema_uri: Optional[str] = None,
35+
parameters_schema_uri: Optional[str] = None,
36+
prediction_schema_uri: Optional[str] = None,
37+
explanation_metadata: Optional[aiplatform.explain.ExplanationMetadata] = None,
38+
explanation_parameters: Optional[aiplatform.explain.ExplanationParameters] = None,
39+
sync: bool = True,
40+
):
41+
42+
aiplatform.init(project=project, location=location)
43+
44+
model = aiplatform.Model.upload(
45+
display_name=model_display_name,
46+
serving_container_image_uri=serving_container_image_uri,
47+
artifact_uri=artifact_uri,
48+
serving_container_predict_route=serving_container_predict_route,
49+
serving_container_health_route=serving_container_health_route,
50+
description=description,
51+
serving_container_command=serving_container_command,
52+
serving_container_args=serving_container_args,
53+
serving_container_environment_variables=serving_container_environment_variables,
54+
serving_container_ports=serving_container_ports,
55+
instance_schema_uri=instance_schema_uri,
56+
parameters_schema_uri=parameters_schema_uri,
57+
prediction_schema_uri=prediction_schema_uri,
58+
explanation_metadata=explanation_metadata,
59+
explanation_parameters=explanation_parameters,
60+
sync=sync,
61+
)
62+
63+
model.wait()
64+
65+
print(model.display_name)
66+
print(model.resource_name)
67+
return model
68+
69+
70+
# [END aiplatform_sdk_upload_model_explain_tabular_managed_container_sample]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import test_constants as constants
16+
17+
import upload_model_explain_tabular_managed_container_sample
18+
19+
20+
def test_upload_model_explain_tabular_managed_container_sample(
21+
mock_sdk_init, mock_model, mock_init_model, mock_upload_model
22+
):
23+
24+
upload_model_explain_tabular_managed_container_sample.upload_model_explain_tabular_managed_container_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
model_display_name=constants.MODEL_NAME,
28+
serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI,
29+
artifact_uri=constants.MODEL_ARTIFACT_URI,
30+
serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE,
31+
serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE,
32+
description=constants.DESCRIPTION,
33+
serving_container_command=constants.SERVING_CONTAINER_COMMAND,
34+
serving_container_args=constants.SERVING_CONTAINER_ARGS,
35+
serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
36+
serving_container_ports=constants.SERVING_CONTAINER_PORTS,
37+
instance_schema_uri=constants.INSTANCE_SCHEMA_URI,
38+
parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI,
39+
prediction_schema_uri=constants.PREDICTION_SCHEMA_URI,
40+
explanation_metadata=constants.EXPLANATION_METADATA,
41+
explanation_parameters=constants.EXPLANATION_PARAMETERS,
42+
)
43+
44+
mock_sdk_init.assert_called_once_with(
45+
project=constants.PROJECT, location=constants.LOCATION
46+
)
47+
48+
mock_upload_model.assert_called_once_with(
49+
display_name=constants.MODEL_NAME,
50+
serving_container_image_uri=constants.SERVING_CONTAINER_IMAGE_URI,
51+
artifact_uri=constants.MODEL_ARTIFACT_URI,
52+
serving_container_predict_route=constants.SERVING_CONTAINER_PREDICT_ROUTE,
53+
serving_container_health_route=constants.SERVING_CONTAINER_HEALTH_ROUTE,
54+
description=constants.DESCRIPTION,
55+
serving_container_command=constants.SERVING_CONTAINER_COMMAND,
56+
serving_container_args=constants.SERVING_CONTAINER_ARGS,
57+
serving_container_environment_variables=constants.SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
58+
serving_container_ports=constants.SERVING_CONTAINER_PORTS,
59+
instance_schema_uri=constants.INSTANCE_SCHEMA_URI,
60+
parameters_schema_uri=constants.PARAMETERS_SCHEMA_URI,
61+
prediction_schema_uri=constants.PREDICTION_SCHEMA_URI,
62+
explanation_metadata=constants.EXPLANATION_METADATA,
63+
explanation_parameters=constants.EXPLANATION_PARAMETERS,
64+
sync=True,
65+
)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://www.github.com/googleapis/python-aiplatform/commit/c95d1cebec0a3e2bf6a25a76700d46a42e65376c

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy