Skip to content

Commit

Permalink
Use truncate when syncing a table entirely
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAntoine committed Jul 30, 2024
1 parent 2245fb3 commit 8c0914d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 103 deletions.
13 changes: 8 additions & 5 deletions datascience/src/pipeline/flows/init_pno_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.pipeline.generic_tasks import load
from src.pipeline.shared_tasks.control_flow import check_flow_not_running
from src.pipeline.shared_tasks.infrastructure import get_table
from src.pipeline.utils import delete
from src.pipeline.utils import truncate


@task(checkpoint=False)
Expand Down Expand Up @@ -47,16 +47,19 @@ def load_pno_types_and_rules(
e = create_engine("monitorfish_remote")

with e.begin() as con:
delete(table=pno_type_rules_table, connection=con, logger=logger)
delete(table=pno_types_table, connection=con, logger=logger)
truncate(
tables=[pno_type_rules_table, pno_types_table],
connection=con,
logger=logger,
)

load(
pno_types,
table_name="pno_types",
schema="public",
connection=con,
logger=prefect.context.get("logger"),
how="replace",
how="append",
end_ddls=[
DDL(
"SELECT setval("
Expand All @@ -74,7 +77,7 @@ def load_pno_types_and_rules(
schema="public",
connection=con,
logger=prefect.context.get("logger"),
how="replace",
how="append",
pg_array_columns=[
"species",
"fao_areas",
Expand Down
68 changes: 30 additions & 38 deletions datascience/src/pipeline/flows/missions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import prefect
from prefect import Flow, Parameter, case, task
from prefect.executors import LocalDaskExecutor
from sqlalchemy import DDL
from sqlalchemy import Table

from src.db_config import create_engine
from src.pipeline.generic_tasks import extract, load
from src.pipeline.shared_tasks.control_flow import check_flow_not_running
from src.pipeline.shared_tasks.infrastructure import get_table
from src.pipeline.utils import truncate


@task(checkpoint=False)
Expand Down Expand Up @@ -100,53 +102,37 @@ def filter_missions_control_units(

@task(checkpoint=False)
def load_missions_and_missions_control_units(
missions: pd.DataFrame, missions_control_units: pd.DataFrame, loading_mode: str
missions: pd.DataFrame,
missions_control_units: pd.DataFrame,
analytics_missions_table: Table,
analytics_missions_control_units_table: Table,
):
# In "replace" loading mode, we want to replace all `missions`, so we use `replace`
# loading mode.

# In "upsert" loading mode, we want to replace only the missions whose `id` is
# present in the DataFrame. So we use `id` as the identifier and `upsert` loading
# mode.

assert loading_mode in ("replace", "upsert")
missions_id_column = "id" if loading_mode == "upsert" else None

"""
Truncates tables and populates them with data from input DataFrames.
"""
e = create_engine("monitorfish_remote")

with e.begin() as connection:
truncate(
tables=[analytics_missions_table, analytics_missions_control_units_table],
connection=connection,
logger=prefect.context.get("logger"),
)

load(
missions,
table_name="analytics_missions",
schema="public",
table_name=analytics_missions_table.name,
schema=analytics_missions_table.schema,
connection=connection,
logger=prefect.context.get("logger"),
pg_array_columns=["mission_types"],
how=loading_mode,
table_id_column=missions_id_column,
df_id_column=missions_id_column,
init_ddls=[
DDL(
"ALTER TABLE public.analytics_missions_control_units "
"ADD CONSTRAINT "
"analytics_missions_control_units_mission_id_cascade_fkey "
"FOREIGN KEY (mission_id) "
"REFERENCES public.analytics_missions (id) "
"ON DELETE CASCADE;"
),
],
end_ddls=[
DDL(
"ALTER TABLE public.analytics_missions_control_units "
"DROP CONSTRAINT "
"analytics_missions_control_units_mission_id_cascade_fkey;"
),
],
how="append",
)

load(
missions_control_units,
table_name="analytics_missions_control_units",
schema="public",
table_name=analytics_missions_control_units_table.name,
schema=analytics_missions_control_units_table.schema,
connection=connection,
logger=prefect.context.get("logger"),
how="append",
Expand All @@ -157,12 +143,15 @@ def load_missions_and_missions_control_units(
flow_not_running = check_flow_not_running()
with case(flow_not_running, True):
# Parameters
loading_mode = Parameter("loading_mode")
number_of_months = Parameter("number_of_months")

# Extract
missions = extract_missions(number_of_months=number_of_months)
missions_control_units = extract_missions_control_units()
analytics_missions_table = get_table("analytics_missions")
analytics_missions_control_units_table = get_table(
"analytics_missions_control_units"
)

# Transform
missions_control_units = filter_missions_control_units(
Expand All @@ -171,7 +160,10 @@ def load_missions_and_missions_control_units(

# Load
load_missions_and_missions_control_units(
missions, missions_control_units, loading_mode=loading_mode
missions,
missions_control_units,
analytics_missions_table,
analytics_missions_control_units_table,
)


Expand Down
6 changes: 3 additions & 3 deletions datascience/src/pipeline/generic_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def load(
logger (logging.Logger): logger instance
how (str): one of
- 'replace' to delete all rows in the table before loading
- 'replace' to truncate the table before loading
- 'append' to append the data to rows already in the table
- 'upsert' to append the rows to the table, replacing the rows whose id is
already
Expand Down Expand Up @@ -233,8 +233,8 @@ def load_with_connection(

table = get_table(table_name, schema, connection, logger)
if how == "replace":
# Delete all rows from table
utils.delete(table, connection, logger)
# Truncate table
utils.truncate([table], connection, logger)

elif how == "upsert":
# Delete rows that are in the DataFrame from the table
Expand Down
31 changes: 15 additions & 16 deletions datascience/src/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import shutil
import sys
from io import StringIO
from typing import Sequence, Union
from typing import List, Sequence, Union

import geoalchemy2
import sqlalchemy
from sqlalchemy import MetaData, Table, func, select
from sqlalchemy import MetaData, Table, func, select, text
from sqlalchemy.exc import InvalidRequestError

# ***************************** Database operations utils *****************************
Expand Down Expand Up @@ -53,23 +53,22 @@ def get_table(
return table


def delete(
table: sqlalchemy.Table,
def truncate(
tables: List[sqlalchemy.Table],
connection: sqlalchemy.engine.base.Connection,
logger: logging.Logger,
):
"""Deletes all rows from a table.
Useful to wipe a table before re-inserting fresh data in ETL jobs."""
count_statement = select(func.count()).select_from(table)
n = connection.execute(count_statement).fetchall()[0][0]
if logger:
logger.info(f"Found existing table {table.name} with {n} rows.")
logger.info(f"Deleting table {table.name}...")
connection.execute(table.delete())
count_statement = select(func.count()).select_from(table)
n = connection.execute(count_statement).fetchall()[0][0]
if logger:
logger.info(f"Rows after deletion: {n}.")
"""Truncate tables.
Useful to wipe tables before re-inserting fresh data in ETL jobs."""
for table in tables:
count_statement = select(func.count()).select_from(table)
n = connection.execute(count_statement).fetchall()[0][0]
logger.info(f"Table {table.name} has {n} rows.")

tables_list = ", ".join([f'"{table.schema}"."{table.name}"' for table in tables])
logger.info(f"Truncating tables {tables_list}...")

connection.execute(text(f"TRUNCATE {tables_list}"))


def delete_rows(
Expand Down
47 changes: 6 additions & 41 deletions datascience/tests/test_pipeline/test_flows/test_missions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from unittest.mock import patch

import pandas as pd
import pytest
import sqlalchemy
from prefect import task

Expand Down Expand Up @@ -100,8 +99,7 @@ def test_extract_missions_control_units(mock_extract):
assert isinstance(query, sqlalchemy.sql.elements.TextClause)


@pytest.mark.parametrize("loading_mode", ["replace", "upsert"])
def test_flow(reset_test_data, loading_mode):
def test_flow(reset_test_data):
missions_query = "SELECT * FROM analytics_missions ORDER BY id"
missions_control_units_query = (
"SELECT * FROM analytics_missions_control_units ORDER BY id"
Expand All @@ -113,7 +111,7 @@ def test_flow(reset_test_data, loading_mode):
)

flow.schedule = None
state = flow.run(loading_mode=loading_mode, number_of_months=12)
state = flow.run(number_of_months=12)
assert state.is_successful()

extracted_missions = state.result[flow.get_tasks("mock_extract_missions")[0]].result
Expand Down Expand Up @@ -157,40 +155,7 @@ def test_flow(reset_test_data, loading_mode):
filtered_missions_control_units.mission_id
) == extracted_missions_control_unit_ids.intersection(extracted_mission_ids)

if loading_mode == "upsert":
assert len(loaded_missions) == 27
assert set(loaded_missions.id) == extracted_mission_ids.union(
initial_mission_ids
)
assert set(
loaded_missions_control_units.mission_id
) == initial_mission_ids.union({112})

# Check data is updated for missions already present initially
assert (
initial_missions.loc[initial_missions.id == 1, "facade"] == "NAMO"
).all()
assert (
extracted_missions.loc[extracted_missions.id == 1, "facade"] == "Facade 1"
).all()
assert (
loaded_missions.loc[loaded_missions.id == 1, "facade"] == "Facade 1"
).all()
assert set(
initial_missions_control_units.loc[
initial_missions_control_units.mission_id == 1, "control_unit_id"
]
) == {5}
assert set(
loaded_missions_control_units.loc[
loaded_missions_control_units.mission_id == 1, "control_unit_id"
]
) == {7, 8}

else:
pd.testing.assert_frame_equal(
extracted_missions, loaded_missions, check_like=True
)
pd.testing.assert_frame_equal(
filtered_missions_control_units, loaded_missions_control_units
)
pd.testing.assert_frame_equal(extracted_missions, loaded_missions, check_like=True)
pd.testing.assert_frame_equal(
filtered_missions_control_units, loaded_missions_control_units
)

0 comments on commit 8c0914d

Please sign in to comment.