From c5d19b591c2b0915e76e5a459bd87c2b9b8461d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Fri, 2 Apr 2021 15:00:42 +0700 Subject: [PATCH] feat: support sqlalchemy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Đặng Minh Dũng --- integration_tests/__init__.py | 11 + setup.py | 28 +- tests/__init__.py | 0 tests/integration/test_dbapi_integration.py | 2 +- tests/sqlalchemy/__init__.py | 11 + tests/sqlalchemy/conftest.py | 46 +++ tests/sqlalchemy/test_datatype_parse.py | 183 ++++++++++++ tests/sqlalchemy/test_datatype_split.py | 95 +++++++ tests/sqlalchemy/test_dialect.py | 63 +++++ trino/__init__.py | 9 - trino/auth.py | 56 +++- trino/dbapi.py | 40 ++- trino/sqlalchemy/__init__.py | 14 + trino/sqlalchemy/compiler.py | 143 ++++++++++ trino/sqlalchemy/datatype.py | 200 +++++++++++++ trino/sqlalchemy/dialect.py | 298 ++++++++++++++++++++ trino/sqlalchemy/error.py | 24 ++ trino/transaction.py | 28 +- 18 files changed, 1196 insertions(+), 55 deletions(-) create mode 100644 integration_tests/__init__.py create mode 100644 tests/__init__.py create mode 100644 tests/sqlalchemy/__init__.py create mode 100644 tests/sqlalchemy/conftest.py create mode 100644 tests/sqlalchemy/test_datatype_parse.py create mode 100644 tests/sqlalchemy/test_datatype_split.py create mode 100644 tests/sqlalchemy/test_dialect.py create mode 100644 trino/sqlalchemy/__init__.py create mode 100644 trino/sqlalchemy/compiler.py create mode 100644 trino/sqlalchemy/datatype.py create mode 100644 trino/sqlalchemy/dialect.py create mode 100644 trino/sqlalchemy/error.py diff --git a/integration_tests/__init__.py b/integration_tests/__init__.py new file mode 100644 index 00000000..4d9a9249 --- /dev/null +++ b/integration_tests/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/setup.py b/setup.py index db35d01d..25a0492c 100755 --- a/setup.py +++ b/setup.py @@ -14,24 +14,23 @@ import ast import re -from setuptools import setup import textwrap +from setuptools import setup _version_re = re.compile(r"__version__\s+=\s+(.*)") - with open("trino/__init__.py", "rb") as f: trino_version = _version_re.search(f.read().decode("utf-8")) assert trino_version is not None version = str(ast.literal_eval(trino_version.group(1))) - kerberos_require = ["requests_kerberos"] +sqlalchemy_require = ["sqlalchemy~=1.3"] -all_require = kerberos_require + [] +all_require = kerberos_require + sqlalchemy_require -tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click"] +tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click", "assertpy"] setup( name="trino", @@ -44,19 +43,17 @@ description="Client for the Trino distributed SQL Engine", long_description=textwrap.dedent( """ - Client for Trino (https://trino.io), a distributed SQL engine for - interactive and batch big data processing. Provides a low-level client and - a DBAPI 2.0 implementation. - """ + Client for Trino (https://trino.io), a distributed SQL engine for + interactive and batch big data processing. Provides a low-level client and + a DBAPI 2.0 implementation. + """ ), license="Apache 2.0", classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", - "Operating System :: MacOS :: MacOS X", - "Operating System :: POSIX", - "Operating System :: Microsoft :: Windows", + "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", @@ -66,6 +63,7 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Database", "Topic :: Database :: Front-Ends", ], python_requires='>=3.6', @@ -73,6 +71,12 @@ extras_require={ "all": all_require, "kerberos": kerberos_require, + "sqlalchemy": sqlalchemy_require, "tests": tests_require, }, + entry_points={ + "sqlalchemy.dialects": [ + "trino = trino.sqlalchemy.dialect:TrinoDialect", + ] + }, ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 365c0043..d16f8a53 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -15,7 +15,7 @@ import pytest import pytz -import trino +import trino.dbapi from trino.exceptions import TrinoQueryError from trino.transaction import IsolationLevel diff --git a/tests/sqlalchemy/__init__.py b/tests/sqlalchemy/__init__.py new file mode 100644 index 00000000..4d9a9249 --- /dev/null +++ b/tests/sqlalchemy/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py new file mode 100644 index 00000000..0644f741 --- /dev/null +++ b/tests/sqlalchemy/conftest.py @@ -0,0 +1,46 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from assertpy import add_extension, assert_that +from sqlalchemy.sql.sqltypes import ARRAY + +from trino.sqlalchemy.datatype import MAP, ROW, SQLType + + +def assert_sqltype(this: SQLType, that: SQLType): + if isinstance(this, type): + this = this() + if isinstance(that, type): + that = that() + assert_that(type(this)).is_same_as(type(that)) + if isinstance(this, ARRAY): + assert_sqltype(this.item_type, that.item_type) + if this.dimensions is None or this.dimensions == 1: + # ARRAY(dimensions=None) == ARRAY(dimensions=1) + assert_that(that.dimensions).is_in(None, 1) + else: + assert_that(this.dimensions).is_equal_to(this.dimensions) + elif isinstance(this, MAP): + assert_sqltype(this.key_type, that.key_type) + assert_sqltype(this.value_type, that.value_type) + elif isinstance(this, ROW): + assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types)) + for (this_attr, that_attr) in zip(this.attr_types, that.attr_types): + assert_that(this_attr[0]).is_equal_to(that_attr[0]) + assert_sqltype(this_attr[1], that_attr[1]) + else: + assert_that(str(this)).is_equal_to(str(that)) + + +@add_extension +def is_sqltype(self, that): + this = self.val + assert_sqltype(this, that) diff --git a/tests/sqlalchemy/test_datatype_parse.py b/tests/sqlalchemy/test_datatype_parse.py new file mode 100644 index 00000000..58f1cb88 --- /dev/null +++ b/tests/sqlalchemy/test_datatype_parse.py @@ -0,0 +1,183 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from assertpy import assert_that +from sqlalchemy.sql.sqltypes import ( + CHAR, VARCHAR, + ARRAY, + INTEGER, DECIMAL, + DATE, TIME, TIMESTAMP +) +from sqlalchemy.sql.type_api import TypeEngine + +from trino.sqlalchemy import datatype +from trino.sqlalchemy.datatype import MAP, ROW + + +@pytest.mark.parametrize( + 'type_str, sql_type', + datatype._type_map.items(), + ids=datatype._type_map.keys() +) +def test_parse_simple_type(type_str: str, sql_type: TypeEngine): + actual_type = datatype.parse_sqltype(type_str) + if not isinstance(actual_type, type): + actual_type = type(actual_type) + assert_that(actual_type).is_equal_to(sql_type) + + +parse_cases_testcases = { + 'char(10)': CHAR(10), + 'Char(10)': CHAR(10), + 'char': CHAR(), + 'cHaR': CHAR(), + 'VARCHAR(10)': VARCHAR(10), + 'varCHAR(10)': VARCHAR(10), + 'VARchar(10)': VARCHAR(10), + 'VARCHAR': VARCHAR(), + 'VaRchAr': VARCHAR(), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_cases_testcases.items(), + ids=parse_cases_testcases.keys() +) +def test_parse_cases(type_str: str, sql_type: TypeEngine): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_type_options_testcases = { + 'CHAR(10)': CHAR(10), + 'VARCHAR(10)': VARCHAR(10), + 'DECIMAL(20)': DECIMAL(20), + 'DECIMAL(20, 3)': DECIMAL(20, 3), + # TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107) +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_type_options_testcases.items(), + ids=parse_type_options_testcases.keys() +) +def test_parse_type_options(type_str: str, sql_type: TypeEngine): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_array_testcases = { + 'array(integer)': ARRAY(INTEGER()), + 'array(varchar(10))': ARRAY(VARCHAR(10)), + 'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)), + 'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2), + 'array(map(char, integer))': ARRAY(MAP(CHAR(), INTEGER())), + 'array(row(a integer, b varchar))': ARRAY(ROW([("a", INTEGER()), ("b", VARCHAR())])), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_array_testcases.items(), + ids=parse_array_testcases.keys() +) +def test_parse_array(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_map_testcases = { + 'map(char, integer)': MAP(CHAR(), INTEGER()), + 'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)), + 'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)), + 'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))), + 'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))), + 'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_map_testcases.items(), + ids=parse_map_testcases.keys() +) +def test_parse_map(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_row_testcases = { + 'row(a integer, b varchar)': + ROW(attr_types=[ + ("a", INTEGER()), + ("b", VARCHAR()), + ]), + 'row(a varchar(20), b decimal(20,3))': + ROW(attr_types=[ + ("a", VARCHAR(20)), + ("b", DECIMAL(20, 3)), + ]), + 'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))': + ROW(attr_types=[ + ("x", ARRAY(VARCHAR(10))), + ("y", ARRAY(VARCHAR(10), dimensions=2)), + ("z", DECIMAL(20, 3)), + ]), + 'row(min timestamp(6) with time zone, max timestamp(6) with time zone)': + ROW(attr_types=[ + ("min", TIMESTAMP(timezone=True)), + ("max", TIMESTAMP(timezone=True)), + ]), + 'row("first name" varchar, "last name" varchar)': + ROW(attr_types=[ + ("first name", VARCHAR()), + ("last name", VARCHAR()), + ]), + 'row("foo,bar" varchar, "foo(bar)" varchar, "foo\\"bar" varchar)': + ROW(attr_types=[ + (r'foo,bar', VARCHAR()), + (r'foo(bar)', VARCHAR()), + (r'foo"bar', VARCHAR()), + ]), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_row_testcases.items(), + ids=parse_row_testcases.keys() +) +def test_parse_row(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_datetime_testcases = { + # TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107) + 'date': DATE(), + 'time': TIME(), + 'time with time zone': TIME(timezone=True), + 'timestamp': TIMESTAMP(), + 'timestamp with time zone': TIMESTAMP(timezone=True), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_datetime_testcases.items(), + ids=parse_datetime_testcases.keys() +) +def test_parse_datetime(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) diff --git a/tests/sqlalchemy/test_datatype_split.py b/tests/sqlalchemy/test_datatype_split.py new file mode 100644 index 00000000..f6049038 --- /dev/null +++ b/tests/sqlalchemy/test_datatype_split.py @@ -0,0 +1,95 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import pytest +from assertpy import assert_that + +from trino.sqlalchemy import datatype + +split_string_testcases = { + '10': ['10'], + '10,3': ['10', '3'], + '"a,b",c': ['"a,b"', 'c'], + '"a,b","c,d"': ['"a,b"', '"c,d"'], + r'"a,\"b\",c",d': [r'"a,\"b\",c"', 'd'], + r'"foo(bar,\"baz\")",quiz': [r'"foo(bar,\"baz\")"', 'quiz'], + 'varchar': ['varchar'], + 'varchar,int': ['varchar', 'int'], + 'varchar,int,float': ['varchar', 'int', 'float'], + 'array(varchar)': ['array(varchar)'], + 'array(varchar),int': ['array(varchar)', 'int'], + 'array(varchar(20))': ['array(varchar(20))'], + 'array(varchar(20)),int': ['array(varchar(20))', 'int'], + 'array(varchar(20)),array(varchar(20))': ['array(varchar(20))', 'array(varchar(20))'], + 'map(varchar, integer),int': ['map(varchar, integer)', 'int'], + 'map(varchar(20), integer),int': ['map(varchar(20), integer)', 'int'], + 'map(varchar(20), varchar(20)),int': ['map(varchar(20), varchar(20))', 'int'], + 'map(varchar(20), varchar(20)),array(varchar)': ['map(varchar(20), varchar(20))', 'array(varchar)'], + 'row(first_name varchar(20), last_name varchar(20)),int': + ['row(first_name varchar(20), last_name varchar(20))', 'int'], + 'row("first name" varchar(20), "last name" varchar(20)),int': + ['row("first name" varchar(20), "last name" varchar(20))', 'int'], +} + + +@pytest.mark.parametrize( + 'input_string, output_strings', + split_string_testcases.items(), + ids=split_string_testcases.keys() +) +def test_split_string(input_string: str, output_strings: List[str]): + actual = list(datatype.aware_split(input_string)) + assert_that(actual).is_equal_to(output_strings) + + +split_delimiter_testcases = [ + ('first,second', ',', ['first', 'second']), + ('first second', ' ', ['first', 'second']), + ('first|second', '|', ['first', 'second']), + ('first,second third', ',', ['first', 'second third']), + ('first,second third', ' ', ['first,second', 'third']), +] + + +@pytest.mark.parametrize( + 'input_string, delimiter, output_strings', + split_delimiter_testcases, +) +def test_split_delimiter(input_string: str, delimiter: str, output_strings: List[str]): + actual = list(datatype.aware_split(input_string, delimiter=delimiter)) + assert_that(actual).is_equal_to(output_strings) + + +split_maxsplit_testcases = [ + ('one,two,three', -1, ['one', 'two', 'three']), + ('one,two,three', 0, ['one,two,three']), + ('one,two,three', 1, ['one', 'two,three']), + ('one,two,three', 2, ['one', 'two', 'three']), + ('one,two,three', 3, ['one', 'two', 'three']), + ('one,two,three', 10, ['one', 'two', 'three']), + + (',one,two,three', 0, [',one,two,three']), + (',one,two,three', 1, ['', 'one,two,three']), + + ('one,two,three,', 2, ['one', 'two', 'three,']), + ('one,two,three,', 3, ['one', 'two', 'three', '']), +] + + +@pytest.mark.parametrize( + 'input_string, maxsplit, output_strings', + split_maxsplit_testcases, +) +def test_split_maxsplit(input_string: str, maxsplit: int, output_strings: List[str]): + actual = list(datatype.aware_split(input_string, maxsplit=maxsplit)) + assert_that(actual).is_equal_to(output_strings) diff --git a/tests/sqlalchemy/test_dialect.py b/tests/sqlalchemy/test_dialect.py new file mode 100644 index 00000000..fa192c6d --- /dev/null +++ b/tests/sqlalchemy/test_dialect.py @@ -0,0 +1,63 @@ +from typing import List, Any, Dict +from unittest import mock + +import pytest +from assertpy import assert_that +from sqlalchemy.engine import make_url +from sqlalchemy.engine.url import URL + +from trino.auth import BasicAuthentication +from trino.dbapi import Connection +from trino.sqlalchemy.dialect import TrinoDialect +from trino.transaction import IsolationLevel + + +class TestTrinoDialect: + def setup(self): + self.dialect = TrinoDialect() + + # TODO: Test more authentication methods and URL params (https://github.com/trinodb/trino-python-client/issues/106) + @pytest.mark.parametrize( + 'url, expected_args, expected_kwargs', + [ + (make_url('trino://user@localhost'), + list(), dict(host='localhost', catalog='system', user='user')), + + (make_url('trino://user@localhost:8080'), + list(), dict(host='localhost', port=8080, catalog='system', user='user')), + + (make_url('trino://user:pass@localhost:8080'), + list(), dict(host='localhost', port=8080, catalog='system', user='user', + auth=BasicAuthentication('user', 'pass'), http_scheme='https')), + ], + ) + def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]): + actual_args, actual_kwargs = self.dialect.create_connect_args(url) + + assert_that(actual_args).is_equal_to(expected_args) + assert_that(actual_kwargs).is_equal_to(expected_kwargs) + + def test_create_connect_args_missing_user_when_specify_password(self): + url = make_url('trino://:pass@localhost') + assert_that(self.dialect.create_connect_args).raises(ValueError) \ + .when_called_with(url) \ + .is_equal_to('Username is required when specify password in connection URL') + + def test_create_connect_args_wrong_db_format(self): + url = make_url('trino://abc@localhost/catalog/schema/foobar') + assert_that(self.dialect.create_connect_args).raises(ValueError) \ + .when_called_with(url) \ + .is_equal_to('Unexpected database format catalog/schema/foobar') + + def test_get_default_isolation_level(self): + isolation_level = self.dialect.get_default_isolation_level(mock.Mock()) + assert_that(isolation_level).is_equal_to('AUTOCOMMIT') + + def test_isolation_level(self): + dbapi_conn = Connection(host="localhost") + + self.dialect.set_isolation_level(dbapi_conn, "SERIALIZABLE") + assert_that(dbapi_conn._isolation_level).is_equal_to(IsolationLevel.SERIALIZABLE) + + isolation_level = self.dialect.get_isolation_level(dbapi_conn) + assert_that(isolation_level).is_equal_to("SERIALIZABLE") diff --git a/trino/__init__.py b/trino/__init__.py index 6c3d5877..5317d4ae 100644 --- a/trino/__init__.py +++ b/trino/__init__.py @@ -10,13 +10,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import auth -from . import dbapi -from . import client -from . import constants -from . import exceptions -from . import logging - -__all__ = ['auth', 'dbapi', 'client', 'constants', 'exceptions', 'logging'] - __version__ = "0.306.0" diff --git a/trino/auth.py b/trino/auth.py index 2156b229..9bdf03a4 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -36,16 +36,16 @@ def get_exceptions(self): class KerberosAuthentication(Authentication): def __init__( - self, - config: Optional[str] = None, - service_name: str = None, - mutual_authentication: bool = False, - force_preemptive: bool = False, - hostname_override: Optional[str] = None, - sanitize_mutual_error_response: bool = True, - principal: Optional[str] = None, - delegate: bool = False, - ca_bundle: Optional[str] = None, + self, + config: Optional[str] = None, + service_name: str = None, + mutual_authentication: bool = False, + force_preemptive: bool = False, + hostname_override: Optional[str] = None, + sanitize_mutual_error_response: bool = True, + principal: Optional[str] = None, + delegate: bool = False, + ca_bundle: Optional[str] = None, ) -> None: self._config = config self._service_name = service_name @@ -87,6 +87,19 @@ def get_exceptions(self): except ImportError: raise RuntimeError("unable to import requests_kerberos") + def __eq__(self, other): + if not isinstance(other, KerberosAuthentication): + return False + return (self._config == other._config + and self._service_name == other._service_name + and self._mutual_authentication == other._mutual_authentication + and self._force_preemptive == other._force_preemptive + and self._hostname_override == other._hostname_override + and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response + and self._principal == other._principal + and self._delegate == other._delegate + and self._ca_bundle == other._ca_bundle) + class BasicAuthentication(Authentication): def __init__(self, username, password): @@ -105,11 +118,17 @@ def set_http_session(self, http_session): def get_exceptions(self): return () + def __eq__(self, other): + if not isinstance(other, BasicAuthentication): + return False + return self._username == other._username and self._password == other._password + class _BearerAuth(AuthBase): """ Custom implementation of Authentication class for bearer token """ + def __init__(self, token): self.token = token @@ -130,6 +149,11 @@ def set_http_session(self, http_session): def get_exceptions(self): return () + def __eq__(self, other): + if not isinstance(other, JWTAuthentication): + return False + return self.token == other.token + def handle_redirect_auth_url(auth_url): print("Open the following URL in browser for the external authentication:") @@ -140,6 +164,7 @@ class _OAuth2TokenBearer(AuthBase): """ Custom implementation of Trino Oauth2 based authorization to get the token """ + MAX_OAUTH_ATTEMPTS = 5 class _AuthStep(Enum): @@ -160,16 +185,16 @@ def __call__(self, r): self._thread_local.auth_step = self._AuthStep.GET_REDIRECT_SERVER if self._thread_local.token: - r.headers['Authorization'] = "Bearer " + self._thread_local.token + r.headers['Authorization'] = f"Bearer {self._thread_local.token}" r.register_hook('response', self.__authenticate) return r def __authenticate(self, r, **kwargs): - if (self._thread_local.auth_step == self._AuthStep.GET_REDIRECT_SERVER): + if self._thread_local.auth_step == self._AuthStep.GET_REDIRECT_SERVER: self.__process_get_redirect_server(r) - elif (self._thread_local.auth_step == self._AuthStep.GET_TOKEN): + elif self._thread_local.auth_step == self._AuthStep.GET_TOKEN: self.__process_get_token(r) return r @@ -249,3 +274,8 @@ def set_http_session(self, http_session): def get_exceptions(self): return () + + def __eq__(self, other): + if not isinstance(other, OAuth2Authentication): + return False + return self._redirect_auth_url == other._redirect_auth_url diff --git a/trino/dbapi.py b/trino/dbapi.py index 60b00de7..9eeac880 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -29,10 +29,44 @@ import trino.exceptions import trino.client import trino.logging -from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION - +from trino.transaction import ( + Transaction, + IsolationLevel, + NO_TRANSACTION +) +from trino.exceptions import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) -__all__ = ["connect", "Connection", "Cursor"] +__all__ = [ + # https://www.python.org/dev/peps/pep-0249/#globals + "apilevel", + "threadsafety", + "paramstyle", + "connect", + "Connection", + "Cursor", + # https://www.python.org/dev/peps/pep-0249/#exceptions + "Warning", + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", +] apilevel = "2.0" diff --git a/trino/sqlalchemy/__init__.py b/trino/sqlalchemy/__init__.py new file mode 100644 index 00000000..000d3e08 --- /dev/null +++ b/trino/sqlalchemy/__init__.py @@ -0,0 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from sqlalchemy.dialects import registry + +registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect") diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py new file mode 100644 index 00000000..08361b37 --- /dev/null +++ b/trino/sqlalchemy/compiler.py @@ -0,0 +1,143 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from sqlalchemy.sql import compiler + +# https://trino.io/docs/current/language/reserved.html +RESERVED_WORDS = { + "alter", + "and", + "as", + "between", + "by", + "case", + "cast", + "constraint", + "create", + "cross", + "cube", + "current_catalog", + "current_date", + "current_path", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "deallocate", + "delete", + "describe", + "distinct", + "drop", + "else", + "end", + "escape", + "except", + "execute", + "exists", + "extract", + "false", + "for", + "from", + "full", + "group", + "grouping", + "having", + "in", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "left", + "like", + "localtime", + "localtimestamp", + "natural", + "normalize", + "not", + "null", + "on", + "or", + "order", + "outer", + "prepare", + "recursive", + "right", + "rollup", + "select", + "skip", + "table", + "then", + "true", + "uescape", + "union", + "unnest", + "using", + "values", + "when", + "where", + "with", +} + + +class TrinoSQLCompiler(compiler.SQLCompiler): + pass + + +class TrinoDDLCompiler(compiler.DDLCompiler): + pass + + +class TrinoTypeCompiler(compiler.GenericTypeCompiler): + def visit_FLOAT(self, type_, **kw): + precision = type_.precision or 32 + if 0 <= precision <= 32: + return self.visit_REAL(type_, **kw) + elif 32 < precision <= 64: + return self.visit_DOUBLE(type_, **kw) + else: + raise ValueError(f"type.precision must be in range [0, 64], got {type_.precision}") + + def visit_DOUBLE(self, type_, **kw): + return "DOUBLE" + + def visit_NUMERIC(self, type_, **kw): + return self.visit_DECIMAL(type_, **kw) + + def visit_NCHAR(self, type_, **kw): + return self.visit_CHAR(type_, **kw) + + def visit_NVARCHAR(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_TEXT(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_BINARY(self, type_, **kw): + return self.visit_VARBINARY(type_, **kw) + + def visit_CLOB(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_NCLOB(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_BLOB(self, type_, **kw): + return self.visit_VARBINARY(type_, **kw) + + def visit_DATETIME(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + +class TrinoIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py new file mode 100644 index 00000000..04ee192b --- /dev/null +++ b/trino/sqlalchemy/datatype.py @@ -0,0 +1,200 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import Iterator, List, Optional, Tuple, Type, Union + +from sqlalchemy import util +from sqlalchemy.sql import sqltypes +from sqlalchemy.sql.type_api import TypeEngine + +SQLType = Union[TypeEngine, Type[TypeEngine]] + + +class DOUBLE(sqltypes.Float): + __visit_name__ = "DOUBLE" + + +class MAP(TypeEngine): + __visit_name__ = "MAP" + + def __init__(self, key_type: SQLType, value_type: SQLType): + if isinstance(key_type, type): + key_type = key_type() + self.key_type: TypeEngine = key_type + + if isinstance(value_type, type): + value_type = value_type() + self.value_type: TypeEngine = value_type + + @property + def python_type(self): + return dict + + +class ROW(TypeEngine): + __visit_name__ = "ROW" + + def __init__(self, attr_types: List[Tuple[Optional[str], SQLType]]): + self.attr_types: List[Tuple[Optional[str], SQLType]] = [] + for attr_name, attr_type in attr_types: + if isinstance(attr_type, type): + attr_type = attr_type() + self.attr_types.append((attr_name, attr_type)) + + @property + def python_type(self): + return list + + +# https://trino.io/docs/current/language/types.html +_type_map = { + # === Boolean === + 'boolean': sqltypes.BOOLEAN, + + # === Integer === + 'tinyint': sqltypes.SMALLINT, + 'smallint': sqltypes.SMALLINT, + 'int': sqltypes.INTEGER, + 'integer': sqltypes.INTEGER, + 'bigint': sqltypes.BIGINT, + + # === Floating-point === + 'real': sqltypes.REAL, + 'double': DOUBLE, + + # === Fixed-precision === + 'decimal': sqltypes.DECIMAL, + + # === String === + 'varchar': sqltypes.VARCHAR, + 'char': sqltypes.CHAR, + 'varbinary': sqltypes.VARBINARY, + 'json': sqltypes.JSON, + + # === Date and time === + 'date': sqltypes.DATE, + 'time': sqltypes.TIME, + 'timestamp': sqltypes.TIMESTAMP, + + # 'interval year to month': + # 'interval day to second': + # + # === Structural === + # 'array': ARRAY, + # 'map': MAP + # 'row': ROW + # + # === Mixed === + # 'ipaddress': IPADDRESS + # 'uuid': UUID, + # 'hyperloglog': HYPERLOGLOG, + # 'p4hyperloglog': P4HYPERLOGLOG, + # 'qdigest': QDIGEST, + # 'tdigest': TDIGEST, +} + + +def unquote(string: str, quote: str = '"', escape: str = '\\') -> str: + """ + If string starts and ends with a quote, unquote it + """ + if string.startswith(quote) and string.endswith(quote): + string = string[1:-1] + string = string.replace(f"{escape}{quote}", quote) \ + .replace(f"{escape}{escape}", escape) + return string + + +def aware_split(string: str, delimiter: str = ',', maxsplit: int = -1, + quote: str = '"', escaped_quote: str = r'\"', + open_bracket: str = '(', close_bracket: str = ')') -> Iterator[str]: + """ + A split function that is aware of quotes and brackets/parentheses. + + :param string: string to split + :param delimiter: string defining where to split, usually a comma or space + :param maxsplit: Maximum number of splits to do. -1 (default) means no limit. + :param quote: string, either a single or a double quote + :param escaped_quote: string representing an escaped quote + :param open_bracket: string, either [, {, < or ( + :param close_bracket: string, either ], }, > or ) + """ + parens = 0 + quotes = False + i = 0 + if maxsplit < -1: + raise ValueError(f"maxsplit must be >= -1, got {maxsplit}") + elif maxsplit == 0: + yield string + return + for j, character in enumerate(string): + complete = parens == 0 and not quotes + if complete and character == delimiter: + if maxsplit != -1: + maxsplit -= 1 + yield string[i:j] + i = j + len(delimiter) + if maxsplit == 0: + break + elif character == open_bracket: + parens += 1 + elif character == close_bracket: + parens -= 1 + elif character == quote: + if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote: + quotes = False + elif not quotes: + quotes = True + yield string[i:] + + +def parse_sqltype(type_str: str) -> TypeEngine: + type_str = type_str.strip().lower() + match = re.match(r'^(?P\w+)\s*(?:\((?P.*)\))?', type_str) + if not match: + util.warn(f"Could not parse type name '{type_str}'") + return sqltypes.NULLTYPE + type_name = match.group("type") + type_opts = match.group("options") + + if type_name == "array": + item_type = parse_sqltype(type_opts) + if isinstance(item_type, sqltypes.ARRAY): + # Multi-dimensions array is normalized in SQLAlchemy, e.g: + # `ARRAY(ARRAY(INT))` in Trino SQL will become `ARRAY(INT(), dimensions=2)` in SQLAlchemy + dimensions = (item_type.dimensions or 1) + 1 + return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions) + return sqltypes.ARRAY(item_type) + elif type_name == "map": + key_type_str, value_type_str = aware_split(type_opts) + key_type = parse_sqltype(key_type_str) + value_type = parse_sqltype(value_type_str) + return MAP(key_type, value_type) + elif type_name == "row": + attr_types: List[Tuple[Optional[str], SQLType]] = [] + for attr in aware_split(type_opts): + attr_name, attr_type_str = aware_split(attr.strip(), delimiter=' ', maxsplit=1) + attr_name = unquote(attr_name) + attr_type = parse_sqltype(attr_type_str) + attr_types.append((attr_name, attr_type)) + return ROW(attr_types) + + if type_name not in _type_map: + util.warn(f"Did not recognize type '{type_name}'") + return sqltypes.NULLTYPE + type_class = _type_map[type_name] + type_args = [int(o.strip()) for o in type_opts.split(',')] if type_opts else [] + if type_name in ("time", "timestamp"): + type_kwargs = dict(timezone=type_str.endswith("with time zone")) + # TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107) + return type_class(**type_kwargs) + return type_class(*type_args) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py new file mode 100644 index 00000000..f0732c14 --- /dev/null +++ b/trino/sqlalchemy/dialect.py @@ -0,0 +1,298 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from textwrap import dedent +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple + +from sqlalchemy import exc, sql +from sqlalchemy.engine.base import Connection +from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext +from sqlalchemy.engine.url import URL + +from trino import dbapi as trino_dbapi +from trino.auth import BasicAuthentication +from trino.dbapi import Cursor +from trino.sqlalchemy import compiler, datatype, error + + +class TrinoDialect(DefaultDialect): + name = 'trino' + driver = 'rest' + + statement_compiler = compiler.TrinoSQLCompiler + ddl_compiler = compiler.TrinoDDLCompiler + type_compiler = compiler.TrinoTypeCompiler + preparer = compiler.TrinoIdentifierPreparer + + # Data Type + supports_native_enum = False + supports_native_boolean = True + supports_native_decimal = True + + # Column options + supports_sequences = False + supports_comments = True + inline_comments = True + supports_default_values = False + + # DDL + supports_alter = True + + # DML + # Queries of the form `INSERT () VALUES ()` is not supported by Trino. + supports_empty_insert = False + supports_multivalues_insert = True + postfetch_lastrowid = False + + @classmethod + def dbapi(cls): + """ + ref: https://www.python.org/dev/peps/pep-0249/#module-interface + """ + return trino_dbapi + + def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any]]: + args: Sequence[Any] = list() + kwargs: Dict[str, Any] = dict(host=url.host) + + if url.port: + kwargs['port'] = url.port + + db_parts = (url.database or 'system').split('/') + if len(db_parts) == 1: + kwargs['catalog'] = db_parts[0] + elif len(db_parts) == 2: + kwargs['catalog'] = db_parts[0] + kwargs['schema'] = db_parts[1] + else: + raise ValueError(f'Unexpected database format {url.database}') + + if url.username: + kwargs['user'] = url.username + + if url.password: + if not url.username: + raise ValueError('Username is required when specify password in connection URL') + kwargs['http_scheme'] = 'https' + kwargs['auth'] = BasicAuthentication(url.username, url.password) + + return args, kwargs + + def get_columns(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + if not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError(f'schema={schema}, table={table_name}') + return self._get_columns(connection, table_name, schema, **kw) + + def _get_columns(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + schema = schema or self._get_default_schema_name(connection) + query = dedent(''' + SELECT + "column_name", + "data_type", + "column_default", + UPPER("is_nullable") AS "is_nullable" + FROM "information_schema"."columns" + WHERE "table_schema" = :schema + AND "table_name" = :table + ORDER BY "ordinal_position" ASC + ''').strip() + res = connection.execute(sql.text(query), schema=schema, table=table_name) + columns = [] + for record in res: + column = dict( + name=record.column_name, + type=datatype.parse_sqltype(record.data_type), + nullable=record.is_nullable == 'YES', + default=record.column_default, + ) + columns.append(column) + return columns + + def get_pk_constraint(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + """Trino has no support for primary keys. Returns a dummy""" + return dict(name=None, constrained_columns=[]) + + def get_primary_keys(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[str]: + pk = self.get_pk_constraint(connection, table_name, schema) + return pk.get('constrained_columns') # type: ignore + + def get_foreign_keys(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for foreign keys. Returns an empty list.""" + return [] + + def get_schema_names(self, connection: Connection, **kw) -> List[str]: + query = dedent(''' + SELECT "schema_name" + FROM "information_schema"."schemata" + ''').strip() + res = connection.execute(sql.text(query)) + return [row.schema_name for row in res] + + def get_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + schema = schema or self._get_default_schema_name(connection) + if schema is None: + raise exc.NoSuchTableError('schema is required') + query = dedent(''' + SELECT "table_name" + FROM "information_schema"."tables" + WHERE "table_schema" = :schema + ''').strip() + res = connection.execute(sql.text(query), schema=schema) + return [row.table_name for row in res] + + def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + """Trino has no support for temporary tables. Returns an empty list.""" + return [] + + def get_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + schema = schema or self._get_default_schema_name(connection) + if schema is None: + raise exc.NoSuchTableError('schema is required') + query = dedent(''' + SELECT "table_name" + FROM "information_schema"."views" + WHERE "table_schema" = :schema + ''').strip() + res = connection.execute(sql.text(query), schema=schema) + return [row.table_name for row in res] + + def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + """Trino has no support for temporary views. Returns an empty list.""" + return [] + + def get_view_definition(self, connection: Connection, view_name: str, schema: str = None, **kw) -> str: + schema = schema or self._get_default_schema_name(connection) + if schema is None: + raise exc.NoSuchTableError('schema is required') + query = dedent(''' + SELECT "view_definition" + FROM "information_schema"."views" + WHERE "table_schema" = :schema + AND "table_name" = :view + ''').strip() + res = connection.execute(sql.text(query), schema=schema, view=view_name) + return res.scalar() + + def get_indexes(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + if not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError(f'schema={schema}, table={table_name}') + + partitioned_columns = self._get_columns(connection, f'{table_name}$partitions', schema, **kw) + partition_index = dict( + name='partition', + column_names=[col['name'] for col in partitioned_columns], + unique=False + ) + return [partition_index, ] + + def get_sequence_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + """Trino has no support for sequences. Returns an empty list.""" + return [] + + def get_unique_constraints(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for unique constraints. Returns an empty list.""" + return [] + + def get_check_constraints(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for check constraints. Returns an empty list.""" + return [] + + def get_table_comment(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + properties_table = self._get_full_table(f'{table_name}$properties', schema) + query = f'SELECT "comment" FROM {properties_table}' + try: + res = connection.execute(sql.text(query)) + return dict(text=res.scalar()) + except error.TrinoQueryError as e: + if e.error_name in ( + error.NOT_FOUND, + error.COLUMN_NOT_FOUND, + error.TABLE_NOT_FOUND, + ): + return dict(text=None) + raise + + def has_schema(self, connection: Connection, schema: str) -> bool: + query = dedent(''' + SELECT "schema_name" + FROM "information_schema"."schemata" + WHERE "schema_name" = :schema + ''').strip() + res = connection.execute(sql.text(query), schema=schema) + return res.first() is not None + + def has_table(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> bool: + schema = schema or self._get_default_schema_name(connection) + if schema is None: + return False + query = dedent(''' + SELECT "table_name" + FROM "information_schema"."tables" + WHERE "table_schema" = :schema + AND "table_name" = :table + ''').strip() + res = connection.execute(sql.text(query), schema=schema, table=table_name) + return res.first() is not None + + def has_sequence(self, connection: Connection, + sequence_name: str, schema: str = None, **kw) -> bool: + """Trino has no support for sequence. Returns False indicate that given sequence does not exists.""" + return False + + def _get_server_version_info(self, connection: Connection) -> Tuple[int, ...]: + query = 'SELECT version()' + res = connection.execute(sql.text(query)) + version = res.scalar() + return tuple([version]) + + def _get_default_schema_name(self, connection: Connection) -> Optional[str]: + dbapi_connection: trino_dbapi.Connection = connection.connection + return dbapi_connection.schema + + def do_execute(self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], + context: DefaultExecutionContext = None): + cursor.execute(statement, parameters) + if context and context.should_autocommit: + # SQL statement only submitted to Trino server when cursor.fetch*() is called. + # For DDL (CREATE/ALTER/DROP) and DML (INSERT/UPDATE/DELETE) statement, call cursor.description + # to force submit statement immediately. + cursor.description # noqa + + def do_rollback(self, dbapi_connection: trino_dbapi.Connection): + if dbapi_connection.transaction is not None: + dbapi_connection.rollback() + + def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level: str) -> None: + dbapi_conn._isolation_level = trino_dbapi.IsolationLevel[level] + + def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str: + return dbapi_conn.isolation_level.name + + def get_default_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str: + return trino_dbapi.IsolationLevel.AUTOCOMMIT.name + + def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str: + table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name + if schema: + schema_part = self.identifier_preparer.quote_identifier(schema) if quote else schema + return f'{schema_part}.{table_part}' + + return table_part diff --git a/trino/sqlalchemy/error.py b/trino/sqlalchemy/error.py new file mode 100644 index 00000000..3079d6eb --- /dev/null +++ b/trino/sqlalchemy/error.py @@ -0,0 +1,24 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from trino.exceptions import TrinoQueryError # noqa + +# ref: https://github.com/trinodb/trino/blob/master/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +NOT_FOUND = 'NOT_FOUND' +COLUMN_NOT_FOUND = 'COLUMN_NOT_FOUND' +TABLE_NOT_FOUND = 'TABLE_NOT_FOUND' +SCHEMA_NOT_FOUND = 'SCHEMA_NOT_FOUND' +CATALOG_NOT_FOUND = 'CATALOG_NOT_FOUND' + +MISSING_TABLE = 'MISSING_TABLE' +MISSING_COLUMN_NAME = 'MISSING_COLUMN_NAME' +MISSING_SCHEMA_NAME = 'MISSING_SCHEMA_NAME' +MISSING_CATALOG_NAME = 'MISSING_CATALOG_NAME' diff --git a/trino/transaction.py b/trino/transaction.py index b6f3b2f4..f1ebe6b0 100644 --- a/trino/transaction.py +++ b/trino/transaction.py @@ -9,24 +9,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum, unique from typing import Iterable -from trino import constants import trino.client import trino.exceptions import trino.logging - +from trino import constants logger = trino.logging.get_logger(__name__) - NO_TRANSACTION = "NONE" START_TRANSACTION = "START TRANSACTION" ROLLBACK = "ROLLBACK" COMMIT = "COMMIT" -class IsolationLevel(object): +@unique +class IsolationLevel(Enum): AUTOCOMMIT = 0 READ_UNCOMMITTED = 1 READ_COMMITTED = 2 @@ -35,16 +35,16 @@ class IsolationLevel(object): @classmethod def levels(cls) -> Iterable[str]: - return {k for k, v in cls.__dict__.items() if not k.startswith("_") and isinstance(v, int)} + return {isolation_level.name for isolation_level in IsolationLevel} @classmethod def values(cls) -> Iterable[int]: - return {getattr(cls, level) for level in cls.levels()} + return {isolation_level.value for isolation_level in IsolationLevel} @classmethod def check(cls, level: int) -> int: if level not in cls.values(): - raise ValueError("invalid isolation level {}".format(level)) + raise ValueError(f"invalid isolation level {level}") return level @@ -60,9 +60,7 @@ def id(self): def begin(self): response = self._request.post(START_TRANSACTION) if not response.ok: - raise trino.exceptions.DatabaseError( - "failed to start transaction: {}".format(response.status_code) - ) + raise trino.exceptions.DatabaseError(f"failed to start transaction: {response.status_code}") transaction_id = response.headers.get(constants.HEADER_STARTED_TRANSACTION) if transaction_id and transaction_id != NO_TRANSACTION: self._id = response.headers[constants.HEADER_STARTED_TRANSACTION] @@ -74,16 +72,14 @@ def begin(self): self._id = response.headers[constants.HEADER_STARTED_TRANSACTION] status = self._request.process(response) self._request.transaction_id = self._id - logger.info("transaction started: " + self._id) + logger.info("transaction started: %s", self._id) def commit(self): query = trino.client.TrinoQuery(self._request, COMMIT) try: list(query.execute()) except Exception as err: - raise trino.exceptions.DatabaseError( - "failed to commit transaction {}: {}".format(self._id, err) - ) + raise trino.exceptions.DatabaseError(f"failed to commit transaction {self._id}") from err self._id = NO_TRANSACTION self._request.transaction_id = self._id @@ -92,8 +88,6 @@ def rollback(self): try: list(query.execute()) except Exception as err: - raise trino.exceptions.DatabaseError( - "failed to rollback transaction {}: {}".format(self._id, err) - ) + raise trino.exceptions.DatabaseError(f"failed to rollback transaction {self._id}") from err self._id = NO_TRANSACTION self._request.transaction_id = self._id