Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add overwrite_method to postgresql.to_sql #2820

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions awswrangler/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,22 @@ def _validate_connection(con: "pg8000.Connection") -> None:
)


def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str) -> None:
def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None:
schema_str = f"{_identifier(schema)}." if schema else ""
sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)}"
cascade_str = "CASCADE" if cascade else "RESTRICT"
sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)} {cascade_str}"
_logger.debug("Drop table query:\n%s", sql)
cursor.execute(sql)


def _truncate_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None:
schema_str = f"{_identifier(schema)}." if schema else ""
cascade_str = "CASCADE" if cascade else "RESTRICT"
sql = f"TRUNCATE TABLE {schema_str}{_identifier(table)} {cascade_str}"
_logger.debug("Truncate table query:\n%s", sql)
cursor.execute(sql)


def _does_table_exist(cursor: "pg8000.Cursor", schema: str | None, table: str) -> bool:
schema_str = f"TABLE_SCHEMA = {pg8000_native.literal(schema)} AND" if schema else ""
cursor.execute(
Expand All @@ -66,12 +75,21 @@ def _create_table(
table: str,
schema: str,
mode: str,
overwrite_method: _ToSqlOverwriteModeLiteral,
index: bool,
dtype: dict[str, str] | None,
varchar_lengths: dict[str, int] | None,
) -> None:
if mode == "overwrite":
_drop_table(cursor=cursor, schema=schema, table=table)
if overwrite_method in ["drop", "cascade"]:
_drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade"))
elif overwrite_method in ["truncate", "truncate cascade"]:
if _does_table_exist(cursor=cursor, schema=schema, table=table):
_truncate_table(
cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "truncate cascade")
)
else:
raise exceptions.InvalidArgumentValue(f"Invalid overwrite_method: {overwrite_method}")
elif _does_table_exist(cursor=cursor, schema=schema, table=table):
return
postgresql_types: dict[str, str] = _data_types.database_types_from_pandas(
Expand Down Expand Up @@ -485,6 +503,7 @@ def read_sql_table(


_ToSqlModeLiteral = Literal["append", "overwrite", "upsert"]
_ToSqlOverwriteModeLiteral = Literal["drop", "cascade", "truncate", "truncate cascade"]


@_utils.check_optional_dependency(pg8000, "pg8000")
Expand All @@ -495,6 +514,7 @@ def to_sql(
table: str,
schema: str,
mode: _ToSqlModeLiteral = "append",
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
index: bool = False,
dtype: dict[str, str] | None = None,
varchar_lengths: dict[str, int] | None = None,
Expand Down Expand Up @@ -522,6 +542,13 @@ def to_sql(
overwrite: Drops table and recreates.
upsert: Perform an upsert which checks for conflicts on columns given by `upsert_conflict_columns` and
sets the new values on conflicts. Note that `upsert_conflict_columns` is required for this mode.
overwrite_method : str
Drop, cascade, truncate, or truncate cascade. Only applicable in overwrite mode.

"drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it.
"cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it.
"truncate" - ``TRUNCATE ... RESTRICT`` - truncates the table. Fails if any of the tables have foreign-key references from tables that are not listed in the command.
"truncate cascade" - ``TRUNCATE ... CASCADE`` - truncates the table, and all tables that have foreign-key references to any of the named tables.
index : bool
True to store the DataFrame index as a column in the table,
otherwise False to ignore it.
Expand Down Expand Up @@ -583,6 +610,7 @@ def to_sql(
table=table,
schema=schema,
mode=mode,
overwrite_method=overwrite_method,
index=index,
dtype=dtype,
varchar_lengths=varchar_lengths,
Expand Down
46 changes: 45 additions & 1 deletion tests/unit/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,51 @@ def test_read_sql_query_simple(databases_parameters):

def test_to_sql_simple(postgresql_table, postgresql_con):
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
wr.postgresql.to_sql(df, postgresql_con, postgresql_table, "public", "overwrite", True)
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
table=postgresql_table,
schema="public",
mode="overwrite",
index=True,
)


@pytest.mark.parametrize("overwrite_method", ["drop", "cascade", "truncate", "truncate cascade"])
def test_to_sql_overwrite(postgresql_table, postgresql_con, overwrite_method):
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
table=postgresql_table,
schema="public",
mode="overwrite",
overwrite_method=overwrite_method,
)
df = pd.DataFrame({"c0": [4, 5, 6], "c1": ["xoo", "yoo", "zoo"]})
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
table=postgresql_table,
schema="public",
mode="overwrite",
overwrite_method=overwrite_method,
)
df = wr.postgresql.read_sql_table(table=postgresql_table, schema="public", con=postgresql_con)
assert df.shape == (3, 2)


def test_unknown_overwrite_method_error(postgresql_table, postgresql_con):
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]})
with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.postgresql.to_sql(
df=df,
con=postgresql_con,
table=postgresql_table,
schema="public",
mode="overwrite",
overwrite_method="unknown",
)


def test_sql_types(postgresql_table, postgresql_con):
Expand Down
Loading