Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

data-diff now uses database A's now instead of cli's now. #306

Merged
merged 2 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from datetime import datetime
import sys
import time
import json
Expand All @@ -15,8 +16,9 @@
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
from .table_segment import TableSegment
from .sqeleton.schema import create_schema
from .sqeleton.queries.api import current_timestamp
from .databases import connect
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
from .parse_time import parse_time_before, UNITS_STR, ParseError
from .config import apply_config_from_file
from .tracking import disable_tracking
from . import __version__
Expand Down Expand Up @@ -299,17 +301,6 @@ def _main(

start = time.monotonic()

try:
options = dict(
min_update=max_age and parse_time_before_now(max_age),
max_update=min_age and parse_time_before_now(min_age),
case_sensitive=case_sensitive,
where=where,
)
except ParseError as e:
logging.error(f"Error while parsing age expression: {e}")
return

if database1 is None or database2 is None:
logging.error(
f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information."
Expand All @@ -326,6 +317,20 @@ def _main(
logging.error(e)
return


now: datetime = db1.query(current_timestamp(), datetime)
now = now.replace(tzinfo=None)
try:
options = dict(
min_update=max_age and parse_time_before(now, max_age),
max_update=min_age and parse_time_before(now, min_age),
case_sensitive=case_sensitive,
where=where,
)
except ParseError as e:
logging.error(f"Error while parsing age expression: {e}")
return

dbs = db1, db2

if interactive:
Expand Down
23 changes: 21 additions & 2 deletions data_diff/databases/_connect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from data_diff.sqeleton.databases import Connect
from data_diff.sqeleton.databases import Connect, Database
import logging

from .postgresql import PostgreSQL
from .mysql import MySQL
Expand Down Expand Up @@ -29,4 +30,22 @@
"vertica": Vertica,
}

connect = Connect(DATABASE_BY_SCHEME)

class Connect_SetUTC(Connect):
"""Provides methods for connecting to a supported database using a URL or connection dict.

Ensures all sessions use UTC Timezone, if possible.
"""

def _connection_created(self, db):
db = super()._connection_created(db)
try:
db.query(db.dialect.set_timezone_to_utc())
except NotImplementedError:
logging.debug(
f"Database '{db}' does not allow setting timezone. We recommend making sure it's set to 'UTC'."
)
return db


connect = Connect_SetUTC(DATABASE_BY_SCHEME)
1 change: 1 addition & 0 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def sample(table_expr):

def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
db = c.database
c = c.replace(root=False) # we're compiling fragments, not full queries
if isinstance(db, BigQuery):
return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}"
elif isinstance(db, Presto):
Expand Down
4 changes: 2 additions & 2 deletions data_diff/parse_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def parse_time_delta(t: str):
return timedelta(**time_dict)


def parse_time_before_now(t: str):
return datetime.now() - parse_time_delta(t)
def parse_time_before(time: datetime, delta: str):
return time - parse_time_delta(delta)
8 changes: 8 additions & 0 deletions data_diff/sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def to_string(self, s: str) -> str:
def random(self) -> str:
"Provide SQL for generating a random number betweein 0..1"

@abstractmethod
def current_timestamp(self) -> str:
"Provide SQL for returning the current timestamp, aka now"

@abstractmethod
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
"Provide SQL fragment for limit and offset inside a select"
Expand All @@ -199,6 +203,10 @@ def timestamp_value(self, t: datetime) -> str:
"Provide SQL for the given timestamp value"
...

@abstractmethod
def set_timezone_to_utc(self) -> str:
"Provide SQL for setting the session timezone to UTC"

@abstractmethod
def parse_type(
self,
Expand Down
5 changes: 4 additions & 1 deletion data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def timestamp_value(self, t: DbTime) -> str:
return f"'{t.isoformat()}'"

def random(self) -> str:
return "RANDOM()"
return "random()"

def current_timestamp(self) -> str:
return "current_timestamp()"

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN {query}"
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def type_repr(self, t) -> str:
except KeyError:
return super().type_repr(t)

def set_timezone_to_utc(self) -> str:
raise NotImplementedError()


class BigQuery(Database):
CONNECT_URI_HELP = "bigquery://<project>/<dataset>"
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
# # return f"'{t}'"
# return f"'{str(t)[:19]}'"

def set_timezone_to_utc(self) -> str:
raise NotImplementedError()


class Clickhouse(ThreadedDatabase):
dialect = Dialect()
Expand Down
18 changes: 14 additions & 4 deletions data_diff/sqeleton/databases/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def match_path(self, dsn):


class Connect:
"""Provides methods for connecting to a supported database using a URL or connection dict."""

def __init__(self, database_by_scheme: Dict[str, Database]):
self.database_by_scheme = database_by_scheme
self.match_uri_path = {
Expand Down Expand Up @@ -172,9 +174,11 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
kw = {k: v for k, v in kw.items() if v is not None}

if issubclass(cls, ThreadedDatabase):
return cls(thread_count=thread_count, **kw)
db = cls(thread_count=thread_count, **kw)
else:
db = cls(**kw)

return cls(**kw)
return self._connection_created(db)

def connect_with_dict(self, d, thread_count):
d = dict(d)
Expand All @@ -186,9 +190,15 @@ def connect_with_dict(self, d, thread_count):

cls = matcher.database_cls
if issubclass(cls, ThreadedDatabase):
return cls(thread_count=thread_count, **d)
db = cls(thread_count=thread_count, **d)
else:
db = cls(**d)

return self._connection_created(db)

return cls(**d)
def _connection_created(self, db):
"Nop function to be overridden by subclasses."
return db

def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
"""Connect to a database using the given database configuration.
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 2 due to wierd precision issues
return max(super()._convert_db_precision_to_digits(p) - 2, 0)

def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE 'UTC'"


class Databricks(ThreadedDatabase):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def parse_type(

return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)

def set_timezone_to_utc(self) -> str:
return "SET GLOBAL TimeZone='UTC'"


class DuckDB(Database):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def type_repr(self, t) -> str:
def explain_as_text(self, query: str) -> str:
return f"EXPLAIN FORMAT=TREE {query}"

def set_timezone_to_utc(self) -> str:
return "SET @@session.time_zone='+00:00'"


class MySQL(ThreadedDatabase):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def parse_type(

return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)

def set_timezone_to_utc(self) -> str:
return "ALTER SESSION SET TIME_ZONE = 'UTC'"


class Oracle(ThreadedDatabase):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 2 due to wierd precision issues in PostgreSQL
return super()._convert_db_precision_to_digits(p) - 2

def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE 'UTC'"


class PostgreSQL(ThreadedDatabase):
dialect = PostgresqlDialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def parse_type(

return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)

def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE '+00:00'"


class Presto(Database):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def to_string(self, s: str):
def table_information(self) -> Compilable:
return table("INFORMATION_SCHEMA", "TABLES")

def set_timezone_to_utc(self) -> str:
return "ALTER SESSION SET TIMEZONE = 'UTC'"


class Snowflake(Database):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def parse_type(

return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)

def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE TO 'UTC'"


class Vertica(ThreadedDatabase):
dialect = Dialect()
Expand Down
4 changes: 4 additions & 0 deletions data_diff/sqeleton/queries/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,8 @@ def insert_rows_in_batches(db, table: TablePath, rows, *, columns=None, batch_si
db.query(table.insert_rows(batch, columns=columns))


def current_timestamp():
return CurrentTimestamp()


commit = Commit()
22 changes: 14 additions & 8 deletions data_diff/sqeleton/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..abcs import Compilable
from ..schema import Schema

from .compiler import Compiler, cv_params
from .compiler import Compiler, cv_params, Root
from .base import SKIP, CompileError, DbPath, args_as_tuple


Expand Down Expand Up @@ -47,7 +47,7 @@ def cast_to(self, to):


@dataclass
class Code(ExprNode):
class Code(ExprNode, Root):
code: str

def compile(self, c: Compiler) -> str:
Expand Down Expand Up @@ -434,7 +434,7 @@ def compile(self, c: Compiler) -> str:


@dataclass
class Join(ExprNode, ITable):
class Join(ExprNode, ITable, Root):
source_tables: Sequence[ITable]
op: str = None
on_exprs: Sequence[Expr] = None
Expand Down Expand Up @@ -499,7 +499,7 @@ def compile(self, parent_c: Compiler) -> str:


@dataclass
class GroupBy(ExprNode, ITable):
class GroupBy(ExprNode, ITable, Root):
table: ITable
keys: Sequence[Expr] = None # IKey?
values: Sequence[Expr] = None
Expand Down Expand Up @@ -540,7 +540,7 @@ def compile(self, c: Compiler) -> str:


@dataclass
class TableOp(ExprNode, ITable):
class TableOp(ExprNode, ITable, Root):
op: str
table1: ITable
table2: ITable
Expand Down Expand Up @@ -571,7 +571,7 @@ def compile(self, parent_c: Compiler) -> str:


@dataclass
class Select(ExprNode, ITable):
class Select(ExprNode, ITable, Root):
table: Expr = None
columns: Sequence[Expr] = None
where_exprs: Sequence[Expr] = None
Expand Down Expand Up @@ -771,7 +771,7 @@ def compile_for_insert(self, c: Compiler):


@dataclass
class Explain(ExprNode):
class Explain(ExprNode, Root):
select: Select

type = str
Expand All @@ -780,10 +780,16 @@ def compile(self, c: Compiler) -> str:
return c.dialect.explain_as_text(c.compile(self.select))


class CurrentTimestamp(ExprNode):
type = datetime

def compile(self, c: Compiler) -> str:
return c.dialect.current_timestamp()

# DDL


class Statement(Compilable):
class Statement(Compilable, Root):
type = None


Expand Down
7 changes: 7 additions & 0 deletions data_diff/sqeleton/queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

cv_params = contextvars.ContextVar("params")

class Root:
"Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)"


@dataclass
class Compiler(AbstractCompiler):
Expand All @@ -33,6 +36,10 @@ def compile(self, elem, params=None) -> str:
if params:
cv_params.set(params)

if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
from .ast_classes import Select
elem = Select(columns=[elem])

res = self._compile(elem)
if self.root and self._subqueries:
subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items())
Expand Down
Loading