diff --git a/data_diff/__init__.py b/data_diff/__init__.py index b43807d3..8b763527 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -2,7 +2,7 @@ from .tracking import disable_tracking from .databases import connect -from .sqeleton.databases import DbKey, DbTime, DbPath +from .sqeleton.abcs import DbKey, DbTime, DbPath from .diff_tables import Algorithm from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from .joindiff_tables import JoinDiffer diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 22a1150c..ee5eb954 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -14,7 +14,7 @@ from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer from .table_segment import TableSegment -from .sqeleton.databases import create_schema +from .sqeleton.schema import create_schema from .databases import connect from .parse_time import parse_time_before_now, UNITS_STR, ParseError from .config import apply_config_from_file @@ -54,10 +54,10 @@ def diff_schemas(table1, table2, schema1, schema2, columns): diffs = [] if c not in schema1: - cols = ', '.join(schema1) + cols = ", ".join(schema1) raise ValueError(f"Column '{c}' not found in table 1, named '{table1}'. Columns: {cols}") if c not in schema2: - cols = ', '.join(schema1) + cols = ", ".join(schema1) raise ValueError(f"Column '{c}' not found in table 2, named '{table2}'. Columns: {cols}") col1 = schema1[c] @@ -216,7 +216,6 @@ def main(conf, run, **kw): raise - def _main( database1, table1, diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 704cc6d5..5b7ff5ce 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,4 +1,4 @@ -from data_diff.sqeleton.databases import AbstractMixin_MD5, AbstractMixin_NormalizeValue +from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue class DatadiffDialect(AbstractMixin_MD5, AbstractMixin_NormalizeValue): diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 58687c42..1876fa54 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -18,7 +18,7 @@ from .thread_utils import ThreadedYielder from .table_segment import TableSegment from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled -from .sqeleton.databases import IKey +from .sqeleton.abcs import IKey logger = getLogger(__name__) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 653b9c74..0395506c 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -11,7 +11,7 @@ from .utils import safezip from .thread_utils import ThreadedYielder -from .sqeleton.databases import ColType_UUID, NumericType, PrecisionType, StringType +from .sqeleton.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType from .table_segment import TableSegment from .diff_tables import TableDiffer diff --git a/data_diff/info_tree.py b/data_diff/info_tree.py index d3b77d97..51a94509 100644 --- a/data_diff/info_tree.py +++ b/data_diff/info_tree.py @@ -31,8 +31,8 @@ def update_from_children(self, child_infos): self.is_diff = any(c.is_diff for c in child_infos) self.rowcounts = { - 1: sum(c.rowcounts[1] for c in child_infos), - 2: sum(c.rowcounts[2] for c in child_infos), + 1: sum(c.rowcounts[1] for c in child_infos if c.rowcounts), + 2: sum(c.rowcounts[2] for c in child_infos if c.rowcounts), } diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index f3ce3f57..158b7347 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -10,10 +10,11 @@ from runtype import dataclass -from .sqeleton.databases import Database, DbPath, NumericType, MySQL, BigQuery, Presto, Oracle, Snowflake +from .sqeleton.databases import Database, MySQL, BigQuery, Presto, Oracle, Snowflake +from .sqeleton.abcs.database_types import DbPath, NumericType from .sqeleton.queries import table, sum_, min_, max_, avg from .sqeleton.queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable -from .sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath +from .sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code from .sqeleton.queries.compiler import Compiler from .sqeleton.queries.extras import NormalizeAsString @@ -332,7 +333,7 @@ def exclusive_rows(expr): c = Compiler(db) name = c.new_unique_table_name("temp_table") exclusive_rows = table(name, schema=expr.source_table.schema) - yield create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit)) + yield Code(create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit))) count = yield exclusive_rows.count() self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0] diff --git a/data_diff/sqeleton/abcs/__init__.py b/data_diff/sqeleton/abcs/__init__.py new file mode 100644 index 00000000..b1245819 --- /dev/null +++ b/data_diff/sqeleton/abcs/__init__.py @@ -0,0 +1,2 @@ +from .database_types import AbstractDatabase, AbstractDialect, DbKey, DbPath, DbTime, IKey +from .compiler import AbstractCompiler, Compilable diff --git a/data_diff/sqeleton/abcs/compiler.py b/data_diff/sqeleton/abcs/compiler.py new file mode 100644 index 00000000..9c734e3c --- /dev/null +++ b/data_diff/sqeleton/abcs/compiler.py @@ -0,0 +1,14 @@ +from typing import Any, Dict +from abc import ABC, abstractmethod + + +class AbstractCompiler(ABC): + @abstractmethod + def compile(self, elem: Any, params: Dict[str, Any] = None) -> str: + ... + + +class Compilable(ABC): + @abstractmethod + def compile(self, c: AbstractCompiler) -> str: + ... diff --git a/data_diff/sqeleton/databases/database_types.py b/data_diff/sqeleton/abcs/database_types.py similarity index 62% rename from data_diff/sqeleton/databases/database_types.py rename to data_diff/sqeleton/abcs/database_types.py index 98ebf8b9..c6a945a9 100644 --- a/data_diff/sqeleton/databases/database_types.py +++ b/data_diff/sqeleton/abcs/database_types.py @@ -1,4 +1,3 @@ -import logging import decimal from abc import ABC, abstractmethod from typing import Sequence, Optional, Tuple, Union, Dict, List @@ -6,15 +5,13 @@ from runtype import dataclass -from ..utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict, ArithAlphanumeric, ArithUUID +from ..utils import ArithAlphanumeric, ArithUUID DbPath = Tuple[str, ...] DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric] DbTime = datetime -logger = logging.getLogger("databases") - class ColType: supported = True @@ -214,94 +211,6 @@ def parse_type( "Parse type info as returned by the database" -class AbstractMixin_NormalizeValue(ABC): - @abstractmethod - def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized timestamp. - - The returned expression must accept any SQL datetime/timestamp, and return a string. - - Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF`` - - Precision of dates should be rounded up/down according to coltype.rounds - """ - ... - - @abstractmethod - def normalize_number(self, value: str, coltype: FractionalType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized number. - - The returned expression must accept any SQL int/numeric/float, and return a string. - - Floats/Decimals are expected in the format - "I.P" - - Where I is the integer part of the number (as many digits as necessary), - and must be at least one digit (0). - P is the fractional digits, the amount of which is specified with - coltype.precision. Trailing zeroes may be necessary. - If P is 0, the dot is omitted. - - Note: We use 'precision' differently than most databases. For decimals, - it's the same as ``numeric_scale``, and for floats, who use binary precision, - it can be calculated as ``log10(2**numeric_precision)``. - """ - ... - - @abstractmethod - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - """Creates an SQL expression, that converts 'value' to a normalized uuid. - - i.e. just makes sure there is no trailing whitespace. - """ - ... - - def normalize_boolean(self, value: str, coltype: Boolean) -> str: - """Creates an SQL expression, that converts 'value' to either '0' or '1'.""" - return self.to_string(value) - - def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: - """Creates an SQL expression, that strips uuids of artifacts like whitespace.""" - if isinstance(coltype, String_UUID): - return f"TRIM({value})" - return self.to_string(value) - - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - """Creates an SQL expression, that converts 'value' to a normalized representation. - - The returned expression must accept any SQL value, and return a string. - - The default implementation dispatches to a method according to `coltype`: - - :: - - TemporalType -> normalize_timestamp() - FractionalType -> normalize_number() - *else* -> to_string() - - (`Integer` falls in the *else* category) - - """ - if isinstance(coltype, TemporalType): - return self.normalize_timestamp(value, coltype) - elif isinstance(coltype, FractionalType): - return self.normalize_number(value, coltype) - elif isinstance(coltype, ColType_UUID): - return self.normalize_uuid(value, coltype) - elif isinstance(coltype, Boolean): - return self.normalize_boolean(value, coltype) - return self.to_string(value) - - -class AbstractMixin_MD5(ABC): - """Dialect-dependent query expressions, that are specific to data-diff""" - - @abstractmethod - def md5_as_int(self, s: str) -> str: - "Provide SQL for computing md5 and returning an int" - ... - - class AbstractDatabase: @property @abstractmethod @@ -374,18 +283,3 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: @abstractmethod def is_autocommit(self) -> bool: "Return whether the database autocommits changes. When false, COMMIT statements are skipped." - - -Schema = CaseAwareMapping - - -def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: - logger.debug(f"[{db.name}] Schema = {schema}") - - if case_sensitive: - return CaseSensitiveDict(schema) - - if len({k.lower() for k in schema}) < len(schema): - logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') - logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).") - return CaseInsensitiveDict(schema) diff --git a/data_diff/sqeleton/abcs/mixins.py b/data_diff/sqeleton/abcs/mixins.py new file mode 100644 index 00000000..774dfa30 --- /dev/null +++ b/data_diff/sqeleton/abcs/mixins.py @@ -0,0 +1,105 @@ +from abc import ABC, abstractmethod +from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID +from .compiler import Compilable + + +class AbstractMixin_NormalizeValue(ABC): + @abstractmethod + def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized timestamp. + + The returned expression must accept any SQL datetime/timestamp, and return a string. + + Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF`` + + Precision of dates should be rounded up/down according to coltype.rounds + """ + ... + + @abstractmethod + def normalize_number(self, value: str, coltype: FractionalType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized number. + + The returned expression must accept any SQL int/numeric/float, and return a string. + + Floats/Decimals are expected in the format + "I.P" + + Where I is the integer part of the number (as many digits as necessary), + and must be at least one digit (0). + P is the fractional digits, the amount of which is specified with + coltype.precision. Trailing zeroes may be necessary. + If P is 0, the dot is omitted. + + Note: We use 'precision' differently than most databases. For decimals, + it's the same as ``numeric_scale``, and for floats, who use binary precision, + it can be calculated as ``log10(2**numeric_precision)``. + """ + ... + + @abstractmethod + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + """Creates an SQL expression, that converts 'value' to a normalized uuid. + + i.e. just makes sure there is no trailing whitespace. + """ + ... + + def normalize_boolean(self, value: str, coltype: Boolean) -> str: + """Creates an SQL expression, that converts 'value' to either '0' or '1'.""" + return self.to_string(value) + + def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: + """Creates an SQL expression, that strips uuids of artifacts like whitespace.""" + if isinstance(coltype, String_UUID): + return f"TRIM({value})" + return self.to_string(value) + + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized representation. + + The returned expression must accept any SQL value, and return a string. + + The default implementation dispatches to a method according to `coltype`: + + :: + + TemporalType -> normalize_timestamp() + FractionalType -> normalize_number() + *else* -> to_string() + + (`Integer` falls in the *else* category) + + """ + if isinstance(coltype, TemporalType): + return self.normalize_timestamp(value, coltype) + elif isinstance(coltype, FractionalType): + return self.normalize_number(value, coltype) + elif isinstance(coltype, ColType_UUID): + return self.normalize_uuid(value, coltype) + elif isinstance(coltype, Boolean): + return self.normalize_boolean(value, coltype) + return self.to_string(value) + + +class AbstractMixin_MD5(ABC): + """Dialect-dependent query expressions, that are specific to data-diff""" + + @abstractmethod + def md5_as_int(self, s: str) -> str: + "Provide SQL for computing md5 and returning an int" + ... + + +class AbstractMixin_Schema(ABC): + """Methods for querying the database schema + + TODO: Move AbstractDatabase.query_table_schema() and friends over here + """ + + @abstractmethod + def list_tables(self, like: Compilable = None) -> Compilable: + """Query to select the list of tables in the schema. + + If 'like' is specified, the value is applied to the table name, using the 'like' operator. + """ diff --git a/data_diff/sqeleton/databases/__init__.py b/data_diff/sqeleton/databases/__init__.py index 5b52863e..00076f06 100644 --- a/data_diff/sqeleton/databases/__init__.py +++ b/data_diff/sqeleton/databases/__init__.py @@ -1,22 +1,5 @@ -from .database_types import ( - AbstractDatabase, - AbstractDialect, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, - DbKey, - DbTime, - DbPath, - create_schema, - IKey, - ColType_UUID, - NumericType, - PrecisionType, - StringType, - ColType, - Native_UUID, - Schema, -) from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError, BaseDialect, Database +from ..abcs import DbPath, DbKey from .postgresql import PostgreSQL from .mysql import MySQL diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index 11c0ccb0..de1bcb13 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -12,7 +12,7 @@ from ..utils import is_uuid, safezip from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code -from .database_types import ( +from ..abcs.database_types import ( AbstractDatabase, AbstractDialect, ColType, @@ -250,9 +250,12 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = list): self.query(i) return self.query(sql_ast[-1], res_type) else: - sql_code = compiler.compile(sql_ast) - if sql_code is SKIP: - return SKIP + if isinstance(sql_ast, str): + sql_code = sql_ast + else: + sql_code = compiler.compile(sql_ast) + if sql_code is SKIP: + return SKIP logger.debug("Running SQL (%s): %s", self.name, sql_code) @@ -350,8 +353,8 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe if not text_columns: return - fields = [self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID()) for c in text_columns] - samples_by_row = self.query(table(*table_path).select(*fields).where(where or SKIP).limit(sample_size), list) + fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns] + samples_by_row = self.query(table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list) if not samples_by_row: raise ValueError(f"Table {table_path} is empty.") diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py index 988597fe..07520bf2 100644 --- a/data_diff/sqeleton/databases/bigquery.py +++ b/data_diff/sqeleton/databases/bigquery.py @@ -1,5 +1,5 @@ from typing import List, Union -from .database_types import ( +from ..abcs.database_types import ( Timestamp, Datetime, Integer, @@ -10,9 +10,8 @@ FractionalType, TemporalType, Boolean, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .base import BaseDialect, Database, import_helper, parse_table_name, ConnectError, apply_query from .base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter diff --git a/data_diff/sqeleton/databases/clickhouse.py b/data_diff/sqeleton/databases/clickhouse.py index 901fb07b..8dc0ac4a 100644 --- a/data_diff/sqeleton/databases/clickhouse.py +++ b/data_diff/sqeleton/databases/clickhouse.py @@ -10,7 +10,7 @@ ConnectError, DbTime, ) -from .database_types import ( +from ..abcs.database_types import ( ColType, Decimal, Float, @@ -20,9 +20,8 @@ TemporalType, Text, Timestamp, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue @import_helper("clickhouse") diff --git a/data_diff/sqeleton/databases/databricks.py b/data_diff/sqeleton/databases/databricks.py index 450ec0e7..5bebb0e3 100644 --- a/data_diff/sqeleton/databases/databricks.py +++ b/data_diff/sqeleton/databases/databricks.py @@ -2,7 +2,7 @@ from typing import Dict, Sequence import logging -from .database_types import ( +from ..abcs.database_types import ( Integer, Float, Decimal, @@ -13,9 +13,8 @@ DbPath, ColType, UnknownColType, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, BaseDialect, ThreadedDatabase, import_helper, parse_table_name diff --git a/data_diff/sqeleton/databases/duckdb.py b/data_diff/sqeleton/databases/duckdb.py index 979d5f3d..1fefcf1d 100644 --- a/data_diff/sqeleton/databases/duckdb.py +++ b/data_diff/sqeleton/databases/duckdb.py @@ -1,7 +1,7 @@ from typing import Union from ..utils import match_regexps -from .database_types import ( +from ..abcs.database_types import ( Timestamp, TimestampTZ, DbPath, @@ -14,9 +14,8 @@ Text, FractionalType, Boolean, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .base import ( Database, BaseDialect, diff --git a/data_diff/sqeleton/databases/mysql.py b/data_diff/sqeleton/databases/mysql.py index 6f3f37a6..5fe13072 100644 --- a/data_diff/sqeleton/databases/mysql.py +++ b/data_diff/sqeleton/databases/mysql.py @@ -1,4 +1,4 @@ -from .database_types import ( +from ..abcs.database_types import ( Datetime, Timestamp, Float, @@ -9,9 +9,8 @@ FractionalType, ColType_UUID, Boolean, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .base import ( ThreadedDatabase, import_helper, diff --git a/data_diff/sqeleton/databases/oracle.py b/data_diff/sqeleton/databases/oracle.py index 74941941..c7003f05 100644 --- a/data_diff/sqeleton/databases/oracle.py +++ b/data_diff/sqeleton/databases/oracle.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional from ..utils import match_regexps -from .database_types import ( +from ..abcs.database_types import ( Decimal, Float, Text, @@ -13,9 +13,8 @@ Timestamp, TimestampTZ, FractionalType, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError from .base import TIMESTAMP_PRECISION_POS diff --git a/data_diff/sqeleton/databases/postgresql.py b/data_diff/sqeleton/databases/postgresql.py index dc24320a..d313fe73 100644 --- a/data_diff/sqeleton/databases/postgresql.py +++ b/data_diff/sqeleton/databases/postgresql.py @@ -1,4 +1,4 @@ -from .database_types import ( +from ..abcs.database_types import ( Timestamp, TimestampTZ, Float, @@ -9,9 +9,8 @@ Text, FractionalType, Boolean, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .base import ( BaseDialect, ThreadedDatabase, diff --git a/data_diff/sqeleton/databases/presto.py b/data_diff/sqeleton/databases/presto.py index 117ce6e1..da3d0404 100644 --- a/data_diff/sqeleton/databases/presto.py +++ b/data_diff/sqeleton/databases/presto.py @@ -3,7 +3,7 @@ from ..utils import match_regexps -from .database_types import ( +from ..abcs.database_types import ( Timestamp, TimestampTZ, Integer, @@ -17,9 +17,8 @@ ColType_UUID, TemporalType, Boolean, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .base import BaseDialect, Database, import_helper, ThreadLocalInterpreter from .base import ( MD5_HEXDIGITS, diff --git a/data_diff/sqeleton/databases/redshift.py b/data_diff/sqeleton/databases/redshift.py index 42adaaaf..05b58520 100644 --- a/data_diff/sqeleton/databases/redshift.py +++ b/data_diff/sqeleton/databases/redshift.py @@ -1,5 +1,6 @@ from typing import List -from .database_types import Float, TemporalType, FractionalType, DbPath, AbstractMixin_NormalizeValue, AbstractMixin_MD5 +from ..abcs.database_types import Float, TemporalType, FractionalType, DbPath +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .postgresql import ( PostgreSQL, MD5_HEXDIGITS, diff --git a/data_diff/sqeleton/databases/snowflake.py b/data_diff/sqeleton/databases/snowflake.py index a22bfe84..a9b88cd2 100644 --- a/data_diff/sqeleton/databases/snowflake.py +++ b/data_diff/sqeleton/databases/snowflake.py @@ -1,7 +1,7 @@ from typing import Union, List import logging -from .database_types import ( +from ..abcs.database_types import ( Timestamp, TimestampTZ, Decimal, @@ -11,9 +11,8 @@ TemporalType, DbPath, Boolean, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from .base import BaseDialect, ConnectError, Database, import_helper, CHECKSUM_MASK, ThreadLocalInterpreter diff --git a/data_diff/sqeleton/databases/trino.py b/data_diff/sqeleton/databases/trino.py index 5327f928..c8077fd2 100644 --- a/data_diff/sqeleton/databases/trino.py +++ b/data_diff/sqeleton/databases/trino.py @@ -1,4 +1,4 @@ -from .database_types import TemporalType, ColType_UUID +from ..abcs.database_types import TemporalType, ColType_UUID from . import presto from .base import import_helper from .base import TIMESTAMP_PRECISION_POS diff --git a/data_diff/sqeleton/databases/vertica.py b/data_diff/sqeleton/databases/vertica.py index 38470fbf..86da1b78 100644 --- a/data_diff/sqeleton/databases/vertica.py +++ b/data_diff/sqeleton/databases/vertica.py @@ -12,7 +12,7 @@ ThreadedDatabase, import_helper, ) -from .database_types import ( +from ..abcs.database_types import ( Decimal, Float, FractionalType, @@ -22,10 +22,9 @@ Timestamp, TimestampTZ, Boolean, - AbstractMixin_MD5, - AbstractMixin_NormalizeValue, ColType_UUID, ) +from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue @import_helper("vertica") diff --git a/data_diff/sqeleton/queries/__init__.py b/data_diff/sqeleton/queries/__init__.py index 39f7a796..7f3e02bf 100644 --- a/data_diff/sqeleton/queries/__init__.py +++ b/data_diff/sqeleton/queries/__init__.py @@ -1,4 +1,4 @@ from .compiler import Compiler from .api import this, join, outerjoin, table, SKIP, sum_, avg, min_, max_, cte, commit, when, coalesce -from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code +from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code, Column from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/sqeleton/queries/ast_classes.py index 5934db79..ef55ce4d 100644 --- a/data_diff/sqeleton/queries/ast_classes.py +++ b/data_diff/sqeleton/queries/ast_classes.py @@ -5,9 +5,11 @@ from runtype import dataclass from ..utils import join_iter, ArithString +from ..abcs import Compilable +from ..schema import Schema -from .compiler import Compilable, Compiler, cv_params -from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple +from .compiler import Compiler, cv_params +from .base import SKIP, CompileError, DbPath, args_as_tuple class SqeletonError(Exception): @@ -43,6 +45,7 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None] + @dataclass class Code(ExprNode): code: str @@ -50,6 +53,7 @@ class Code(ExprNode): def compile(self, c: Compiler) -> str: return self.code + def _expr_type(e: Expr) -> type: if isinstance(e, ExprNode): return e.type @@ -173,7 +177,7 @@ class Concat(ExprNode): def compile(self, c: Compiler) -> str: # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL - items = [f"coalesce({c.compile(c.dialect.to_string(c.compile(expr)))}, '')" for expr in self.exprs] + items = [f"coalesce({c.compile(Code(c.dialect.to_string(c.compile(expr))))}, '')" for expr in self.exprs] assert items if len(items) == 1: return items[0] @@ -520,7 +524,7 @@ def compile(self, c: Compiler) -> str: return c.compile( self.table.replace( columns=columns, - group_by_exprs=keys, # XXX pass Expr instances, not strings (Code) + group_by_exprs=[Code(k) for k in keys], having_exprs=self.having_exprs, ) ) diff --git a/data_diff/sqeleton/queries/base.py b/data_diff/sqeleton/queries/base.py index f4d4906e..006b4c76 100644 --- a/data_diff/sqeleton/queries/base.py +++ b/data_diff/sqeleton/queries/base.py @@ -1,6 +1,7 @@ from typing import Generator -from ..databases.database_types import DbPath, DbKey, Schema +from ..abcs import DbPath, DbKey +from ..schema import Schema class _SKIP: diff --git a/data_diff/sqeleton/queries/compiler.py b/data_diff/sqeleton/queries/compiler.py index 52d8debc..f896d153 100644 --- a/data_diff/sqeleton/queries/compiler.py +++ b/data_diff/sqeleton/queries/compiler.py @@ -1,12 +1,11 @@ import random -from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, Sequence, List from runtype import dataclass from ..utils import ArithString -from ..databases import AbstractDatabase, AbstractDialect, DbPath +from ..abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable import contextvars @@ -14,7 +13,7 @@ @dataclass -class Compiler: +class Compiler(AbstractCompiler): database: AbstractDatabase params: dict = {} in_select: bool = False # Compilation runtime flag @@ -47,7 +46,7 @@ def _compile(self, elem) -> str: elif isinstance(elem, Compilable): return elem.compile(self.replace(root=False)) elif isinstance(elem, str): - return elem + return f"'{elem}'" elif isinstance(elem, int): return str(elem) elif isinstance(elem, datetime): @@ -71,9 +70,3 @@ def add_table_context(self, *tables: Sequence, **kw): def quote(self, s: str): return self.dialect.quote(s) - - -class Compilable(ABC): - @abstractmethod - def compile(self, c: Compiler) -> str: - ... diff --git a/data_diff/sqeleton/queries/extras.py b/data_diff/sqeleton/queries/extras.py index b20dbda5..1014c372 100644 --- a/data_diff/sqeleton/queries/extras.py +++ b/data_diff/sqeleton/queries/extras.py @@ -3,10 +3,10 @@ from typing import Callable, Sequence from runtype import dataclass -from ..databases import ColType, Native_UUID +from ..abcs.database_types import ColType, Native_UUID from .compiler import Compiler -from .ast_classes import Expr, ExprNode, Concat +from .ast_classes import Expr, ExprNode, Concat, Code @dataclass @@ -51,7 +51,7 @@ class Checksum(ExprNode): def compile(self, c: Compiler): if len(self.exprs) > 1: - exprs = [f"coalesce({c.compile(expr)}, '')" for expr in self.exprs] + exprs = [Code(f"coalesce({c.compile(expr)}, '')") for expr in self.exprs] # exprs = [c.compile(e) for e in exprs] expr = Concat(exprs, "|") else: diff --git a/data_diff/sqeleton/schema.py b/data_diff/sqeleton/schema.py new file mode 100644 index 00000000..ddf7e786 --- /dev/null +++ b/data_diff/sqeleton/schema.py @@ -0,0 +1,20 @@ +import logging + +from .utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict +from .abcs import AbstractDatabase, DbPath + +logger = logging.getLogger("schema") + +Schema = CaseAwareMapping + + +def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: + logger.debug(f"[{db.name}] Schema = {schema}") + + if case_sensitive: + return CaseSensitiveDict(schema) + + if len({k.lower() for k in schema}) < len(schema): + logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}') + logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).") + return CaseInsensitiveDict(schema) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index f6176fee..2ec0c186 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -5,8 +5,10 @@ from runtype import dataclass from .sqeleton.utils import ArithString, split_space -from .sqeleton.databases import Database, DbPath, DbKey, DbTime, Schema, create_schema -from .sqeleton.queries import Count, Checksum, SKIP, table, this, Expr, min_, max_ +from .sqeleton.databases import Database +from .sqeleton.abcs import DbPath, DbKey, DbTime +from .sqeleton.schema import Schema, create_schema +from .sqeleton.queries import Count, Checksum, SKIP, table, this, Expr, min_, max_, Code from .sqeleton.queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString logger = logging.getLogger("table_segment") @@ -98,7 +100,7 @@ def source_table(self): return table(*self.table_path, schema=self._schema) def make_select(self): - return self.source_table.where(*self._make_key_range(), *self._make_update_range(), self.where or SKIP) + return self.source_table.where(*self._make_key_range(), *self._make_update_range(), Code(self.where) if self.where else SKIP) def get_values(self) -> list: "Download all the relevant values of the segment from the database" diff --git a/data_diff/tracking.py b/data_diff/tracking.py index 38c7b054..ef426b33 100644 --- a/data_diff/tracking.py +++ b/data_diff/tracking.py @@ -104,10 +104,7 @@ def send_event_json(event_json): if not g_tracking_enabled: raise RuntimeError("Won't send; tracking is disabled!") - headers = { - 'Content-Type': 'application/json', - 'Authorization': 'Basic MkhndE00SGNxOUJtZWlDcU5ZaHo3Tzl0a2pNOg==' - } + headers = {"Content-Type": "application/json", "Authorization": "Basic MkhndE00SGNxOUJtZWlDcU5ZaHo3Tzl0a2pNOg=="} data = json.dumps(event_json).encode() try: req = urllib.request.Request(TRACK_URL, data=data, headers=headers) diff --git a/tests/sqeleton/test_query.py b/tests/sqeleton/test_query.py index 3c5c996c..dc2dc3e6 100644 --- a/tests/sqeleton/test_query.py +++ b/tests/sqeleton/test_query.py @@ -1,12 +1,8 @@ from datetime import datetime from typing import List, Optional import unittest -from data_diff.sqeleton.databases.database_types import ( - AbstractDatabase, - AbstractDialect, - CaseInsensitiveDict, - CaseSensitiveDict, -) +from data_diff.sqeleton.abcs import AbstractDatabase, AbstractDialect +from data_diff.sqeleton.utils import CaseInsensitiveDict, CaseSensitiveDict from data_diff.sqeleton.queries import this, table, Compiler, outerjoin, cte, when, coalesce from data_diff.sqeleton.queries.ast_classes import Random @@ -77,8 +73,9 @@ def test_basic(self): t = table("point").where(this.x == 1, this.y == 2) assert c.compile(t) == "SELECT * FROM point WHERE (x = 1) AND (y = 2)" - t = table("point").select("x", "y") - assert c.compile(t) == "SELECT x, y FROM point" + t = table("person").where(this.name == "Albert") + self.assertEqual( c.compile(t), "SELECT * FROM person WHERE (name = 'Albert')" ) + def test_outerjoin(self): c = Compiler(MockDatabase()) @@ -196,8 +193,8 @@ def test_select_distinct(self): def test_table_ops(self): c = Compiler(MockDatabase()) - a = table("a").select("x") - b = table("b").select("y") + a = table("a").select(this.x) + b = table("b").select(this.y) q = c.compile(a.union(b)) assert q == "SELECT x FROM a UNION SELECT y FROM b" diff --git a/tests/sqeleton/test_sql.py b/tests/sqeleton/test_sql.py index cdd85102..b92f6e90 100644 --- a/tests/sqeleton/test_sql.py +++ b/tests/sqeleton/test_sql.py @@ -3,7 +3,7 @@ from ..common import TEST_MYSQL_CONN_STRING from data_diff.sqeleton import connect -from data_diff.sqeleton.queries import Compiler, Count, Explain, Select, table, In, BinOp +from data_diff.sqeleton.queries import Compiler, Count, Explain, Select, table, In, BinOp, Code class TestSQL(unittest.TestCase): @@ -12,7 +12,7 @@ def setUp(self): self.compiler = Compiler(self.mysql) def test_compile_string(self): - self.assertEqual("SELECT 1", self.compiler.compile("SELECT 1")) + self.assertEqual("SELECT 1", self.compiler.compile(Code("SELECT 1"))) def test_compile_int(self): self.assertEqual("1", self.compiler.compile(1)) @@ -27,7 +27,7 @@ def test_compile_select(self): self.compiler.compile( Select( table("marine_mammals", "walrus"), - ["name"], + [Code("name")], ) ), ) @@ -63,8 +63,8 @@ def test_compare(self): self.compiler.compile( Select( table("marine_mammals", "walrus"), - ["name"], - [BinOp("<=", ["id", "1000"]), BinOp(">", ["id", "1"])], + [Code("name")], + [BinOp("<=", [Code("id"), Code("1000")]), BinOp(">", [Code("id"), Code("1")])], ) ), ) @@ -73,21 +73,21 @@ def test_in(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select(table("marine_mammals", "walrus"), ["name"], [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Code("name")], [In(Code("id"), [1, 2, 3])])), ) def test_count(self): expected_sql = "SELECT count(*) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count()], [In(Code("id"), [1, 2, 3])])), ) def test_count_with_column(self): expected_sql = "SELECT count(id) FROM `marine_mammals`.`walrus` WHERE (id IN (1, 2, 3))" self.assertEqual( expected_sql, - self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count("id")], [In("id", [1, 2, 3])])), + self.compiler.compile(Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])])), ) def test_explain(self): @@ -95,6 +95,6 @@ def test_explain(self): self.assertEqual( expected_sql, self.compiler.compile( - Explain(Select(table("marine_mammals", "walrus"), [Count("id")], [In("id", [1, 2, 3])])) + Explain(Select(table("marine_mammals", "walrus"), [Count(Code("id"))], [In(Code("id"), [1, 2, 3])])) ), ) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 0b02fa60..2333ee94 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -588,7 +588,6 @@ def _insert_to_table(conn, table_path, values, type): elif isinstance(conn, db.BigQuery) and type == "datetime": values = [(i, Code(f"cast(timestamp '{sample}' as datetime)")) for i, sample in values] - insert_rows_in_batches(conn, tbl, values, columns=["id", "col"]) conn.query(commit) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index d4f54add..67ca22e4 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -733,7 +733,7 @@ class TestInfoTree(unittest.TestCase): def test_info_tree_root(self): try: self.db = get_conn(db.DuckDB) - except KeyError: # ddb not defined + except KeyError: # ddb not defined self.db = get_conn(db.MySQL) table_suffix = random_table_suffix()