diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 3bfc9250..c1db0901 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,4 +1,5 @@ from copy import deepcopy +from datetime import datetime import sys import time import json @@ -15,8 +16,9 @@ from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer from .table_segment import TableSegment from .sqeleton.schema import create_schema +from .sqeleton.queries.api import current_timestamp from .databases import connect -from .parse_time import parse_time_before_now, UNITS_STR, ParseError +from .parse_time import parse_time_before, UNITS_STR, ParseError from .config import apply_config_from_file from .tracking import disable_tracking from . import __version__ @@ -299,17 +301,6 @@ def _main( start = time.monotonic() - try: - options = dict( - min_update=max_age and parse_time_before_now(max_age), - max_update=min_age and parse_time_before_now(min_age), - case_sensitive=case_sensitive, - where=where, - ) - except ParseError as e: - logging.error(f"Error while parsing age expression: {e}") - return - if database1 is None or database2 is None: logging.error( f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information." @@ -326,6 +317,20 @@ def _main( logging.error(e) return + + now: datetime = db1.query(current_timestamp(), datetime) + now = now.replace(tzinfo=None) + try: + options = dict( + min_update=max_age and parse_time_before(now, max_age), + max_update=min_age and parse_time_before(now, min_age), + case_sensitive=case_sensitive, + where=where, + ) + except ParseError as e: + logging.error(f"Error while parsing age expression: {e}") + return + dbs = db1, db2 if interactive: diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index fb669971..bbb9790f 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -1,4 +1,5 @@ -from data_diff.sqeleton.databases import Connect +from data_diff.sqeleton.databases import Connect, Database +import logging from .postgresql import PostgreSQL from .mysql import MySQL @@ -29,4 +30,22 @@ "vertica": Vertica, } -connect = Connect(DATABASE_BY_SCHEME) + +class Connect_SetUTC(Connect): + """Provides methods for connecting to a supported database using a URL or connection dict. + + Ensures all sessions use UTC Timezone, if possible. + """ + + def _connection_created(self, db): + db = super()._connection_created(db) + try: + db.query(db.dialect.set_timezone_to_utc()) + except NotImplementedError: + logging.debug( + f"Database '{db}' does not allow setting timezone. We recommend making sure it's set to 'UTC'." + ) + return db + + +connect = Connect_SetUTC(DATABASE_BY_SCHEME) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index d3938fd5..1327f9d3 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -62,6 +62,7 @@ def sample(table_expr): def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str: db = c.database + c = c.replace(root=False) # we're compiling fragments, not full queries if isinstance(db, BigQuery): return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}" elif isinstance(db, Presto): diff --git a/data_diff/parse_time.py b/data_diff/parse_time.py index c24f4015..39924798 100644 --- a/data_diff/parse_time.py +++ b/data_diff/parse_time.py @@ -70,5 +70,5 @@ def parse_time_delta(t: str): return timedelta(**time_dict) -def parse_time_before_now(t: str): - return datetime.now() - parse_time_delta(t) +def parse_time_before(time: datetime, delta: str): + return time - parse_time_delta(delta) diff --git a/data_diff/sqeleton/abcs/database_types.py b/data_diff/sqeleton/abcs/database_types.py index 431aeb42..7cc280d1 100644 --- a/data_diff/sqeleton/abcs/database_types.py +++ b/data_diff/sqeleton/abcs/database_types.py @@ -184,6 +184,10 @@ def to_string(self, s: str) -> str: def random(self) -> str: "Provide SQL for generating a random number betweein 0..1" + @abstractmethod + def current_timestamp(self) -> str: + "Provide SQL for returning the current timestamp, aka now" + @abstractmethod def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): "Provide SQL fragment for limit and offset inside a select" @@ -199,6 +203,10 @@ def timestamp_value(self, t: datetime) -> str: "Provide SQL for the given timestamp value" ... + @abstractmethod + def set_timezone_to_utc(self) -> str: + "Provide SQL for setting the session timezone to UTC" + @abstractmethod def parse_type( self, diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index ac521be2..bccd7723 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -142,7 +142,10 @@ def timestamp_value(self, t: DbTime) -> str: return f"'{t.isoformat()}'" def random(self) -> str: - return "RANDOM()" + return "random()" + + def current_timestamp(self) -> str: + return "current_timestamp()" def explain_as_text(self, query: str) -> str: return f"EXPLAIN {query}" diff --git a/data_diff/sqeleton/databases/bigquery.py b/data_diff/sqeleton/databases/bigquery.py index a8e715cb..64e056b2 100644 --- a/data_diff/sqeleton/databases/bigquery.py +++ b/data_diff/sqeleton/databases/bigquery.py @@ -101,6 +101,9 @@ def type_repr(self, t) -> str: except KeyError: return super().type_repr(t) + def set_timezone_to_utc(self) -> str: + raise NotImplementedError() + class BigQuery(Database): CONNECT_URI_HELP = "bigquery:///" diff --git a/data_diff/sqeleton/databases/clickhouse.py b/data_diff/sqeleton/databases/clickhouse.py index 8dc0ac4a..f1e1cc73 100644 --- a/data_diff/sqeleton/databases/clickhouse.py +++ b/data_diff/sqeleton/databases/clickhouse.py @@ -150,6 +150,9 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]: # # return f"'{t}'" # return f"'{str(t)[:19]}'" + def set_timezone_to_utc(self) -> str: + raise NotImplementedError() + class Clickhouse(ThreadedDatabase): dialect = Dialect() diff --git a/data_diff/sqeleton/databases/connect.py b/data_diff/sqeleton/databases/connect.py index 5535814a..f11d9144 100644 --- a/data_diff/sqeleton/databases/connect.py +++ b/data_diff/sqeleton/databases/connect.py @@ -93,6 +93,8 @@ def match_path(self, dsn): class Connect: + """Provides methods for connecting to a supported database using a URL or connection dict.""" + def __init__(self, database_by_scheme: Dict[str, Database]): self.database_by_scheme = database_by_scheme self.match_uri_path = { @@ -172,9 +174,11 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa kw = {k: v for k, v in kw.items() if v is not None} if issubclass(cls, ThreadedDatabase): - return cls(thread_count=thread_count, **kw) + db = cls(thread_count=thread_count, **kw) + else: + db = cls(**kw) - return cls(**kw) + return self._connection_created(db) def connect_with_dict(self, d, thread_count): d = dict(d) @@ -186,9 +190,15 @@ def connect_with_dict(self, d, thread_count): cls = matcher.database_cls if issubclass(cls, ThreadedDatabase): - return cls(thread_count=thread_count, **d) + db = cls(thread_count=thread_count, **d) + else: + db = cls(**d) + + return self._connection_created(db) - return cls(**d) + def _connection_created(self, db): + "Nop function to be overridden by subclasses." + return db def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database: """Connect to a database using the given database configuration. diff --git a/data_diff/sqeleton/databases/databricks.py b/data_diff/sqeleton/databases/databricks.py index ef3055b2..6c819903 100644 --- a/data_diff/sqeleton/databases/databricks.py +++ b/data_diff/sqeleton/databases/databricks.py @@ -82,6 +82,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int: # Subtracting 2 due to wierd precision issues return max(super()._convert_db_precision_to_digits(p) - 2, 0) + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE 'UTC'" + class Databricks(ThreadedDatabase): dialect = Dialect() diff --git a/data_diff/sqeleton/databases/duckdb.py b/data_diff/sqeleton/databases/duckdb.py index d477266b..f2e425c5 100644 --- a/data_diff/sqeleton/databases/duckdb.py +++ b/data_diff/sqeleton/databases/duckdb.py @@ -108,6 +108,9 @@ def parse_type( return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + def set_timezone_to_utc(self) -> str: + return "SET GLOBAL TimeZone='UTC'" + class DuckDB(Database): dialect = Dialect() diff --git a/data_diff/sqeleton/databases/mysql.py b/data_diff/sqeleton/databases/mysql.py index df1b77cd..b1e695b1 100644 --- a/data_diff/sqeleton/databases/mysql.py +++ b/data_diff/sqeleton/databases/mysql.py @@ -94,6 +94,9 @@ def type_repr(self, t) -> str: def explain_as_text(self, query: str) -> str: return f"EXPLAIN FORMAT=TREE {query}" + def set_timezone_to_utc(self) -> str: + return "SET @@session.time_zone='+00:00'" + class MySQL(ThreadedDatabase): dialect = Dialect() diff --git a/data_diff/sqeleton/databases/oracle.py b/data_diff/sqeleton/databases/oracle.py index 704341a4..d9df0808 100644 --- a/data_diff/sqeleton/databases/oracle.py +++ b/data_diff/sqeleton/databases/oracle.py @@ -149,6 +149,9 @@ def parse_type( return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale) + def set_timezone_to_utc(self) -> str: + return "ALTER SESSION SET TIME_ZONE = 'UTC'" + class Oracle(ThreadedDatabase): dialect = Dialect() diff --git a/data_diff/sqeleton/databases/postgresql.py b/data_diff/sqeleton/databases/postgresql.py index a402b01a..ab67d6b6 100644 --- a/data_diff/sqeleton/databases/postgresql.py +++ b/data_diff/sqeleton/databases/postgresql.py @@ -87,6 +87,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int: # Subtracting 2 due to wierd precision issues in PostgreSQL return super()._convert_db_precision_to_digits(p) - 2 + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE 'UTC'" + class PostgreSQL(ThreadedDatabase): dialect = PostgresqlDialect() diff --git a/data_diff/sqeleton/databases/presto.py b/data_diff/sqeleton/databases/presto.py index 76af62cf..71ad6ba5 100644 --- a/data_diff/sqeleton/databases/presto.py +++ b/data_diff/sqeleton/databases/presto.py @@ -134,6 +134,9 @@ def parse_type( return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE '+00:00'" + class Presto(Database): dialect = Dialect() diff --git a/data_diff/sqeleton/databases/snowflake.py b/data_diff/sqeleton/databases/snowflake.py index 4ea1fb92..092c7f68 100644 --- a/data_diff/sqeleton/databases/snowflake.py +++ b/data_diff/sqeleton/databases/snowflake.py @@ -93,6 +93,9 @@ def to_string(self, s: str): def table_information(self) -> Compilable: return table("INFORMATION_SCHEMA", "TABLES") + def set_timezone_to_utc(self) -> str: + return "ALTER SESSION SET TIMEZONE = 'UTC'" + class Snowflake(Database): dialect = Dialect() diff --git a/data_diff/sqeleton/databases/vertica.py b/data_diff/sqeleton/databases/vertica.py index c01e9544..453ff035 100644 --- a/data_diff/sqeleton/databases/vertica.py +++ b/data_diff/sqeleton/databases/vertica.py @@ -141,6 +141,9 @@ def parse_type( return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision) + def set_timezone_to_utc(self) -> str: + return "SET TIME ZONE TO 'UTC'" + class Vertica(ThreadedDatabase): dialect = Dialect() diff --git a/data_diff/sqeleton/queries/api.py b/data_diff/sqeleton/queries/api.py index 7b77bd0f..63291883 100644 --- a/data_diff/sqeleton/queries/api.py +++ b/data_diff/sqeleton/queries/api.py @@ -95,4 +95,8 @@ def insert_rows_in_batches(db, table: TablePath, rows, *, columns=None, batch_si db.query(table.insert_rows(batch, columns=columns)) +def current_timestamp(): + return CurrentTimestamp() + + commit = Commit() diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/sqeleton/queries/ast_classes.py index ef55ce4d..a9724834 100644 --- a/data_diff/sqeleton/queries/ast_classes.py +++ b/data_diff/sqeleton/queries/ast_classes.py @@ -8,7 +8,7 @@ from ..abcs import Compilable from ..schema import Schema -from .compiler import Compiler, cv_params +from .compiler import Compiler, cv_params, Root from .base import SKIP, CompileError, DbPath, args_as_tuple @@ -47,7 +47,7 @@ def cast_to(self, to): @dataclass -class Code(ExprNode): +class Code(ExprNode, Root): code: str def compile(self, c: Compiler) -> str: @@ -434,7 +434,7 @@ def compile(self, c: Compiler) -> str: @dataclass -class Join(ExprNode, ITable): +class Join(ExprNode, ITable, Root): source_tables: Sequence[ITable] op: str = None on_exprs: Sequence[Expr] = None @@ -499,7 +499,7 @@ def compile(self, parent_c: Compiler) -> str: @dataclass -class GroupBy(ExprNode, ITable): +class GroupBy(ExprNode, ITable, Root): table: ITable keys: Sequence[Expr] = None # IKey? values: Sequence[Expr] = None @@ -540,7 +540,7 @@ def compile(self, c: Compiler) -> str: @dataclass -class TableOp(ExprNode, ITable): +class TableOp(ExprNode, ITable, Root): op: str table1: ITable table2: ITable @@ -571,7 +571,7 @@ def compile(self, parent_c: Compiler) -> str: @dataclass -class Select(ExprNode, ITable): +class Select(ExprNode, ITable, Root): table: Expr = None columns: Sequence[Expr] = None where_exprs: Sequence[Expr] = None @@ -771,7 +771,7 @@ def compile_for_insert(self, c: Compiler): @dataclass -class Explain(ExprNode): +class Explain(ExprNode, Root): select: Select type = str @@ -780,10 +780,16 @@ def compile(self, c: Compiler) -> str: return c.dialect.explain_as_text(c.compile(self.select)) +class CurrentTimestamp(ExprNode): + type = datetime + + def compile(self, c: Compiler) -> str: + return c.dialect.current_timestamp() + # DDL -class Statement(Compilable): +class Statement(Compilable, Root): type = None diff --git a/data_diff/sqeleton/queries/compiler.py b/data_diff/sqeleton/queries/compiler.py index f896d153..73c3971a 100644 --- a/data_diff/sqeleton/queries/compiler.py +++ b/data_diff/sqeleton/queries/compiler.py @@ -11,6 +11,9 @@ cv_params = contextvars.ContextVar("params") +class Root: + "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" + @dataclass class Compiler(AbstractCompiler): @@ -33,6 +36,10 @@ def compile(self, elem, params=None) -> str: if params: cv_params.set(params) + if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root): + from .ast_classes import Select + elem = Select(columns=[elem]) + res = self._compile(elem) if self.root and self._subqueries: subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items()) diff --git a/tests/common.py b/tests/common.py index c5f345c9..d44303e7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,6 +13,7 @@ from data_diff import tracking from data_diff import connect from data_diff.sqeleton.queries import table +from data_diff.table_segment import TableSegment from data_diff.sqeleton.databases import Database from data_diff.query_utils import drop_table @@ -86,10 +87,13 @@ def get_git_revision_short_hash() -> str: _database_instances = {} -def get_conn(cls: type) -> Database: - if cls not in _database_instances: - _database_instances[cls] = connect(CONN_STRINGS[cls], N_THREADS) - return _database_instances[cls] +def get_conn(cls: type, shared: bool =True) -> Database: + if shared: + if cls not in _database_instances: + _database_instances[cls] = get_conn(cls, shared=False) + return _database_instances[cls] + + return connect(CONN_STRINGS[cls], N_THREADS) def _print_used_dbs(): @@ -134,11 +138,12 @@ class DiffTestCase(unittest.TestCase): db_cls = None src_schema = None dst_schema = None + shared_connection = True def setUp(self): assert self.db_cls, self.db_cls - self.connection = get_conn(self.db_cls) + self.connection = get_conn(self.db_cls, self.shared_connection) table_suffix = random_table_suffix() self.table_src_name = f"src{table_suffix}" @@ -150,11 +155,11 @@ def setUp(self): drop_table(self.connection, self.table_src_path) drop_table(self.connection, self.table_dst_path) + self.src_table = table(self.table_src_path, schema=self.src_schema) + self.dst_table = table(self.table_dst_path, schema=self.dst_schema) if self.src_schema: - self.src_table = table(self.table_src_path, schema=self.src_schema) self.connection.query(self.src_table.create()) if self.dst_schema: - self.dst_table = table(self.table_dst_path, schema=self.dst_schema) self.connection.query(self.dst_table.create()) return super().setUp() @@ -175,3 +180,8 @@ def _test_per_database(cls): return _parameterized_class_per_conn(databases)(cls) return _test_per_database + +def table_segment(database, table_path, key_columns, *args, **kw): + if isinstance(key_columns, str): + key_columns = (key_columns,) + return TableSegment(database, table_path, key_columns, *args, **kw) diff --git a/tests/sqeleton/test_query.py b/tests/sqeleton/test_query.py index bb4a0286..cf52293e 100644 --- a/tests/sqeleton/test_query.py +++ b/tests/sqeleton/test_query.py @@ -33,6 +33,9 @@ def is_distinct_from(self, a: str, b: str) -> str: def random(self) -> str: return "random()" + def current_timestamp(self) -> str: + return "now()" + def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None): x = offset and f"OFFSET {offset}", limit and f"LIMIT {limit}" return " ".join(filter(None, x)) @@ -43,6 +46,9 @@ def explain_as_text(self, query: str) -> str: def timestamp_value(self, t: datetime) -> str: return f"timestamp '{t}'" + def set_timezone_to_utc(self) -> str: + return "set timezone 'UTC'" + parse_type = NotImplemented diff --git a/tests/sqeleton/test_sql.py b/tests/sqeleton/test_sql.py index b178e36d..7708408d 100644 --- a/tests/sqeleton/test_sql.py +++ b/tests/sqeleton/test_sql.py @@ -18,7 +18,7 @@ def test_compile_int(self): self.assertEqual("1", self.compiler.compile(1)) def test_compile_table_name(self): - self.assertEqual("`marine_mammals`.`walrus`", self.compiler.compile(table("marine_mammals", "walrus"))) + self.assertEqual("`marine_mammals`.`walrus`", self.compiler.replace(root=False).compile(table("marine_mammals", "walrus"))) def test_compile_select(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus`" diff --git a/tests/test_cli.py b/tests/test_cli.py index 93eca852..8f9eb97a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,7 @@ import arrow import subprocess import sys -from datetime import datetime +from datetime import datetime, timedelta from data_diff.databases import MySQL from data_diff.sqeleton.queries import table, commit @@ -41,23 +41,22 @@ def setUp(self) -> None: src_table = table(table_src_name, schema={"id": int, "datetime": datetime, "text_comment": str}) self.conn.query(src_table.create()) self.conn.query("SET @@session.time_zone='+00:00'") - db_time = self.conn.query("select now()", datetime) - self.now = now = arrow.get(db_time) + now = self.conn.query("select now()", datetime) rows = [ (now, "now"), - (self.now.shift(seconds=-10), "a"), - (self.now.shift(seconds=-7), "b"), - (self.now.shift(seconds=-6), "c"), + (now - timedelta(seconds=10), "a"), + (now - timedelta(seconds=7), "b"), + (now - timedelta(seconds=6), "c"), ] - self.conn.query(src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows))) + self.conn.query(src_table.insert_rows((i, ts, s) for i, (ts, s) in enumerate(rows))) _commit(self.conn) self.conn.query(self.table_dst.create(self.table_src)) _commit(self.conn) - self.conn.query(src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago")) + self.conn.query(src_table.insert_row(len(rows), now - timedelta(seconds=3), "3 seconds ago")) _commit(self.conn) def tearDown(self) -> None: diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 5c787221..0bd9d5bf 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -41,14 +41,6 @@ def init_conns(): return CONNS = {cls: get_conn(cls) for cls in CONN_STRINGS} - if db.MySQL in CONNS: - CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'") - if db.PostgreSQL: - CONNS[db.PostgreSQL].query("SET TIME ZONE 'UTC';") - if db.DuckDB in CONNS: - CONNS[db.DuckDB].query("SET GLOBAL TimeZone='UTC'") - if db.Oracle in CONNS: - CONNS[db.Oracle].query("ALTER SESSION SET TIME_ZONE = 'UTC'") DATABASE_TYPES = { diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index 85e8ad5f..60066f24 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -13,7 +13,7 @@ from data_diff.table_segment import TableSegment, split_space from data_diff import databases as db -from .common import str_to_checksum, test_each_database_in_list, DiffTestCase, get_conn, random_table_suffix +from .common import str_to_checksum, test_each_database_in_list, DiffTestCase, table_segment TEST_DATABASES = { @@ -31,11 +31,6 @@ test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) -def _table_segment(database, table_path, key_columns, *args, **kw): - if isinstance(key_columns, str): - key_columns = (key_columns,) - return TableSegment(database, table_path, key_columns, *args, **kw) - class TestUtils(unittest.TestCase): def test_split_space(self): @@ -75,17 +70,17 @@ def setUp(self): ) def test_init(self): - a = _table_segment( + a = table_segment( self.connection, self.table_src_path, "id", "datetime", max_update=self.now.datetime, case_sensitive=False ) self.assertRaises( - ValueError, _table_segment, self.connection, self.table_src_path, "id", max_update=self.now.datetime + ValueError, table_segment, self.connection, self.table_src_path, "id", max_update=self.now.datetime ) def test_basic(self): differ = HashDiffer(bisection_factor=10, bisection_threshold=100) - a = _table_segment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) - b = _table_segment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) + a = table_segment(self.connection, self.table_src_path, "id", "datetime", case_sensitive=False) + b = table_segment(self.connection, self.table_dst_path, "id", "datetime", case_sensitive=False) assert a.count() == 6 assert b.count() == 5 @@ -95,10 +90,10 @@ def test_basic(self): def test_offset(self): differ = HashDiffer(bisection_factor=2, bisection_threshold=10) sec1 = self.now.shift(seconds=-3).datetime - a = _table_segment( + a = table_segment( self.connection, self.table_src_path, "id", "datetime", max_update=sec1, case_sensitive=False ) - b = _table_segment( + b = table_segment( self.connection, self.table_dst_path, "id", "datetime", max_update=sec1, case_sensitive=False ) assert a.count() == 4, a.count() @@ -107,10 +102,10 @@ def test_offset(self): assert not list(differ.diff_tables(a, a)) self.assertEqual(len(list(differ.diff_tables(a, b))), 1) - a = _table_segment( + a = table_segment( self.connection, self.table_src_path, "id", "datetime", min_update=sec1, case_sensitive=False ) - b = _table_segment( + b = table_segment( self.connection, self.table_dst_path, "id", "datetime", min_update=sec1, case_sensitive=False ) assert a.count() == 2 @@ -119,7 +114,7 @@ def test_offset(self): day1 = self.now.shift(days=-1).datetime - a = _table_segment( + a = table_segment( self.connection, self.table_src_path, "id", @@ -128,7 +123,7 @@ def test_offset(self): max_update=sec1, case_sensitive=False, ) - b = _table_segment( + b = table_segment( self.connection, self.table_dst_path, "id", @@ -151,8 +146,8 @@ class TestDiffTables(DiffTestCase): def setUp(self): super().setUp() - self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) self.differ = HashDiffer(bisection_factor=3, bisection_threshold=4) @@ -355,8 +350,8 @@ def test_diff_column_names(self): ] ) - table1 = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - table2 = _table_segment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) + table1 = table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + table2 = table_segment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False) differ = HashDiffer(bisection_factor=2) diff = list(differ.diff_tables(table1, table2)) @@ -383,10 +378,10 @@ def setUp(self): ] ) - self.a = _table_segment( + self.a = table_segment( self.connection, self.table_src_path, "id", extra_columns=("text_comment",), case_sensitive=False ).with_schema() - self.b = _table_segment( + self.b = table_segment( self.connection, self.table_dst_path, "id", extra_columns=("text_comment",), case_sensitive=False ).with_schema() @@ -439,8 +434,8 @@ def setUp(self): for query in queries: self.connection.query(query, None) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_alphanum_keys(self): @@ -450,8 +445,8 @@ def test_alphanum_keys(self): self.connection.query([self.src_table.insert_row("@@@", "<-- this bad value should not break us"), commit]) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) @@ -485,8 +480,8 @@ def setUp(self): self.connection.query(queries) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_varying_alphanum_keys(self): # Test the class itself @@ -507,8 +502,8 @@ def test_varying_alphanum_keys(self): commit, ) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) @@ -517,8 +512,8 @@ def test_varying_alphanum_keys(self): class TestTableSegment(DiffTestCase): def setUp(self) -> None: super().setUp() - self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) - self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) + self.table = table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False) + self.table2 = table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False) def test_table_segment(self): early = datetime(2021, 1, 1, 0, 0) @@ -571,8 +566,8 @@ def setUp(self): ] ) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_column_with_nulls(self): differ = HashDiffer(bisection_factor=2) @@ -599,8 +594,8 @@ def setUp(self): ] ) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) def test_uuid_columns_with_nulls(self): """ @@ -659,10 +654,10 @@ def setUp(self): ] ) - self.a = _table_segment( + self.a = table_segment( self.connection, self.table_src_path, "id", extra_columns=("c1", "c2"), case_sensitive=False ) - self.b = _table_segment( + self.b = table_segment( self.connection, self.table_dst_path, "id", extra_columns=("c1", "c2"), case_sensitive=False ) @@ -704,22 +699,19 @@ def setUp(self): self.null_uuid = uuid.uuid1(1) - self.diffs = [(uuid.uuid1(i), str(i)) for i in range(100)] + diffs = [(uuid.uuid1(i), str(i)) for i in range(100)] + self.connection.query([self.src_table.insert_rows(diffs), commit]) - self.a = _table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) - self.b = _table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) + self.a = table_segment(self.connection, self.table_src_path, "id", "text_comment", case_sensitive=False) + self.b = table_segment(self.connection, self.table_dst_path, "id", "text_comment", case_sensitive=False) - def test_right_table_empty(self): - self.connection.query([self.src_table.insert_rows(self.diffs), commit]) + self.differ = HashDiffer(bisection_factor=2) - differ = HashDiffer(bisection_factor=2) - self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) + def test_right_table_empty(self): + self.assertRaises(ValueError, list, self.differ.diff_tables(self.a, self.b)) def test_left_table_empty(self): - self.connection.query([self.dst_table.insert_rows(self.diffs), commit]) - - differ = HashDiffer(bisection_factor=2) - self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) + self.assertRaises(ValueError, list, self.differ.diff_tables(self.a, self.b)) class TestInfoTree(DiffTestCase): @@ -776,10 +768,10 @@ def setUp(self): self.connection.query([self.src_table.insert_rows(src_values), self.dst_table.insert_rows(dst_values), commit]) - self.a = _table_segment( + self.a = table_segment( self.connection, self.table_src_path, "id", extra_columns=("data",), case_sensitive=False ) - self.b = _table_segment( + self.b = table_segment( self.connection, self.table_dst_path, "id", extra_columns=("data",), case_sensitive=False )