Skip to content

Commit

Permalink
Refacto approach to using database for join
Browse files Browse the repository at this point in the history
  • Loading branch information
zacdezgeo committed Nov 6, 2024
1 parent 232046f commit 525fab3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 71 deletions.
126 changes: 58 additions & 68 deletions space2stats_api/src/space2stats_ingest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import adbc_driver_postgresql.dbapi as pg
import boto3
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from pystac import Item, STACValidationError
Expand Down Expand Up @@ -41,11 +40,18 @@ def validate_stac_item(stac_item_path: str) -> bool:


def verify_columns(parquet_file: str, stac_item_path: str) -> bool:
"""Verifies that the Parquet file columns match the STAC item metadata columns."""
"""Verifies that the Parquet file columns match the STAC item metadata columns
and ensures that 'hex_id' column is present."""
parquet_table = read_parquet_file(parquet_file)
parquet_columns = set(parquet_table.column_names)
stac_fields = get_stac_fields_from_item(stac_item_path)

# Check if 'hex_id' is present in the Parquet columns
# We are not verifying the hex level as new hex ids will cause error on SQL Update
if "hex_id" not in parquet_columns:
raise ValueError("The 'hex_id' column is missing from the Parquet file.")

# Verify that columns in the Parquet file match the STAC item metadata columns
if parquet_columns != stac_fields:
extra_in_parquet = parquet_columns - stac_fields
extra_in_stac = stac_fields - parquet_columns
Expand All @@ -55,48 +61,6 @@ def verify_columns(parquet_file: str, stac_item_path: str) -> bool:
return True


def read_table_from_db(connection_string: str, table_name: str) -> pa.Table:
"""Reads a PostgreSQL table into an Arrow table, ordered by hex_id."""
with pg.connect(connection_string) as conn:
with conn.cursor() as cur:
# Check if the table exists
cur.execute(f"SELECT to_regclass('{table_name}');")
if cur.fetchone()[0] is None:
raise ValueError(
f"Table '{table_name}' does not exist in the database."
)

# Fetch the table data ordered by hex_id
query = f"SELECT * FROM {table_name} ORDER BY hex_id"
cur.execute(query)

return cur.fetch_arrow_table()


def validate_table_alignment(
db_table: pa.Table, parquet_table: pa.Table, sample_size: int = 1000
):
"""Ensures both tables have similar 'hex_id' values based on a random sample."""
if db_table.num_rows != parquet_table.num_rows:
raise ValueError(
"Row counts do not match between the database and Parquet table."
)

# Determine the sample indices
total_rows = db_table.num_rows
sample_size = min(sample_size, total_rows) # Ensure sample size is within bounds
sample_indices = np.random.choice(total_rows, size=sample_size, replace=False)

# Compare hex_id values at the sampled indices
db_sample = db_table["hex_id"].take(sample_indices)
parquet_sample = parquet_table["hex_id"].take(sample_indices)

if not pa.compute.all(pa.compute.equal(db_sample, parquet_sample)).as_py():
raise ValueError(
"hex_id columns do not match between database and Parquet tables for the sampled rows."
)


def merge_tables(db_table: pa.Table, parquet_table: pa.Table) -> pa.Table:
"""Adds columns from the Parquet table to the database table in memory."""
for column in parquet_table.column_names:
Expand Down Expand Up @@ -144,34 +108,60 @@ def load_parquet_to_db(
conn.commit()
return

# Load the existing table and new table if the table already exists
db_table = read_table_from_db(connection_string, TABLE_NAME)
print("Read db table")
parquet_table = read_parquet_file(parquet_file).sort_by("hex_id")
print("read parquet table")

# Validate alignment of the two tables using a sample
print("Validating alignment")
validate_table_alignment(db_table, parquet_table)

# Merge tables in memory
print("Merge tables")
merged_table = merge_tables(db_table, parquet_table)

# Write merged data back to the database in batches
parquet_table = read_parquet_file(parquet_file)
temp_table = f"{TABLE_NAME}_temp"
with pg.connect(connection_string) as conn, tqdm(
total=merged_table.num_rows, desc="Ingesting Merged Data", unit="rows"
total=parquet_table.num_rows, desc="Ingesting Temporary Table", unit="rows"
) as pbar:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {TABLE_NAME}")
cur.adbc_ingest(TABLE_NAME, merged_table.slice(0, 0), mode="replace")
cur.adbc_ingest(temp_table, parquet_table.slice(0, 0), mode="replace")

for batch in merged_table.to_batches(max_chunksize=chunksize):
cur.adbc_ingest(TABLE_NAME, batch, mode="append")
for batch in parquet_table.to_batches(max_chunksize=chunksize):
cur.adbc_ingest(temp_table, batch, mode="append")
pbar.update(batch.num_rows)

# Recreate index on hex_id
cur.execute(
f"CREATE INDEX idx_{TABLE_NAME}_hex_id ON {TABLE_NAME} (hex_id)"
)
conn.commit()

# Fetch columns to add to dataset
with pg.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute(f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = '{temp_table}'
AND column_name NOT IN (
SELECT column_name FROM information_schema.columns WHERE table_name = '{TABLE_NAME}'
)
""")
new_columns = cur.fetchall()

for column, column_type in new_columns:
cur.execute(
f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column} {column_type}"
)

conn.commit()

print(f"Adding new columns: {new_columns}...")

# Update TABLE_NAME with data from temp_table based on matching hex_id
print("Adding columns to dataset... All or nothing operation may take some time.")
with pg.connect(connection_string) as conn:
with conn.cursor() as cur:
update_columns = [f"{column} = temp.{column}" for column, _ in new_columns]

set_clause = ", ".join(update_columns)

cur.execute(f"""
UPDATE {TABLE_NAME} AS main
SET {set_clause}
FROM {temp_table} AS temp
WHERE main.hex_id = temp.hex_id
""")

conn.commit()

with pg.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE {temp_table}")
conn.commit()
7 changes: 4 additions & 3 deletions space2stats_api/src/tests/test_ingest_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def test_load_command_column_mismatch(tmpdir, clean_database):
collection_file = tmpdir.join("collection.json")
item_file = tmpdir.join("space2stats_population_2020.json")

create_mock_parquet_file(parquet_file, [("different_column", pa.float64())])

create_stac_item(item_file, [("mock_column", "float64")])
create_mock_parquet_file(
parquet_file, [("hex_id", pa.string()), ("different_column", pa.float64())]
)
create_stac_item(item_file, [("hex_id", "string"), ("mock_column", "float64")])

create_stac_collection(collection_file, item_file)
create_stac_catalog(catalog_file, collection_file)
Expand Down

0 comments on commit 525fab3

Please sign in to comment.