Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create upstream_tasks parameter for dependencies independent of data transfers #585

Merged
merged 28 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/getting-started/GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,46 @@ or

In all scenarios, even if the user gives a non-temporary table, only temporary
tables will actually be deleted.

## Tying Astro SDK decorators to traditional Airflow Operators
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the getting started tutorial the best place for this documentation?

I feel it could be best placed within our Sphinx documentation as a reference page. I believe the idea of the tutorial was to walk users through a first example of the Python SDK without introducing too many possibilities. Could you confirm this, @mikeshwe ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tatiana this is one of the most common questions we get asked (how can I use astro-sdk with traditional airflow operators) so it seemed like the getting started page was a good place to do it. Glad to discuss other options though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fair. 👍


1. Operators that pass data that can be picked up by astro functions
2. Operators that don't pass any data but you want to run upstream of a task

### Scenario 1: Operators that pass on data to astro sdk tasks

When passing operators that return Xcom-based data, you can just pass those values
into the astro-sdk function using the `.output` function (or just using the output for values
created with the taskflow API)
```python
@task
def get_num_rows():
return 5


@aql.transform
def get_rows(table: Table, name: str, num_rows: int):
return "SELECT * FROM {{table}} WHERE name={{name}} LIMIT {{num_rows}}"


with dag:
name_from_env = BashOperator(...)
get_rows(table=Table(), name=name_from_env.output, num_rows=get_num_rows())
```

### Scenario 2: Operators that dont pass on data to astro sdk tasks

When tying traditional tasks to astro-sdk decorators, you might run into a situation where the original operators
might not pass any data. In these cases you can use the `upstream_tasks` function to set up dependencies between
traditional airflow tasks and Astro SDK tasks

```python
@aql.transform
def get_rows(table: Table, num_rows: int):
return "SELECT * FROM {{table}} LIMIT {{num_rows}}"


with dag:
bash_command = BashOperator(...)
get_rows(table=Table(), num_rows=5, upstream_tasks=[bash_command])
dimberman marked this conversation as resolved.
Show resolved Hide resolved
```
4 changes: 2 additions & 2 deletions src/astro/sql/operators/append.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from airflow.decorators.base import get_unique_task_id
from airflow.models.baseoperator import BaseOperator

from astro.databases import create_database
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.sql.table import Table

APPEND_COLUMN_TYPE = Optional[Union[List[str], Tuple[str], Dict[str, str]]]


class AppendOperator(BaseOperator):
class AppendOperator(AstroSQLBaseOperator):
"""
Append the source table rows into a destination table.

Expand Down
9 changes: 8 additions & 1 deletion src/astro/sql/operators/base_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

from astro.databases import create_database
from astro.databases.base import BaseDatabase
from astro.sql.operators.upstream_task_mixin import UpstreamTaskMixin
from astro.sql.table import Table
from astro.utils.table import find_first_table


class BaseSQLDecoratedOperator(DecoratedOperator):
class BaseSQLDecoratedOperator(UpstreamTaskMixin, DecoratedOperator):
"""Handles all decorator classes that can return a SQL function"""

database_impl: BaseDatabase
Expand All @@ -32,12 +33,18 @@ def __init__(
self.output_table: Table = self.op_kwargs.pop("output_table", Table())
self.handler = self.op_kwargs.pop("handler", handler)
self.conn_id = self.op_kwargs.pop("conn_id", conn_id)

self.sql = sql
self.parameters = parameters or {}
self.database = self.op_kwargs.pop("database", database)
self.schema = self.op_kwargs.pop("schema", schema)
self.op_args: Dict[str, Union[Table, pd.DataFrame]] = {}

# We purposely do NOT render upstream_tasks otherwise we could have a case where a user
# has 10 dataframes as upstream tasks and it crashes the worker
upstream_tasks = self.op_kwargs.pop("upstream_tasks", [])
super().__init__(
upstream_tasks=upstream_tasks,
**kwargs,
)

Expand Down
9 changes: 9 additions & 0 deletions src/astro/sql/operators/base_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC

from airflow.models.baseoperator import BaseOperator

from astro.sql.operators.upstream_task_mixin import UpstreamTaskMixin


class AstroSQLBaseOperator(UpstreamTaskMixin, BaseOperator, ABC):
pass
3 changes: 2 additions & 1 deletion src/astro/sql/operators/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from astro.databases import create_database
from astro.sql.operators.base_decorator import BaseSQLDecoratedOperator
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.sql.operators.dataframe import DataframeOperator
from astro.sql.operators.load_file import LoadFile
from astro.sql.table import Table
Expand All @@ -21,7 +22,7 @@ def filter_for_temp_tables(task_outputs: List[Any]) -> List[Table]:
return [t for t in task_outputs if isinstance(t, Table) and t.temp]


class CleanupOperator(BaseOperator):
class CleanupOperator(AstroSQLBaseOperator):
"""
Clean up temporary tables at the end of a DAG run. Temporary tables are the ones that are
generated by the SDK (where you do not pass a name arg to Table) or the ones that has the name
Expand Down
7 changes: 6 additions & 1 deletion src/astro/sql/operators/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astro.constants import ColumnCapitalization
from astro.databases import create_database
from astro.exceptions import IllegalLoadToDatabaseException
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.sql.table import Table
from astro.utils.dataframe import convert_columns_names_capitalization
from astro.utils.table import find_first_table
Expand Down Expand Up @@ -73,7 +74,7 @@ def load_op_kwarg_table_into_dataframe(
}


class DataframeOperator(DecoratedOperator):
class DataframeOperator(AstroSQLBaseOperator, DecoratedOperator):
def __init__(
self,
conn_id: Optional[str] = None,
Expand Down Expand Up @@ -111,7 +112,11 @@ def __init__(
self.op_args = self.kwargs.get("op_args", ()) # type: ignore
self.columns_names_capitalization = columns_names_capitalization

# We purposely do NOT render upstream_tasks otherwise we could have a case where a user
# has 10 dataframes as upstream tasks and it crashes the worker
upstream_tasks = self.op_kwargs.pop("upstream_tasks", [])
super().__init__(
upstream_tasks=upstream_tasks,
**kwargs,
)

Expand Down
5 changes: 3 additions & 2 deletions src/astro/sql/operators/drop.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Dict

from airflow.decorators.base import get_unique_task_id
from airflow.models import BaseOperator
from airflow.models.baseoperator import BaseOperator

from astro.databases import create_database
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.sql.table import Table


class DropTableOperator(BaseOperator):
class DropTableOperator(AstroSQLBaseOperator, BaseOperator):
dimberman marked this conversation as resolved.
Show resolved Hide resolved
"""Airflow Operator for dropping SQL tables."""

template_fields = ("table",)
Expand Down
5 changes: 3 additions & 2 deletions src/astro/sql/operators/export_file.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import Any, Optional, Union

import pandas as pd
from airflow.models import BaseOperator
from airflow.models.baseoperator import BaseOperator
from airflow.models.xcom_arg import XComArg

from astro.constants import ExportExistsStrategy
from astro.databases import create_database
from astro.files import File
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.sql.table import Table
from astro.utils.task_id_helper import get_task_id


class ExportFile(BaseOperator):
class ExportFile(AstroSQLBaseOperator, BaseOperator):
dimberman marked this conversation as resolved.
Show resolved Hide resolved
"""Write SQL table to csv/parquet on local/S3/GCS.

:param input_data: Table to convert to file
Expand Down
5 changes: 3 additions & 2 deletions src/astro/sql/operators/load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

import pandas as pd
from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.models.baseoperator import BaseOperator
from airflow.models.xcom_arg import XComArg

from astro.constants import DEFAULT_CHUNK_SIZE, ColumnCapitalization, LoadExistStrategy
from astro.databases import BaseDatabase, create_database
from astro.exceptions import IllegalLoadToDatabaseException
from astro.files import File, check_if_connection_exists, resolve_file_path_pattern
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.sql.table import Table
from astro.utils.task_id_helper import get_task_id


class LoadFile(BaseOperator):
class LoadFile(AstroSQLBaseOperator, BaseOperator):
dimberman marked this conversation as resolved.
Show resolved Hide resolved
"""Load S3/local file into either a database or a pandas dataframe

:param input_file: File path and conn_id for object stores
Expand Down
4 changes: 2 additions & 2 deletions src/astro/sql/operators/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from astro.constants import MergeConflictStrategy
from astro.databases import create_database
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.sql.table import Table

MERGE_COLUMN_TYPE = Union[List[str], Tuple[str], Dict[str, str]]


class MergeOperator(BaseOperator):
class MergeOperator(AstroSQLBaseOperator, BaseOperator):
dimberman marked this conversation as resolved.
Show resolved Hide resolved
"""
Merge the source table rows into a destination table.

Expand Down Expand Up @@ -49,7 +50,6 @@ def __init__(
self.columns = columns or {}
self.if_conflicts = if_conflicts
task_id = task_id or get_unique_task_id("_merge")

super().__init__(task_id=task_id, **kwargs)

def execute(self, context: dict) -> Table:
Expand Down
21 changes: 21 additions & 0 deletions src/astro/sql/operators/upstream_task_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.xcom_arg import XComArg


class UpstreamTaskMixin:
def __init__(self, **kwargs):
upstream_tasks = kwargs.pop("upstream_tasks", [])

super().__init__(**kwargs)

for task in upstream_tasks:
if isinstance(task, XComArg):
self.set_upstream(task.operator)
elif isinstance(task, BaseOperator):
self.set_upstream(task)
else:
raise AirflowException(
"Cannot upstream a non-task, please only use XcomArg or operators for this"
" parameter"
)
69 changes: 69 additions & 0 deletions tests/sql/operators/test_upstream_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pathlib

import pytest

from astro import sql as aql
from astro.constants import Database
from astro.files import File
from astro.sql.table import Table
from tests.sql.operators import utils as test_utils

cwd = pathlib.Path(__file__).parent


@pytest.mark.parametrize(
"database_table_fixture",
[
{"database": Database.SNOWFLAKE},
{"database": Database.BIGQUERY},
{"database": Database.POSTGRES},
{"database": Database.SQLITE},
],
indirect=True,
ids=["snowflake", "bigquery", "postgresql", "sqlite"],
)
def test_raw_sql_chained_queries(database_table_fixture, sample_dag):
import pandas

db, test_table = database_table_fixture

@aql.run_raw_sql(conn_id=db.conn_id)
def raw_sql_no_deps(new_table: Table, t_table: Table):
"""
Let' test without any data dependencies, purely using upstream_tasks
Returns:

"""
return "CREATE TABLE {{new_table}} AS SELECT * FROM {{t_table}}"

@aql.dataframe
def validate(df1: pandas.DataFrame, df2: pandas.DataFrame):
df1 = df1.sort_values(by=df1.columns.tolist()).reset_index(drop=True)
df2 = df2.sort_values(by=df2.columns.tolist()).reset_index(drop=True)
assert df1.equals(df2)

with sample_dag:
homes_file = aql.load_file(
input_file=File(path=str(cwd) + "/../../data/homes.csv"),
output_table=test_table,
)
generated_tables = []
last_task = homes_file
for _ in range(5):
n_table = test_table.create_similar_table()
n_task = raw_sql_no_deps(
new_table=n_table, t_table=test_table, upstream_tasks=[last_task]
)
generated_tables.append(n_table)
last_task = n_task

validated = validate(
df1=test_table, df2=generated_tables[-1], upstream_tasks=[last_task]
)
for table in generated_tables:
aql.drop_table(table, upstream_tasks=[validated])

test_utils.run_dag(sample_dag)
all_tasks = sample_dag.tasks
for t in all_tasks[1:]:
assert len(t.upstream_task_ids) == 1
1 change: 0 additions & 1 deletion tests/sql/operators/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def top_five_animations(input_table: Table) -> str:
"""

with sample_dag:

target_table = Table(name="test_is_{{ ds_nodash }}", conn_id="sqlite_default")

top_five_animations(input_table=imdb_table, output_table=target_table)
Expand Down