Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigQuery: Add more stats to Models API, such as optimization_strategy (via synth). #8344

Merged
merged 1 commit into from
Jun 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion bigquery/google/cloud/bigquery_v2/gapic/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,30 @@ class ModelType(enum.IntEnum):
Attributes:
MODEL_TYPE_UNSPECIFIED (int)
LINEAR_REGRESSION (int): Linear regression model.
LOGISTIC_REGRESSION (int): Logistic regression model.
LOGISTIC_REGRESSION (int): Logistic regression based classification model.
KMEANS (int): [Beta] K-means clustering model.
TENSORFLOW (int): [Beta] An imported TensorFlow model.
"""

MODEL_TYPE_UNSPECIFIED = 0
LINEAR_REGRESSION = 1
LOGISTIC_REGRESSION = 2
KMEANS = 3
TENSORFLOW = 6

class OptimizationStrategy(enum.IntEnum):
"""
Indicates the optimization strategy used for training.

Attributes:
OPTIMIZATION_STRATEGY_UNSPECIFIED (int)
BATCH_GRADIENT_DESCENT (int): Uses an iterative batch gradient descent algorithm.
NORMAL_EQUATION (int): Uses a normal equation to solve linear regression problem.
"""

OPTIMIZATION_STRATEGY_UNSPECIFIED = 0
BATCH_GRADIENT_DESCENT = 1
NORMAL_EQUATION = 2


class StandardSqlDataType(object):
Expand Down
98 changes: 70 additions & 28 deletions bigquery/google/cloud/bigquery_v2/proto/model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@ import "google/protobuf/empty.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
import "google/api/annotations.proto";
import "google/api/client.proto";

option go_package = "google.golang.org/genproto/googleapis/cloud/bigquery/v2;bigquery";
option java_outer_classname = "ModelProto";
option java_package = "com.google.cloud.bigquery.v2";


service ModelService {
option (google.api.default_host) = "bigquery.googleapis.com";
option (google.api.oauth_scopes) =
"https://www.googleapis.com/auth/bigquery,"
"https://www.googleapis.com/auth/cloud-platform,"
"https://www.googleapis.com/auth/cloud-platform.read-only";

// Gets the specified model resource by model ID.
rpc GetModel(GetModelRequest) returns (Model) {
}
Expand Down Expand Up @@ -67,11 +73,12 @@ message Model {
google.protobuf.DoubleValue r_squared = 5;
}

// Aggregate metrics for classification models. For multi-class models,
// the metrics are either macro-averaged: metrics are calculated for each
// label and then an unweighted average is taken of those values or
// micro-averaged: the metric is calculated globally by counting the total
// number of correctly predicted rows.
// Aggregate metrics for classification/classifier models. For multi-class
// models, the metrics are either macro-averaged or micro-averaged. When
// macro-averaged, the metrics are calculated for each label and then an
// unweighted average is taken of those values. When micro-averaged, the
// metric is calculated globally by counting the total number of correctly
// predicted rows.
message AggregateClassificationMetrics {
// Precision is the fraction of actual positive predictions that had
// positive actual labels. For multiclass this is a macro-averaged
Expand Down Expand Up @@ -104,7 +111,7 @@ message Model {
google.protobuf.DoubleValue roc_auc = 7;
}

// Evaluation metrics for binary classification models.
// Evaluation metrics for binary classification/classifier models.
message BinaryClassificationMetrics {
// Confusion matrix for binary classification models.
message BinaryConfusionMatrix {
Expand All @@ -123,21 +130,35 @@ message Model {
// Number of false samples predicted as false.
google.protobuf.Int64Value false_negatives = 5;

// Aggregate precision.
// The fraction of actual positive predictions that had positive actual
// labels.
google.protobuf.DoubleValue precision = 6;

// Aggregate recall.
// The fraction of actual positive labels that were given a positive
// prediction.
google.protobuf.DoubleValue recall = 7;

// The equally weighted average of recall and precision.
google.protobuf.DoubleValue f1_score = 8;

// The fraction of predictions given the correct label.
google.protobuf.DoubleValue accuracy = 9;
}

// Aggregate classification metrics.
AggregateClassificationMetrics aggregate_classification_metrics = 1;

// Binary confusion matrix at multiple thresholds.
repeated BinaryConfusionMatrix binary_confusion_matrix_list = 2;

// Label representing the positive class.
string positive_label = 3;

// Label representing the negative class.
string negative_label = 4;
}

// Evaluation metrics for multi-class classification models.
// Evaluation metrics for multi-class classification/classifier models.
message MultiClassClassificationMetrics {
// Confusion matrix for multi-class classification models.
message ConfusionMatrix {
Expand Down Expand Up @@ -185,18 +206,18 @@ message Model {
google.protobuf.DoubleValue mean_squared_distance = 2;
}

// Evaluation metrics of a model. These are either computed on all
// training data or just the eval data based on whether eval data was used
// during training.
// Evaluation metrics of a model. These are either computed on all training
// data or just the eval data based on whether eval data was used during
// training. These are not present for imported models.
message EvaluationMetrics {
oneof metrics {
// Populated for regression models.
RegressionMetrics regression_metrics = 1;

// Populated for binary classification models.
// Populated for binary classification/classifier models.
BinaryClassificationMetrics binary_classification_metrics = 2;

// Populated for multi-class classification models.
// Populated for multi-class classification/classifier models.
MultiClassClassificationMetrics multi_class_classification_metrics = 3;

// [Beta] Populated for clustering models.
Expand All @@ -207,13 +228,14 @@ message Model {
// Information about a single training query run for the model.
message TrainingRun {
message TrainingOptions {
// The maximum number of iterations in training.
// The maximum number of iterations in training. Used only for iterative
// training algorithms.
int64 max_iterations = 1;

// Type of loss function used during training run.
LossType loss_type = 2;

// Learning rate in training.
// Learning rate in training. Used only for iterative training algorithms.
double learn_rate = 3;

// L1 regularization coefficient.
Expand All @@ -223,14 +245,16 @@ message Model {
google.protobuf.DoubleValue l2_regularization = 5;

// When early_stop is true, stops training when accuracy improvement is
// less than 'min_relative_progress'.
// less than 'min_relative_progress'. Used only for iterative training
// algorithms.
google.protobuf.DoubleValue min_relative_progress = 6;

// Whether to train a model from the last checkpoint.
google.protobuf.BoolValue warm_start = 7;

// Whether to stop early when the loss doesn't improve significantly
// any more (compared to min_relative_progress).
// any more (compared to min_relative_progress). Used only for iterative
// training algorithms.
google.protobuf.BoolValue early_stop = 8;

// Name of input label columns in training data.
Expand All @@ -257,21 +281,29 @@ message Model {
// https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data-type-properties
string data_split_column = 12;

// The strategy to determine learning rate.
// The strategy to determine learn rate for the current iteration.
LearnRateStrategy learn_rate_strategy = 13;

// Specifies the initial learning rate for line search to start at.
// Specifies the initial learning rate for the line search learn rate
// strategy.
double initial_learn_rate = 16;

// Weights associated with each label class, for rebalancing the
// training data.
// training data. Only applicable for classification models.
map<string, double> label_class_weights = 17;

// [Beta] Distance type for clustering models.
DistanceType distance_type = 20;

// [Beta] Number of clusters for clustering models.
int64 num_clusters = 21;

// [Beta] Google Cloud Storage URI from which the model was imported. Only
// applicable for imported models.
string model_uri = 22;

// Optimization strategy for training linear regression models.
OptimizationStrategy optimization_strategy = 23;
}

// Information about a single iteration of the training run.
Expand Down Expand Up @@ -330,11 +362,14 @@ message Model {
// Linear regression model.
LINEAR_REGRESSION = 1;

// Logistic regression model.
// Logistic regression based classification model.
LOGISTIC_REGRESSION = 2;

// [Beta] K-means clustering model.
KMEANS = 3;

// [Beta] An imported TensorFlow model.
TENSORFLOW = 6;
}

// Loss metric to evaluate model training performance.
Expand Down Expand Up @@ -391,6 +426,17 @@ message Model {
CONSTANT = 2;
}

// Indicates the optimization strategy used for training.
enum OptimizationStrategy {
OPTIMIZATION_STRATEGY_UNSPECIFIED = 0;

// Uses an iterative batch gradient descent algorithm.
BATCH_GRADIENT_DESCENT = 1;

// Uses a normal equation to solve linear regression problem.
NORMAL_EQUATION = 2;
}

// Output only. A hash of this resource.
string etag = 1;

Expand All @@ -406,11 +452,9 @@ message Model {
int64 last_modified_time = 6;

// [Optional] A user-friendly description of this model.
// @mutable bigquery.models.patch
string description = 12;

// [Optional] A descriptive name for this model.
// @mutable bigquery.models.patch
string friendly_name = 14;

// [Optional] The labels associated with this model. You can use these to
Expand All @@ -419,15 +463,13 @@ message Model {
// characters, underscores and dashes. International characters are allowed.
// Label values are optional. Label keys must start with a letter and each
// label in the list must have a different key.
// @mutable bigquery.models.patch
map<string, string> labels = 15;

// [Optional] The time when this model expires, in milliseconds since the
// epoch. If not present, the model will persist indefinitely. Expired models
// will be deleted and their storage reclaimed. The defaultTableExpirationMs
// property of the encapsulating dataset can be used to set a default
// expirationTime on newly created models.
// @mutable bigquery.models.patch
int64 expiration_time = 16;

// Output only. The geographic location where the model resides. This value
Expand All @@ -445,7 +487,7 @@ message Model {
repeated StandardSqlField feature_columns = 10;

// Output only. Label columns that were used to train this model.
// The output of the model will have a β€œpredicted_” prefix to these columns.
// The output of the model will have a "predicted_" prefix to these columns.
repeated StandardSqlField label_columns = 11;
}

Expand Down
Loading