Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Nov22 sqeleton refactor #308

Merged
merged 4 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data_diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .tracking import disable_tracking
from .databases import connect
from .sqeleton.databases import DbKey, DbTime, DbPath
from .sqeleton.abcs import DbKey, DbTime, DbPath
from .diff_tables import Algorithm
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
from .joindiff_tables import JoinDiffer
Expand Down
7 changes: 3 additions & 4 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
from .table_segment import TableSegment
from .sqeleton.databases import create_schema
from .sqeleton.schema import create_schema
from .databases import connect
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
from .config import apply_config_from_file
Expand Down Expand Up @@ -54,10 +54,10 @@ def diff_schemas(table1, table2, schema1, schema2, columns):
diffs = []

if c not in schema1:
cols = ', '.join(schema1)
cols = ", ".join(schema1)
raise ValueError(f"Column '{c}' not found in table 1, named '{table1}'. Columns: {cols}")
if c not in schema2:
cols = ', '.join(schema1)
cols = ", ".join(schema1)
raise ValueError(f"Column '{c}' not found in table 2, named '{table2}'. Columns: {cols}")

col1 = schema1[c]
Expand Down Expand Up @@ -216,7 +216,6 @@ def main(conf, run, **kw):
raise



def _main(
database1,
table1,
Expand Down
2 changes: 1 addition & 1 deletion data_diff/databases/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from data_diff.sqeleton.databases import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue


class DatadiffDialect(AbstractMixin_MD5, AbstractMixin_NormalizeValue):
Expand Down
2 changes: 1 addition & 1 deletion data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .thread_utils import ThreadedYielder
from .table_segment import TableSegment
from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled
from .sqeleton.databases import IKey
from .sqeleton.abcs import IKey

logger = getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion data_diff/hashdiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .utils import safezip
from .thread_utils import ThreadedYielder
from .sqeleton.databases import ColType_UUID, NumericType, PrecisionType, StringType
from .sqeleton.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType
from .table_segment import TableSegment

from .diff_tables import TableDiffer
Expand Down
4 changes: 2 additions & 2 deletions data_diff/info_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def update_from_children(self, child_infos):
self.is_diff = any(c.is_diff for c in child_infos)

self.rowcounts = {
1: sum(c.rowcounts[1] for c in child_infos),
2: sum(c.rowcounts[2] for c in child_infos),
1: sum(c.rowcounts[1] for c in child_infos if c.rowcounts),
2: sum(c.rowcounts[2] for c in child_infos if c.rowcounts),
}


Expand Down
7 changes: 4 additions & 3 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

from runtype import dataclass

from .sqeleton.databases import Database, DbPath, NumericType, MySQL, BigQuery, Presto, Oracle, Snowflake
from .sqeleton.databases import Database, MySQL, BigQuery, Presto, Oracle, Snowflake
from .sqeleton.abcs.database_types import DbPath, NumericType
from .sqeleton.queries import table, sum_, min_, max_, avg
from .sqeleton.queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable
from .sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath
from .sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code
from .sqeleton.queries.compiler import Compiler
from .sqeleton.queries.extras import NormalizeAsString

Expand Down Expand Up @@ -332,7 +333,7 @@ def exclusive_rows(expr):
c = Compiler(db)
name = c.new_unique_table_name("temp_table")
exclusive_rows = table(name, schema=expr.source_table.schema)
yield create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit))
yield Code(create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit)))

count = yield exclusive_rows.count()
self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0]
Expand Down
2 changes: 2 additions & 0 deletions data_diff/sqeleton/abcs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .database_types import AbstractDatabase, AbstractDialect, DbKey, DbPath, DbTime, IKey
from .compiler import AbstractCompiler, Compilable
14 changes: 14 additions & 0 deletions data_diff/sqeleton/abcs/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any, Dict
from abc import ABC, abstractmethod


class AbstractCompiler(ABC):
@abstractmethod
def compile(self, elem: Any, params: Dict[str, Any] = None) -> str:
...


class Compilable(ABC):
@abstractmethod
def compile(self, c: AbstractCompiler) -> str:
...
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import logging
import decimal
from abc import ABC, abstractmethod
from typing import Sequence, Optional, Tuple, Union, Dict, List
from datetime import datetime

from runtype import dataclass

from ..utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict, ArithAlphanumeric, ArithUUID
from ..utils import ArithAlphanumeric, ArithUUID


DbPath = Tuple[str, ...]
DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric]
DbTime = datetime

logger = logging.getLogger("databases")


class ColType:
supported = True
Expand Down Expand Up @@ -214,94 +211,6 @@ def parse_type(
"Parse type info as returned by the database"


class AbstractMixin_NormalizeValue(ABC):
@abstractmethod
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.

The returned expression must accept any SQL datetime/timestamp, and return a string.

Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF``

Precision of dates should be rounded up/down according to coltype.rounds
"""
...

@abstractmethod
def normalize_number(self, value: str, coltype: FractionalType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized number.

The returned expression must accept any SQL int/numeric/float, and return a string.

Floats/Decimals are expected in the format
"I.P"

Where I is the integer part of the number (as many digits as necessary),
and must be at least one digit (0).
P is the fractional digits, the amount of which is specified with
coltype.precision. Trailing zeroes may be necessary.
If P is 0, the dot is omitted.

Note: We use 'precision' differently than most databases. For decimals,
it's the same as ``numeric_scale``, and for floats, who use binary precision,
it can be calculated as ``log10(2**numeric_precision)``.
"""
...

@abstractmethod
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized uuid.

i.e. just makes sure there is no trailing whitespace.
"""
...

def normalize_boolean(self, value: str, coltype: Boolean) -> str:
"""Creates an SQL expression, that converts 'value' to either '0' or '1'."""
return self.to_string(value)

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
"""Creates an SQL expression, that strips uuids of artifacts like whitespace."""
if isinstance(coltype, String_UUID):
return f"TRIM({value})"
return self.to_string(value)

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized representation.

The returned expression must accept any SQL value, and return a string.

The default implementation dispatches to a method according to `coltype`:

::

TemporalType -> normalize_timestamp()
FractionalType -> normalize_number()
*else* -> to_string()

(`Integer` falls in the *else* category)

"""
if isinstance(coltype, TemporalType):
return self.normalize_timestamp(value, coltype)
elif isinstance(coltype, FractionalType):
return self.normalize_number(value, coltype)
elif isinstance(coltype, ColType_UUID):
return self.normalize_uuid(value, coltype)
elif isinstance(coltype, Boolean):
return self.normalize_boolean(value, coltype)
return self.to_string(value)


class AbstractMixin_MD5(ABC):
"""Dialect-dependent query expressions, that are specific to data-diff"""

@abstractmethod
def md5_as_int(self, s: str) -> str:
"Provide SQL for computing md5 and returning an int"
...


class AbstractDatabase:
@property
@abstractmethod
Expand Down Expand Up @@ -374,18 +283,3 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
@abstractmethod
def is_autocommit(self) -> bool:
"Return whether the database autocommits changes. When false, COMMIT statements are skipped."


Schema = CaseAwareMapping


def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
logger.debug(f"[{db.name}] Schema = {schema}")

if case_sensitive:
return CaseSensitiveDict(schema)

if len({k.lower() for k in schema}) < len(schema):
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).")
return CaseInsensitiveDict(schema)
105 changes: 105 additions & 0 deletions data_diff/sqeleton/abcs/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID
from .compiler import Compilable


class AbstractMixin_NormalizeValue(ABC):
@abstractmethod
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.

The returned expression must accept any SQL datetime/timestamp, and return a string.

Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF``

Precision of dates should be rounded up/down according to coltype.rounds
"""
...

@abstractmethod
def normalize_number(self, value: str, coltype: FractionalType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized number.

The returned expression must accept any SQL int/numeric/float, and return a string.

Floats/Decimals are expected in the format
"I.P"

Where I is the integer part of the number (as many digits as necessary),
and must be at least one digit (0).
P is the fractional digits, the amount of which is specified with
coltype.precision. Trailing zeroes may be necessary.
If P is 0, the dot is omitted.

Note: We use 'precision' differently than most databases. For decimals,
it's the same as ``numeric_scale``, and for floats, who use binary precision,
it can be calculated as ``log10(2**numeric_precision)``.
"""
...

@abstractmethod
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized uuid.

i.e. just makes sure there is no trailing whitespace.
"""
...

def normalize_boolean(self, value: str, coltype: Boolean) -> str:
"""Creates an SQL expression, that converts 'value' to either '0' or '1'."""
return self.to_string(value)

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
"""Creates an SQL expression, that strips uuids of artifacts like whitespace."""
if isinstance(coltype, String_UUID):
return f"TRIM({value})"
return self.to_string(value)

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized representation.

The returned expression must accept any SQL value, and return a string.

The default implementation dispatches to a method according to `coltype`:

::

TemporalType -> normalize_timestamp()
FractionalType -> normalize_number()
*else* -> to_string()

(`Integer` falls in the *else* category)

"""
if isinstance(coltype, TemporalType):
return self.normalize_timestamp(value, coltype)
elif isinstance(coltype, FractionalType):
return self.normalize_number(value, coltype)
elif isinstance(coltype, ColType_UUID):
return self.normalize_uuid(value, coltype)
elif isinstance(coltype, Boolean):
return self.normalize_boolean(value, coltype)
return self.to_string(value)


class AbstractMixin_MD5(ABC):
"""Dialect-dependent query expressions, that are specific to data-diff"""

@abstractmethod
def md5_as_int(self, s: str) -> str:
"Provide SQL for computing md5 and returning an int"
...


class AbstractMixin_Schema(ABC):
"""Methods for querying the database schema

TODO: Move AbstractDatabase.query_table_schema() and friends over here
"""

@abstractmethod
def list_tables(self, like: Compilable = None) -> Compilable:
"""Query to select the list of tables in the schema.

If 'like' is specified, the value is applied to the table name, using the 'like' operator.
"""
19 changes: 1 addition & 18 deletions data_diff/sqeleton/databases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,5 @@
from .database_types import (
AbstractDatabase,
AbstractDialect,
AbstractMixin_MD5,
AbstractMixin_NormalizeValue,
DbKey,
DbTime,
DbPath,
create_schema,
IKey,
ColType_UUID,
NumericType,
PrecisionType,
StringType,
ColType,
Native_UUID,
Schema,
)
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError, BaseDialect, Database
from ..abcs import DbPath, DbKey

from .postgresql import PostgreSQL
from .mysql import MySQL
Expand Down
Loading