diff --git a/pyhive/presto.py b/pyhive/presto.py index a38cd891..3217f4c2 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -9,6 +9,8 @@ from __future__ import unicode_literals from builtins import object +from decimal import Decimal + from pyhive import common from pyhive.common import DBAPITypeObject # Make all exceptions visible in this module per DB-API @@ -34,6 +36,11 @@ _logger = logging.getLogger(__name__) +TYPES_CONVERTER = { + "decimal": Decimal, + # As of Presto 0.69, binary data is returned as the varbinary type in base64 format + "varbinary": base64.b64decode +} class PrestoParamEscaper(common.ParamEscaper): def escape_datetime(self, item, format): @@ -307,14 +314,13 @@ def _fetch_more(self): """Fetch the next URI and update state""" self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs)) - def _decode_binary(self, rows): - # As of Presto 0.69, binary data is returned as the varbinary type in base64 format - # This function decodes base64 data in place + def _process_data(self, rows): for i, col in enumerate(self.description): - if col[1] == 'varbinary': + col_type = col[1].split("(")[0].lower() + if col_type in TYPES_CONVERTER: for row in rows: if row[i] is not None: - row[i] = base64.b64decode(row[i]) + row[i] = TYPES_CONVERTER[col_type](row[i]) def _process_response(self, response): """Given the JSON response from Presto's REST API, update the internal state with the next @@ -341,7 +347,7 @@ def _process_response(self, response): if 'data' in response_json: assert self._columns new_data = response_json['data'] - self._decode_binary(new_data) + self._process_data(new_data) self._data += map(tuple, new_data) if 'nextUri' not in response_json: self._state = self._STATE_FINISHED diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py index 7c74f057..187b1c21 100644 --- a/pyhive/tests/test_presto.py +++ b/pyhive/tests/test_presto.py @@ -9,6 +9,8 @@ import contextlib import os +from decimal import Decimal + import requests from pyhive import exc @@ -93,7 +95,7 @@ def test_complex(self, cursor): {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), )] self.assertEqual(rows, expected) # catch unicode/str diff --git a/pyhive/tests/test_trino.py b/pyhive/tests/test_trino.py index cdc8bb43..41bb489b 100644 --- a/pyhive/tests/test_trino.py +++ b/pyhive/tests/test_trino.py @@ -9,6 +9,8 @@ import contextlib import os +from decimal import Decimal + import requests from pyhive import exc @@ -89,7 +91,7 @@ def test_complex(self, cursor): {"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), )] self.assertEqual(rows, expected) # catch unicode/str diff --git a/pyhive/trino.py b/pyhive/trino.py index e8a1aabd..658457a3 100644 --- a/pyhive/trino.py +++ b/pyhive/trino.py @@ -124,7 +124,7 @@ def _process_response(self, response): if 'data' in response_json: assert self._columns new_data = response_json['data'] - self._decode_binary(new_data) + self._process_data(new_data) self._data += map(tuple, new_data) if 'nextUri' not in response_json: self._state = self._STATE_FINISHED