Content-Length: 764366 | pFad | https://github.com/googleapis/python-aiplatform/commit/ea2b5cfbcafead1c63009fda10bd44a00d560efb

11 feat: Add XAI SDK integration to TensorFlow models with LIT integrati… · googleapis/python-aiplatform@ea2b5cf · GitHub
Skip to content

Commit ea2b5cf

Browse files
authored
feat: Add XAI SDK integration to TensorFlow models with LIT integration (#917)
Add automatic addition of feature attribution for TensorFlow 2 models in the LIT integration on Vertex Notebooks. Detects for Vertex Notebooks by looking for the same environment variable to check for Vertex Notebooks as the LIT library does. Fixes b/210943910 🦕 go/local-explanations-lit-xai-notebook
1 parent 235fbf9 commit ea2b5cf

File tree

3 files changed

+227
-52
lines changed

3 files changed

+227
-52
lines changed

google/cloud/aiplatform/explain/lit.py

Lines changed: 127 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from typing import Dict, List, Tuple, Union
17+
import logging
18+
import os
19+
20+
from typing import Dict, List, Optional, Tuple, Union
1821

1922
try:
2023
from lit_nlp.api import dataset as lit_dataset
24+
from lit_nlp.api import dtypes as lit_dtypes
2125
from lit_nlp.api import model as lit_model
2226
from lit_nlp.api import types as lit_types
2327
from lit_nlp import notebook
@@ -82,6 +86,7 @@ def __init__(
8286
model: str,
8387
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
8488
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
89+
attribution_method: str = "sampled_shapley",
8590
):
8691
"""Construct a VertexLitModel.
8792
Args:
@@ -94,39 +99,33 @@ def __init__(
9499
output_types:
95100
Required. An OrderedDict of string names matching the labels of the model
96101
as the key, and the associated LitType of the label.
102+
attribution_method:
103+
Optional. A string to choose what attribution configuration to
104+
set up the explainer with. Valid options are 'sampled_shapley'
105+
or 'integrated_gradients'.
97106
"""
98-
self._loaded_model = tf.saved_model.load(model)
99-
serving_default = self._loaded_model.signatures[
100-
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
101-
]
102-
_, self._kwargs_signature = serving_default.structured_input_signature
103-
self._output_signature = serving_default.structured_outputs
104-
105-
if len(self._kwargs_signature) != 1:
106-
raise ValueError("Please use a model with only one input tensor.")
107-
108-
if len(self._output_signature) != 1:
109-
raise ValueError("Please use a model with only one output tensor.")
110-
107+
self._load_model(model)
111108
self._input_types = input_types
112109
self._output_types = output_types
110+
self._input_tensor_name = next(iter(self._kwargs_signature))
111+
self._attribution_explainer = None
112+
if os.environ.get("LIT_PROXY_URL"):
113+
self._set_up_attribution_explainer(model, attribution_method)
114+
115+
@property
116+
def attribution_explainer(self,) -> Optional["AttributionExplainer"]: # noqa: F821
117+
"""Gets the attribution explainer property if set."""
118+
return self._attribution_explainer
113119

114120
def predict_minibatch(
115121
self, inputs: List[lit_types.JsonDict]
116122
) -> List[lit_types.JsonDict]:
117-
"""Returns predictions for a single batch of examples.
118-
Args:
119-
inputs:
120-
sequence of inputs, following model.input_spec()
121-
Returns:
122-
list of outputs, following model.output_spec()
123-
"""
124123
instances = []
125124
for input in inputs:
126125
instance = [input[feature] for feature in self._input_types]
127126
instances.append(instance)
128127
prediction_input_dict = {
129-
next(iter(self._kwargs_signature)): tf.convert_to_tensor(instances)
128+
self._input_tensor_name: tf.convert_to_tensor(instances)
130129
}
131130
prediction_dict = self._loaded_model.signatures[
132131
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -140,6 +139,15 @@ def predict_minibatch(
140139
for label, value in zip(self._output_types.keys(), prediction)
141140
}
142141
)
142+
# Get feature attributions
143+
if self.attribution_explainer:
144+
attributions = self.attribution_explainer.explain(
145+
[{self._input_tensor_name: i} for i in instances]
146+
)
147+
for i, attribution in enumerate(attributions):
148+
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
149+
attribution.feature_importance()
150+
)
143151
return outputs
144152

145153
def input_spec(self) -> lit_types.Spec:
@@ -148,7 +156,70 @@ def input_spec(self) -> lit_types.Spec:
148156

149157
def output_spec(self) -> lit_types.Spec:
150158
"""Return a spec describing model outputs."""
151-
return self._output_types
159+
output_spec_dict = dict(self._output_types)
160+
if self.attribution_explainer:
161+
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
162+
signed=True
163+
)
164+
return output_spec_dict
165+
166+
def _load_model(self, model: str):
167+
"""Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
168+
Args:
169+
model: Required. A string reference to a TensorFlow saved model directory.
170+
Raises:
171+
ValueError if the model has more than one input tensor or more than one output tensor.
172+
"""
173+
self._loaded_model = tf.saved_model.load(model)
174+
serving_default = self._loaded_model.signatures[
175+
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
176+
]
177+
_, self._kwargs_signature = serving_default.structured_input_signature
178+
self._output_signature = serving_default.structured_outputs
179+
180+
if len(self._kwargs_signature) != 1:
181+
raise ValueError("Please use a model with only one input tensor.")
182+
183+
if len(self._output_signature) != 1:
184+
raise ValueError("Please use a model with only one output tensor.")
185+
186+
def _set_up_attribution_explainer(
187+
self, model: str, attribution_method: str = "integrated_gradients"
188+
):
189+
"""Populates the attribution explainer attribute of the class.
190+
Args:
191+
model: Required. A string reference to a TensorFlow saved model directory.
192+
attribution_method:
193+
Optional. A string to choose what attribution configuration to
194+
set up the explainer with. Valid options are 'sampled_shapley'
195+
or 'integrated_gradients'.
196+
"""
197+
try:
198+
import explainable_ai_sdk
199+
from explainable_ai_sdk.metadata.tf.v2 import SavedModelMetadataBuilder
200+
except ImportError:
201+
logging.info(
202+
"Skipping explanations because the Explainable AI SDK is not installed."
203+
'Please install the SDK using "pip install explainable-ai-sdk"'
204+
)
205+
return
206+
207+
builder = SavedModelMetadataBuilder(model)
208+
builder.get_metadata()
209+
builder.set_numeric_metadata(
210+
self._input_tensor_name,
211+
index_feature_mapping=list(self._input_types.keys()),
212+
)
213+
builder.save_metadata(model)
214+
if attribution_method == "integrated_gradients":
215+
explainer_config = explainable_ai_sdk.IntegratedGradientsConfig()
216+
else:
217+
explainer_config = explainable_ai_sdk.SampledShapleyConfig()
218+
219+
self._attribution_explainer = explainable_ai_sdk.load_model_from_local_path(
220+
model, explainer_config
221+
)
222+
self._load_model(model)
152223

153224

154225
def create_lit_dataset(
@@ -172,22 +243,27 @@ def create_lit_model(
172243
model: str,
173244
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
174245
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
246+
attribution_method: str = "sampled_shapley",
175247
) -> lit_model.Model:
176248
"""Creates a LIT Model object.
177249
Args:
178250
model:
179-
Required. A string reference to a local TensorFlow saved model directory.
180-
The model must have at most one input and one output tensor.
251+
Required. A string reference to a local TensorFlow saved model directory.
252+
The model must have at most one input and one output tensor.
181253
input_types:
182-
Required. An OrderedDict of string names matching the features of the model
183-
as the key, and the associated LitType of the feature.
254+
Required. An OrderedDict of string names matching the features of the model
255+
as the key, and the associated LitType of the feature.
184256
output_types:
185-
Required. An OrderedDict of string names matching the labels of the model
186-
as the key, and the associated LitType of the label.
257+
Required. An OrderedDict of string names matching the labels of the model
258+
as the key, and the associated LitType of the label.
259+
attribution_method:
260+
Optional. A string to choose what attribution configuration to
261+
set up the explainer with. Valid options are 'sampled_shapley'
262+
or 'integrated_gradients'.
187263
Returns:
188264
A LIT Model object that has the same functionality as the model provided.
189265
"""
190-
return _VertexLitModel(model, input_types, output_types)
266+
return _VertexLitModel(model, input_types, output_types, attribution_method)
191267

192268

193269
def open_lit(
@@ -198,11 +274,11 @@ def open_lit(
198274
"""Open LIT from the provided models and datasets.
199275
Args:
200276
models:
201-
Required. A list of LIT models to open LIT with.
277+
Required. A list of LIT models to open LIT with.
202278
input_types:
203-
Required. A lit of LIT datasets to open LIT with.
279+
Required. A lit of LIT datasets to open LIT with.
204280
open_in_new_tab:
205-
Optional. A boolean to choose if LIT open in a new tab or not.
281+
Optional. A boolean to choose if LIT open in a new tab or not.
206282
Raises:
207283
ImportError if LIT is not installed.
208284
"""
@@ -216,24 +292,31 @@ def set_up_and_open_lit(
216292
model: Union[str, lit_model.Model],
217293
input_types: Union[List[str], Dict[str, lit_types.LitType]],
218294
output_types: Union[str, List[str], Dict[str, lit_types.LitType]],
295+
attribution_method: str = "sampled_shapley",
219296
open_in_new_tab: bool = True,
220297
) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
221298
"""Creates a LIT dataset and model and opens LIT.
222299
Args:
223-
dataset:
300+
dataset:
224301
Required. A Pandas DataFrame that includes feature column names and data.
225-
column_types:
302+
column_types:
226303
Required. An OrderedDict of string names matching the columns of the dataset
227304
as the key, and the associated LitType of the column.
228-
model:
305+
model:
229306
Required. A string reference to a TensorFlow saved model directory.
230307
The model must have at most one input and one output tensor.
231-
input_types:
308+
input_types:
232309
Required. An OrderedDict of string names matching the features of the model
233310
as the key, and the associated LitType of the feature.
234-
output_types:
311+
output_types:
235312
Required. An OrderedDict of string names matching the labels of the model
236313
as the key, and the associated LitType of the label.
314+
attribution_method:
315+
Optional. A string to choose what attribution configuration to
316+
set up the explainer with. Valid options are 'sampled_shapley'
317+
or 'integrated_gradients'.
318+
open_in_new_tab:
319+
Optional. A boolean to choose if LIT open in a new tab or not.
237320
Returns:
238321
A Tuple of the LIT dataset and model created.
239322
Raises:
@@ -244,8 +327,12 @@ def set_up_and_open_lit(
244327
dataset = create_lit_dataset(dataset, column_types)
245328

246329
if not isinstance(model, lit_model.Model):
247-
model = create_lit_model(model, input_types, output_types)
330+
model = create_lit_model(
331+
model, input_types, output_types, attribution_method=attribution_method
332+
)
248333

249-
open_lit({"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab)
334+
open_lit(
335+
{"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab,
336+
)
250337

251338
return dataset, model

setup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.7.0"]
3737
metadata_extra_require = ["pandas >= 1.0.0"]
3838
xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
39-
lit_extra_require = ["tensorflow >= 2.3.0", "pandas >= 1.0.0", "lit-nlp >= 0.4.0"]
39+
lit_extra_require = [
40+
"tensorflow >= 2.3.0",
41+
"pandas >= 1.0.0",
42+
"lit-nlp >= 0.4.0",
43+
"explainable-ai-sdk >= 1.0.0",
44+
]
4045
profiler_extra_require = [
4146
"tensorboard-plugin-profile >= 2.4.0",
4247
"werkzeug >= 2.0.0",

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/ea2b5cfbcafead1c63009fda10bd44a00d560efb

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy