diff --git a/samples/snippets/batch_predict.py b/samples/snippets/batch_predict.py new file mode 100644 index 00000000..efe484f4 --- /dev/null +++ b/samples/snippets/batch_predict.py @@ -0,0 +1,52 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def batch_predict(project_id, model_id, input_uri, output_uri): + """Batch predict""" + # [START automl_batch_predict] + from google.cloud import automl + + # TODO(developer): Uncomment and set the following variables + # project_id = "YOUR_PROJECT_ID" + # model_id = "YOUR_MODEL_ID" + # input_uri = "gs://YOUR_BUCKET_ID/path/to/your/input/csv_or_jsonl" + # output_uri = "gs://YOUR_BUCKET_ID/path/to/save/results/" + + prediction_client = automl.PredictionServiceClient() + + # Get the full path of the model. + model_full_id = prediction_client.model_path( + project_id, "us-central1", model_id + ) + + gcs_source = automl.types.GcsSource(input_uris=[input_uri]) + + input_config = automl.types.BatchPredictInputConfig(gcs_source=gcs_source) + gcs_destination = automl.types.GcsDestination(output_uri_prefix=output_uri) + output_config = automl.types.BatchPredictOutputConfig( + gcs_destination=gcs_destination + ) + + response = prediction_client.batch_predict( + model_full_id, input_config, output_config + ) + + print("Waiting for operation to complete...") + print( + "Batch Prediction results saved to Cloud Storage bucket. {}".format( + response.result() + ) + ) + # [END automl_batch_predict] diff --git a/samples/snippets/batch_predict_test.py b/samples/snippets/batch_predict_test.py new file mode 100644 index 00000000..0f9417f6 --- /dev/null +++ b/samples/snippets/batch_predict_test.py @@ -0,0 +1,47 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific ladnguage governing permissions and +# limitations under the License. + +import datetime +import os + +import batch_predict + +PROJECT_ID = os.environ["AUTOML_PROJECT_ID"] +BUCKET_ID = "{}-lcm".format(PROJECT_ID) +MODEL_ID = "TEN0000000000000000000" +PREFIX = "TEST_EXPORT_OUTPUT_" + datetime.datetime.now().strftime( + "%Y%m%d%H%M%S" +) + + +def test_batch_predict(capsys): + # As batch prediction can take a long time. Try to batch predict on a model + # and confirm that the model was not found, but other elements of the + # request were valid. + try: + input_uri = "gs://{}/entity-extraction/input.jsonl".format(BUCKET_ID) + output_uri = "gs://{}/{}/".format(BUCKET_ID, PREFIX) + batch_predict.batch_predict( + PROJECT_ID, MODEL_ID, input_uri, output_uri + ) + out, _ = capsys.readouterr() + assert ( + "The model is either not found or not supported for prediction yet" + in out + ) + except Exception as e: + assert ( + "The model is either not found or not supported for prediction yet" + in e.message + ) diff --git a/samples/snippets/delete_dataset_test.py b/samples/snippets/delete_dataset_test.py index 8a1057a6..6d204dde 100644 --- a/samples/snippets/delete_dataset_test.py +++ b/samples/snippets/delete_dataset_test.py @@ -25,7 +25,7 @@ @pytest.fixture(scope="function") -def create_dataset(): +def dataset_id(): client = automl.AutoMlClient() project_location = client.location_path(PROJECT_ID, "us-central1") display_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S") @@ -39,8 +39,8 @@ def create_dataset(): yield dataset_id -def test_delete_dataset(capsys, create_dataset): +def test_delete_dataset(capsys, dataset_id): # delete dataset - delete_dataset.delete_dataset(PROJECT_ID, create_dataset) + delete_dataset.delete_dataset(PROJECT_ID, dataset_id) out, _ = capsys.readouterr() assert "Dataset deleted." in out diff --git a/samples/snippets/get_model_evaluation_test.py b/samples/snippets/get_model_evaluation_test.py index 40a88a82..f3fe1b2b 100644 --- a/samples/snippets/get_model_evaluation_test.py +++ b/samples/snippets/get_model_evaluation_test.py @@ -24,7 +24,7 @@ @pytest.fixture(scope="function") -def get_evaluation_id(): +def model_evaluation_id(): client = automl.AutoMlClient() model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID) evaluation = None @@ -37,9 +37,9 @@ def get_evaluation_id(): yield model_evaluation_id -def test_get_model_evaluation(capsys, get_evaluation_id): +def test_get_model_evaluation(capsys, model_evaluation_id): get_model_evaluation.get_model_evaluation( - PROJECT_ID, MODEL_ID, get_evaluation_id + PROJECT_ID, MODEL_ID, model_evaluation_id ) out, _ = capsys.readouterr() assert "Model evaluation name: " in out diff --git a/samples/snippets/get_operation_status.py b/samples/snippets/get_operation_status.py new file mode 100644 index 00000000..4e5c90f8 --- /dev/null +++ b/samples/snippets/get_operation_status.py @@ -0,0 +1,34 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_operation_status(operation_full_id): + """Get operation status.""" + # [START automl_get_operation_status] + from google.cloud import automl + + # TODO(developer): Uncomment and set the following variables + # operation_full_id = \ + # "projects/[projectId]/locations/us-central1/operations/[operationId]" + + client = automl.AutoMlClient() + # Get the latest state of a long-running operation. + response = client.transport._operations_client.get_operation( + operation_full_id + ) + + print("Name: {}".format(response.name)) + print("Operation details:") + print(response) + # [END automl_get_operation_status] diff --git a/samples/snippets/get_operation_status_test.py b/samples/snippets/get_operation_status_test.py new file mode 100644 index 00000000..c08095fc --- /dev/null +++ b/samples/snippets/get_operation_status_test.py @@ -0,0 +1,40 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.cloud import automl +import pytest + +import get_operation_status + +PROJECT_ID = os.environ["AUTOML_PROJECT_ID"] + + +@pytest.fixture(scope="function") +def operation_id(): + client = automl.AutoMlClient() + project_location = client.location_path(PROJECT_ID, "us-central1") + generator = client.transport._operations_client.list_operations( + project_location, filter_="" + ).pages + page = next(generator) + operation = page.next() + yield operation.name + + +def test_get_operation_status(capsys, operation_id): + get_operation_status.get_operation_status(operation_id) + out, _ = capsys.readouterr() + assert "Operation details" in out diff --git a/samples/snippets/import_dataset_test.py b/samples/snippets/import_dataset_test.py index 2064abbc..35d23edc 100644 --- a/samples/snippets/import_dataset_test.py +++ b/samples/snippets/import_dataset_test.py @@ -12,49 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime import os -from google.cloud import automl -import pytest - import import_dataset PROJECT_ID = os.environ["AUTOML_PROJECT_ID"] BUCKET_ID = "{}-lcm".format(PROJECT_ID) - - -@pytest.fixture(scope="function") -def create_dataset(): - client = automl.AutoMlClient() - project_location = client.location_path(PROJECT_ID, "us-central1") - display_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S") - metadata = automl.types.TextSentimentDatasetMetadata( - sentiment_max=4 - ) - dataset = automl.types.Dataset( - display_name=display_name, text_sentiment_dataset_metadata=metadata - ) - response = client.create_dataset(project_location, dataset) - dataset_id = response.result().name.split("/")[-1] - - yield dataset_id - - -@pytest.mark.slow -def test_import_dataset(capsys, create_dataset): - data = ( - "gs://{}/sentiment-analysis/dataset.csv".format(BUCKET_ID) - ) - dataset_id = create_dataset - import_dataset.import_dataset(PROJECT_ID, dataset_id, data) - out, _ = capsys.readouterr() - assert "Data imported." in out - - # delete created dataset - client = automl.AutoMlClient() - dataset_full_id = client.dataset_path( - PROJECT_ID, "us-central1", dataset_id - ) - response = client.delete_dataset(dataset_full_id) - response.result() +DATASET_ID = "TEN0000000000000000000" + + +def test_import_dataset(capsys): + # As importing a dataset can take a long time and only four operations can + # be run on a dataset at once. Try to import into a nonexistent dataset and + # confirm that the dataset was not found, but other elements of the request + # were valid. + try: + data = "gs://{}/sentiment-analysis/dataset.csv".format(BUCKET_ID) + import_dataset.import_dataset(PROJECT_ID, DATASET_ID, data) + out, _ = capsys.readouterr() + assert ( + "The Dataset doesn't exist or is inaccessible for use with AutoMl." + in out + ) + except Exception as e: + assert ( + "The Dataset doesn't exist or is inaccessible for use with AutoMl." + in e.message + ) diff --git a/samples/snippets/language_sentiment_analysis_predict_test.py b/samples/snippets/language_sentiment_analysis_predict_test.py index 63a37264..bfd35649 100644 --- a/samples/snippets/language_sentiment_analysis_predict_test.py +++ b/samples/snippets/language_sentiment_analysis_predict_test.py @@ -23,8 +23,9 @@ MODEL_ID = os.environ["SENTIMENT_ANALYSIS_MODEL_ID"] -@pytest.fixture(scope="function") -def verify_model_state(): +@pytest.fixture(scope="function", autouse=True) +def setup(): + # Verify the model is deployed before trying to predict client = automl.AutoMlClient() model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID) @@ -35,8 +36,7 @@ def verify_model_state(): response.result() -def test_sentiment_analysis_predict(capsys, verify_model_state): - verify_model_state +def test_sentiment_analysis_predict(capsys): text = "Hopefully this Claritin kicks in soon" language_sentiment_analysis_predict.predict(PROJECT_ID, MODEL_ID, text) out, _ = capsys.readouterr() diff --git a/samples/snippets/list_operation_status.py b/samples/snippets/list_operation_status.py new file mode 100644 index 00000000..45534fda --- /dev/null +++ b/samples/snippets/list_operation_status.py @@ -0,0 +1,37 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def list_operation_status(project_id): + """List operation status.""" + # [START automl_list_operation_status] + from google.cloud import automl + + # TODO(developer): Uncomment and set the following variables + # project_id = "YOUR_PROJECT_ID" + + client = automl.AutoMlClient() + # A resource that represents Google Cloud Platform location. + project_location = client.location_path(project_id, "us-central1") + # List all the operations names available in the region. + response = client.transport._operations_client.list_operations( + project_location, "" + ) + + print("List of operations:") + for operation in response: + print("Name: {}".format(operation.name)) + print("Operation details:") + print(operation) + # [END automl_list_operation_status] diff --git a/samples/snippets/list_operation_status_test.py b/samples/snippets/list_operation_status_test.py new file mode 100644 index 00000000..ff6a0973 --- /dev/null +++ b/samples/snippets/list_operation_status_test.py @@ -0,0 +1,28 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +import list_operation_status + +PROJECT_ID = os.environ["AUTOML_PROJECT_ID"] + + +@pytest.mark.slow +def test_list_operation_status(capsys): + list_operation_status.list_operation_status(PROJECT_ID) + out, _ = capsys.readouterr() + assert "Operation details" in out diff --git a/samples/snippets/translate_predict_test.py b/samples/snippets/translate_predict_test.py index 7dbdb4ba..cd31d98b 100644 --- a/samples/snippets/translate_predict_test.py +++ b/samples/snippets/translate_predict_test.py @@ -23,8 +23,9 @@ MODEL_ID = os.environ["TRANSLATION_MODEL_ID"] -@pytest.fixture(scope="function") -def verify_model_state(): +@pytest.fixture(scope="function", autouse=True) +def setup(): + # Verify the model is deployed before trying to predict client = automl.AutoMlClient() model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID) @@ -35,8 +36,7 @@ def verify_model_state(): response.result() -def test_translate_predict(capsys, verify_model_state): - verify_model_state +def test_translate_predict(capsys): translate_predict.predict(PROJECT_ID, MODEL_ID, "resources/input.txt") out, _ = capsys.readouterr() assert "Translated content: " in out diff --git a/samples/snippets/vision_classification_predict_test.py b/samples/snippets/vision_classification_predict_test.py index 9df9c911..bc91796a 100644 --- a/samples/snippets/vision_classification_predict_test.py +++ b/samples/snippets/vision_classification_predict_test.py @@ -23,8 +23,9 @@ MODEL_ID = os.environ["VISION_CLASSIFICATION_MODEL_ID"] -@pytest.fixture(scope="function") -def verify_model_state(): +@pytest.fixture(scope="function", autouse=True) +def setup(): + # Verify the model is deployed before tyring to predict client = automl.AutoMlClient() model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID) @@ -35,8 +36,7 @@ def verify_model_state(): response.result() -def test_vision_classification_predict(capsys, verify_model_state): - verify_model_state +def test_vision_classification_predict(capsys): file_path = "resources/test.png" vision_classification_predict.predict(PROJECT_ID, MODEL_ID, file_path) out, _ = capsys.readouterr()