-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
…oogleCloudPlatform/python-docs-samples#2614) * automl: add vision object detection samples for atuoml ga * Update tests * update test resource file used * Consistently use double quotes * Move test imports to top of file * license year 2020 * Use centralized testing project for automl, improve comment with links to docs Co-authored-by: Leah E. Cole <6719667+leahecole@users.noreply.github.com>
- Loading branch information
1 parent
1bd8847
commit ad7c5c6
Showing
9 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
43 changes: 43 additions & 0 deletions
43
packages/google-cloud-automl/samples/snippets/vision_object_detection_create_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def create_dataset(project_id, display_name): | ||
"""Create a dataset.""" | ||
# [START automl_vision_object_detection_create_dataset] | ||
from google.cloud import automl | ||
|
||
# TODO(developer): Uncomment and set the following variables | ||
# project_id = "YOUR_PROJECT_ID" | ||
# display_name = "your_datasets_display_name" | ||
|
||
client = automl.AutoMlClient() | ||
|
||
# A resource that represents Google Cloud Platform location. | ||
project_location = client.location_path(project_id, "us-central1") | ||
metadata = automl.types.ImageObjectDetectionDatasetMetadata() | ||
dataset = automl.types.Dataset( | ||
display_name=display_name, | ||
image_object_detection_dataset_metadata=metadata, | ||
) | ||
|
||
# Create a dataset with the dataset metadata in the region. | ||
response = client.create_dataset(project_location, dataset) | ||
|
||
created_dataset = response.result() | ||
|
||
# Display the dataset information | ||
print("Dataset name: {}".format(created_dataset.name)) | ||
print("Dataset id: {}".format(created_dataset.name.split("/")[-1])) | ||
# [END automl_vision_object_detection_create_dataset] |
44 changes: 44 additions & 0 deletions
44
packages/google-cloud-automl/samples/snippets/vision_object_detection_create_dataset_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import datetime | ||
import os | ||
|
||
from google.cloud import automl | ||
import pytest | ||
|
||
import vision_object_detection_create_dataset | ||
|
||
|
||
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"] | ||
|
||
|
||
@pytest.mark.slow | ||
def test_vision_object_detection_create_dataset(capsys): | ||
# create dataset | ||
dataset_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S") | ||
vision_object_detection_create_dataset.create_dataset( | ||
PROJECT_ID, dataset_name | ||
) | ||
out, _ = capsys.readouterr() | ||
assert "Dataset id: " in out | ||
|
||
# Delete the created dataset | ||
dataset_id = out.splitlines()[1].split()[2] | ||
client = automl.AutoMlClient() | ||
dataset_full_id = client.dataset_path( | ||
PROJECT_ID, "us-central1", dataset_id | ||
) | ||
response = client.delete_dataset(dataset_full_id) | ||
response.result() |
48 changes: 48 additions & 0 deletions
48
packages/google-cloud-automl/samples/snippets/vision_object_detection_create_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def create_model(project_id, dataset_id, display_name): | ||
"""Create a model.""" | ||
# [START automl_vision_object_detection_create_model] | ||
from google.cloud import automl | ||
|
||
# TODO(developer): Uncomment and set the following variables | ||
# project_id = "YOUR_PROJECT_ID" | ||
# dataset_id = "YOUR_DATASET_ID" | ||
# display_name = "your_models_display_name" | ||
|
||
client = automl.AutoMlClient() | ||
|
||
# A resource that represents Google Cloud Platform location. | ||
project_location = client.location_path(project_id, "us-central1") | ||
# Leave model unset to use the default base model provided by Google | ||
# train_budget_milli_node_hours: The actual train_cost will be equal or | ||
# less than this value. | ||
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#imageobjectdetectionmodelmetadata | ||
metadata = automl.types.ImageObjectDetectionModelMetadata( | ||
train_budget_milli_node_hours=24000 | ||
) | ||
model = automl.types.Model( | ||
display_name=display_name, | ||
dataset_id=dataset_id, | ||
image_object_detection_model_metadata=metadata, | ||
) | ||
|
||
# Create a model with the model metadata in the region. | ||
response = client.create_model(project_location, model) | ||
|
||
print("Training operation name: {}".format(response.operation.name)) | ||
print("Training started...") | ||
# [END automl_vision_object_detection_create_model] |
37 changes: 37 additions & 0 deletions
37
packages/google-cloud-automl/samples/snippets/vision_object_detection_create_model_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
from google.cloud import automl | ||
import pytest | ||
|
||
import vision_object_detection_create_model | ||
|
||
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"] | ||
DATASET_ID = os.environ["OBJECT_DETECTION_DATASET_ID"] | ||
|
||
|
||
@pytest.mark.slow | ||
def test_vision_object_detection_create_model(capsys): | ||
vision_object_detection_create_model.create_model( | ||
PROJECT_ID, DATASET_ID, "object_test_create_model" | ||
) | ||
out, _ = capsys.readouterr() | ||
assert "Training started" in out | ||
|
||
# Cancel the operation | ||
operation_id = out.split("Training operation name: ")[1].split("\n")[0] | ||
client = automl.AutoMlClient() | ||
client.transport._operations_client.cancel_operation(operation_id) |
40 changes: 40 additions & 0 deletions
40
...s/google-cloud-automl/samples/snippets/vision_object_detection_deploy_model_node_count.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def deploy_model(project_id, model_id): | ||
"""Deploy a model with a specified node count.""" | ||
# [START automl_vision_object_detection_deploy_model_node_count] | ||
from google.cloud import automl | ||
|
||
# TODO(developer): Uncomment and set the following variables | ||
# project_id = "YOUR_PROJECT_ID" | ||
# model_id = "YOUR_MODEL_ID" | ||
|
||
client = automl.AutoMlClient() | ||
# Get the full path of the model. | ||
model_full_id = client.model_path(project_id, "us-central1", model_id) | ||
|
||
# node count determines the number of nodes to deploy the model on. | ||
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#imageobjectdetectionmodeldeploymentmetadata | ||
metadata = automl.types.ImageObjectDetectionModelDeploymentMetadata( | ||
node_count=2 | ||
) | ||
response = client.deploy_model( | ||
model_full_id, | ||
image_object_detection_model_deployment_metadata=metadata, | ||
) | ||
|
||
print("Model deployment finished. {}".format(response.result())) | ||
# [END automl_vision_object_detection_deploy_model_node_count] |
37 changes: 37 additions & 0 deletions
37
...gle-cloud-automl/samples/snippets/vision_object_detection_deploy_model_node_count_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
import pytest | ||
|
||
import vision_object_detection_deploy_model_node_count | ||
|
||
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"] | ||
MODEL_ID = "0000000000000000000000" | ||
|
||
|
||
@pytest.mark.slow | ||
def test_object_detection_deploy_model_with_node_count(capsys): | ||
# As model deployment can take a long time, instead try to deploy a | ||
# nonexistent model and confirm that the model was not found, but other | ||
# elements of the request were valid. | ||
try: | ||
vision_object_detection_deploy_model_node_count.deploy_model( | ||
PROJECT_ID, MODEL_ID | ||
) | ||
out, _ = capsys.readouterr() | ||
assert "The model does not exist" in out | ||
except Exception as e: | ||
assert "The model does not exist" in e.message |
58 changes: 58 additions & 0 deletions
58
packages/google-cloud-automl/samples/snippets/vision_object_detection_predict.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
def predict(project_id, model_id, file_path): | ||
"""Predict.""" | ||
# [START automl_vision_object_detection_predict] | ||
from google.cloud import automl | ||
|
||
# TODO(developer): Uncomment and set the following variables | ||
# project_id = "YOUR_PROJECT_ID" | ||
# model_id = "YOUR_MODEL_ID" | ||
# file_path = "path_to_local_file.jpg" | ||
|
||
prediction_client = automl.PredictionServiceClient() | ||
|
||
# Get the full path of the model. | ||
model_full_id = prediction_client.model_path( | ||
project_id, "us-central1", model_id | ||
) | ||
|
||
# Read the file. | ||
with open(file_path, "rb") as content_file: | ||
content = content_file.read() | ||
|
||
image = automl.types.Image(image_bytes=content) | ||
payload = automl.types.ExamplePayload(image=image) | ||
|
||
# params is additional domain-specific parameters. | ||
# score_threshold is used to filter the result | ||
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#predictrequest | ||
params = {"score_threshold": "0.8"} | ||
|
||
response = prediction_client.predict(model_full_id, payload, params) | ||
print("Prediction results:") | ||
for result in response.payload: | ||
print("Predicted class name: {}".format(result.display_name)) | ||
print( | ||
"Predicted class score: {}".format( | ||
result.image_object_detection.score | ||
) | ||
) | ||
bounding_box = result.image_object_detection.bounding_box | ||
print("Normalized Vertices:") | ||
for vertex in bounding_box.normalized_vertices: | ||
print("\tX: {}, Y: {}".format(vertex.x, vertex.y)) | ||
# [END automl_vision_object_detection_predict] |
43 changes: 43 additions & 0 deletions
43
packages/google-cloud-automl/samples/snippets/vision_object_detection_predict_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
from google.cloud import automl | ||
import pytest | ||
|
||
import vision_object_detection_predict | ||
|
||
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"] | ||
MODEL_ID = os.environ["OBJECT_DETECTION_MODEL_ID"] | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def verify_model_state(): | ||
client = automl.AutoMlClient() | ||
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID) | ||
|
||
model = client.get_model(model_full_id) | ||
if model.deployment_state == automl.enums.Model.DeploymentState.UNDEPLOYED: | ||
# Deploy model if it is not deployed | ||
response = client.deploy_model(model_full_id) | ||
response.result() | ||
|
||
|
||
def test_vision_object_detection_predict(capsys, verify_model_state): | ||
verify_model_state | ||
file_path = "resources/salad.jpg" | ||
vision_object_detection_predict.predict(PROJECT_ID, MODEL_ID, file_path) | ||
out, _ = capsys.readouterr() | ||
assert "Predicted class name:" in out |