Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #260 from datafold/oct20_queries
Browse files Browse the repository at this point in the history
Various small fixes and refactors
  • Loading branch information
erezsh authored Oct 20, 2022
2 parents 1fc52c2 + 601d2bb commit b237fd8
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 76 deletions.
1 change: 1 addition & 0 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def _constant_value(self, v):
elif isinstance(v, str):
return f"'{v}'"
elif isinstance(v, datetime):
# TODO use self.timestamp_value
return f"timestamp '{v}'"
elif isinstance(v, UUID):
return f"'{v}'"
Expand Down
9 changes: 5 additions & 4 deletions data_diff/databases/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ class AbstractDialect(ABC):

@abstractmethod
def quote(self, s: str):
"Quote SQL name (implementation specific)"
"Quote SQL name"
...

@abstractmethod
def concat(self, l: List[str]) -> str:
"Provide SQL for concatenating a bunch of column into a string"
"Provide SQL for concatenating a bunch of columns into a string"
...

@abstractmethod
Expand All @@ -162,12 +162,13 @@ def is_distinct_from(self, a: str, b: str) -> str:

@abstractmethod
def to_string(self, s: str) -> str:
# TODO rewrite using cast_to(x, str)
"Provide SQL for casting a column to string"
...

@abstractmethod
def random(self) -> str:
"Provide SQL for generating a random number"
"Provide SQL for generating a random number betweein 0..1"

@abstractmethod
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
Expand All @@ -176,7 +177,7 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None

@abstractmethod
def explain_as_text(self, query: str) -> str:
"Provide SQL for explaining a query, returned in as table(varchar)"
"Provide SQL for explaining a query, returned as table(varchar)"
...

@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,6 @@ def type_repr(self, t) -> str:
return super().type_repr(t)

def constant_values(self, rows) -> str:
return " UNION ALL ".join("SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows)
return " UNION ALL ".join(
"SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows
)
4 changes: 2 additions & 2 deletions data_diff/queries/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def or_(*exprs: Expr):
exprs = args_as_tuple(exprs)
if len(exprs) == 1:
return exprs[0]
return BinOp("OR", exprs)
return BinBoolOp("OR", exprs)


def and_(*exprs: Expr):
exprs = args_as_tuple(exprs)
if len(exprs) == 1:
return exprs[0]
return BinOp("AND", exprs)
return BinBoolOp("AND", exprs)


def sum_(expr: Expr):
Expand Down
67 changes: 37 additions & 30 deletions data_diff/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def cast_to(self, to):
Expr = Union[ExprNode, str, bool, int, datetime, ArithString, None]


def get_type(e: Expr) -> type:
def _expr_type(e: Expr) -> type:
if isinstance(e, ExprNode):
return e.type
return type(e)
Expand All @@ -48,7 +48,7 @@ def compile(self, c: Compiler) -> str:

@property
def type(self):
return get_type(self.expr)
return _expr_type(self.expr)


def _drop_skips(exprs):
Expand Down Expand Up @@ -156,6 +156,8 @@ class Count(ExprNode):
expr: Expr = "*"
distinct: bool = False

type = int

def compile(self, c: Compiler) -> str:
expr = c.compile(self.expr)
if self.distinct:
Expand All @@ -174,12 +176,6 @@ def compile(self, c: Compiler) -> str:
return f"{self.name}({args})"


def _expr_type(e: Expr):
if isinstance(e, ExprNode):
return e.type
return type(e)


@dataclass
class CaseWhen(ExprNode):
cases: Sequence[Tuple[Expr, Expr]]
Expand Down Expand Up @@ -226,6 +222,9 @@ def __le__(self, other):
def __or__(self, other):
return BinBoolOp("OR", [self, other])

def __and__(self, other):
return BinBoolOp("AND", [self, other])

def is_distinct_from(self, other):
return IsDistinctFrom(self, other)

Expand Down Expand Up @@ -254,7 +253,7 @@ def compile(self, c: Compiler) -> str:

@property
def type(self):
types = {get_type(i) for i in self.args}
types = {_expr_type(i) for i in self.args}
if len(types) > 1:
raise TypeError(f"Expected all args to have the same type, got {types}")
(t,) = types
Expand Down Expand Up @@ -298,6 +297,16 @@ class TablePath(ExprNode, ITable):
path: DbPath
schema: Optional[Schema] = field(default=None, repr=False)

@property
def source_table(self):
return self

def compile(self, c: Compiler) -> str:
path = self.path # c.database._normalize_table_path(self.name)
return ".".join(map(c.quote, path))

# Statement shorthands

def create(self, source_table: ITable = None, *, if_not_exists=False):
if source_table is None and not self.schema:
raise ValueError("Either schema or source table needed to create table")
Expand All @@ -323,14 +332,6 @@ def insert_expr(self, expr: Expr):
expr = expr.select()
return InsertToTable(self, expr)

@property
def source_table(self):
return self

def compile(self, c: Compiler) -> str:
path = self.path # c.database._normalize_table_path(self.name)
return ".".join(map(c.quote, path))


@dataclass
class TableAlias(ExprNode, ITable):
Expand Down Expand Up @@ -386,7 +387,7 @@ def compile(self, parent_c: Compiler) -> str:
tables = [
t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in self.source_tables
]
c = parent_c.add_table_context(*tables).replace(in_join=True, in_select=False)
c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
op = " JOIN " if self.op is None else f" {self.op} JOIN "
joined = op.join(c.compile(t) for t in tables)

Expand All @@ -408,7 +409,7 @@ def compile(self, parent_c: Compiler) -> str:

class GroupBy(ITable):
def having(self):
pass
raise NotImplementedError()


@dataclass
Expand Down Expand Up @@ -546,26 +547,26 @@ class _ResolveColumn(ExprNode, LazyOps):
resolve_name: str
resolved: Expr = None

def resolve(self, expr):
assert self.resolved is None
def resolve(self, expr: Expr):
if self.resolved is not None:
raise RuntimeError("Already resolved!")
self.resolved = expr

def compile(self, c: Compiler) -> str:
def _get_resolved(self) -> Expr:
if self.resolved is None:
raise RuntimeError(f"Column not resolved: {self.resolve_name}")
return self.resolved.compile(c)
return self.resolved

def compile(self, c: Compiler) -> str:
return self._get_resolved().compile(c)

@property
def type(self):
if self.resolved is None:
raise RuntimeError(f"Column not resolved: {self.resolve_name}")
return self.resolved.type
return self._get_resolved().type

@property
def name(self):
if self.resolved is None:
raise RuntimeError(f"Column not resolved: {self.name}")
return self.resolved.name
return self._get_resolved().name


class This:
Expand All @@ -583,6 +584,8 @@ class In(ExprNode):
expr: Expr
list: Sequence[Expr]

type = bool

def compile(self, c: Compiler):
elems = ", ".join(map(c.compile, self.list))
return f"({c.compile(self.expr)} IN ({elems}))"
Expand All @@ -599,6 +602,8 @@ def compile(self, c: Compiler) -> str:

@dataclass
class Random(ExprNode):
type = float

def compile(self, c: Compiler) -> str:
return c.database.random()

Expand All @@ -618,6 +623,8 @@ def compile_for_insert(self, c: Compiler):
class Explain(ExprNode):
select: Select

type = str

def compile(self, c: Compiler) -> str:
return c.database.explain_as_text(c.compile(self.select))

Expand All @@ -640,7 +647,7 @@ def compile(self, c: Compiler) -> str:
if self.source_table:
return f"CREATE TABLE {ne}{c.compile(self.path)} AS {c.compile(self.source_table)}"

schema = ", ".join(f"{c.database.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
schema = ", ".join(f"{c.quote(k)} {c.database.type_repr(v)}" for k, v in self.path.schema.items())
return f"CREATE TABLE {ne}{c.compile(self.path)}({schema})"


Expand Down
10 changes: 5 additions & 5 deletions data_diff/queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class Compiler:

_counter: List = [0]

def quote(self, s: str):
return self.database.quote(s)

def compile(self, elem) -> str:
res = self._compile(elem)
if self.root and self._subqueries:
Expand Down Expand Up @@ -57,8 +54,11 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
self._counter[0] += 1
return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")

def add_table_context(self, *tables: Sequence):
return self.replace(_table_context=self._table_context + list(tables))
def add_table_context(self, *tables: Sequence, **kw):
return self.replace(_table_context=self._table_context + list(tables), **kw)

def quote(self, s: str):
return self.database.quote(s)


class Compilable(ABC):
Expand Down
49 changes: 15 additions & 34 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import re
import math
from typing import Iterable, Tuple, Union, Any, Sequence, Dict
from typing import TypeVar, Generic
from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict
from typing import TypeVar
from abc import ABC, abstractmethod
from urllib.parse import urlparse
from uuid import UUID
Expand Down Expand Up @@ -204,58 +204,39 @@ def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
V = TypeVar("V")


class CaseAwareMapping(ABC, Generic[V]):
class CaseAwareMapping(MutableMapping[str, V]):
@abstractmethod
def get_key(self, key: str) -> str:
...

@abstractmethod
def __getitem__(self, key: str) -> V:
...

@abstractmethod
def __setitem__(self, key: str, value: V):
...

@abstractmethod
def __contains__(self, key: str) -> bool:
...

def __repr__(self):
return repr(dict(self.items()))

@abstractmethod
def items(self) -> Iterable[Tuple[str, V]]:
...


class CaseInsensitiveDict(CaseAwareMapping):
def __init__(self, initial):
self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()}

def get_key(self, key: str) -> str:
return self._dict[key.lower()][0]

def __getitem__(self, key: str) -> V:
return self._dict[key.lower()][1]

def __iter__(self) -> Iterator[V]:
return iter(self._dict)

def __len__(self) -> int:
return len(self._dict)

def __setitem__(self, key: str, value):
k = key.lower()
if k in self._dict:
key = self._dict[k][0]
self._dict[k] = key, value

def __contains__(self, key):
return key.lower() in self._dict

def keys(self) -> Iterable[str]:
return self._dict.keys()
def __delitem__(self, key: str):
del self._dict[key.lower()]

def items(self) -> Iterable[Tuple[str, V]]:
return ((k, v[1]) for k, v in self._dict.items())
def get_key(self, key: str) -> str:
return self._dict[key.lower()][0]

def __len__(self):
return len(self._dict)
def __repr__(self) -> str:
return repr(dict(self.items()))


class CaseSensitiveDict(dict, CaseAwareMapping):
Expand Down

0 comments on commit b237fd8

Please sign in to comment.