diff --git a/.gitignore b/.gitignore index db178583..859abcfe 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ pydruid/bard.py build/ dist/ pyDruid.egg-info/ - +__pycache__ +.eggs \ No newline at end of file diff --git a/pydruid/client.py b/pydruid/client.py index af0ff953..53ef5720 100755 --- a/pydruid/client.py +++ b/pydruid/client.py @@ -16,7 +16,9 @@ from __future__ import division from __future__ import absolute_import -from six import iteritems +import sys + +import six from six.moves import urllib try: @@ -108,7 +110,7 @@ def __init__(self, url, endpoint): def __post(self, query): try: - querystr = json.dumps(query).encode('ascii') + querystr = json.dumps(query).encode('utf-8') if self.url.endswith('/'): url = self.url + self.endpoint else: @@ -176,25 +178,26 @@ def export_tsv(self, dest_path): 7.0 user_1 2013-10-04T00:00:00.000Z 6.0 user_2 2013-10-04T00:00:00.000Z """ - f = open(dest_path, 'wb') - tsv_file = csv.writer(f, delimiter='\t') + if six.PY3: + f = open(dest_path, 'w', newline='', encoding='utf-8') + else: + f = open(dest_path, 'wb') + w = UnicodeWriter(f) if self.query_type == "timeseries": - header = self.result[0]['result'].keys() + header = list(self.result[0]['result'].keys()) header.append('timestamp') - if self.query_type == 'topN': - header = self.result[0]['result'][0].keys() + elif self.query_type == 'topN': + header = list(self.result[0]['result'][0].keys()) header.append('timestamp') elif self.query_type == "groupBy": - header = self.result[0]['event'].keys() + header = list(self.result[0]['event'].keys()) header.append('timestamp') header.append('version') else: raise NotImplementedError('TSV export not implemented for query type: {0}'.format(self.query_type)) - tsv_file.writerow(header) - - w = UnicodeWriter(f) + w.writerow(header) if self.result: if self.query_type == "topN" or self.query_type == "timeseries": @@ -203,15 +206,15 @@ def export_tsv(self, dest_path): result = item['result'] if type(result) is list: # topN for line in result: - w.writerow(line.values() + [timestamp]) + w.writerow(list(line.values()) + [timestamp]) else: # timeseries - w.writerow(result.values() + [timestamp]) + w.writerow(list(result.values()) + [timestamp]) elif self.query_type == "groupBy": for item in self.result: timestamp = item['timestamp'] version = item['version'] w.writerow( - item['event'].values() + [timestamp] + [version]) + list(item['event'].values()) + [timestamp] + [version]) f.close() @@ -283,7 +286,7 @@ def validate_query(self, valid_parts, args): :raise ValueError: if an invalid object is given """ valid_parts = valid_parts[:] + ['context'] - for key, val in iteritems(args): + for key, val in six.iteritems(args): if key not in valid_parts: raise ValueError( 'Query component: {0} is not valid for query type: {1}.' @@ -294,7 +297,7 @@ def validate_query(self, valid_parts, args): def build_query(self, args): query_dict = {'queryType': self.query_type} - for key, val in iteritems(args): + for key, val in six.iteritems(args): if key == 'aggregations': query_dict[key] = build_aggregators(val) elif key == 'post_aggregations': diff --git a/pydruid/utils/postaggregator.py b/pydruid/utils/postaggregator.py index c6dd5e3a..2b2dcc68 100644 --- a/pydruid/utils/postaggregator.py +++ b/pydruid/utils/postaggregator.py @@ -15,6 +15,7 @@ # from __future__ import division +import six class Postaggregator: def __init__(self, fn, fields, name): @@ -40,6 +41,9 @@ def __div__(self, other): return Postaggregator('/', self.fields(other), self.name + 'div' + other.name) + def __truediv__(self, other): + return self.__div__(other) + def fields(self, other): return [self.post_aggregator, other.post_aggregator] @@ -50,7 +54,7 @@ def rename_postagg(new_name, post_aggregator): return post_aggregator return [rename_postagg(new_name, postagg.post_aggregator) - for (new_name, postagg) in postaggs.iteritems()] + for (new_name, postagg) in six.iteritems(postaggs)] class Field(Postaggregator): diff --git a/pydruid/utils/query_utils.py b/pydruid/utils/query_utils.py index c7b77bcf..19ec0eb7 100644 --- a/pydruid/utils/query_utils.py +++ b/pydruid/utils/query_utils.py @@ -15,8 +15,7 @@ # import csv import codecs -from six import StringIO - +import six # A special CSV writer which will write rows to TSV file "f", which is encoded in utf-8. # this is necessary because the values in druid are not all ASCII. @@ -25,24 +24,21 @@ class UnicodeWriter: # delimiter="\t" def __init__(self, f, dialect="excel-tab", encoding="utf-8", **kwds): - # Redirect output to a queue - self.queue = StringIO() - self.writer = csv.writer(self.queue, dialect=dialect, **kwds) self.stream = f + self.writer = csv.writer(self.stream, dialect=dialect, **kwds) self.encoder = codecs.getincrementalencoder(encoding)() + def __encode(self, data): + data = str(data) if isinstance(data, six.integer_types) else data + if not six.PY3: + data = data.encode('utf-8') if isinstance(data, unicode) else data + data = data.decode('utf-8') + return self.encoder.encode(data) + return data + def writerow(self, row): - self.writer.writerow( - [s.encode("utf-8") if isinstance(s, unicode) else s for s in row]) - # Fetch UTF-8 output from the queue ... - data = self.queue.getvalue() - data = data.decode("utf-8") - # ... and reencode it into the target encoding - data = self.encoder.encode(data) - # write to the target stream - self.stream.write(data) - # empty queue - self.queue.truncate(0) + row = [self.__encode(s) for s in row] + self.writer.writerow(row) def writerows(self, rows): for row in rows: diff --git a/setup.py b/setup.py index f5101294..4856eec6 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,23 @@ import sys - from setuptools import setup +from setuptools.command.test import test as TestCommand + +class PyTest(TestCommand): + user_options = [('pytest-args=', 'a', "Arguments to pass to py.test")] + + def initialize_options(self): + TestCommand.initialize_options(self) + self.pytest_args = [] + + def finalize_options(self): + TestCommand.finalize_options(self) + self.test_args = [] + self.test_suite = True + + def run_tests(self): + import pytest + status = pytest.main(self.pytest_args) + sys.exit(status) install_requires = ["six >= 1.9.0"] @@ -20,4 +37,6 @@ description='A Python connector for Druid.', long_description='See https://github.com/metamx/pydruid for more information.', install_requires=install_requires, + tests_require=['pytest'], + cmdclass={'test': PyTest}, ) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..9f1d3815 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,97 @@ +# -*- coding: UTF-8 -*- + +import os +import pytest +import pandas +from pandas.util.testing import assert_frame_equal +from six import PY3 +from pydruid.client import PyDruid +from pydruid.utils.aggregators import * +from pydruid.utils.postaggregator import * +from pydruid.utils.filters import * +from pydruid.utils.having import * + +def create_client(): + return PyDruid('http://localhost:8083', 'druid/v2/') + +def create_client_with_results(): + client = create_client() + client.query_type = 'timeseries' + client.result = [ + {'result': {'value1': 1, 'value2': '㬓'}, 'timestamp': '2015-01-01T00:00:00.000-05:00'}, + {'result': {'value1': 2, 'value2': '㬓'}, 'timestamp': '2015-01-02T00:00:00.000-05:00'} + ] + return client + +def line_ending(): + if PY3: + return os.linesep + return "\r\n" + +class TestClient: + def test_build_query(self): + client = create_client() + assert client.query_dict == None + + client.build_query({ + 'datasource': 'things', + 'aggregations': { + 'count': count('thing'), + }, + 'post_aggregations': { + 'avg': Field('sum') / Field('count'), + }, + 'paging_spec': { + 'pagingIdentifies': {}, + 'threshold': 1, + }, + 'filter': Dimension('one') == 1, + 'having': Aggregation('sum') > 1, + 'new_key': 'value', + }) + expected_query_dict = { + 'queryType': None, + 'dataSource': 'things', + 'aggregations': [{'fieldName': 'thing', 'name': 'count', 'type': 'count'}], + 'postAggregations': [{ + 'fields': [{ + 'fieldName': 'sum', 'type': 'fieldAccess', + }, { + 'fieldName': 'count', 'type': 'fieldAccess', + }], + 'fn': '/', + 'name': 'avg', + 'type': 'arithmetic', + }], + 'pagingSpec': {'pagingIdentifies': {}, 'threshold': 1}, + 'filter': {'dimension': 'one', 'type': 'selector', 'value': 1}, + 'having': {'aggregation': 'sum', 'type': 'greaterThan', 'value': 1}, + 'new_key': 'value', + } + assert client.query_dict == expected_query_dict + + def test_validate_query(self): + client = create_client() + client.validate_query(['validkey'], {'validkey': 'value'}) + pytest.raises(ValueError, client.validate_query, *[['validkey'], {'invalidkey': 'value'}]) + + def test_export_tsv(self, tmpdir): + client = create_client_with_results() + file_path = tmpdir.join('out.tsv') + client.export_tsv(str(file_path)) + assert file_path.read() == "value2\tvalue1\ttimestamp" + line_ending() + "㬓\t1\t2015-01-01T00:00:00.000-05:00" + line_ending() + "㬓\t2\t2015-01-02T00:00:00.000-05:00" + line_ending() + + def test_export_pandas(self): + client = create_client_with_results() + df = client.export_pandas() + expected_df = pandas.DataFrame([{ + 'timestamp': '2015-01-01T00:00:00.000-05:00', + 'value1': 1, + 'value2': '㬓', + }, { + 'timestamp': '2015-01-02T00:00:00.000-05:00', + 'value1': 2, + 'value2': '㬓', + }]) + assert_frame_equal(df, expected_df) + diff --git a/tests/utils/test_query_utils.py b/tests/utils/test_query_utils.py new file mode 100644 index 00000000..ba4f8172 --- /dev/null +++ b/tests/utils/test_query_utils.py @@ -0,0 +1,39 @@ +# -*- coding: UTF-8 -*- + +import os +import pytest +from six import PY3 +from pydruid.utils.query_utils import * + +def open_file(file_path): + if PY3: + f = open(file_path, 'w', newline='', encoding='utf-8') + else: + f = open(file_path, 'wb') + return f + +def line_ending(): + if PY3: + return os.linesep + return "\r\n" + +class TestUnicodeWriter: + def test_writerow(self, tmpdir): + file_path = tmpdir.join("out.tsv") + f = open_file(str(file_path)) + w = UnicodeWriter(f) + w.writerow(['value1', '㬓']) + f.close() + assert file_path.read() == "value1\t㬓" + line_ending() + + def test_writerows(self, tmpdir): + file_path = tmpdir.join("out.tsv") + f = open_file(str(file_path)) + w = UnicodeWriter(f) + w.writerows([ + ['header1', 'header2'], + ['value1', '㬓'] + ]) + f.close() + assert file_path.read() == "header1\theader2" + line_ending() + "value1\t㬓" + line_ending() +