Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLAlchemy support #81

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
28 changes: 16 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,23 @@

import ast
import re
from setuptools import setup
import textwrap

from setuptools import setup

_version_re = re.compile(r"__version__\s+=\s+(.*)")


with open("trino/__init__.py", "rb") as f:
trino_version = _version_re.search(f.read().decode("utf-8"))
assert trino_version is not None
version = str(ast.literal_eval(trino_version.group(1)))


kerberos_require = ["requests_kerberos"]
sqlalchemy_require = ["sqlalchemy~=1.3"]

all_require = kerberos_require + []
all_require = kerberos_require + sqlalchemy_require

tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click"]
tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click", "assertpy"]

setup(
name="trino",
Expand All @@ -44,19 +43,17 @@
description="Client for the Trino distributed SQL Engine",
long_description=textwrap.dedent(
"""
Client for Trino (https://trino.io), a distributed SQL engine for
interactive and batch big data processing. Provides a low-level client and
a DBAPI 2.0 implementation.
"""
Client for Trino (https://trino.io), a distributed SQL engine for
interactive and batch big data processing. Provides a low-level client and
a DBAPI 2.0 implementation.
"""
),
license="Apache 2.0",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX",
"Operating System :: Microsoft :: Windows",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
Expand All @@ -66,13 +63,20 @@
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Topic :: Database",
"Topic :: Database :: Front-Ends",
],
python_requires='>=3.6',
install_requires=["requests"],
extras_require={
"all": all_require,
"kerberos": kerberos_require,
"sqlalchemy": sqlalchemy_require,
"tests": tests_require,
},
entry_points={
"sqlalchemy.dialects": [
"trino = trino.sqlalchemy.dialect:TrinoDialect",
]
},
)
Empty file added tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
import pytz

import trino
import trino.dbapi
from trino.exceptions import TrinoQueryError
from trino.transaction import IsolationLevel

Expand Down
11 changes: 11 additions & 0 deletions tests/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
46 changes: 46 additions & 0 deletions tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from assertpy import add_extension, assert_that
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType


def assert_sqltype(this: SQLType, that: SQLType):
if isinstance(this, type):
hashhar marked this conversation as resolved.
Show resolved Hide resolved
this = this()
if isinstance(that, type):
that = that()
assert_that(type(this)).is_same_as(type(that))
if isinstance(this, ARRAY):
assert_sqltype(this.item_type, that.item_type)
if this.dimensions is None or this.dimensions == 1:
# ARRAY(dimensions=None) == ARRAY(dimensions=1)
assert_that(that.dimensions).is_in(None, 1)
else:
assert_that(this.dimensions).is_equal_to(this.dimensions)
elif isinstance(this, MAP):
assert_sqltype(this.key_type, that.key_type)
assert_sqltype(this.value_type, that.value_type)
elif isinstance(this, ROW):
assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types))
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
hashhar marked this conversation as resolved.
Show resolved Hide resolved
assert_that(this_attr[0]).is_equal_to(that_attr[0])
assert_sqltype(this_attr[1], that_attr[1])
else:
assert_that(str(this)).is_equal_to(str(that))
hashhar marked this conversation as resolved.
Show resolved Hide resolved


@add_extension
def is_sqltype(self, that):
this = self.val
assert_sqltype(this, that)
183 changes: 183 additions & 0 deletions tests/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from assertpy import assert_that
from sqlalchemy.sql.sqltypes import (
CHAR, VARCHAR,
ARRAY,
INTEGER, DECIMAL,
DATE, TIME, TIMESTAMP
)
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW


@pytest.mark.parametrize(
'type_str, sql_type',
datatype._type_map.items(),
ids=datatype._type_map.keys()
)
def test_parse_simple_type(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
if not isinstance(actual_type, type):
actual_type = type(actual_type)
assert_that(actual_type).is_equal_to(sql_type)


parse_cases_testcases = {
'char(10)': CHAR(10),
'Char(10)': CHAR(10),
'char': CHAR(),
'cHaR': CHAR(),
'VARCHAR(10)': VARCHAR(10),
'varCHAR(10)': VARCHAR(10),
'VARchar(10)': VARCHAR(10),
dungdm93 marked this conversation as resolved.
Show resolved Hide resolved
'VARCHAR': VARCHAR(),
'VaRchAr': VARCHAR(),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_cases_testcases.items(),
ids=parse_cases_testcases.keys()
)
def test_parse_cases(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_type_options_testcases = {
'CHAR(10)': CHAR(10),
'VARCHAR(10)': VARCHAR(10),
'DECIMAL(20)': DECIMAL(20),
'DECIMAL(20, 3)': DECIMAL(20, 3),
hashhar marked this conversation as resolved.
Show resolved Hide resolved
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_type_options_testcases.items(),
ids=parse_type_options_testcases.keys()
)
def test_parse_type_options(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_array_testcases = {
'array(integer)': ARRAY(INTEGER()),
'array(varchar(10))': ARRAY(VARCHAR(10)),
'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)),
'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2),
hashhar marked this conversation as resolved.
Show resolved Hide resolved
'array(map(char, integer))': ARRAY(MAP(CHAR(), INTEGER())),
'array(row(a integer, b varchar))': ARRAY(ROW([("a", INTEGER()), ("b", VARCHAR())])),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_array_testcases.items(),
ids=parse_array_testcases.keys()
)
def test_parse_array(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_map_testcases = {
'map(char, integer)': MAP(CHAR(), INTEGER()),
'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)),
'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)),
'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))),
'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_map_testcases.items(),
ids=parse_map_testcases.keys()
)
def test_parse_map(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_row_testcases = {
'row(a integer, b varchar)':
ROW(attr_types=[
("a", INTEGER()),
("b", VARCHAR()),
]),
'row(a varchar(20), b decimal(20,3))':
ROW(attr_types=[
("a", VARCHAR(20)),
("b", DECIMAL(20, 3)),
]),
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
dungdm93 marked this conversation as resolved.
Show resolved Hide resolved
ROW(attr_types=[
("x", ARRAY(VARCHAR(10))),
("y", ARRAY(VARCHAR(10), dimensions=2)),
("z", DECIMAL(20, 3)),
]),
'row(min timestamp(6) with time zone, max timestamp(6) with time zone)':
ROW(attr_types=[
("min", TIMESTAMP(timezone=True)),
("max", TIMESTAMP(timezone=True)),
]),
'row("first name" varchar, "last name" varchar)':
ROW(attr_types=[
("first name", VARCHAR()),
("last name", VARCHAR()),
]),
'row("foo,bar" varchar, "foo(bar)" varchar, "foo\\"bar" varchar)':
ROW(attr_types=[
(r'foo,bar', VARCHAR()),
(r'foo(bar)', VARCHAR()),
(r'foo"bar', VARCHAR()),
]),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_row_testcases.items(),
ids=parse_row_testcases.keys()
)
def test_parse_row(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_datetime_testcases = {
hashhar marked this conversation as resolved.
Show resolved Hide resolved
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
'date': DATE(),
'time': TIME(),
'time with time zone': TIME(timezone=True),
'timestamp': TIMESTAMP(),
'timestamp with time zone': TIMESTAMP(timezone=True),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys()
)
def test_parse_datetime(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)
Loading