Skip to content

Commit

Permalink
Merge pull request pandas-dev#13 from manahl/hooks_and_backports
Browse files Browse the repository at this point in the history
Fixed the hooks and backported some changes
  • Loading branch information
jamesblackburn committed Jul 14, 2015
2 parents 6368fbc + f72b7e3 commit b8668bb
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 38 deletions.
5 changes: 2 additions & 3 deletions arctic/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from .logging import logger


def authenticate(db, user, password):
"""
Return True / False on authentication success.
Expand All @@ -20,9 +19,9 @@ def authenticate(db, user, password):

Credential = namedtuple("MongoCredentials", ['database', 'user', 'password'])


def get_auth(host, app_name, database_name):
"""
Authentication hook to allow plugging in custom authentication credential providers
"""
return None
from hooks import _get_auth_hook
return _get_auth_hook(host, app_name, database_name)
8 changes: 7 additions & 1 deletion arctic/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

_resolve_mongodb_hook = lambda env: env
_log_exception_hook = lambda *args, **kwargs: None
_get_auth_hook = lambda *args, **kwargs: None


def get_mongodb_uri(host):
Expand All @@ -16,7 +17,7 @@ def get_mongodb_uri(host):

def register_resolve_mongodb_hook(hook):
global _resolve_mongodb_hook
_mongodb_resolve_hook = hook
_resolve_mongodb_hook = hook


def log_exception(fn_name, exception, retry_count, **kwargs):
Expand All @@ -29,3 +30,8 @@ def log_exception(fn_name, exception, retry_count, **kwargs):
def register_log_exception_hook(hook):
global _log_exception_hook
_log_exception_hook = hook


def register_get_auth_hook(hook):
global _get_auth_hook
_get_auth_hook = hook
7 changes: 4 additions & 3 deletions arctic/store/version_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _write_handler(self, version, symbol, data, **kwargs):
handler = self._bson_handler
return handler

def read(self, symbol, as_of=None, from_version=None, **kwargs):
def read(self, symbol, as_of=None, from_version=None, allow_secondary=None, **kwargs):
"""
Read data for the named symbol. Returns a VersionedItem object with
a data and metdata element (as passed into write).
Expand All @@ -292,9 +292,10 @@ def read(self, symbol, as_of=None, from_version=None, **kwargs):
-------
VersionedItem namedtuple which contains a .data and .metadata element
"""
allow_secondary = self._allow_secondary if allow_secondary is None else allow_secondary
try:
_version = self._read_metadata(symbol, as_of=as_of)
read_preference = ReadPreference.NEAREST if self._allow_secondary else None
read_preference = ReadPreference.NEAREST if allow_secondary else ReadPreference.PRIMARY
_version = self._read_metadata(symbol, as_of=as_of, read_preference=read_preference)
return self._do_read(symbol, _version, from_version, read_preference=read_preference, **kwargs)
except (OperationFailure, AutoReconnect) as e:
# Log the exception so we know how often this is happening
Expand Down
40 changes: 25 additions & 15 deletions arctic/tickstore/tickstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def read(self, symbol, date_range=None, columns=None, include_images=False, _tar
for b in self._collection.find(query, projection=projection).sort([(START, pymongo.ASCENDING)],):
data = self._read_bucket(b, column_set, column_dtypes,
multiple_symbols or (columns is not None and 'SYMBOL' in columns),
include_images)
include_images, columns)
for k, v in data.iteritems():
try:
rtn[k].append(v)
Expand Down Expand Up @@ -325,24 +325,35 @@ def _set_or_promote_dtype(self, column_dtypes, c, dtype):
dtype = np.dtype('f8')
column_dtypes[c] = np.promote_types(column_dtypes.get(c, dtype), dtype)

def _prepend_image(self, document, im):
def _prepend_image(self, document, im, rtn_length, column_dtypes, column_set, columns):
image = im[IMAGE]
first_dt = im['t']
if not first_dt.tzinfo:
first_dt = first_dt.replace(tzinfo=mktz('UTC'))
document[INDEX] = np.insert(document[INDEX], 0, np.uint64(datetime_to_ms(first_dt)))
for field in document:
if field == INDEX or document[field] is None:
for field in image:
if field == INDEX:
continue
if field in image:
val = image[field]
else:
logger.debug("Field %s is missing from image!", field)
val = np.nan
if columns and field not in columns:
continue
if field not in document or document[field] is None:
col_dtype = np.dtype(str if isinstance(image[field], basestring) else 'f8')
document[field] = self._empty(rtn_length, dtype=col_dtype)
column_dtypes[field] = col_dtype
column_set.add(field)
val = image[field]
document[field] = np.insert(document[field], 0, document[field].dtype.type(val))
# Now insert rows for fields in document that are not in the image
for field in set(document).difference(set(image)):
if field == INDEX:
continue
logger.debug("Field %s is missing from image!", field)
if document[field] is not None:
val = np.nan
document[field] = np.insert(document[field], 0, document[field].dtype.type(val))
return document

def _read_bucket(self, doc, columns, column_dtypes, include_symbol, include_images):
def _read_bucket(self, doc, column_set, column_dtypes, include_symbol, include_images, columns):
rtn = {}
if doc[VERSION] != 3:
raise ArcticException("Unhandled document version: %s" % doc[VERSION])
Expand All @@ -351,8 +362,8 @@ def _read_bucket(self, doc, columns, column_dtypes, include_symbol, include_imag
rtn_length = len(rtn[INDEX])
if include_symbol:
rtn['SYMBOL'] = [doc[SYMBOL], ] * rtn_length
columns.update(doc[COLUMNS].keys())
for c in columns:
column_set.update(doc[COLUMNS].keys())
for c in column_set:
try:
coldata = doc[COLUMNS][c]
dtype = np.dtype(coldata[DTYPE])
Expand All @@ -366,7 +377,7 @@ def _read_bucket(self, doc, columns, column_dtypes, include_symbol, include_imag
rtn[c] = None

if include_images and doc.get(IMAGE_DOC, {}).get(IMAGE, {}):
rtn = self._prepend_image(rtn, doc[IMAGE_DOC])
rtn = self._prepend_image(rtn, doc[IMAGE_DOC], rtn_length, column_dtypes, column_set, columns)
return rtn

def _empty(self, length, dtype):
Expand Down Expand Up @@ -493,8 +504,7 @@ def _to_dt(self, date, default_tz=None):
elif date.tzinfo is None:
if default_tz is None:
raise ValueError("Must specify a TimeZone on incoming data")
# Treat naive datetimes as London
return date.replace(tzinfo=mktz())
return date.replace(tzinfo=default_tz)
return date

def _str_dtype(self, dtype):
Expand Down
46 changes: 31 additions & 15 deletions tests/integration/tickstore/test_ts_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import pytest
import pytz

from arctic import arctic as m
from arctic.date import DateRange, mktz, CLOSED_CLOSED, CLOSED_OPEN, OPEN_CLOSED, OPEN_OPEN
from arctic.exceptions import OverlappingDataException, NoDataFoundException
from arctic.exceptions import NoDataFoundException


def test_read(tickstore_lib):
Expand Down Expand Up @@ -356,11 +355,11 @@ def test_read_longs(tickstore_lib):
def test_read_with_image(tickstore_lib):
DUMMY_DATA = [
{'a': 1.,
'index': dt(2013, 6, 1, 12, 00, tzinfo=mktz('Europe/London'))
'index': dt(2013, 1, 1, 11, 00, tzinfo=mktz('Europe/London'))
},
{
'b': 4.,
'index': dt(2013, 6, 1, 13, 00, tzinfo=mktz('Europe/London'))
'index': dt(2013, 1, 1, 12, 00, tzinfo=mktz('Europe/London'))
},
]
# Add an image
Expand All @@ -371,21 +370,38 @@ def test_read_with_image(tickstore_lib):
{'a': 37.,
'c': 2.,
},
't': dt(2013, 6, 1, 11, 0)
't': dt(2013, 1, 1, 10, tzinfo=mktz('Europe/London'))
}
}
}
)

tickstore_lib.read('SYM', columns=None)
read = tickstore_lib.read('SYM', columns=None, date_range=DateRange(dt(2013, 6, 1), dt(2013, 6, 2)))
assert read['a'][0] == 1
dr = DateRange(dt(2013, 1, 1), dt(2013, 1, 2))
# tickstore_lib.read('SYM', columns=None)
df = tickstore_lib.read('SYM', columns=None, date_range=dr)
assert df['a'][0] == 1

# Read with the image as well
read = tickstore_lib.read('SYM', columns=None, date_range=DateRange(dt(2013, 6, 1), dt(2013, 6, 2)),
include_images=True)
assert read['a'][0] == 37
assert read['a'][1] == 1
assert np.isnan(read['b'][0])
assert read['b'][2] == 4
assert read.index[0] == dt(2013, 6, 1, 11)
df = tickstore_lib.read('SYM', columns=None, date_range=dr, include_images=True)
assert set(df.columns) == set(('a', 'b', 'c'))
assert_array_equal(df['a'].values, np.array([37, 1, np.nan]))
assert_array_equal(df['b'].values, np.array([np.nan, np.nan, 4]))
assert_array_equal(df['c'].values, np.array([2, np.nan, np.nan]))
assert df.index[0] == dt(2013, 1, 1, 10)
assert df.index[1] == dt(2013, 1, 1, 11)
assert df.index[2] == dt(2013, 1, 1, 12)

df = tickstore_lib.read('SYM', columns=('a', 'b'), date_range=dr, include_images=True)
assert set(df.columns) == set(('a', 'b'))
assert_array_equal(df['a'].values, np.array([37, 1, np.nan]))
assert_array_equal(df['b'].values, np.array([np.nan, np.nan, 4]))
assert df.index[0] == dt(2013, 1, 1, 10)
assert df.index[1] == dt(2013, 1, 1, 11)
assert df.index[2] == dt(2013, 1, 1, 12)

df = tickstore_lib.read('SYM', columns=['c'], date_range=dr, include_images=True)
assert set(df.columns) == set(['c'])
assert_array_equal(df['c'].values, np.array([2, np.nan, np.nan]))
assert df.index[0] == dt(2013, 1, 1, 10)
assert df.index[1] == dt(2013, 1, 1, 11)
assert df.index[2] == dt(2013, 1, 1, 12)
29 changes: 28 additions & 1 deletion tests/unit/store/test_version_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import pymongo
from pymongo import ReadPreference
from pymongo import ReadPreference, read_preferences

from arctic.date import mktz
from arctic.store import version_store
Expand Down Expand Up @@ -44,6 +44,33 @@ def test_list_versions_localTime():
'snapshots': 'snap'}


def test_get_version_allow_secondary_True():
vs = create_autospec(VersionStore, instance=True,
_versions=Mock())
vs._allow_secondary = True
vs._find_snapshots.return_value = 'snap'
vs._versions.find.return_value = [{'_id': bson.ObjectId.from_datetime(dt(2013, 4, 1, 9, 0)),
'symbol': 's', 'version': 10}]

VersionStore.read(vs, "symbol")
assert vs._read_metadata.call_args_list == [call('symbol', as_of=None, read_preference=ReadPreference.NEAREST)]
assert vs._do_read.call_args_list == [call('symbol', vs._read_metadata.return_value, None, read_preference=ReadPreference.NEAREST)]


def test_get_version_allow_secondary_user_override_False():
"""Ensure user can override read preference when calling read"""
vs = create_autospec(VersionStore, instance=True,
_versions=Mock())
vs._allow_secondary = True
vs._find_snapshots.return_value = 'snap'
vs._versions.find.return_value = [{'_id': bson.ObjectId.from_datetime(dt(2013, 4, 1, 9, 0)),
'symbol': 's', 'version': 10}]

VersionStore.read(vs, "symbol", allow_secondary=False)
assert vs._read_metadata.call_args_list == [call('symbol', as_of=None, read_preference=ReadPreference.PRIMARY)]
assert vs._do_read.call_args_list == [call('symbol', vs._read_metadata.return_value, None, read_preference=ReadPreference.PRIMARY)]


def test_read_as_of_LondonTime():
# When we do a read, with naive as_of, that as_of is treated in London Time.
vs = create_autospec(VersionStore, instance=True,
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from mock import sentinel, call, Mock
from arctic.hooks import register_get_auth_hook, register_log_exception_hook, \
register_resolve_mongodb_hook, get_mongodb_uri, log_exception
from arctic.auth import get_auth


def test_log_exception_hook():
logger = Mock()
register_log_exception_hook(logger)
log_exception(sentinel.fn, sentinel.e, sentinel.r)
assert logger.call_args_list == [call(sentinel.fn, sentinel.e, sentinel.r)]


def test_get_mongodb_uri_hook():
resolver = Mock()
resolver.return_value = sentinel.result
register_resolve_mongodb_hook(resolver)
assert get_mongodb_uri(sentinel.host) == sentinel.result
assert resolver.call_args_list == [call(sentinel.host)]


def test_get_auth_hook():
auth_resolver = Mock()
register_get_auth_hook(auth_resolver)
get_auth(sentinel.host, sentinel.app_name, sentinel.database_name)
assert auth_resolver.call_args_list == [call(sentinel.host, sentinel.app_name, sentinel.database_name)]

0 comments on commit b8668bb

Please sign in to comment.