Skip to content

Commit

Permalink
Add SQLAlchemy dialect for Trino
Browse files Browse the repository at this point in the history
  • Loading branch information
dungdm93 authored and hashhar committed Dec 2, 2021
1 parent f9e68da commit da1441f
Show file tree
Hide file tree
Showing 12 changed files with 1,109 additions and 2 deletions.
10 changes: 8 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
assert trino_version is not None
version = str(ast.literal_eval(trino_version.group(1)))


kerberos_require = ["requests_kerberos"]
sqlalchemy_require = ["sqlalchemy~=1.3"]

all_require = kerberos_require + []
all_require = kerberos_require + sqlalchemy_require

tests_require = all_require + [
# httpretty >= 1.1 duplicates requests in `httpretty.latest_requests`
Expand Down Expand Up @@ -80,6 +80,12 @@
extras_require={
"all": all_require,
"kerberos": kerberos_require,
"sqlalchemy": sqlalchemy_require,
"tests": tests_require,
},
entry_points={
"sqlalchemy.dialects": [
"trino = trino.sqlalchemy.dialect:TrinoDialect",
]
},
)
Empty file added tests/__init__.py
Empty file.
Empty file.
46 changes: 46 additions & 0 deletions tests/unit/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType


@pytest.fixture(scope="session")
def assert_sqltype():
def _assert_sqltype(this: SQLType, that: SQLType):
if isinstance(this, type):
this = this()
if isinstance(that, type):
that = that()

assert type(this) == type(that)

if isinstance(this, ARRAY):
_assert_sqltype(this.item_type, that.item_type)
if this.dimensions is None or this.dimensions == 1:
# ARRAY(dimensions=None) == ARRAY(dimensions=1)
assert that.dimensions is None or that.dimensions == 1
else:
assert that.dimensions == this.dimensions
elif isinstance(this, MAP):
_assert_sqltype(this.key_type, that.key_type)
_assert_sqltype(this.value_type, that.value_type)
elif isinstance(this, ROW):
assert len(this.attr_types) == len(that.attr_types)
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
assert this_attr[0] == that_attr[0]
_assert_sqltype(this_attr[1], that_attr[1])
else:
assert str(this) == str(that)

return _assert_sqltype
192 changes: 192 additions & 0 deletions tests/unit/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from sqlalchemy.sql.sqltypes import (
CHAR,
VARCHAR,
ARRAY,
INTEGER,
DECIMAL,
DATE,
TIME,
TIMESTAMP,
)
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW


@pytest.mark.parametrize(
"type_str, sql_type",
datatype._type_map.items(),
ids=datatype._type_map.keys(),
)
def test_parse_simple_type(type_str: str, sql_type: TypeEngine, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
if not isinstance(actual_type, type):
actual_type = type(actual_type)
assert_sqltype(actual_type, sql_type)


parse_cases_testcases = {
"char(10)": CHAR(10),
"Char(10)": CHAR(10),
"char": CHAR(),
"cHaR": CHAR(),
"VARCHAR(10)": VARCHAR(10),
"varCHAR(10)": VARCHAR(10),
"VARchar(10)": VARCHAR(10),
"VARCHAR": VARCHAR(),
"VaRchAr": VARCHAR(),
}


@pytest.mark.parametrize(
"type_str, sql_type",
parse_cases_testcases.items(),
ids=parse_cases_testcases.keys(),
)
def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)


parse_type_options_testcases = {
"CHAR(10)": CHAR(10),
"VARCHAR(10)": VARCHAR(10),
"DECIMAL(20)": DECIMAL(20),
"DECIMAL(20, 3)": DECIMAL(20, 3),
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
}


@pytest.mark.parametrize(
"type_str, sql_type",
parse_type_options_testcases.items(),
ids=parse_type_options_testcases.keys(),
)
def test_parse_type_options(type_str: str, sql_type: TypeEngine, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)


parse_array_testcases = {
"array(integer)": ARRAY(INTEGER()),
"array(varchar(10))": ARRAY(VARCHAR(10)),
"array(decimal(20,3))": ARRAY(DECIMAL(20, 3)),
"array(array(varchar(10)))": ARRAY(VARCHAR(10), dimensions=2),
"array(map(char, integer))": ARRAY(MAP(CHAR(), INTEGER())),
"array(row(a integer, b varchar))": ARRAY(ROW([("a", INTEGER()), ("b", VARCHAR())])),
}


@pytest.mark.parametrize(
"type_str, sql_type",
parse_array_testcases.items(),
ids=parse_array_testcases.keys(),
)
def test_parse_array(type_str: str, sql_type: ARRAY, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)


parse_map_testcases = {
"map(char, integer)": MAP(CHAR(), INTEGER()),
"map(varchar(10), varchar(10))": MAP(VARCHAR(10), VARCHAR(10)),
"map(varchar(10), decimal(20,3))": MAP(VARCHAR(10), DECIMAL(20, 3)),
"map(char, array(varchar(10)))": MAP(CHAR(), ARRAY(VARCHAR(10))),
"map(varchar(10), array(varchar(10)))": MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
"map(varchar(10), array(array(varchar(10))))": MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
}


@pytest.mark.parametrize(
"type_str, sql_type",
parse_map_testcases.items(),
ids=parse_map_testcases.keys(),
)
def test_parse_map(type_str: str, sql_type: ARRAY, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)


parse_row_testcases = {
"row(a integer, b varchar)": ROW(
attr_types=[
("a", INTEGER()),
("b", VARCHAR()),
]
),
"row(a varchar(20), b decimal(20,3))": ROW(
attr_types=[
("a", VARCHAR(20)),
("b", DECIMAL(20, 3)),
]
),
"row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))": ROW(
attr_types=[
("x", ARRAY(VARCHAR(10))),
("y", ARRAY(VARCHAR(10), dimensions=2)),
("z", DECIMAL(20, 3)),
]
),
"row(min timestamp(6) with time zone, max timestamp(6) with time zone)": ROW(
attr_types=[
("min", TIMESTAMP(timezone=True)),
("max", TIMESTAMP(timezone=True)),
]
),
'row("first name" varchar, "last name" varchar)': ROW(
attr_types=[
("first name", VARCHAR()),
("last name", VARCHAR()),
]
),
'row("foo,bar" varchar, "foo(bar)" varchar, "foo\\"bar" varchar)': ROW(
attr_types=[
(r"foo,bar", VARCHAR()),
(r"foo(bar)", VARCHAR()),
(r'foo"bar', VARCHAR()),
]
),
}


@pytest.mark.parametrize(
"type_str, sql_type",
parse_row_testcases.items(),
ids=parse_row_testcases.keys(),
)
def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)


parse_datetime_testcases = {
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"date": DATE(),
"time": TIME(),
"time with time zone": TIME(timezone=True),
"timestamp": TIMESTAMP(),
"timestamp with time zone": TIMESTAMP(timezone=True),
}


@pytest.mark.parametrize(
"type_str, sql_type",
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys(),
)
def test_parse_datetime(type_str: str, sql_type: ARRAY, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)
102 changes: 102 additions & 0 deletions tests/unit/sqlalchemy/test_datatype_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import pytest

from trino.sqlalchemy import datatype

split_string_testcases = {
"10": ["10"],
"10,3": ["10", "3"],
'"a,b",c': ['"a,b"', "c"],
'"a,b","c,d"': ['"a,b"', '"c,d"'],
r'"a,\"b\",c",d': [r'"a,\"b\",c"', "d"],
r'"foo(bar,\"baz\")",quiz': [r'"foo(bar,\"baz\")"', "quiz"],
"varchar": ["varchar"],
"varchar,int": ["varchar", "int"],
"varchar,int,float": ["varchar", "int", "float"],
"array(varchar)": ["array(varchar)"],
"array(varchar),int": ["array(varchar)", "int"],
"array(varchar(20))": ["array(varchar(20))"],
"array(varchar(20)),int": ["array(varchar(20))", "int"],
"array(varchar(20)),array(varchar(20))": [
"array(varchar(20))",
"array(varchar(20))",
],
"map(varchar, integer),int": ["map(varchar, integer)", "int"],
"map(varchar(20), integer),int": ["map(varchar(20), integer)", "int"],
"map(varchar(20), varchar(20)),int": ["map(varchar(20), varchar(20))", "int"],
"map(varchar(20), varchar(20)),array(varchar)": [
"map(varchar(20), varchar(20))",
"array(varchar)",
],
"row(first_name varchar(20), last_name varchar(20)),int": [
"row(first_name varchar(20), last_name varchar(20))",
"int",
],
'row("first name" varchar(20), "last name" varchar(20)),int': [
'row("first name" varchar(20), "last name" varchar(20))',
"int",
],
}


@pytest.mark.parametrize(
"input_string, output_strings",
split_string_testcases.items(),
ids=split_string_testcases.keys(),
)
def test_split_string(input_string: str, output_strings: List[str]):
actual = list(datatype.aware_split(input_string))
assert actual == output_strings


split_delimiter_testcases = [
("first,second", ",", ["first", "second"]),
("first second", " ", ["first", "second"]),
("first|second", "|", ["first", "second"]),
("first,second third", ",", ["first", "second third"]),
("first,second third", " ", ["first,second", "third"]),
]


@pytest.mark.parametrize(
"input_string, delimiter, output_strings",
split_delimiter_testcases,
)
def test_split_delimiter(input_string: str, delimiter: str, output_strings: List[str]):
actual = list(datatype.aware_split(input_string, delimiter=delimiter))
assert actual == output_strings


split_maxsplit_testcases = [
("one,two,three", -1, ["one", "two", "three"]),
("one,two,three", 0, ["one,two,three"]),
("one,two,three", 1, ["one", "two,three"]),
("one,two,three", 2, ["one", "two", "three"]),
("one,two,three", 3, ["one", "two", "three"]),
("one,two,three", 10, ["one", "two", "three"]),
(",one,two,three", 0, [",one,two,three"]),
(",one,two,three", 1, ["", "one,two,three"]),
("one,two,three,", 2, ["one", "two", "three,"]),
("one,two,three,", 3, ["one", "two", "three", ""]),
]


@pytest.mark.parametrize(
"input_string, maxsplit, output_strings",
split_maxsplit_testcases,
)
def test_split_maxsplit(input_string: str, maxsplit: int, output_strings: List[str]):
actual = list(datatype.aware_split(input_string, maxsplit=maxsplit))
assert actual == output_strings
Loading

1 comment on commit da1441f

@long2ice
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit lack dungdm93/sqlalchemy-trino#37 so I can't remove sqlalchemy-trino in my project

Please sign in to comment.