From 598d98cbe883ebec10ba440a8967b9be4189e679 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Wed, 18 Oct 2023 16:04:26 +0200 Subject: [PATCH 1/8] Normalise all numbers, not only fractional in MS SQL --- data_diff/databases/mssql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index e2845ec8..67eb1b0c 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -16,6 +16,7 @@ from data_diff.databases.base import Mixin_Schema from data_diff.abcs.database_types import ( JSON, + NumericType, Timestamp, TimestampTZ, DbPath, @@ -50,7 +51,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return formatted_value - def normalize_number(self, value: str, coltype: FractionalType) -> str: + def normalize_number(self, value: str, coltype: NumericType) -> str: if coltype.precision == 0: return f"CAST(FLOOR({value}) AS VARCHAR)" From 34f84037beceb1327ef76625f338c306263efe79 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Wed, 18 Oct 2023 16:04:52 +0200 Subject: [PATCH 2/8] Restrict bisection threshold to pure integers, disable bisection only explicitly --- data_diff/hashdiff_tables.py | 7 ++++--- tests/test_database_types.py | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index c25bc06d..e1164bfc 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -71,7 +71,8 @@ class HashDiffer(TableDiffer): """ bisection_factor: int = DEFAULT_BISECTION_FACTOR - bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests + bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD + bisection_disabled: bool = False # i.e. always download the rows (used in tests) stats: dict = attrs.field(factory=dict) @@ -157,7 +158,7 @@ def _diff_segments( # default, data-diff will checksum the section first (when it's below # the threshold) and _then_ download it. if BENCHMARK: - if max_rows < self.bisection_threshold: + if self.bisection_disabled or max_rows < self.bisection_threshold: return self._bisect_and_diff_segments(ti, table1, table2, info_tree, level=level, max_rows=max_rows) (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) @@ -202,7 +203,7 @@ def _bisect_and_diff_segments( # If count is below the threshold, just download and compare the columns locally # This saves time, as bisection speed is limited by ping and query performance. - if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2: + if self.bisection_disabled or max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2: rows1, rows2 = self._threaded_call("get_values", [table1, table2]) json_cols = { i: colname diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 70fd01aa..e97ca484 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,8 +1,8 @@ +import sys import unittest import time import json import re -import math import uuid from datetime import datetime, timedelta, timezone import logging @@ -765,10 +765,13 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego # reasonable amount of rows each. These will then be downloaded in # parallel, using the existing implementation. dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2 - dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf + dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else sys.maxsize dl_threads = N_THREADS differ = HashDiffer( - bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads + bisection_factor=dl_factor, + bisection_threshold=dl_threshold, + bisection_disabled=True, + max_threadpool_size=dl_threads, ) start = time.monotonic() diff = list(differ.diff_tables(self.table, self.table2)) From 6c510f73e5797fc8308d217850900028b6701688 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Thu, 12 Oct 2023 23:14:47 +0200 Subject: [PATCH 3/8] Show unidentified types of diff lines in grey instead of failing --- data_diff/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 7d96fe9f..0701d1b3 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -519,7 +519,7 @@ def _data_diff( else: for op, values in diff_iter: - color = COLOR_SCHEME[op] + color = COLOR_SCHEME.get(op, "grey62") if json_output: jsonl = json.dumps([op, list(values)]) From 494a3da12edc0fe5c44f8f545f4e4221e7987a2b Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Thu, 12 Oct 2023 23:15:20 +0200 Subject: [PATCH 4/8] Fix the CaseWhen's "else" rendering --- data_diff/queries/ast_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 7cf2319b..53e83e45 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -302,7 +302,7 @@ class CaseWhen(ExprNode): def type(self): then_types = {_expr_type(case.then) for case in self.cases} if self.else_expr: - then_types |= _expr_type(self.else_expr) + then_types |= {_expr_type(self.else_expr)} if len(then_types) > 1: raise QB_TypeError(f"Non-matching types in when: {then_types}") (t,) = then_types From 75108b97e8741375b1c2402ff57f441e73eed59f Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Thu, 12 Oct 2023 23:17:25 +0200 Subject: [PATCH 5/8] Show debug SQL statements ready for copy-pasting With the first line of the SQL on the first line, the file name (e.g. `base.py:123`) is copy-pasted into SQL and requirees manual editing. With this fix, the whole statement will be ready for copy-pasting without any fixes and cleanups. The statements are often multi-line anyway, so it does not damage the logging. --- data_diff/databases/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index e4e215b7..78f601f5 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -182,7 +182,7 @@ def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) while True: sql = self.compiler.database.dialect.compile(self.compiler, q) - logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) + logger.debug("Running SQL (%s-TL):\n%s", self.compiler.database.name, sql) try: try: res = callback(sql) if sql is not SKIP else SKIP @@ -876,7 +876,7 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): if sql_code is SKIP: return SKIP - logger.debug("Running SQL (%s): %s", self.name, sql_code) + logger.debug("Running SQL (%s):\n%s", self.name, sql_code) if self._interactive and isinstance(sql_ast, Select): explained_sql = self.compile(Explain(sql_ast)) From a8eb6f5f6e67131488e5dffd7cb1c8c47d6d9048 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Thu, 12 Oct 2023 23:18:20 +0200 Subject: [PATCH 6/8] Render column types in SQL as it is implied in `cast()` (but was never used?) --- data_diff/databases/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 78f601f5..84fdb8a4 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -267,6 +267,8 @@ def _compile(self, compiler: Compiler, elem) -> str: return "NULL" elif isinstance(elem, Compilable): return self.render_compilable(attrs.evolve(compiler, root=False), elem) + elif isinstance(elem, ColType): + return self.render_coltype(attrs.evolve(compiler, root=False), elem) elif isinstance(elem, str): return f"'{elem}'" elif isinstance(elem, (int, float)): @@ -359,6 +361,9 @@ def render_compilable(self, c: Compiler, elem: Compilable) -> str: raise RuntimeError(f"Cannot render AST of type {elem.__class__}") # return elem.compile(compiler.replace(root=False)) + def render_coltype(self, c: Compiler, elem: ColType) -> str: + return self.type_repr(elem) + def render_column(self, c: Compiler, elem: Column) -> str: if c._table_context: if len(c._table_context) > 1: From 515319670f4a8cc1b8a55c6ec421ec39676e53a8 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Thu, 12 Oct 2023 23:20:20 +0200 Subject: [PATCH 7/8] Annotate inferred column types with arbitrary runtime notes (metadata) Similar to pytest's marks or Python's Exception's notes, we can add arbitrary data classes and retrieve the latest note (marker). This makes the customization of column types easier without overriding the whole hierarchy of column types. --- data_diff/abcs/database_types.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 0eb5bb69..f3c6381a 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -1,6 +1,6 @@ import decimal from abc import ABC, abstractmethod -from typing import Tuple, Union +from typing import List, Optional, Tuple, Type, TypeVar, Union from datetime import datetime import attrs @@ -12,9 +12,24 @@ DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric] DbTime = datetime +N = TypeVar("N") -@attrs.define(frozen=True) + +@attrs.define(frozen=True, kw_only=True) class ColType: + # Arbitrary metadata added and fetched at runtime. + _notes: List[N] = attrs.field(factory=list, init=False, hash=False, eq=False) + + def add_note(self, note: N) -> None: + self._notes.append(note) + + def get_note(self, cls: Type[N]) -> Optional[N]: + """Get the latest added note of type ``cls`` or its descendants.""" + for note in reversed(self._notes): + if isinstance(note, cls): + return note + return None + @property def supported(self) -> bool: return True From d268ff753e1a75509fcec78d4873cb2e23d048ba Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Thu, 12 Oct 2023 23:23:43 +0200 Subject: [PATCH 8/8] Relax the checking of cross-type column matching on demand MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit And ensure that both columns have the same set of attributes and follow the same protocol before actually using it — i.e. do not simply imply it as given. --- data_diff/hashdiff_tables.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index e1164bfc..95814e23 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -3,7 +3,6 @@ import logging from collections import defaultdict from typing import Iterator -from operator import attrgetter import attrs @@ -83,7 +82,7 @@ def __attrs_post_init__(self): if self.bisection_factor < 2: raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") - def _validate_and_adjust_columns(self, table1, table2): + def _validate_and_adjust_columns(self, table1, table2, *, strict: bool = True): for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns): if c1 not in table1._schema: raise ValueError(f"Column '{c1}' not found in schema for table {table1}") @@ -93,11 +92,11 @@ def _validate_and_adjust_columns(self, table1, table2): # Update schemas to minimal mutual precision col1 = table1._schema[c1] col2 = table2._schema[c2] - if isinstance(col1, PrecisionType): - if not isinstance(col2, PrecisionType): + if isinstance(col1, PrecisionType) and isinstance(col2, PrecisionType): + if strict and not isinstance(col2, PrecisionType): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - lowest = min(col1, col2, key=attrgetter("precision")) + lowest = min(col1, col2, key=lambda col: col.precision) if col1.precision != col2.precision: logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") @@ -105,11 +104,11 @@ def _validate_and_adjust_columns(self, table1, table2): table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision, rounds=lowest.rounds) table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision, rounds=lowest.rounds) - elif isinstance(col1, (NumericType, Boolean)): - if not isinstance(col2, (NumericType, Boolean)): + elif isinstance(col1, (NumericType, Boolean)) and isinstance(col2, (NumericType, Boolean)): + if strict and not isinstance(col2, (NumericType, Boolean)): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - lowest = min(col1, col2, key=attrgetter("precision")) + lowest = min(col1, col2, key=lambda col: col.precision) if col1.precision != col2.precision: logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") @@ -120,11 +119,11 @@ def _validate_and_adjust_columns(self, table1, table2): table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision) elif isinstance(col1, ColType_UUID): - if not isinstance(col2, ColType_UUID): + if strict and not isinstance(col2, ColType_UUID): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") elif isinstance(col1, StringType): - if not isinstance(col2, StringType): + if strict and not isinstance(col2, StringType): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") for t in [table1, table2]: