From d6c768775ada5bb38b92ae000f67e37a73e89171 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Mon, 21 Sep 2020 11:56:33 +0300 Subject: [PATCH 1/7] feat: avoid adding dummy WHERE clause into UPDATE and DELETE queires --- google/cloud/spanner_dbapi/__init__.py | 3 + google/cloud/spanner_dbapi/parse_utils.py | 204 +++++++++++----------- tests/spanner_dbapi/test_parse_utils.py | 41 ++--- 3 files changed, 123 insertions(+), 125 deletions(-) diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py index 098b0bd786..ba3671b12b 100644 --- a/google/cloud/spanner_dbapi/__init__.py +++ b/google/cloud/spanner_dbapi/__init__.py @@ -71,6 +71,9 @@ def connect( If none are specified, the client will attempt to ascertain the credentials from the environment. + :type user_agent: :class:`str` + :param user_agent: (Optional) The user agent to be used with API requests. + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` :returns: Connection object associated with the given Cloud Spanner resource. diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 68d10867e5..8b5bf27578 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -8,6 +8,7 @@ import datetime import decimal +import os import re from functools import reduce @@ -53,6 +54,105 @@ RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL) +SPANNER_RESERVED_KEYWORDS = { + "ALL", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASSERT_ROWS_MODIFIED", + "AT", + "BETWEEN", + "BY", + "CASE", + "CAST", + "COLLATE", + "CONTAINS", + "CREATE", + "CROSS", + "CUBE", + "CURRENT", + "DEFAULT", + "DEFINE", + "DESC", + "DISTINCT", + "DROP", + "ELSE", + "END", + "ENUM", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXISTS", + "EXTRACT", + "FALSE", + "FETCH", + "FOLLOWING", + "FOR", + "FROM", + "FULL", + "GROUP", + "GROUPING", + "GROUPS", + "HASH", + "HAVING", + "IF", + "IGNORE", + "IN", + "INNER", + "INTERSECT", + "INTERVAL", + "INTO", + "IS", + "JOIN", + "LATERAL", + "LEFT", + "LIKE", + "LIMIT", + "LOOKUP", + "MERGE", + "NATURAL", + "NEW", + "NO", + "NOT", + "NULL", + "NULLS", + "OF", + "ON", + "OR", + "ORDER", + "OUTER", + "OVER", + "PARTITION", + "PRECEDING", + "PROTO", + "RANGE", + "RECURSIVE", + "RESPECT", + "RIGHT", + "ROLLUP", + "ROWS", + "SELECT", + "SET", + "SOME", + "STRUCT", + "TABLESAMPLE", + "THEN", + "TO", + "TREAT", + "TRUE", + "UNBOUNDED", + "UNION", + "UNNEST", + "USING", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHIN", +} + def classify_stmt(query): """Determine SQL query type. @@ -402,107 +502,13 @@ def ensure_where_clause(sql): for token in sqlparse.parse(sql)[0] ): return sql - return sql + " WHERE 1=1" + if os.environ.get("RUNNING_SPANNER_BACKEND_TESTS") == "1": + return sql + " WHERE 1=1" -SPANNER_RESERVED_KEYWORDS = { - "ALL", - "AND", - "ANY", - "ARRAY", - "AS", - "ASC", - "ASSERT_ROWS_MODIFIED", - "AT", - "BETWEEN", - "BY", - "CASE", - "CAST", - "COLLATE", - "CONTAINS", - "CREATE", - "CROSS", - "CUBE", - "CURRENT", - "DEFAULT", - "DEFINE", - "DESC", - "DISTINCT", - "DROP", - "ELSE", - "END", - "ENUM", - "ESCAPE", - "EXCEPT", - "EXCLUDE", - "EXISTS", - "EXTRACT", - "FALSE", - "FETCH", - "FOLLOWING", - "FOR", - "FROM", - "FULL", - "GROUP", - "GROUPING", - "GROUPS", - "HASH", - "HAVING", - "IF", - "IGNORE", - "IN", - "INNER", - "INTERSECT", - "INTERVAL", - "INTO", - "IS", - "JOIN", - "LATERAL", - "LEFT", - "LIKE", - "LIMIT", - "LOOKUP", - "MERGE", - "NATURAL", - "NEW", - "NO", - "NOT", - "NULL", - "NULLS", - "OF", - "ON", - "OR", - "ORDER", - "OUTER", - "OVER", - "PARTITION", - "PRECEDING", - "PROTO", - "RANGE", - "RECURSIVE", - "RESPECT", - "RIGHT", - "ROLLUP", - "ROWS", - "SELECT", - "SET", - "SOME", - "STRUCT", - "TABLESAMPLE", - "THEN", - "TO", - "TREAT", - "TRUE", - "UNBOUNDED", - "UNION", - "UNNEST", - "USING", - "WHEN", - "WHERE", - "WINDOW", - "WITH", - "WITHIN", -} + raise ProgrammingError( + "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" + ) def escape_name(name): diff --git a/tests/spanner_dbapi/test_parse_utils.py b/tests/spanner_dbapi/test_parse_utils.py index 615b6a9069..1ced6f774c 100644 --- a/tests/spanner_dbapi/test_parse_utils.py +++ b/tests/spanner_dbapi/test_parse_utils.py @@ -445,34 +445,23 @@ def test_get_param_types(self): self.assertEqual(got_param_types, want_param_types) def test_ensure_where_clause(self): - cases = [ - ( - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - ), - ( - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1", - ), - ( - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"), - ] + cases = ( + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ) + err_cases = ( + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", + "DELETE * FROM TABLE", + ) + for sql in cases: + with self.subTest(sql=sql): + ensure_where_clause(sql) - for sql, want in cases: + for sql in err_cases: with self.subTest(sql=sql): - got = ensure_where_clause(sql) - self.assertEqual(got, want) + with self.assertRaises(ProgrammingError): + ensure_where_clause(sql) def test_escape_name(self): cases = [ From 285dda5aaccdbfe066a0f1ba21f1330db6290968 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 6 Nov 2020 12:39:32 +0300 Subject: [PATCH 2/7] add the WHERE clause ensuring code into Django specific code --- django_spanner/compiler.py | 4 +++- django_spanner/utils.py | 15 +++++++++++++++ google/cloud/spanner_dbapi/parse_utils.py | 8 ++++++++ tests/unit/spanner_dbapi/test_cursor.py | 2 +- tests/unit/spanner_dbapi/test_parse_utils.py | 3 +++ 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/django_spanner/compiler.py b/django_spanner/compiler.py index 202ef103dc..34fa886caf 100644 --- a/django_spanner/compiler.py +++ b/django_spanner/compiler.py @@ -12,7 +12,7 @@ SQLInsertCompiler as BaseSQLInsertCompiler, SQLUpdateCompiler as BaseSQLUpdateCompiler, ) -from django.db.utils import DatabaseError +from django.db.utils import DatabaseError, ensure_where_clause class SQLCompiler(BaseSQLCompiler): @@ -90,6 +90,8 @@ def get_combinator_sql(self, combinator, all): params = [] for part in args_parts: params.extend(part) + + result = ensure_where_clause(result) return result, params diff --git a/django_spanner/utils.py b/django_spanner/utils.py index 1136c33a87..28227a7ef1 100644 --- a/django_spanner/utils.py +++ b/django_spanner/utils.py @@ -1,4 +1,5 @@ import django +import sqlparse from django.core.exceptions import ImproperlyConfigured from django.utils.version import get_version_tuple @@ -18,3 +19,17 @@ def check_django_compatability(): A=django.VERSION[0], B=django.VERSION[1], C=__version__ ) ) + + +def ensure_where_clause(sql): + """ + Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. + Add a dummy WHERE clause if necessary. + """ + if any( + isinstance(token, sqlparse.sql.Where) + for token in sqlparse.parse(sql)[0] + ): + return sql + + return sql + " WHERE 1=1" diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 7951333dcc..a1367122f0 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -617,6 +617,9 @@ def ensure_where_clause(sql): """ Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. Add a dummy WHERE clause if necessary. + + :type sql: `str` + :param sql: SQL code to check. """ if any( isinstance(token, sqlparse.sql.Where) @@ -627,6 +630,11 @@ def ensure_where_clause(sql): if os.environ.get("RUNNING_SPANNER_BACKEND_TESTS") == "1": return sql + " WHERE 1=1" + raise ProgrammingError( + "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" + ) + + def escape_name(name): """ Apply backticks to the name that either contain '-' or diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 09288df94e..a73265e932 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -94,7 +94,7 @@ def test_do_execute_update(self): def run_helper(ret_value): transaction.execute_update.return_value = ret_value res = cursor._do_execute_update( - transaction=transaction, sql="sql", params=None, + transaction=transaction, sql="SELECT * WHERE true", params={}, ) return res diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index be2d26630f..d68e4118fd 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -407,6 +407,9 @@ def test_get_param_types_none(self): self.assertEqual(get_param_types(None), None) def test_ensure_where_clause(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause + cases = ( "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", From 164669f9b276608bced5340acff3df44c61d76de Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 6 Nov 2020 12:46:20 +0300 Subject: [PATCH 3/7] erase testing WHERE clause --- google/cloud/spanner_dbapi/parse_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index a1367122f0..7cd9adb8c7 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -627,9 +627,6 @@ def ensure_where_clause(sql): ): return sql - if os.environ.get("RUNNING_SPANNER_BACKEND_TESTS") == "1": - return sql + " WHERE 1=1" - raise ProgrammingError( "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" ) From 7d7ec73ee6ed30a2556ec9534badf6bdebe2523b Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 6 Nov 2020 12:55:18 +0300 Subject: [PATCH 4/7] fix imports --- google/cloud/spanner_dbapi/parse_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 7cd9adb8c7..b8cda9d531 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -8,7 +8,6 @@ import datetime import decimal -import os import re from functools import reduce From fe89045eeece3851c5cb1dbaf2de89effb46ad31 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Mon, 16 Nov 2020 11:21:23 +0300 Subject: [PATCH 5/7] rename func --- django_spanner/compiler.py | 4 ++-- django_spanner/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/django_spanner/compiler.py b/django_spanner/compiler.py index 34fa886caf..106686d445 100644 --- a/django_spanner/compiler.py +++ b/django_spanner/compiler.py @@ -12,7 +12,7 @@ SQLInsertCompiler as BaseSQLInsertCompiler, SQLUpdateCompiler as BaseSQLUpdateCompiler, ) -from django.db.utils import DatabaseError, ensure_where_clause +from django.db.utils import DatabaseError, add_dummy_where class SQLCompiler(BaseSQLCompiler): @@ -91,7 +91,7 @@ def get_combinator_sql(self, combinator, all): for part in args_parts: params.extend(part) - result = ensure_where_clause(result) + result = add_dummy_where(result) return result, params diff --git a/django_spanner/utils.py b/django_spanner/utils.py index 28227a7ef1..bf9c3940d8 100644 --- a/django_spanner/utils.py +++ b/django_spanner/utils.py @@ -21,7 +21,7 @@ def check_django_compatability(): ) -def ensure_where_clause(sql): +def add_dummy_where(sql): """ Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. Add a dummy WHERE clause if necessary. From 3d045713a097294a353cd2844b5dc2f8afbef320 Mon Sep 17 00:00:00 2001 From: Chris Kleinknecht Date: Tue, 17 Nov 2020 17:10:39 -0800 Subject: [PATCH 6/7] Change ensure sig --- django_spanner/utils.py | 6 ++++++ google/cloud/spanner_dbapi/cursor.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/django_spanner/utils.py b/django_spanner/utils.py index bf9c3940d8..444afe053d 100644 --- a/django_spanner/utils.py +++ b/django_spanner/utils.py @@ -1,3 +1,9 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + import django import sqlparse from django.core.exceptions import ImproperlyConfigured diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 6997752a42..e41f0f381a 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -117,7 +117,7 @@ def close(self): self._is_closed = True def _do_execute_update(self, transaction, sql, params, param_types=None): - sql = parse_utils.ensure_where_clause(sql) + parse_utils.ensure_where_clause(sql) sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) result = transaction.execute_update( From dc060c15d4ee43b5b43bb9816599442cf2f1a7e7 Mon Sep 17 00:00:00 2001 From: Chris Kleinknecht Date: Tue, 17 Nov 2020 17:10:58 -0800 Subject: [PATCH 7/7] Add license boilerplate --- google/cloud/spanner_dbapi/parse_utils.py | 116 ++-------------------- 1 file changed, 7 insertions(+), 109 deletions(-) diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index b8cda9d531..0e69dbc0ca 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -164,105 +164,6 @@ RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL) -SPANNER_RESERVED_KEYWORDS = { - "ALL", - "AND", - "ANY", - "ARRAY", - "AS", - "ASC", - "ASSERT_ROWS_MODIFIED", - "AT", - "BETWEEN", - "BY", - "CASE", - "CAST", - "COLLATE", - "CONTAINS", - "CREATE", - "CROSS", - "CUBE", - "CURRENT", - "DEFAULT", - "DEFINE", - "DESC", - "DISTINCT", - "DROP", - "ELSE", - "END", - "ENUM", - "ESCAPE", - "EXCEPT", - "EXCLUDE", - "EXISTS", - "EXTRACT", - "FALSE", - "FETCH", - "FOLLOWING", - "FOR", - "FROM", - "FULL", - "GROUP", - "GROUPING", - "GROUPS", - "HASH", - "HAVING", - "IF", - "IGNORE", - "IN", - "INNER", - "INTERSECT", - "INTERVAL", - "INTO", - "IS", - "JOIN", - "LATERAL", - "LEFT", - "LIKE", - "LIMIT", - "LOOKUP", - "MERGE", - "NATURAL", - "NEW", - "NO", - "NOT", - "NULL", - "NULLS", - "OF", - "ON", - "OR", - "ORDER", - "OUTER", - "OVER", - "PARTITION", - "PRECEDING", - "PROTO", - "RANGE", - "RECURSIVE", - "RESPECT", - "RIGHT", - "ROLLUP", - "ROWS", - "SELECT", - "SET", - "SOME", - "STRUCT", - "TABLESAMPLE", - "THEN", - "TO", - "TREAT", - "TRUE", - "UNBOUNDED", - "UNION", - "UNNEST", - "USING", - "WHEN", - "WHERE", - "WINDOW", - "WITH", - "WITHIN", -} - def classify_stmt(query): """Determine SQL query type. @@ -614,21 +515,18 @@ def get_param_types(params): def ensure_where_clause(sql): """ - Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. - Add a dummy WHERE clause if necessary. + Raise unless `sql` includes a WHERE clause. - :type sql: `str` - :param sql: SQL code to check. + :type sql: str + :param sql: SQL statement to check. """ - if any( + if not any( isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0] ): - return sql - - raise ProgrammingError( - "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" - ) + raise ProgrammingError( + "Cloud Spanner requires a WHERE clause in UPDATE and DELETE statements" + ) def escape_name(name):