Skip to content

Commit

Permalink
fix: add primary keys when upserting in Postgres (#2819)
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonMantulo authored May 15, 2024
1 parent 8ea7427 commit b3f215e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 42 deletions.
4 changes: 4 additions & 0 deletions awswrangler/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _create_table(
index: bool,
dtype: dict[str, str] | None,
varchar_lengths: dict[str, int] | None,
unique_keys: list[str] | None = None,
) -> None:
if mode == "overwrite":
if overwrite_method in ["drop", "cascade"]:
Expand All @@ -101,6 +102,8 @@ def _create_table(
converter_func=_data_types.pyarrow2postgresql,
)
cols_str: str = "".join([f"{_identifier(k)} {v},\n" for k, v in postgresql_types.items()])[:-2]
if unique_keys:
cols_str += f",\nUNIQUE ({', '.join([_identifier(k) for k in unique_keys])})"
sql = f"CREATE TABLE IF NOT EXISTS {_identifier(schema)}.{_identifier(table)} (\n{cols_str})"
_logger.debug("Create table query:\n%s", sql)
cursor.execute(sql)
Expand Down Expand Up @@ -619,6 +622,7 @@ def to_sql(
index=index,
dtype=dtype,
varchar_lengths=varchar_lengths,
unique_keys=upsert_conflict_columns or insert_conflict_columns,
)
if index:
df.reset_index(level=df.index.names, inplace=True)
Expand Down
42 changes: 0 additions & 42 deletions tests/unit/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,6 @@ def test_dfs_are_equal_for_different_chunksizes(postgresql_table, postgresql_con


def test_upsert(postgresql_table, postgresql_con):
create_table_sql = (
f"CREATE TABLE public.{postgresql_table} "
"(c0 varchar NULL PRIMARY KEY,"
"c1 int NULL DEFAULT 42,"
"c2 int NOT NULL);"
)
with postgresql_con.cursor() as cursor:
cursor.execute(create_table_sql)
postgresql_con.commit()

df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})

with pytest.raises(wr.exceptions.InvalidArgumentValue):
Expand Down Expand Up @@ -369,17 +359,6 @@ def test_upsert(postgresql_table, postgresql_con):


def test_upsert_multiple_conflict_columns(postgresql_table, postgresql_con):
create_table_sql = (
f"CREATE TABLE public.{postgresql_table} "
"(c0 varchar NULL PRIMARY KEY,"
"c1 int NOT NULL,"
"c2 int NOT NULL,"
"UNIQUE (c1, c2));"
)
with postgresql_con.cursor() as cursor:
cursor.execute(create_table_sql)
postgresql_con.commit()

df = pd.DataFrame({"c0": ["foo", "bar"], "c1": [1, 2], "c2": [3, 4]})
upsert_conflict_columns = ["c1", "c2"]

Expand Down Expand Up @@ -437,16 +416,6 @@ def test_upsert_multiple_conflict_columns(postgresql_table, postgresql_con):


def test_insert_ignore_duplicate_columns(postgresql_table, postgresql_con):
create_table_sql = (
f"CREATE TABLE public.{postgresql_table} "
"(c0 varchar NULL PRIMARY KEY,"
"c1 int NULL DEFAULT 42,"
"c2 int NOT NULL);"
)
with postgresql_con.cursor() as cursor:
cursor.execute(create_table_sql)
postgresql_con.commit()

df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})

wr.postgresql.to_sql(
Expand Down Expand Up @@ -501,17 +470,6 @@ def test_insert_ignore_duplicate_columns(postgresql_table, postgresql_con):


def test_insert_ignore_duplicate_multiple_columns(postgresql_table, postgresql_con):
create_table_sql = (
f"CREATE TABLE public.{postgresql_table} "
"(c0 varchar NULL PRIMARY KEY,"
"c1 int NOT NULL,"
"c2 int NOT NULL,"
"UNIQUE (c1, c2));"
)
with postgresql_con.cursor() as cursor:
cursor.execute(create_table_sql)
postgresql_con.commit()

df = pd.DataFrame({"c0": ["foo", "bar"], "c1": [1, 2], "c2": [3, 4]})
insert_conflict_columns = ["c1", "c2"]

Expand Down

0 comments on commit b3f215e

Please sign in to comment.