diff --git a/tstools/src/ts_driver/drivers/__init__.py b/tstools/src/ts_driver/drivers/__init__.py index 2af0823..8e18324 100644 --- a/tstools/src/ts_driver/drivers/__init__.py +++ b/tstools/src/ts_driver/drivers/__init__.py @@ -1,2 +1,14 @@ """ Timeseries drivers """ +from collections import OrderedDict + +DRIVERS = OrderedDict(( + ('StackedTimeSeries', 'timeseries_stacked'), + ('CCDCTimeSeries', 'timeseries_ccdc'), + ('YATSMTimeSeries', 'timeseries_yatsm'), + ('YATSMMetTimeSeries', 'timeseries_yatsm_met'), + ('YATSMLandsatPALSARTS', 'timeseries_opticalradar'), +)) + +for name, val in DRIVERS.items(): + DRIVERS[name] = 'tstools.ts_driver.drivers.' + val diff --git a/tstools/src/ts_driver/ts_manager.py b/tstools/src/ts_driver/ts_manager.py index 5580aaf..5706e11 100644 --- a/tstools/src/ts_driver/ts_manager.py +++ b/tstools/src/ts_driver/ts_manager.py @@ -1,13 +1,13 @@ """ Find, detect, and make available timeseries drivers implementations -Timeseries drivers must inherit from the Abstract Base Class -"AbstractTimeSeriesDriver" to be detected. +Timeseries drivers must be enumerated in tstools.ts_drivers.drivers.DRIVERS +or be a part of the TSTools.drivers entry point to be detected. """ import importlib -import os -import pkgutil +from pkg_resources import iter_entry_points import sys +from .drivers import DRIVERS from ..logger import logger @@ -24,6 +24,7 @@ class BrokenModule(object): """ def __init__(self, module, message): self.__doc__ = self.__doc__ % (module, message) + self.description = 'Broken: %s' % module class TSManager(object): @@ -35,70 +36,44 @@ class TSManager(object): ts = None def __init__(self, location=None): - # Location of timeseires modules - self.plugin_dir = [] # All available timeseries self.ts_drivers = [] - - if location and os.path.isdir(location): - self.plugin_dir.append(location) - - file_location = os.path.join(os.path.dirname(__file__), 'drivers') - self.plugin_dir.append('./' if file_location == '' else file_location) - self.find_timeseries() def find_timeseries(self): """ Try to find timeseries classes """ - try: - from . import timeseries - except ImportError: - logger.critical('Could not import "timeseries". Check your path') - raise - else: - logger.debug('Found "timeseries" module') - broken = [] - # Use pkgutil to search for timeseries - logger.debug('Module name: {n}'.format(n=__name__)) - for loader, modname, ispkg in pkgutil.walk_packages(self.plugin_dir): - full_path = '%s.drivers.%s' % (__name__.rsplit('.', 1)[0], modname) + for name, import_path in DRIVERS.items(): try: - importlib.import_module(full_path) - except ImportError as e: - logger.error('Cannot import %s: %s' % (modname, e.message)) - broken_module = BrokenModule(modname, - e.args[0] if e.args else - 'Unknown import error') - broken_module.description = 'Broken: %s' % modname + driver = getattr(importlib.import_module(import_path), name) + self.ts_drivers.append(driver) + except ImportError as exc: + logger.error('Cannot import %s: %s' % (name, exc)) + broken_module = BrokenModule(name, exc) broken.append(broken_module) except: logger.error('Cannot import %s: %s' % - (modname, sys.exc_info()[0])) + (name, sys.exc_info()[0])) + + for plugin in iter_entry_points('TSTools.drivers'): + try: + driver = plugin.load() + self.ts_drivers.append(driver) + except ImportError as exc: + logger.error('Cannot import %s: %s' % (plugin.name, exc)) + broken_module = BrokenModule(plugin.name, exc) + broken.append(broken_module) + except: + logger.error('Cannot import %s: %s' % + (plugin.name, sys.exc_info()[0])) raise - self.ts_drivers = timeseries.AbstractTimeSeriesDriver.__subclasses__() for tsd in self.ts_drivers: logger.info('Found driver: {tsd}'.format(tsd=tsd)) - # Find even more descendents - for subclass in self.ts_drivers: - self.recursive_find_subclass(subclass) - self.ts_drivers.extend(broken) - def recursive_find_subclass(self, subclass): - """ Search subclass for descendents """ - - sub_subclasses = subclass.__subclasses__() - - for sub_subclass in sub_subclasses: - if sub_subclass not in self.ts_drivers: - self.ts_drivers.append(sub_subclass) - logger.info('Found driver: {tsd}'.format(tsd=sub_subclass)) - self.recursive_find_subclass(sub_subclass) - # Store timeseries manager tsm = TSManager()