diff --git a/airflow/providers/google/cloud/hooks/automl.py b/airflow/providers/google/cloud/hooks/automl.py
index 1dd7cb03bafbf..29846b320ba03 100644
--- a/airflow/providers/google/cloud/hooks/automl.py
+++ b/airflow/providers/google/cloud/hooks/automl.py
@@ -15,166 +15,511 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
+"""
+This module contains a Google AutoML hook.
-import warnings
+.. spelling:word-list::
-from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
+ PredictResponse
+"""
+from __future__ import annotations
-class CloudAutoMLHook:
- """
- Former Google Cloud AutoML hook.
+from functools import cached_property
+from typing import TYPE_CHECKING, Sequence
+
+from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
+from google.cloud.automl_v1beta1 import (
+ AutoMlClient,
+ BatchPredictInputConfig,
+ BatchPredictOutputConfig,
+ Dataset,
+ ExamplePayload,
+ InputConfig,
+ Model,
+ PredictionServiceClient,
+ PredictResponse,
+)
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.common.consts import CLIENT_INFO
+from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
+
+if TYPE_CHECKING:
+ from google.api_core.operation import Operation
+ from google.api_core.retry import Retry
+ from google.cloud.automl_v1beta1.services.auto_ml.pagers import (
+ ListDatasetsPager,
+ )
+ from google.protobuf.field_mask_pb2 import FieldMask
- Deprecated as AutoML API becomes unusable starting March 31, 2024:
- https://cloud.google.com/automl/docs
- """
- deprecation_warning = (
- "CloudAutoMLHook has been deprecated, as AutoML API becomes unusable starting "
- "March 31, 2024, and will be removed in future release. Please use an equivalent "
- " Vertex AI hook available in"
- "airflow.providers.google.cloud.hooks.vertex_ai instead."
- )
+class CloudAutoMLHook(GoogleBaseHook):
+ """
+ Google Cloud AutoML hook.
- method_exception = "This method cannot be used as AutoML API becomes unusable."
+ All the methods in the hook where project_id is used must be called with
+ keyword arguments rather than positional.
+ """
- def __init__(self, **_) -> None:
- warnings.warn(self.deprecation_warning, AirflowProviderDeprecationWarning)
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ if kwargs.get("delegate_to") is not None:
+ raise RuntimeError(
+ "The `delegate_to` parameter has been deprecated before and finally removed in this version"
+ " of Google Provider. You MUST convert it to `impersonate_chain`"
+ )
+ super().__init__(
+ gcp_conn_id=gcp_conn_id,
+ impersonation_chain=impersonation_chain,
+ )
+ self._client: AutoMlClient | None = None
@staticmethod
def extract_object_id(obj: dict) -> str:
"""Return unique id of the object."""
- warnings.warn(
- "'extract_object_id' method is deprecated and will be removed in future release.",
- AirflowProviderDeprecationWarning,
- )
return obj["name"].rpartition("/")[-1]
- def get_conn(self):
- """
- Retrieve connection to AutoML (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def wait_for_operation(self, **_):
- """
- Wait for long-lasting operation to complete (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def prediction_client(self, **_):
- """
- Create a PredictionServiceClient (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def create_model(self, **_):
- """
- Create a model_id and returns a Model in the `response` field when it completes (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def batch_predict(self, **_):
- """
- Perform a batch prediction (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def predict(self, **_):
- """
- Perform an online prediction (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def create_dataset(self, **_):
- """
- Create a dataset (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def import_data(self, **_):
- """
- Import data (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def list_column_specs(self, **_):
- """
- List column specs (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def get_model(self, **_):
- """
- Get a model (deprecated).
-
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
-
- def delete_model(self, **_):
- """
- Delete a model (deprecated).
+ def get_conn(self) -> AutoMlClient:
+ """
+ Retrieve connection to AutoML.
+
+ :return: Google Cloud AutoML client object.
+ """
+ if self._client is None:
+ self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
+ return self._client
+
+ def wait_for_operation(self, operation: Operation, timeout: float | None = None):
+ """Wait for long-lasting operation to complete."""
+ try:
+ return operation.result(timeout=timeout)
+ except Exception:
+ error = operation.exception(timeout=timeout)
+ raise AirflowException(error)
+
+ @cached_property
+ def prediction_client(self) -> PredictionServiceClient:
+ """
+ Creates PredictionServiceClient.
+
+ :return: Google Cloud AutoML PredictionServiceClient client object.
+ """
+ return PredictionServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_model(
+ self,
+ model: dict | Model,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ retry: Retry | _MethodDefault = DEFAULT,
+ ) -> Operation:
+ """
+ Create a model_id and returns a Model in the `response` field when it completes.
+
+ When you create a model, several model evaluations are created for it:
+ a global evaluation, and one evaluation for each annotation spec.
+
+ :param model: The model_id to create. If a dict is provided, it must be of the same form
+ as the protobuf message `google.cloud.automl_v1beta1.types.Model`
+ :param project_id: ID of the Google Cloud project where model will be created if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests
+ will not be retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete.
+ Note that if `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
+ """
+ client = self.get_conn()
+ parent = f"projects/{project_id}/locations/{location}"
+ return client.create_model(
+ request={"parent": parent, "model": model},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
+ @GoogleBaseHook.fallback_to_default_project_id
+ def batch_predict(
+ self,
+ model_id: str,
+ input_config: dict | BatchPredictInputConfig,
+ output_config: dict | BatchPredictOutputConfig,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ params: dict[str, str] | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Perform a batch prediction and returns a long-running operation object.
+
+ Unlike the online `Predict`, batch prediction result won't be immediately
+ available in the response. Instead, a long-running operation object is returned.
+
+ :param model_id: Name of the model_id requested to serve the batch prediction.
+ :param input_config: Required. The input configuration for batch prediction.
+ If a dict is provided, it must be of the same form as the protobuf message
+ `google.cloud.automl_v1beta1.types.BatchPredictInputConfig`
+ :param output_config: Required. The Configuration specifying where output predictions should be
+ written. If a dict is provided, it must be of the same form as the protobuf message
+ `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig`
+ :param params: Additional domain-specific parameters for the predictions, any string must be up to
+ 25000 characters long.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
+ """
+ client = self.prediction_client
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.batch_predict(
+ request={
+ "name": name,
+ "input_config": input_config,
+ "output_config": output_config,
+ "params": params,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def predict(
+ self,
+ model_id: str,
+ payload: dict | ExamplePayload,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ params: dict[str, str] | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> PredictResponse:
+ """
+ Perform an online prediction and returns the prediction result in the response.
+
+ :param model_id: Name of the model_id requested to serve the prediction.
+ :param payload: Required. Payload to perform a prediction on. The payload must match the problem type
+ that the model_id was trained to solve. If a dict is provided, it must be of
+ the same form as the protobuf message `google.cloud.automl_v1beta1.types.ExamplePayload`
+ :param params: Additional domain-specific parameters, any string must be up to 25000 characters long.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.PredictResponse` instance
+ """
+ client = self.prediction_client
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.predict(
+ request={"name": name, "payload": payload, "params": params},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_dataset(
+ self,
+ dataset: dict | Dataset,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Dataset:
+ """
+ Create a dataset.
+
+ :param dataset: The dataset to create. If a dict is provided, it must be of the
+ same form as the protobuf message Dataset.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.Dataset` instance.
+ """
+ client = self.get_conn()
+ parent = f"projects/{project_id}/locations/{location}"
+ result = client.create_dataset(
+ request={"parent": parent, "dataset": dataset},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def import_data(
+ self,
+ dataset_id: str,
+ location: str,
+ input_config: dict | InputConfig,
+ project_id: str = PROVIDE_PROJECT_ID,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Import data into a dataset. For Tables this method can only be called on an empty Dataset.
+
+ :param dataset_id: Name of the AutoML dataset.
+ :param input_config: The desired input location and its domain specific semantics, if any.
+ If a dict is provided, it must be of the same form as the protobuf message InputConfig.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
+ result = client.import_data(
+ request={"name": name, "input_config": input_config},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
- def update_dataset(self, **_):
+ def list_column_specs(self, **kwargs) -> None:
"""
- Update a model (deprecated).
+ List column specs in a table spec (Deprecated).
:raises: AirflowException
"""
- raise AirflowException(self.method_exception)
-
- def deploy_model(self, **_):
- """
- Deploy a model (deprecated).
+ raise AirflowException(
+ "This method is deprecated as corresponding API becomes no longer available. See:"
+ "https://cloud.google.com/automl/docs/reference/rest/v1beta1/projects.locations.datasets.tableSpecs.columnSpecs/list"
+ )
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
+ @GoogleBaseHook.fallback_to_default_project_id
+ def get_model(
+ self,
+ model_id: str,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Model:
+ """
+ Get a AutoML model.
+
+ :param model_id: Name of the model.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.Model` instance.
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.get_model(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_model(
+ self,
+ model_id: str,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Delete a AutoML model.
+
+ :param model_id: Name of the model.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.delete_model(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ def update_dataset(
+ self,
+ dataset: dict | Dataset,
+ update_mask: dict | FieldMask | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Dataset:
+ """
+ Update a dataset.
+
+ :param dataset: The dataset which replaces the resource on the server.
+ If a dict is provided, it must be of the same form as the protobuf message Dataset.
+ :param update_mask: The update mask applies to the resource. If a dict is provided, it must
+ be of the same form as the protobuf message FieldMask.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types.Dataset` instance..
+ """
+ client = self.get_conn()
+ result = client.update_dataset(
+ request={"dataset": dataset, "update_mask": update_mask},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
- def list_table_specs(self, **_):
+ def deploy_model(self, **kwargs) -> None:
"""
- List table specs (deprecated).
+ Deploys a model (Deprecated).
:raises: AirflowException
"""
- raise AirflowException(self.method_exception)
+ raise AirflowException(
+ "This method is deprecated as corresponding API becomes no longer available. See:"
+ "https://cloud.google.com/automl/docs/reference/rest/v1beta1/projects.locations.models/deploy "
+ )
- def list_datasets(self, **_):
+ def list_table_specs(self, **kwargs) -> None:
"""
- List datasets (deprecated).
+ List table specs in a dataset_id (Deprecated).
:raises: AirflowException
"""
- raise AirflowException(self.method_exception)
-
- def delete_dataset(self, **_):
- """
- Delete a dataset (deprecated).
+ raise AirflowException(
+ "This method is deprecated as corresponding API becomes no longer available. See:"
+ "https://cloud.google.com/automl/docs/reference/rest/v1beta1/projects.locations.datasets.tableSpecs/list "
+ )
- :raises: AirflowException
- """
- raise AirflowException(self.method_exception)
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_datasets(
+ self,
+ location: str,
+ project_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> ListDatasetsPager:
+ """
+ List datasets in a project.
+
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: A `google.gax.PageIterator` instance. By default, this
+ is an iterable of `google.cloud.automl_v1beta1.types.Dataset` instances.
+ This object can also be configured to iterate over the pages
+ of the response through the `options` parameter.
+ """
+ client = self.get_conn()
+ parent = f"projects/{project_id}/locations/{location}"
+ result = client.list_datasets(
+ request={"parent": parent},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_dataset(
+ self,
+ dataset_id: str,
+ location: str,
+ project_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Delete a dataset and all of its contents.
+
+ :param dataset_id: ID of dataset to be deleted.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
+ """
+ client = self.get_conn()
+ name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
+ result = client.delete_dataset(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
diff --git a/airflow/providers/google/cloud/links/automl.py b/airflow/providers/google/cloud/links/automl.py
index b57601c64906c..79561d5b48132 100644
--- a/airflow/providers/google/cloud/links/automl.py
+++ b/airflow/providers/google/cloud/links/automl.py
@@ -19,28 +19,13 @@
from __future__ import annotations
-import warnings
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
-from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.links.base import BaseGoogleLink
if TYPE_CHECKING:
from airflow.utils.context import Context
-
-def __getattr__(name: str) -> Any:
- warnings.warn(
- (
- "AutoML links module have been deprecated and will be removed in the next MAJOR release."
- " Please use equivalent Vertex AI links instead"
- ),
- AirflowProviderDeprecationWarning,
- stacklevel=2,
- )
- return getattr(__name__, name)
-
-
AUTOML_BASE_LINK = "https://console.cloud.google.com/automl-tables"
AUTOML_DATASET_LINK = (
AUTOML_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/schemav2?project={project_id}"
diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py
index ca32994193fbb..1d0784e73a4a9 100644
--- a/airflow/providers/google/cloud/operators/automl.py
+++ b/airflow/providers/google/cloud/operators/automl.py
@@ -15,16 +15,1019 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""This module is deprecated. Please use `airflow.providers.google.cloud.vertex_ai.auto_ml` instead."""
+"""This module contains Google AutoML operators."""
from __future__ import annotations
+import ast
import warnings
+from typing import TYPE_CHECKING, Sequence, Tuple
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
+from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.vertex_ai.auto_ml` instead.",
- AirflowProviderDeprecationWarning,
- stacklevel=2,
+from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.links.automl import (
+ AutoMLDatasetLink,
+ AutoMLDatasetListLink,
+ AutoMLModelLink,
+ AutoMLModelPredictLink,
+ AutoMLModelTrainLink,
)
+from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
+
+if TYPE_CHECKING:
+ from google.api_core.retry import Retry
+
+ from airflow.utils.context import Context
+
+MetaData = Sequence[Tuple[str, str]]
+
+
+class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
+ """
+ Creates Google Cloud AutoML model.
+
+ AutoMLTrainModelOperator for tables, video intelligence, vision and natural language is deprecated,
+ and can only be used for translation.
+ All the functionality of legacy features are available on the Vertex AI AutoML Operators,
+ which can be used instead.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLTrainModelOperator`
+
+ :param model: Model definition.
+ :param project_id: ID of the Google Cloud project where model will be created if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+
+ :raises: AirflowException: if model type is legacy
+ """
+
+ template_fields: Sequence[str] = (
+ "model",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (
+ AutoMLModelTrainLink(),
+ AutoMLModelLink(),
+ )
+
+ def __init__(
+ self,
+ *,
+ model: dict,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.model = model
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ # Output warning if running not AutoML Translation prediction job
+ if "translation_model_metadata" not in self.model:
+ raise AirflowException(
+ "Using AutoMLTrainModelOperator for tables, video intelligence, vision and natural language"
+ " is deprecated, and can only be used for translation. "
+ "All the functionality of legacy domains are available on the Vertex AI platform. "
+ "Please use equivalent AutoML operators of Vertex AI."
+ )
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Creating model %s...", self.model["display_name"])
+ operation = hook.create_model(
+ model=self.model,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelTrainLink.persist(context=context, task_instance=self, project_id=project_id)
+ operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ result = Model.to_dict(operation_result)
+ model_id = hook.extract_object_id(result)
+ self.log.info("Model is created, model_id: %s", model_id)
+
+ self.xcom_push(context, key="model_id", value=model_id)
+ if project_id:
+ AutoMLModelLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.model["dataset_id"] or "-",
+ model_id=model_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class AutoMLPredictOperator(GoogleCloudBaseOperator):
+ """
+ Runs prediction operation on Google Cloud AutoML.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLPredictOperator`
+
+ :param model_id: Name of the model requested to serve the batch prediction.
+ :param payload: Name od the model used for the prediction.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param operation_params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLModelPredictLink(),)
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ location: str,
+ payload: dict,
+ operation_params: dict[str, str] | None = None,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.operation_params = operation_params # type: ignore
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.payload = payload
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ result = hook.predict(
+ model_id=self.model_id,
+ payload=self.payload,
+ location=self.location,
+ project_id=self.project_id,
+ params=self.operation_params,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelPredictLink.persist(
+ context=context,
+ task_instance=self,
+ model_id=self.model_id,
+ project_id=project_id,
+ )
+ return PredictResponse.to_dict(result)
+
+
+class AutoMLBatchPredictOperator(GoogleCloudBaseOperator):
+ """
+ Perform a batch prediction on Google Cloud AutoML.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLBatchPredictOperator`
+
+ :param project_id: ID of the Google Cloud project where model will be created if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param model_id: Name of the model_id requested to serve the batch prediction.
+ :param input_config: Required. The input configuration for batch prediction.
+ If a dict is provided, it must be of the same form as the protobuf message
+ `google.cloud.automl_v1beta1.types.BatchPredictInputConfig`
+ :param output_config: Required. The Configuration specifying where output predictions should be
+ written. If a dict is provided, it must be of the same form as the protobuf message
+ `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig`
+ :param prediction_params: Additional domain-specific parameters for the predictions,
+ any string must be up to 25000 characters long.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "input_config",
+ "output_config",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLModelPredictLink(),)
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ input_config: dict,
+ output_config: dict,
+ location: str,
+ project_id: str | None = None,
+ prediction_params: dict[str, str] | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.location = location
+ self.project_id = project_id
+ self.prediction_params = prediction_params
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.input_config = input_config
+ self.output_config = output_config
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Fetch batch prediction.")
+ operation = hook.batch_predict(
+ model_id=self.model_id,
+ input_config=self.input_config,
+ output_config=self.output_config,
+ project_id=self.project_id,
+ location=self.location,
+ params=self.prediction_params,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ result = BatchPredictResult.to_dict(operation_result)
+ self.log.info("Batch prediction is ready.")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelPredictLink.persist(
+ context=context,
+ task_instance=self,
+ model_id=self.model_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class AutoMLCreateDatasetOperator(GoogleCloudBaseOperator):
+ """
+ Creates a Google Cloud AutoML dataset.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLCreateDatasetOperator`
+
+ :param dataset: The dataset to create. If a dict is provided, it must be of the
+ same form as the protobuf message Dataset.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset: dict,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.dataset = dataset
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Creating dataset %s...", self.dataset)
+ result = hook.create_dataset(
+ dataset=self.dataset,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ result = Dataset.to_dict(result)
+ dataset_id = hook.extract_object_id(result)
+ self.log.info("Creating completed. Dataset id: %s", dataset_id)
+
+ self.xcom_push(context, key="dataset_id", value=dataset_id)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=dataset_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class AutoMLImportDataOperator(GoogleCloudBaseOperator):
+ """
+ Imports data to a Google Cloud AutoML dataset.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLImportDataOperator`
+
+ :param dataset_id: ID of dataset to be updated.
+ :param input_config: The desired input location and its domain specific semantics, if any.
+ If a dict is provided, it must be of the same form as the protobuf message InputConfig.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset_id",
+ "input_config",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str,
+ location: str,
+ input_config: dict,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.dataset_id = dataset_id
+ self.input_config = input_config
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Importing data to dataset...")
+ operation = hook.import_data(
+ dataset_id=self.dataset_id,
+ input_config=self.input_config,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ self.log.info("Import is completed")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.dataset_id,
+ project_id=project_id,
+ )
+
+
+class AutoMLTablesListColumnSpecsOperator(GoogleCloudBaseOperator):
+ """
+ Lists column specs in a table (Deprecated).
+
+ :raises: AirflowException
+ """
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ raise AirflowException(
+ "AutoMLTablesListColumnSpecsOperator is deprecated as corresponding API becomes no longer"
+ " available. See: "
+ "https://cloud.google.com/automl/docs/reference/rest/v1beta1/projects.locations.datasets.tableSpecs.columnSpecs/list"
+ )
+
+
+class AutoMLUpdateDatasetOperator(GoogleCloudBaseOperator):
+ """
+ Updates a dataset.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLUpdateDatasetOperator`
+
+ :param dataset: The dataset which replaces the resource on the server.
+ If a dict is provided, it must be of the same form as the protobuf message Dataset.
+ :param update_mask: The update mask applies to the resource. If a dict is provided, it must
+ be of the same form as the protobuf message FieldMask.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset",
+ "update_mask",
+ "location",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset: dict,
+ location: str,
+ update_mask: dict | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.dataset = dataset
+ self.update_mask = update_mask
+ self.location = location
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Updating AutoML dataset %s.", self.dataset["name"])
+ result = hook.update_dataset(
+ dataset=self.dataset,
+ update_mask=self.update_mask,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Dataset updated.")
+ project_id = hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=hook.extract_object_id(self.dataset),
+ project_id=project_id,
+ )
+ return Dataset.to_dict(result)
+
+
+class AutoMLTablesUpdateDatasetOperator(AutoMLUpdateDatasetOperator):
+ """
+ Updates a dataset (Deprecated).
+
+ This operator has been renamed to AutoMLUpdateDatasetOperator.
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset",
+ "update_mask",
+ "location",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetLink(),)
+
+ def __init__(
+ self,
+ *,
+ dataset: dict,
+ location: str,
+ update_mask: dict | None = None,
+ **kwargs,
+ ) -> None:
+ warnings.warn(
+ "This operator is deprecated and has been renamed to AutoMLUpdateDatasetOperator",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ super().__init__(dataset=dataset, update_mask=update_mask, location=location, **kwargs)
+
+
+class AutoMLGetModelOperator(GoogleCloudBaseOperator):
+ """
+ Get Google Cloud AutoML model.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLGetModelOperator`
+
+ :param model_id: Name of the model requested to serve the prediction.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLModelLink(),)
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ result = hook.get_model(
+ model_id=self.model_id,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ model = Model.to_dict(result)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=model["dataset_id"],
+ model_id=self.model_id,
+ project_id=project_id,
+ )
+ return model
+
+
+class AutoMLDeleteModelOperator(GoogleCloudBaseOperator):
+ """
+ Delete Google Cloud AutoML model.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLDeleteModelOperator`
+
+ :param model_id: Name of the model requested to serve the prediction.
+ :param project_id: ID of the Google Cloud project where model is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param params: Additional domain-specific parameters for the predictions.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.model_id = model_id
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ operation = hook.delete_model(
+ model_id=self.model_id,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ hook.wait_for_operation(timeout=self.timeout, operation=operation)
+ self.log.info("Deletion is completed")
+
+
+class AutoMLDeployModelOperator(GoogleCloudBaseOperator):
+ """Deploys a model (Deprecated)."""
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ raise AirflowException(
+ "AutoMLDeployModelOperator is deprecated as corresponding API becomes no longer available. See:"
+ "https://cloud.google.com/automl/docs/reference/rest/v1beta1/projects.locations.models/deploy"
+ )
+
+
+class AutoMLTablesListTableSpecsOperator(GoogleCloudBaseOperator):
+ """Lists table specs in a dataset (Deprecated)."""
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ raise AirflowException(
+ "AutoMLTablesListTableSpecsOperator is deprecated as corresponding API becomes"
+ " no longer available. See:"
+ " https://cloud.google.com/automl/docs/reference/rest/v1beta1/projects.locations.datasets.tableSpecs.columnSpecs/list "
+ )
+
+
+class AutoMLListDatasetOperator(GoogleCloudBaseOperator):
+ """
+ Lists AutoML Datasets in project.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLListDatasetOperator`
+
+ :param project_id: ID of the Google Cloud project where datasets are located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (AutoMLDatasetListLink(),)
+
+ def __init__(
+ self,
+ *,
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Requesting datasets")
+ page_iterator = hook.list_datasets(
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ result = [Dataset.to_dict(dataset) for dataset in page_iterator]
+ self.log.info("Datasets obtained.")
+
+ self.xcom_push(
+ context,
+ key="dataset_id_list",
+ value=[hook.extract_object_id(d) for d in result],
+ )
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetListLink.persist(context=context, task_instance=self, project_id=project_id)
+ return result
+
+
+class AutoMLDeleteDatasetOperator(GoogleCloudBaseOperator):
+ """
+ Deletes a dataset and all of its contents.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AutoMLDeleteDatasetOperator`
+
+ :param dataset_id: Name of the dataset_id, list of dataset_id or string of dataset_id
+ coma separated to be deleted.
+ :param project_id: ID of the Google Cloud project where dataset is located if None then
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using short-term
+ credentials, or chained list of accounts required to get the access_token
+ of the last account in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding identity, with first
+ account from the list granting this role to the originating account (templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset_id",
+ "location",
+ "project_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ dataset_id: str | list[str],
+ location: str,
+ project_id: str | None = None,
+ metadata: MetaData = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.dataset_id = dataset_id
+ self.location = location
+ self.project_id = project_id
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ @staticmethod
+ def _parse_dataset_id(dataset_id: str | list[str]) -> list[str]:
+ if not isinstance(dataset_id, str):
+ return dataset_id
+ try:
+ return ast.literal_eval(dataset_id)
+ except (SyntaxError, ValueError):
+ return dataset_id.split(",")
+
+ def execute(self, context: Context):
+ hook = CloudAutoMLHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ dataset_id_list = self._parse_dataset_id(self.dataset_id)
+ for dataset_id in dataset_id_list:
+ self.log.info("Deleting dataset %s", dataset_id)
+ hook.delete_dataset(
+ dataset_id=dataset_id,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Dataset deleted.")
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index b70ff2705a0d7..748cfbb39cc28 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -108,6 +108,7 @@ dependencies:
- google-auth>=1.0.0
- google-auth-httplib2>=0.0.1
- google-cloud-aiplatform>=1.42.1
+ - google-cloud-automl>=2.12.0
- google-cloud-bigquery-datatransfer>=3.13.0
- google-cloud-bigtable>=2.17.0
- google-cloud-build>=3.22.0
@@ -201,6 +202,8 @@ integrations:
tags: [gmp]
- integration-name: Google AutoML
external-doc-url: https://cloud.google.com/automl/
+ how-to-guide:
+ - /docs/apache-airflow-providers-google/operators/cloud/automl.rst
logo: /integration-logos/gcp/Cloud-AutoML.png
tags: [gcp]
- integration-name: Google BigQuery Data Transfer Service
@@ -529,6 +532,9 @@ operators:
- integration-name: Google Cloud Common
python-modules:
- airflow.providers.google.cloud.operators.cloud_base
+ - integration-name: Google AutoML
+ python-modules:
+ - airflow.providers.google.cloud.operators.automl
- integration-name: Google BigQuery
python-modules:
- airflow.providers.google.cloud.operators.bigquery
@@ -1229,6 +1235,11 @@ extra-links:
- airflow.providers.google.cloud.links.cloud_build.CloudBuildListLink
- airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggersListLink
- airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggerDetailsLink
+ - airflow.providers.google.cloud.links.automl.AutoMLDatasetLink
+ - airflow.providers.google.cloud.links.automl.AutoMLDatasetListLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelTrainLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelPredictLink
- airflow.providers.google.cloud.links.life_sciences.LifeSciencesLink
- airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsDetailsLink
- airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsListLink
diff --git a/docker_tests/test_prod_image.py b/docker_tests/test_prod_image.py
index b4e59052ee309..ab35c63bffa53 100644
--- a/docker_tests/test_prod_image.py
+++ b/docker_tests/test_prod_image.py
@@ -131,6 +131,7 @@ def test_pip_dependencies_conflict(self, default_docker_image):
"googleapiclient",
"google.auth",
"google_auth_httplib2",
+ "google.cloud.automl",
"google.cloud.bigquery_datatransfer",
"google.cloud.bigtable",
"google.cloud.container",
diff --git a/docs/apache-airflow-providers-google/operators/cloud/automl.rst b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
new file mode 100644
index 0000000000000..3211340344c2e
--- /dev/null
+++ b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
@@ -0,0 +1,91 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Google Cloud AutoML Operators
+=======================================
+
+The `Google Cloud AutoML `__
+makes the power of machine learning available to you even if you have limited knowledge
+of machine learning. You can use AutoML to build on Google's machine learning capabilities
+to create your own custom machine learning models that are tailored to your business needs,
+and then integrate those models into your applications and web sites.
+
+As of March 31, 2024, GCP has shut down most of the AutoML domains (AutoML Tables, AutoML Video Intelligence,
+AutoML Vision, and AutoML Natural Language) in favor of equivalent domains of Vertex AI, and retained only AutoML Translation.
+All the functionality of legacy features are available on the Vertex AI operators - please refer to
+:doc:`/operators/cloud/vertex_ai` for more information.
+Please avoid using the AutoML operators for domains other than AutoML translation.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+.. include:: /operators/_partials/prerequisite_tasks.rst
+
+.. _howto/operator:CloudAutoMLDocuments:
+.. _howto/operator:AutoMLCreateDatasetOperator:
+.. _howto/operator:AutoMLImportDataOperator:
+.. _howto/operator:AutoMLUpdateDatasetOperator:
+
+Creating Datasets
+^^^^^^^^^^^^^^^^^
+
+To create a Google AutoML dataset you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator`.
+The operator returns dataset id in :ref:`XCom ` under ``dataset_id`` key.
+After creating a dataset you can use it to import some data using
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLImportDataOperator`.
+To update dataset you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLUpdateDatasetOperator`.
+
+
+.. _howto/operator:AutoMLTrainModelOperator:
+.. _howto/operator:AutoMLGetModelOperator:
+.. _howto/operator:AutoMLDeleteModelOperator:
+
+Operations On Models
+^^^^^^^^^^^^^^^^^^^^
+
+To create a Google AutoML model you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator`.
+The operator will wait for the operation to complete. Additionally the operator
+returns the id of model in :ref:`XCom ` under ``model_id`` key.
+
+To get an existing model one can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLGetModelOperator`.
+
+If you wish to delete a model you can use
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator`.
+
+.. _howto/operator:AutoMLPredictOperator:
+.. _howto/operator:AutoMLBatchPredictOperator:
+.. _howto/operator:AutoMLListDatasetOperator:
+.. _howto/operator:AutoMLDeleteDatasetOperator:
+
+Listing And Deleting Datasets
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+You can get a list of AutoML datasets using
+:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`. The operator returns list
+of datasets ids in :ref:`XCom ` under ``dataset_id_list`` key.
+
+Reference
+^^^^^^^^^
+
+For further information, look at:
+
+* `Client Library Documentation `__
+* `Product Documentation `__
diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
index 24c51430d351c..5409656e033e0 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
@@ -15,10 +15,10 @@
specific language governing permissions and limitations
under the License.
-Google Cloud VertexAI Operators
+Google Cloud Vertex AI Operators
=======================================
-The `Google Cloud VertexAI `__
+The `Google Cloud Vertex AI `__
brings AutoML and AI Platform together into a unified API, client library, and user
interface. AutoML lets you train models on image, tabular, text, and video datasets
without writing code, while training in AI Platform lets you run custom training code.
@@ -29,7 +29,7 @@ request predictions with Vertex AI.
Creating Datasets
^^^^^^^^^^^^^^^^^
-To create a Google VertexAI dataset you can use
+To create a Google Vertex AI dataset you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`.
The operator returns dataset id in :ref:`XCom ` under ``dataset_id`` key.
@@ -177,13 +177,45 @@ If you wish to delete a Custom Training Job you can use
Creating an AutoML Training Jobs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Before running a Vertex AI Operator for AutoML training jobs, please ensure that your data is correctly stored in Vertex AI
+datasets. To create and import data to the dataset please use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`
+and
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`.
-To create a Google Vertex AI Auto ML training jobs you have five operators
+To create a Google Vertex AI Auto ML training jobs you have the following operators
:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLForecastingTrainingJobOperator`
:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator`
+
+You can find example on how to use ``CreateAutoMLImageTrainingJobOperator`` for AutoML image classification here:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_vision_classification.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_cloud_create_image_classification_training_job_operator]
+ :end-before: [END howto_cloud_create_image_classification_training_job_operator]
+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator`
:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`
+
+You can find example on how to use ``CreateAutoMLTextTrainingJobOperator`` for AutoML text classification here:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_cloud_create_text_classification_training_job_operator]
+ :end-before: [END howto_cloud_create_text_classification_training_job_operator]
+
:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`
+
+You can find an example on how to use ``CreateAutoMLVideoTrainingJobOperator`` for Auto ML Video Intelligence classification here:
+
+.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_video_classification.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_cloud_create_video_classification_training_job_operator]
+ :end-before: [END howto_cloud_create_video_classification_training_job_operator]
+
Each of them will wait for the operation to complete. The results of each operator will be a model
which was trained by user using these operators.
@@ -278,7 +310,7 @@ If you wish to delete a Auto ML Training Job you can use
Creating a Batch Prediction Jobs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-To create a Google VertexAI Batch Prediction Job you can use
+To create a Google Vertex AI Batch Prediction Job you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.CreateBatchPredictionJobOperator`.
The operator returns batch prediction job id in :ref:`XCom ` under ``batch_prediction_job_id`` key.
@@ -319,7 +351,7 @@ To get a batch prediction job list you can use
Creating an Endpoint Service
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-To create a Google VertexAI endpoint you can use
+To create a Google Vertex AI endpoint you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.CreateEndpointOperator`.
The operator returns endpoint id in :ref:`XCom ` under ``endpoint_id`` key.
@@ -368,7 +400,7 @@ To get an endpoint list you can use
Creating a Hyperparameter Tuning Jobs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-To create a Google VertexAI hyperparameter tuning job you can use
+To create a Google Vertex AI hyperparameter tuning job you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.hyperparameter_tuning_job.CreateHyperparameterTuningJobOperator`.
The operator returns hyperparameter tuning job id in :ref:`XCom ` under ``hyperparameter_tuning_job_id`` key.
@@ -417,7 +449,7 @@ To get a hyperparameter tuning job list you can use
Creating a Model Service
^^^^^^^^^^^^^^^^^^^^^^^^
-To upload a Google VertexAI model you can use
+To upload a Google Vertex AI model you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.model_service.UploadModelOperator`.
The operator returns model id in :ref:`XCom ` under ``model_id`` key.
@@ -511,7 +543,7 @@ To delete specific version of model you can use
Running a Pipeline Jobs
^^^^^^^^^^^^^^^^^^^^^^^
-To run a Google VertexAI Pipeline Job you can use
+To run a Google Vertex AI Pipeline Job you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.pipeline_job.RunPipelineJobOperator`.
The operator returns pipeline job id in :ref:`XCom ` under ``pipeline_job_id`` key.
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index d88b975a87ebc..54a02cc5d598b 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -526,6 +526,7 @@
"google-auth-httplib2>=0.0.1",
"google-auth>=1.0.0",
"google-cloud-aiplatform>=1.42.1",
+ "google-cloud-automl>=2.12.0",
"google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.13.0",
"google-cloud-bigtable>=2.17.0",
diff --git a/newsfragments/38635.significant.rst b/newsfragments/38635.significant.rst
new file mode 100644
index 0000000000000..839c3910f449f
--- /dev/null
+++ b/newsfragments/38635.significant.rst
@@ -0,0 +1,7 @@
+Support for AutoML operators and hooks has been limited only to AutoML Translation.
+
+As of March 31, 2024, GCP has shut down most of the AutoML domains (AutoML Tables, AutoML Video Intelligence,
+AutoML Vision, and AutoML Natural Language) in favor of equivalent domains of Vertex AI AutoML, and retained only AutoML
+Translation. For that reason, AutoML operators and hooks that utilize obsolete APIs have been deprecated and will
+raise ``AirflowException`` upon initialization. Using ``AutoMLTrainModelOperator`` for AutoML domains
+other than translation will raise an exception as well.
diff --git a/scripts/in_container/run_provider_yaml_files_check.py b/scripts/in_container/run_provider_yaml_files_check.py
index c343bb2397726..b1608c25bcedd 100755
--- a/scripts/in_container/run_provider_yaml_files_check.py
+++ b/scripts/in_container/run_provider_yaml_files_check.py
@@ -50,17 +50,10 @@
"airflow.providers.apache.hdfs.hooks.hdfs",
"airflow.providers.cncf.kubernetes.triggers.kubernetes_pod",
"airflow.providers.cncf.kubernetes.operators.kubernetes_pod",
- "airflow.providers.google.cloud.operators.automl",
]
KNOWN_DEPRECATED_CLASSES = [
- "airflow.providers.google.cloud.links.automl.AutoMLDatasetLink",
- "airflow.providers.google.cloud.links.automl.AutoMLDatasetListLink",
- "airflow.providers.google.cloud.links.automl.AutoMLModelLink",
- "airflow.providers.google.cloud.links.automl.AutoMLModelListLink",
- "airflow.providers.google.cloud.links.automl.AutoMLModelPredictLink",
"airflow.providers.google.cloud.links.dataproc.DataprocLink",
- "airflow.providers.google.cloud.hooks.automl.CloudAutoMLHook",
]
try:
diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py
index 773a2fc1c4206..b6a275f594ffa 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -372,6 +372,10 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
CLASS_DIRS = ProjectStructureTest.CLASS_DIRS | {"operators/vertex_ai"}
DEPRECATED_CLASSES = {
+ "airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLTablesListColumnSpecsOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator",
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service"
".CloudDataTransferServiceS3ToGCSOperator",
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service"
@@ -416,6 +420,11 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
}
MISSING_EXAMPLES_FOR_CLASSES = {
+ "airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLGetModelOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLPredictOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLUpdateDatasetOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
"airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator",
"airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator",
@@ -429,6 +438,8 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
}
ASSETS_NOT_REQUIRED = {
+ "airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator",
+ "airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator",
"airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator",
"airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteDatasetOperator",
"airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator",
diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml
index d271e6077a25b..6d27f4f5388d6 100644
--- a/tests/deprecations_ignore.yml
+++ b/tests/deprecations_ignore.yml
@@ -779,6 +779,7 @@
- tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py::TestCustomJobWithoutDefaultProjectIdHook::test_get_pipeline_job
- tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py::TestCustomJobWithoutDefaultProjectIdHook::test_list_pipeline_jobs
- tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryCreateExternalTableOperator::test_execute_with_csv_format
+- tests/providers/google/cloud/operators/test_automl.py::TestAutoMLTrainModelOperator::test_execute
- tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryCreateExternalTableOperator::test_execute_with_parquet_format
- tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryOperator::test_bigquery_operator_defaults
- tests/providers/google/cloud/operators/test_bigquery.py::TestBigQueryOperator::test_bigquery_operator_extra_link_when_missing_job_id
diff --git a/tests/providers/google/cloud/hooks/test_automl.py b/tests/providers/google/cloud/hooks/test_automl.py
index 0f97b91e5a78d..9b7245a67184f 100644
--- a/tests/providers/google/cloud/hooks/test_automl.py
+++ b/tests/providers/google/cloud/hooks/test_automl.py
@@ -17,85 +17,212 @@
# under the License.
from __future__ import annotations
+from unittest import mock
+
import pytest
+from google.api_core.gapic_v1.method import DEFAULT
-from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.common.consts import CLIENT_INFO
+from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id
+
+CREDENTIALS = "test-creds"
+TASK_ID = "test-automl-hook"
+GCP_PROJECT_ID = "test-project"
+GCP_LOCATION = "test-location"
+MODEL_NAME = "test_model"
+MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
+DATASET_ID = "TBL123456789"
+MODEL = {
+ "display_name": MODEL_NAME,
+ "dataset_id": DATASET_ID,
+ "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
+}
+
+LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
+MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
+DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
+
+INPUT_CONFIG = {"input": "value"}
+OUTPUT_CONFIG = {"output": "value"}
+PAYLOAD = {"test": "payload"}
+DATASET = {"dataset_id": "data"}
+MASK = {"field": "mask"}
class TestAutoMLHook:
- def setup_method(self):
- self.hook = CloudAutoMLHook()
-
- def test_init(self):
- with pytest.warns(AirflowProviderDeprecationWarning):
- CloudAutoMLHook()
-
- def test_extract_object_id(self):
- with pytest.warns(AirflowProviderDeprecationWarning, match="'extract_object_id'"):
- object_id = CloudAutoMLHook.extract_object_id(obj={"name": "x/y"})
- assert object_id == "y"
-
- def test_get_conn(self):
- with pytest.raises(AirflowException):
- self.hook.get_conn()
-
- def test_wait_for_operation(self):
- with pytest.raises(AirflowException):
- self.hook.wait_for_operation()
-
- def test_prediction_client(self):
- with pytest.raises(AirflowException):
- self.hook.prediction_client()
-
- def test_create_model(self):
- with pytest.raises(AirflowException):
- self.hook.create_model()
-
- def test_batch_predict(self):
- with pytest.raises(AirflowException):
- self.hook.batch_predict()
-
- def test_predict(self):
- with pytest.raises(AirflowException):
- self.hook.predict()
-
- def test_create_dataset(self):
- with pytest.raises(AirflowException):
- self.hook.create_dataset()
+ def test_delegate_to_runtime_error(self):
+ with pytest.raises(RuntimeError):
+ CloudAutoMLHook(gcp_conn_id="GCP_CONN_ID", delegate_to="delegate_to")
- def test_import_data(self):
- with pytest.raises(AirflowException):
- self.hook.import_data()
-
- def test_list_column_specs(self):
- with pytest.raises(AirflowException):
- self.hook.list_column_specs()
-
- def test_get_model(self):
- with pytest.raises(AirflowException):
- self.hook.get_model()
-
- def test_delete_model(self):
- with pytest.raises(AirflowException):
- self.hook.delete_model()
-
- def test_update_dataset(self):
- with pytest.raises(AirflowException):
- self.hook.update_dataset()
-
- def test_deploy_model(self):
- with pytest.raises(AirflowException):
- self.hook.deploy_model()
+ def setup_method(self):
+ with mock.patch(
+ "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_no_default_project_id,
+ ):
+ self.hook = CloudAutoMLHook()
+ self.hook.get_credentials = mock.MagicMock(return_value=CREDENTIALS) # type: ignore
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient")
+ def test_get_conn(self, mock_automl_client):
+ self.hook.get_conn()
+ mock_automl_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO)
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient")
+ def test_prediction_client(self, mock_prediction_client):
+ client = self.hook.prediction_client # noqa: F841
+ mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO)
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_model")
+ def test_create_model(self, mock_create_model):
+ self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_create_model.assert_called_once_with(
+ request=dict(parent=LOCATION_PATH, model=MODEL), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict")
+ def test_batch_predict(self, mock_batch_predict):
+ self.hook.batch_predict(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ input_config=INPUT_CONFIG,
+ output_config=OUTPUT_CONFIG,
+ )
+
+ mock_batch_predict.assert_called_once_with(
+ request=dict(
+ name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None
+ ),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict")
+ def test_predict(self, mock_predict):
+ self.hook.predict(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ payload=PAYLOAD,
+ )
+
+ mock_predict.assert_called_once_with(
+ request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset")
+ def test_create_dataset(self, mock_create_dataset):
+ self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_create_dataset.assert_called_once_with(
+ request=dict(parent=LOCATION_PATH, dataset=DATASET),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data")
+ def test_import_dataset(self, mock_import_data):
+ self.hook.import_data(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ input_config=INPUT_CONFIG,
+ )
+
+ mock_import_data.assert_called_once_with(
+ request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG),
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs")
+ def test_list_column_specs(self, mock_list_column_specs):
+ table_spec = "table_spec_id"
+ filter_ = "filter"
+ page_size = 42
+
+ with pytest.raises(AirflowException):
+ self.hook.list_column_specs(
+ dataset_id=DATASET_ID,
+ table_spec_id=table_spec,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ field_mask=MASK,
+ filter_=filter_,
+ page_size=page_size,
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model")
+ def test_get_model(self, mock_get_model):
+ self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_get_model.assert_called_once_with(
+ request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model")
+ def test_delete_model(self, mock_delete_model):
+ self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_delete_model.assert_called_once_with(
+ request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset")
+ def test_update_dataset(self, mock_update_dataset):
+ self.hook.update_dataset(
+ dataset=DATASET,
+ update_mask=MASK,
+ )
+
+ mock_update_dataset.assert_called_once_with(
+ request=dict(dataset=DATASET, update_mask=MASK), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ def test_deploy_model(
+ self,
+ ):
+ with pytest.raises(AirflowException):
+ self.hook.deploy_model(
+ model_id=MODEL_ID,
+ image_detection_metadata={},
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
def test_list_table_specs(self):
- with pytest.raises(AirflowException):
- self.hook.list_table_specs()
-
- def test_list_datasets(self):
- with pytest.raises(AirflowException):
- self.hook.list_datasets()
-
- def test_delete_dataset(self):
- with pytest.raises(AirflowException):
- self.hook.delete_dataset()
+ filter_ = "filter"
+ page_size = 42
+ with pytest.raises(AirflowException):
+ self.hook.list_table_specs(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ filter_=filter_,
+ page_size=page_size,
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets")
+ def test_list_datasets(self, mock_list_datasets):
+ self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_list_datasets.assert_called_once_with(
+ request=dict(parent=LOCATION_PATH), retry=DEFAULT, timeout=None, metadata=()
+ )
+
+ @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset")
+ def test_delete_dataset(self, mock_delete_dataset):
+ self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
+
+ mock_delete_dataset.assert_called_once_with(
+ request=dict(name=DATASET_PATH), retry=DEFAULT, timeout=None, metadata=()
+ )
diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py
index 1caf680cd512c..7559c64fac5bc 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -17,13 +17,596 @@
# under the License.
from __future__ import annotations
-from importlib import import_module
+import copy
+from unittest import mock
import pytest
+from google.api_core.gapic_v1.method import DEFAULT
+from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.operators.automl import (
+ AutoMLBatchPredictOperator,
+ AutoMLCreateDatasetOperator,
+ AutoMLDeleteDatasetOperator,
+ AutoMLDeleteModelOperator,
+ AutoMLDeployModelOperator,
+ AutoMLGetModelOperator,
+ AutoMLImportDataOperator,
+ AutoMLListDatasetOperator,
+ AutoMLPredictOperator,
+ AutoMLTablesListColumnSpecsOperator,
+ AutoMLTablesListTableSpecsOperator,
+ AutoMLTablesUpdateDatasetOperator,
+ AutoMLTrainModelOperator,
+ AutoMLUpdateDatasetOperator,
+)
+from airflow.utils import timezone
+CREDENTIALS = "test-creds"
+TASK_ID = "test-automl-hook"
+GCP_PROJECT_ID = "test-project"
+GCP_LOCATION = "test-location"
+MODEL_NAME = "test_model"
+MODEL_ID = "TBL9195602771183665152"
+DATASET_ID = "TBL123456789"
+MODEL = {
+ "display_name": MODEL_NAME,
+ "dataset_id": DATASET_ID,
+ "translation_model_metadata": {"base_model": "some_base_model"},
+}
-def test_deprecated_module():
- with pytest.warns(AirflowProviderDeprecationWarning):
- import_module("airflow.providers.google.cloud.operators.automl")
+LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
+MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
+DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
+
+INPUT_CONFIG = {"input": "value"}
+OUTPUT_CONFIG = {"output": "value"}
+PAYLOAD = {"test": "payload"}
+DATASET = {"dataset_id": "data"}
+MASK = {"field": "mask"}
+
+extract_object_id = CloudAutoMLHook.extract_object_id
+
+
+class TestAutoMLTrainModelOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.wait_for_operation.return_value = Model()
+ op = AutoMLTrainModelOperator(
+ model=MODEL,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.create_model.assert_called_once_with(
+ model=MODEL,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ metadata=(),
+ )
+
+ def test_execute_deprecated_model_types(self):
+ invalid_model = MODEL.copy()
+ invalid_model.pop("translation_model_metadata")
+ op = AutoMLTrainModelOperator(
+ model=invalid_model,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ with pytest.raises(AirflowException):
+ op.execute(context=mock.MagicMock())
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTrainModelOperator,
+ # Templated fields
+ model="{{ 'model' }}",
+ location="{{ 'location' }}",
+ impersonation_chain="{{ 'impersonation_chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTrainModelOperator = ti.task
+ assert task.model == "model"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation_chain"
+
+
+class TestAutoMLBatchPredictOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult()
+ mock_hook.return_value.extract_object_id = extract_object_id
+ mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult()
+
+ op = AutoMLBatchPredictOperator(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ input_config=INPUT_CONFIG,
+ output_config=OUTPUT_CONFIG,
+ task_id=TASK_ID,
+ prediction_params={},
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.batch_predict.assert_called_once_with(
+ input_config=INPUT_CONFIG,
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ output_config=OUTPUT_CONFIG,
+ params={},
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLBatchPredictOperator,
+ # Templated fields
+ model_id="{{ 'model' }}",
+ input_config="{{ 'input-config' }}",
+ output_config="{{ 'output-config' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLBatchPredictOperator = ti.task
+ assert task.model_id == "model"
+ assert task.input_config == "input-config"
+ assert task.output_config == "output-config"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLPredictOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.predict.return_value = PredictResponse()
+
+ op = AutoMLPredictOperator(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ payload=PAYLOAD,
+ task_id=TASK_ID,
+ operation_params={"TEST_KEY": "TEST_VALUE"},
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.predict.assert_called_once_with(
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ params={"TEST_KEY": "TEST_VALUE"},
+ payload=PAYLOAD,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLPredictOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ payload={},
+ )
+ ti.render_templates()
+ task: AutoMLPredictOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.project_id == "project-id"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLCreateImportOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH)
+ mock_hook.return_value.extract_object_id = extract_object_id
+
+ op = AutoMLCreateDatasetOperator(
+ dataset=DATASET,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.create_dataset.assert_called_once_with(
+ dataset=DATASET,
+ location=GCP_LOCATION,
+ metadata=(),
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLCreateDatasetOperator,
+ # Templated fields
+ dataset="{{ 'dataset' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLCreateDatasetOperator = ti.task
+ assert task.dataset == "dataset"
+ assert task.project_id == "project-id"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLListColumnsSpecsOperator:
+ def test_deprecation(self):
+ with pytest.raises(AirflowException):
+ AutoMLTablesListColumnSpecsOperator(
+ dataset_id=DATASET_ID,
+ table_spec_id="table_spec",
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ field_mask=MASK,
+ filter_="filter_",
+ page_size=10,
+ task_id=TASK_ID,
+ )
+
+
+class TestAutoMLUpdateDatasetOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH)
+
+ dataset = copy.deepcopy(DATASET)
+ dataset["name"] = DATASET_ID
+
+ op = AutoMLUpdateDatasetOperator(
+ dataset=dataset,
+ update_mask=MASK,
+ location=GCP_LOCATION,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.update_dataset.assert_called_once_with(
+ dataset=dataset,
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ update_mask=MASK,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLUpdateDatasetOperator,
+ # Templated fields
+ dataset="{{ 'dataset' }}",
+ update_mask="{{ 'update-mask' }}",
+ location="{{ 'location' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLUpdateDatasetOperator = ti.task
+ assert task.dataset == "dataset"
+ assert task.update_mask == "update-mask"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLTablesUpdateDatasetOperator:
+ def test_deprecation(self):
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ AutoMLTablesUpdateDatasetOperator(
+ dataset={},
+ update_mask=MASK,
+ location=GCP_LOCATION,
+ task_id=TASK_ID,
+ )
+
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH)
+
+ dataset = copy.deepcopy(DATASET)
+ dataset["name"] = DATASET_ID
+
+ op = AutoMLUpdateDatasetOperator(
+ dataset=dataset,
+ update_mask=MASK,
+ location=GCP_LOCATION,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.update_dataset.assert_called_once_with(
+ dataset=dataset,
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ update_mask=MASK,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLUpdateDatasetOperator,
+ # Templated fields
+ dataset="{{ 'dataset' }}",
+ update_mask="{{ 'update-mask' }}",
+ location="{{ 'location' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLUpdateDatasetOperator = ti.task
+ assert task.dataset == "dataset"
+ assert task.update_mask == "update-mask"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLGetModelOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH)
+ mock_hook.return_value.extract_object_id = extract_object_id
+
+ op = AutoMLGetModelOperator(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.get_model.assert_called_once_with(
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLGetModelOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLGetModelOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLDeleteModelOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ op = AutoMLDeleteModelOperator(
+ model_id=MODEL_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=None)
+ mock_hook.return_value.delete_model.assert_called_once_with(
+ location=GCP_LOCATION,
+ metadata=(),
+ model_id=MODEL_ID,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLDeleteModelOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLDeleteModelOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLDeployModelOperator:
+ def test_deprecation(self):
+ with pytest.raises(AirflowException):
+ AutoMLDeployModelOperator(
+ model_id=MODEL_ID,
+ image_detection_metadata={},
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+
+
+class TestAutoMLDatasetImportOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ op = AutoMLImportDataOperator(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ input_config=INPUT_CONFIG,
+ task_id=TASK_ID,
+ )
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.import_data.assert_called_once_with(
+ input_config=INPUT_CONFIG,
+ location=GCP_LOCATION,
+ metadata=(),
+ dataset_id=DATASET_ID,
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLImportDataOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ input_config="{{ 'input-config' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLImportDataOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.input_config == "input-config"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLTablesListTableSpecsOperator:
+ def test_deprecation(self):
+ with pytest.raises(AirflowException):
+ AutoMLTablesListTableSpecsOperator(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ filter_="filter",
+ page_size=10,
+ task_id=TASK_ID,
+ )
+
+
+class TestAutoMLDatasetListOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ op = AutoMLListDatasetOperator(location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID)
+ op.execute(context=mock.MagicMock())
+ mock_hook.return_value.list_datasets.assert_called_once_with(
+ location=GCP_LOCATION,
+ metadata=(),
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLListDatasetOperator,
+ # Templated fields
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLListDatasetOperator = ti.task
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
+
+class TestAutoMLDatasetDeleteOperator:
+ @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
+ def test_execute(self, mock_hook):
+ op = AutoMLDeleteDatasetOperator(
+ dataset_id=DATASET_ID,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ task_id=TASK_ID,
+ )
+ op.execute(context=None)
+ mock_hook.return_value.delete_dataset.assert_called_once_with(
+ location=GCP_LOCATION,
+ dataset_id=DATASET_ID,
+ metadata=(),
+ project_id=GCP_PROJECT_ID,
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLDeleteDatasetOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLDeleteDatasetOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
index 7305123cb0164..9ef04db81882b 100644
--- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
+++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
@@ -28,6 +28,7 @@
from google.protobuf.struct_pb2 import Value
from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
@@ -70,6 +71,7 @@
"gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
},
]
+extract_object_id = CloudAutoMLHook.extract_object_id
# Example DAG for AutoML Natural Language Text Classification
with DAG(
diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
index 916fc25877c9c..8f8564f62c209 100644
--- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
+++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
@@ -28,6 +28,7 @@
from google.protobuf.struct_pb2 import Value
from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
@@ -69,11 +70,7 @@
},
]
-
-def extract_object_id(obj: dict) -> str:
- """Returns unique id of the object."""
- return obj["name"].rpartition("/")[-1]
-
+extract_object_id = CloudAutoMLHook.extract_object_id
# Example DAG for AutoML Natural Language Entities Extraction
with DAG(
diff --git a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
index 0e641e1b05feb..94f349c6c3702 100644
--- a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
+++ b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
@@ -28,6 +28,7 @@
from google.protobuf.struct_pb2 import Value
from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.gcs import (
GCSCreateBucketOperator,
GCSDeleteBucketOperator,
@@ -70,11 +71,7 @@
},
]
-
-def extract_object_id(obj: dict) -> str:
- """Returns unique id of the object."""
- return obj["name"].rpartition("/")[-1]
-
+extract_object_id = CloudAutoMLHook.extract_object_id
# Example DAG for AutoML Natural Language Text Sentiment
with DAG(
diff --git a/tests/system/providers/google/cloud/automl/example_automl_translation.py b/tests/system/providers/google/cloud/automl/example_automl_translation.py
new file mode 100644
index 0000000000000..ba36f556c427d
--- /dev/null
+++ b/tests/system/providers/google/cloud/automl/example_automl_translation.py
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that uses Google AutoML services.
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+from typing import cast
+
+# The storage module cannot be imported yet https://github.com/googleapis/python-storage/issues/393
+from google.cloud import storage # type: ignore[attr-defined]
+
+from airflow.decorators import task
+from airflow.models.dag import DAG
+from airflow.models.xcom_arg import XComArg
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.operators.automl import (
+ AutoMLCreateDatasetOperator,
+ AutoMLDeleteDatasetOperator,
+ AutoMLDeleteModelOperator,
+ AutoMLImportDataOperator,
+ AutoMLTrainModelOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
+from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator
+from airflow.utils.trigger_rule import TriggerRule
+
+DAG_ID = "example_automl_translate"
+GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+GCP_AUTOML_LOCATION = "us-central1"
+DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
+RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
+
+
+MODEL_NAME = "translate_test_model"
+MODEL = {
+ "display_name": MODEL_NAME,
+ "translation_model_metadata": {},
+}
+
+DATASET_NAME = f"ds_translate_{ENV_ID}".replace("-", "_")
+DATASET = {
+ "display_name": DATASET_NAME,
+ "translation_dataset_metadata": {
+ "source_language_code": "en",
+ "target_language_code": "es",
+ },
+}
+
+CSV_FILE_NAME = "en-es.csv"
+TSV_FILE_NAME = "en-es.tsv"
+GCS_FILE_PATH = f"automl/datasets/translate/{CSV_FILE_NAME}"
+AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/{CSV_FILE_NAME}"
+IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
+
+extract_object_id = CloudAutoMLHook.extract_object_id
+
+
+# Example DAG for AutoML Translation
+with DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ user_defined_macros={"extract_object_id": extract_object_id},
+ tags=["example", "automl", "translate"],
+) as dag:
+ create_bucket = GCSCreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ storage_class="REGIONAL",
+ location=GCP_AUTOML_LOCATION,
+ )
+
+ @task
+ def upload_csv_file_to_gcs():
+ # download file into memory
+ storage_client = storage.Client()
+ bucket = storage_client.bucket(RESOURCE_DATA_BUCKET)
+ blob = bucket.blob(GCS_FILE_PATH)
+ contents = blob.download_as_string().decode()
+
+ # update memory content
+ updated_contents = contents.replace("template-bucket", DATA_SAMPLE_GCS_BUCKET_NAME)
+
+ # upload updated content to bucket
+ destination_bucket = storage_client.bucket(DATA_SAMPLE_GCS_BUCKET_NAME)
+ destination_blob = destination_bucket.blob(f"automl/{CSV_FILE_NAME}")
+ destination_blob.upload_from_string(updated_contents)
+
+ upload_csv_file_to_gcs_task = upload_csv_file_to_gcs()
+
+ copy_dataset_file = GCSToGCSOperator(
+ task_id="copy_dataset_file",
+ source_bucket=RESOURCE_DATA_BUCKET,
+ source_object=f"automl/datasets/translate/{TSV_FILE_NAME}",
+ destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
+ destination_object=f"automl/{TSV_FILE_NAME}",
+ )
+
+ create_dataset = AutoMLCreateDatasetOperator(
+ task_id="create_dataset", dataset=DATASET, location=GCP_AUTOML_LOCATION
+ )
+
+ dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
+
+ import_dataset = AutoMLImportDataOperator(
+ task_id="import_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ input_config=IMPORT_INPUT_CONFIG,
+ )
+
+ MODEL["dataset_id"] = dataset_id
+
+ create_model = AutoMLTrainModelOperator(task_id="create_model", model=MODEL, location=GCP_AUTOML_LOCATION)
+ model_id = cast(str, XComArg(create_model, key="model_id"))
+
+ delete_model = AutoMLDeleteModelOperator(
+ task_id="delete_model",
+ model_id=model_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ delete_dataset = AutoMLDeleteDatasetOperator(
+ task_id="delete_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ [create_bucket >> upload_csv_file_to_gcs_task >> copy_dataset_file]
+ # TEST BODY
+ >> create_dataset
+ >> import_dataset
+ >> create_model
+ # TEST TEARDOWN
+ >> delete_dataset
+ >> delete_model
+ >> delete_bucket
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)