diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc.py b/airflow/providers/google/cloud/example_dags/example_dataproc.py index a2f1c82bfe36..b6e1070df72b 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataproc.py +++ b/airflow/providers/google/cloud/example_dags/example_dataproc.py @@ -38,7 +38,8 @@ OUTPUT_PATH = "gs://{}/{}/".format(BUCKET, OUTPUT_FOLDER) PYSPARK_MAIN = os.environ.get("PYSPARK_MAIN", "hello_world.py") PYSPARK_URI = "gs://{}/{}".format(BUCKET, PYSPARK_MAIN) - +SPARKR_MAIN = os.environ.get("SPARKR_MAIN", "hello_world.R") +SPARKR_URI = "gs://{}/{}".format(BUCKET, SPARKR_MAIN) # Cluster definition CLUSTER = { @@ -104,6 +105,12 @@ "pyspark_job": {"main_python_file_uri": PYSPARK_URI}, } +SPARKR_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "spark_r_job": {"main_r_file_uri": SPARKR_URI}, +} + HIVE_JOB = { "reference": {"project_id": PROJECT_ID}, "placement": {"cluster_name": CLUSTER_NAME}, @@ -157,6 +164,10 @@ task_id="pyspark_task", job=PYSPARK_JOB, location=REGION, project_id=PROJECT_ID ) + sparkr_task = DataprocSubmitJobOperator( + task_id="sparkr_task", job=SPARKR_JOB, location=REGION, project_id=PROJECT_ID + ) + hive_task = DataprocSubmitJobOperator( task_id="hive_task", job=HIVE_JOB, location=REGION, project_id=PROJECT_ID ) @@ -178,4 +189,5 @@ scale_cluster >> spark_sql_task >> delete_cluster scale_cluster >> spark_task >> delete_cluster scale_cluster >> pyspark_task >> delete_cluster + scale_cluster >> sparkr_task >> delete_cluster scale_cluster >> hadoop_task >> delete_cluster diff --git a/tests/providers/google/cloud/operators/test_dataproc_system.py b/tests/providers/google/cloud/operators/test_dataproc_system.py index e6779a6baa9d..863ad5b5ce09 100644 --- a/tests/providers/google/cloud/operators/test_dataproc_system.py +++ b/tests/providers/google/cloud/operators/test_dataproc_system.py @@ -25,6 +25,8 @@ BUCKET = os.environ.get("GCP_DATAPROC_BUCKET", "dataproc-system-tests") PYSPARK_MAIN = os.environ.get("PYSPARK_MAIN", "hello_world.py") PYSPARK_URI = "gs://{}/{}".format(BUCKET, PYSPARK_MAIN) +SPARKR_MAIN = os.environ.get("SPARKR_MAIN", "hello_world.R") +SPARKR_URI = "gs://{}/{}".format(BUCKET, SPARKR_MAIN) pyspark_file = """ #!/usr/bin/python @@ -35,16 +37,32 @@ print(words) """ +sparkr_file = """ +#!/usr/bin/r +if (nchar(Sys.getenv("SPARK_HOME")) < 1) { +Sys.setenv(SPARK_HOME = "/home/spark") +} +library(SparkR, lib.loc = c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"))) +sparkR.session() +# Create the SparkDataFrame +df <- as.DataFrame(faithful) +head(summarize(groupBy(df, df$waiting), count = n(df$waiting))) +""" + @pytest.mark.backend("mysql", "postgres") @pytest.mark.credential_file(GCP_DATAPROC_KEY) class DataprocExampleDagsTest(GoogleSystemTest): - @provide_gcp_context(GCP_DATAPROC_KEY) def setUp(self): super().setUp() self.create_gcs_bucket(BUCKET) - self.upload_content_to_gcs(lines=pyspark_file, bucket=PYSPARK_URI, filename=PYSPARK_MAIN) + self.upload_content_to_gcs( + lines=pyspark_file, bucket=PYSPARK_URI, filename=PYSPARK_MAIN + ) + self.upload_content_to_gcs( + lines=sparkr_file, bucket=SPARKR_URI, filename=SPARKR_MAIN + ) @provide_gcp_context(GCP_DATAPROC_KEY) def tearDown(self):