14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
- from typing import Dict , List , Tuple , Union
17
+ import logging
18
+ import os
19
+
20
+ from typing import Dict , List , Optional , Tuple , Union
18
21
19
22
try :
20
23
from lit_nlp .api import dataset as lit_dataset
24
+ from lit_nlp .api import dtypes as lit_dtypes
21
25
from lit_nlp .api import model as lit_model
22
26
from lit_nlp .api import types as lit_types
23
27
from lit_nlp import notebook
@@ -82,6 +86,7 @@ def __init__(
82
86
model : str ,
83
87
input_types : "OrderedDict[str, lit_types.LitType]" , # noqa: F821
84
88
output_types : "OrderedDict[str, lit_types.LitType]" , # noqa: F821
89
+ attribution_method : str = "sampled_shapley" ,
85
90
):
86
91
"""Construct a VertexLitModel.
87
92
Args:
@@ -94,39 +99,33 @@ def __init__(
94
99
output_types:
95
100
Required. An OrderedDict of string names matching the labels of the model
96
101
as the key, and the associated LitType of the label.
102
+ attribution_method:
103
+ Optional. A string to choose what attribution configuration to
104
+ set up the explainer with. Valid options are 'sampled_shapley'
105
+ or 'integrated_gradients'.
97
106
"""
98
- self ._loaded_model = tf .saved_model .load (model )
99
- serving_default = self ._loaded_model .signatures [
100
- tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY
101
- ]
102
- _ , self ._kwargs_signature = serving_default .structured_input_signature
103
- self ._output_signature = serving_default .structured_outputs
104
-
105
- if len (self ._kwargs_signature ) != 1 :
106
- raise ValueError ("Please use a model with only one input tensor." )
107
-
108
- if len (self ._output_signature ) != 1 :
109
- raise ValueError ("Please use a model with only one output tensor." )
110
-
107
+ self ._load_model (model )
111
108
self ._input_types = input_types
112
109
self ._output_types = output_types
110
+ self ._input_tensor_name = next (iter (self ._kwargs_signature ))
111
+ self ._attribution_explainer = None
112
+ if os .environ .get ("LIT_PROXY_URL" ):
113
+ self ._set_up_attribution_explainer (model , attribution_method )
114
+
115
+ @property
116
+ def attribution_explainer (self ,) -> Optional ["AttributionExplainer" ]: # noqa: F821
117
+ """Gets the attribution explainer property if set."""
118
+ return self ._attribution_explainer
113
119
114
120
def predict_minibatch (
115
121
self , inputs : List [lit_types .JsonDict ]
116
122
) -> List [lit_types .JsonDict ]:
117
- """Returns predictions for a single batch of examples.
118
- Args:
119
- inputs:
120
- sequence of inputs, following model.input_spec()
121
- Returns:
122
- list of outputs, following model.output_spec()
123
- """
124
123
instances = []
125
124
for input in inputs :
126
125
instance = [input [feature ] for feature in self ._input_types ]
127
126
instances .append (instance )
128
127
prediction_input_dict = {
129
- next ( iter ( self ._kwargs_signature )) : tf .convert_to_tensor (instances )
128
+ self ._input_tensor_name : tf .convert_to_tensor (instances )
130
129
}
131
130
prediction_dict = self ._loaded_model .signatures [
132
131
tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -140,6 +139,15 @@ def predict_minibatch(
140
139
for label , value in zip (self ._output_types .keys (), prediction )
141
140
}
142
141
)
142
+ # Get feature attributions
143
+ if self .attribution_explainer :
144
+ attributions = self .attribution_explainer .explain (
145
+ [{self ._input_tensor_name : i } for i in instances ]
146
+ )
147
+ for i , attribution in enumerate (attributions ):
148
+ outputs [i ]["feature_attribution" ] = lit_dtypes .FeatureSalience (
149
+ attribution .feature_importance ()
150
+ )
143
151
return outputs
144
152
145
153
def input_spec (self ) -> lit_types .Spec :
@@ -148,7 +156,70 @@ def input_spec(self) -> lit_types.Spec:
148
156
149
157
def output_spec (self ) -> lit_types .Spec :
150
158
"""Return a spec describing model outputs."""
151
- return self ._output_types
159
+ output_spec_dict = dict (self ._output_types )
160
+ if self .attribution_explainer :
161
+ output_spec_dict ["feature_attribution" ] = lit_types .FeatureSalience (
162
+ signed = True
163
+ )
164
+ return output_spec_dict
165
+
166
+ def _load_model (self , model : str ):
167
+ """Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
168
+ Args:
169
+ model: Required. A string reference to a TensorFlow saved model directory.
170
+ Raises:
171
+ ValueError if the model has more than one input tensor or more than one output tensor.
172
+ """
173
+ self ._loaded_model = tf .saved_model .load (model )
174
+ serving_default = self ._loaded_model .signatures [
175
+ tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY
176
+ ]
177
+ _ , self ._kwargs_signature = serving_default .structured_input_signature
178
+ self ._output_signature = serving_default .structured_outputs
179
+
180
+ if len (self ._kwargs_signature ) != 1 :
181
+ raise ValueError ("Please use a model with only one input tensor." )
182
+
183
+ if len (self ._output_signature ) != 1 :
184
+ raise ValueError ("Please use a model with only one output tensor." )
185
+
186
+ def _set_up_attribution_explainer (
187
+ self , model : str , attribution_method : str = "integrated_gradients"
188
+ ):
189
+ """Populates the attribution explainer attribute of the class.
190
+ Args:
191
+ model: Required. A string reference to a TensorFlow saved model directory.
192
+ attribution_method:
193
+ Optional. A string to choose what attribution configuration to
194
+ set up the explainer with. Valid options are 'sampled_shapley'
195
+ or 'integrated_gradients'.
196
+ """
197
+ try :
198
+ import explainable_ai_sdk
199
+ from explainable_ai_sdk .metadata .tf .v2 import SavedModelMetadataBuilder
200
+ except ImportError :
201
+ logging .info (
202
+ "Skipping explanations because the Explainable AI SDK is not installed."
203
+ 'Please install the SDK using "pip install explainable-ai-sdk"'
204
+ )
205
+ return
206
+
207
+ builder = SavedModelMetadataBuilder (model )
208
+ builder .get_metadata ()
209
+ builder .set_numeric_metadata (
210
+ self ._input_tensor_name ,
211
+ index_feature_mapping = list (self ._input_types .keys ()),
212
+ )
213
+ builder .save_metadata (model )
214
+ if attribution_method == "integrated_gradients" :
215
+ explainer_config = explainable_ai_sdk .IntegratedGradientsConfig ()
216
+ else :
217
+ explainer_config = explainable_ai_sdk .SampledShapleyConfig ()
218
+
219
+ self ._attribution_explainer = explainable_ai_sdk .load_model_from_local_path (
220
+ model , explainer_config
221
+ )
222
+ self ._load_model (model )
152
223
153
224
154
225
def create_lit_dataset (
@@ -172,22 +243,27 @@ def create_lit_model(
172
243
model : str ,
173
244
input_types : "OrderedDict[str, lit_types.LitType]" , # noqa: F821
174
245
output_types : "OrderedDict[str, lit_types.LitType]" , # noqa: F821
246
+ attribution_method : str = "sampled_shapley" ,
175
247
) -> lit_model .Model :
176
248
"""Creates a LIT Model object.
177
249
Args:
178
250
model:
179
- Required. A string reference to a local TensorFlow saved model directory.
180
- The model must have at most one input and one output tensor.
251
+ Required. A string reference to a local TensorFlow saved model directory.
252
+ The model must have at most one input and one output tensor.
181
253
input_types:
182
- Required. An OrderedDict of string names matching the features of the model
183
- as the key, and the associated LitType of the feature.
254
+ Required. An OrderedDict of string names matching the features of the model
255
+ as the key, and the associated LitType of the feature.
184
256
output_types:
185
- Required. An OrderedDict of string names matching the labels of the model
186
- as the key, and the associated LitType of the label.
257
+ Required. An OrderedDict of string names matching the labels of the model
258
+ as the key, and the associated LitType of the label.
259
+ attribution_method:
260
+ Optional. A string to choose what attribution configuration to
261
+ set up the explainer with. Valid options are 'sampled_shapley'
262
+ or 'integrated_gradients'.
187
263
Returns:
188
264
A LIT Model object that has the same functionality as the model provided.
189
265
"""
190
- return _VertexLitModel (model , input_types , output_types )
266
+ return _VertexLitModel (model , input_types , output_types , attribution_method )
191
267
192
268
193
269
def open_lit (
@@ -198,11 +274,11 @@ def open_lit(
198
274
"""Open LIT from the provided models and datasets.
199
275
Args:
200
276
models:
201
- Required. A list of LIT models to open LIT with.
277
+ Required. A list of LIT models to open LIT with.
202
278
input_types:
203
- Required. A lit of LIT datasets to open LIT with.
279
+ Required. A lit of LIT datasets to open LIT with.
204
280
open_in_new_tab:
205
- Optional. A boolean to choose if LIT open in a new tab or not.
281
+ Optional. A boolean to choose if LIT open in a new tab or not.
206
282
Raises:
207
283
ImportError if LIT is not installed.
208
284
"""
@@ -216,24 +292,31 @@ def set_up_and_open_lit(
216
292
model : Union [str , lit_model .Model ],
217
293
input_types : Union [List [str ], Dict [str , lit_types .LitType ]],
218
294
output_types : Union [str , List [str ], Dict [str , lit_types .LitType ]],
295
+ attribution_method : str = "sampled_shapley" ,
219
296
open_in_new_tab : bool = True ,
220
297
) -> Tuple [lit_dataset .Dataset , lit_model .Model ]:
221
298
"""Creates a LIT dataset and model and opens LIT.
222
299
Args:
223
- dataset:
300
+ dataset:
224
301
Required. A Pandas DataFrame that includes feature column names and data.
225
- column_types:
302
+ column_types:
226
303
Required. An OrderedDict of string names matching the columns of the dataset
227
304
as the key, and the associated LitType of the column.
228
- model:
305
+ model:
229
306
Required. A string reference to a TensorFlow saved model directory.
230
307
The model must have at most one input and one output tensor.
231
- input_types:
308
+ input_types:
232
309
Required. An OrderedDict of string names matching the features of the model
233
310
as the key, and the associated LitType of the feature.
234
- output_types:
311
+ output_types:
235
312
Required. An OrderedDict of string names matching the labels of the model
236
313
as the key, and the associated LitType of the label.
314
+ attribution_method:
315
+ Optional. A string to choose what attribution configuration to
316
+ set up the explainer with. Valid options are 'sampled_shapley'
317
+ or 'integrated_gradients'.
318
+ open_in_new_tab:
319
+ Optional. A boolean to choose if LIT open in a new tab or not.
237
320
Returns:
238
321
A Tuple of the LIT dataset and model created.
239
322
Raises:
@@ -244,8 +327,12 @@ def set_up_and_open_lit(
244
327
dataset = create_lit_dataset (dataset , column_types )
245
328
246
329
if not isinstance (model , lit_model .Model ):
247
- model = create_lit_model (model , input_types , output_types )
330
+ model = create_lit_model (
331
+ model , input_types , output_types , attribution_method = attribution_method
332
+ )
248
333
249
- open_lit ({"model" : model }, {"dataset" : dataset }, open_in_new_tab = open_in_new_tab )
334
+ open_lit (
335
+ {"model" : model }, {"dataset" : dataset }, open_in_new_tab = open_in_new_tab ,
336
+ )
250
337
251
338
return dataset , model
0 commit comments