Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Fix a few things here & there #740

Merged
merged 8 commits into from
Oct 18, 2023
2 changes: 1 addition & 1 deletion data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _data_diff(

else:
for op, values in diff_iter:
color = COLOR_SCHEME[op]
color = COLOR_SCHEME.get(op, "grey62")

if json_output:
jsonl = json.dumps([op, list(values)])
Expand Down
19 changes: 17 additions & 2 deletions data_diff/abcs/database_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import decimal
from abc import ABC, abstractmethod
from typing import Tuple, Union
from typing import List, Optional, Tuple, Type, TypeVar, Union
from datetime import datetime

import attrs
Expand All @@ -12,9 +12,24 @@
DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric]
DbTime = datetime

N = TypeVar("N")

@attrs.define(frozen=True)

@attrs.define(frozen=True, kw_only=True)
class ColType:
# Arbitrary metadata added and fetched at runtime.
_notes: List[N] = attrs.field(factory=list, init=False, hash=False, eq=False)

def add_note(self, note: N) -> None:
self._notes.append(note)

def get_note(self, cls: Type[N]) -> Optional[N]:
"""Get the latest added note of type ``cls`` or its descendants."""
for note in reversed(self._notes):
if isinstance(note, cls):
return note
return None

@property
def supported(self) -> bool:
return True
Expand Down
9 changes: 7 additions & 2 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def apply_queries(self, callback: Callable[[str], Any]):
q: Expr = next(self.gen)
while True:
sql = self.compiler.database.dialect.compile(self.compiler, q)
logger.debug("Running SQL (%s-TL): %s", self.compiler.database.name, sql)
logger.debug("Running SQL (%s-TL):\n%s", self.compiler.database.name, sql)
try:
try:
res = callback(sql) if sql is not SKIP else SKIP
Expand Down Expand Up @@ -267,6 +267,8 @@ def _compile(self, compiler: Compiler, elem) -> str:
return "NULL"
elif isinstance(elem, Compilable):
return self.render_compilable(attrs.evolve(compiler, root=False), elem)
elif isinstance(elem, ColType):
return self.render_coltype(attrs.evolve(compiler, root=False), elem)
elif isinstance(elem, str):
return f"'{elem}'"
elif isinstance(elem, (int, float)):
Expand Down Expand Up @@ -359,6 +361,9 @@ def render_compilable(self, c: Compiler, elem: Compilable) -> str:
raise RuntimeError(f"Cannot render AST of type {elem.__class__}")
# return elem.compile(compiler.replace(root=False))

def render_coltype(self, c: Compiler, elem: ColType) -> str:
return self.type_repr(elem)

def render_column(self, c: Compiler, elem: Column) -> str:
if c._table_context:
if len(c._table_context) > 1:
Expand Down Expand Up @@ -876,7 +881,7 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
if sql_code is SKIP:
return SKIP

logger.debug("Running SQL (%s): %s", self.name, sql_code)
logger.debug("Running SQL (%s):\n%s", self.name, sql_code)

if self._interactive and isinstance(sql_ast, Select):
explained_sql = self.compile(Explain(sql_ast))
Expand Down
3 changes: 2 additions & 1 deletion data_diff/databases/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from data_diff.databases.base import Mixin_Schema
from data_diff.abcs.database_types import (
JSON,
NumericType,
Timestamp,
TimestampTZ,
DbPath,
Expand Down Expand Up @@ -50,7 +51,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:

return formatted_value

def normalize_number(self, value: str, coltype: FractionalType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
if coltype.precision == 0:
return f"CAST(FLOOR({value}) AS VARCHAR)"

Expand Down
26 changes: 13 additions & 13 deletions data_diff/hashdiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
from collections import defaultdict
from typing import Iterator
from operator import attrgetter

import attrs

Expand Down Expand Up @@ -71,7 +70,8 @@ class HashDiffer(TableDiffer):
"""

bisection_factor: int = DEFAULT_BISECTION_FACTOR
bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests
bisection_threshold: int = DEFAULT_BISECTION_THRESHOLD
bisection_disabled: bool = False # i.e. always download the rows (used in tests)

stats: dict = attrs.field(factory=dict)

Expand All @@ -82,7 +82,7 @@ def __attrs_post_init__(self):
if self.bisection_factor < 2:
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")

def _validate_and_adjust_columns(self, table1, table2):
def _validate_and_adjust_columns(self, table1, table2, *, strict: bool = True):
for c1, c2 in safezip(table1.relevant_columns, table2.relevant_columns):
if c1 not in table1._schema:
raise ValueError(f"Column '{c1}' not found in schema for table {table1}")
Expand All @@ -92,23 +92,23 @@ def _validate_and_adjust_columns(self, table1, table2):
# Update schemas to minimal mutual precision
col1 = table1._schema[c1]
col2 = table2._schema[c2]
if isinstance(col1, PrecisionType):
if not isinstance(col2, PrecisionType):
if isinstance(col1, PrecisionType) and isinstance(col2, PrecisionType):
if strict and not isinstance(col2, PrecisionType):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

lowest = min(col1, col2, key=attrgetter("precision"))
lowest = min(col1, col2, key=lambda col: col.precision)

if col1.precision != col2.precision:
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")

table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision, rounds=lowest.rounds)
table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision, rounds=lowest.rounds)

elif isinstance(col1, (NumericType, Boolean)):
if not isinstance(col2, (NumericType, Boolean)):
elif isinstance(col1, (NumericType, Boolean)) and isinstance(col2, (NumericType, Boolean)):
if strict and not isinstance(col2, (NumericType, Boolean)):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

lowest = min(col1, col2, key=attrgetter("precision"))
lowest = min(col1, col2, key=lambda col: col.precision)

if col1.precision != col2.precision:
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
Expand All @@ -119,11 +119,11 @@ def _validate_and_adjust_columns(self, table1, table2):
table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision)

elif isinstance(col1, ColType_UUID):
if not isinstance(col2, ColType_UUID):
if strict and not isinstance(col2, ColType_UUID):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

elif isinstance(col1, StringType):
if not isinstance(col2, StringType):
if strict and not isinstance(col2, StringType):
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")

for t in [table1, table2]:
Expand Down Expand Up @@ -157,7 +157,7 @@ def _diff_segments(
# default, data-diff will checksum the section first (when it's below
# the threshold) and _then_ download it.
if BENCHMARK:
if max_rows < self.bisection_threshold:
if self.bisection_disabled or max_rows < self.bisection_threshold:
return self._bisect_and_diff_segments(ti, table1, table2, info_tree, level=level, max_rows=max_rows)

(count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2])
Expand Down Expand Up @@ -202,7 +202,7 @@ def _bisect_and_diff_segments(

# If count is below the threshold, just download and compare the columns locally
# This saves time, as bisection speed is limited by ping and query performance.
if max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
if self.bisection_disabled or max_rows < self.bisection_threshold or max_space_size < self.bisection_factor * 2:
rows1, rows2 = self._threaded_call("get_values", [table1, table2])
json_cols = {
i: colname
Expand Down
2 changes: 1 addition & 1 deletion data_diff/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class CaseWhen(ExprNode):
def type(self):
then_types = {_expr_type(case.then) for case in self.cases}
if self.else_expr:
then_types |= _expr_type(self.else_expr)
then_types |= {_expr_type(self.else_expr)}
if len(then_types) > 1:
raise QB_TypeError(f"Non-matching types in when: {then_types}")
(t,) = then_types
Expand Down
9 changes: 6 additions & 3 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
import unittest
import time
import json
import re
import math
import uuid
from datetime import datetime, timedelta, timezone
import logging
Expand Down Expand Up @@ -765,10 +765,13 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
# reasonable amount of rows each. These will then be downloaded in
# parallel, using the existing implementation.
dl_factor = max(int(N_SAMPLES / 100_000), 2) if BENCHMARK else 2
dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else math.inf
dl_threshold = int(N_SAMPLES / dl_factor) + 1 if BENCHMARK else sys.maxsize
dl_threads = N_THREADS
differ = HashDiffer(
bisection_threshold=dl_threshold, bisection_factor=dl_factor, max_threadpool_size=dl_threads
bisection_factor=dl_factor,
bisection_threshold=dl_threshold,
bisection_disabled=True,
max_threadpool_size=dl_threads,
)
start = time.monotonic()
diff = list(differ.diff_tables(self.table, self.table2))
Expand Down
Loading