diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 7d96fe9f..0701d1b3 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -519,7 +519,7 @@ def _data_diff( else: for op, values in diff_iter: - color = COLOR_SCHEME[op] + color = COLOR_SCHEME.get(op, "grey62") if json_output: jsonl = json.dumps([op, list(values)]) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 0eb5bb69..f3c6381a 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -1,6 +1,6 @@ import decimal from abc import ABC, abstractmethod -from typing import Tuple, Union +from typing import List, Optional, Tuple, Type, TypeVar, Union from datetime import datetime import attrs @@ -12,9 +12,24 @@ DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric] DbTime = datetime +N = TypeVar("N") -@attrs.define(frozen=True) + +@attrs.define(frozen=True, kw_only=True) class ColType: + # Arbitrary metadata added and fetched at runtime. + _notes: List[N] = attrs.field(factory=list, init=False, hash=False, eq=False) + + def add_note(self, note: N) -> None: + self._notes.append(note) + + def get_note(self, cls: Type[N]) -> Optional[N]: + """Get the latest added note of type ``cls`` or its descendants.""" + for note in reversed(self._notes): + if isinstance(note, cls): + return note + return None + @property def supported(self) -> bool: return True diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index e4e215b7..84fdb8a4 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -182,7 +182,7 @@ def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) while True: sql = self.compiler.database.dialect.compile(self.compiler, q) - logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql) + logger.debug("Running SQL (%s-TL):\n%s", self.compiler.database.name, sql) try: try: res = callback(sql) if sql is not SKIP else SKIP @@ -267,6 +267,8 @@ def _compile(self, compiler: Compiler, elem) -> str: return "NULL" elif isinstance(elem, Compilable): return self.render_compilable(attrs.evolve(compiler, root=False), elem) + elif isinstance(elem, ColType): + return self.render_coltype(attrs.evolve(compiler, root=False), elem) elif isinstance(elem, str): return f"'{elem}'" elif isinstance(elem, (int, float)): @@ -359,6 +361,9 @@ def render_compilable(self, c: Compiler, elem: Compilable) -> str: raise RuntimeError(f"Cannot render AST of type {elem.__class__}") # return elem.compile(compiler.replace(root=False)) + def render_coltype(self, c: Compiler, elem: ColType) -> str: + return self.type_repr(elem) + def render_column(self, c: Compiler, elem: Column) -> str: if c._table_context: if len(c._table_context) > 1: @@ -876,7 +881,7 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None): if sql_code is SKIP: return SKIP - logger.debug("Running SQL (%s): %s", self.name, sql_code) + logger.debug("Running SQL (%s):\n%s", self.name, sql_code) if self._interactive and isinstance(sql_ast, Select): explained_sql = self.compile(Explain(sql_ast)) diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index e2845ec8..67eb1b0c 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -16,6 +16,7 @@ from data_diff.databases.base import Mixin_Schema from data_diff.abcs.database_types import ( JSON, + NumericType, Timestamp, TimestampTZ, DbPath, @@ -50,7 +51,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return formatted_value - def normalize_number(self, value: str, coltype: FractionalType) -> str: + def normalize_number(self, value: str, coltype: NumericType) -> str: if coltype.precision == 0: return f"CAST(FLOOR({value}) AS VARCHAR)" diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index c25bc06d..95814e23 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -3,7 +3,6 @@ import logging from collections import defaultdict from typing import Iterator -from operator import attrgetter import attrs @@ -71,7 +70,8 @@ class HashDiffer(TableDiffer): """ bisection_factor: int = DEFAULT_BISECTION_FACTOR - bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests + bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD + bisection_disabled: bool = False # i.e. always download the rows (used in tests) stats: dict = attrs.field(factory=dict) @@ -82,7 +82,7 @@ def __attrs_post_init__(self): if self.bisection_factor < 2: raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)") - def _validate_and_adjust_columns(self, table1, table2): + def _validate_and_adjust_columns(self, table1, table2, *, strict: bool = True): for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns): if c1 not in table1._schema: raise ValueError(f"Column '{c1}' not found in schema for table {table1}") @@ -92,11 +92,11 @@ def _validate_and_adjust_columns(self, table1, table2): # Update schemas to minimal mutual precision col1 = table1._schema[c1] col2 = table2._schema[c2] - if isinstance(col1, PrecisionType): - if not isinstance(col2, PrecisionType): + if isinstance(col1, PrecisionType) and isinstance(col2, PrecisionType): + if strict and not isinstance(col2, PrecisionType): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - lowest = min(col1, col2, key=attrgetter("precision")) + lowest = min(col1, col2, key=lambda col: col.precision) if col1.precision != col2.precision: logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") @@ -104,11 +104,11 @@ def _validate_and_adjust_columns(self, table1, table2): table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision, rounds=lowest.rounds) table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision, rounds=lowest.rounds) - elif isinstance(col1, (NumericType, Boolean)): - if not isinstance(col2, (NumericType, Boolean)): + elif isinstance(col1, (NumericType, Boolean)) and isinstance(col2, (NumericType, Boolean)): + if strict and not isinstance(col2, (NumericType, Boolean)): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") - lowest = min(col1, col2, key=attrgetter("precision")) + lowest = min(col1, col2, key=lambda col: col.precision) if col1.precision != col2.precision: logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") @@ -119,11 +119,11 @@ def _validate_and_adjust_columns(self, table1, table2): table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision) elif isinstance(col1, ColType_UUID): - if not isinstance(col2, ColType_UUID): + if strict and not isinstance(col2, ColType_UUID): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") elif isinstance(col1, StringType): - if not isinstance(col2, StringType): + if strict and not isinstance(col2, StringType): raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}") for t in [table1, table2]: @@ -157,7 +157,7 @@ def _diff_segments( # default, data-diff will checksum the section first (when it's below # the threshold) and _then_ download it. if BENCHMARK: - if max_rows < self.bisection_threshold: + if self.bisection_disabled or max_rows < self.bisection_threshold: return self._bisect_and_diff_segments(ti, table1, table2, info_tree, level=level, max_rows=max_rows) (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) @@ -202,7 +202,7 @@ def _bisect_and_diff_segments( # If count is below the threshold, just download and compare the columns locally # This saves time, as bisection speed is limited by ping and query performance. - if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2: + if self.bisection_disabled or max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2: rows1, rows2 = self._threaded_call("get_values", [table1, table2]) json_cols = { i: colname diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 7cf2319b..53e83e45 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -302,7 +302,7 @@ class CaseWhen(ExprNode): def type(self): then_types = {_expr_type(case.then) for case in self.cases} if self.else_expr: - then_types |= _expr_type(self.else_expr) + then_types |= {_expr_type(self.else_expr)} if len(then_types) > 1: raise QB_TypeError(f"Non-matching types in when: {then_types}") (t,) = then_types diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 70fd01aa..e97ca484 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,8 +1,8 @@ +import sys import unittest import time import json import re -import math import uuid from datetime import datetime, timedelta, timezone import logging @@ -765,10 +765,13 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego # reasonable amount of rows each. These will then be downloaded in # parallel, using the existing implementation. dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2 - dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf + dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else sys.maxsize dl_threads = N_THREADS differ = HashDiffer( - bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads + bisection_factor=dl_factor, + bisection_threshold=dl_threshold, + bisection_disabled=True, + max_threadpool_size=dl_threads, ) start = time.monotonic() diff = list(differ.diff_tables(self.table, self.table2))