Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #306 from datafold/issue284
Browse files Browse the repository at this point in the history
data-diff now uses database A's now instead of cli's now.
  • Loading branch information
erezsh committed Nov 25, 2022
2 parents 9289570 + 79a419c commit 01abf3a
Show file tree
Hide file tree
Showing 26 changed files with 196 additions and 104 deletions.
29 changes: 17 additions & 12 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from datetime import datetime
import sys
import time
import json
Expand All @@ -15,8 +16,9 @@
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
from .table_segment import TableSegment
from .sqeleton.schema import create_schema
from .sqeleton.queries.api import current_timestamp
from .databases import connect
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
from .parse_time import parse_time_before, UNITS_STR, ParseError
from .config import apply_config_from_file
from .tracking import disable_tracking
from . import __version__
Expand Down Expand Up @@ -299,17 +301,6 @@ def _main(

start = time.monotonic()

try:
options = dict(
min_update=max_age and parse_time_before_now(max_age),
max_update=min_age and parse_time_before_now(min_age),
case_sensitive=case_sensitive,
where=where,
)
except ParseError as e:
logging.error(f"Error while parsing age expression: {e}")
return

if database1 is None or database2 is None:
logging.error(
f"Error: Databases not specified. Got {database1} and {database2}. Use --help for more information."
Expand All @@ -326,6 +317,20 @@ def _main(
logging.error(e)
return


now: datetime = db1.query(current_timestamp(), datetime)
now = now.replace(tzinfo=None)
try:
options = dict(
min_update=max_age and parse_time_before(now, max_age),
max_update=min_age and parse_time_before(now, min_age),
case_sensitive=case_sensitive,
where=where,
)
except ParseError as e:
logging.error(f"Error while parsing age expression: {e}")
return

dbs = db1, db2

if interactive:
Expand Down
23 changes: 21 additions & 2 deletions data_diff/databases/_connect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from data_diff.sqeleton.databases import Connect
from data_diff.sqeleton.databases import Connect, Database
import logging

from .postgresql import PostgreSQL
from .mysql import MySQL
Expand Down Expand Up @@ -29,4 +30,22 @@
"vertica": Vertica,
}

connect = Connect(DATABASE_BY_SCHEME)

class Connect_SetUTC(Connect):
"""Provides methods for connecting to a supported database using a URL or connection dict.
Ensures all sessions use UTC Timezone, if possible.
"""

def _connection_created(self, db):
db = super()._connection_created(db)
try:
db.query(db.dialect.set_timezone_to_utc())
except NotImplementedError:
logging.debug(
f"Database '{db}' does not allow setting timezone. We recommend making sure it's set to 'UTC'."
)
return db


connect = Connect_SetUTC(DATABASE_BY_SCHEME)
1 change: 1 addition & 0 deletions data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def sample(table_expr):

def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str:
db = c.database
c = c.replace(root=False) # we're compiling fragments, not full queries
if isinstance(db, BigQuery):
return f"create table {c.compile(path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.compile(expr)}"
elif isinstance(db, Presto):
Expand Down
4 changes: 2 additions & 2 deletions data_diff/parse_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def parse_time_delta(t: str):
return timedelta(**time_dict)


def parse_time_before_now(t: str):
return datetime.now() - parse_time_delta(t)
def parse_time_before(time: datetime, delta: str):
return time - parse_time_delta(delta)
8 changes: 8 additions & 0 deletions data_diff/sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def to_string(self, s: str) -> str:
def random(self) -> str:
"Provide SQL for generating a random number betweein 0..1"

@abstractmethod
def current_timestamp(self) -> str:
"Provide SQL for returning the current timestamp, aka now"

@abstractmethod
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
"Provide SQL fragment for limit and offset inside a select"
Expand All @@ -199,6 +203,10 @@ def timestamp_value(self, t: datetime) -> str:
"Provide SQL for the given timestamp value"
...

@abstractmethod
def set_timezone_to_utc(self) -> str:
"Provide SQL for setting the session timezone to UTC"

@abstractmethod
def parse_type(
self,
Expand Down
5 changes: 4 additions & 1 deletion data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def timestamp_value(self, t: DbTime) -> str:
return f"'{t.isoformat()}'"

def random(self) -> str:
return "RANDOM()"
return "random()"

def current_timestamp(self) -> str:
return "current_timestamp()"

def explain_as_text(self, query: str) -> str:
return f"EXPLAIN {query}"
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def type_repr(self, t) -> str:
except KeyError:
return super().type_repr(t)

def set_timezone_to_utc(self) -> str:
raise NotImplementedError()


class BigQuery(Database):
CONNECT_URI_HELP = "bigquery://<project>/<dataset>"
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
# # return f"'{t}'"
# return f"'{str(t)[:19]}'"

def set_timezone_to_utc(self) -> str:
raise NotImplementedError()


class Clickhouse(ThreadedDatabase):
dialect = Dialect()
Expand Down
18 changes: 14 additions & 4 deletions data_diff/sqeleton/databases/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def match_path(self, dsn):


class Connect:
"""Provides methods for connecting to a supported database using a URL or connection dict."""

def __init__(self, database_by_scheme: Dict[str, Database]):
self.database_by_scheme = database_by_scheme
self.match_uri_path = {
Expand Down Expand Up @@ -172,9 +174,11 @@ def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Databa
kw = {k: v for k, v in kw.items() if v is not None}

if issubclass(cls, ThreadedDatabase):
return cls(thread_count=thread_count, **kw)
db = cls(thread_count=thread_count, **kw)
else:
db = cls(**kw)

return cls(**kw)
return self._connection_created(db)

def connect_with_dict(self, d, thread_count):
d = dict(d)
Expand All @@ -186,9 +190,15 @@ def connect_with_dict(self, d, thread_count):

cls = matcher.database_cls
if issubclass(cls, ThreadedDatabase):
return cls(thread_count=thread_count, **d)
db = cls(thread_count=thread_count, **d)
else:
db = cls(**d)

return self._connection_created(db)

return cls(**d)
def _connection_created(self, db):
"Nop function to be overridden by subclasses."
return db

def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
"""Connect to a database using the given database configuration.
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 2 due to wierd precision issues
return max(super()._convert_db_precision_to_digits(p) - 2, 0)

def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE 'UTC'"


class Databricks(ThreadedDatabase):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def parse_type(

return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)

def set_timezone_to_utc(self) -> str:
return "SET GLOBAL TimeZone='UTC'"


class DuckDB(Database):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def type_repr(self, t) -> str:
def explain_as_text(self, query: str) -> str:
return f"EXPLAIN FORMAT=TREE {query}"

def set_timezone_to_utc(self) -> str:
return "SET @@session.time_zone='+00:00'"


class MySQL(ThreadedDatabase):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def parse_type(

return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)

def set_timezone_to_utc(self) -> str:
return "ALTER SESSION SET TIME_ZONE = 'UTC'"


class Oracle(ThreadedDatabase):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 2 due to wierd precision issues in PostgreSQL
return super()._convert_db_precision_to_digits(p) - 2

def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE 'UTC'"


class PostgreSQL(ThreadedDatabase):
dialect = PostgresqlDialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def parse_type(

return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)

def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE '+00:00'"


class Presto(Database):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def to_string(self, s: str):
def table_information(self) -> Compilable:
return table("INFORMATION_SCHEMA", "TABLES")

def set_timezone_to_utc(self) -> str:
return "ALTER SESSION SET TIMEZONE = 'UTC'"


class Snowflake(Database):
dialect = Dialect()
Expand Down
3 changes: 3 additions & 0 deletions data_diff/sqeleton/databases/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def parse_type(

return super().parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)

def set_timezone_to_utc(self) -> str:
return "SET TIME ZONE TO 'UTC'"


class Vertica(ThreadedDatabase):
dialect = Dialect()
Expand Down
4 changes: 4 additions & 0 deletions data_diff/sqeleton/queries/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,8 @@ def insert_rows_in_batches(db, table: TablePath, rows, *, columns=None, batch_si
db.query(table.insert_rows(batch, columns=columns))


def current_timestamp():
return CurrentTimestamp()


commit = Commit()
22 changes: 14 additions & 8 deletions data_diff/sqeleton/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..abcs import Compilable
from ..schema import Schema

from .compiler import Compiler, cv_params
from .compiler import Compiler, cv_params, Root
from .base import SKIP, CompileError, DbPath, args_as_tuple


Expand Down Expand Up @@ -47,7 +47,7 @@ def cast_to(self, to):


@dataclass
class Code(ExprNode):
class Code(ExprNode, Root):
code: str

def compile(self, c: Compiler) -> str:
Expand Down Expand Up @@ -434,7 +434,7 @@ def compile(self, c: Compiler) -> str:


@dataclass
class Join(ExprNode, ITable):
class Join(ExprNode, ITable, Root):
source_tables: Sequence[ITable]
op: str = None
on_exprs: Sequence[Expr] = None
Expand Down Expand Up @@ -499,7 +499,7 @@ def compile(self, parent_c: Compiler) -> str:


@dataclass
class GroupBy(ExprNode, ITable):
class GroupBy(ExprNode, ITable, Root):
table: ITable
keys: Sequence[Expr] = None # IKey?
values: Sequence[Expr] = None
Expand Down Expand Up @@ -540,7 +540,7 @@ def compile(self, c: Compiler) -> str:


@dataclass
class TableOp(ExprNode, ITable):
class TableOp(ExprNode, ITable, Root):
op: str
table1: ITable
table2: ITable
Expand Down Expand Up @@ -571,7 +571,7 @@ def compile(self, parent_c: Compiler) -> str:


@dataclass
class Select(ExprNode, ITable):
class Select(ExprNode, ITable, Root):
table: Expr = None
columns: Sequence[Expr] = None
where_exprs: Sequence[Expr] = None
Expand Down Expand Up @@ -771,7 +771,7 @@ def compile_for_insert(self, c: Compiler):


@dataclass
class Explain(ExprNode):
class Explain(ExprNode, Root):
select: Select

type = str
Expand All @@ -780,10 +780,16 @@ def compile(self, c: Compiler) -> str:
return c.dialect.explain_as_text(c.compile(self.select))


class CurrentTimestamp(ExprNode):
type = datetime

def compile(self, c: Compiler) -> str:
return c.dialect.current_timestamp()

# DDL


class Statement(Compilable):
class Statement(Compilable, Root):
type = None


Expand Down
7 changes: 7 additions & 0 deletions data_diff/sqeleton/queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

cv_params = contextvars.ContextVar("params")

class Root:
"Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)"


@dataclass
class Compiler(AbstractCompiler):
Expand All @@ -33,6 +36,10 @@ def compile(self, elem, params=None) -> str:
if params:
cv_params.set(params)

if self.root and isinstance(elem, Compilable) and not isinstance(elem, Root):
from .ast_classes import Select
elem = Select(columns=[elem])

res = self._compile(elem)
if self.root and self._subqueries:
subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items())
Expand Down
Loading

0 comments on commit 01abf3a

Please sign in to comment.