Skip to content

Commit

Permalink
Fix regressions introduced by support for Python 3 and add initial te…
Browse files Browse the repository at this point in the history
…st coverage
  • Loading branch information
griffy committed Jun 30, 2015
1 parent faccbc0 commit 616a93e
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 35 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ pydruid/bard.py
build/
dist/
pyDruid.egg-info/

__pycache__
.eggs
35 changes: 19 additions & 16 deletions pydruid/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from __future__ import division
from __future__ import absolute_import

from six import iteritems
import sys

import six
from six.moves import urllib

try:
Expand Down Expand Up @@ -108,7 +110,7 @@ def __init__(self, url, endpoint):

def __post(self, query):
try:
querystr = json.dumps(query).encode('ascii')
querystr = json.dumps(query).encode('utf-8')
if self.url.endswith('/'):
url = self.url + self.endpoint
else:
Expand Down Expand Up @@ -176,25 +178,26 @@ def export_tsv(self, dest_path):
7.0 user_1 2013-10-04T00:00:00.000Z
6.0 user_2 2013-10-04T00:00:00.000Z
"""
f = open(dest_path, 'wb')
tsv_file = csv.writer(f, delimiter='\t')
if six.PY3:
f = open(dest_path, 'w', newline='', encoding='utf-8')
else:
f = open(dest_path, 'wb')
w = UnicodeWriter(f)

if self.query_type == "timeseries":
header = self.result[0]['result'].keys()
header = list(self.result[0]['result'].keys())
header.append('timestamp')
if self.query_type == 'topN':
header = self.result[0]['result'][0].keys()
elif self.query_type == 'topN':
header = list(self.result[0]['result'][0].keys())
header.append('timestamp')
elif self.query_type == "groupBy":
header = self.result[0]['event'].keys()
header = list(self.result[0]['event'].keys())
header.append('timestamp')
header.append('version')
else:
raise NotImplementedError('TSV export not implemented for query type: {0}'.format(self.query_type))

tsv_file.writerow(header)

w = UnicodeWriter(f)
w.writerow(header)

if self.result:
if self.query_type == "topN" or self.query_type == "timeseries":
Expand All @@ -203,15 +206,15 @@ def export_tsv(self, dest_path):
result = item['result']
if type(result) is list: # topN
for line in result:
w.writerow(line.values() + [timestamp])
w.writerow(list(line.values()) + [timestamp])
else: # timeseries
w.writerow(result.values() + [timestamp])
w.writerow(list(result.values()) + [timestamp])
elif self.query_type == "groupBy":
for item in self.result:
timestamp = item['timestamp']
version = item['version']
w.writerow(
item['event'].values() + [timestamp] + [version])
list(item['event'].values()) + [timestamp] + [version])

f.close()

Expand Down Expand Up @@ -283,7 +286,7 @@ def validate_query(self, valid_parts, args):
:raise ValueError: if an invalid object is given
"""
valid_parts = valid_parts[:] + ['context']
for key, val in iteritems(args):
for key, val in six.iteritems(args):
if key not in valid_parts:
raise ValueError(
'Query component: {0} is not valid for query type: {1}.'
Expand All @@ -294,7 +297,7 @@ def validate_query(self, valid_parts, args):
def build_query(self, args):
query_dict = {'queryType': self.query_type}

for key, val in iteritems(args):
for key, val in six.iteritems(args):
if key == 'aggregations':
query_dict[key] = build_aggregators(val)
elif key == 'post_aggregations':
Expand Down
6 changes: 5 additions & 1 deletion pydruid/utils/postaggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#
from __future__ import division

import six

class Postaggregator:
def __init__(self, fn, fields, name):
Expand All @@ -40,6 +41,9 @@ def __div__(self, other):
return Postaggregator('/', self.fields(other),
self.name + 'div' + other.name)

def __truediv__(self, other):
return self.__div__(other)

def fields(self, other):
return [self.post_aggregator, other.post_aggregator]

Expand All @@ -50,7 +54,7 @@ def rename_postagg(new_name, post_aggregator):
return post_aggregator

return [rename_postagg(new_name, postagg.post_aggregator)
for (new_name, postagg) in postaggs.iteritems()]
for (new_name, postagg) in six.iteritems(postaggs)]


class Field(Postaggregator):
Expand Down
28 changes: 12 additions & 16 deletions pydruid/utils/query_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
#
import csv
import codecs
from six import StringIO

import six
# A special CSV writer which will write rows to TSV file "f", which is encoded in utf-8.
# this is necessary because the values in druid are not all ASCII.

Expand All @@ -25,24 +24,21 @@ class UnicodeWriter:

# delimiter="\t"
def __init__(self, f, dialect="excel-tab", encoding="utf-8", **kwds):
# Redirect output to a queue
self.queue = StringIO()
self.writer = csv.writer(self.queue, dialect=dialect, **kwds)
self.stream = f
self.writer = csv.writer(self.stream, dialect=dialect, **kwds)
self.encoder = codecs.getincrementalencoder(encoding)()

def __encode(self, data):
data = str(data) if isinstance(data, six.integer_types) else data
if not six.PY3:
data = data.encode('utf-8') if isinstance(data, unicode) else data
data = data.decode('utf-8')
return self.encoder.encode(data)
return data

def writerow(self, row):
self.writer.writerow(
[s.encode("utf-8") if isinstance(s, unicode) else s for s in row])
# Fetch UTF-8 output from the queue ...
data = self.queue.getvalue()
data = data.decode("utf-8")
# ... and reencode it into the target encoding
data = self.encoder.encode(data)
# write to the target stream
self.stream.write(data)
# empty queue
self.queue.truncate(0)
row = [self.__encode(s) for s in row]
self.writer.writerow(row)

def writerows(self, rows):
for row in rows:
Expand Down
21 changes: 20 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
import sys

from setuptools import setup
from setuptools.command.test import test as TestCommand

class PyTest(TestCommand):
user_options = [('pytest-args=', 'a', "Arguments to pass to py.test")]

def initialize_options(self):
TestCommand.initialize_options(self)
self.pytest_args = []

def finalize_options(self):
TestCommand.finalize_options(self)
self.test_args = []
self.test_suite = True

def run_tests(self):
import pytest
status = pytest.main(self.pytest_args)
sys.exit(status)


install_requires = ["six >= 1.9.0"]
Expand All @@ -20,4 +37,6 @@
description='A Python connector for Druid.',
long_description='See https://github.com/metamx/pydruid for more information.',
install_requires=install_requires,
tests_require=['pytest'],
cmdclass={'test': PyTest},
)
97 changes: 97 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: UTF-8 -*-

import os
import pytest
import pandas
from pandas.util.testing import assert_frame_equal
from six import PY3
from pydruid.client import PyDruid
from pydruid.utils.aggregators import *
from pydruid.utils.postaggregator import *
from pydruid.utils.filters import *
from pydruid.utils.having import *

def create_client():
return PyDruid('http://localhost:8083', 'druid/v2/')

def create_client_with_results():
client = create_client()
client.query_type = 'timeseries'
client.result = [
{'result': {'value1': 1, 'value2': '㬓'}, 'timestamp': '2015-01-01T00:00:00.000-05:00'},
{'result': {'value1': 2, 'value2': '㬓'}, 'timestamp': '2015-01-02T00:00:00.000-05:00'}
]
return client

def line_ending():
if PY3:
return os.linesep
return "\r\n"

class TestClient:
def test_build_query(self):
client = create_client()
assert client.query_dict == None

client.build_query({
'datasource': 'things',
'aggregations': {
'count': count('thing'),
},
'post_aggregations': {
'avg': Field('sum') / Field('count'),
},
'paging_spec': {
'pagingIdentifies': {},
'threshold': 1,
},
'filter': Dimension('one') == 1,
'having': Aggregation('sum') > 1,
'new_key': 'value',
})
expected_query_dict = {
'queryType': None,
'dataSource': 'things',
'aggregations': [{'fieldName': 'thing', 'name': 'count', 'type': 'count'}],
'postAggregations': [{
'fields': [{
'fieldName': 'sum', 'type': 'fieldAccess',
}, {
'fieldName': 'count', 'type': 'fieldAccess',
}],
'fn': '/',
'name': 'avg',
'type': 'arithmetic',
}],
'pagingSpec': {'pagingIdentifies': {}, 'threshold': 1},
'filter': {'dimension': 'one', 'type': 'selector', 'value': 1},
'having': {'aggregation': 'sum', 'type': 'greaterThan', 'value': 1},
'new_key': 'value',
}
assert client.query_dict == expected_query_dict

def test_validate_query(self):
client = create_client()
client.validate_query(['validkey'], {'validkey': 'value'})
pytest.raises(ValueError, client.validate_query, *[['validkey'], {'invalidkey': 'value'}])

def test_export_tsv(self, tmpdir):
client = create_client_with_results()
file_path = tmpdir.join('out.tsv')
client.export_tsv(str(file_path))
assert file_path.read() == "value2\tvalue1\ttimestamp" + line_ending() + "㬓\t1\t2015-01-01T00:00:00.000-05:00" + line_ending() + "㬓\t2\t2015-01-02T00:00:00.000-05:00" + line_ending()

def test_export_pandas(self):
client = create_client_with_results()
df = client.export_pandas()
expected_df = pandas.DataFrame([{
'timestamp': '2015-01-01T00:00:00.000-05:00',
'value1': 1,
'value2': '㬓',
}, {
'timestamp': '2015-01-02T00:00:00.000-05:00',
'value1': 2,
'value2': '㬓',
}])
assert_frame_equal(df, expected_df)

39 changes: 39 additions & 0 deletions tests/utils/test_query_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: UTF-8 -*-

import os
import pytest
from six import PY3
from pydruid.utils.query_utils import *

def open_file(file_path):
if PY3:
f = open(file_path, 'w', newline='', encoding='utf-8')
else:
f = open(file_path, 'wb')
return f

def line_ending():
if PY3:
return os.linesep
return "\r\n"

class TestUnicodeWriter:
def test_writerow(self, tmpdir):
file_path = tmpdir.join("out.tsv")
f = open_file(str(file_path))
w = UnicodeWriter(f)
w.writerow(['value1', '㬓'])
f.close()
assert file_path.read() == "value1\t㬓" + line_ending()

def test_writerows(self, tmpdir):
file_path = tmpdir.join("out.tsv")
f = open_file(str(file_path))
w = UnicodeWriter(f)
w.writerows([
['header1', 'header2'],
['value1', '㬓']
])
f.close()
assert file_path.read() == "header1\theader2" + line_ending() + "value1\t㬓" + line_ending()

0 comments on commit 616a93e

Please sign in to comment.