Skip to content

Commit

Permalink
Support decimal, date, time, timestamp with time zone and timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Mar 2, 2022
1 parent 4e61be8 commit 1f33e2a
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 31 deletions.
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,39 @@ The transaction is created when the first SQL statement is executed.
exits the *with* context and the queries succeed, otherwise
`trino.dbapi.Connection.rollback()` will be called.

## Development
# Improved Python types

### Getting Started With Development
If you enable the flag `experimental_python_types`, the client will convert the results of the query to the
corresponding Python types. For example, if the query returns a `DECIMAL` column, the result will be a `Decimal` object.

Limitations of the Python types are described in the
[Python types documentation](https://docs.python.org/3/library/datatypes.html). These limitations will generate an
exception `trino.exceptions.DataError` if the query returns a value that cannot be converted to the corresponding Python
type.

```python
import trino
import pytz
from datetime import datetime

conn = trino.dbapi.connect(
...
)

cur = conn.cursor(experimental_python_types=True)

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp with time zone"
```

# Development

## Getting Started With Development

Start by forking the repository and then modify the code in your fork.

Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@
kerberos_require = ["requests_kerberos"]
sqlalchemy_require = ["sqlalchemy~=1.3"]

all_require = kerberos_require + sqlalchemy_require
all_require = ["pytz"] + kerberos_require + sqlalchemy_require

tests_require = all_require + [
# httpretty >= 1.1 duplicates requests in `httpretty.latest_requests`
# https://github.com/gabrielfalcao/HTTPretty/issues/425
"httpretty < 1.1",
"pytest",
"pytest-runner",
"pytz",
"click",
]

Expand Down
292 changes: 285 additions & 7 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from datetime import datetime
from datetime import datetime, time, date, timezone, timedelta
from decimal import Decimal

import pytest
import pytz
Expand Down Expand Up @@ -123,22 +124,267 @@ def test_string_query_param(trino_connection):
assert rows[0][0] == "six'"


def test_datetime_query_param(trino_connection):
def test_python_types_not_used_when_experimental_python_types_is_not_set(trino_connection):
cur = trino_connection.cursor()

cur.execute("SELECT ?", params=(datetime(2020, 1, 1, 0, 0, 0),))
cur.execute("""
SELECT
DECIMAL '0.142857',
DATE '2018-01-01',
TIMESTAMP '2019-01-01 00:00:00.000+01:00',
TIMESTAMP '2019-01-01 00:00:00.000 UTC',
TIMESTAMP '2019-01-01 00:00:00.000',
TIME '00:00:00.000'
""")
rows = cur.fetchall()

assert rows[0][0] == '0.142857'
assert rows[0][1] == '2018-01-01'
assert rows[0][2] == '2019-01-01 00:00:00.000 +01:00'
assert rows[0][3] == '2019-01-01 00:00:00.000 UTC'
assert rows[0][4] == '2019-01-01 00:00:00.000'
assert rows[0][5] == '00:00:00.000'


def test_decimal_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

cur.execute("SELECT ?", params=(Decimal('0.142857'),))
rows = cur.fetchall()

assert rows[0][0] == Decimal('0.142857')


def test_null_decimal(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

cur.execute("SELECT CAST(NULL AS DECIMAL)")
rows = cur.fetchall()

assert rows[0][0] is None


def test_biggest_decimal(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = Decimal('99999999999999999999999999999999999999')
cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == "2020-01-01 00:00:00.000"
assert rows[0][0] == params

cur.execute("SELECT ?",
params=(datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc),))

def test_smallest_decimal(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = Decimal('-99999999999999999999999999999999999999')
cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == "2020-01-01 00:00:00.000 UTC"
assert rows[0][0] == params


def test_highest_precision_decimal(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = Decimal('0.99999999999999999999999999999999999999')
cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params


def test_datetime_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = datetime(2020, 1, 1, 16, 43, 22, 320000)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp"


def test_datetime_with_trailing_zeros(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

cur.execute("SELECT TIMESTAMP '2001-08-22 03:04:05.321000'")
rows = cur.fetchall()

assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321000", "%Y-%m-%d %H:%M:%S.%f")


def test_datetime_with_utc_time_zone_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('UTC'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp with time zone"


def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

tz = timezone(-timedelta(hours=5, minutes=30))

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=tz)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp with time zone"


def test_datetime_with_named_time_zone_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp with time zone"


def test_null_datetime_with_time_zone(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

cur.execute("SELECT CAST(NULL AS TIMESTAMP WITH TIME ZONE)")
rows = cur.fetchall()

assert rows[0][0] is None


def test_datetime_with_time_zone_numeric_offset(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

cur.execute("SELECT TIMESTAMP '2001-08-22 03:04:05.321 -08:00'")
rows = cur.fetchall()

assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z")


def test_unexisting_datetimes_with_time_zone_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = datetime(2021, 3, 28, 2, 30, 0, tzinfo=pytz.timezone('Europe/Brussels'))
with pytest.raises(trino.exceptions.TrinoUserError):
cur.execute("SELECT ?", params=(params,))
cur.fetchall()


def test_doubled_datetimes_first_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=True)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))


def test_doubled_datetimes_second_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0), is_dst=False)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))


def test_date_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = datetime(2020, 1, 1, 0, 0, 0).date()

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params


def test_null_date(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

cur.execute("SELECT CAST(NULL AS DATE)")
rows = cur.fetchall()

assert rows[0][0] is None


def test_unsupported_python_dates(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

# dates not supported in Python date type
for unsupported_date in [
'-0001-01-01',
'0000-01-01'
]:
with pytest.raises(trino.exceptions.TrinoDataError):
cur.execute(f"SELECT DATE '{unsupported_date}'")
cur.fetchall()


def test_supported_special_dates_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

for params in (
# first day of AD
date(1, 1, 1),
date(12, 12, 12),
# before julian->gregorian switch
date(1500, 1, 1),
# During julian->gregorian switch
date(1752, 9, 4),
# before epoch
date(1952, 4, 3),
date(1970, 1, 1),
date(1970, 2, 3),
# summer on northern hemisphere (possible DST)
date(2017, 7, 1),
# winter on northern hemisphere (possible DST on southern hemisphere)
date(2017, 1, 1),
# winter on southern hemisphere (possible DST on northern hemisphere)
date(2017, 12, 31),
date(1983, 4, 1),
date(1983, 10, 1),
):
cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params


def test_time_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = time(12, 3, 44, 333000)

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params


def test_time_with_time_zone_query_param(trino_connection):
with pytest.raises(trino.exceptions.NotSupportedError):
cur = trino_connection.cursor()

params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('Asia/Shanghai'))

cur.execute("SELECT ?", params=(params,))


def test_array_query_param(trino_connection):
cur = trino_connection.cursor()

Expand All @@ -158,6 +404,38 @@ def test_array_query_param(trino_connection):
assert rows[0][0] == "array(integer)"


def test_array_timestamp_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = [datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 1, 2, 0, 0, 0)]

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params

cur.execute("SELECT TYPEOF(?)", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == "array(timestamp(6))"


def test_array_timestamp_with_timezone_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params

cur.execute("SELECT TYPEOF(?)", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == "array(timestamp(6) with time zone)"


def test_dict_query_param(trino_connection):
cur = trino_connection.cursor()

Expand Down
Loading

0 comments on commit 1f33e2a

Please sign in to comment.