Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Queries: Implemented GROUP BY and HAVING
Browse files Browse the repository at this point in the history
  • Loading branch information
erezsh committed Nov 11, 2022
1 parent 27dbe7c commit c9fe4e0
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 39 deletions.
120 changes: 81 additions & 39 deletions data_diff/sqeleton/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,13 @@ def join(self, target):
return Join(self, target)

def group_by(self, *, keys=None, values=None):
# TODO
assert keys or values
raise NotImplementedError()
keys = _drop_skips(keys)
resolve_names(self.source_table, keys)

values = _drop_skips(values)
resolve_names(self.source_table, values)

return GroupBy(self, keys, values)

def with_schema(self):
# TODO
Expand Down Expand Up @@ -166,38 +170,6 @@ def compile(self, c: Compiler) -> str:
return f"count({expr})"


@dataclass
class Func(ExprNode):
name: str
args: Sequence[Expr]

def compile(self, c: Compiler) -> str:
args = ", ".join(c.compile(e) for e in self.args)
return f"{self.name}({args})"


@dataclass
class CaseWhen(ExprNode):
cases: Sequence[Tuple[Expr, Expr]]
else_: Expr = None

def compile(self, c: Compiler) -> str:
assert self.cases
when_thens = " ".join(f"WHEN {c.compile(when)} THEN {c.compile(then)}" for when, then in self.cases)
else_ = (" ELSE " + c.compile(self.else_)) if self.else_ is not None else ""
return f"CASE {when_thens}{else_} END"

@property
def type(self):
when_types = {_expr_type(w) for _c, w in self.cases}
if self.else_:
when_types |= _expr_type(self.else_)
if len(when_types) > 1:
raise RuntimeError(f"Non-matching types in when: {when_types}")
(t,) = when_types
return t


class LazyOps:
def __add__(self, other):
return BinOp("+", [self, other])
Expand Down Expand Up @@ -235,6 +207,39 @@ def sum(self):
return Func("SUM", [self])


@dataclass
class Func(ExprNode, LazyOps):
name: str
args: Sequence[Expr]

def compile(self, c: Compiler) -> str:
args = ", ".join(c.compile(e) for e in self.args)
return f"{self.name}({args})"


@dataclass
class CaseWhen(ExprNode):
cases: Sequence[Tuple[Expr, Expr]]
else_: Expr = None

def compile(self, c: Compiler) -> str:
assert self.cases
when_thens = " ".join(f"WHEN {c.compile(when)} THEN {c.compile(then)}" for when, then in self.cases)
else_ = (" ELSE " + c.compile(self.else_)) if self.else_ is not None else ""
return f"CASE {when_thens}{else_} END"

@property
def type(self):
when_types = {_expr_type(w) for _c, w in self.cases}
if self.else_:
when_types |= _expr_type(self.else_)
if len(when_types) > 1:
raise RuntimeError(f"Non-matching types in when: {when_types}")
(t,) = when_types
return t



@dataclass(eq=False, order=False)
class IsDistinctFrom(ExprNode, LazyOps):
a: Expr
Expand Down Expand Up @@ -410,9 +415,41 @@ def compile(self, parent_c: Compiler) -> str:
return select


class GroupBy(ITable):
def having(self):
raise NotImplementedError()
@dataclass
class GroupBy(ExprNode, ITable):
table: ITable
keys: Sequence[Expr] = None # IKey?
values: Sequence[Expr] = None
having_exprs: Sequence[Expr] = None

def __post_init__(self):
assert self.keys or self.values

def having(self, *exprs):
exprs = args_as_tuple(exprs)
exprs = _drop_skips(exprs)
if not exprs:
return self

resolve_names(self.table, exprs)
return self.replace(having_exprs=(self.having_exprs or []) + exprs)

def compile(self, c: Compiler) -> str:
keys = [str(i+1) for i in range(len(self.keys))]
columns = (self.keys or []) + (self.values or [])
if isinstance(self.table, Select) and self.table.columns is None and self.table.group_by_exprs is None:
return c.compile(self.table.replace(
columns=columns,
group_by_exprs=keys, # XXX pass Expr instances, not strings (Code)
having_exprs=self.having_exprs
))

keys_str = ", ".join(keys)
columns_str = ", ".join(c.compile(x) for x in columns)
having_str = " HAVING " + " AND ".join(map(c.compile, self.having_exprs)) if self.having_exprs is not None else ''
return f'SELECT {columns_str} FROM {c.replace(in_select=True).compile(self.table)} GROUP BY {keys_str}{having_str}'




@dataclass
Expand Down Expand Up @@ -452,6 +489,7 @@ class Select(ExprNode, ITable):
where_exprs: Sequence[Expr] = None
order_by_exprs: Sequence[Expr] = None
group_by_exprs: Sequence[Expr] = None
having_exprs: Sequence[Expr] = None
limit_expr: int = None
distinct: bool = False

Expand Down Expand Up @@ -482,6 +520,10 @@ def compile(self, parent_c: Compiler) -> str:
if self.group_by_exprs:
select += " GROUP BY " + ", ".join(map(c.compile, self.group_by_exprs))

if self.having_exprs:
assert self.group_by_exprs
select += " HAVING " + " AND ".join(map(c.compile, self.having_exprs))

if self.order_by_exprs:
select += " ORDER BY " + ", ".join(map(c.compile, self.order_by_exprs))

Expand Down Expand Up @@ -555,7 +597,7 @@ def _named_exprs_as_aliases(named_exprs):
def resolve_names(source_table, exprs):
i = 0
for expr in exprs:
# Iterate recursively and update _ResolveColumn with the right expression
# Iterate recursively and update _ResolveColumn instances with the right expression
if isinstance(expr, ExprNode):
for v in expr._dfs_values():
if isinstance(v, _ResolveColumn):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,27 @@ def test_ops(self):

q = c.compile(t.select(this.b.like(this.c)))
self.assertEqual(q, "SELECT (b LIKE c) FROM a")

def test_group_by(self):
c = Compiler(MockDatabase())
t = table("a")

q = c.compile(t.group_by(keys=[this.b], values=[this.c]))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1")

q = c.compile(t.where(this.b > 1).group_by(keys=[this.b], values=[this.c]))
self.assertEqual(q, "SELECT b, c FROM a WHERE (b > 1) GROUP BY 1")

q = c.compile(t.select(this.b).group_by(keys=[this.b], values=[]))
self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp1 GROUP BY 1")

# Having
q = c.compile(t.group_by(keys=[this.b], values=[this.c]).having(this.b > 1))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (b > 1)")

q = c.compile(t.select(this.b).group_by(keys=[this.b], values=[]).having(this.b > 1))
self.assertEqual(q, "SELECT b FROM (SELECT b FROM a) tmp2 GROUP BY 1 HAVING (b > 1)")

# Having sum
q = c.compile(t.group_by(keys=[this.b], values=[this.c]).having(this.b.sum() > 1))
self.assertEqual(q, "SELECT b, c FROM a GROUP BY 1 HAVING (SUM(b) > 1)")

0 comments on commit c9fe4e0

Please sign in to comment.