diff --git a/clickhouse_sqlalchemy/orm/query.py b/clickhouse_sqlalchemy/orm/query.py index c4c082d..dc55f57 100644 --- a/clickhouse_sqlalchemy/orm/query.py +++ b/clickhouse_sqlalchemy/orm/query.py @@ -1,3 +1,5 @@ +from functools import partial + from sqlalchemy import exc from sqlalchemy.sql.base import _generative from sqlalchemy.orm.query import Query as BaseQuery @@ -10,6 +12,20 @@ ) +def _compile_state_factory(orig_compile_state_factory, query, statement, + *args, **kwargs): + rv = orig_compile_state_factory(statement, *args, **kwargs) + new_stmt = rv.statement + new_stmt._with_cube = query._with_cube + new_stmt._with_rollup = query._with_rollup + new_stmt._with_totals = query._with_totals + new_stmt._final_clause = query._final + new_stmt._sample_clause = sample_clause(query._sample) + new_stmt._limit_by_clause = query._limit_by + new_stmt._array_join = query._array_join + return rv + + class Query(BaseQuery): _with_cube = False _with_rollup = False @@ -19,19 +35,13 @@ class Query(BaseQuery): _limit_by = None _array_join = None - def _compile_context(self, *args, **kwargs): - context = super(Query, self)._compile_context(*args, **kwargs) - query = context.query - - query._with_cube = self._with_cube - query._with_rollup = self._with_rollup - query._with_totals = self._with_totals - query._final_clause = self._final - query._sample_clause = sample_clause(self._sample) - query._limit_by_clause = self._limit_by - query._array_join = self._array_join + def _statement_20(self, *args, **kwargs): + statement = super(Query, self)._statement_20(*args, **kwargs) + statement._compile_state_factory = partial( + _compile_state_factory, statement._compile_state_factory, self + ) - return context + return statement @_generative def with_cube(self): diff --git a/tests/testcase.py b/tests/testcase.py index 8e86f63..67627c0 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -36,8 +36,7 @@ def _compile(self, clause, bind=None, literal_binds=False, if bind is None: bind = self.session.bind if isinstance(clause, Query): - context = clause._compile_context() - clause = context.query + clause = clause._statement_20() kw = {} compile_kwargs = {}