diff --git a/awswrangler/postgresql.py b/awswrangler/postgresql.py index 0617aea74..e0adf740c 100644 --- a/awswrangler/postgresql.py +++ b/awswrangler/postgresql.py @@ -42,13 +42,22 @@ def _validate_connection(con: "pg8000.Connection") -> None: ) -def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str) -> None: +def _drop_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None: schema_str = f"{_identifier(schema)}." if schema else "" - sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)}" + cascade_str = "CASCADE" if cascade else "RESTRICT" + sql = f"DROP TABLE IF EXISTS {schema_str}{_identifier(table)} {cascade_str}" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql) +def _truncate_table(cursor: "pg8000.Cursor", schema: str | None, table: str, cascade: bool) -> None: + schema_str = f"{_identifier(schema)}." if schema else "" + cascade_str = "CASCADE" if cascade else "RESTRICT" + sql = f"TRUNCATE TABLE {schema_str}{_identifier(table)} {cascade_str}" + _logger.debug("Truncate table query:\n%s", sql) + cursor.execute(sql) + + def _does_table_exist(cursor: "pg8000.Cursor", schema: str | None, table: str) -> bool: schema_str = f"TABLE_SCHEMA = {pg8000_native.literal(schema)} AND" if schema else "" cursor.execute( @@ -66,12 +75,21 @@ def _create_table( table: str, schema: str, mode: str, + overwrite_method: _ToSqlOverwriteModeLiteral, index: bool, dtype: dict[str, str] | None, varchar_lengths: dict[str, int] | None, ) -> None: if mode == "overwrite": - _drop_table(cursor=cursor, schema=schema, table=table) + if overwrite_method in ["drop", "cascade"]: + _drop_table(cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "cascade")) + elif overwrite_method in ["truncate", "truncate cascade"]: + if _does_table_exist(cursor=cursor, schema=schema, table=table): + _truncate_table( + cursor=cursor, schema=schema, table=table, cascade=bool(overwrite_method == "truncate cascade") + ) + else: + raise exceptions.InvalidArgumentValue(f"Invalid overwrite_method: {overwrite_method}") elif _does_table_exist(cursor=cursor, schema=schema, table=table): return postgresql_types: dict[str, str] = _data_types.database_types_from_pandas( @@ -485,6 +503,7 @@ def read_sql_table( _ToSqlModeLiteral = Literal["append", "overwrite", "upsert"] +_ToSqlOverwriteModeLiteral = Literal["drop", "cascade", "truncate", "truncate cascade"] @_utils.check_optional_dependency(pg8000, "pg8000") @@ -495,6 +514,7 @@ def to_sql( table: str, schema: str, mode: _ToSqlModeLiteral = "append", + overwrite_method: _ToSqlOverwriteModeLiteral = "drop", index: bool = False, dtype: dict[str, str] | None = None, varchar_lengths: dict[str, int] | None = None, @@ -522,6 +542,13 @@ def to_sql( overwrite: Drops table and recreates. upsert: Perform an upsert which checks for conflicts on columns given by `upsert_conflict_columns` and sets the new values on conflicts. Note that `upsert_conflict_columns` is required for this mode. + overwrite_method : str + Drop, cascade, truncate, or truncate cascade. Only applicable in overwrite mode. + + "drop" - ``DROP ... RESTRICT`` - drops the table. Fails if there are any views that depend on it. + "cascade" - ``DROP ... CASCADE`` - drops the table, and all views that depend on it. + "truncate" - ``TRUNCATE ... RESTRICT`` - truncates the table. Fails if any of the tables have foreign-key references from tables that are not listed in the command. + "truncate cascade" - ``TRUNCATE ... CASCADE`` - truncates the table, and all tables that have foreign-key references to any of the named tables. index : bool True to store the DataFrame index as a column in the table, otherwise False to ignore it. @@ -583,6 +610,7 @@ def to_sql( table=table, schema=schema, mode=mode, + overwrite_method=overwrite_method, index=index, dtype=dtype, varchar_lengths=varchar_lengths, diff --git a/tests/unit/test_postgresql.py b/tests/unit/test_postgresql.py index 878d341ac..56b14279e 100644 --- a/tests/unit/test_postgresql.py +++ b/tests/unit/test_postgresql.py @@ -49,7 +49,51 @@ def test_read_sql_query_simple(databases_parameters): def test_to_sql_simple(postgresql_table, postgresql_con): df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) - wr.postgresql.to_sql(df, postgresql_con, postgresql_table, "public", "overwrite", True) + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + table=postgresql_table, + schema="public", + mode="overwrite", + index=True, + ) + + +@pytest.mark.parametrize("overwrite_method", ["drop", "cascade", "truncate", "truncate cascade"]) +def test_to_sql_overwrite(postgresql_table, postgresql_con, overwrite_method): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + table=postgresql_table, + schema="public", + mode="overwrite", + overwrite_method=overwrite_method, + ) + df = pd.DataFrame({"c0": [4, 5, 6], "c1": ["xoo", "yoo", "zoo"]}) + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + table=postgresql_table, + schema="public", + mode="overwrite", + overwrite_method=overwrite_method, + ) + df = wr.postgresql.read_sql_table(table=postgresql_table, schema="public", con=postgresql_con) + assert df.shape == (3, 2) + + +def test_unknown_overwrite_method_error(postgresql_table, postgresql_con): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) + with pytest.raises(wr.exceptions.InvalidArgumentValue): + wr.postgresql.to_sql( + df=df, + con=postgresql_con, + table=postgresql_table, + schema="public", + mode="overwrite", + overwrite_method="unknown", + ) def test_sql_types(postgresql_table, postgresql_con):