Skip to content
This repository has been archived by the owner on Jun 14, 2019. It is now read-only.

Commit

Permalink
Specify drivers manually & via setuptools
Browse files Browse the repository at this point in the history
No more recursive import -- should save some load time
and is more straightforward for drivers distributed with
TSTools. Avoids class inheritance gore. Setuptools
iter_entry_points allows for driver plugins to plugin
under 'TSTools.drivers'
  • Loading branch information
ceholden committed Aug 24, 2016
1 parent a6ae437 commit 7ec87bd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 49 deletions.
12 changes: 12 additions & 0 deletions tstools/src/ts_driver/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
""" Timeseries drivers
"""
from collections import OrderedDict

DRIVERS = OrderedDict((
('StackedTimeSeries', 'timeseries_stacked'),
('CCDCTimeSeries', 'timeseries_ccdc'),
('YATSMTimeSeries', 'timeseries_yatsm'),
('YATSMMetTimeSeries', 'timeseries_yatsm_met'),
('YATSMLandsatPALSARTS', 'timeseries_opticalradar'),
))

for name, val in DRIVERS.items():
DRIVERS[name] = 'tstools.ts_driver.drivers.' + val
73 changes: 24 additions & 49 deletions tstools/src/ts_driver/ts_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
""" Find, detect, and make available timeseries drivers implementations
Timeseries drivers must inherit from the Abstract Base Class
"AbstractTimeSeriesDriver" to be detected.
Timeseries drivers must be enumerated in tstools.ts_drivers.drivers.DRIVERS
or be a part of the TSTools.drivers entry point to be detected.
"""
import importlib
import os
import pkgutil
from pkg_resources import iter_entry_points
import sys

from .drivers import DRIVERS
from ..logger import logger


Expand All @@ -24,6 +24,7 @@ class BrokenModule(object):
"""
def __init__(self, module, message):
self.__doc__ = self.__doc__ % (module, message)
self.description = 'Broken: %s' % module


class TSManager(object):
Expand All @@ -35,70 +36,44 @@ class TSManager(object):
ts = None

def __init__(self, location=None):
# Location of timeseires modules
self.plugin_dir = []
# All available timeseries
self.ts_drivers = []

if location and os.path.isdir(location):
self.plugin_dir.append(location)

file_location = os.path.join(os.path.dirname(__file__), 'drivers')
self.plugin_dir.append('./' if file_location == '' else file_location)

self.find_timeseries()

def find_timeseries(self):
""" Try to find timeseries classes """
try:
from . import timeseries
except ImportError:
logger.critical('Could not import "timeseries". Check your path')
raise
else:
logger.debug('Found "timeseries" module')

broken = []

# Use pkgutil to search for timeseries
logger.debug('Module name: {n}'.format(n=__name__))
for loader, modname, ispkg in pkgutil.walk_packages(self.plugin_dir):
full_path = '%s.drivers.%s' % (__name__.rsplit('.', 1)[0], modname)
for name, import_path in DRIVERS.items():
try:
importlib.import_module(full_path)
except ImportError as e:
logger.error('Cannot import %s: %s' % (modname, e.message))
broken_module = BrokenModule(modname,
e.args[0] if e.args else
'Unknown import error')
broken_module.description = 'Broken: %s' % modname
driver = getattr(importlib.import_module(import_path), name)
self.ts_drivers.append(driver)
except ImportError as exc:
logger.error('Cannot import %s: %s' % (name, exc))
broken_module = BrokenModule(name, exc)
broken.append(broken_module)
except:
logger.error('Cannot import %s: %s' %
(modname, sys.exc_info()[0]))
(name, sys.exc_info()[0]))

for plugin in iter_entry_points('TSTools.drivers'):
try:
driver = plugin.load()
self.ts_drivers.append(driver)
except ImportError as exc:
logger.error('Cannot import %s: %s' % (plugin.name, exc))
broken_module = BrokenModule(plugin.name, exc)
broken.append(broken_module)
except:
logger.error('Cannot import %s: %s' %
(plugin.name, sys.exc_info()[0]))
raise

self.ts_drivers = timeseries.AbstractTimeSeriesDriver.__subclasses__()
for tsd in self.ts_drivers:
logger.info('Found driver: {tsd}'.format(tsd=tsd))

# Find even more descendents
for subclass in self.ts_drivers:
self.recursive_find_subclass(subclass)

self.ts_drivers.extend(broken)

def recursive_find_subclass(self, subclass):
""" Search subclass for descendents """

sub_subclasses = subclass.__subclasses__()

for sub_subclass in sub_subclasses:
if sub_subclass not in self.ts_drivers:
self.ts_drivers.append(sub_subclass)
logger.info('Found driver: {tsd}'.format(tsd=sub_subclass))
self.recursive_find_subclass(sub_subclass)


# Store timeseries manager
tsm = TSManager()
Expand Down

0 comments on commit 7ec87bd

Please sign in to comment.