Skip to content

Commit ea2b5cf

Browse files
authored
feat: Add XAI SDK integration to TensorFlow models with LIT integration (#917)
Add automatic addition of feature attribution for TensorFlow 2 models in the LIT integration on Vertex Notebooks. Detects for Vertex Notebooks by looking for the same environment variable to check for Vertex Notebooks as the LIT library does. Fixes b/210943910 🦕 go/local-explanations-lit-xai-notebook
1 parent 235fbf9 commit ea2b5cf

File tree

3 files changed

+227
-52
lines changed

3 files changed

+227
-52
lines changed

google/cloud/aiplatform/explain/lit.py

Lines changed: 127 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from typing import Dict, List, Tuple, Union
17+
import logging
18+
import os
19+
20+
from typing import Dict, List, Optional, Tuple, Union
1821

1922
try:
2023
from lit_nlp.api import dataset as lit_dataset
24+
from lit_nlp.api import dtypes as lit_dtypes
2125
from lit_nlp.api import model as lit_model
2226
from lit_nlp.api import types as lit_types
2327
from lit_nlp import notebook
@@ -82,6 +86,7 @@ def __init__(
8286
model: str,
8387
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
8488
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
89+
attribution_method: str = "sampled_shapley",
8590
):
8691
"""Construct a VertexLitModel.
8792
Args:
@@ -94,39 +99,33 @@ def __init__(
9499
output_types:
95100
Required. An OrderedDict of string names matching the labels of the model
96101
as the key, and the associated LitType of the label.
102+
attribution_method:
103+
Optional. A string to choose what attribution configuration to
104+
set up the explainer with. Valid options are 'sampled_shapley'
105+
or 'integrated_gradients'.
97106
"""
98-
self._loaded_model = tf.saved_model.load(model)
99-
serving_default = self._loaded_model.signatures[
100-
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
101-
]
102-
_, self._kwargs_signature = serving_default.structured_input_signature
103-
self._output_signature = serving_default.structured_outputs
104-
105-
if len(self._kwargs_signature) != 1:
106-
raise ValueError("Please use a model with only one input tensor.")
107-
108-
if len(self._output_signature) != 1:
109-
raise ValueError("Please use a model with only one output tensor.")
110-
107+
self._load_model(model)
111108
self._input_types = input_types
112109
self._output_types = output_types
110+
self._input_tensor_name = next(iter(self._kwargs_signature))
111+
self._attribution_explainer = None
112+
if os.environ.get("LIT_PROXY_URL"):
113+
self._set_up_attribution_explainer(model, attribution_method)
114+
115+
@property
116+
def attribution_explainer(self,) -> Optional["AttributionExplainer"]: # noqa: F821
117+
"""Gets the attribution explainer property if set."""
118+
return self._attribution_explainer
113119

114120
def predict_minibatch(
115121
self, inputs: List[lit_types.JsonDict]
116122
) -> List[lit_types.JsonDict]:
117-
"""Returns predictions for a single batch of examples.
118-
Args:
119-
inputs:
120-
sequence of inputs, following model.input_spec()
121-
Returns:
122-
list of outputs, following model.output_spec()
123-
"""
124123
instances = []
125124
for input in inputs:
126125
instance = [input[feature] for feature in self._input_types]
127126
instances.append(instance)
128127
prediction_input_dict = {
129-
next(iter(self._kwargs_signature)): tf.convert_to_tensor(instances)
128+
self._input_tensor_name: tf.convert_to_tensor(instances)
130129
}
131130
prediction_dict = self._loaded_model.signatures[
132131
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -140,6 +139,15 @@ def predict_minibatch(
140139
for label, value in zip(self._output_types.keys(), prediction)
141140
}
142141
)
142+
# Get feature attributions
143+
if self.attribution_explainer:
144+
attributions = self.attribution_explainer.explain(
145+
[{self._input_tensor_name: i} for i in instances]
146+
)
147+
for i, attribution in enumerate(attributions):
148+
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
149+
attribution.feature_importance()
150+
)
143151
return outputs
144152

145153
def input_spec(self) -> lit_types.Spec:
@@ -148,7 +156,70 @@ def input_spec(self) -> lit_types.Spec:
148156

149157
def output_spec(self) -> lit_types.Spec:
150158
"""Return a spec describing model outputs."""
151-
return self._output_types
159+
output_spec_dict = dict(self._output_types)
160+
if self.attribution_explainer:
161+
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
162+
signed=True
163+
)
164+
return output_spec_dict
165+
166+
def _load_model(self, model: str):
167+
"""Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
168+
Args:
169+
model: Required. A string reference to a TensorFlow saved model directory.
170+
Raises:
171+
ValueError if the model has more than one input tensor or more than one output tensor.
172+
"""
173+
self._loaded_model = tf.saved_model.load(model)
174+
serving_default = self._loaded_model.signatures[
175+
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
176+
]
177+
_, self._kwargs_signature = serving_default.structured_input_signature
178+
self._output_signature = serving_default.structured_outputs
179+
180+
if len(self._kwargs_signature) != 1:
181+
raise ValueError("Please use a model with only one input tensor.")
182+
183+
if len(self._output_signature) != 1:
184+
raise ValueError("Please use a model with only one output tensor.")
185+
186+
def _set_up_attribution_explainer(
187+
self, model: str, attribution_method: str = "integrated_gradients"
188+
):
189+
"""Populates the attribution explainer attribute of the class.
190+
Args:
191+
model: Required. A string reference to a TensorFlow saved model directory.
192+
attribution_method:
193+
Optional. A string to choose what attribution configuration to
194+
set up the explainer with. Valid options are 'sampled_shapley'
195+
or 'integrated_gradients'.
196+
"""
197+
try:
198+
import explainable_ai_sdk
199+
from explainable_ai_sdk.metadata.tf.v2 import SavedModelMetadataBuilder
200+
except ImportError:
201+
logging.info(
202+
"Skipping explanations because the Explainable AI SDK is not installed."
203+
'Please install the SDK using "pip install explainable-ai-sdk"'
204+
)
205+
return
206+
207+
builder = SavedModelMetadataBuilder(model)
208+
builder.get_metadata()
209+
builder.set_numeric_metadata(
210+
self._input_tensor_name,
211+
index_feature_mapping=list(self._input_types.keys()),
212+
)
213+
builder.save_metadata(model)
214+
if attribution_method == "integrated_gradients":
215+
explainer_config = explainable_ai_sdk.IntegratedGradientsConfig()
216+
else:
217+
explainer_config = explainable_ai_sdk.SampledShapleyConfig()
218+
219+
self._attribution_explainer = explainable_ai_sdk.load_model_from_local_path(
220+
model, explainer_config
221+
)
222+
self._load_model(model)
152223

153224

154225
def create_lit_dataset(
@@ -172,22 +243,27 @@ def create_lit_model(
172243
model: str,
173244
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
174245
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
246+
attribution_method: str = "sampled_shapley",
175247
) -> lit_model.Model:
176248
"""Creates a LIT Model object.
177249
Args:
178250
model:
179-
Required. A string reference to a local TensorFlow saved model directory.
180-
The model must have at most one input and one output tensor.
251+
Required. A string reference to a local TensorFlow saved model directory.
252+
The model must have at most one input and one output tensor.
181253
input_types:
182-
Required. An OrderedDict of string names matching the features of the model
183-
as the key, and the associated LitType of the feature.
254+
Required. An OrderedDict of string names matching the features of the model
255+
as the key, and the associated LitType of the feature.
184256
output_types:
185-
Required. An OrderedDict of string names matching the labels of the model
186-
as the key, and the associated LitType of the label.
257+
Required. An OrderedDict of string names matching the labels of the model
258+
as the key, and the associated LitType of the label.
259+
attribution_method:
260+
Optional. A string to choose what attribution configuration to
261+
set up the explainer with. Valid options are 'sampled_shapley'
262+
or 'integrated_gradients'.
187263
Returns:
188264
A LIT Model object that has the same functionality as the model provided.
189265
"""
190-
return _VertexLitModel(model, input_types, output_types)
266+
return _VertexLitModel(model, input_types, output_types, attribution_method)
191267

192268

193269
def open_lit(
@@ -198,11 +274,11 @@ def open_lit(
198274
"""Open LIT from the provided models and datasets.
199275
Args:
200276
models:
201-
Required. A list of LIT models to open LIT with.
277+
Required. A list of LIT models to open LIT with.
202278
input_types:
203-
Required. A lit of LIT datasets to open LIT with.
279+
Required. A lit of LIT datasets to open LIT with.
204280
open_in_new_tab:
205-
Optional. A boolean to choose if LIT open in a new tab or not.
281+
Optional. A boolean to choose if LIT open in a new tab or not.
206282
Raises:
207283
ImportError if LIT is not installed.
208284
"""
@@ -216,24 +292,31 @@ def set_up_and_open_lit(
216292
model: Union[str, lit_model.Model],
217293
input_types: Union[List[str], Dict[str, lit_types.LitType]],
218294
output_types: Union[str, List[str], Dict[str, lit_types.LitType]],
295+
attribution_method: str = "sampled_shapley",
219296
open_in_new_tab: bool = True,
220297
) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
221298
"""Creates a LIT dataset and model and opens LIT.
222299
Args:
223-
dataset:
300+
dataset:
224301
Required. A Pandas DataFrame that includes feature column names and data.
225-
column_types:
302+
column_types:
226303
Required. An OrderedDict of string names matching the columns of the dataset
227304
as the key, and the associated LitType of the column.
228-
model:
305+
model:
229306
Required. A string reference to a TensorFlow saved model directory.
230307
The model must have at most one input and one output tensor.
231-
input_types:
308+
input_types:
232309
Required. An OrderedDict of string names matching the features of the model
233310
as the key, and the associated LitType of the feature.
234-
output_types:
311+
output_types:
235312
Required. An OrderedDict of string names matching the labels of the model
236313
as the key, and the associated LitType of the label.
314+
attribution_method:
315+
Optional. A string to choose what attribution configuration to
316+
set up the explainer with. Valid options are 'sampled_shapley'
317+
or 'integrated_gradients'.
318+
open_in_new_tab:
319+
Optional. A boolean to choose if LIT open in a new tab or not.
237320
Returns:
238321
A Tuple of the LIT dataset and model created.
239322
Raises:
@@ -244,8 +327,12 @@ def set_up_and_open_lit(
244327
dataset = create_lit_dataset(dataset, column_types)
245328

246329
if not isinstance(model, lit_model.Model):
247-
model = create_lit_model(model, input_types, output_types)
330+
model = create_lit_model(
331+
model, input_types, output_types, attribution_method=attribution_method
332+
)
248333

249-
open_lit({"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab)
334+
open_lit(
335+
{"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab,
336+
)
250337

251338
return dataset, model

setup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.7.0"]
3737
metadata_extra_require = ["pandas >= 1.0.0"]
3838
xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
39-
lit_extra_require = ["tensorflow >= 2.3.0", "pandas >= 1.0.0", "lit-nlp >= 0.4.0"]
39+
lit_extra_require = [
40+
"tensorflow >= 2.3.0",
41+
"pandas >= 1.0.0",
42+
"lit-nlp >= 0.4.0",
43+
"explainable-ai-sdk >= 1.0.0",
44+
]
4045
profiler_extra_require = [
4146
"tensorboard-plugin-profile >= 2.4.0",
4247
"werkzeug >= 2.0.0",

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy