From 525fab36c5214da49cd06314484b7e8ad07ba3ba Mon Sep 17 00:00:00 2001 From: Zachary Deziel Date: Wed, 6 Nov 2024 10:51:34 -0800 Subject: [PATCH] Refacto approach to using database for join --- .../src/space2stats_ingest/main.py | 126 ++++++++---------- space2stats_api/src/tests/test_ingest_cli.py | 7 +- 2 files changed, 62 insertions(+), 71 deletions(-) diff --git a/space2stats_api/src/space2stats_ingest/main.py b/space2stats_api/src/space2stats_ingest/main.py index 21a905a..007f5fe 100644 --- a/space2stats_api/src/space2stats_ingest/main.py +++ b/space2stats_api/src/space2stats_ingest/main.py @@ -3,7 +3,6 @@ import adbc_driver_postgresql.dbapi as pg import boto3 -import numpy as np import pyarrow as pa import pyarrow.parquet as pq from pystac import Item, STACValidationError @@ -41,11 +40,18 @@ def validate_stac_item(stac_item_path: str) -> bool: def verify_columns(parquet_file: str, stac_item_path: str) -> bool: - """Verifies that the Parquet file columns match the STAC item metadata columns.""" + """Verifies that the Parquet file columns match the STAC item metadata columns + and ensures that 'hex_id' column is present.""" parquet_table = read_parquet_file(parquet_file) parquet_columns = set(parquet_table.column_names) stac_fields = get_stac_fields_from_item(stac_item_path) + # Check if 'hex_id' is present in the Parquet columns + # We are not verifying the hex level as new hex ids will cause error on SQL Update + if "hex_id" not in parquet_columns: + raise ValueError("The 'hex_id' column is missing from the Parquet file.") + + # Verify that columns in the Parquet file match the STAC item metadata columns if parquet_columns != stac_fields: extra_in_parquet = parquet_columns - stac_fields extra_in_stac = stac_fields - parquet_columns @@ -55,48 +61,6 @@ def verify_columns(parquet_file: str, stac_item_path: str) -> bool: return True -def read_table_from_db(connection_string: str, table_name: str) -> pa.Table: - """Reads a PostgreSQL table into an Arrow table, ordered by hex_id.""" - with pg.connect(connection_string) as conn: - with conn.cursor() as cur: - # Check if the table exists - cur.execute(f"SELECT to_regclass('{table_name}');") - if cur.fetchone()[0] is None: - raise ValueError( - f"Table '{table_name}' does not exist in the database." - ) - - # Fetch the table data ordered by hex_id - query = f"SELECT * FROM {table_name} ORDER BY hex_id" - cur.execute(query) - - return cur.fetch_arrow_table() - - -def validate_table_alignment( - db_table: pa.Table, parquet_table: pa.Table, sample_size: int = 1000 -): - """Ensures both tables have similar 'hex_id' values based on a random sample.""" - if db_table.num_rows != parquet_table.num_rows: - raise ValueError( - "Row counts do not match between the database and Parquet table." - ) - - # Determine the sample indices - total_rows = db_table.num_rows - sample_size = min(sample_size, total_rows) # Ensure sample size is within bounds - sample_indices = np.random.choice(total_rows, size=sample_size, replace=False) - - # Compare hex_id values at the sampled indices - db_sample = db_table["hex_id"].take(sample_indices) - parquet_sample = parquet_table["hex_id"].take(sample_indices) - - if not pa.compute.all(pa.compute.equal(db_sample, parquet_sample)).as_py(): - raise ValueError( - "hex_id columns do not match between database and Parquet tables for the sampled rows." - ) - - def merge_tables(db_table: pa.Table, parquet_table: pa.Table) -> pa.Table: """Adds columns from the Parquet table to the database table in memory.""" for column in parquet_table.column_names: @@ -144,34 +108,60 @@ def load_parquet_to_db( conn.commit() return - # Load the existing table and new table if the table already exists - db_table = read_table_from_db(connection_string, TABLE_NAME) - print("Read db table") - parquet_table = read_parquet_file(parquet_file).sort_by("hex_id") - print("read parquet table") - - # Validate alignment of the two tables using a sample - print("Validating alignment") - validate_table_alignment(db_table, parquet_table) - - # Merge tables in memory - print("Merge tables") - merged_table = merge_tables(db_table, parquet_table) - - # Write merged data back to the database in batches + parquet_table = read_parquet_file(parquet_file) + temp_table = f"{TABLE_NAME}_temp" with pg.connect(connection_string) as conn, tqdm( - total=merged_table.num_rows, desc="Ingesting Merged Data", unit="rows" + total=parquet_table.num_rows, desc="Ingesting Temporary Table", unit="rows" ) as pbar: with conn.cursor() as cur: - cur.execute(f"DROP TABLE IF EXISTS {TABLE_NAME}") - cur.adbc_ingest(TABLE_NAME, merged_table.slice(0, 0), mode="replace") + cur.adbc_ingest(temp_table, parquet_table.slice(0, 0), mode="replace") - for batch in merged_table.to_batches(max_chunksize=chunksize): - cur.adbc_ingest(TABLE_NAME, batch, mode="append") + for batch in parquet_table.to_batches(max_chunksize=chunksize): + cur.adbc_ingest(temp_table, batch, mode="append") pbar.update(batch.num_rows) - # Recreate index on hex_id - cur.execute( - f"CREATE INDEX idx_{TABLE_NAME}_hex_id ON {TABLE_NAME} (hex_id)" - ) + conn.commit() + + # Fetch columns to add to dataset + with pg.connect(connection_string) as conn: + with conn.cursor() as cur: + cur.execute(f""" + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name = '{temp_table}' + AND column_name NOT IN ( + SELECT column_name FROM information_schema.columns WHERE table_name = '{TABLE_NAME}' + ) + """) + new_columns = cur.fetchall() + + for column, column_type in new_columns: + cur.execute( + f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column} {column_type}" + ) + + conn.commit() + + print(f"Adding new columns: {new_columns}...") + + # Update TABLE_NAME with data from temp_table based on matching hex_id + print("Adding columns to dataset... All or nothing operation may take some time.") + with pg.connect(connection_string) as conn: + with conn.cursor() as cur: + update_columns = [f"{column} = temp.{column}" for column, _ in new_columns] + + set_clause = ", ".join(update_columns) + + cur.execute(f""" + UPDATE {TABLE_NAME} AS main + SET {set_clause} + FROM {temp_table} AS temp + WHERE main.hex_id = temp.hex_id + """) + + conn.commit() + + with pg.connect(connection_string) as conn: + with conn.cursor() as cur: + cur.execute(f"DROP TABLE {temp_table}") conn.commit() diff --git a/space2stats_api/src/tests/test_ingest_cli.py b/space2stats_api/src/tests/test_ingest_cli.py index b0b98ac..2e146e3 100644 --- a/space2stats_api/src/tests/test_ingest_cli.py +++ b/space2stats_api/src/tests/test_ingest_cli.py @@ -100,9 +100,10 @@ def test_load_command_column_mismatch(tmpdir, clean_database): collection_file = tmpdir.join("collection.json") item_file = tmpdir.join("space2stats_population_2020.json") - create_mock_parquet_file(parquet_file, [("different_column", pa.float64())]) - - create_stac_item(item_file, [("mock_column", "float64")]) + create_mock_parquet_file( + parquet_file, [("hex_id", pa.string()), ("different_column", pa.float64())] + ) + create_stac_item(item_file, [("hex_id", "string"), ("mock_column", "float64")]) create_stac_collection(collection_file, item_file) create_stac_catalog(catalog_file, collection_file)