Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: avoid adding dummy WHERE clause into UPDATE and DELETE queires #516

Merged
merged 12 commits into from
Nov 18, 2020
4 changes: 3 additions & 1 deletion django_spanner/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SQLInsertCompiler as BaseSQLInsertCompiler,
SQLUpdateCompiler as BaseSQLUpdateCompiler,
)
from django.db.utils import DatabaseError
from django.db.utils import DatabaseError, add_dummy_where


class SQLCompiler(BaseSQLCompiler):
Expand Down Expand Up @@ -90,6 +90,8 @@ def get_combinator_sql(self, combinator, all):
params = []
for part in args_parts:
params.extend(part)

result = add_dummy_where(result)
return result, params


Expand Down
15 changes: 15 additions & 0 deletions django_spanner/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import django
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
import sqlparse
from django.core.exceptions import ImproperlyConfigured
from django.utils.version import get_version_tuple

Expand All @@ -18,3 +19,17 @@ def check_django_compatability():
A=django.VERSION[0], B=django.VERSION[1], C=__version__
)
)


def add_dummy_where(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.
"""
if any(
isinstance(token, sqlparse.sql.Where)
for token in sqlparse.parse(sql)[0]
):
return sql

return sql + " WHERE 1=1"
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
107 changes: 106 additions & 1 deletion google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,105 @@

RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL)

SPANNER_RESERVED_KEYWORDS = {
"ALL",
"AND",
"ANY",
"ARRAY",
"AS",
"ASC",
"ASSERT_ROWS_MODIFIED",
"AT",
"BETWEEN",
"BY",
"CASE",
"CAST",
"COLLATE",
"CONTAINS",
"CREATE",
"CROSS",
"CUBE",
"CURRENT",
"DEFAULT",
"DEFINE",
"DESC",
"DISTINCT",
"DROP",
"ELSE",
"END",
"ENUM",
"ESCAPE",
"EXCEPT",
"EXCLUDE",
"EXISTS",
"EXTRACT",
"FALSE",
"FETCH",
"FOLLOWING",
"FOR",
"FROM",
"FULL",
"GROUP",
"GROUPING",
"GROUPS",
"HASH",
"HAVING",
"IF",
"IGNORE",
"IN",
"INNER",
"INTERSECT",
"INTERVAL",
"INTO",
"IS",
"JOIN",
"LATERAL",
"LEFT",
"LIKE",
"LIMIT",
"LOOKUP",
"MERGE",
"NATURAL",
"NEW",
"NO",
"NOT",
"NULL",
"NULLS",
"OF",
"ON",
"OR",
"ORDER",
"OUTER",
"OVER",
"PARTITION",
"PRECEDING",
"PROTO",
"RANGE",
"RECURSIVE",
"RESPECT",
"RIGHT",
"ROLLUP",
"ROWS",
"SELECT",
"SET",
"SOME",
"STRUCT",
"TABLESAMPLE",
"THEN",
"TO",
"TREAT",
"TRUE",
"UNBOUNDED",
"UNION",
"UNNEST",
"USING",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
"WITHIN",
}

IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def classify_stmt(query):
"""Determine SQL query type.
Expand Down Expand Up @@ -517,13 +616,19 @@ def ensure_where_clause(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

:type sql: `str`
:param sql: SQL code to check.
"""
if any(
isinstance(token, sqlparse.sql.Where)
for token in sqlparse.parse(sql)[0]
):
return sql
return sql + " WHERE 1=1"

raise ProgrammingError(
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
)


def escape_name(name):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_do_execute_update(self):
def run_helper(ret_value):
transaction.execute_update.return_value = ret_value
res = cursor._do_execute_update(
transaction=transaction, sql="sql", params=None,
transaction=transaction, sql="SELECT * WHERE true", params={},
)
return res

Expand Down
42 changes: 16 additions & 26 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,36 +407,26 @@ def test_get_param_types_none(self):
self.assertEqual(get_param_types(None), None)

def test_ensure_where_clause(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause

cases = [
(
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
),
(
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1",
),
(
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"),
]
cases = (
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
)
err_cases = (
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"DELETE * FROM TABLE",
)
for sql in cases:
with self.subTest(sql=sql):
ensure_where_clause(sql)

for sql, want in cases:
for sql in err_cases:
with self.subTest(sql=sql):
got = ensure_where_clause(sql)
self.assertEqual(got, want)
with self.assertRaises(ProgrammingError):
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
ensure_where_clause(sql)

def test_escape_name(self):
from google.cloud.spanner_dbapi.parse_utils import escape_name
Expand Down