From e9ea9ed3c1fb0b6626362250d84d356ba51031de Mon Sep 17 00:00:00 2001 From: Lars Wander Date: Wed, 14 Aug 2019 19:35:03 +0200 Subject: [PATCH] AutoML: Add a TablesClient for automl-tables specific behavior. (#8720) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Checking in staged client helper code Additional test, docs & proposed cleanup needs to happen on top of this. * update create_model to allow user to specify included or excluded col… (#16) * update create_model to allow user to specify included or excluded columns * made minor changes stylistically and with added ValueError outputs * Update doc gen & module structure. Add unit & system tests * added two new func: set time, get table address (#23) * added two new func: set time, get table address * changed indentation * Add system tests * Address linter & python2.7 import errors * Passes **kwargs through to client & implements missing methods * Support BQ as input/output in batch_predict * Address first round of feedback * Switch to pytest.raises, fix .rst formatting exception * Make list system tests more stringent --- automl/README.rst | 29 + automl/docs/gapic/v1beta1/tables.rst | 5 + automl/docs/index.rst | 1 + .../google/cloud/automl_v1beta1/__init__.py | 7 +- .../cloud/automl_v1beta1/tables/__init__.py | 0 .../automl_v1beta1/tables/tables_client.py | 2682 +++++++++++++++++ automl/synth.py | 16 + .../v1beta1/test_system_tables_client_v1.py | 265 ++ .../v1beta1/test_tables_client_v1beta1.py | 1241 ++++++++ 9 files changed, 4245 insertions(+), 1 deletion(-) create mode 100644 automl/docs/gapic/v1beta1/tables.rst create mode 100644 automl/google/cloud/automl_v1beta1/tables/__init__.py create mode 100644 automl/google/cloud/automl_v1beta1/tables/tables_client.py create mode 100644 automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py create mode 100644 automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py diff --git a/automl/README.rst b/automl/README.rst index d1387cbe7377..2c40f99d2521 100644 --- a/automl/README.rst +++ b/automl/README.rst @@ -104,3 +104,32 @@ Next Steps API to see other available methods on the client. - Read the `Product documentation`_ to learn more about the product and see How-to Guides. + +Making & Testing Local Changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to make changes to this library, here is how to set up your +development environment: + +1. Make sure you have `virtualenv`_ installed and activated as shown above. +2. Run the following one-time setup (it will be persisted in your virtualenv): + + .. code-block:: console + + pip install -r ../docs/requirements.txt + pip install -U nox mock pytest + +3. If you want to run all tests, you will need a billing-enabled + `GCP project`_, and a `service account`_ with access to the AutoML APIs. + Note: the first time the tests run in a new project it will take a _long_ + time, on the order of 2-3 hours. This is one-time setup that will be skipped + in future runs. + +.. _service account: https://cloud.google.com/iam/docs/creating-managing-service-accounts +.. _GCP project: https://cloud.google.com/resource-manager/docs/creating-managing-projects + +.. code-block:: console + + export PROJECT_ID= GOOGLE_APPLICATION_CREDENTIALS= + nox + diff --git a/automl/docs/gapic/v1beta1/tables.rst b/automl/docs/gapic/v1beta1/tables.rst new file mode 100644 index 000000000000..54ed6a203805 --- /dev/null +++ b/automl/docs/gapic/v1beta1/tables.rst @@ -0,0 +1,5 @@ +A tables-specific client for AutoML +=================================== + +.. automodule:: google.cloud.automl_v1beta1.tables.tables_client + :members: diff --git a/automl/docs/index.rst b/automl/docs/index.rst index cc1d290e2b55..01f577642cb1 100644 --- a/automl/docs/index.rst +++ b/automl/docs/index.rst @@ -8,6 +8,7 @@ Api Reference gapic/v1beta1/api gapic/v1beta1/types + gapic/v1beta1/tables Changelog diff --git a/automl/google/cloud/automl_v1beta1/__init__.py b/automl/google/cloud/automl_v1beta1/__init__.py index 2bc4b2a9f5a8..474b05550c81 100644 --- a/automl/google/cloud/automl_v1beta1/__init__.py +++ b/automl/google/cloud/automl_v1beta1/__init__.py @@ -21,6 +21,11 @@ from google.cloud.automl_v1beta1.gapic import auto_ml_client from google.cloud.automl_v1beta1.gapic import enums from google.cloud.automl_v1beta1.gapic import prediction_service_client +from google.cloud.automl_v1beta1.tables import tables_client + + +class TablesClient(tables_client.TablesClient): + __doc__ = tables_client.TablesClient.__doc__ class AutoMlClient(auto_ml_client.AutoMlClient): @@ -33,4 +38,4 @@ class PredictionServiceClient(prediction_service_client.PredictionServiceClient) enums = enums -__all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient") +__all__ = ("enums", "types", "AutoMlClient", "PredictionServiceClient", "TablesClient") diff --git a/automl/google/cloud/automl_v1beta1/tables/__init__.py b/automl/google/cloud/automl_v1beta1/tables/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/automl/google/cloud/automl_v1beta1/tables/tables_client.py b/automl/google/cloud/automl_v1beta1/tables/tables_client.py new file mode 100644 index 000000000000..303818b1536a --- /dev/null +++ b/automl/google/cloud/automl_v1beta1/tables/tables_client.py @@ -0,0 +1,2682 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A tables helper for the google.cloud.automl_v1beta1 AutoML API""" + +import pkg_resources + +from google.api_core.gapic_v1 import client_info +from google.api_core import exceptions +from google.cloud.automl_v1beta1 import gapic +from google.cloud.automl_v1beta1.proto import data_types_pb2 + +_GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version + + +class TablesClient(object): + """ + AutoML Tables API helper. + + This is intended to simplify usage of the auto-generated python client, + in particular for the `AutoML Tables product + `_. + """ + + def __init__( + self, + project=None, + region="us-central1", + client=None, + prediction_client=None, + **kwargs + ): + """Constructor. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + + Args: + project (Optional[string]): The project all future calls will + default to. Most methods take `project` as an optional + parameter, and can override your choice of `project` supplied + here. + region (Optional[string]): The region all future calls will + default to. Most methods take `region` as an optional + parameter, and can override your choice of `region` supplied + here. Note, only `us-central1` is supported to-date. + transport (Union[~.AutoMlGrpcTransport, Callable[[~.Credentials, type], ~.AutoMlGrpcTransport]): + A transport instance, responsible for actually making the API + calls. The default transport uses the gRPC protocol. This + argument may also be a callable which returns a transport + instance. Callables will be sent the credentials as the first + argument and the default transport class as the second + argument. + channel (grpc.Channel): DEPRECATED. A ``Channel`` instance + through which to make calls. This argument is mutually exclusive + with ``credentials``; providing both will raise an exception. + credentials (google.auth.credentials.Credentials): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is mutually exclusive with providing a + transport instance to ``transport``; doing so will raise + an exception. + client_config (dict): DEPRECATED. A dictionary of call options for + each method. If not specified, the default configuration is used. + client_options (Union[dict, google.api_core.client_options.ClientOptions]): + Client options used to set user options on the client. API Endpoint + should be set through client_options. + """ + version = _GAPIC_LIBRARY_VERSION + user_agent = "automl-tables-wrapper/{}".format(version) + + client_info_ = kwargs.get("client_info") + if client_info_ is None: + client_info_ = client_info.ClientInfo( + user_agent=user_agent, gapic_version=version + ) + else: + client_info_.user_agent = user_agent + client_info_.gapic_version = version + + if client is None: + self.auto_ml_client = gapic.auto_ml_client.AutoMlClient( + client_info=client_info_, **kwargs + ) + else: + self.auto_ml_client = client + + if prediction_client is None: + self.prediction_client = gapic.prediction_service_client.PredictionServiceClient( + client_info=client_info_, **kwargs + ) + else: + self.prediction_client = prediction_client + + self.project = project + self.region = region + + def __location_path(self, project=None, region=None): + if project is None: + if self.project is None: + raise ValueError( + "Either initialize your client with a value " + "for 'project', or provide 'project' as a " + "parameter for this method." + ) + project = self.project + + if region is None: + if self.region is None: + raise ValueError( + "Either initialize your client with a value " + "for 'region', or provide 'region' as a " + "parameter for this method." + ) + region = self.region + + return self.auto_ml_client.location_path(project, region) + + # the returned metadata object doesn't allow for updating fields, so + # we need to manually copy user-updated fields over + def __update_metadata(self, metadata, k, v): + new_metadata = {} + new_metadata["ml_use_column_spec_id"] = metadata.ml_use_column_spec_id + new_metadata["weight_column_spec_id"] = metadata.weight_column_spec_id + new_metadata["target_column_spec_id"] = metadata.target_column_spec_id + new_metadata[k] = v + + return new_metadata + + def __dataset_from_args( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + **kwargs + ): + if dataset is None and dataset_display_name is None and dataset_name is None: + raise ValueError( + "One of 'dataset', 'dataset_name' or " + "'dataset_display_name' must be set." + ) + # we prefer to make a live call here in the case that the + # dataset object is out-of-date + if dataset is not None: + dataset_name = dataset.name + + return self.get_dataset( + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + project=project, + region=region, + **kwargs + ) + + def __model_from_args( + self, + model=None, + model_display_name=None, + model_name=None, + project=None, + region=None, + **kwargs + ): + if model is None and model_display_name is None and model_name is None: + raise ValueError( + "One of 'model', 'model_name' or " "'model_display_name' must be set." + ) + # we prefer to make a live call here in the case that the + # model object is out-of-date + if model is not None: + model_name = model.name + + return self.get_model( + model_display_name=model_display_name, + model_name=model_name, + project=project, + region=region, + **kwargs + ) + + def __dataset_name_from_args( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + **kwargs + ): + if dataset is None and dataset_display_name is None and dataset_name is None: + raise ValueError( + "One of 'dataset', 'dataset_name' or " + "'dataset_display_name' must be set." + ) + + if dataset_name is None: + if dataset is None: + dataset = self.get_dataset( + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + dataset_name = dataset.name + else: + # we do this to force a NotFound error when needed + self.get_dataset( + dataset_name=dataset_name, project=project, region=region, **kwargs + ) + return dataset_name + + def __table_spec_name_from_args( + self, + table_spec_index=0, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + **kwargs + ): + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + table_specs = [ + t for t in self.list_table_specs(dataset_name=dataset_name, **kwargs) + ] + + table_spec_full_id = table_specs[table_spec_index].name + return table_spec_full_id + + def __model_name_from_args( + self, + model=None, + model_display_name=None, + model_name=None, + project=None, + region=None, + **kwargs + ): + if model is None and model_display_name is None and model_name is None: + raise ValueError( + "One of 'model', 'model_name' or " "'model_display_name' must be set." + ) + + if model_name is None: + if model is None: + model = self.get_model( + model_display_name=model_display_name, + project=project, + region=region, + **kwargs + ) + model_name = model.name + else: + # we do this to force a NotFound error when needed + self.get_model( + model_name=model_name, project=project, region=region, **kwargs + ) + return model_name + + def __column_spec_name_from_args( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + **kwargs + ): + column_specs = self.list_column_specs( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + project=project, + region=region, + **kwargs + ) + if column_spec_display_name is not None: + column_specs = {s.display_name: s for s in column_specs} + if column_specs.get(column_spec_display_name) is None: + raise exceptions.NotFound( + "No column with " + + "column_spec_display_name: '{}' found".format( + column_spec_display_name + ) + ) + column_spec_name = column_specs[column_spec_display_name].name + elif column_spec_name is not None: + column_specs = {s.name: s for s in column_specs} + if column_specs.get(column_spec_name) is None: + raise exceptions.NotFound( + "No column with " + + "column_spec_name: '{}' found".format(column_spec_name) + ) + else: + raise ValueError( + "Either supply 'column_spec_name' or " + "'column_spec_display_name' for the column to update" + ) + + return column_spec_name + + def __type_code_to_value_type(self, type_code, value): + if value is None: + return {"null_value": 0} + elif type_code == data_types_pb2.FLOAT64: + return {"number_value": value} + elif type_code == data_types_pb2.TIMESTAMP: + return {"string_value": value} + elif type_code == data_types_pb2.STRING: + return {"string_value": value} + elif type_code == data_types_pb2.ARRAY: + return {"list_value": value} + elif type_code == data_types_pb2.STRUCT: + return {"struct_value": value} + elif type_code == data_types_pb2.CATEGORY: + return {"string_value": value} + else: + raise ValueError("Unknown type_code: {}".format(type_code)) + + def list_datasets(self, project=None, region=None, **kwargs): + """List all datasets in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> ds = client.list_datasets() + >>> + >>> for d in ds: + ... # do something + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of :class:`~google.cloud.automl_v1beta1.types.Dataset` + instances. You can also iterate over the pages of the response + using its `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.list_datasets( + self.__location_path(project=project, region=region), **kwargs + ) + + def get_dataset( + self, + project=None, + region=None, + dataset_name=None, + dataset_display_name=None, + **kwargs + ): + """Gets a single dataset in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.get_dataset(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_name (Optional[string]): + This is the fully-qualified name generated by the AutoML API + for this dataset. This is not to be confused with the + human-assigned `dataset_display_name` that is provided when + creating a dataset. Either `dataset_name` or + `dataset_display_name` must be provided. + dataset_display_name (Optional[string]): + This is the name you provided for the dataset when first + creating it. Either `dataset_name` or `dataset_display_name` + must be provided. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance if + found, `None` otherwise. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if dataset_name is None and dataset_display_name is None: + raise ValueError( + "One of 'dataset_name' or " "'dataset_display_name' must be set." + ) + + if dataset_name is not None: + return self.auto_ml_client.get_dataset(dataset_name, **kwargs) + + result = next( + ( + d + for d in self.list_datasets(project, region, **kwargs) + if d.display_name == dataset_display_name + ), + None, + ) + + if result is None: + raise exceptions.NotFound( + ("Dataset with display_name: '{}' " + "not found").format( + dataset_display_name + ) + ) + + return result + + def create_dataset( + self, dataset_display_name, metadata={}, project=None, region=None, **kwargs + ): + """Create a dataset. Keep in mind, importing data is a separate step. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.create_dataset(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (string): + A human-readable name to refer to this dataset by. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.create_dataset( + self.__location_path(project, region), + {"display_name": dataset_display_name, "tables_dataset_metadata": metadata}, + **kwargs + ) + + def delete_dataset( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + **kwargs + ): + """Deletes a dataset. This does not delete any models trained on + this dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> op = client.delete_dataset(dataset_display_name='my_dataset') + >>> + >>> op.result() # blocks on delete request + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to + delete. This must be supplied if `dataset` or `dataset_name` + are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + delete. This must be supplied if `dataset_display_name` or + `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to delete. This must be + supplied if `dataset_display_name` or `dataset_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + try: + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + # delete is idempotent + except exceptions.NotFound: + return None + + return self.auto_ml_client.delete_dataset(dataset_name, **kwargs) + + def import_data( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + gcs_input_uris=None, + bigquery_input_uri=None, + project=None, + region=None, + **kwargs + ): + """Imports data into a dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.create_dataset(dataset_display_name='my_dataset') + >>> + >>> client.import_data(dataset=d, + ... gcs_input_uris='gs://cloud-ml-tables-data/bank-marketing.csv') + ... + >>> def callback(operation_future): + ... result = operation_future.result() + ... + >>> response.add_done_callback(callback) + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to import + data into. This must be supplied if `dataset` or `dataset_name` + are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + import data into. This must be supplied if + `dataset_display_name` or `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to import data into. This must + be supplied if `dataset_display_name` or `dataset_name` are not + supplied. + gcs_input_uris (Optional[Union[string, Sequence[string]]]): + Either a single `gs://..` prefixed URI, or a list of URIs + referring to GCS-hosted CSV files containing the data to + import. This must be supplied if `bigquery_input_uri` is not. + bigquery_input_uri (Optional[string]): + A URI pointing to the BigQuery table containing the data to + import. This must be supplied if `gcs_input_uris` is not. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + request = {} + if gcs_input_uris is not None: + if type(gcs_input_uris) != list: + gcs_input_uris = [gcs_input_uris] + request = {"gcs_source": {"input_uris": gcs_input_uris}} + elif bigquery_input_uri is not None: + request = {"bigquery_source": {"input_uri": bigquery_input_uri}} + else: + raise ValueError( + "One of 'gcs_input_uris', or " "'bigquery_input_uri' must be set." + ) + + return self.auto_ml_client.import_data(dataset_name, request, **kwargs) + + def export_data( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + gcs_output_uri_prefix=None, + bigquery_output_uri=None, + project=None, + region=None, + **kwargs + ): + """Exports data from a dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.create_dataset(dataset_display_name='my_dataset') + >>> + >>> client.export_data(dataset=d, + ... gcs_output_uri_prefix='gs://cloud-ml-tables-data/bank-marketing.csv') + ... + >>> def callback(operation_future): + ... result = operation_future.result() + ... + >>> response.add_done_callback(callback) + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to export + data from. This must be supplied if `dataset` or `dataset_name` + are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + export data from. This must be supplied if + `dataset_display_name` or `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to export data from. This must + be supplied if `dataset_display_name` or `dataset_name` are not + supplied. + gcs_output_uri_prefix (Optional[Union[string, Sequence[string]]]): + A single `gs://..` prefixed URI to export to. This must be + supplied if `bigquery_output_uri` is not. + bigquery_output_uri (Optional[string]): + A URI pointing to the BigQuery table containing the data to + export. This must be supplied if `gcs_output_uri_prefix` is not. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + request = {} + if gcs_output_uri_prefix is not None: + request = {"gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix}} + elif bigquery_output_uri is not None: + request = {"bigquery_destination": {"output_uri": bigquery_output_uri}} + else: + raise ValueError( + "One of 'gcs_output_uri_prefix', or 'bigquery_output_uri' must be set." + ) + + return self.auto_ml_client.export_data(dataset_name, request, **kwargs) + + def get_table_spec(self, table_spec_name, project=None, region=None, **kwargs): + """Gets a single table spec in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.get_table_spec('my_table_spec') + >>> + + Args: + table_spec_name (string): + This is the fully-qualified name generated by the AutoML API + for this table spec. + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.TableSpec` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.get_table_spec(table_spec_name, **kwargs) + + def list_table_specs( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + **kwargs + ): + """Lists table specs. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> for s in client.list_table_specs(dataset_display_name='my_dataset') + ... # process the spec + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to read + specs from. This must be supplied if `dataset` or + `dataset_name` are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to read + specs from. This must be supplied if `dataset_display_name` or + `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to read specs from. This must + be supplied if `dataset_display_name` or `dataset_name` are not + supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of + :class:`~google.cloud.automl_v1beta1.types.TableSpec` instances. + You can also iterate over the pages of the response using its + `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + return self.auto_ml_client.list_table_specs(dataset_name, **kwargs) + + def get_column_spec(self, column_spec_name, project=None, region=None, **kwargs): + """Gets a single column spec in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.get_column_spec('my_column_spec') + >>> + + Args: + column_spec_name (string): + This is the fully-qualified name generated by the AutoML API + for this column spec. + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.ColumnSpec` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.get_column_spec(column_spec_name, **kwargs) + + def list_column_specs( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + project=None, + region=None, + **kwargs + ): + """Lists column specs. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> for s in client.list_column_specs(dataset_display_name='my_dataset') + ... # process the spec + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose specs you want to + read. If not supplied, the client can determine this name from + a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` was provided, we use this index to + determine which table to read column specs from. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to read + specs from. If no `table_spec_name` is supplied, this will be + used together with `table_spec_index` to infer the name of + table to read specs from. This must be supplied if + `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to read + specs from. If no `table_spec_name` is supplied, this will be + used together with `table_spec_index` to infer the name of + table to read specs from. This must be supplied if + `table_spec_name`, `dataset` or `dataset_display_name` are not + supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to read specs from. If no + `table_spec_name` is supplied, this will be used together with + `table_spec_index` to infer the name of table to read specs + from. This must be supplied if `table_spec_name`, + `dataset_name` or `dataset_display_name` are not supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of + :class:`~google.cloud.automl_v1beta1.types.ColumnSpec` instances. + You can also iterate over the pages of the response using its + `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if table_spec_name is None: + table_specs = [ + t + for t in self.list_table_specs( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + project=project, + region=region, + **kwargs + ) + ] + + table_spec_name = table_specs[table_spec_index].name + + return self.auto_ml_client.list_column_specs(table_spec_name, **kwargs) + + def update_column_spec( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + type_code=None, + nullable=None, + project=None, + region=None, + **kwargs + ): + """Updates a column's specs. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.update_column_specs(dataset_display_name='my_dataset', + ... column_spec_display_name='Outcome', type_code='CATEGORY') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to + update. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to update. If + this is supplied in place of `column_spec_name`, you also need + to provide either a way to lookup the source dataset (using one + of the `dataset*` kwargs), or the `table_spec_name` of the + table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose specs you want to + update. If not supplied, the client can determine this name + from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` was provided, we use this index to + determine which table to update column specs on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + specs on. If no `table_spec_name` is supplied, this will be + used together with `table_spec_index` to infer the name of + table to update specs on. This must be supplied if + `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update specs one. If no `table_spec_name` is supplied, this + will be used together with `table_spec_index` to infer the name + of table to update specs on. This must be supplied if + `table_spec_name`, `dataset` or `dataset_display_name` are not + supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update specs on. If no + `table_spec_name` is supplied, this will be used together with + `table_spec_index` to infer the name of table to update specs + on. This must be supplied if `table_spec_name`, `dataset_name` + or `dataset_display_name` are not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.ColumnSpec` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, + **kwargs + ) + + # type code must always be set + if type_code is None: + # this index is safe, we would have already thrown a NotFound + # had the column_spec_name not existed + type_code = { + s.name: s + for s in self.list_column_specs( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + project=project, + region=region, + **kwargs + ) + }[column_spec_name].data_type.type_code + + data_type = {} + if nullable is not None: + data_type["nullable"] = nullable + + data_type["type_code"] = type_code + + request = {"name": column_spec_name, "data_type": data_type} + + return self.auto_ml_client.update_column_spec(request, **kwargs) + + def set_target_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + **kwargs + ): + """Sets the target column for a given table. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.set_target_column(dataset_display_name='my_dataset', + ... column_spec_display_name='Income') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to set as + the target column. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to set as the + target column. If this is supplied in place of + `column_spec_name`, you also need to provide either a way to + lookup the source dataset (using one of the `dataset*` kwargs), + or the `table_spec_name` of the table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose target column you + want to set . If not supplied, the client can determine this + name from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` or `column_spec_name` was provided, we + use this index to determine which table to set the target + column on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the target column of. If no `table_spec_name` is supplied, this + will be used together with `table_spec_index` to infer the name + of table to update the target column of. This must be supplied + if `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the target column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the target column of. This + must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the target column of. + If no `table_spec_name` is supplied, this will be used together + with `table_spec_index` to infer the name of table to update + the target column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, + **kwargs + ) + column_spec_id = column_spec_name.rsplit("/", 1)[-1] + + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata( + metadata, "target_column_spec_id", column_spec_id + ) + + request = {"name": dataset.name, "tables_dataset_metadata": metadata} + + return self.auto_ml_client.update_dataset(request, **kwargs) + + def set_time_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + **kwargs + ): + """Sets the time column which designates which data will be of type + timestamp and will be used for the timeseries data. + This column must be of type timestamp. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.set_time_column(dataset_display_name='my_dataset', + ... column_spec_name='Unix Time') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to set as + the time column. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to set as the + time column. If this is supplied in place of + `column_spec_name`, you also need to provide either a way to + lookup the source dataset (using one of the `dataset*` kwargs), + or the `table_spec_name` of the table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose time column + you want to set . If not supplied, the client can determine + this name from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` or `column_spec_name` was provided, we + use this index to determine which table to set the time + column on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the time column of. If no `table_spec_name` is supplied, + this will be used together with `table_spec_index` to infer the + name of table to update the time column of. This must be + supplied if `table_spec_name`, `dataset` or `dataset_name` are + not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the time column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the time column of. + This must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the time column + of. If no `table_spec_name` is supplied, this will be used + together with `table_spec_index` to infer the name of table to + update the time column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + Returns: + A :class:`~google.cloud.automl_v1beta1.types.TableSpec` instance. + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, + **kwargs + ) + column_spec_id = column_spec_name.rsplit("/", 1)[-1] + + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + table_spec_full_id = self.__table_spec_name_from_args( + dataset_name=dataset_name, **kwargs + ) + + my_table_spec = { + "name": table_spec_full_id, + "time_column_spec_id": column_spec_id, + } + + return self.auto_ml_client.update_table_spec(my_table_spec, **kwargs) + + def clear_time_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + **kwargs + ): + """Clears the time column which designates which data will be of type + timestamp and will be used for the timeseries data. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.set_time_column(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the time column of. If no `table_spec_name` is supplied, + this will be used together with `table_spec_index` to infer the + name of table to update the time column of. This must be + supplied if `table_spec_name`, `dataset` or `dataset_name` are + not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the time column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the time column of. + This must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the time column + of. If no `table_spec_name` is supplied, this will be used + together with `table_spec_index` to infer the name of table to + update the time column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.TableSpec` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + table_spec_full_id = self.__table_spec_name_from_args( + dataset_name=dataset_name, **kwargs + ) + + my_table_spec = {"name": table_spec_full_id, "time_column_spec_id": None} + + return self.auto_ml_client.update_table_spec(my_table_spec, **kwargs) + + def set_weight_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + **kwargs + ): + """Sets the weight column for a given table. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.set_weight_column(dataset_display_name='my_dataset', + ... column_spec_display_name='Income') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to + set as the weight column. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to set as the + weight column. If this is supplied in place of + `column_spec_name`, you also need to provide either a way to + lookup the source dataset (using one of the `dataset*` kwargs), + or the `table_spec_name` of the table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose weight column you + want to set . If not supplied, the client can determine this + name from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` or `column_spec_name` was provided, we + use this index to determine which table to set the weight + column on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the weight column of. If no `table_spec_name` is supplied, this + will be used together with `table_spec_index` to infer the name + of table to update the weight column of. This must be supplied + if `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the weight column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the weight column of. This + must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the weight column of. + If no `table_spec_name` is supplied, this will be used together + with `table_spec_index` to infer the name of table to update + the weight column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, + **kwargs + ) + column_spec_id = column_spec_name.rsplit("/", 1)[-1] + + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata( + metadata, "weight_column_spec_id", column_spec_id + ) + + request = {"name": dataset.name, "tables_dataset_metadata": metadata} + + return self.auto_ml_client.update_dataset(request, **kwargs) + + def clear_weight_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + **kwargs + ): + """Clears the weight column for a given dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.clear_weight_column(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the weight column of. If no `table_spec_name` is supplied, this + will be used together with `table_spec_index` to infer the name + of table to update the weight column of. This must be supplied + if `table_spec_name`, `dataset` or `dataset_name` are not + supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the weight column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the weight column of. This + must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the weight column of. + If no `table_spec_name` is supplied, this will be used together + with `table_spec_index` to infer the name of table to update + the weight column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata(metadata, "weight_column_spec_id", None) + + request = {"name": dataset.name, "tables_dataset_metadata": metadata} + + return self.auto_ml_client.update_dataset(request, **kwargs) + + def set_test_train_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + table_spec_name=None, + table_spec_index=0, + column_spec_name=None, + column_spec_display_name=None, + project=None, + region=None, + **kwargs + ): + """Sets the test/train (ml_use) column which designates which data + belongs to the test and train sets. This column must be categorical. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.set_test_train_column(dataset_display_name='my_dataset', + ... column_spec_display_name='TestSplit') + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + column_spec_name (Optional[string]): + The name AutoML-assigned name for the column you want to set as + the test/train column. + column_spec_display_name (Optional[string]): + The human-readable name of the column you want to set as the + test/train column. If this is supplied in place of + `column_spec_name`, you also need to provide either a way to + lookup the source dataset (using one of the `dataset*` kwargs), + or the `table_spec_name` of the table this column belongs to. + table_spec_name (Optional[string]): + The AutoML-assigned name for the table whose test/train column + you want to set . If not supplied, the client can determine + this name from a source `Dataset` object. + table_spec_index (Optional[int]): + If no `table_spec_name` or `column_spec_name` was provided, we + use this index to determine which table to set the test/train + column on. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the test/train column of. If no `table_spec_name` is supplied, + this will be used together with `table_spec_index` to infer the + name of table to update the test/train column of. This must be + supplied if `table_spec_name`, `dataset` or `dataset_name` are + not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the test/train column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the test/train column of. + This must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the test/train column + of. If no `table_spec_name` is supplied, this will be used + together with `table_spec_index` to infer the name of table to + update the test/train column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + column_spec_name = self.__column_spec_name_from_args( + dataset=dataset, + dataset_display_name=dataset_display_name, + dataset_name=dataset_name, + table_spec_name=table_spec_name, + table_spec_index=table_spec_index, + column_spec_name=column_spec_name, + column_spec_display_name=column_spec_display_name, + project=project, + region=region, + **kwargs + ) + column_spec_id = column_spec_name.rsplit("/", 1)[-1] + + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata( + metadata, "ml_use_column_spec_id", column_spec_id + ) + + request = {"name": dataset.name, "tables_dataset_metadata": metadata} + + return self.auto_ml_client.update_dataset(request, **kwargs) + + def clear_test_train_column( + self, + dataset=None, + dataset_display_name=None, + dataset_name=None, + project=None, + region=None, + **kwargs + ): + """Clears the test/train (ml_use) column which designates which data + belongs to the test and train sets. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.clear_test_train_column(dataset_display_name='my_dataset') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to update + the test/train column of. If no `table_spec_name` is supplied, + this will be used together with `table_spec_index` to infer the + name of table to update the test/train column of. This must be + supplied if `table_spec_name`, `dataset` or `dataset_name` are + not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to + update the test/train column of. If no `table_spec_name` is + supplied, this will be used together with `table_spec_index` to + infer the name of table to update the test/train column of. + This must be supplied if `table_spec_name`, `dataset` or + `dataset_display_name` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to update the test/train column + of. If no `table_spec_name` is supplied, this will be used + together with `table_spec_index` to infer the name of table to + update the test/train column of. This must be supplied if + `table_spec_name`, `dataset_name` or `dataset_display_name` are + not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Dataset` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + dataset = self.__dataset_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + metadata = dataset.tables_dataset_metadata + metadata = self.__update_metadata(metadata, "ml_use_column_spec_id", None) + + request = {"name": dataset.name, "tables_dataset_metadata": metadata} + + return self.auto_ml_client.update_dataset(request, **kwargs) + + def list_models(self, project=None, region=None, **kwargs): + """List all models in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> ms = client.list_models() + >>> + >>> for m in ms: + ... # do something + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of :class:`~google.cloud.automl_v1beta1.types.Model` + instances. You can also iterate over the pages of the response + using its `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.list_models( + self.__location_path(project=project, region=region), **kwargs + ) + + def list_model_evaluations( + self, + project=None, + region=None, + model=None, + model_display_name=None, + model_name=None, + **kwargs + ): + """List all model evaluations for a given model. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> ms = client.list_model_evaluations(model_display_name='my_model') + >>> + >>> for m in ms: + ... # do something + ... pass + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (Optional[string]): + The human-readable name given to the model you want to list + evaluations for. This must be supplied if `model` or + `model_name` are not supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to list + evaluations for. This must be supplied if `model_display_name` + or `model` are not supplied. + model (Optional[model]): + The `model` instance you want to list evaluations for. This + must be supplied if `model_display_name` or `model_name` are + not supplied. + + Returns: + A :class:`~google.api_core.page_iterator.PageIterator` instance. + An iterable of + :class:`~google.cloud.automl_v1beta1.types.ModelEvaluation` + instances. You can also iterate over the pages of the response + using its `pages` property. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, + **kwargs + ) + + return self.auto_ml_client.list_model_evaluations(model_name, **kwargs) + + def create_model( + self, + model_display_name, + dataset=None, + dataset_display_name=None, + dataset_name=None, + train_budget_milli_node_hours=None, + optimization_objective=None, + project=None, + region=None, + model_metadata={}, + include_column_spec_names=None, + exclude_column_spec_names=None, + **kwargs + ): + """Create a model. This will train your model on the given dataset. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> m = client.create_model('my_model', dataset_display_name='my_dataset') + >>> + >>> m.result() # blocks on result + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (string): + A human-readable name to refer to this model by. + train_budget_milli_node_hours (int): + The amount of time (in thousandths of an hour) to spend + training. This value must be between 1,000 and 72,000 inclusive + (between 1 and 72 hours). + optimization_objective (string): + The metric AutoML tables should optimize for. + dataset_display_name (Optional[string]): + The human-readable name given to the dataset you want to train + your model on. This must be supplied if `dataset` or + `dataset_name` are not supplied. + dataset_name (Optional[string]): + The AutoML-assigned name given to the dataset you want to train + your model on. This must be supplied if `dataset_display_name` + or `dataset` are not supplied. + dataset (Optional[Dataset]): + The `Dataset` instance you want to train your model on. This + must be supplied if `dataset_display_name` or `dataset_name` + are not supplied. + model_metadata (Optional[Dict]): + Optional model metadata to supply to the client. + include_column_spec_names(Optional[string]): + The list of the names of the columns you want to include to train + your model on. + exclude_column_spec_names(Optional[string]): + The list of the names of the columns you want to exclude and + not train your model on. + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if ( + train_budget_milli_node_hours is None + or train_budget_milli_node_hours < 1000 + or train_budget_milli_node_hours > 72000 + ): + raise ValueError( + "'train_budget_milli_node_hours' must be a " + "value between 1,000 and 72,000 inclusive" + ) + + if exclude_column_spec_names not in [ + None, + [], + ] and include_column_spec_names not in [None, []]: + raise ValueError( + "Cannot set both " + "'exclude_column_spec_names' and " + "'include_column_spec_names'" + ) + + dataset_name = self.__dataset_name_from_args( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + project=project, + region=region, + **kwargs + ) + + model_metadata["train_budget_milli_node_hours"] = train_budget_milli_node_hours + if optimization_objective is not None: + model_metadata["optimization_objective"] = optimization_objective + + dataset_id = dataset_name.rsplit("/", 1)[-1] + columns = [ + s + for s in self.list_column_specs( + dataset=dataset, + dataset_name=dataset_name, + dataset_display_name=dataset_display_name, + **kwargs + ) + ] + + final_columns = [] + if include_column_spec_names: + for c in columns: + if c.display_name in include_column_spec_names: + final_columns.append(c) + + model_metadata["input_feature_column_specs"] = final_columns + elif exclude_column_spec_names: + for a in columns: + if a.display_name not in exclude_column_spec_names: + final_columns.append(a) + + model_metadata["input_feature_column_specs"] = final_columns + + request = { + "display_name": model_display_name, + "dataset_id": dataset_id, + "tables_model_metadata": model_metadata, + } + + return self.auto_ml_client.create_model( + self.__location_path(project=project, region=region), request, **kwargs + ) + + def delete_model( + self, + model=None, + model_display_name=None, + model_name=None, + project=None, + region=None, + **kwargs + ): + """Deletes a model. Note this will not delete any datasets associated + with this model. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> op = client.delete_model(model_display_name='my_model') + >>> + >>> op.result() # blocks on delete request + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (Optional[string]): + The human-readable name given to the model you want to + delete. This must be supplied if `model` or `model_name` + are not supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to + delete. This must be supplied if `model_display_name` or + `model` are not supplied. + model (Optional[model]): + The `model` instance you want to delete. This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + try: + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, + **kwargs + ) + # delete is idempotent + except exceptions.NotFound: + return None + + return self.auto_ml_client.delete_model(model_name, **kwargs) + + def get_model_evaluation( + self, model_evaluation_name, project=None, region=None, **kwargs + ): + """Gets a single evaluation model in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.get_model_evaluation('my_model_evaluation') + >>> + + Args: + model_evaluation_name (string): + This is the fully-qualified name generated by the AutoML API + for this model evaluation. + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.ModelEvaluation` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + return self.auto_ml_client.get_model_evaluation(model_evaluation_name, **kwargs) + + def get_model( + self, + project=None, + region=None, + model_name=None, + model_display_name=None, + **kwargs + ): + """Gets a single model in a particular project and region. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> d = client.get_model(model_display_name='my_model') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_name (Optional[string]): + This is the fully-qualified name generated by the AutoML API + for this model. This is not to be confused with the + human-assigned `model_display_name` that is provided when + creating a model. Either `model_name` or + `model_display_name` must be provided. + model_display_name (Optional[string]): + This is the name you provided for the model when first + creating it. Either `model_name` or `model_display_name` + must be provided. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.Model` instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + if model_name is None and model_display_name is None: + raise ValueError( + "One of 'model_name' or " "'model_display_name' must be set." + ) + + if model_name is not None: + return self.auto_ml_client.get_model(model_name, **kwargs) + + model = next( + ( + d + for d in self.list_models(project, region, **kwargs) + if d.display_name == model_display_name + ), + None, + ) + + if model is None: + raise exceptions.NotFound( + "No model with model_diplay_name: " + + "'{}' found".format(model_display_name) + ) + + return model + + # TODO(jonathanskim): allow deployment from just model ID + def deploy_model( + self, + model=None, + model_name=None, + model_display_name=None, + project=None, + region=None, + **kwargs + ): + """Deploys a model. This allows you make online predictions using the + model you've deployed. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> op = client.deploy_model(model_display_name='my_model') + >>> + >>> op.result() # blocks on deploy request + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (Optional[string]): + The human-readable name given to the model you want to + deploy. This must be supplied if `model` or `model_name` + are not supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to + deploy. This must be supplied if `model_display_name` or + `model` are not supplied. + model (Optional[model]): + The `model` instance you want to deploy. This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, + **kwargs + ) + + return self.auto_ml_client.deploy_model(model_name, **kwargs) + + def undeploy_model( + self, + model=None, + model_name=None, + model_display_name=None, + project=None, + region=None, + **kwargs + ): + """Undeploys a model. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> op = client.undeploy_model(model_display_name='my_model') + >>> + >>> op.result() # blocks on undeploy request + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + model_display_name (Optional[string]): + The human-readable name given to the model you want to + undeploy. This must be supplied if `model` or `model_name` + are not supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to + undeploy. This must be supplied if `model_display_name` or + `model` are not supplied. + model (Optional[model]): + The `model` instance you want to undeploy. This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, + **kwargs + ) + + return self.auto_ml_client.undeploy_model(model_name, **kwargs) + + ## TODO(lwander): support pandas DataFrame as input type + def predict( + self, + inputs, + model=None, + model_name=None, + model_display_name=None, + project=None, + region=None, + **kwargs + ): + """Makes a prediction on a deployed model. This will fail if the model + was not deployed. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.predict(inputs={'Age': 30, 'Income': 12, 'Category': 'A'} + ... model_display_name='my_model') + ... + >>> client.predict([30, 12, 'A'], model_display_name='my_model') + >>> + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + inputs (Union[List[string], Dict[string, string]]): + Either the sorted list of column values to predict with, or a + key-value map of column display name to value to predict with. + model_display_name (Optional[string]): + The human-readable name given to the model you want to predict + with. This must be supplied if `model` or `model_name` are not + supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to predict + with. This must be supplied if `model_display_name` or `model` + are not supplied. + model (Optional[model]): + The `model` instance you want to predict with . This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types.PredictResponse` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model = self.__model_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, + **kwargs + ) + + column_specs = model.tables_model_metadata.input_feature_column_specs + if type(inputs) == dict: + inputs = [inputs.get(c.display_name, None) for c in column_specs] + + if len(inputs) != len(column_specs): + raise ValueError( + ( + "Dimension mismatch, the number of provided " + "inputs ({}) does not match that of the model " + "({})" + ).format(len(inputs), len(column_specs)) + ) + + values = [] + for i, c in zip(inputs, column_specs): + value_type = self.__type_code_to_value_type(c.data_type.type_code, i) + values.append(value_type) + + request = {"row": {"values": values}} + + return self.prediction_client.predict(model.name, request, **kwargs) + + def batch_predict( + self, + bigquery_input_uri=None, + bigquery_output_uri=None, + gcs_input_uris=None, + gcs_output_uri_prefix=None, + model=None, + model_name=None, + model_display_name=None, + project=None, + region=None, + inputs=None, + **kwargs + ): + """Makes a batch prediction on a model. This does _not_ require the + model to be deployed. + + Example: + >>> from google.cloud import automl_v1beta1 + >>> + >>> from google.oauth2 import service_account + >>> + >>> client = automl_v1beta1.TablesClient( + ... credentials=service_account.Credentials.from_service_account_file('~/.gcp/account.json') + ... project='my-project', region='us-central1') + ... + >>> client.batch_predict( + ... gcs_input_uris='gs://inputs/input.csv', + ... gcs_output_uri_prefix='gs://outputs/', + ... model_display_name='my_model' + ... ).result() + ... + + Args: + project (Optional[string]): + If you have initialized the client with a value for `project` + it will be used if this parameter is not supplied. Keep in + mind, the service account this client was initialized with must + have access to this project. + region (Optional[string]): + If you have initialized the client with a value for `region` it + will be used if this parameter is not supplied. + gcs_input_uris (Optional(Union[List[string], string])) + Either a list of or a single GCS URI containing the data you + want to predict off of. + gcs_output_uri_prefix (Optional[string]) + The folder in GCS you want to write output to. + bigquery_input_uri (Optional[string]) + The BigQuery table to input data from. + bigquery_output_uri (Optional[string]) + The BigQuery table to output data to. + model_display_name (Optional[string]): + The human-readable name given to the model you want to predict + with. This must be supplied if `model` or `model_name` are not + supplied. + model_name (Optional[string]): + The AutoML-assigned name given to the model you want to predict + with. This must be supplied if `model_display_name` or `model` + are not supplied. + model (Optional[model]): + The `model` instance you want to predict with . This must be + supplied if `model_display_name` or `model_name` are not + supplied. + + Returns: + A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` + instance. + + Raises: + google.api_core.exceptions.GoogleAPICallError: If the request + failed for any reason. + google.api_core.exceptions.RetryError: If the request failed due + to a retryable error and retry attempts failed. + ValueError: If required parameters are missing. + """ + model_name = self.__model_name_from_args( + model=model, + model_name=model_name, + model_display_name=model_display_name, + project=project, + region=region, + **kwargs + ) + + input_request = None + if gcs_input_uris is not None: + if type(gcs_input_uris) != list: + gcs_input_uris = [gcs_input_uris] + input_request = {"gcs_source": {"input_uris": gcs_input_uris}} + elif bigquery_input_uri is not None: + input_request = {"bigquery_source": {"input_uri": bigquery_input_uri}} + else: + raise ValueError( + "One of 'gcs_input_uris'/'bigquery_input_uris' must" "be set" + ) + + output_request = None + if gcs_output_uri_prefix is not None: + output_request = { + "gcs_destination": {"output_uri_prefix": gcs_output_uri_prefix} + } + elif bigquery_output_uri is not None: + output_request = { + "bigquery_destination": {"output_uri": bigquery_output_uri} + } + else: + raise ValueError( + "One of 'gcs_output_uri_prefix'/'bigquery_output_uri' must be set" + ) + + return self.prediction_client.batch_predict( + model_name, input_request, output_request, **kwargs + ) diff --git a/automl/synth.py b/automl/synth.py index 4318ff31fc2b..ab93031fc583 100644 --- a/automl/synth.py +++ b/automl/synth.py @@ -33,6 +33,22 @@ s.move(library / f"tests/unit/gapic/{version}") s.move(library / f"docs/gapic/{version}") + s.replace( + f"google/cloud/automl_{version}/__init__.py", + f"from google.cloud.automl_v1beta1.gapic import prediction_service_client", + f"from google.cloud.automl_v1beta1.gapic import prediction_service_client" + f"from google.cloud.automl_v1beta1.tables import tables_client" + f"\n\n" + f"class TablesClient(tables_client.TablesClient):" + f" __doc__ = tables_client.TablesClient.__doc__" + ) + + s.replace( + f"google/cloud/automl_{version}/__init__.py", + f"__all__ = (\"enums\", \"types\", \"AutoMlClient\", \"PredictionServiceClient\")", + f"__all__ = (\"enums\", \"types\", \"AutoMlClient\", \"PredictionServiceClient\", \"TablesClient\")" + ) + s.move(library / f"docs/conf.py") # Use the highest version library to generate import alias. diff --git a/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py new file mode 100644 index 000000000000..2a763cdc24dc --- /dev/null +++ b/automl/tests/system/gapic/v1beta1/test_system_tables_client_v1.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest +import random +import string +import time + +from google.cloud import automl_v1beta1 +from google.api_core import exceptions +from google.cloud.automl_v1beta1.gapic import enums + +PROJECT = os.environ["PROJECT_ID"] +REGION = "us-central1" +MAX_WAIT_TIME_SECONDS = 30 +MAX_SLEEP_TIME_SECONDS = 5 +STATIC_DATASET = "test_dataset_do_not_delete" +STATIC_MODEL = "test_model_do_not_delete" + +ID = "{rand}_{time}".format( + rand="".join( + [random.choice(string.ascii_letters + string.digits) for n in range(4)] + ), + time=int(time.time()), +) + + +def _id(name): + return "{}_{}".format(name, ID) + + +class TestSystemTablesClient(object): + def cancel_and_wait(self, op): + op.cancel() + start = time.time() + sleep_time = 1 + while time.time() - start < MAX_WAIT_TIME_SECONDS: + if op.cancelled(): + return + time.sleep(sleep_time) + sleep_time = min(sleep_time * 2, MAX_SLEEP_TIME_SECONDS) + assert op.cancelled() + + def test_list_datasets(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + # will raise if not found + next( + iter( + [d for d in client.list_datasets(timeout=10) if d.name == dataset.name] + ) + ) + + def test_list_models(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + model = self.ensure_model_ready(client) + # will raise if not found + next(iter([m for m in client.list_models(timeout=10) if m.name == model.name])) + + def test_create_delete_dataset(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + display_name = _id("t_cr_dl") + dataset = client.create_dataset(display_name) + assert dataset is not None + assert ( + dataset.name == client.get_dataset(dataset_display_name=display_name).name + ) + client.delete_dataset(dataset=dataset) + + def test_import_data(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + display_name = _id("t_import") + dataset = client.create_dataset(display_name) + op = client.import_data( + dataset=dataset, + gcs_input_uris="gs://cloud-ml-tables-data/bank-marketing.csv", + ) + self.cancel_and_wait(op) + client.delete_dataset(dataset=dataset) + + def ensure_dataset_ready(self, client): + dataset = None + try: + dataset = client.get_dataset(dataset_display_name=STATIC_DATASET) + except exceptions.NotFound: + dataset = client.create_dataset(STATIC_DATASET) + + if dataset.example_count is None or dataset.example_count == 0: + op = client.import_data( + dataset=dataset, + gcs_input_uris="gs://cloud-ml-tables-data/bank-marketing.csv", + ) + op.result() + dataset = client.get_dataset(dataset_name=dataset.name) + + return dataset + + def test_list_column_specs(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + # will raise if not found + next( + iter( + [ + d + for d in client.list_column_specs(dataset=dataset) + if d.display_name == "Deposit" + ] + ) + ) + + def test_get_column_spec(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + name = [d for d in client.list_column_specs(dataset=dataset)][0].name + assert client.get_column_spec(name).name == name + + def test_list_table_specs(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + name = [d for d in client.list_table_specs(dataset=dataset)][0].name + assert client.get_table_spec(name).name == name + + def test_set_column_nullable(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.update_column_spec( + dataset=dataset, column_spec_display_name="POutcome", nullable=True + ) + columns = {c.display_name: c for c in client.list_column_specs(dataset=dataset)} + assert columns["POutcome"].data_type.nullable == True + + def test_set_target_column(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.set_target_column(dataset=dataset, column_spec_display_name="Age") + columns = {c.display_name: c for c in client.list_column_specs(dataset=dataset)} + dataset = client.get_dataset(dataset_name=dataset.name) + metadata = dataset.tables_dataset_metadata + assert columns["Age"].name.endswith( + "/{}".format(metadata.target_column_spec_id) + ) + + def test_set_weight_column(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.set_weight_column(dataset=dataset, column_spec_display_name="Duration") + columns = {c.display_name: c for c in client.list_column_specs(dataset=dataset)} + dataset = client.get_dataset(dataset_name=dataset.name) + metadata = dataset.tables_dataset_metadata + assert columns["Duration"].name.endswith( + "/{}".format(metadata.weight_column_spec_id) + ) + + def test_set_weight_and_target_column(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.set_weight_column(dataset=dataset, column_spec_display_name="Day") + client.set_target_column(dataset=dataset, column_spec_display_name="Campaign") + columns = {c.display_name: c for c in client.list_column_specs(dataset=dataset)} + dataset = client.get_dataset(dataset_name=dataset.name) + metadata = dataset.tables_dataset_metadata + assert columns["Day"].name.endswith( + "/{}".format(metadata.weight_column_spec_id) + ) + assert columns["Campaign"].name.endswith( + "/{}".format(metadata.target_column_spec_id) + ) + + def test_create_delete_model(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + dataset = self.ensure_dataset_ready(client) + client.set_target_column(dataset=dataset, column_spec_display_name="Deposit") + display_name = _id("t_cr_dl") + op = client.create_model( + display_name, dataset=dataset, train_budget_milli_node_hours=1000 + ) + self.cancel_and_wait(op) + client.delete_model(model_display_name=display_name) + + def test_list_model_evaluations(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + model = self.ensure_model_online(client) + # will raise if not found + next( + iter( + [ + m + for m in client.list_model_evaluations(model=model) + if m.display_name is not None + ] + ) + ) + + def test_get_model_evaluation(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + model = self.ensure_model_online(client) + name = [m for m in client.list_model_evaluations(model=model)][0].name + assert client.get_model_evaluation(model_evaluation_name=name).name == name + + def test_online_predict(self): + client = automl_v1beta1.TablesClient(project=PROJECT, region=REGION) + model = self.ensure_model_online(client) + result = client.predict( + inputs={ + "Age": 31, + "Balance": 200, + "Campaign": 2, + "Contact": "cellular", + "Day": 4, + "Default": "no", + "Duration": 12, + "Education": "primary", + "Housing": "yes", + "Job": "blue-collar", + "Loan": "no", + "MaritalStatus": "divorced", + "Month": "jul", + "PDays": 4, + "POutcome": None, + "Previous": 12, + }, + model=model, + ) + assert result is not None + + def ensure_model_online(self, client): + model = self.ensure_model_ready(client) + if model.deployment_state != enums.Model.DeploymentState.DEPLOYED: + client.deploy_model(model=model).result() + + return client.get_model(model_name=model.name) + + def ensure_model_ready(self, client): + try: + return client.get_model(model_display_name=STATIC_MODEL) + except exceptions.NotFound: + pass + + dataset = self.ensure_dataset_ready(client) + client.set_target_column(dataset=dataset, column_spec_display_name="Deposit") + client.clear_weight_column(dataset=dataset) + client.clear_test_train_column(dataset=dataset) + client.update_column_spec( + dataset=dataset, column_spec_display_name="POutcome", nullable=True + ) + op = client.create_model( + STATIC_MODEL, dataset=dataset, train_budget_milli_node_hours=1000 + ) + op.result() + return client.get_model(model_display_name=STATIC_MODEL) diff --git a/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py new file mode 100644 index 000000000000..5176a31cce6d --- /dev/null +++ b/automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -0,0 +1,1241 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests.""" + +import mock +import pytest + +from google.cloud import automl_v1beta1 +from google.api_core import exceptions +from google.cloud.automl_v1beta1.proto import data_types_pb2 + +PROJECT = "project" +REGION = "region" +LOCATION_PATH = "projects/{}/locations/{}".format(PROJECT, REGION) + + +class TestTablesClient(object): + def tables_client(self, client_attrs={}, prediction_client_attrs={}): + client_mock = mock.Mock(**client_attrs) + prediction_client_mock = mock.Mock(**prediction_client_attrs) + return automl_v1beta1.TablesClient( + client=client_mock, + prediction_client=prediction_client_mock, + project=PROJECT, + region=REGION, + ) + + def test_list_datasets_empty(self): + client = self.tables_client( + { + "list_datasets.return_value": [], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + ds = client.list_datasets() + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.list_datasets.assert_called_with(LOCATION_PATH) + assert ds == [] + + def test_list_datasets_not_empty(self): + datasets = ["some_dataset"] + client = self.tables_client( + { + "list_datasets.return_value": datasets, + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + ds = client.list_datasets() + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.list_datasets.assert_called_with(LOCATION_PATH) + assert len(ds) == 1 + assert ds[0] == "some_dataset" + + def test_get_dataset_no_value(self): + dataset_actual = "dataset" + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + dataset = client.get_dataset() + client.auto_ml_client.get_dataset.assert_not_called() + + def test_get_dataset_name(self): + dataset_actual = "dataset" + client = self.tables_client({"get_dataset.return_value": dataset_actual}, {}) + dataset = client.get_dataset(dataset_name="my_dataset") + client.auto_ml_client.get_dataset.assert_called_with("my_dataset") + assert dataset == dataset_actual + + def test_get_no_dataset(self): + client = self.tables_client( + {"get_dataset.side_effect": exceptions.NotFound("err")}, {} + ) + with pytest.raises(exceptions.NotFound): + client.get_dataset(dataset_name="my_dataset") + client.auto_ml_client.get_dataset.assert_called_with("my_dataset") + + def test_get_dataset_from_empty_list(self): + client = self.tables_client({"list_datasets.return_value": []}, {}) + with pytest.raises(exceptions.NotFound): + client.get_dataset(dataset_display_name="my_dataset") + + def test_get_dataset_from_list_not_found(self): + client = self.tables_client( + {"list_datasets.return_value": [mock.Mock(display_name="not_it")]}, {} + ) + with pytest.raises(exceptions.NotFound): + client.get_dataset(dataset_display_name="my_dataset") + + def test_get_dataset_from_list(self): + client = self.tables_client( + { + "list_datasets.return_value": [ + mock.Mock(display_name="not_it"), + mock.Mock(display_name="my_dataset"), + ] + }, + {}, + ) + dataset = client.get_dataset(dataset_display_name="my_dataset") + assert dataset.display_name == "my_dataset" + + def test_create_dataset(self): + client = self.tables_client( + { + "location_path.return_value": LOCATION_PATH, + "create_dataset.return_value": mock.Mock(display_name="name"), + }, + {}, + ) + metadata = {"metadata": "values"} + dataset = client.create_dataset("name", metadata=metadata) + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.create_dataset.assert_called_with( + LOCATION_PATH, {"display_name": "name", "tables_dataset_metadata": metadata} + ) + assert dataset.display_name == "name" + + def test_delete_dataset(self): + dataset = mock.Mock() + dataset.configure_mock(name="name") + client = self.tables_client({"delete_dataset.return_value": None}, {}) + client.delete_dataset(dataset=dataset) + client.auto_ml_client.delete_dataset.assert_called_with("name") + + def test_delete_dataset_not_found(self): + client = self.tables_client({"list_datasets.return_value": []}, {}) + client.delete_dataset(dataset_display_name="not_found") + client.auto_ml_client.delete_dataset.assert_not_called() + + def test_delete_dataset_name(self): + client = self.tables_client({"delete_dataset.return_value": None}, {}) + client.delete_dataset(dataset_name="name") + client.auto_ml_client.delete_dataset.assert_called_with("name") + + def test_export_not_found(self): + client = self.tables_client({"list_datasets.return_value": []}, {}) + with pytest.raises(exceptions.NotFound): + client.export_data(dataset_display_name="name", gcs_input_uris="uri") + + client.auto_ml_client.export_data.assert_not_called() + + def test_export_gcs_uri(self): + client = self.tables_client({"export_data.return_value": None}, {}) + client.export_data(dataset_name="name", gcs_output_uri_prefix="uri") + client.auto_ml_client.export_data.assert_called_with( + "name", {"gcs_destination": {"output_uri_prefix": "uri"}} + ) + + def test_export_bq_uri(self): + client = self.tables_client({"export_data.return_value": None}, {}) + client.export_data(dataset_name="name", bigquery_output_uri="uri") + client.auto_ml_client.export_data.assert_called_with( + "name", {"bigquery_destination": {"output_uri": "uri"}} + ) + + def test_import_not_found(self): + client = self.tables_client({"list_datasets.return_value": []}, {}) + with pytest.raises(exceptions.NotFound): + client.import_data(dataset_display_name="name", gcs_input_uris="uri") + + client.auto_ml_client.import_data.assert_not_called() + + def test_import_gcs_uri(self): + client = self.tables_client({"import_data.return_value": None}, {}) + client.import_data(dataset_name="name", gcs_input_uris="uri") + client.auto_ml_client.import_data.assert_called_with( + "name", {"gcs_source": {"input_uris": ["uri"]}} + ) + + def test_import_gcs_uris(self): + client = self.tables_client({"import_data.return_value": None}, {}) + client.import_data(dataset_name="name", gcs_input_uris=["uri", "uri"]) + client.auto_ml_client.import_data.assert_called_with( + "name", {"gcs_source": {"input_uris": ["uri", "uri"]}} + ) + + def test_import_bq_uri(self): + client = self.tables_client({"import_data.return_value": None}, {}) + client.import_data(dataset_name="name", bigquery_input_uri="uri") + client.auto_ml_client.import_data.assert_called_with( + "name", {"bigquery_source": {"input_uri": "uri"}} + ) + + def test_list_table_specs(self): + client = self.tables_client({"list_table_specs.return_value": None}, {}) + client.list_table_specs(dataset_name="name") + client.auto_ml_client.list_table_specs.assert_called_with("name") + + def test_list_table_specs_not_found(self): + client = self.tables_client( + {"list_table_specs.side_effect": exceptions.NotFound("not found")}, {} + ) + with pytest.raises(exceptions.NotFound): + client.list_table_specs(dataset_name="name") + client.auto_ml_client.list_table_specs.assert_called_with("name") + + def test_get_table_spec(self): + client = self.tables_client({}, {}) + client.get_table_spec("name") + client.auto_ml_client.get_table_spec.assert_called_with("name") + + def test_get_column_spec(self): + client = self.tables_client({}, {}) + client.get_column_spec("name") + client.auto_ml_client.get_column_spec.assert_called_with("name") + + def test_list_column_specs(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [], + }, + {}, + ) + client.list_column_specs(dataset_name="name") + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + + def test_update_column_spec_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + with pytest.raises(exceptions.NotFound): + client.update_column_spec(dataset_name="name", column_spec_name="column2") + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_not_called() + + def test_update_column_spec_display_name_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + with pytest.raises(exceptions.NotFound): + client.update_column_spec( + dataset_name="name", column_spec_display_name="column2" + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_not_called() + + def test_update_column_spec_name_no_args(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column/2", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec(dataset_name="name", column_spec_name="column/2") + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( + {"name": "column/2", "data_type": {"type_code": "type_code"}} + ) + + def test_update_column_spec_no_args(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", column_spec_display_name="column" + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( + {"name": "column", "data_type": {"type_code": "type_code"}} + ) + + def test_update_column_spec_nullable(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", column_spec_display_name="column", nullable=True + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( + { + "name": "column", + "data_type": {"type_code": "type_code", "nullable": True}, + } + ) + + def test_update_column_spec_type_code(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", + column_spec_display_name="column", + type_code="type_code2", + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( + {"name": "column", "data_type": {"type_code": "type_code2"}} + ) + + def test_update_column_spec_type_code_nullable(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", + nullable=True, + column_spec_display_name="column", + type_code="type_code2", + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( + { + "name": "column", + "data_type": {"type_code": "type_code2", "nullable": True}, + } + ) + + def test_update_column_spec_type_code_nullable_false(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + data_type_mock = mock.Mock(type_code="type_code") + column_spec_mock.configure_mock( + name="column", display_name="column", data_type=data_type_mock + ) + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.update_column_spec( + dataset_name="name", + nullable=False, + column_spec_display_name="column", + type_code="type_code2", + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_column_spec.assert_called_with( + { + "name": "column", + "data_type": {"type_code": "type_code2", "nullable": False}, + } + ) + + def test_set_target_column_table_not_found(self): + client = self.tables_client( + {"list_table_specs.side_effect": exceptions.NotFound("err")}, {} + ) + with pytest.raises(exceptions.NotFound): + client.set_target_column( + dataset_name="name", column_spec_display_name="column2" + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_not_called() + client.auto_ml_client.update_dataset.assert_not_called() + + def test_set_target_column_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name="column/1", display_name="column") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + with pytest.raises(exceptions.NotFound): + client.set_target_column( + dataset_name="name", column_spec_display_name="column2" + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_not_called() + + def test_set_target_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name="column/1", display_name="column") + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="2", + weight_column_spec_id="2", + ml_use_column_spec_id="3", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.set_target_column(dataset_name="name", column_spec_display_name="column") + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, + } + ) + + def test_set_weight_column_table_not_found(self): + client = self.tables_client( + {"list_table_specs.side_effect": exceptions.NotFound("err")}, {} + ) + try: + client.set_weight_column( + dataset_name="name", column_spec_display_name="column2" + ) + except exceptions.NotFound: + pass + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_not_called() + client.auto_ml_client.update_dataset.assert_not_called() + + def test_set_weight_column_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name="column/1", display_name="column") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + with pytest.raises(exceptions.NotFound): + client.set_weight_column( + dataset_name="name", column_spec_display_name="column2" + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_not_called() + + def test_set_weight_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name="column/2", display_name="column") + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="1", + weight_column_spec_id="1", + ml_use_column_spec_id="3", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.set_weight_column(dataset_name="name", column_spec_display_name="column") + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, + } + ) + + def test_clear_weight_column(self): + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="1", + weight_column_spec_id="2", + ml_use_column_spec_id="3", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client({"get_dataset.return_value": dataset_mock}, {}) + client.clear_weight_column(dataset_name="name") + client.auto_ml_client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": None, + "ml_use_column_spec_id": "3", + }, + } + ) + + def test_set_test_train_column_table_not_found(self): + client = self.tables_client( + {"list_table_specs.side_effect": exceptions.NotFound("err")}, {} + ) + with pytest.raises(exceptions.NotFound): + client.set_test_train_column( + dataset_name="name", column_spec_display_name="column2" + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_not_called() + client.auto_ml_client.update_dataset.assert_not_called() + + def test_set_test_train_column_not_found(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name="column/1", display_name="column") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + with pytest.raises(exceptions.NotFound): + client.set_test_train_column( + dataset_name="name", column_spec_display_name="column2" + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_not_called() + + def test_set_test_train_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name="column/3", display_name="column") + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="1", + weight_column_spec_id="2", + ml_use_column_spec_id="2", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.set_test_train_column( + dataset_name="name", column_spec_display_name="column" + ) + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": "3", + }, + } + ) + + def test_clear_test_train_column(self): + dataset_mock = mock.Mock() + tables_dataset_metadata_mock = mock.Mock() + tables_dataset_metadata_mock.configure_mock( + target_column_spec_id="1", + weight_column_spec_id="2", + ml_use_column_spec_id="2", + ) + dataset_mock.configure_mock( + name="dataset", tables_dataset_metadata=tables_dataset_metadata_mock + ) + client = self.tables_client({"get_dataset.return_value": dataset_mock}, {}) + client.clear_test_train_column(dataset_name="name") + client.auto_ml_client.update_dataset.assert_called_with( + { + "name": "dataset", + "tables_dataset_metadata": { + "target_column_spec_id": "1", + "weight_column_spec_id": "2", + "ml_use_column_spec_id": None, + }, + } + ) + + def test_set_time_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name="column/3", display_name="column") + dataset_mock = mock.Mock() + dataset_mock.configure_mock(name="dataset") + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + }, + {}, + ) + client.set_time_column(dataset_name="name", column_spec_display_name="column") + client.auto_ml_client.list_table_specs.assert_called_with("name") + client.auto_ml_client.list_column_specs.assert_called_with("table") + client.auto_ml_client.update_table_spec.assert_called_with( + {"name": "table", "time_column_spec_id": "3"} + ) + + def test_clear_time_column(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + dataset_mock = mock.Mock() + dataset_mock.configure_mock(name="dataset") + client = self.tables_client( + { + "get_dataset.return_value": dataset_mock, + "list_table_specs.return_value": [table_spec_mock], + }, + {}, + ) + client.clear_time_column(dataset_name="name") + client.auto_ml_client.update_table_spec.assert_called_with( + {"name": "table", "time_column_spec_id": None} + ) + + def test_get_model_evaluation(self): + client = self.tables_client({}, {}) + ds = client.get_model_evaluation(model_evaluation_name="x") + client.auto_ml_client.get_model_evaluation.assert_called_with("x") + + def test_list_model_evaluations_empty(self): + client = self.tables_client({"list_model_evaluations.return_value": []}, {}) + ds = client.list_model_evaluations(model_name="model") + client.auto_ml_client.list_model_evaluations.assert_called_with("model") + assert ds == [] + + def test_list_model_evaluations_not_empty(self): + evaluations = ["eval"] + client = self.tables_client( + { + "list_model_evaluations.return_value": evaluations, + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + ds = client.list_model_evaluations(model_name="model") + client.auto_ml_client.list_model_evaluations.assert_called_with("model") + assert len(ds) == 1 + assert ds[0] == "eval" + + def test_list_models_empty(self): + client = self.tables_client( + { + "list_models.return_value": [], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + ds = client.list_models() + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.list_models.assert_called_with(LOCATION_PATH) + assert ds == [] + + def test_list_models_not_empty(self): + models = ["some_model"] + client = self.tables_client( + { + "list_models.return_value": models, + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + ds = client.list_models() + client.auto_ml_client.location_path.assert_called_with(PROJECT, REGION) + client.auto_ml_client.list_models.assert_called_with(LOCATION_PATH) + assert len(ds) == 1 + assert ds[0] == "some_model" + + def test_get_model_name(self): + model_actual = "model" + client = self.tables_client({"get_model.return_value": model_actual}, {}) + model = client.get_model(model_name="my_model") + client.auto_ml_client.get_model.assert_called_with("my_model") + assert model == model_actual + + def test_get_no_model(self): + client = self.tables_client( + {"get_model.side_effect": exceptions.NotFound("err")}, {} + ) + with pytest.raises(exceptions.NotFound): + client.get_model(model_name="my_model") + client.auto_ml_client.get_model.assert_called_with("my_model") + + def test_get_model_from_empty_list(self): + client = self.tables_client({"list_models.return_value": []}, {}) + with pytest.raises(exceptions.NotFound): + client.get_model(model_display_name="my_model") + + def test_get_model_from_list_not_found(self): + client = self.tables_client( + {"list_models.return_value": [mock.Mock(display_name="not_it")]}, {} + ) + with pytest.raises(exceptions.NotFound): + client.get_model(model_display_name="my_model") + + def test_get_model_from_list(self): + client = self.tables_client( + { + "list_models.return_value": [ + mock.Mock(display_name="not_it"), + mock.Mock(display_name="my_model"), + ] + }, + {}, + ) + model = client.get_model(model_display_name="my_model") + assert model.display_name == "my_model" + + def test_delete_model(self): + model = mock.Mock() + model.configure_mock(name="name") + client = self.tables_client({"delete_model.return_value": None}, {}) + client.delete_model(model=model) + client.auto_ml_client.delete_model.assert_called_with("name") + + def test_delete_model_not_found(self): + client = self.tables_client({"list_models.return_value": []}, {}) + client.delete_model(model_display_name="not_found") + client.auto_ml_client.delete_model.assert_not_called() + + def test_delete_model_name(self): + client = self.tables_client({"delete_model.return_value": None}, {}) + client.delete_model(model_name="name") + client.auto_ml_client.delete_model.assert_called_with("name") + + def test_deploy_model_no_args(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.deploy_model() + client.auto_ml_client.deploy_model.assert_not_called() + + def test_deploy_model(self): + client = self.tables_client({}, {}) + client.deploy_model(model_name="name") + client.auto_ml_client.deploy_model.assert_called_with("name") + + def test_deploy_model_not_found(self): + client = self.tables_client({"list_models.return_value": []}, {}) + with pytest.raises(exceptions.NotFound): + client.deploy_model(model_display_name="name") + client.auto_ml_client.deploy_model.assert_not_called() + + def test_undeploy_model(self): + client = self.tables_client({}, {}) + client.undeploy_model(model_name="name") + client.auto_ml_client.undeploy_model.assert_called_with("name") + + def test_undeploy_model_not_found(self): + client = self.tables_client({"list_models.return_value": []}, {}) + with pytest.raises(exceptions.NotFound): + client.undeploy_model(model_display_name="name") + client.auto_ml_client.undeploy_model.assert_not_called() + + def test_create_model(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock = mock.Mock() + column_spec_mock.configure_mock(name="column/2", display_name="column") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [column_spec_mock], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + client.create_model( + "my_model", dataset_name="my_dataset", train_budget_milli_node_hours=1000 + ) + client.auto_ml_client.create_model.assert_called_with( + LOCATION_PATH, + { + "display_name": "my_model", + "dataset_id": "my_dataset", + "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, + }, + ) + + def test_create_model_include_columns(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock1 = mock.Mock() + column_spec_mock1.configure_mock(name="column/1", display_name="column1") + column_spec_mock2 = mock.Mock() + column_spec_mock2.configure_mock(name="column/2", display_name="column2") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [ + column_spec_mock1, + column_spec_mock2, + ], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + client.create_model( + "my_model", + dataset_name="my_dataset", + include_column_spec_names=["column1"], + train_budget_milli_node_hours=1000, + ) + client.auto_ml_client.create_model.assert_called_with( + LOCATION_PATH, + { + "display_name": "my_model", + "dataset_id": "my_dataset", + "tables_model_metadata": { + "train_budget_milli_node_hours": 1000, + "input_feature_column_specs": [column_spec_mock1], + }, + }, + ) + + def test_create_model_exclude_columns(self): + table_spec_mock = mock.Mock() + # name is reserved in use of __init__, needs to be passed here + table_spec_mock.configure_mock(name="table") + column_spec_mock1 = mock.Mock() + column_spec_mock1.configure_mock(name="column/1", display_name="column1") + column_spec_mock2 = mock.Mock() + column_spec_mock2.configure_mock(name="column/2", display_name="column2") + client = self.tables_client( + { + "list_table_specs.return_value": [table_spec_mock], + "list_column_specs.return_value": [ + column_spec_mock1, + column_spec_mock2, + ], + "location_path.return_value": LOCATION_PATH, + }, + {}, + ) + client.create_model( + "my_model", + dataset_name="my_dataset", + exclude_column_spec_names=["column1"], + train_budget_milli_node_hours=1000, + ) + client.auto_ml_client.create_model.assert_called_with( + LOCATION_PATH, + { + "display_name": "my_model", + "dataset_id": "my_dataset", + "tables_model_metadata": { + "train_budget_milli_node_hours": 1000, + "input_feature_column_specs": [column_spec_mock2], + }, + }, + ) + + def test_create_model_invalid_hours_small(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.create_model( + "my_model", dataset_name="my_dataset", train_budget_milli_node_hours=1 + ) + client.auto_ml_client.create_model.assert_not_called() + + def test_create_model_invalid_hours_large(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.create_model( + "my_model", + dataset_name="my_dataset", + train_budget_milli_node_hours=1000000, + ) + client.auto_ml_client.create_model.assert_not_called() + + def test_create_model_invalid_no_dataset(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.create_model("my_model", train_budget_milli_node_hours=1000) + client.auto_ml_client.get_dataset.assert_not_called() + client.auto_ml_client.create_model.assert_not_called() + + def test_create_model_invalid_include_exclude(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.create_model( + "my_model", + dataset_name="my_dataset", + include_column_spec_names=["a"], + exclude_column_spec_names=["b"], + train_budget_milli_node_hours=1000, + ) + client.auto_ml_client.get_dataset.assert_not_called() + client.auto_ml_client.create_model.assert_not_called() + + def test_predict_from_array(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec = mock.Mock(display_name="a", data_type=data_type) + model_metadata = mock.Mock(input_feature_column_specs=[column_spec]) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict(["1"], model_name="my_model") + client.prediction_client.predict.assert_called_with( + "my_model", {"row": {"values": [{"string_value": "1"}]}} + ) + + def test_predict_from_dict(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec_a = mock.Mock(display_name="a", data_type=data_type) + column_spec_b = mock.Mock(display_name="b", data_type=data_type) + model_metadata = mock.Mock( + input_feature_column_specs=[column_spec_a, column_spec_b] + ) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict({"a": "1", "b": "2"}, model_name="my_model") + client.prediction_client.predict.assert_called_with( + "my_model", + {"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}}, + ) + + def test_predict_from_dict_missing(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec_a = mock.Mock(display_name="a", data_type=data_type) + column_spec_b = mock.Mock(display_name="b", data_type=data_type) + model_metadata = mock.Mock( + input_feature_column_specs=[column_spec_a, column_spec_b] + ) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict({"a": "1"}, model_name="my_model") + client.prediction_client.predict.assert_called_with( + "my_model", {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}} + ) + + def test_predict_all_types(self): + float_type = mock.Mock(type_code=data_types_pb2.FLOAT64) + timestamp_type = mock.Mock(type_code=data_types_pb2.TIMESTAMP) + string_type = mock.Mock(type_code=data_types_pb2.STRING) + array_type = mock.Mock(type_code=data_types_pb2.ARRAY) + struct_type = mock.Mock(type_code=data_types_pb2.STRUCT) + category_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec_float = mock.Mock(display_name="float", data_type=float_type) + column_spec_timestamp = mock.Mock( + display_name="timestamp", data_type=timestamp_type + ) + column_spec_string = mock.Mock(display_name="string", data_type=string_type) + column_spec_array = mock.Mock(display_name="array", data_type=array_type) + column_spec_struct = mock.Mock(display_name="struct", data_type=struct_type) + column_spec_category = mock.Mock( + display_name="category", data_type=category_type + ) + column_spec_null = mock.Mock(display_name="null", data_type=category_type) + model_metadata = mock.Mock( + input_feature_column_specs=[ + column_spec_float, + column_spec_timestamp, + column_spec_string, + column_spec_array, + column_spec_struct, + column_spec_category, + column_spec_null, + ] + ) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + client.predict( + { + "float": 1.0, + "timestamp": "EST", + "string": "text", + "array": [1], + "struct": {"a": "b"}, + "category": "a", + "null": None, + }, + model_name="my_model", + ) + client.prediction_client.predict.assert_called_with( + "my_model", + { + "row": { + "values": [ + {"number_value": 1.0}, + {"string_value": "EST"}, + {"string_value": "text"}, + {"list_value": [1]}, + {"struct_value": {"a": "b"}}, + {"string_value": "a"}, + {"null_value": 0}, + ] + } + }, + ) + + def test_predict_from_array_missing(self): + data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) + column_spec = mock.Mock(display_name="a", data_type=data_type) + model_metadata = mock.Mock(input_feature_column_specs=[column_spec]) + model = mock.Mock() + model.configure_mock(tables_model_metadata=model_metadata, name="my_model") + client = self.tables_client({"get_model.return_value": model}, {}) + with pytest.raises(ValueError): + client.predict([], model_name="my_model") + client.prediction_client.predict.assert_not_called() + + def test_batch_predict_gcs(self): + client = self.tables_client({}, {}) + client.batch_predict( + model_name="my_model", + gcs_input_uris="gs://input", + gcs_output_uri_prefix="gs://output", + ) + client.prediction_client.batch_predict.assert_called_with( + "my_model", + {"gcs_source": {"input_uris": ["gs://input"]}}, + {"gcs_destination": {"output_uri_prefix": "gs://output"}}, + ) + + def test_batch_predict_bigquery(self): + client = self.tables_client({}, {}) + client.batch_predict( + model_name="my_model", + bigquery_input_uri="bq://input", + bigquery_output_uri="bq://output", + ) + client.prediction_client.batch_predict.assert_called_with( + "my_model", + {"bigquery_source": {"input_uri": "bq://input"}}, + {"bigquery_destination": {"output_uri": "bq://output"}}, + ) + + def test_batch_predict_mixed(self): + client = self.tables_client({}, {}) + client.batch_predict( + model_name="my_model", + gcs_input_uris="gs://input", + bigquery_output_uri="bq://output", + ) + client.prediction_client.batch_predict.assert_called_with( + "my_model", + {"gcs_source": {"input_uris": ["gs://input"]}}, + {"bigquery_destination": {"output_uri": "bq://output"}}, + ) + + def test_batch_predict_missing_input_gcs_uri(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.batch_predict( + model_name="my_model", + gcs_input_uris=None, + gcs_output_uri_prefix="gs://output", + ) + client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_missing_input_bigquery_uri(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.batch_predict( + model_name="my_model", + bigquery_input_uri=None, + gcs_output_uri_prefix="gs://output", + ) + client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_missing_output_gcs_uri(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.batch_predict( + model_name="my_model", + gcs_input_uris="gs://input", + gcs_output_uri_prefix=None, + ) + client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_missing_output_bigquery_uri(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.batch_predict( + model_name="my_model", + gcs_input_uris="gs://input", + bigquery_output_uri=None, + ) + client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_missing_model(self): + client = self.tables_client({"list_models.return_value": []}, {}) + with pytest.raises(exceptions.NotFound): + client.batch_predict( + model_display_name="my_model", + gcs_input_uris="gs://input", + gcs_output_uri_prefix="gs://output", + ) + client.prediction_client.batch_predict.assert_not_called() + + def test_batch_predict_no_model(self): + client = self.tables_client({}, {}) + with pytest.raises(ValueError): + client.batch_predict( + gcs_input_uris="gs://input", gcs_output_uri_prefix="gs://output" + ) + client.auto_ml_client.list_models.assert_not_called() + client.prediction_client.batch_predict.assert_not_called()