-
Notifications
You must be signed in to change notification settings - Fork 277
/
config.proto
565 lines (529 loc) · 25.5 KB
/
config.proto
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package tensorflow_model_analysis;
import "google/protobuf/wrappers.proto";
// Model specification.
message ModelSpec {
// Name used to distinguish different models when multiple instances are being
// evaluated. Note that this name is not necessarily the name of the model as
// seen by a trainer, etc. This name is more of an alias for both a model name
// and a particular version and/or format. For example, common names to use
// here might be "candidate" or "baseline" when referring to different
// versions of the same model that are being evaluated for the purpose of
// model validation. Note also that if only a single ModelSpec is used in the
// config, then no model_name will be set in any metrics keys that are output
// regardless of whether a name was provided here or not.
string name = 2;
// The type of the model that is being evaluated. Supported types include
// "tf_keras", "tf_estimator", "tf_lite", "tf_js", and "tf_generic". If unset,
// automatically detects whether the model_type is "tf_keras", "tf_estimator",
// or "tf_generic" based on whether the model loads as a keras model followed
// by whether or not the signature_name is set to "eval".
string model_type = 12;
// Optional name of signature to use for inference (e.g. "serving_default").
// For estimator based EvalSavedModels, this must be set to "eval". If not
// set, then the default depends on the model_type. For "tf_keras" models the
// model itself will be used for inference. For models that support signatures
// ("tf_generic", etc) "predict" (if it exists) or "serving_default" will be
// assumed. For models that don't use signatures ("tf_lite", etc) this setting
// will be ignored.
string signature_name = 3;
// Optional names of preprocessing functions to run in the order that they
// should be invoked. Preprocessing functions are used to transform the
// features into the form required for inference and metrics evaluation. The
// output from preprocessing can also be used for slicing. Preprocessing
// functions can be saved as signatures or as attributes on the saved model.
// If no names are provided, the names "transformed_features" and
// "transformed_labels" will be searched for.
//
// The output of a preprocessing function will override the feature with the
// same name for label or example weight extraction purposes. If a
// preprocessing function outputs a non-dict value, then it will be stored as
// a feature under the preprocessing function name itself. For example, if a
// function called "transformed_labels" outputs a single array value then it
// will associated with the feature name "transformed_labels". This name can
// be used when setting the "label_key" or in slicing configs.
repeated string preprocessing_function_names = 13;
// Label key (single-output model). The key can identify either a transformed
// feature (see preprocessing_function_names) or a raw input feature. Use one
// of label_key or label_keys.
string label_key = 5;
// Label keys (multi-output model) keyed by output_name. If all the outputs
// for a multi-output model use the same key, then a single key may also be
// used. Use one of label_key or label_keys.
map<string, string> label_keys = 6; // oneof not allowed with maps
// Optional prediction key (single_output model). The prediction key is used
// to distinguish between different values when the output from the predict
// call is a dict instead of a single tensor as is the case with
// tf.Estimators. If not set and the prediction is a dict, the keys
// 'logistic', 'predictions', or 'probabilities' are tried (in that order).
// The prediction key is also used in cases where the predictions are
// pre-calculated and stored along side the features. In this case the
// prediction key refers to a key in the features dictionary. Use one of
// prediction_key or prediction_keys.
//
// Note that prediction_key is NOT the same as the output_name used in the
// MetricsSpec. The output_name refers to the name of an output for a
// multi-output model (for tf.Estimator's this is called the "head" whereas
// for keras the term output is used). Some outputs (typically tf.Estimator)
// are themselves made up of a dict of multiple tensors (e.g. 'classes',
// 'probabilities', etc). The predition_key specifies which key in the output
// contains the prediction values (i.e. 'probabilities', etc). For example,
// a tf.Estimator model might output the following:
//
// {
// 'head1': {
// 'classes': classes_tensor,
// 'class_ids': class_ids_tensor,
// 'logits': logits_tensor',
// 'probabilities': probabilities_tensor
// }
// 'head2': {
// 'classes': classes_tensor,
// 'class_ids': class_ids_tensor,
// 'logits': logits_tensor',
// 'probabilities': probabilities_tensor
// }
// }
//
// Here 'head1' or 'head2' would be the output_name, whereas 'probabilities'
// would be the prediction_key.
//
// If a model is not being used (i.e. the predictions are stored in the
// inputs), then a prediction_key is required. However, if a model is being
// used to compute the predictions a prediction_key typically is not needed
// since all the default tf.Estimator values for the keys are handled by TFMA
// and keras does not return multiple tensors for a given output.
string prediction_key = 7;
// Optional prediction keys (multi-output model) keyed by output_name. Use one
// of prediction_key or prediction_keys. See comment under prediction_key on
// the difference between output_name and prediction_key.
map<string, string> prediction_keys = 8; // oneof not allowed with maps
// Optional example weight key (single-output model). The example_weight_key
// can identify either a transformed feature (see
// preprocessing_function_names) or raw input feature. Use one of
// example_weight_key or example_weight_keys.
string example_weight_key = 9;
// Optional example weight keys (multi-output model) keyed by output_name. If
// all the outputs for a multi-output model use the same key, then a single
// key may also be used. Use one of example_weight_key or example_weight_keys.
map<string, string> example_weight_keys = 10; // oneof not allowed with maps
// True if baseline model (otherwise candidate). Only one baseline is allowed
// per evaluation run.
bool is_baseline = 11;
// Options for padding prediction and label arrays before feeding to metrics.
//
// Predictions and labels may not have the same length (for example,
// the model may pad the predictions so that a batch of predictions is aligned
// while labels are extracted from the input and is not padded) while metrics
// may require they be of the same length. TFMA can pad the shorter one with
// the configured values.
PaddingOptions padding_options = 14;
reserved 1, 4;
}
// Slicing specification.
message SlicingSpec {
// Feature keys to slice on.
//
// Note that the feature key can be either a transformed feature key (see
// ModelSpec.preprocessing_function_names) or a raw feature key parsed
// directly from the inputs. If a transformed feature key and raw feature key
// use the same name, the transformed feature will take precedence. Note also
// that while transformed features are associated with the models that
// processed them, when it comes to slicing all the unique values across all
// models will be used.
repeated string feature_keys = 1;
// Feature values to slice on keyed by associated feature keys.
//
// The same caveats that apply to feature_keys with respect to feature
// transformations and raw features apply to feature_values as well (see
// feature_keys for more information).
//
// Note that strings representing ints and floats will be automatically
// converted to ints and floats respectively and will be compared against both
// the string versions and int or float versions of the associated features.
map<string, string> feature_values = 2;
// This config is an alternative to the config above.
// It must have the pattern:
// "SELECT STRUCT({feature_name} [AS {slice_key}])
// [FROM example.feature_name [, example.feature_name, ... ]
// [WHERE ... ]]"
//
// The “example.feature_name” inside the FROM statement is used to flatten the
// repeated fields. For non-repeated fields, you can directly write the config
// as follows:
// “SELECT STRUCT(non_repeated_feature_a, non_repeated_feature_b)”.
//
// When executing, this SQL expression will be further wrapped as:
// “SELECT ARRAY({slice_keys_sql}) as slices FROM Examples as example”.
//
// The resulting output of the query will have the same number of rows as the
// input dataset. Each row will only have one column named "slices". Each row
// is a list. Each element in the list will be a list of tuple with
// ('key', 'value') pairs representing a slice. For example, a single row
// could be:
// [[(‘gender’, ‘male’), (‘country’: ‘USA’)], [(‘zip_code’, ‘123456’)]]
//
// In the user’s SQL statement, the “example” is a key word that binds to each
// input "row". The semantics of this variable will depend on the decoding of
// the input data to the Arrow representation (e.g., for tf.Example, each key
// is decoded to a separate column). Thus, structured data can be readily
// accessed by iterating/unnesting the fields of the "example" variable.
//
// Example 1:
// slice_keys_sql="SELECT STRUCT(gender) FROM example.gender"
// - This equals to config: feature_keys=[gender]
// - the slice key and value will be: (gender, {gender_value})
//
// Example 2:
// slice_keys_sql =
// "SELECT STRUCT(gender, country)
// FROM example.gender, example.country
// WHERE country = 'USA'"
// - This equals to config:
// feature_keys=[gender], feature_values={country:'USA'}
// - the slice key and value will be:
// (gender_x_country, {gender_value}_x_USA)
//
// Example 3 (background positive subgroup negative):
// slice_keys_sql=
// "SELECT STRUCT('male' as bpsn)
// FROM example
// WHERE ('male' not in UNNEST(example.gender) and 1 in
// UNNEST(example.label)) or
// ('male' in UNNEST(example.gender) and 0 in
// UNNEST(example.label))"
// - the slice key and value will be: (bpsn, male)
string slice_keys_sql = 3;
}
// Cross slicing specification.
message CrossSlicingSpec {
SlicingSpec baseline_spec = 1;
repeated SlicingSpec slicing_specs = 2;
}
// Options for aggregating multi-class / multi-label outputs.
//
// When used the associated MetricSpec metrics must be binary classification
// metrics (NOT multi-class classification metrics).
message AggregationOptions {
oneof type {
// Compute aggregate metrics by treating all examples being equal (i.e.
// flatten the prediction/label pairs across all classes and perform the
// computation as if they were separate examples in a binary classification
// problem). Micro is typically used with multi-class outputs.
bool micro_average = 1;
// Computes aggregate metrics by treating all classes being equal (i.e.
// compute binary classification metrics separately for each of the classes
// and then take the average). This approach is good for the case where each
// class is equally important and/or class labels distribution is balanced.
// Macro is typically used with multi-label outputs.
//
// If macro averaging is enabled without using top_k_list, class_weights
// must be configured in order to identify which classes the average will be
// computed for.
bool macro_average = 2;
// Compute aggregate metrics using macro averaging but weight the classes
// during aggregation by the ratio of positive labels for each class.
//
// If weighted macro averaging is enabled without using top_k_list,
// class_weights must be configured in order to identify which classes the
// average will be computed for.
bool weighted_macro_average = 3;
}
// Weights to apply to classes during aggregation (only supported if
// top_k_list is not used). Each key corresponds to a class ID. For micro
// aggregation the weights will be applied to each prediction/label pair. For
// macro aggregation the weights will be applied to the overall metric
// computed for each class prior to aggregation.
//
// If class_weights are configured, but some keys are not provided then their
// weights will be assumed to be 0.0. This allows the class_weights to be used
// to filter the classes used for aggregation.
//
// Note that for macro_average and weighted_macro_average when the top_k_list
// is not used, the class_weights are required. Also note that when used with
// weighted_macro_average, weights will be applied in two forms (from the
// ratio of positive labels and from the values provided here) which may or
// may not be desired (i.e. setting all the weights to 1.0 is the most common
// configuration for weighted_macro_average).
map<int32, float> class_weights = 4;
// Performs aggregation based on the classes with the top k predicted values
// for each value of top k provided. If not set then all classes are used.
// Note that unlike the top k used with binarization this truncates the list
// of classes to only the top k values (i.e. it does not set non-top k to
// -inf).
RepeatedInt32Value top_k_list = 5;
}
// Options for binarizing multi-class / multi-label outputs.
//
// When used the associated MetricSpec metrics must be binary classification
// metrics (NOT multi-class classification metrics).
message BinarizationOptions {
// Creates binary classification metrics based on one-vs-rest for each
// value of class_id provided.
RepeatedInt32Value class_ids = 4;
// Creates binary classification metrics based on the kth predicted value
// for each value of k provided.
RepeatedInt32Value k_list = 5;
// Creates binary classification metrics based on the top k predicted values
// for each value of top k provided. How this is computed is up to each metric
// implementation. However, the default implementation is such that for a
// given top k setting, the input prediction arrays will be updated to set the
// non-top k predictions to -inf before flattening the resulting array into a
// single binarized value. This makes top k well suited to calculations such
// as precision@k or recall@k, but may not be well suited for other binary
// classification metrics unless special handing is provided. Note that
// precision@k and recall@k can also be configured directly as multi-class
// classification metrics by setting top_k on the metric itself.
RepeatedInt32Value top_k_list = 6;
reserved 1, 2, 3;
}
// Options for padding prediction and label arrays before feeding to metrics.
//
// Predictions and labels may not have the same length (for example,
// the model may pad the predictions so that a batch of predictions is aligned,
// while labels are extracted from the input and is not padded) while metrics
// may require they be of the same length. TFMA can pad the shorter one with
// the configured values.
message PaddingOptions {
// If neither of the oneof is set, 0 will be used.
oneof label_padding {
int64 label_int_padding = 1;
float label_float_padding = 2;
}
// If neither of the oneof is set, 0 will be used.
oneof prediction_padding {
int64 prediction_int_padding = 3;
float prediction_float_padding = 4;
}
}
// Determines how to compute the deltas.
// This might be used in the following ways:
// - For metrics where a higher value is better (e.g. AUC),
// set the direction to HIGHER_IS_BETTER, and set the threshold to be
// the "slack", so we fail if say auc_new - auc_old < slack.
// - For metrics where a lower value is better (e.g. loss),
// set the direction to LOWER_IS_BETTER, and set the threshold to be the
// "slack", so we fail if say loss_new - loss_old > slack.
enum MetricDirection {
UNKNOWN = 0;
LOWER_IS_BETTER = 1;
HIGHER_IS_BETTER = 2;
}
// Generic change threshold message.
message GenericChangeThreshold {
// Let delta by determined as in the comments for Direction below.
// If delta > absolute, fail the validation.
google.protobuf.DoubleValue absolute = 1;
// Let delta by determined as in the comments for Direction below.
// If delta / X_old > relative, fail the validation.
google.protobuf.DoubleValue relative = 2;
MetricDirection direction = 3;
}
// Generic value threshold message.
// Fail the validation if the value does not lie in [lower_bound,
// upper_bound], both boundaries inclusive.
message GenericValueThreshold {
// Lower bound. Assumed to be -Infinity if not set.
google.protobuf.DoubleValue lower_bound = 1;
// Upper bound. Assumed to be +Infinity if not set.
google.protobuf.DoubleValue upper_bound = 2;
}
message MetricThreshold {
oneof validate_absolute {
GenericValueThreshold value_threshold = 1;
}
oneof validate_relative {
GenericChangeThreshold change_threshold = 2;
}
}
message PerSliceMetricThreshold {
// A list of slicing specs to apply threshold to. An empty SlicingSpec
// represents the overall slice.
//
// NOTE: These are only references to slice definitions not new definitions.
// Slices must have been defined using EvalConfig.slicing_specs.
//
// See EvalConfig.slicing_specs for examples.
repeated SlicingSpec slicing_specs = 1;
MetricThreshold threshold = 2;
}
message PerSliceMetricThresholds {
repeated PerSliceMetricThreshold thresholds = 1;
}
// Cross slice metric threshold.
message CrossSliceMetricThreshold {
// A list of cross slicing specs to apply threshold to.
repeated CrossSlicingSpec cross_slicing_specs = 1;
MetricThreshold threshold = 2;
}
message CrossSliceMetricThresholds {
repeated CrossSliceMetricThreshold thresholds = 1;
}
// Metric configuration.
message MetricConfig {
// Name of a class derived for either tf.keras.metrics.Metric or
// tfma.metrics.Metric.
string class_name = 1;
// Optional name of module associated with class_name. If not set then class
// will be searched for under tfma.metrics followed by tf.keras.metrics.
string module = 2;
// Optional JSON encoded config settings associated with the class.
//
// The config settings will be passed as **kwarg values to the __init__ method
// for the class. For ease of use the leading and trailing '{' and '}'
// brackets may be omitted.
//
// Example: '"name": "my_metric", "thresholds": [0.5]'
string config = 3;
// Optional threshold for model validation on all slices.
MetricThreshold threshold = 4;
// Optional thresholds for model validation using specific slices.
repeated PerSliceMetricThreshold per_slice_thresholds = 5;
// Optional thresholds for model validation across slices.
repeated CrossSliceMetricThreshold cross_slice_thresholds = 6;
}
// Metrics specification.
message MetricsSpec {
// List of metric configurations.
repeated MetricConfig metrics = 1;
// Names of models (as defined by model_specs) the metrics should be
// calculated for. If this list is empty then all the names defined in
// the model_specs will be assumed else these metrics will only be
// computed for the model names provided.
repeated string model_names = 2;
// Optional names of outputs the metrics should be calculated for (for
// multi-output models). See comment under the ModelSpec.prediction_key on the
// difference between output_name and prediction_key.
repeated string output_names = 3;
// Optional weights to use when aggregating across outputs. Output aggregation
// will only be performed when weights are configured and only between outputs
// that have a weight set. For example, assume metrics contains 'auc' and the
// following output information was configured:
// output_names = ['output_1', 'output_2', 'output_3']
// output_weights = {'output_1': 1.0, 'output_2': 1.0}
// An 'auc' metric will be computed for each output along with an overall auc
// metric calculated as (1.0*(auc output_1) + 1.0*(auc output_2)) / (1.0+1.0).
map<string, float> output_weights = 10;
// Optional binarization options for converting multi-class / multi-label
// model outputs into outputs suitable for binary classification metrics.
BinarizationOptions binarize = 4;
// Optional aggregation options for computing overall aggregate metrics for
// multi-class / multi-label model outputs. Aggregation options are computed
// separately from binarization options so both can be set safely at the same
// time.
AggregationOptions aggregate = 6;
// Optional query key for query/ranking based metrics.
string query_key = 5;
// Thresholds defined here are intended to be used for metrics that were
// saved with the model and computed by default without requiring a metric
// config. All other thresholds should be defined in the MetricConfig
// associated with the metric.
//
// Optional thresholds for model validation on all slices (keyed by the
// associated metric name - e.g. 'auc', etc).
map<string, MetricThreshold> thresholds = 7;
// Optional thresholds for model validation using specific slices (keyed by
// the associated metric name - e.g. 'auc', etc).
map<string, PerSliceMetricThresholds> per_slice_thresholds = 8;
// Optional thresholds for model validation across slices (keyed by the
// associated metric name - e.g. 'auc', etc).
map<string, CrossSliceMetricThresholds> cross_slice_thresholds = 9;
}
// Additional configuration options.
message Options {
// True to include metrics saved with the model(s) (where possible) when
// calculating metrics. Any metrics defined in metrics_specs will override the
// metrics defined in the model if there are overlapping names.
google.protobuf.BoolValue include_default_metrics = 1;
// True to calculate confidence intervals.
google.protobuf.BoolValue compute_confidence_intervals = 2;
ConfidenceIntervalOptions confidence_intervals = 9;
// Int value to omit slices with example count < min_slice_size.
google.protobuf.Int32Value min_slice_size = 3;
// List of outputs that should not be written (e.g. 'metrics', 'plots',
// 'analysis', 'eval_config.json').
RepeatedStringValue disabled_outputs = 7;
reserved 4, 5, 6, 8;
}
message ConfidenceIntervalOptions {
enum ConfidenceIntervalMethod {
UNKNOWN_CONFIDENCE_INTERVAL_METHOD = 0;
POISSON_BOOTSTRAP = 1;
JACKKNIFE = 2;
}
// The confidence interval method to use for all metrics.
ConfidenceIntervalMethod method = 1;
}
// Tensorflow model analaysis config settings.
message EvalConfig {
// Model specifications for models used. Only one baseline is permitted.
repeated ModelSpec model_specs = 2;
// A list specs where each spec represents a way to slice the data. An empty
// config means slice on overall data.
//
// Example usages:
// - slicing_specs: {}
// Slice consisting of overall data.
// - slicing_specs: { feature_keys: ["country"] }
// Slices for all values in feature "country". For example, we might get
// slices "country:us", "country:jp", etc.
// - slicing_specs: { feature_values: [{key: "country", value: "us"}] }
// Slice consisting of "country:us".
// - slicing_specs: { feature_keys: ["country", "city"] }
// Slices for all values in feature "country" crossed with
// all values in feature "city" (note this may be expensive).
// - slicing_specs: { feature_keys: ["country"]
// feature_values: [{key: "age", value: "20"}] }
// Slices for all values in feature "country" crossed with value
// "age:20".
repeated SlicingSpec slicing_specs = 4;
// A list of cross slicing specs where each spec represents a pair of slices
// whose associated outputs should be compared. By default slices will be
// created for both slicing_spec and baseline_spec if they do not already
// exist in slicing_specs.
repeated CrossSlicingSpec cross_slicing_specs = 8;
// Metrics specifications.
repeated MetricsSpec metrics_specs = 5;
// Additional configuration options.
Options options = 6;
reserved 1, 3, 7;
}
// Repeated string value. Used to allow a default if no values are given.
message RepeatedStringValue {
repeated string values = 1;
}
// Repeated int32 value. Used to allow a default if no values are given.
message RepeatedInt32Value {
repeated int32 values = 1;
}
// Config and version.
message EvalConfigAndVersion {
EvalConfig eval_config = 1;
string version = 2;
}
// Evaluation run containing config, version and input parameters. This should
// be structurally compatible with EvalConfigAndVersion such that a saved
// EvalRun can be read as an EvalConfigAndVersion.
message EvalRun {
EvalConfig eval_config = 1;
string version = 2;
// Location of data used with evaluation run.
string data_location = 3;
// File format used with evaluation run.
string file_format = 4;
// Locations of model used with evaluation run.
map<string, string> model_locations = 5;
}