diff --git a/src/astro/databases/base.py b/src/astro/databases/base.py index a08940cd2..af8e1e0fc 100644 --- a/src/astro/databases/base.py +++ b/src/astro/databases/base.py @@ -24,6 +24,15 @@ from astro.sql.table import Metadata, Table +class DatabaseCustomError(ValueError, AttributeError): + """ + Inappropriate argument value (of correct type) or attribute + not found while running query. while running query + """ + + pass + + class BaseDatabase(ABC): """ Base class to represent all the Database interactions. @@ -48,7 +57,7 @@ class BaseDatabase(ABC): # illegal_column_name_chars[0] will be replaced by value in illegal_column_name_chars_replacement[0] illegal_column_name_chars: List[str] = [] illegal_column_name_chars_replacement: List[str] = [] - NATIVE_LOAD_EXCEPTIONS: Any = (ValueError, AttributeError) + NATIVE_LOAD_EXCEPTIONS: Any = DatabaseCustomError def __init__(self, conn_id: str): self.conn_id = conn_id diff --git a/src/astro/databases/google/bigquery.py b/src/astro/databases/google/bigquery.py index aa482aaa1..1ae9c9690 100644 --- a/src/astro/databases/google/bigquery.py +++ b/src/astro/databases/google/bigquery.py @@ -1,4 +1,5 @@ """Google BigQuery table implementation.""" +import logging import time from typing import Any, Dict, List, Optional, Tuple @@ -45,7 +46,7 @@ LoadExistStrategy, MergeConflictStrategy, ) -from astro.databases.base import BaseDatabase +from astro.databases.base import BaseDatabase, DatabaseCustomError from astro.files import File from astro.sql.table import Metadata, Table @@ -73,8 +74,6 @@ class BigqueryDatabase(BaseDatabase): illegal_column_name_chars: List[str] = ["."] illegal_column_name_chars_replacement: List[str] = ["_"] NATIVE_LOAD_EXCEPTIONS: Any = ( - ValueError, - AttributeError, GoogleNotFound, ClientError, GoogleAPIError, @@ -89,7 +88,7 @@ class BigqueryDatabase(BaseDatabase): Unknown, ServiceUnavailable, InvalidResponse, - OSError, + DatabaseCustomError, ) def __init__(self, conn_id: str = DEFAULT_CONN_ID): @@ -244,7 +243,7 @@ def load_file_to_table_natively( **kwargs, ) else: - raise ValueError( + raise DatabaseCustomError( f"No transfer performed since there is no optimised path " f"for {source_file.location.location_type} to bigquery." ) @@ -334,8 +333,11 @@ def get_project_id(self, target_table) -> str: """ try: return str(self.hook.project_id) - except AttributeError: - raise ValueError(f"conn_id {target_table.conn_id} has no project id") + except AttributeError as exe: + logging.warning(exe) + raise DatabaseCustomError( + f"conn_id {target_table.conn_id} has no project id" + ) from exe def load_local_file_to_table( self, @@ -438,7 +440,7 @@ def run(self): time.sleep(self.poll_duration) if run_info.state != TransferState.SUCCEEDED: - raise ValueError(run_info.error_status) + raise DatabaseCustomError(run_info.error_status) finally: # delete transfer config created. self.delete_transfer_config(transfer_config_id) diff --git a/src/astro/databases/snowflake.py b/src/astro/databases/snowflake.py index 9b6bc7f36..cdedd286d 100644 --- a/src/astro/databases/snowflake.py +++ b/src/astro/databases/snowflake.py @@ -31,7 +31,7 @@ LoadExistStrategy, MergeConflictStrategy, ) -from astro.databases.base import BaseDatabase +from astro.databases.base import BaseDatabase, DatabaseCustomError from astro.files import File from astro.sql.table import Metadata, Table @@ -159,8 +159,7 @@ class SnowflakeDatabase(BaseDatabase): """ NATIVE_LOAD_EXCEPTIONS: Any = ( - ValueError, - AttributeError, + DatabaseCustomError, ProgrammingError, DatabaseError, OperationalError, @@ -235,7 +234,7 @@ def _create_stage_auth_sub_statement( auth = f"storage_integration = {storage_integration};" else: if file.location.location_type == FileLocation.GS: - raise ValueError( + raise DatabaseCustomError( "In order to create an stage for GCS, `storage_integration` is required." ) elif file.location.location_type == FileLocation.S3: @@ -243,7 +242,7 @@ def _create_stage_auth_sub_statement( if aws.access_key and aws.secret_key: auth = f"credentials=(aws_key_id='{aws.access_key}' aws_secret_key='{aws.secret_key}');" else: - raise ValueError( + raise DatabaseCustomError( "In order to create an stage for S3, one of the following is required: " "* `storage_integration`" "* AWS_KEY_ID and SECRET_KEY_ID" @@ -419,7 +418,11 @@ def load_file_to_table_natively( sql_statement = ( f"COPY INTO {table_name} FROM @{stage.qualified_name}/{file_path}" ) - self.hook.run(sql_statement) + try: + self.hook.run(sql_statement) + except (ValueError, AttributeError) as exe: + logging.warning(exe) + raise DatabaseCustomError from exe self.drop_stage(stage) def load_pandas_dataframe_to_table( @@ -591,7 +594,7 @@ def _build_merge_sql( values_to_check.extend(target_cols) for v in values_to_check: if not is_valid_snow_identifier(v): - raise ValueError( + raise DatabaseCustomError( f"The identifier {v} is invalid. Please check to prevent SQL injection" ) if if_conflicts == "update": diff --git a/tests/databases/test_bigquery.py b/tests/databases/test_bigquery.py index 1789e3858..2e4deaf79 100644 --- a/tests/databases/test_bigquery.py +++ b/tests/databases/test_bigquery.py @@ -15,6 +15,7 @@ from astro.constants import Database from astro.databases import create_database +from astro.databases.base import DatabaseCustomError from astro.databases.google.bigquery import BigqueryDatabase, S3ToBigqueryDataTransfer from astro.exceptions import NonExistentTableException from astro.files import File @@ -243,7 +244,7 @@ def test_load_file_to_table_natively_for_fallback( mock_load_file, database_table_fixture ): """Test loading on files to bigquery natively for fallback.""" - mock_load_file.side_effect = AttributeError + mock_load_file.side_effect = DatabaseCustomError database, target_table = database_table_fixture filepath = str(pathlib.Path(CWD.parent, "data/sample.csv")) response = database.load_file_to_table_natively_with_fallback( @@ -252,6 +253,69 @@ def test_load_file_to_table_natively_for_fallback( assert response is None +@pytest.mark.integration +@pytest.mark.parametrize( + "database_table_fixture", + [ + { + "database": Database.BIGQUERY, + "table": Table(metadata=Metadata(schema=SCHEMA)), + }, + ], + indirect=True, + ids=["bigquery"], +) +def test_load_file_to_table_natively_for_fallback_wrong_file_location( + database_table_fixture, +): + """ + Test loading on files to bigquery natively for fallback without fallback + gracefully for wrong file location. + """ + database, target_table = database_table_fixture + filepath = "https://www.data.com/data/sample.json" + + response = database.load_file_to_table_natively_with_fallback( + source_file=File(filepath), + target_table=target_table, + enable_native_fallback=False, + ) + assert response is None + + +@pytest.mark.integration +@pytest.mark.parametrize( + "database_table_fixture", + [ + { + "database": Database.BIGQUERY, + "table": Table(metadata=Metadata(schema=SCHEMA)), + }, + ], + indirect=True, + ids=["bigquery"], +) +@mock.patch("astro.databases.google.bigquery.BigqueryDatabase.hook") +def test_get_project_id_raise_exception( + mock_hook, + database_table_fixture, +): + """ + Test loading on files to bigquery natively for fallback without fallback + gracefully for wrong file location. + """ + + class CustomAttibuteError: + def __str__(self): + raise AttributeError + + mock_hook.project_id = CustomAttibuteError() + database, target_table = database_table_fixture + + with pytest.raises(DatabaseCustomError): + database.get_project_id(target_table=target_table) + + @pytest.mark.parametrize( "database_table_fixture", [ diff --git a/tests/databases/test_snowflake.py b/tests/databases/test_snowflake.py index 44d85375a..34bd9fcfc 100644 --- a/tests/databases/test_snowflake.py +++ b/tests/databases/test_snowflake.py @@ -1,6 +1,7 @@ """Tests specific to the Sqlite Database implementation.""" import os import pathlib +from unittest import mock from unittest.mock import patch import pandas as pd @@ -10,6 +11,7 @@ from astro.constants import Database, FileLocation, FileType from astro.databases import create_database +from astro.databases.base import DatabaseCustomError from astro.databases.snowflake import SnowflakeDatabase, SnowflakeStage from astro.exceptions import NonExistentTableException from astro.files import File @@ -184,6 +186,68 @@ def test_load_file_to_table(database_table_fixture): test_utils.assert_dataframes_are_equal(df, expected) +@pytest.mark.integration +@pytest.mark.parametrize( + "database_table_fixture", + [ + { + "database": Database.SNOWFLAKE, + "table": Table(metadata=Metadata(schema=SCHEMA)), + }, + ], + indirect=True, + ids=["snowflake"], +) +@mock.patch("astro.databases.snowflake.SnowflakeDatabase.hook") +@mock.patch("astro.databases.snowflake.SnowflakeDatabase.create_stage") +def test_load_file_to_table_natively_for_fallback( + mock_stage, mock_hook, database_table_fixture +): + """Test loading on files to bigquery natively for fallback.""" + mock_hook.run.side_effect = ValueError + mock_stage.return_value = SnowflakeStage( + name="mock_stage", + url="gcs://bucket/prefix", + metadata=Metadata(database="SNOWFLAKE_DATABASE", schema="SNOWFLAKE_SCHEMA"), + ) + database, target_table = database_table_fixture + filepath = str(pathlib.Path(CWD.parent, "data/sample.csv")) + response = database.load_file_to_table_natively_with_fallback( + source_file=File(filepath), + target_table=target_table, + enable_native_fallback=False, + ) + assert response is None + + +@pytest.mark.integration +@pytest.mark.parametrize( + "database_table_fixture", + [ + { + "database": Database.SNOWFLAKE, + "table": Table(name="test_table", metadata=Metadata(schema=SCHEMA)), + }, + ], + indirect=True, + ids=["snowflake"], +) +@mock.patch("astro.databases.snowflake.is_valid_snow_identifier") +def test_build_merge_sql(mock_is_valid_snow_identifier, database_table_fixture): + """Test build merge SQL for DatabaseCustomError""" + mock_is_valid_snow_identifier.return_value = False + database, target_table = database_table_fixture + with pytest.raises(DatabaseCustomError): + database._build_merge_sql( + source_table=Table( + name="source_test_table", metadata=Metadata(schema=SCHEMA) + ), + target_table=target_table, + source_to_target_columns_map={"list": "val"}, + target_conflict_columns=["target"], + ) + + @pytest.mark.parametrize( "database_table_fixture", [