diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..d948a9e --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,27 @@ +name: Publish tagged releases to PyPI + +on: + push: + tags: + - "v*" + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v1 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: '3.9' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + - name: Build and publish + env: + HATCH_INDEX_USER: __token__ + HATCH_INDEX_AUTH: ${{ secrets.PYPI_PUBLISH_TOKEN }} + run: | + hatch build + hatch diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..9e9a6af --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,39 @@ +name: Run Tests + +on: + push: + branches: + - "**" + +jobs: + build: + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + fail-fast: false + matrix: + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch hatch-vcs + - name: Run tests + run: hatch run tests + - if: matrix.python-version == '3.9' + name: Lint + run: hatch run lint:check + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1.0.10 + with: + file: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON diff --git a/.gitignore b/.gitignore index 087f133..c6f2a06 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ var/ *.egg-info/ .installed.cfg *.egg +_version_info.py # PyInstaller # Usually these files are written by a python script from a template diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 1689ff5..0000000 --- a/.travis.yml +++ /dev/null @@ -1,27 +0,0 @@ -language: python -python: - - "2.7" - - "3.6" - - "3.7" - - "3.8" - -before_install: - - pip install flake8 wheel tox-travis - -install: - - pip install ".[all]" - -script: - - "tox" - -# Deploy to pypi automatically from tagged releases on the stable branch -deploy: - provider: pypi - user: mister.wardrop - password: - secure: "rCsAx55Z/IS8MMJi5PlMQLg8bPHC7ava19+hU0I8AYg1lY+yxBisO48i54g4veMISR2I0HUqoK7WSFxPfOceLCYkTNV6XT8RkHjBrsmpdwS1GdeLBPaI+QOe5AfR/G2F3KXRRD2NpyIkzeo3wS7n3+sqkGBSNB1+4KtDu8PUqWKk3mv1CqKLS6ZHwwrRvk9K4oF3AD70qpu2CaPUT08xNU50tFAuzz7bwG3DIOWxzr7YefewwjiaGlsffMTp1sJxlmTYtKL1inCI7VDk3Xxymq4GQu0gIW3zWIjvZEuJlAnYoq1JhR1Btw8KNF23uz4PutbkOHq+c3j2ZyOsyzDz2pK/ywS6DnMIHpbceHH+JnOstxts05IOBIP++Wti0lfIXuOMh/lKWUNkW1KrOVr3Qz4A0UMgbJ82FzTsE2Ei/ShgLeVhpvYgN/ZaJB5g9HmL9HQAWsgHZACa2BHM90SidwZtTY1qY2Hxp9Rvj6gg3Q59CTIfTYfWOnobgiHL6ClOOk0oES36KrwU/AVOxaH1dddtWkLbX4qG0/Ur+NY7iGgD+GIwH/fAxxuTyZ81+jQjhLXpMYr0HewYI0MruqEPhekbJffqh5xRlL+S/A9Xf4+X6Ox7PKn+MQBPPwps8ZTKT8TlqUM/VgtVLBMr660kLS2Jr0LXwSUJF2JvUTOsFww=" - on: - tags: true - python: "3.6" - distributions: sdist bdist_wheel - repo: airbnb/omniduct diff --git a/README.md b/README.md index 5570e51..dc0e932 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![PyPI - Version](https://img.shields.io/pypi/v/omniduct.svg)](https://pypi.org/project/omniduct/) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/omniduct.svg) ![PyPI - Status](https://img.shields.io/pypi/status/omniduct.svg) -[![Build Status](https://travis-ci.org/airbnb/omniduct.svg?branch=master)](https://travis-ci.org/airbnb/omniduct) +[![Build Status](https://img.shields.io/github/actions/workflow/status/airbnb/omniduct/tests.yml?branch=main)](https://github.com/airbnb/omniduct/actions?query=workflow%3A%22Run+Tests%22) [![Documentation Status](https://readthedocs.org/projects/omniduct/badge/?version=latest)](http://omniduct.readthedocs.io/en/latest/?badge=latest) @@ -19,6 +19,3 @@ It provides: - Automatic port forwarding of remote services over SSH where connections cannot be made directly. - Convenient IPython magic functions for interfacing with data providers from within IPython and Jupyter Notebook sessions. - Utility classes and methods to assist in maintaining registries of useful services. - -**Note:** Omniduct 1.1.x is the last version series to support Python 2. Going -forward it will support Python 3.6+. diff --git a/omniduct/__init__.py b/omniduct/__init__.py index 5f0140e..437e5eb 100644 --- a/omniduct/__init__.py +++ b/omniduct/__init__.py @@ -17,17 +17,19 @@ def about(): "Omniduct", version=__version__, logo=__logo__, - maintainers=OrderedDict(zip( - [a.strip() for a in __author__.split(',')], - [a.strip() for a in __author_email__.split(',')] - )), + maintainers=OrderedDict( + zip( + [a.strip() for a in __author__.split(",")], + [a.strip() for a in __author_email__.split(",")], + ) + ), attributes={ - 'Documentation': __docs_url__, + "Documentation": __docs_url__, }, description=""" Omniduct provides uniform interfaces for connecting to and extracting data from a wide variety of (potentially remote) data stores (including HDFS, Hive, Presto, MySQL, etc). """, - endorse_omniduct=False + endorse_omniduct=False, ) diff --git a/omniduct/_version.py b/omniduct/_version.py index 8fa6e7b..522edb8 100644 --- a/omniduct/_version.py +++ b/omniduct/_version.py @@ -1,107 +1,102 @@ import os import sys -__all__ = ['__author__', '__author_email__', '__version__', '__logo__', '__docs_url__'] +try: + from ._version_info import __version__, __version_tuple__ +except ImportError: + __version__ = "unknown" + __version_tuple__ = (0, 0, 0, "+unknown") + +__all__ = [ + "__author__", + "__author_email__", + "__version__", + "__version_tuple__", + "__logo__", + "__docs_url__", +] __author__ = "Matthew Wardrop, Dan Frank" __author_email__ = "mpwardrop@gmail.com, danfrankj@gmail.com" -__version__ = "1.1.19" -__logo__ = os.path.join(os.path.dirname(__file__), 'logo.png') if '__file__' in globals() else None +__logo__ = ( + os.path.join(os.path.dirname(__file__), "logo.png") + if "__file__" in globals() + else None +) __docs_url__ = "https://omniduct.readthedocs.io/" # These are the core dependencies, and should not include those which are used only in handling specific protocols. # Order matters since installation happens from the end of the list __dependencies__ = [ - "future", # Python 2/3 support - "six", # Python 2/3 support - "interface_meta>=1.1.0,<2", # Metaclass for creating an extensible well-documented architecture "pyyaml", # YAML configuration parsing "decorator", # Decorators used by caching and documentation routines "progressbar2>=3.30.0", # Support for progressbars in logging routines "wrapt", # Object proxying for conveniently exposing ducts in registry - # Database querying libraries "jinja2", # Templating support in databases - "pandas>=0.17.1", # Various results including database queries are returned as pandas dataframes + "pandas>=0.20.3", # Various results including database queries are returned as pandas dataframes "sqlparse", # Neatening of SQL based queries (mainly to avoid missing the cache) "sqlalchemy", # Various integration endpoints in the database stack - # Utility libraries "python-dateutil", # Used for its `relativedelta` class for Cache instances "lazy-object-proxy", # Schema traversal ] -if sys.version_info.major < 3 or sys.version_info.major == 3 and sys.version_info.minor < 4: - __dependencies__.append("enum34") # Python 3.4+ style enums in older versions of python - -PY2 = sys.version_info[0] == 2 -if os.name == 'posix' and PY2: - __dependencies__.append('subprocess32') # Python 3.2+ subprocess handling for Python 2 __optional_dependencies__ = { # Databases - 'druid': [ - 'pydruid>=0.4.0', # Primary client + "druid": [ + "pydruid>=0.4.0", # Primary client ], - - 'hiveserver2': [ - 'pyhive[hive]>=0.4', # Primary client - 'thrift>=0.10.0', # Thrift dependency which seems not to be installed with upstream deps + "hiveserver2": [ + "pyhive[hive]>=0.4", # Primary client + "thrift>=0.10.0", # Thrift dependency which seems not to be installed with upstream deps ], - - 'presto': [ - 'pyhive[presto]>=0.4', # Primary client + "presto": [ + "pyhive[presto]>=0.4", # Primary client ], - - 'pyspark': [ - 'pyspark', # Primary client + "pyspark": [ + "pyspark", # Primary client ], - - 'snowflake': [ - 'snowflake-sqlalchemy', + "snowflake": [ + "snowflake-sqlalchemy", ], - - 'exasol': ['pyexasol'] if sys.version_info.major > 2 else [], - + "exasol": ["pyexasol"] if sys.version_info.major > 2 else [], # Filesystems - 'webhdfs': [ - 'pywebhdfs', # Primary client - 'requests', # For rerouting redirect queries to our port-forwarded services + "webhdfs": [ + "pywebhdfs", # Primary client + "requests", # For rerouting redirect queries to our port-forwarded services ], - - 's3': [ - 'boto3', # AWS client library + "s3": [ + "boto3", # AWS client library ], - # Remotes - 'ssh': [ - 'pexpect', # Command line handling (including smartcard activation) + "ssh": [ + "pexpect", # Command line handling (including smartcard activation) ], - - 'ssh_paramiko': [ - 'paramiko', # Primary client - 'pexpect', # Command line handling (including smartcard activation) + "ssh_paramiko": [ + "paramiko", # Primary client + "pexpect", # Command line handling (including smartcard activation) ], - # Rest clients - 'rest': [ - 'requests', # Library to handle underlying REST queries + "rest": [ + "requests", # Library to handle underlying REST queries ], - # Documentation requirements - 'docs': [ - 'sphinx', # The documentation engine - 'sphinx_autobuild', # A Sphinx plugin used during development of docs - 'sphinx_rtd_theme', # The Spinx theme used by the docs + "docs": [ + "sphinx", # The documentation engine + "sphinx_autobuild", # A Sphinx plugin used during development of docs + "sphinx_rtd_theme", # The Spinx theme used by the docs + ], + "test": [ + "nose", # test runner + "mock", # mocking + "pyfakefs", # mock filesystem + "coverage", # test coverage monitoring + "flake8", # Code linting ], - - 'test': [ - 'nose', # test runner - 'mock', # mocking - 'pyfakefs', # mock filesystem - 'coverage', # test coverage monitoring - 'flake8', # Code linting - ] } -__optional_dependencies__['all'] = [dep for deps in __optional_dependencies__.values() for dep in deps] +__optional_dependencies__["all"] = [ + dep for deps in __optional_dependencies__.values() for dep in deps +] diff --git a/omniduct/caches/_serializers.py b/omniduct/caches/_serializers.py index 5ff363c..a512151 100644 --- a/omniduct/caches/_serializers.py +++ b/omniduct/caches/_serializers.py @@ -1,11 +1,9 @@ import pickle -from distutils.version import LooseVersion import pandas -class Serializer(object): - +class Serializer: @property def file_extension(self): return "" @@ -18,13 +16,14 @@ def deserialize(self, fh): class BytesSerializer(Serializer): - @property def file_extension(self): return ".bytes" def serialize(self, obj, fh): - assert isinstance(obj, bytes), "BytesSerializer requires incoming data be already encoded into a bytestring." + assert isinstance( + obj, bytes + ), "BytesSerializer requires incoming data be already encoded into a bytestring." fh.write(obj) def deserialize(self, fh): @@ -32,7 +31,6 @@ def deserialize(self, fh): class PickleSerializer(Serializer): - @property def file_extension(self): return ".pickle" @@ -45,19 +43,12 @@ def deserialize(self, fh): class PandasSerializer(Serializer): - @property def file_extension(self): return ".pandas" - @classmethod - def serialize(cls, formatted_data, fh): - # compat: if pandas is old, to_pickle does not accept file handles - if LooseVersion(pandas.__version__) <= LooseVersion('0.20.3'): - fh.close() - fh = fh.name - return pandas.to_pickle(formatted_data, fh, compression=None) + def serialize(self, obj, fh): + return pandas.to_pickle(obj, fh, compression=None) - @classmethod - def deserialize(cls, fh): + def deserialize(self, fh): return pandas.read_pickle(fh, compression=None) diff --git a/omniduct/caches/base.py b/omniduct/caches/base.py index 78c4992..08ce017 100644 --- a/omniduct/caches/base.py +++ b/omniduct/caches/base.py @@ -1,11 +1,9 @@ import datetime import functools -import sys from abc import abstractmethod import dateutil import pandas -import six import yaml from decorator import decorator from interface_meta import quirk_docs @@ -18,22 +16,22 @@ from ._serializers import PickleSerializer config.register( - 'cache_fail_hard', - description='Raise an exception if a cache fails to save (otherwise errors are logged and suppressed).', - default=False + "cache_fail_hard", + description="Raise an exception if a cache fails to save (otherwise errors are logged and suppressed).", + default=False, ) def cached_method( - key, - namespace=lambda self, kwargs: ( - self.cache_namespace or "{}.{}".format(self.__class__.__name__, self.name) - ), - cache=lambda self, kwargs: self.cache, - use_cache=lambda self, kwargs: kwargs.pop('use_cache', True), - renew=lambda self, kwargs: kwargs.pop('renew', False), - serializer=lambda self, kwargs: PickleSerializer(), - metadata=lambda self, kwargs: None + key, + namespace=lambda self, kwargs: ( + self.cache_namespace or f"{self.__class__.__name__}.{self.name}" + ), + cache=lambda self, kwargs: self.cache, + use_cache=lambda self, kwargs: kwargs.pop("use_cache", True), + renew=lambda self, kwargs: kwargs.pop("renew", False), + serializer=lambda self, kwargs: PickleSerializer(), + metadata=lambda self, kwargs: None, ): """ Wrap a method of a `Duct` class and add caching capabilities. @@ -73,7 +71,7 @@ def cached_method( @decorator def wrapped(method, self, *args, **kwargs): kwargs = function_args_as_kwargs(method, self, *args, **kwargs) - kwargs.pop('self') + kwargs.pop("self") _key = key(self, kwargs) _namespace = namespace(self, kwargs) @@ -86,25 +84,26 @@ def wrapped(method, self, *args, **kwargs): if _cache is None or not _use_cache: return method(self, **kwargs) - if _cache.has_key(_key, namespace=_namespace) and not _renew: # noqa: has_key is not of a dictionary here + if ( + _cache.has_key(_key, namespace=_namespace) and not _renew + ): # noqa: has_key is not of a dictionary here try: - return _cache.get( - _key, - namespace=_namespace, - serializer=_serializer + return _cache.get(_key, namespace=_namespace, serializer=_serializer) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + "Failed to retrieve results from cache [%s]. Renewing the cache...", + e, ) - except: - logger.warning("Failed to retrieve results from cache. Renewing the cache...") if config.cache_fail_hard: - six.reraise(*sys.exc_info()) + raise finally: - logger.caveat('Loaded from cache') + logger.caveat("Loaded from cache") # Renewing/creating cache value = method(self, **kwargs) if value is None: logger.warning("Method value returned None. Not saving to cache.") - return + return None try: _cache.set( @@ -112,19 +111,17 @@ def wrapped(method, self, *args, **kwargs): value=value, namespace=_namespace, serializer=_serializer, - metadata=_metadata + metadata=_metadata, ) # Return from cache every time, just in case serialization operation was # destructive (e.g. reading from cursors) - return _cache.get( - _key, - namespace=_namespace, - serializer=_serializer + return _cache.get(_key, namespace=_namespace, serializer=_serializer) + except: # pylint: disable=bare-except + logger.warning( + "Failed to save results to cache. If needed, please save them manually." ) - except: - logger.warning("Failed to save results to cache. If needed, please save them manually.") if config.cache_fail_hard: - six.reraise(*sys.exc_info()) + raise return value # As a last resort, return value object (which could be mutated by serialization). return wrapped @@ -137,8 +134,8 @@ class Cache(Duct): DUCT_TYPE = Duct.Type.CACHE - @quirk_docs('_init', mro=True) - def __init__(self, **kwargs): + @quirk_docs("_init", mro=True) + def __init__(self, **kwargs): # pylint: disable=super-init-not-called Duct.__init_with_kwargs__(self, kwargs) self._init(**kwargs) @@ -165,12 +162,18 @@ def set(self, key, value, namespace=None, serializer=None, metadata=None): namespace, key = self._namespace(namespace), self._key(key) serializer = serializer or PickleSerializer() try: - with self._get_stream_for_key(namespace, key, 'data{}'.format(serializer.file_extension), mode='wb', create=True) as fh: + with self._get_stream_for_key( + namespace, + key, + f"data{serializer.file_extension}", + mode="wb", + create=True, + ) as fh: serializer.serialize(value, fh) self.set_metadata(key, metadata, namespace=namespace, replace=True) - except: + except: # pylint: disable=bare-except self.unset(key, namespace=namespace) - six.reraise(*sys.exc_info()) + raise @require_connection def set_metadata(self, key, metadata, namespace=None, replace=False): @@ -189,13 +192,15 @@ def set_metadata(self, key, metadata, namespace=None, replace=False): """ namespace, key = self._namespace(namespace), self._key(key) if replace: - orig_metadata = {'created': datetime.datetime.utcnow()} + orig_metadata = {"created": datetime.datetime.utcnow()} else: orig_metadata = self.get_metadata(key, namespace=namespace) orig_metadata.update(metadata or {}) - with self._get_stream_for_key(namespace, key, 'metadata', mode='w', create=True) as fh: + with self._get_stream_for_key( + namespace, key, "metadata", mode="w", create=True + ) as fh: yaml.safe_dump(orig_metadata, fh, default_flow_style=False) @require_connection @@ -215,12 +220,22 @@ def get(self, key, namespace=None, serializer=None): namespace, key = self._namespace(namespace), self._key(key) serializer = serializer or PickleSerializer() if not self._has_key(namespace, key): - raise KeyError("{} (namespace: {})".format(key, namespace)) + raise KeyError(f"{key} (namespace: {namespace})") try: - with self._get_stream_for_key(namespace, key, 'data{}'.format(serializer.file_extension), mode='rb', create=False) as fh: + with self._get_stream_for_key( + namespace, + key, + f"data{serializer.file_extension}", + mode="rb", + create=False, + ) as fh: return serializer.deserialize(fh) finally: - self.set_metadata(key, namespace=namespace, metadata={'last_accessed': datetime.datetime.utcnow()}) + self.set_metadata( + key, + namespace=namespace, + metadata={"last_accessed": datetime.datetime.utcnow()}, + ) @require_connection def get_bytecount(self, key, namespace=None): @@ -240,7 +255,7 @@ def get_bytecount(self, key, namespace=None): """ namespace, key = self._namespace(namespace), self._key(key) if not self._has_key(namespace, key): - raise KeyError("{} (namespace: {})".format(key, namespace)) + raise KeyError(f"{key} (namespace: {namespace})") return self._get_bytecount_for_key(namespace, key) @require_connection @@ -257,11 +272,13 @@ def get_metadata(self, key, namespace=None): """ namespace, key = self._namespace(namespace), self._key(key) if not self._has_key(namespace, key): - raise KeyError("{} (namespace: {})".format(key, namespace)) + raise KeyError(f"{key} (namespace: {namespace})") try: - with self._get_stream_for_key(namespace, key, 'metadata', mode='r', create=False) as fh: + with self._get_stream_for_key( + namespace, key, "metadata", mode="r", create=False + ) as fh: return yaml.safe_load(fh) - except: + except: # pylint: disable=bare-except return {} @require_connection @@ -275,7 +292,7 @@ def unset(self, key, namespace=None): """ namespace, key = self._namespace(namespace), self._key(key) if not self._has_key(namespace, key): - raise KeyError("{} (namespace: {})".format(key, namespace)) + raise KeyError(f"{key} (namespace: {namespace})") self._remove_key(namespace, key) @require_connection @@ -288,7 +305,7 @@ def unset_namespace(self, namespace=None): """ namespace = self._namespace(namespace) if not self._has_namespace(namespace): - raise KeyError("namespace: {}".format(namespace)) + raise KeyError(f"namespace: {namespace}") self._remove_namespace(namespace) # Top-level descriptions @@ -392,34 +409,37 @@ def describe(self, namespaces=None): for namespace in namespaces: for key in self.keys(namespace=namespace): usage = { - 'bytes': self.get_bytecount(key, namespace=namespace), - 'namespace': namespace, - 'key': key, - 'created': None, - 'last_accessed': None + "bytes": self.get_bytecount(key, namespace=namespace), + "namespace": namespace, + "key": key, + "created": None, + "last_accessed": None, } usage.update(self.get_metadata(key, namespace=namespace)) out.append(usage) - required_columns = ['bytes', 'namespace', 'key', 'created', 'last_accessed'] + required_columns = ["bytes", "namespace", "key", "created", "last_accessed"] if out: df = pandas.DataFrame(out) - order = required_columns + sorted(set(df.columns).difference(required_columns)) - return ( - df - .sort_values('last_accessed', ascending=False) - .reset_index(drop=True) - [order] + order = required_columns + sorted( + set(df.columns).difference(required_columns) ) + return df.sort_values("last_accessed", ascending=False).reset_index( + drop=True + )[order] - return pandas.DataFrame( - data=[], - columns=required_columns - ) + return pandas.DataFrame(data=[], columns=required_columns) # Cache pruning - def prune(self, namespaces=None, max_age=None, max_bytes=None, total_bytes=None, total_count=None): + def prune( + self, + namespaces=None, + max_age=None, + max_bytes=None, + total_bytes=None, + total_count=None, + ): """ Remove keys from the cache in order to satisfy nominated constraints. @@ -441,7 +461,9 @@ def prune(self, namespaces=None, max_age=None, max_bytes=None, total_bytes=None, constraint will be applied after max_age and max_bytes. """ usage = self.describe(namespaces=namespaces) - if usage.shape[0] == 0: # Abort early if the cache is empty (and hence has no index, which would cause problems later on) + if ( + usage.shape[0] == 0 + ): # Abort early if the cache is empty (and hence has no index, which would cause problems later on) return constraints = [] @@ -450,36 +472,54 @@ def prune(self, namespaces=None, max_age=None, max_bytes=None, total_bytes=None, if max_age is not None: if isinstance(max_age, int): max_age = datetime.timedelta(max_age) - if isinstance(max_age, (datetime.timedelta, dateutil.relativedelta.relativedelta)): + if isinstance( + max_age, (datetime.timedelta, dateutil.relativedelta.relativedelta) + ): max_age = datetime.datetime.now() - max_age if not isinstance(max_age, (datetime.datetime, datetime.date)): - raise ValueError("Invalid type specified for `max_age`: {}".format(max_age.__repr__())) + raise ValueError( + f"Invalid type specified for `max_age`: {repr(max_age)}" + ) constraints.append(usage.last_accessed < max_age) if max_bytes is not None: if not isinstance(max_bytes, int): - raise ValueError("Invalid type specified for `max_bytes`: {}".format(max_bytes.__repr__())) + raise ValueError( + f"Invalid type specified for `max_bytes`: {repr(max_bytes)}" + ) constraints.append(usage.bytes > max_bytes) if constraints: to_unset = usage[functools.reduce(lambda x, y: x | y, constraints, False)] - for i, row in to_unset.iterrows(): - logger.info("Unsetting key '{}' (namespace: '{}')...".format(row.key, row.namespace)) + for _, row in to_unset.iterrows(): + logger.info( + f"Unsetting key '{row.key}' (namespace: '{row.namespace}')..." + ) self.unset(row.key, namespace=row.namespace) # Unset keys according to global constraints if total_bytes is not None or total_count is not None: if total_bytes is not None and not isinstance(total_bytes, int): - raise ValueError("Invalid type specified for `total_bytes`: {}".format(total_bytes.__repr__())) + raise ValueError( + f"Invalid type specified for `total_bytes`: {repr(total_bytes)}" + ) if total_count is not None and not isinstance(total_count, int): - raise ValueError("Invalid type specified for `total_count`: {}".format(total_bytes.__repr__())) - usage = self.describe(namespaces=namespaces).assign(cum_bytes=lambda x: x.bytes.cumsum()) + raise ValueError( + f"Invalid type specified for `total_count`: {repr(total_bytes)}" + ) + usage = self.describe(namespaces=namespaces).assign( + cum_bytes=lambda x: x.bytes.cumsum() + ) unset_index = total_count if total_count is not None else len(usage) if total_bytes is not None: - unset_index = min(unset_index, usage.cum_bytes.searchsorted(total_bytes, side='right')) - for i, row in usage.loc[unset_index:].iterrows(): - logger.info("Unsetting key '{}' (namespace: '{}')...".format(row.key, row.namespace)) + unset_index = min( + unset_index, usage.cum_bytes.searchsorted(total_bytes, side="right") + ) + for _, row in usage.loc[unset_index:].iterrows(): + logger.info( + f"Unsetting key '{row.key}' (namespace: '{row.namespace}')..." + ) self.unset(row.key, namespace=row.namespace) # Methods for subclasses to implement diff --git a/omniduct/caches/filesystem.py b/omniduct/caches/filesystem.py index b83b3e6..cc72ddd 100644 --- a/omniduct/caches/filesystem.py +++ b/omniduct/caches/filesystem.py @@ -1,4 +1,3 @@ -import six import yaml from interface_meta import override @@ -13,10 +12,10 @@ class FileSystemCache(Cache): An implementation of `Cache` that wraps around a `FilesystemClient`. """ - PROTOCOLS = ['filesystem_cache'] + PROTOCOLS = ["filesystem_cache"] @override - def _init(self, path, fs=None): + def _init(self, path, fs=None): # pylint: disable=arguments-differ """ path (str): The top-level path of the cache in the filesystem. fs (FileSystemClient, str): The filesystem client to use as the @@ -30,51 +29,53 @@ def _init(self, path, fs=None): self.path = path # Currently config is not used, but will be in future versions self._config = None - self.connection_fields += ('fs',) + self.connection_fields += ("fs",) @override def _prepare(self): Cache._prepare(self) if self.registry is not None: - if isinstance(self.fs, six.string_types): - self.fs = self.registry.lookup(self.fs, kind=FileSystemCache.Type.FILESYSTEM) - assert isinstance(self.fs, FileSystemClient), "Provided cache is not an instance of `omniduct.filesystems.base.FileSystemClient`." + if isinstance(self.fs, str): + self.fs = self.registry.lookup( # pylint: disable=attribute-defined-outside-init + self.fs, kind=FileSystemCache.Type.FILESYSTEM + ) + assert isinstance( + self.fs, FileSystemClient + ), "Provided cache is not an instance of `omniduct.filesystems.base.FileSystemClient`." self._prepare_cache() def _prepare_cache(self): - config_path = self.fs.path_join(self.path, 'config') + config_path = self.fs.path_join(self.path, "config") if self.fs.exists(config_path): with self.fs.open(config_path) as fh: try: return yaml.safe_load(fh) - except yaml.error.YAMLError: + except yaml.error.YAMLError as e: raise RuntimeError( - "Path nominated for cache ('{}') has a corrupt " - "configuration. Please manually empty or delete this " - "path cache, and try again.".format(self.path) - ) + f"Path nominated for cache ('{self.path}') has a corrupt " + "configuration. Please manually empty or delete this path " + "cache, and try again." + ) from e # Cache needs initialising if self.fs.exists(self.path): if not self.fs.isdir(self.path): raise RuntimeError( - "Path nominated for cache ('{}') is not a directory.".format(self.path) + f"Path nominated for cache ('{self.path}') is not a directory." ) - elif self.fs.listdir(self.path): + if self.fs.listdir(self.path): raise RuntimeError( - "Cache directory ({}) needs to be initialised, and is not " - "empty. Please manually delete and/or empty this path, and " - "try again.".format(self.path) + f"Cache directory ({self.path}) needs to be initialised, and is not empty. Please manually delete and/or empty this path, and try again." ) else: # Create cache directory self.fs.mkdir(self.path, recursive=True, exist_ok=True) # Write config file to mark cache as initialised - with self.fs.open(config_path, 'w') as fh: - yaml.safe_dump({'version': 1}, fh, default_flow_style=False) - return {'version': 1} + with self.fs.open(config_path, "w") as fh: + yaml.safe_dump({"version": 1}, fh, default_flow_style=False) + return {"version": 1} @override def _connect(self): @@ -92,13 +93,13 @@ def _disconnect(self): @override def _namespace(self, namespace): if namespace is None: - return '__default__' - assert isinstance(namespace, str) and namespace != 'config' + return "__default__" + assert isinstance(namespace, str) and namespace != "config" return namespace @override def _get_namespaces(self): - return [d for d in self.fs.listdir(self.path) if d != 'config'] + return [d for d in self.fs.listdir(self.path) if d != "config"] @override def _has_namespace(self, namespace): @@ -118,15 +119,14 @@ def _has_key(self, namespace, key): @override def _remove_key(self, namespace, key): - return self.fs.remove(self.fs.path_join(self.path, namespace, key), recursive=True) + return self.fs.remove( + self.fs.path_join(self.path, namespace, key), recursive=True + ) @override def _get_bytecount_for_key(self, namespace, key): path = self.fs.path_join(self.path, namespace, key) - return sum([ - f.bytes - for f in self.fs.dir(path) - ]) + return sum(f.bytes for f in self.fs.dir(path)) @override def _get_stream_for_key(self, namespace, key, stream_name, mode, create): diff --git a/omniduct/databases/_cursor_formatters.py b/omniduct/databases/_cursor_formatters.py index 5d168d1..6814918 100644 --- a/omniduct/databases/_cursor_formatters.py +++ b/omniduct/databases/_cursor_formatters.py @@ -1,17 +1,16 @@ import csv import io -import six from omniduct.utils.debug import logger COLUMN_NAME_FORMATTERS = { None: lambda x: x, - 'lowercase': lambda x: x.lower(), - 'uppercase': lambda x: x.upper() + "lowercase": lambda x: x.lower(), + "uppercase": lambda x: x.upper(), } -class CursorFormatter(object): +class CursorFormatter: """ An abstract base class for all cursor formatters. @@ -35,7 +34,8 @@ def __init__(self, cursor, column_name_formatter=None, **kwargs): """ self.cursor = cursor self.column_name_formatter = ( - column_name_formatter if callable(column_name_formatter) + column_name_formatter + if callable(column_name_formatter) else COLUMN_NAME_FORMATTERS[column_name_formatter] ) self._init(**kwargs) @@ -99,10 +99,14 @@ def _prepare_row(self, row): return row def _format_dump(self, data): - raise NotImplementedError("{} does not support formatting dumped data.".format(self.__class__.__name__)) + raise NotImplementedError( + f"{self.__class__.__name__} does not support formatting dumped data." + ) def _format_row(self, row): - raise NotImplementedError("{} does not support formatting streaming data.".format(self.__class__.__name__)) + raise NotImplementedError( + f"{self.__class__.__name__} does not support formatting streaming data." + ) class PandasCursorFormatter(CursorFormatter): @@ -125,9 +129,10 @@ def _format_dump(self, data): if self.date_fields is not None: try: df = pd.io.sql._parse_date_columns(df, self.date_fields) - except Exception as e: - logger.warning('Unable to parse date columns. Perhaps your version of pandas is outdated.' - 'Original error message was: {}: {}'.format(e.__class__.__name__, str(e))) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning( + f"Unable to parse date columns. Perhaps your version of pandas is outdated.Original error message was: {e.__class__.__name__}: {str(e)}" + ) if self.index_fields is not None: df.set_index(self.index_fields, inplace=True) @@ -184,16 +189,16 @@ class CsvCursorFormatter(CursorFormatter): """ FORMAT_PARAMS = { - 'delimiter': ',', - 'doublequote': False, - 'escapechar': '\\', - 'lineterminator': '\r\n', - 'quotechar': '"', - 'quoting': csv.QUOTE_MINIMAL + "delimiter": ",", + "doublequote": False, + "escapechar": "\\", + "lineterminator": "\r\n", + "quotechar": '"', + "quoting": csv.QUOTE_MINIMAL, } def _init(self, include_header=True): - self.output = io.StringIO() if six.PY3 else io.BytesIO() + self.output = io.StringIO() self.include_header = include_header self.writer = csv.writer(self.output, **self.FORMAT_PARAMS) @@ -225,17 +230,17 @@ class HiveCursorFormatter(CsvCursorFormatter): """ FORMAT_PARAMS = { - 'delimiter': '\t', - 'doublequote': False, - 'escapechar': '', - 'lineterminator': '\n', - 'quotechar': '', - 'quoting': csv.QUOTE_NONE + "delimiter": "\t", + "doublequote": False, + "escapechar": "", + "lineterminator": "\n", + "quotechar": "", + "quoting": csv.QUOTE_NONE, } - def _init(self): + def _init(self): # pylint: disable=arguments-differ CsvCursorFormatter._init(self, include_header=False) # Convert null values to '\N'. def _prepare_row(self, row): - return [r'\N' if v is None else str(v).replace('\t', r'\t') for v in row] + return [r"\N" if v is None else str(v).replace("\t", r"\t") for v in row] diff --git a/omniduct/databases/_cursor_serializer.py b/omniduct/databases/_cursor_serializer.py index 73b9ae3..a33ec6c 100644 --- a/omniduct/databases/_cursor_serializer.py +++ b/omniduct/databases/_cursor_serializer.py @@ -13,17 +13,17 @@ def file_extension(self): """str: The file extension to use when storing in the cache.""" return ".pickled_cursor" - def serialize(self, cursor, fh): + def serialize(self, obj, fh): """ Serialize a cursor object into a nominated file handle. Args: - cursor (DB-API 2.0 cursor): The cursor to serialize. + obj (DB-API 2.0 cursor): The cursor to serialize. fh (binary file handle): A file-like object opened in binary mode capable of being written into. """ - description = cursor.description - rows = cursor.fetchall() + description = obj.description + rows = obj.fetchall() pickle.dump((description, rows), fh) def deserialize(self, fh): @@ -42,7 +42,7 @@ def deserialize(self, fh): return CachedCursor(description, rows) -class CachedCursor(object): +class CachedCursor: """ A DB-API 2.0 cursor implementation atop of static data. @@ -58,7 +58,7 @@ def __init__(self, description, rows): @property def iter(self): - if not getattr(self, '_iter'): + if not getattr(self, "_iter"): self._iter = (row for row in self._rows) return self._iter @@ -75,13 +75,13 @@ def row_count(self): def close(self): pass - def execute(operation, parameters=None): + def execute(self, operation, parameters=None): raise NotImplementedError( "Cached cursors are not connected to a database, and cannot be " "used for database operations." ) - def executemany(operation, seq_of_parameters=None): + def executemany(self, operation, seq_of_parameters=None): raise NotImplementedError( "Cached cursors are not connected to a database, and cannot be " "used for database operations." diff --git a/omniduct/databases/_namespaces.py b/omniduct/databases/_namespaces.py index b5b54b2..ec27c25 100644 --- a/omniduct/databases/_namespaces.py +++ b/omniduct/databases/_namespaces.py @@ -2,7 +2,7 @@ from collections import OrderedDict -class ParsedNamespaces(object): +class ParsedNamespaces: """ A namespace parser for DatabaseClient subclasses. @@ -21,7 +21,7 @@ class ParsedNamespaces(object): """ @classmethod - def from_name(cls, name, namespaces, quote_char='"', separator='.', defaults=None): + def from_name(cls, name, namespaces, quote_char='"', separator=".", defaults=None): """ Return an instance of `ParsedNamespaces` from a given name. @@ -54,35 +54,38 @@ def from_name(cls, name, namespaces, quote_char='"', separator='.', defaults=Non extra_namespaces = set(name.namespaces).difference(namespaces) if extra_namespaces: raise ValueError( - "ParsedNamespace is not encapsulated by the namespaces " - "provided to this constructor. It has extra namespaces: {}." - .format(extra_namespaces) + f"ParsedNamespace is not encapsulated by the namespaces provided to this constructor. It has extra namespaces: {extra_namespaces}." ) parsed = name.as_dict() elif isinstance(name, str): namespace_matcher = re.compile( r"([^{sep}{qc}]+)|{qc}([^`]*?){qc}".format( - qc=re.escape(quote_char), - sep=re.escape(separator) + qc=re.escape(quote_char), sep=re.escape(separator) ) ) - names = [''.join(t) for t in namespace_matcher.findall(name)] if name else [] + names = ( + ["".join(t) for t in namespace_matcher.findall(name)] if name else [] + ) if len(names) > len(namespaces): raise ValueError( - "Name '{}' has too many namespaces. Should be of form: <{}>." - .format(name, ">{sep}<".format(sep=separator).join(namespaces)) + f"Name '{name}' has too many namespaces. Should be of form: <{'>{separator}<'.join(namespaces)}>." ) - parsed = OrderedDict(reversed([ - (namespace, names.pop() if names else None) - for namespace in namespaces[::-1] - ])) + parsed = OrderedDict( + reversed( + [ + (namespace, names.pop() if names else None) + for namespace in namespaces[::-1] + ] + ) + ) else: - raise ValueError("Cannot construct `ParsedNamespaces` instance from " - "name of type: `{}`.".format(type(name))) + raise ValueError( + f"Cannot construct `ParsedNamespaces` instance from name of type: `{type(name)}`." + ) for namespace in namespaces[::-1]: if not parsed.get(namespace) and namespace in defaults: @@ -92,11 +95,10 @@ def from_name(cls, name, namespaces, quote_char='"', separator='.', defaults=Non return cls(parsed, quote_char=quote_char, separator=separator) - def __init__(self, names, namespaces=None, quote_char='"', separator='.'): + def __init__(self, names, namespaces=None, quote_char='"', separator="."): if namespaces: names = OrderedDict( - (namespace, names.get(namespace, None)) - for namespace in namespaces + (namespace, names.get(namespace, None)) for namespace in namespaces ) self._names = names @@ -104,15 +106,15 @@ def __init__(self, names, namespaces=None, quote_char='"', separator='.'): self._separator = separator def __getattr__(self, name): - if '_names' in self.__dict__ and name in self._names: + if "_names" in self.__dict__ and name in self._names: return self._names[name] raise AttributeError(name) def __setattr__(self, name, value): - if '_names' in self.__dict__ and name in self._names: + if "_names" in self.__dict__ and name in self._names: self._names[name] = value else: - super(ParsedNamespaces, self).__setattr__(name, value) + super().__setattr__(name, value) def __bool__(self): return bool(self.name) @@ -136,9 +138,7 @@ def parent(self): names = self._names.copy() names.popitem() return ParsedNamespaces( - names=names, - quote_char=self._quote_char, - separator=self._separator + names=names, quote_char=self._quote_char, separator=self._separator ) def as_dict(self): @@ -152,9 +152,7 @@ def render(self, quote_char=None, separator=None): separator = self._separator names = [ - self._names[namespace] - for namespace, name in self._names.items() - if name + self._names[namespace] for namespace, name in self._names.items() if name ] if len(names) == 0: return "" @@ -168,4 +166,4 @@ def __str__(self): return self.name def __repr__(self): - return "Namespace<{}>".format(self.name) + return f"Namespace<{self.name}>" diff --git a/omniduct/databases/_pandas.py b/omniduct/databases/_pandas.py index 2a73ee1..068e174 100644 --- a/omniduct/databases/_pandas.py +++ b/omniduct/databases/_pandas.py @@ -1,43 +1,55 @@ from pandas.io.sql import SQLTable, SQLDatabase -def to_sql(df, name, schema, con, index, if_exists, mode='default', **kwargs): +def to_sql(df, name, schema, con, index, if_exists, mode="default", **kwargs): """ Override the default `pandas.to_sql` method to allow for insertion of multiple rows of data at once. This is derived from the upstream patch at https://github.com/pandas-dev/pandas/pull/21401, and can be deprecated once it is merged and released in a new version of `pandas`. """ - assert mode in ('default', 'multi'), 'unexpected `to_sql` mode {}'.format(mode) - if mode == 'default': + assert mode in ("default", "multi"), f"unexpected `to_sql` mode {mode}" + if mode == "default": return df.to_sql( - name=name, schema=schema, con=con, index=index, if_exists=if_exists, **kwargs + name=name, + schema=schema, + con=con, + index=index, + if_exists=if_exists, + **kwargs, ) - else: - nrows = len(df) - if nrows == 0: - return - chunksize = kwargs.get('chunksize', nrows) - if chunksize == 0: - raise ValueError('chunksize argument should be non-zero') - chunks = int(nrows / chunksize) + 1 + nrows = len(df) + if nrows == 0: + return None - pd_sql = SQLDatabase(con) - pd_table = SQLTable( - name, pd_sql, frame=df, index=index, if_exists=if_exists, - index_label=kwargs.get('insert_label'), schema=schema, dtype=kwargs.get('dtype') - ) - pd_table.create() - keys, data_list = pd_table.insert_data() + chunksize = kwargs.get("chunksize", nrows) + if chunksize == 0: + raise ValueError("chunksize argument should be non-zero") + chunks = int(nrows / chunksize) + 1 + + pd_sql = SQLDatabase(con) + pd_table = SQLTable( + name, + pd_sql, + frame=df, + index=index, + if_exists=if_exists, + index_label=kwargs.get("insert_label"), + schema=schema, + dtype=kwargs.get("dtype"), + ) + pd_table.create() + keys, data_list = pd_table.insert_data() - with pd_sql.run_transaction() as conn: - for i in range(chunks): - start_i = i * chunksize - end_i = min((i + 1) * chunksize, nrows) - if start_i >= end_i: - break + with pd_sql.run_transaction() as conn: + for i in range(chunks): + start_i = i * chunksize + end_i = min((i + 1) * chunksize, nrows) + if start_i >= end_i: + break - chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list]) - data = [{k: v for k, v in zip(keys, row)} for row in chunk_iter] - conn.execute(pd_table.table.insert(data)) # multivalues insert + chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list]) + data = [dict(zip(keys, row)) for row in chunk_iter] + conn.execute(pd_table.table.insert(data)) # multivalues insert + return None diff --git a/omniduct/databases/_schemas.py b/omniduct/databases/_schemas.py index 5e0587f..20d07fc 100644 --- a/omniduct/databases/_schemas.py +++ b/omniduct/databases/_schemas.py @@ -1,3 +1,5 @@ +# pylint: disable=abstract-method + from __future__ import absolute_import import pandas as pd @@ -14,15 +16,15 @@ def get_columns(self, connection, table_name, schema=None, **kw): # Extend types supported by PrestoDialect as defined in PyHive type_map = { - 'bigint': sql_types.BigInteger, - 'integer': sql_types.Integer, - 'boolean': sql_types.Boolean, - 'double': sql_types.Float, - 'varchar': sql_types.String, - 'timestamp': sql_types.TIMESTAMP, - 'date': sql_types.DATE, - 'array': sql_types.ARRAY(sql_types.Integer), - 'array': sql_types.ARRAY(sql_types.String) + "bigint": sql_types.BigInteger, + "integer": sql_types.Integer, + "boolean": sql_types.Boolean, + "double": sql_types.Float, + "varchar": sql_types.String, + "timestamp": sql_types.TIMESTAMP, + "date": sql_types.DATE, + "array": sql_types.ARRAY(sql_types.Integer), + "array": sql_types.ARRAY(sql_types.String), } rows = self._get_table_columns(connection, table_name, schema) @@ -31,23 +33,29 @@ def get_columns(self, connection, table_name, schema=None, **kw): try: coltype = type_map[row.Type] except KeyError: - logger.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column)) + logger.warn( + f"Did not recognize type '{row.Type}' of column '{row.Column}'" + ) coltype = sql_types.NullType - result.append({ - 'name': row.Column, - 'type': coltype, - # newer Presto no longer includes this column - 'nullable': getattr(row, 'Null', True), - 'default': None, - }) + result.append( + { + "name": row.Column, + "type": coltype, + # newer Presto no longer includes this column + "nullable": getattr(row, "Null", True), + "default": None, + } + ) return result PrestoDialect.get_columns = get_columns except ImportError: - logger.debug("Not monkey patching pyhive's PrestoDialect.get_columns due to missing dependencies.") + logger.debug( + "Not monkey patching pyhive's PrestoDialect.get_columns due to missing dependencies." + ) -class SchemasMixin(object): +class SchemasMixin: """ Attaches a tab-completable `.schemas` attribute to a `DatabaseClient` instance. @@ -69,13 +77,13 @@ def schemas(self): from lazy_object_proxy import Proxy def get_schemas(): - if not getattr(self, '_schemas', None): - assert getattr(self, '_sqlalchemy_metadata', None) is not None, ( - "`{class_name}` instances do not provide the required sqlalchemy metadata " - "for schema exploration.".format(class_name=self.__class__.__name__) - ) + if not getattr(self, "_schemas", None): + assert ( + getattr(self, "_sqlalchemy_metadata", None) is not None + ), f"`{self.__class__.__name__}` instances do not provide the required sqlalchemy metadata for schema exploration." self._schemas = Schemas(self._sqlalchemy_metadata) return self._schemas + return Proxy(get_schemas) @@ -88,8 +96,11 @@ class TableDesc(Table): def desc(self): """pandas.DataFrame: The description of this SQL table.""" return pd.DataFrame( - [[col.name, col.type.compile(self.bind.dialect)] for col in self.columns.values()], - columns=['name', 'type'] + [ + [col.name, col.type.compile(self.bind.dialect)] + for col in self.columns.values() + ], + columns=["name", "type"], ) def head(self, n=10): @@ -122,7 +133,7 @@ def __repr__(self): # Define helpers to allow for table completion/etc -class Schemas(object): +class Schemas: """ An object which has as its attributes all of the schemas in a nominated database. @@ -140,7 +151,9 @@ def __init__(self, metadata): def all(self): "list: The list of schema names." if self._schema_names is None: - self._schema_names = sqlalchemy.inspect(self._metadata.bind).get_schema_names() + self._schema_names = sqlalchemy.inspect( + self._metadata.bind + ).get_schema_names() return self._schema_names def __dir__(self): @@ -149,12 +162,14 @@ def __dir__(self): def __getattr__(self, value): if value in self.all: if value not in self._schema_cache: - self._schema_cache[value] = Schema(metadata=self._metadata, schema=value) + self._schema_cache[value] = Schema( + metadata=self._metadata, schema=value + ) return self._schema_cache[value] - raise AttributeError("No such schema {}".format(value)) + raise AttributeError(f"No such schema {value}") def __repr__(self): - return "".format(len(self.all)) + return f"" def __iter__(self): for schema in self.all: @@ -164,7 +179,7 @@ def __len__(self): return len(self.all) -class Schema(object): +class Schema: """ An object which has as its attributes all of the tables in a nominated database schema. @@ -184,7 +199,9 @@ def __init__(self, metadata, schema): def all(self): """list: The table names in this database schema.""" if self._table_names is None: - self._table_names = sqlalchemy.inspect(self._metadata.bind).get_table_names(self._schema) + self._table_names = sqlalchemy.inspect(self._metadata.bind).get_table_names( + self._schema + ) return self._table_names def __dir__(self): @@ -194,13 +211,16 @@ def __getattr__(self, table): if table in self.all: if table not in self._table_cache: self._table_cache[table] = TableDesc( - '{}'.format(table), self._metadata, autoload=True, schema=self._schema + f"{table}", + self._metadata, + autoload=True, + schema=self._schema, ) return self._table_cache[table] - raise AttributeError("No such table {}".format(table)) + raise AttributeError(f"No such table {table}") def __repr__(self): - return "".format(self._schema, len(self.all)) + return f"" def __iter__(self): for schema in self.all: diff --git a/omniduct/databases/base.py b/omniduct/databases/base.py index f9076a5..f1cf9da 100644 --- a/omniduct/databases/base.py +++ b/omniduct/databases/base.py @@ -5,7 +5,6 @@ import itertools import logging import os -import sys from abc import abstractmethod import jinja2 @@ -19,14 +18,17 @@ from omniduct.filesystems.local import LocalFsClient from omniduct.utils.debug import logger, logging_scope from omniduct.utils.decorators import require_connection -from omniduct.utils.magics import (MagicsProvider, process_line_arguments, - process_line_cell_arguments) +from omniduct.utils.magics import ( + MagicsProvider, + process_line_arguments, + process_line_cell_arguments, +) from . import _cursor_formatters from ._cursor_serializer import CursorSerializer from ._namespaces import ParsedNamespaces -logging.getLogger('requests').setLevel(logging.WARNING) +logging.getLogger("requests").setLevel(logging.WARNING) @decorator @@ -37,10 +39,10 @@ def render_statement(method, self, statement, *args, **kwargs): This decorator expects to act as wrapper on functions which takes statements as the second argument. """ - if kwargs.pop('template', True): + if kwargs.pop("template", True): statement = self.template_render( statement, - context=kwargs.pop('context', {}), + context=kwargs.pop("context", {}), by_name=False, ) return method(self, statement, *args, **kwargs) @@ -65,18 +67,18 @@ class DatabaseClient(Duct, MagicsProvider): DEFAULT_PORT = None CURSOR_FORMATTERS = { - 'pandas': _cursor_formatters.PandasCursorFormatter, - 'hive': _cursor_formatters.HiveCursorFormatter, - 'csv': _cursor_formatters.CsvCursorFormatter, - 'tuple': _cursor_formatters.TupleCursorFormatter, - 'dict': _cursor_formatters.DictCursorFormatter, - 'raw': _cursor_formatters.RawCursorFormatter, + "pandas": _cursor_formatters.PandasCursorFormatter, + "hive": _cursor_formatters.HiveCursorFormatter, + "csv": _cursor_formatters.CsvCursorFormatter, + "tuple": _cursor_formatters.TupleCursorFormatter, + "dict": _cursor_formatters.DictCursorFormatter, + "raw": _cursor_formatters.RawCursorFormatter, } - DEFAULT_CURSOR_FORMATTER = 'pandas' + DEFAULT_CURSOR_FORMATTER = "pandas" SUPPORTS_SESSION_PROPERTIES = False - NAMESPACE_NAMES = ['database', 'table'] + NAMESPACE_NAMES = ["database", "table"] NAMESPACE_QUOTECHAR = '"' - NAMESPACE_SEPARATOR = '.' + NAMESPACE_SEPARATOR = "." NAMESPACE_DEFAULT = None # DEPRECATED (use NAMESPACE_DEFAULTS_READ instead): Will be removed in Omniduct 2.0.0 @@ -94,10 +96,15 @@ def NAMESPACE_DEFAULTS_WRITE(self): """ return self.NAMESPACE_DEFAULTS_READ - @quirk_docs('_init', mro=True) + @quirk_docs("_init", mro=True) + # pylint: disable-next=super-init-not-called def __init__( - self, session_properties=None, templates=None, template_context=None, default_format_opts=None, - **kwargs + self, + session_properties=None, + templates=None, + template_context=None, + default_format_opts=None, + **kwargs, ): """ session_properties (dict): A mapping of default session properties @@ -205,7 +212,7 @@ def _statement_split(self, statements): """ for statement in sqlparse.split(statements): statement = statement.strip() - if statement.endswith(';'): + if statement.endswith(";"): statement = statement[:-1].strip() if statement: # remove empty statements yield statement @@ -226,12 +233,7 @@ def statement_hash(cls, statement, cleanup=True): """ if cleanup: statement = cls.statement_cleanup(statement) - if ( - sys.version_info.major == 3 - or sys.version_info.major == 2 and isinstance(statement, unicode) # noqa: F821 - ): - statement = statement.encode('utf8') - return hashlib.sha256(statement).hexdigest() + return hashlib.sha256(statement.encode("utf8")).hexdigest() @classmethod def statement_cleanup(cls, statement): @@ -257,19 +259,20 @@ def statement_cleanup(cls, statement): @render_statement @cached_method( key=lambda self, kwargs: self.statement_hash( - statement=kwargs['statement'], - cleanup=kwargs.pop('cleanup', True) + statement=kwargs["statement"], cleanup=kwargs.pop("cleanup", True) ), serializer=lambda self, kwargs: CursorSerializer(), - use_cache=lambda self, kwargs: kwargs.pop('use_cache', False), + use_cache=lambda self, kwargs: kwargs.pop("use_cache", False), metadata=lambda self, kwargs: { - 'statement': kwargs['statement'], - 'session_properties': kwargs['session_properties'] - } + "statement": kwargs["statement"], + "session_properties": kwargs["session_properties"], + }, ) - @quirk_docs('_execute') + @quirk_docs("_execute") @require_connection - def execute(self, statement, wait=True, cursor=None, session_properties=None, **kwargs): + def execute( + self, statement, wait=True, cursor=None, session_properties=None, **kwargs + ): """ Execute a statement against this database and return a cursor object. @@ -309,20 +312,37 @@ def execute(self, statement, wait=True, cursor=None, session_properties=None, ** session_properties = self._get_session_properties(overrides=session_properties) - statements = list(self._statement_split( - self._statement_prepare(statement, session_properties=session_properties, **kwargs) - )) + statements = list( + self._statement_split( + self._statement_prepare( + statement, session_properties=session_properties, **kwargs + ) + ) + ) assert len(statements) > 0, "No non-empty statements were provided." - for statement in statements[:-1]: - cursor = self._execute(statement, cursor=cursor, wait=True, session_properties=session_properties, **kwargs) - cursor = self._execute(statements[-1], cursor=cursor, wait=wait, session_properties=session_properties, **kwargs) + for stmt in statements[:-1]: + cursor = self._execute( + stmt, + cursor=cursor, + wait=True, + session_properties=session_properties, + **kwargs, + ) + cursor = self._execute( + statements[-1], + cursor=cursor, + wait=wait, + session_properties=session_properties, + **kwargs, + ) return cursor @logging_scope("Query", timed=True) @render_statement - def query(self, statement, format=None, format_opts={}, use_cache=True, **kwargs): + # pylint: disable-next=redefined-builtin + def query(self, statement, format=None, format_opts=None, use_cache=True, **kwargs): """ Execute a statement against this database and collect formatted data. @@ -342,7 +362,10 @@ def query(self, statement, format=None, format_opts={}, use_cache=True, **kwargs Returns: The results of the query formatted as nominated. """ - cursor = self.execute(statement, wait=True, template=False, use_cache=use_cache, **kwargs) + format_opts = format_opts or {} + cursor = self.execute( + statement, wait=True, template=False, use_cache=use_cache, **kwargs + ) # Some DBAPI2 cursor implementations error if attempting to extract # data from an empty cursor, and if so, we simply return None. @@ -352,7 +375,8 @@ def query(self, statement, format=None, format_opts={}, use_cache=True, **kwargs formatter = self._get_formatter(format, cursor, **format_opts) return formatter.dump() - def stream(self, statement, format=None, format_opts={}, batch=None, **kwargs): + # pylint: disable-next=redefined-builtin + def stream(self, statement, format=None, format_opts=None, batch=None, **kwargs): """ Execute a statement against this database and stream formatted results. @@ -375,6 +399,7 @@ def stream(self, statement, format=None, format_opts={}, batch=None, **kwargs): iterator: An iterator over objects of the nominated format or, if batched, a list of such objects. """ + format_opts = format_opts or {} cursor = self.execute(statement, wait=True, **kwargs) formatter = self._get_formatter(format, cursor, **format_opts) @@ -383,13 +408,21 @@ def stream(self, statement, format=None, format_opts={}, batch=None, **kwargs): def _get_formatter(self, formatter, cursor, **kwargs): formatter = formatter or self.DEFAULT_CURSOR_FORMATTER - if not (inspect.isclass(formatter) and issubclass(formatter, _cursor_formatters.CursorFormatter)): - assert formatter in self.CURSOR_FORMATTERS, "Invalid format '{}'. Choose from: {}".format(formatter, ','.join(self.CURSOR_FORMATTERS.keys())) + if not ( + inspect.isclass(formatter) + and issubclass(formatter, _cursor_formatters.CursorFormatter) + ): + assert ( + formatter in self.CURSOR_FORMATTERS + ), f"Invalid format '{formatter}'. Choose from: {','.join(self.CURSOR_FORMATTERS.keys())}" formatter = self.CURSOR_FORMATTERS[formatter] - format_opts = dict(itertools.chain(self._default_format_opts.items(), kwargs.items())) + format_opts = dict( + itertools.chain(self._default_format_opts.items(), kwargs.items()) + ) return formatter(cursor, **format_opts) - def stream_to_file(self, statement, file, format='csv', fs=None, **kwargs): + # pylint: disable-next=redefined-builtin + def stream_to_file(self, statement, file, format="csv", fs=None, **kwargs): """ Execute a statement against this database and stream results to a file. @@ -413,7 +446,7 @@ def stream_to_file(self, statement, file, format='csv', fs=None, **kwargs): """ close_later = False if isinstance(file, str): - file = (fs or LocalFsClient()).open(file, 'w') + file = (fs or LocalFsClient()).open(file, "w") close_later = True try: @@ -441,7 +474,7 @@ def execute_from_file(self, file, fs=None, **kwargs): """ close_later = False if isinstance(file, str): - file = (fs or LocalFsClient()).open(file, 'r') + file = (fs or LocalFsClient()).open(file, "r") close_later = True try: @@ -469,7 +502,7 @@ def query_from_file(self, file, fs=None, **kwargs): """ close_later = False if isinstance(file, str): - file = (fs or LocalFsClient()).open(file, 'r') + file = (fs or LocalFsClient()).open(file, "r") close_later = True try: @@ -517,7 +550,7 @@ def template_get(self, name): str: The requested template. """ if name not in self._templates: - raise ValueError("No such template named: '{}'.".format(name)) + raise ValueError(f"No such template named: '{name}'.") return self._templates[name] def template_variables(self, name_or_statement, by_name=False): @@ -539,8 +572,14 @@ def template_variables(self, name_or_statement, by_name=False): ) return jinja2.meta.find_undeclared_variables(ast) - def template_render(self, name_or_statement, context=None, by_name=False, - cleanup=False, meta_only=False): + def template_render( + self, + name_or_statement, + context=None, + by_name=False, + cleanup=False, + meta_only=False, + ): """ Render a template by name or value. @@ -588,15 +627,18 @@ def template_render(self, name_or_statement, context=None, by_name=False, """ if by_name: if name_or_statement not in self._templates: - raise ValueError("No such template of name: '{}'.".format(name_or_statement)) + raise ValueError(f"No such template of name: '{name_or_statement}'.") statement = self._templates[name_or_statement] else: statement = name_or_statement try: from sqlalchemy.sql.base import Executable + if isinstance(statement, Executable): - statement = str(statement.compile(compile_kwargs={"literal_binds": True})) + statement = str( + statement.compile(compile_kwargs={"literal_binds": True}) + ) except ImportError: pass @@ -609,32 +651,26 @@ def template_render(self, name_or_statement, context=None, by_name=False, intersection = set(self._template_context.keys()) & set(context.keys()) if intersection: logger.warning( - "The following default template context keys have been overridden " - "by the local context: {}." - .format(intersection) + f"The following default template context keys have been overridden by the local context: {intersection}." ) # Substitute in any other named statements recursively - while '{{{' in statement or '{{%' in statement: - statement = ( - jinja2.Template( - statement, - block_start_string='{{%', - block_end_string='%}}', - variable_start_string='{{{', - variable_end_string='}}}', - comment_start_string='{{#', - comment_end_string='#}}', - undefined=jinja2.StrictUndefined - ) - .render(getattr(self, '_templates', {})) - ) + while "{{{" in statement or "{{%" in statement: + statement = jinja2.Template( + statement, + block_start_string="{{%", + block_end_string="%}}", + variable_start_string="{{{", + variable_end_string="}}}", + comment_start_string="{{#", + comment_end_string="#}}", + undefined=jinja2.StrictUndefined, + ).render(getattr(self, "_templates", {})) if not meta_only: - statement = ( - jinja2.Template(statement, undefined=jinja2.StrictUndefined) - .render(template_context) - ) + statement = jinja2.Template( + statement, undefined=jinja2.StrictUndefined + ).render(template_context) if cleanup: statement = self.statement_cleanup(statement) @@ -673,9 +709,9 @@ def query_from_template(self, name, context=None, **kwargs): return self.query(statement, **kwargs) # Uploading/querying data into data store - @logging_scope('Query [CTAS]', timed=True) - @quirk_docs('_query_to_table') - def query_to_table(self, statement, table, if_exists='fail', **kwargs): + @logging_scope("Query [CTAS]", timed=True) + @quirk_docs("_query_to_table") + def query_to_table(self, statement, table, if_exists="fail", **kwargs): """ Run a query and store the results in a table in this database. @@ -693,14 +729,14 @@ def query_to_table(self, statement, table, if_exists='fail', **kwargs): Returns: DB-API cursor: The cursor object associated with the execution. """ - assert if_exists in {'fail', 'replace', 'append'} + assert if_exists in {"fail", "replace", "append"} table = self._parse_namespaces(table, write=True) return self._query_to_table(statement, table, if_exists=if_exists, **kwargs) - @logging_scope('Dataframe Upload', timed=True) - @quirk_docs('_dataframe_to_table') + @logging_scope("Dataframe Upload", timed=True) + @quirk_docs("_dataframe_to_table") @require_connection - def dataframe_to_table(self, df, table, if_exists='fail', **kwargs): + def dataframe_to_table(self, df, table, if_exists="fail", **kwargs): """ Upload a local pandas dataframe into a table in this database. @@ -715,8 +751,10 @@ def dataframe_to_table(self, df, table, if_exists='fail', **kwargs): **kwargs (dict): Additional keyword arguments to pass onto `DatabaseClient._dataframe_to_table`. """ - assert if_exists in {'fail', 'replace', 'append'} - self._dataframe_to_table(df, self._parse_namespaces(table, write=True), if_exists=if_exists, **kwargs) + assert if_exists in {"fail", "replace", "append"} + self._dataframe_to_table( + df, self._parse_namespaces(table, write=True), if_exists=if_exists, **kwargs + ) # Table properties @@ -727,7 +765,7 @@ def _execute(self, statement, cursor, wait, session_properties): def _query_to_table(self, statement, table, if_exists, **kwargs): raise NotImplementedError - def _dataframe_to_table(self, df, table, if_exists='fail', **kwargs): + def _dataframe_to_table(self, df, table, if_exists="fail", **kwargs): raise NotImplementedError def _cursor_empty(self, cursor): @@ -739,10 +777,14 @@ def _parse_namespaces(self, name, level=0, defaults=None, write=False): self.NAMESPACE_NAMES[:-level] if level > 0 else self.NAMESPACE_NAMES, quote_char=self.NAMESPACE_QUOTECHAR, separator=self.NAMESPACE_SEPARATOR, - defaults=defaults if defaults else (self.NAMESPACE_DEFAULTS_WRITE if write else self.NAMESPACE_DEFAULTS_READ), + defaults=defaults + if defaults + else ( + self.NAMESPACE_DEFAULTS_WRITE if write else self.NAMESPACE_DEFAULTS_READ + ), ) - @quirk_docs('_table_list') + @quirk_docs("_table_list") def table_list(self, namespace=None, renew=True, **kwargs): """ Return a list of table names in the data source as a DataFrame. @@ -756,13 +798,15 @@ def table_list(self, namespace=None, renew=True, **kwargs): Returns: list: The names of schemas in this database. """ - return self._table_list(self._parse_namespaces(namespace, level=1), renew=renew, **kwargs) + return self._table_list( + self._parse_namespaces(namespace, level=1), renew=renew, **kwargs + ) @abstractmethod def _table_list(self, namespace, **kwargs): pass - @quirk_docs('_table_exists') + @quirk_docs("_table_exists") def table_exists(self, table, renew=True, **kwargs): """ Check whether a table exists. @@ -776,13 +820,15 @@ def table_exists(self, table, renew=True, **kwargs): Returns: bool: `True` if table exists, and `False` otherwise. """ - return self._table_exists(table=self._parse_namespaces(table), renew=renew, **kwargs) + return self._table_exists( + table=self._parse_namespaces(table), renew=renew, **kwargs + ) @abstractmethod def _table_exists(self, table, **kwargs): pass - @quirk_docs('_table_drop') + @quirk_docs("_table_drop") def table_drop(self, table, **kwargs): """ Remove a table from the database. @@ -794,13 +840,15 @@ def table_drop(self, table, **kwargs): Returns: DB-API cursor: The cursor associated with this execution. """ - return self._table_drop(table=self._parse_namespaces(table, write=True), **kwargs) + return self._table_drop( + table=self._parse_namespaces(table, write=True), **kwargs + ) @abstractmethod def _table_drop(self, table, **kwargs): pass - @quirk_docs('_table_desc') + @quirk_docs("_table_desc") def table_desc(self, table, renew=True, **kwargs): """ Describe a table in the database. @@ -813,13 +861,15 @@ def table_desc(self, table, renew=True, **kwargs): Returns: pandas.DataFrame: A dataframe description of the table. """ - return self._table_desc(table=self._parse_namespaces(table), renew=renew, **kwargs) + return self._table_desc( + table=self._parse_namespaces(table), renew=renew, **kwargs + ) @abstractmethod def _table_desc(self, table, **kwargs): pass - @quirk_docs('_table_partition_cols') + @quirk_docs("_table_partition_cols") def table_partition_cols(self, table, renew=True, **kwargs): """ Extract the columns by which a table is partitioned (if database supports partitions). @@ -832,16 +882,16 @@ def table_partition_cols(self, table, renew=True, **kwargs): Returns: list: A list of columns by which table is partitioned. """ - return self._table_partition_cols(table=self._parse_namespaces(table), renew=renew, **kwargs) + return self._table_partition_cols( + table=self._parse_namespaces(table), renew=renew, **kwargs + ) def _table_partition_cols(self, table, **kwargs): raise NotImplementedError( - "Database backend `{}` does not support, or has not implemented, " - "support for extracting partition columns." - .format(self.__class__.__name__) + f"Database backend `{self.__class__.__name__}` does not support, or has not implemented, support for extracting partition columns." ) - @quirk_docs('_table_head') + @quirk_docs("_table_head") def table_head(self, table, n=10, renew=True, **kwargs): """ Retrieve the first `n` rows from a table. @@ -857,13 +907,15 @@ def table_head(self, table, n=10, renew=True, **kwargs): pandas.DataFrame: A dataframe representation of the first `n` rows of the nominated table. """ - return self._table_head(table=self._parse_namespaces(table), n=n, renew=renew, **kwargs) + return self._table_head( + table=self._parse_namespaces(table), n=n, renew=renew, **kwargs + ) @abstractmethod def _table_head(self, table, n=10, **kwargs): pass - @quirk_docs('_table_props') + @quirk_docs("_table_props") def table_props(self, table, renew=True, **kwargs): """ Retrieve the properties associated with a table. @@ -878,7 +930,9 @@ def table_props(self, table, renew=True, **kwargs): pandas.DataFrame: A dataframe representation of the table properties. """ - return self._table_props(table=self._parse_namespaces(table), renew=renew, **kwargs) + return self._table_props( + table=self._parse_namespaces(table), renew=renew, **kwargs + ) @abstractmethod def _table_props(self, table, **kwargs): @@ -910,10 +964,22 @@ def _register_magics(self, base_name): Documentation for these magics is provided online. """ from IPython import get_ipython - from IPython.core.magic import register_line_magic, register_cell_magic, register_line_cell_magic - - def statement_executor_magic(executor, statement, variable=None, show='head', transpose=False, template=True, context=None, **kwargs): + from IPython.core.magic import ( + register_line_magic, + register_cell_magic, + register_line_cell_magic, + ) + def statement_executor_magic( + executor, + statement, + variable=None, + show="head", + transpose=False, + template=True, + context=None, + **kwargs, + ): ip = get_ipython() if context is None: @@ -924,91 +990,106 @@ def statement_executor_magic(executor, statement, variable=None, show='head', tr return self.query_from_template(variable, context=context, **kwargs) # Cell magic - result = getattr(self, executor)(statement, template=template, context=context, **kwargs) + result = getattr(self, executor)( + statement, template=template, context=context, **kwargs + ) if variable is not None: ip.user_ns[variable] = result - if executor != 'query': + if executor != "query": if variable is None: return result - return - elif variable is None: + return None + if variable is None: return result - format = kwargs.get('format', self.DEFAULT_CURSOR_FORMATTER) - if show == 'head': + # pylint: disable=redefined-builtin + format = kwargs.get("format", self.DEFAULT_CURSOR_FORMATTER) + if show == "head": show = 10 if isinstance(show, int): - r = result.head(show) if format == 'pandas' else result[:show] - elif show == 'all': + r = result.head(show) if format == "pandas" else result[:show] + elif show == "all": r = result - elif show in (None, 'none'): + elif show in (None, "none"): return None else: - raise ValueError("Omniduct does not recognise the argument show='{0}' in cell magic.".format(show)) + raise ValueError( + f"Omniduct does not recognise the argument show='{show}' in cell magic." + ) - if format == 'pandas' and transpose: + if format == "pandas" and transpose: return r.T return r @register_line_cell_magic(base_name) @process_line_cell_arguments def query_magic(*args, **kwargs): - return statement_executor_magic('query', *args, **kwargs) + return statement_executor_magic("query", *args, **kwargs) - @register_line_cell_magic("{}.{}".format(base_name, 'execute')) + @register_line_cell_magic(f"{base_name}.execute") @process_line_cell_arguments def execute_magic(*args, **kwargs): - return statement_executor_magic('execute', *args, **kwargs) + return statement_executor_magic("execute", *args, **kwargs) - @register_line_cell_magic("{}.{}".format(base_name, 'stream')) + @register_line_cell_magic(f"{base_name}.stream") @process_line_cell_arguments def stream_magic(*args, **kwargs): - return statement_executor_magic('stream', *args, **kwargs) + return statement_executor_magic("stream", *args, **kwargs) - @register_cell_magic("{}.{}".format(base_name, 'template')) + @register_cell_magic(f"{base_name}.template") @process_line_arguments def template_add(body, name): self.template_add(name, body) - @register_line_cell_magic("{}.{}".format(base_name, 'render')) + @register_line_cell_magic(f"{base_name}.render") @process_line_cell_arguments - def template_render_magic(body=None, name=None, context=None, show=True, - cleanup=False, meta_only=False): - + def template_render_magic( + body=None, + name=None, + context=None, + show=True, + cleanup=False, + meta_only=False, + ): ip = get_ipython() if body is None: assert name is not None, "Name must be specified in line-mode." rendered = self.template_render( - name, context=context or ip.user_ns, by_name=True, - cleanup=cleanup, meta_only=meta_only + name, + context=context or ip.user_ns, + by_name=True, + cleanup=cleanup, + meta_only=meta_only, ) else: rendered = self.template_render( - body, context=context or ip.user_ns, by_name=False, - cleanup=cleanup, meta_only=meta_only + body, + context=context or ip.user_ns, + by_name=False, + cleanup=cleanup, + meta_only=meta_only, ) if name is not None: ip.user_ns[name] = rendered if show: - print(rendered) - else: - return rendered + return print(rendered) + return rendered - @register_line_magic("{}.{}".format(base_name, 'desc')) + @register_line_magic(f"{base_name}.desc") @process_line_arguments def table_desc(table_name, **kwargs): return self.table_desc(table_name, **kwargs) - @register_line_magic("{}.{}".format(base_name, 'head')) + @register_line_magic(f"{base_name}.head") @process_line_arguments def table_head(table_name, **kwargs): return self.table_head(table_name, **kwargs) - @register_line_magic("{}.{}".format(base_name, 'props')) + @register_line_magic(f"{base_name}.props") @process_line_arguments def table_props(table_name, **kwargs): return self.table_props(table_name, **kwargs) diff --git a/omniduct/databases/druid.py b/omniduct/databases/druid.py index fad956d..6e1cc18 100644 --- a/omniduct/databases/druid.py +++ b/omniduct/databases/druid.py @@ -12,11 +12,11 @@ class DruidClient(DatabaseClient): This Duct connects to a Druid server using the `pydruid` python library. """ - PROTOCOLS = ['druid'] + PROTOCOLS = ["druid"] DEFAULT_PORT = 80 - NAMESPACE_NAMES = ['table'] + NAMESPACE_NAMES = ["table"] NAMESPACE_QUOTECHAR = '"' - NAMESPACE_SEPARATOR = '.' + NAMESPACE_SEPARATOR = "." @override def _init(self): @@ -25,13 +25,16 @@ def _init(self): # Connection @override def _connect(self): - from pydruid.db import connect - logger.info('Connecting to Druid database ...') - self.__druid = connect(self.host, self.port, path='/druid/v2/sql/', scheme='http') + from pydruid.db import connect # pylint: disable=import-error + + logger.info("Connecting to Druid database ...") + self.__druid = connect( # pylint: disable=attribute-defined-outside-init + self.host, self.port, path="/druid/v2/sql/", scheme="http" + ) if self.username or self.password: logger.warning( - 'Duct username and password not passed to pydruid connection. ' - 'pydruid connection currently does not allow these fields to be passed.' + "Duct username and password not passed to pydruid connection. " + "pydruid connection currently does not allow these fields to be passed." ) @override @@ -40,12 +43,12 @@ def _is_connected(self): @override def _disconnect(self): - logger.info('Disconnecting from Druid database ...') + logger.info("Disconnecting from Druid database ...") try: self.__druid.close() - except Exception: + except: # pylint: disable=bare-except pass - self.__druid = None + self.__druid = None # pylint: disable=attribute-defined-outside-init # Querying @override @@ -65,7 +68,7 @@ def _table_exists(self, table, **kwargs): try: self.table_desc(table, **kwargs) return True - except: + except: # pylint: disable=bare-except return False finally: logger.disabled = False @@ -76,22 +79,12 @@ def _table_drop(self, table, **kwargs): @override def _table_desc(self, table, **kwargs): - query = (""" - SELECT - TABLE_SCHEMA - , TABLE_NAME - , COLUMN_NAME - , ORDINAL_POSITION - , COLUMN_DEFAULT - , IS_NULLABLE - , DATA_TYPE - FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_NAME = '{}'""").format(table) + query = f"\n SELECT\n TABLE_SCHEMA\n , TABLE_NAME\n , COLUMN_NAME\n , ORDINAL_POSITION\n , COLUMN_DEFAULT\n , IS_NULLABLE\n , DATA_TYPE\n FROM INFORMATION_SCHEMA.COLUMNS\n WHERE TABLE_NAME = '{table}'" return self.query(query, **kwargs) @override def _table_head(self, table, n=10, **kwargs): - return self.query("SELECT * FROM {} LIMIT {}".format(table, n), **kwargs) + return self.query(f"SELECT * FROM {table} LIMIT {n}", **kwargs) @override def _table_props(self, table, **kwargs): diff --git a/omniduct/databases/exasol.py b/omniduct/databases/exasol.py index d10a679..c0d7da1 100644 --- a/omniduct/databases/exasol.py +++ b/omniduct/databases/exasol.py @@ -49,11 +49,12 @@ def _connect(self): import pyexasol logger.info("Connecting to Exasol ...") + # pylint: disable-next=attribute-defined-outside-init self.__exasol = pyexasol.connect( - dsn="{host}:{port}".format(host=self.host, port=self.port), + dsn=f"{self.host}:{self.port}", user=self.username, password=self.password, - **self.engine_opts + **self.engine_opts, ) @override @@ -64,8 +65,9 @@ def _is_connected(self): def _disconnect(self): try: self.__exasol.close() - except Exception: + except: # pylint: disable=bare-except pass + # pylint: disable-next=attribute-defined-outside-init self.__exasol = None @override @@ -97,29 +99,23 @@ def _query_to_table(self, statement, table, if_exists, **kwargs): statements = [] if if_exists == "fail" and self.table_exists(table): - raise RuntimeError("Table {} already exists!".format(table)) - elif if_exists == "replace": - statements.append("DROP TABLE IF EXISTS {};".format(table)) + raise RuntimeError(f"Table {table} already exists!") + if if_exists == "replace": + statements.append(f"DROP TABLE IF EXISTS {table};") elif if_exists == "append": raise NotImplementedError( - "Append operations have not been implemented for {}.".format( - self.__class__.__name__ - ) + f"Append operations have not been implemented for {self.__class__.__name__}." ) - statement = "CREATE TABLE {table} AS ({statement})".format( - table=table, statement=statement - ) + statement = f"CREATE TABLE {table} AS ({statement})" return self.execute(statement, **kwargs) @override def _table_list(self, namespace, **kwargs): # Since this namespace is a conditional, exasol requires single quotations # instead of double quotations. " -> ' - query = ( - "SELECT TABLE_NAME FROM EXA_ALL_TABLES WHERE table_schema={}" - .format(namespace.render(quote_char="'")) - ) + exasol_namespace = namespace.render(quote_char="'") + query = f"SELECT TABLE_NAME FROM EXA_ALL_TABLES WHERE table_schema={exasol_namespace}" return self.query(query, **kwargs) @override @@ -128,7 +124,7 @@ def _table_exists(self, table, **kwargs): try: self.table_desc(table, **kwargs) return True - except: + except: # pylint: disable=bare-except return False finally: logger.disabled = False @@ -136,26 +132,17 @@ def _table_exists(self, table, **kwargs): @override def _table_drop(self, table, **kwargs): # Schema and tables are always under uppercase namespaces. - return self.execute( - "DROP TABLE {table}".format(table=str(table).upper()), - **kwargs - ) + return self.execute(f"DROP TABLE {str(table).upper()}", **kwargs) @override def _table_desc(self, table, **kwargs): # Schema and tables are always under uppercase namespaces. - return self.query( - "DESCRIBE {0}".format(str(table).upper()), - **kwargs - ) + return self.query(f"DESCRIBE {str(table).upper()}", **kwargs) @override def _table_head(self, table, n=10, **kwargs): # Schema and tables are always under uppercase namespaces. - return self.query( - "SELECT * FROM {} LIMIT {}".format(str(table).upper(), n), - **kwargs - ) + return self.query(f"SELECT * FROM {str(table).upper()} LIMIT {n}", **kwargs) @override def _table_props(self, table, **kwargs): diff --git a/omniduct/databases/hiveserver2.py b/omniduct/databases/hiveserver2.py index 161506f..7424c41 100644 --- a/omniduct/databases/hiveserver2.py +++ b/omniduct/databases/hiveserver2.py @@ -1,3 +1,5 @@ +# pylint: disable=abstract-method,consider-using-f-string + from __future__ import absolute_import import json @@ -45,32 +47,36 @@ class HiveServer2Client(DatabaseClient, SchemasMixin): `.connect()` methods of the drivers. """ - PROTOCOLS = ['hiveserver2'] + PROTOCOLS = ["hiveserver2"] DEFAULT_PORT = 3623 SUPPORTS_SESSION_PROPERTIES = True - NAMESPACE_NAMES = ['schema', 'table'] - NAMESPACE_QUOTECHAR = '`' - NAMESPACE_SEPARATOR = '.' + NAMESPACE_NAMES = ["schema", "table"] + NAMESPACE_QUOTECHAR = "`" + NAMESPACE_SEPARATOR = "." @property @override def NAMESPACE_DEFAULT(self): - return { - 'schema': self.schema - } + return {"schema": self.schema} @property @override def NAMESPACE_DEFAULTS_WRITE(self): defaults = self.NAMESPACE_DEFAULTS_READ.copy() - defaults['schema'] = self.username + defaults["schema"] = self.username return defaults @override - def _init(self, schema=None, driver='pyhive', auth_mechanism='NOSASL', - push_using_hive_cli=False, default_table_props=None, - thrift_transport=None, **connection_options - ): + def _init( + self, + schema=None, + driver="pyhive", + auth_mechanism="NOSASL", + push_using_hive_cli=False, + default_table_props=None, + thrift_transport=None, + **connection_options, + ): """ schema (str, None): The default database/schema to use for queries (will default to server-default if not specified). @@ -102,57 +108,73 @@ def _init(self, schema=None, driver='pyhive', auth_mechanism='NOSASL', self.default_table_props = default_table_props or {} self._thrift_transport = thrift_transport self.__hive = None - self.connection_fields += ('schema',) + self.connection_fields += ("schema",) - assert self.driver in ('pyhive', 'impyla'), "Supported drivers are pyhive and impyla." + assert self.driver in ( + "pyhive", + "impyla", + ), "Supported drivers are pyhive and impyla." @override def _connect(self): from sqlalchemy import create_engine, MetaData - if self.driver == 'pyhive': + + if self.driver == "pyhive": try: import pyhive.hive - except ImportError: - raise ImportError(""" - Omniduct is attempting to use the 'pyhive' driver, but it - is not installed. Please either install the pyhive package, - or reconfigure this Duct to use the 'impyla' driver. - """) - self.__hive = pyhive.hive.connect(host=None if self._thrift_transport else self.host, - port=None if self._thrift_transport else self.port, - auth=None if self._thrift_transport else self.auth_mechanism, - database=self.schema, - username=self.username, - password=None if self._thrift_transport else self.password, - thrift_transport=self._thrift_transport, - **self.connection_options) - self._sqlalchemy_engine = create_engine('hive://{}:{}/{}'.format(self.host, self.port, self.schema)) + except ImportError as e: + raise ImportError( + "Omniduct is attempting to use the 'pyhive' driver, but it " + "is not installed. Please either install the pyhive package, " + "or reconfigure this Duct to use the 'impyla' driver." + ) from e + # pylint: disable-next=attribute-defined-outside-init + self.__hive = pyhive.hive.connect( + host=None if self._thrift_transport else self.host, + port=None if self._thrift_transport else self.port, + auth=None if self._thrift_transport else self.auth_mechanism, + database=self.schema, + username=self.username, + password=None if self._thrift_transport else self.password, + thrift_transport=self._thrift_transport, + **self.connection_options, + ) + self._sqlalchemy_engine = create_engine( + f"hive://{self.host}:{self.port}/{self.schema}" + ) self._sqlalchemy_metadata = MetaData(self._sqlalchemy_engine) - elif self.driver == 'impyla': + elif self.driver == "impyla": try: import impala.dbapi - except ImportError: - raise ImportError(""" - Omniduct is attempting to use the 'impyla' driver, but it - is not installed. Please either install the impyla package, - or reconfigure this Duct to use the 'pyhive' driver. - """) - self.__hive = impala.dbapi.connect(host=self.host, - port=self.port, - auth_mechanism=self.auth_mechanism, - database=self.schema, - user=self.username, - password=self.password, - **self.connection_options) - self._sqlalchemy_engine = create_engine('impala://{}:{}/{}'.format(self.host, self.port, self.schema)) + except ImportError as e: + raise ImportError( + "Omniduct is attempting to use the 'impyla' driver, but it " + "is not installed. Please either install the impyla package, " + "or reconfigure this Duct to use the 'pyhive' driver." + ) from e + # pylint: disable-next=attribute-defined-outside-init + self.__hive = impala.dbapi.connect( + host=self.host, + port=self.port, + auth_mechanism=self.auth_mechanism, + database=self.schema, + user=self.username, + password=self.password, + **self.connection_options, + ) + self._sqlalchemy_engine = create_engine( + f"impala://{self.host}:{self.port}/{self.schema}" + ) self._sqlalchemy_metadata = MetaData(self._sqlalchemy_engine) def __hive_cursor(self): - if self.driver == 'impyla': # Impyla seems to have all manner of connection issues, attempt to restore connection + if ( + self.driver == "impyla" + ): # Impyla seems to have all manner of connection issues, attempt to restore connection try: with Timeout(1): return self.__hive.cursor() - except: + except: # pylint: disable=bare-except self._connect() return self.__hive.cursor() @@ -162,23 +184,25 @@ def _is_connected(self): @override def _disconnect(self): - logger.info('Disconnecting from Hive coordinator...') + logger.info("Disconnecting from Hive coordinator...") try: self.__hive.close() - except: + except: # pylint: disable=bare-except pass + # pylint: disable-next=attribute-defined-outside-init self.__hive = None self._sqlalchemy_engine = None self._sqlalchemy_metadata = None + # pylint: disable-next=attribute-defined-outside-init self._schemas = None @override def _statement_prepare(self, statement, session_properties, **kwargs): return ( "\n".join( - "SET {key} = {value};".format(key=key, value=value) - for key, value in session_properties.items() - ) + statement + f"SET {key} = {value};" for key, value in session_properties.items() + ) + + statement ) @override @@ -191,18 +215,22 @@ def _execute(self, statement, cursor, wait, session_properties, poll_interval=1) cursor = cursor or self.__hive_cursor() log_offset = 0 - if self.driver == 'pyhive': + if self.driver == "pyhive": from TCLIService.ttypes import TOperationState # noqa: F821 - cursor.execute(statement, **{'async': True}) + + cursor.execute(statement, **{"async": True}) if wait: status = cursor.poll().operationState - while status in (TOperationState.INITIALIZED_STATE, TOperationState.RUNNING_STATE): + while status in ( + TOperationState.INITIALIZED_STATE, + TOperationState.RUNNING_STATE, + ): log_offset = self._log_status(cursor, log_offset) time.sleep(poll_interval) status = cursor.poll().operationState - elif self.driver == 'impyla': + elif self.driver == "impyla": cursor.execute_async(statement) if wait: while cursor.is_executing(): @@ -213,33 +241,37 @@ def _execute(self, statement, cursor, wait, session_properties, poll_interval=1) @override def _cursor_empty(self, cursor): - if self.driver == 'impyla': + if self.driver == "impyla": return not cursor.has_result_set - elif self.driver == 'pyhive': + if self.driver == "pyhive": return cursor.description is None return False def _cursor_wait(self, cursor, poll_interval=1): from TCLIService.ttypes import TOperationState # noqa: F821 + status = cursor.poll().operationState - while status in (TOperationState.INITIALIZED_STATE, TOperationState.RUNNING_STATE): + while status in ( + TOperationState.INITIALIZED_STATE, + TOperationState.RUNNING_STATE, + ): time.sleep(poll_interval) status = cursor.poll().operationState def _log_status(self, cursor, log_offset=0): - matcher = re.compile('[0-9/]+ [0-9:]+ (INFO )?') + matcher = re.compile("[0-9/]+ [0-9:]+ (INFO )?") - if self.driver == 'pyhive': + if self.driver == "pyhive": log = cursor.fetch_logs() else: - log = cursor.get_log().strip().split('\n') + log = cursor.get_log().strip().split("\n") for line in log[log_offset:]: if not line: continue m = matcher.match(line) if m: - line = line[len(m.group(0)):] + line = line[len(m.group(0)) :] logger.info(line) return len(log) @@ -248,23 +280,30 @@ def _log_status(self, cursor, log_offset=0): def _query_to_table(self, statement, table, if_exists, **kwargs): statements = [] - if if_exists == 'fail' and self.table_exists(table): - raise RuntimeError("Table {} already exists!".format(table)) - elif if_exists == 'replace': - statements.append('DROP TABLE IF EXISTS {};'.format(table)) - elif if_exists == 'append': - raise NotImplementedError("Append operations have not been implemented for {}.".format(self.__class__.__name__)) + if if_exists == "fail" and self.table_exists(table): + raise RuntimeError(f"Table {table} already exists!") + if if_exists == "replace": + statements.append(f"DROP TABLE IF EXISTS {table};") + elif if_exists == "append": + raise NotImplementedError( + f"Append operations have not been implemented for {self.__class__.__name__}." + ) - statements.append("CREATE TABLE {table} AS ({statement})".format( - table=table, - statement=statement - )) - return self.execute('\n'.join(statements), **kwargs) + statements.append(f"CREATE TABLE {table} AS ({statement})") + return self.execute("\n".join(statements), **kwargs) @override def _dataframe_to_table( - self, df, table, if_exists='fail', use_hive_cli=None, - partition=None, sep=chr(1), table_props=None, dtype_overrides=None, **kwargs + self, + df, + table, + if_exists="fail", + use_hive_cli=None, + partition=None, + sep=chr(1), + table_props=None, + dtype_overrides=None, + **kwargs, ): """ If `use_hive_cli` (or if not specified `.push_using_hive_cli`) is @@ -322,30 +361,38 @@ def _dataframe_to_table( ) try: return _pandas.to_sql( - df=df, name=table.table, schema=table.schema, con=self._sqlalchemy_engine, - index=False, if_exists=if_exists, **kwargs + df=df, + name=table.table, + schema=table.schema, + con=self._sqlalchemy_engine, + index=False, + if_exists=if_exists, + **kwargs, ) except Exception as e: raise RuntimeError( "Push unsuccessful. Your version of Hive may be too old to " "support the `INSERT` keyword. You might want to try setting " - "`.push_using_hive_cli = True` if your local or remote " - "machine has access to the `hive` CLI executable. The " - "original exception was: {}".format(e.args[0]) - ) + "`.push_using_hive_cli = True` if your local or remote machine " + "has access to the `hive` CLI executable. The original " + f"exception was: {e}" + ) from e # Try using Hive CLI # If `partition` is specified, the associated columns must not be # present in the dataframe. - assert len(set(partition).intersection(df.columns)) == 0, "The dataframe to be uploaded must not have any partitioned fields. Please remove the field(s): {}.".format(','.join(set(partition).intersection(df.columns))) + assert ( + len(set(partition).intersection(df.columns)) == 0 + ), f"The dataframe to be uploaded must not have any partitioned fields. Please remove the field(s): {','.join(set(partition).intersection(df.columns))}." # Save dataframe to file and send it to the remote server if necessary - temp_dir = tempfile.mkdtemp(prefix='omniduct_hiveserver2') - tmp_fname = os.path.join(temp_dir, 'data_{}.csv'.format(time.time())) - logger.info('Saving dataframe to file... {}'.format(tmp_fname)) - df.fillna(r'\N').to_csv(tmp_fname, index=False, header=False, - sep=sep, encoding='utf-8') + temp_dir = tempfile.mkdtemp(prefix="omniduct_hiveserver2") + tmp_fname = os.path.join(temp_dir, f"data_{time.time()}.csv") + logger.info(f"Saving dataframe to file... {tmp_fname}") + df.fillna(r"\N").to_csv( + tmp_fname, index=False, header=False, sep=sep, encoding="utf-8" + ) if self.remote: logger.info("Uploading data to remote host...") @@ -355,14 +402,7 @@ def _dataframe_to_table( auto_table_props = set(self.default_table_props).difference(table_props) if len(auto_table_props) > 0: logger.warning( - "In addition to any specified table properties, this " - "HiveServer2Client has added the following default table " - "properties:\n{default_props}\nTo override them, please " - "specify overrides using: `.push(..., table_props={{...}}).`" - .format(default_props=json.dumps({ - prop: value for prop, value in self.default_table_props.items() - if prop in auto_table_props - }, indent=True)) + f"In addition to any specified table properties, this HiveServer2Client has added the following default table properties:\n{json.dumps({prop: value for prop, value in self.default_table_props.items() if prop in auto_table_props}, indent=True)}\nTo override them, please specify overrides using: `.push(..., table_props={{...}}).`" ) tblprops = self.default_table_props.copy() @@ -370,60 +410,49 @@ def _dataframe_to_table( cts = self._create_table_statement_from_df( df=df, table=table, - drop=(if_exists == 'replace') and not partition, + drop=(if_exists == "replace") and not partition, text=True, sep=sep, table_props=tblprops, partition_cols=list(partition), - dtype_overrides=dtype_overrides + dtype_overrides=dtype_overrides, ) # Generate load data statement. partition_clause = ( - '' + "" if not partition - else 'PARTITION ({})'.format( - ','.join("{key} = '{value}'".format(key=key, value=value) for key, value in partition.items()) + else "PARTITION ({})".format( + ",".join(f"{key} = '{value}'" for key, value in partition.items()) ) ) - lds = '\nLOAD DATA LOCAL INPATH "{path}" {overwrite} INTO TABLE {table} {partition_clause};'.format( - path=os.path.basename(tmp_fname) if self.remote else tmp_fname, - overwrite="OVERWRITE" if if_exists == "replace" else "", - table=table, - partition_clause=partition_clause - ) + lds = f"\nLOAD DATA LOCAL INPATH \"{os.path.basename(tmp_fname) if self.remote else tmp_fname}\" {'OVERWRITE' if if_exists == 'replace' else ''} INTO TABLE {table} {partition_clause};" # Run create table statement and load data statments logger.info( - "Creating hive table `{table}` if it does not " - "already exist, and inserting the provided data{partition}." - .format( - table=table, - partition=" into {}".format(partition_clause) if partition_clause else "" - ) + f"Creating hive table `{table}` if it does not already exist, and inserting the provided data{f' into {partition_clause}' if partition_clause else ''}." ) try: - stmts = '\n'.join([cts, lds]) + stmts = "\n".join([cts, lds]) logger.debug(stmts) proc = self._run_in_hivecli(stmts) if proc.returncode != 0: - raise RuntimeError(proc.stderr.decode('utf-8')) + raise RuntimeError(proc.stderr.decode("utf-8")) finally: # Clean up files if self.remote: - self.remote.execute('rm -rf {}'.format(tmp_fname)) + self.remote.execute(f"rm -rf {tmp_fname}") shutil.rmtree(temp_dir, ignore_errors=True) - logger.info("Successfully uploaded dataframe {partition}`{table}`.".format( - table=table, - partition="into {} of ".format(partition_clause) if partition_clause else "" - )) + logger.info( + f"Successfully uploaded dataframe {f'into {partition_clause} of ' if partition_clause else ''}`{table}`." + ) + return None @override - def _table_list(self, namespace, like='*', **kwargs): + def _table_list(self, namespace, like="*", **kwargs): schema = namespace.name or self.schema - return self.query("SHOW TABLES IN {0} '{1}'".format(schema, like), - **kwargs) + return self.query(f"SHOW TABLES IN {schema} '{like}'", **kwargs) @override def _table_exists(self, table, **kwargs): @@ -431,48 +460,53 @@ def _table_exists(self, table, **kwargs): try: self.table_desc(table, **kwargs) return True - except: + except: # pylint: disable=bare-except return False finally: logger.disabled = False @override def _table_drop(self, table, **kwargs): - return self.execute("DROP TABLE {table}".format(table=table)) + return self.execute(f"DROP TABLE {table}") @override def _table_desc(self, table, **kwargs): - records = self.query("DESCRIBE {0}".format(table), **kwargs) + records = self.query(f"DESCRIBE {table}", **kwargs) + + if not records: + raise RuntimeError(f"Table {table} does not appear to have any fields.") # pretty hacky but hive doesn't return DESCRIBE results in a nice format # TODO is there any information we should pull out of DESCRIBE EXTENDED + i = 0 for i, record in enumerate(records): - if record[0] == '': + if record[0] == "": break - columns = ['col_name', 'data_type', 'comment'] + columns = ["col_name", "data_type", "comment"] fields_df = pd.DataFrame(records[:i], columns=columns) - partitions_df = pd.DataFrame(records[i + 4:], columns=columns) - partitions_df['comment'] = "PARTITION " + partitions_df['comment'] + partitions_df = pd.DataFrame(records[i + 4 :], columns=columns) + partitions_df["comment"] = "PARTITION " + partitions_df["comment"] return pd.concat((fields_df, partitions_df)) @override def _table_head(self, table, n=10, **kwargs): - return self.query("SELECT * FROM {} LIMIT {}".format(table, n), **kwargs) + return self.query(f"SELECT * FROM {table} LIMIT {n}", **kwargs) @override def _table_props(self, table, **kwargs): - return self.query('SHOW TBLPROPERTIES {0}'.format(table), **kwargs) + return self.query(f"SHOW TBLPROPERTIES {table}", **kwargs) def _run_in_hivecli(self, cmd): """Run a query using hive cli in a subprocess.""" # Turn hive command into quotable string. - double_escaped = re.sub('\\' * 2, '\\' * 4, cmd) - backtick_escape = '\\\\\\`' if self.remote else '\\`' - sys_cmd = 'hive -e "{0}"'.format(re.sub('"', '\\"', double_escaped)) \ - .replace('`', backtick_escape) + double_escaped = re.sub("\\" * 2, "\\" * 4, cmd) + backtick_escape = "\\\\\\`" if self.remote else "\\`" + sys_cmd = 'hive -e "{0}"'.format(re.sub('"', '\\"', double_escaped)).replace( + "`", backtick_escape + ) # Execute command in a subprocess. if self.remote: proc = self.remote.execute(sys_cmd) @@ -481,10 +515,18 @@ def _run_in_hivecli(self, cmd): return proc @classmethod - def _create_table_statement_from_df(cls, df, table, drop=False, - text=True, sep=chr(1), loc=None, - table_props=None, partition_cols=None, - dtype_overrides=None): + def _create_table_statement_from_df( + cls, + df, + table, + drop=False, + text=True, + sep=chr(1), + loc=None, + table_props=None, + partition_cols=None, + dtype_overrides=None, + ): """ Return create table statement for new hive table based on pandas dataframe. @@ -512,40 +554,39 @@ def _create_table_statement_from_df(cls, df, table, drop=False, # dtype kind to hive type mapping dict. DTYPE_KIND_HIVE_TYPE = { - 'b': 'BOOLEAN', # boolean - 'i': 'BIGINT', # signed integer - 'u': 'BIGINT', # unsigned integer - 'f': 'DOUBLE', # floating-point - 'c': 'STRING', # complex floating-point - 'O': 'STRING', # object - 'S': 'STRING', # (byte-)string - 'U': 'STRING', # Unicode - 'V': 'STRING' # void + "b": "BOOLEAN", # boolean + "i": "BIGINT", # signed integer + "u": "BIGINT", # unsigned integer + "f": "DOUBLE", # floating-point + "c": "STRING", # complex floating-point + "O": "STRING", # object + "S": "STRING", # (byte-)string + "U": "STRING", # Unicode + "V": "STRING", # void } # Sanitise column names and map numpy/pandas data-types to hive types. columns = [] for col, dtype in df.dtypes.iteritems(): - col_sanitized = re.sub(r'\W', '', col.lower().replace(' ', '_')) + col_sanitized = re.sub(r"\W", "", col.lower().replace(" ", "_")) hive_type = dtype_overrides.get(col) or DTYPE_KIND_HIVE_TYPE.get(dtype.kind) if hive_type is None: - hive_type = DTYPE_KIND_HIVE_TYPE['O'] + hive_type = DTYPE_KIND_HIVE_TYPE["O"] logger.warning( - 'Unable to determine hive type for dataframe column {col} of pandas dtype {dtype}. ' - 'Defaulting to hive type {hive_type}. If other column type is desired, ' - 'please specify via `dtype_overrides`' - .format(**locals()) + "Unable to determine hive type for dataframe column {col} of pandas dtype {dtype}. " + "Defaulting to hive type {hive_type}. If other column type is desired, " + "please specify via `dtype_overrides`".format(**locals()) ) - columns.append( - ' {column} {type}'.format(column=col_sanitized, type=hive_type) - ) + columns.append(f" {col_sanitized} {hive_type}") - partition_columns = ['{} STRING'.format(col) for col in partition_cols] + # pylint: disable-next=possibly-unused-variable + partition_columns = [f"{col} STRING" for col in partition_cols] - tblprops = ["'{key}' = '{value}'".format(key=key, value=value) for key, value in table_props.items()] - tblprops = "TBLPROPERTIES({})".format(",".join(tblprops)) if len(tblprops) > 0 else "" + tblprops = [f"'{key}' = '{value}'" for key, value in table_props.items()] + tblprops = f"TBLPROPERTIES({','.join(tblprops)})" if len(tblprops) > 0 else "" - cmd = Template(""" + cmd = Template( + """ {% if drop %} DROP TABLE IF EXISTS {{ table }}; {% endif -%} @@ -571,6 +612,7 @@ def _create_table_statement_from_df(cls, df, table, drop=False, {%- endif %} {{ tblprops }} ; - """).render(**locals()) + """ + ).render(**locals()) return cmd diff --git a/omniduct/databases/neo4j.py b/omniduct/databases/neo4j.py index 5ddde15..854ce1a 100644 --- a/omniduct/databases/neo4j.py +++ b/omniduct/databases/neo4j.py @@ -1,3 +1,5 @@ +# pylint: disable=abstract-method + from __future__ import absolute_import from interface_meta import override @@ -13,9 +15,9 @@ class Neo4jClient(DatabaseClient): library. """ - PROTOCOLS = ['neo4j'] + PROTOCOLS = ["neo4j"] DEFAULT_PORT = 7687 - DEFAULT_CURSOR_FORMATTER = 'raw' + DEFAULT_CURSOR_FORMATTER = "raw" @override @classmethod @@ -30,22 +32,26 @@ def _init(self): @override def _connect(self): from neo4j.v1 import GraphDatabase - logger.info('Connecting to Neo4J graph database ...') + + logger.info("Connecting to Neo4J graph database ...") auth = (self.username, self.password) if self.username else None - self.__driver = GraphDatabase.driver("bolt://{}:{}".format(self.host, self.port), auth=auth) # TODO: Add kerberos support + # pylint: disable-next=attribute-defined-outside-init + self.__driver = GraphDatabase.driver( + f"bolt://{self.host}:{self.port}", auth=auth + ) # TODO: Add kerberos support @override def _is_connected(self): - return hasattr(self, '__driver') and self.__driver is not None + return hasattr(self, "__driver") and self.__driver is not None @override def _disconnect(self): - logger.info('Disconnecting from Neo4J graph database ...') + logger.info("Disconnecting from Neo4J graph database ...") try: self.__driver.close() - except Exception: + except: # pylint: disable=bare-except pass - self.__driver = None + self.__driver = None # pylint: disable=attribute-defined-outside-init # Querying @override @@ -60,24 +66,24 @@ def _execute(self, statement, cursor, wait, session_properties): @override def _table_exists(self, table, **kwargs): - raise Exception('tables do not apply to the Neo4J graph database') + raise RuntimeError("tables do not apply to the Neo4J graph database") @override def _table_drop(self, table, **kwargs): - raise Exception('tables do not apply to the Neo4J graph database') + raise RuntimeError("tables do not apply to the Neo4J graph database") @override def _table_desc(self, table, **kwargs): - raise Exception('tables do not apply to the Neo4J graph database') + raise RuntimeError("tables do not apply to the Neo4J graph database") @override def _table_head(self, table, n=10, **kwargs): - raise Exception('tables do not apply to the Neo4J graph database') + raise RuntimeError("tables do not apply to the Neo4J graph database") @override def _table_list(self, namespace, **kwargs): - raise Exception('tables do not apply to the Neo4J graph database') + raise RuntimeError("tables do not apply to the Neo4J graph database") @override def _table_props(self, table, **kwargs): - raise Exception('tables do not apply to the Neo4J graph database') + raise RuntimeError("tables do not apply to the Neo4J graph database") diff --git a/omniduct/databases/presto.py b/omniduct/databases/presto.py index af36617..f26308b 100644 --- a/omniduct/databases/presto.py +++ b/omniduct/databases/presto.py @@ -1,3 +1,5 @@ +# pylint: disable=consider-using-f-string + from __future__ import absolute_import import ast @@ -6,9 +8,7 @@ import sys import pandas.io.sql -import six from interface_meta import override -from future.utils import raise_with_traceback from omniduct.utils.debug import logger @@ -33,30 +33,34 @@ class PrestoClient(DatabaseClient, SchemasMixin): `pyhive.presto.connect(...)`. """ - PROTOCOLS = ['presto'] + PROTOCOLS = ["presto"] DEFAULT_PORT = 3506 SUPPORTS_SESSION_PROPERTIES = True - NAMESPACE_NAMES = ['catalog', 'schema', 'table'] + NAMESPACE_NAMES = ["catalog", "schema", "table"] NAMESPACE_QUOTECHAR = '"' - NAMESPACE_SEPARATOR = '.' + NAMESPACE_SEPARATOR = "." @property @override def NAMESPACE_DEFAULT(self): - return { - 'catalog': self.catalog, - 'schema': self.schema - } + return {"catalog": self.catalog, "schema": self.schema} @property @override def NAMESPACE_DEFAULTS_WRITE(self): defaults = self.NAMESPACE_DEFAULTS_READ.copy() - defaults['schema'] = self.username + defaults["schema"] = self.username return defaults @override - def _init(self, catalog='default', schema='default', server_protocol='http', source=None, requests_session=None): + def _init( + self, + catalog="default", + schema="default", + server_protocol="http", + source=None, + requests_session=None, + ): """ catalog (str): The default catalog to use in database queries. schema (str): The default schema/database to use in database queries. @@ -73,7 +77,7 @@ def _init(self, catalog='default', schema='default', server_protocol='http', sou self.server_protocol = server_protocol self.source = source self.__presto = None - self.connection_fields += ('catalog', 'schema') + self.connection_fields += ("catalog", "schema") self._requests_session = requests_session @property @@ -82,35 +86,38 @@ def source(self): @source.setter def source(self, source): - self._source = source or 'omniduct' + self._source = source or "omniduct" # Connection @override def _connect(self): from sqlalchemy import create_engine, MetaData - logging.getLogger('pyhive').setLevel(1000) # Silence pyhive logging. - logger.info('Connecting to Presto coordinator...') - self._sqlalchemy_engine = create_engine('presto://{}:{}/{}/{}'.format(self.host, self.port, self.catalog, self.schema)) + + logging.getLogger("pyhive").setLevel(1000) # Silence pyhive logging. + logger.info("Connecting to Presto coordinator...") + self._sqlalchemy_engine = create_engine( + f"presto://{self.host}:{self.port}/{self.catalog}/{self.schema}" + ) self._sqlalchemy_metadata = MetaData(self._sqlalchemy_engine) @override def _is_connected(self): try: return self.__presto is not None - except: + except: # pylint: disable=bare-except return False @override def _disconnect(self): - logger.info('Disconnecting from Presto coordinator...') + logger.info("Disconnecting from Presto coordinator...") try: self.__presto.close() - except: + except: # pylint: disable=bare-except pass self._sqlalchemy_engine = None self._sqlalchemy_metadata = None - self._schemas = None + self._schemas = None # pylint: disable=attribute-defined-outside-init # Querying @override @@ -120,14 +127,22 @@ def _execute(self, statement, cursor, wait, session_properties): log and present the user with useful debugging information. If that fails, the full traceback will be raised instead. """ - from pyhive import presto # Imported here due to slow import performance in Python 3 - from pyhive.exc import DatabaseError # Imported here due to slow import performance in Python 3 + from pyhive import presto + from pyhive.exc import DatabaseError + try: cursor = cursor or presto.Cursor( - host=self.host, port=self.port, username=self.username, password=self.password, - catalog=self.catalog, schema=self.schema, session_props=session_properties, - poll_interval=1, source=self.source, protocol=self.server_protocol, - requests_session=self._requests_session + host=self.host, + port=self.port, + username=self.username, + password=self.password, + catalog=self.catalog, + schema=self.schema, + session_props=session_properties, + poll_interval=1, + source=self.source, + protocol=self.server_protocol, + requests_session=self._requests_session, ) cursor.execute(statement) status = cursor.poll() @@ -135,9 +150,13 @@ def _execute(self, statement, cursor, wait, session_properties): logger.progress(0) # status None means command executed successfully # See https://github.com/dropbox/PyHive/blob/master/pyhive/presto.py#L234 - while status is not None and status['stats']['state'] != "FINISHED": - if status['stats'].get('totalSplits', 0) > 0: - pct_complete = round(status['stats']['completedSplits'] / float(status['stats']['totalSplits']), 4) + while status is not None and status["stats"]["state"] != "FINISHED": + if status["stats"].get("totalSplits", 0) > 0: + pct_complete = round( + status["stats"]["completedSplits"] + / float(status["stats"]["totalSplits"]), + 4, + ) logger.progress(pct_complete * 100) status = cursor.poll() logger.progress(100, complete=True) @@ -149,59 +168,80 @@ def _execute(self, statement, cursor, wait, session_properties): try: message = e.args[0] - if isinstance(message, six.string_types): - message = ast.literal_eval(re.match("[^{]*({.*})[^}]*$", message).group(1)) + if isinstance(message, str): + message = ast.literal_eval( + re.match("[^{]*({.*})[^}]*$", message).group(1) + ) - linenumber = message['errorLocation']['lineNumber'] - 1 + linenumber = message["errorLocation"]["lineNumber"] - 1 splt = statement.splitlines() - splt[linenumber] += ' <-- {errorType} ({errorName}) occurred. {message} '.format(**message) - context = '\n\n[Error Context]\n{}\n'.format('\n'.join([splt[ln] for ln in range(max(linenumber - 1, 0), - min(linenumber + 2, len(splt)))])) - - class ErrContext(object): - + splt[ + linenumber + ] += " <-- {errorType} ({errorName}) occurred. {message} ".format( + **message + ) + context = "\n\n[Error Context]\n{}\n".format( + "\n".join( + [ + splt[ln] + for ln in range( + max(linenumber - 1, 0), min(linenumber + 2, len(splt)) + ) + ] + ) + ) + + class ErrContext: def __repr__(self): return context # logged twice so that both notebook and console users see the error context exception_args.args = [exception_args, ErrContext()] logger.error(context) - except: - logger.warn(("Omniduct was unable to parse the database error messages. Refer to the " - "traceback below for full error details.")) + except: # pylint: disable=bare-except + logger.warn( + ( + "Omniduct was unable to parse the database error messages. Refer to the " + "traceback below for full error details." + ) + ) if isinstance(exception, type): exception = exception(exception_args) - raise_with_traceback(exception, traceback) + raise exception.with_traceback(traceback) @override def _query_to_table(self, statement, table, if_exists, **kwargs): statements = [] - if if_exists == 'fail' and self.table_exists(table): - raise RuntimeError("Table {} already exists!".format(table)) - elif if_exists == 'replace': - statements.append('DROP TABLE IF EXISTS {};\n'.format(table)) - elif if_exists == 'append': - raise NotImplementedError("Append operations have not been implemented for {}.".format(self.__class__.__name__)) + if if_exists == "fail" and self.table_exists(table): + raise RuntimeError(f"Table {table} already exists!") + if if_exists == "replace": + statements.append(f"DROP TABLE IF EXISTS {table};\n") + elif if_exists == "append": + raise NotImplementedError( + f"Append operations have not been implemented for {self.__class__.__name__}." + ) - statements.append("CREATE TABLE {table} AS ({statement})".format( - table=table, - statement=statement - )) - return self.execute('\n'.join(statements), **kwargs) + statements.append(f"CREATE TABLE {table} AS ({statement})") + return self.execute("\n".join(statements), **kwargs) @override - def _dataframe_to_table(self, df, table, if_exists='fail', **kwargs): + def _dataframe_to_table(self, df, table, if_exists="fail", **kwargs): """ If if the schema namespace is not specified, `table.schema` will be defaulted to your username. Catalog overrides will be ignored, and will default to `self.catalog`. """ return _pandas.to_sql( - df=df, name=table.table, schema=table.schema, con=self._sqlalchemy_engine, - index=False, if_exists=if_exists, **kwargs + df=df, + name=table.table, + schema=table.schema, + con=self._sqlalchemy_engine, + index=False, + if_exists=if_exists, + **kwargs, ) @override @@ -232,22 +272,22 @@ def _table_exists(self, table, **kwargs): @override def _table_drop(self, table, **kwargs): - return self.execute("DROP TABLE {table}".format(table=table)) + return self.execute(f"DROP TABLE {table}") @override def _table_desc(self, table, **kwargs): - return self.query("DESCRIBE {0}".format(table), **kwargs) + return self.query(f"DESCRIBE {table}", **kwargs) @override def _table_partition_cols(self, table, **kwargs): desc = self._table_desc(table, **kwargs) - if 'Extra' in desc: - return list(desc[desc['Extra'].str.contains('partition key')]['Column']) + if "Extra" in desc: + return list(desc[desc["Extra"].str.contains("partition key")]["Column"]) return [] @override def _table_head(self, table, n=10, **kwargs): - return self.query("SELECT * FROM {} LIMIT {}".format(table, n), **kwargs) + return self.query(f"SELECT * FROM {table} LIMIT {n}", **kwargs) @override def _table_props(self, table, **kwargs): diff --git a/omniduct/databases/pyspark.py b/omniduct/databases/pyspark.py index fe93093..79af2ba 100644 --- a/omniduct/databases/pyspark.py +++ b/omniduct/databases/pyspark.py @@ -1,3 +1,5 @@ +# pylint: disable=abstract-method + from interface_meta import override from omniduct.databases.base import DatabaseClient @@ -9,15 +11,17 @@ class PySparkClient(DatabaseClient): This Duct connects to a local PySpark session using the `pyspark` library. """ - PROTOCOLS = ['pyspark'] + PROTOCOLS = ["pyspark"] DEFAULT_PORT = None SUPPORTS_SESSION_PROPERTIES = True - NAMESPACE_NAMES = ['schema', 'table'] - NAMESPACE_QUOTECHAR = '`' - NAMESPACE_SEPARATOR = '.' + NAMESPACE_NAMES = ["schema", "table"] + NAMESPACE_QUOTECHAR = "`" + NAMESPACE_SEPARATOR = "." @override - def _init(self, app_name='omniduct', config=None, master=None, enable_hive_support=False): + def _init( + self, app_name="omniduct", config=None, master=None, enable_hive_support=False + ): """ Args: app_name (str): The application name of the SparkSession. @@ -51,6 +55,7 @@ def _connect(self): for key, value in self.config.items(): builder.config(key, value) + # pylint: disable-next=attribute-defined-outside-init self._spark_session = builder.getOrCreate() @override @@ -66,19 +71,23 @@ def _disconnect(self): def _statement_prepare(self, statement, session_properties, **kwargs): return ( "\n".join( - "SET {key} = {value};".format(key=key, value=value) - for key, value in session_properties.items() - ) + statement + f"SET {key} = {value};" for key, value in session_properties.items() + ) + + statement ) @override def _execute(self, statement, cursor, wait, session_properties): - assert wait is True, "This Spark backend does not support asynchronous operations." + assert ( + wait is True + ), "This Spark backend does not support asynchronous operations." return SparkCursor(self._spark_session.sql(statement)) @override def _query_to_table(self, statement, table, if_exists, **kwargs): - return HiveServer2Client._query_to_table(self, statement, table, if_exists, **kwargs) + return HiveServer2Client._query_to_table( + self, statement, table, if_exists, **kwargs + ) @override def _table_list(self, namespace, **kwargs): @@ -105,7 +114,7 @@ def _table_props(self, table, **kwargs): return HiveServer2Client._table_props(self, table, **kwargs) -class SparkCursor(object): +class SparkCursor: """ This DBAPI2 compatible cursor wraps around a Spark DataFrame """ @@ -116,7 +125,7 @@ def __init__(self, df): @property def df_iter(self): - if not getattr(self, '_df_iter'): + if not getattr(self, "_df_iter"): self._df_iter = self.df.toLocalIterator() return self._df_iter @@ -124,10 +133,10 @@ def df_iter(self): @property def description(self): - return tuple([ + return tuple( (name, type_, None, None, None, None, None) for name, type_ in self.df.dtypes - ]) + ) @property def row_count(self): @@ -136,17 +145,14 @@ def row_count(self): def close(self): pass - def execute(operation, parameters=None): + def execute(self, operation, parameters=None): raise NotImplementedError - def executemany(operation, seq_of_parameters=None): + def executemany(self, operation, seq_of_parameters=None): raise NotImplementedError def fetchone(self): - return [ - value or None - for value in next(self.df_iter) - ] + return [value or None for value in next(self.df_iter)] def fetchmany(self, size=None): size = size or self.arraysize diff --git a/omniduct/databases/sqlalchemy.py b/omniduct/databases/sqlalchemy.py index 4238894..f4b3bcc 100644 --- a/omniduct/databases/sqlalchemy.py +++ b/omniduct/databases/sqlalchemy.py @@ -17,25 +17,33 @@ class SQLAlchemyClient(DatabaseClient, SchemasMixin): clients. """ - PROTOCOLS = ['sqlalchemy', 'firebird', 'mssql', 'mysql', 'oracle', 'postgresql', 'sybase', 'snowflake'] - NAMESPACE_NAMES = ['database', 'table'] + PROTOCOLS = [ + "sqlalchemy", + "firebird", + "mssql", + "mysql", + "oracle", + "postgresql", + "sybase", + "snowflake", + ] + NAMESPACE_NAMES = ["database", "table"] NAMESPACE_QUOTECHAR = '"' # TODO: Apply overrides depending on protocol? - NAMESPACE_SEPARATOR = '.' + NAMESPACE_SEPARATOR = "." @property @override def NAMESPACE_DEFAULT(self): - return { - 'database': self.database - } + return {"database": self.database} @override - def _init(self, dialect=None, driver=None, database='', engine_opts=None): - - assert self._port is not None, "Omniduct requires SQLAlchemy databases to manually specify a port, as " \ - "it will often be the case that ports are being forwarded." + def _init(self, dialect=None, driver=None, database="", engine_opts=None): + assert self._port is not None, ( + "Omniduct requires SQLAlchemy databases to manually specify a port, as " + "it will often be the case that ports are being forwarded." + ) - if self.protocol != 'sqlalchemy': + if self.protocol != "sqlalchemy": self.dialect = self.protocol else: self.dialect = dialect @@ -43,7 +51,7 @@ def _init(self, dialect=None, driver=None, database='', engine_opts=None): self.driver = driver self.database = database - self.connection_fields += ('schema',) + self.connection_fields += ("schema",) self.engine_opts = engine_opts or {} self.engine = None @@ -51,23 +59,28 @@ def _init(self, dialect=None, driver=None, database='', engine_opts=None): @property def db_uri(self): - return '{dialect}://{login}@{host_port}/{database}'.format( - dialect=self.dialect + ("+{}".format(self.driver) if self.driver else ''), - login=self.username + (":{}".format(self.password) if self.password else ''), - host_port=self.host + (":{}".format(self.port) if self.port else ''), - database=self.database + # pylint: disable-next=consider-using-f-string + return "{dialect}://{login}@{host_port}/{database}".format( + dialect=self.dialect + (f"+{self.driver}" if self.driver else ""), + login=self.username + (f":{self.password}" if self.password else ""), + host_port=self.host + (f":{self.port}" if self.port else ""), + database=self.database, ) @override def _connect(self): import sqlalchemy - if self.protocol not in ['mysql']: - logger.warning("While querying and executing should work as " - "expected, some operations on this database client " - "(such as listing tables, querying to tables, etc) " - "may not function as expected due to the backend " - "not supporting ANSI SQL.") + if self.protocol not in ["mysql"]: + logger.warning( + "While querying and executing should work as " + "expected, some operations on this database client " + "(such as listing tables, querying to tables, etc) " + "may not function as expected due to the backend " + "not supporting ANSI SQL." + ) + + # pylint: disable-next=attribute-defined-outside-init self.engine = sqlalchemy.create_engine(self.db_uri, **self.engine_opts) self._sqlalchemy_metadata = sqlalchemy.MetaData(self.engine) @@ -77,12 +90,16 @@ def _is_connected(self): @override def _disconnect(self): + # pylint: disable-next=attribute-defined-outside-init self.engine = None self._sqlalchemy_metadata = None + # pylint: disable-next=attribute-defined-outside-init self._schemas = None @override - def _execute(self, statement, cursor, wait, session_properties, query=True, **kwargs): + def _execute( + self, statement, cursor, wait, session_properties, query=True, **kwargs + ): assert wait, "`SQLAlchemyClient` does not support asynchronous operations." if cursor: cursor.execute(statement) @@ -94,29 +111,33 @@ def _execute(self, statement, cursor, wait, session_properties, query=True, **kw def _query_to_table(self, statement, table, if_exists, **kwargs): statements = [] - if if_exists == 'fail' and self.table_exists(table): - raise RuntimeError("Table {} already exists!".format(table)) - elif if_exists == 'replace': - statements.append('DROP TABLE IF EXISTS {};'.format(table)) - elif if_exists == 'append': - raise NotImplementedError("Append operations have not been implemented for {}.".format(self.__class__.__name__)) + if if_exists == "fail" and self.table_exists(table): + raise RuntimeError(f"Table {table} already exists!") + if if_exists == "replace": + statements.append(f"DROP TABLE IF EXISTS {table};") + elif if_exists == "append": + raise NotImplementedError( + f"Append operations have not been implemented for {self.__class__.__name__}." + ) - statement = "CREATE TABLE {table} AS ({statement})".format( - table=table, - statement=statement - ) + statement = f"CREATE TABLE {table} AS ({statement})" return self.execute(statement, **kwargs) @override - def _dataframe_to_table(self, df, table, if_exists='fail', **kwargs): + def _dataframe_to_table(self, df, table, if_exists="fail", **kwargs): return _pandas.to_sql( - df=df, name=table.table, schema=table.database, con=self.engine, - index=False, if_exists=if_exists, **kwargs + df=df, + name=table.table, + schema=table.database, + con=self.engine, + index=False, + if_exists=if_exists, + **kwargs, ) @override def _table_list(self, namespace, **kwargs): - return self.query("SHOW TABLES IN {}".format(namespace), **kwargs) + return self.query(f"SHOW TABLES IN {namespace}", **kwargs) @override def _table_exists(self, table, **kwargs): @@ -124,22 +145,22 @@ def _table_exists(self, table, **kwargs): try: self.table_desc(table, **kwargs) return True - except: + except: # pylint: disable=bare-except return False finally: logger.disabled = False @override def _table_drop(self, table, **kwargs): - return self.execute("DROP TABLE {table}".format(table=table)) + return self.execute(f"DROP TABLE {table}") @override def _table_desc(self, table, **kwargs): - return self.query("DESCRIBE {0}".format(table), **kwargs) + return self.query(f"DESCRIBE {table}", **kwargs) @override def _table_head(self, table, n=10, **kwargs): - return self.query("SELECT * FROM {} LIMIT {}".format(table, n), **kwargs) + return self.query(f"SELECT * FROM {table} LIMIT {n}", **kwargs) @override def _table_props(self, table, **kwargs): diff --git a/omniduct/databases/stub.py b/omniduct/databases/stub.py index b8113a5..0ed87f8 100644 --- a/omniduct/databases/stub.py +++ b/omniduct/databases/stub.py @@ -2,7 +2,6 @@ class StubDatabaseClient(DatabaseClient): - PROTOCOLS = [] DEFAULT_PORT = None @@ -25,7 +24,7 @@ def _disconnect(self): def _execute(self, statement, cursor, wait, session_properties, **kwargs): raise NotImplementedError - def _table_list(self, **kwargs): + def _table_list(self, namespace, **kwargs): raise NotImplementedError def _table_exists(self, table, **kwargs): diff --git a/omniduct/duct.py b/omniduct/duct.py index 02fe775..98cddd7 100644 --- a/omniduct/duct.py +++ b/omniduct/duct.py @@ -9,8 +9,6 @@ from builtins import input from enum import Enum -import six -from future.utils import raise_with_traceback, with_metaclass from interface_meta import InterfaceMeta, quirk_docs from omniduct.errors import DuctProtocolUnknown, DuctServerUnreachable @@ -19,7 +17,7 @@ from omniduct.utils.ports import is_port_bound, naive_load_balancer -class Duct(with_metaclass(InterfaceMeta, object)): +class Duct(metaclass=InterfaceMeta): """ The abstract base class for all protocol implementations. @@ -32,6 +30,7 @@ class Duct(with_metaclass(InterfaceMeta, object)): connnect and disconnect as required, and so manual intervention is not typically required to maintain connections. """ + __doc_attrs = """ protocol (str): The name of the protocol for which this instance was created (especially useful if a `Duct` subclass supports multiple @@ -77,27 +76,38 @@ class Duct(with_metaclass(InterfaceMeta, object)): """ __doc_cls_attrs__ = None - INTERFACE_SKIPPED_NAMES = {'__init__', '_init'} + INTERFACE_SKIPPED_NAMES = {"__init__", "_init"} class Type(Enum): """ The `Duct.Type` enum specifies all of the permissible values of `Duct.DUCT_TYPE`. Also determines the order in which ducts are loaded by DuctRegistry. """ - REMOTE = 'remotes' - FILESYSTEM = 'filesystems' - CACHE = 'caches' - RESTFUL = 'rest_clients' - DATABASE = 'databases' - OTHER = 'other' + + REMOTE = "remotes" + FILESYSTEM = "filesystems" + CACHE = "caches" + RESTFUL = "rest_clients" + DATABASE = "databases" + OTHER = "other" AUTO_LOGGING_SCOPE = True DUCT_TYPE = None PROTOCOLS = None - def __init__(self, protocol=None, name=None, registry=None, remote=None, - host=None, port=None, username=None, password=None, cache=None, - cache_namespace=None): + def __init__( + self, + protocol=None, + name=None, + registry=None, + remote=None, + host=None, + port=None, + username=None, + password=None, + cache=None, + cache_namespace=None, + ): """ protocol (str, None): Name of protocol (used by Duct registries to inform Duct instances of how they were instantiated). @@ -132,29 +142,36 @@ class name if not specified). self.cache = cache self.cache_namespace = cache_namespace - self.connection_fields = ('host', 'port', 'remote', 'username', 'password') - self.prepared_fields = ('_host', '_port', '_username', '_password') + self.connection_fields = ("host", "port", "remote", "username", "password") + self.prepared_fields = ("_host", "_port", "_username", "_password") atexit.register(self.disconnect) self.__prepared = False - self.__getting = False + self.__getting = False # pylint: disable=unused-private-member self.__connected = False - self.__disconnecting = False + self.__disconnecting = False # pylint: disable=unused-private-member self.__cached_auth = {} self.__prepreparation_values = {} @classmethod def __register_implementation__(cls): - if not hasattr(cls, '_protocols'): + if not hasattr(cls, "_protocols"): cls._protocols = {} cls._protocols[cls.__name__] = cls - registry_keys = getattr(cls, 'PROTOCOLS', []) or [] + registry_keys = getattr(cls, "PROTOCOLS", []) or [] if registry_keys: for key in registry_keys: - if key in cls._protocols and cls.__name__ != cls._protocols[key].__name__: - logger.info("Ignoring attempt by class `{}` to register key '{}', which is already registered for class `{}`.".format(cls.__name__, key, cls._protocols[key].__name__)) + if ( + key in cls._protocols + and cls.__name__ != cls._protocols[key].__name__ + ): + logger.info( + f"Ignoring attempt by class `{cls.__name__}` to register " + f"key '{key}', which is already registered for class " + f"`{cls._protocols[key].__name__}`." + ) else: cls._protocols[key] = cls @@ -176,32 +193,31 @@ def for_protocol(cls, protocol): named protocol. """ if protocol not in cls._protocols: - raise DuctProtocolUnknown("Missing `Duct` implementation for protocol: '{}'.".format(protocol)) + raise DuctProtocolUnknown( + f"Missing `Duct` implementation for protocol: '{protocol}'." + ) return functools.partial(cls._protocols[protocol], protocol=protocol) @property - def __prepare_triggers(self): - return ( - ('cache',) - + object.__getattribute__(self, 'connection_fields') - ) + def __prepare_triggers(self): # pylint: disable=unused-private-member + return ("cache",) + object.__getattribute__(self, "connection_fields") @classmethod def __init_with_kwargs__(cls, self, kwargs, **fallbacks): - if not hasattr(self, '_Duct__inited_using_kwargs'): + if not hasattr(self, "_Duct__inited_using_kwargs"): self._Duct__inited_using_kwargs = {} - for cls_parent in reversed([ - parent for parent in inspect.getmro(cls) + for cls_parent in reversed( + [ + parent + for parent in inspect.getmro(cls) if issubclass(parent, Duct) and parent not in self._Duct__inited_using_kwargs - and '__init__' in parent.__dict__ - ]): + and "__init__" in parent.__dict__ + ] + ): self._Duct__inited_using_kwargs[cls_parent] = True - if six.PY3: - argspec = inspect.getfullargspec(cls_parent.__init__) - keys = argspec.args[1:] + argspec.kwonlyargs - else: - keys = inspect.getargspec(cls_parent.__init__).args[1:] + argspec = inspect.getfullargspec(cls_parent.__init__) + keys = argspec.args[1:] + argspec.kwonlyargs params = {} for key in keys: if key in kwargs: @@ -212,34 +228,40 @@ def __init_with_kwargs__(cls, self, kwargs, **fallbacks): def __getattribute__(self, key): try: - if (not object.__getattribute__(self, '_Duct__prepared') - and not object.__getattribute__(self, '_Duct__getting') - and not object.__getattribute__(self, '_Duct__disconnecting') - and key in object.__getattribute__(self, '_Duct__prepare_triggers')): - object.__setattr__(self, '_Duct__getting', True) - object.__getattribute__(self, 'prepare')() - object.__setattr__(self, '_Duct__getting', False) + if ( + not object.__getattribute__(self, "_Duct__prepared") + and not object.__getattribute__(self, "_Duct__getting") + and not object.__getattribute__(self, "_Duct__disconnecting") + and key in object.__getattribute__(self, "_Duct__prepare_triggers") + ): + object.__setattr__(self, "_Duct__getting", True) + object.__getattribute__(self, "prepare")() + object.__setattr__(self, "_Duct__getting", False) except AttributeError: pass - except Exception as e: - object.__setattr__(self, '_Duct__getting', False) - raise_with_traceback(e) + except: # pylint: disable=bare-except + object.__setattr__(self, "_Duct__getting", False) + raise return object.__getattribute__(self, key) def __setattr__(self, key, value): try: - if (object.__getattribute__(self, '_Duct__prepared') - and object.__getattribute__(self, 'connection_fields') - and key in self.connection_fields - and self.is_connected()): - logger.warn('Disconnecting prior to changing field that connection is based on: {}.'.format(key)) + if ( + object.__getattribute__(self, "_Duct__prepared") + and object.__getattribute__(self, "connection_fields") + and key in self.connection_fields + and self.is_connected() + ): + logger.warn( + f"Disconnecting prior to changing field that connection is based on: {key}." + ) self.disconnect() self.__prepared = False except AttributeError: pass object.__setattr__(self, key, value) - @quirk_docs('_prepare') + @quirk_docs("_prepare") def prepare(self): """ Prepare a Duct subclass for use (if not already prepared). @@ -280,37 +302,43 @@ def _prepare(self): from omniduct.remotes.base import RemoteClient # Check registry is of an appropriate type (if present) - assert (self.registry is None) or isinstance(self.registry, DuctRegistry), "Provided registry is not an instance of `omniduct.registry.DuctRegistry`." + assert (self.registry is None) or isinstance( + self.registry, DuctRegistry + ), "Provided registry is not an instance of `omniduct.registry.DuctRegistry`." # If registry is present, lookup remotes and caches if necessary if self.registry is not None: - if self.remote and isinstance(self.remote, six.string_types): - self.__prepreparation_values['remote'] = self.remote + if self.remote and isinstance(self.remote, str): + self.__prepreparation_values["remote"] = self.remote self.remote = self.registry.lookup(self.remote, kind=Duct.Type.REMOTE) - if self.cache and isinstance(self.cache, six.string_types): - self.__prepreparation_values['cache'] = self.cache + if self.cache and isinstance(self.cache, str): + self.__prepreparation_values["cache"] = self.cache self.cache = self.registry.lookup(self.cache, kind=Duct.Type.CACHE) # Check if remote and cache objects are of correct type (if present) - assert (self.remote is None) or isinstance(self.remote, RemoteClient), "Provided remote is not an instance of `omniduct.remotes.base.RemoteClient`." - assert (self.cache is None) or isinstance(self.cache, Cache), "Provided cache is not an instance of `omniduct.caches.base.Cache`." + assert (self.remote is None) or isinstance( + self.remote, RemoteClient + ), "Provided remote is not an instance of `omniduct.remotes.base.RemoteClient`." + assert (self.cache is None) or isinstance( + self.cache, Cache + ), "Provided cache is not an instance of `omniduct.caches.base.Cache`." # Replace prepared fields with the result of calling existing values # with a reference to `self`. for field in self.prepared_fields: value = getattr(self, field) - if hasattr(value, '__call__'): + if hasattr(value, "__call__"): self.__prepreparation_values[field] = value setattr(self, field, value(self)) if isinstance(self._host, (list, tuple)): - if '_host' not in self.__prepreparation_values: - self.__prepreparation_values['_host'] = self._host + if "_host" not in self.__prepreparation_values: + self.__prepreparation_values["_host"] = self._host self._host = naive_load_balancer(self._host, port=self._port) # If host has a port included in it, override the value of self._port - if self._host is not None and re.match(r'[^\:]+:[0-9]{1,5}', self._host): - self._host, self._port = self._host.split(':') + if self._host is not None and re.match(r"[^\:]+:[0-9]{1,5}", self._host): + self._host, self._port = self._host.split(":") # Ensure port is an integer value self.port = int(self._port) if self._port else None @@ -345,7 +373,7 @@ def host(self): at runtime using: `duct.host = ''`. """ if self.remote: - return '127.0.0.1' # TODO: Make this configurable. + return "127.0.0.1" # TODO: Make this configurable. return self._host @host.setter @@ -361,7 +389,7 @@ def port(self): at runtime using: `duct.port = `. """ if self.remote: - return self.remote.port_forward('{}:{}'.format(self._host, self._port)) + return self.remote.port_forward(f"{self._host}:{self._port}") return self._port @port.setter @@ -379,12 +407,14 @@ def username(self): specify a different username at runtime using: `duct.username = ''`. """ if self._username is True: - if 'username' not in self.__cached_auth: - self.__cached_auth['username'] = input("Enter username for '{}':".format(self.name)) - return self.__cached_auth['username'] - elif self._username is False: + if "username" not in self.__cached_auth: + self.__cached_auth["username"] = input( + f"Enter username for '{self.name}':" + ) + return self.__cached_auth["username"] + if self._username is False: return None - elif not self._username: + if not self._username: try: username = os.getlogin() except OSError: @@ -407,10 +437,12 @@ def password(self): using: `duct.password = ''`. """ if self._password is True: - if 'password' not in self.__cached_auth: - self.__cached_auth['password'] = getpass.getpass("Enter password for '{}':".format(self.name)) - return self.__cached_auth['password'] - elif self._password is False: + if "password" not in self.__cached_auth: + self.__cached_auth["password"] = getpass.getpass( + f"Enter password for '{self.name}':" + ) + return self.__cached_auth["password"] + if self._password is False: return None return self._password @@ -431,17 +463,20 @@ def __assert_server_reachable(self): if self.remote and not self.remote.is_port_bound(self._host, self._port): self.disconnect() raise DuctServerUnreachable( - "Remote '{}' cannot connect to '{}:{}'. Please check your settings before trying again.".format( - self.remote.name, self._host, self._port)) - elif not self.remote: + f"Remote '{self.remote.name}' cannot connect to " + f"'{self._host}:{self._port}'. Please check your settings " + "before trying again." + ) + if not self.remote: self.disconnect() raise DuctServerUnreachable( - "Cannot connect to '{}:{}' on your current connection. Please check your connection before trying again.".format( - self.host, self.port)) + f"Cannot connect to '{self.host}:{self.port}' on your current " + "connection. Please check your connection before trying again." + ) # Connection @logging_scope("Connecting") - @quirk_docs('_connect') + @quirk_docs("_connect") def connect(self): """ Connect to the service backing this client. @@ -454,27 +489,21 @@ def connect(self): """ if self.host: logger.info( - "Connecting to {host}:{port}{remote}.".format( - host=self._host, - port=self._port, - remote="on {}".format(self.remote.host) if self.remote else "" - ) + f"Connecting to {self._host}:{self._port}" + f"{f'on {self.remote.host}' if self.remote else ''}." ) self.__assert_server_reachable() if not self.is_connected(): try: self._connect() - except Exception as e: + except: # pylint: disable=bare-except self.reset() - raise_with_traceback(e) + raise self.__connected = True if self.host: logger.info( - "Connected to {host}:{port}{remote}.".format( - host=self._host, - port=self._port, - remote="on {}".format(self.remote.host) if self.remote else "" - ) + f"Connected to {self._host}:{self._port}" + f"{f'on {self.remote.host}' if self.remote else ''}." ) return self @@ -482,7 +511,7 @@ def connect(self): def _connect(self): raise NotImplementedError - @quirk_docs('_is_connected') + @quirk_docs("_is_connected") def is_connected(self): """ Check whether this `Duct` instances is currently connected. @@ -501,7 +530,7 @@ def is_connected(self): if self.remote: if not self.remote.has_port_forward(self._host, self._port): return False - elif not is_port_bound(self.host, self.port): + if not is_port_bound(self.host, self.port): self.disconnect() return False @@ -511,7 +540,7 @@ def is_connected(self): def _is_connected(self): raise NotImplementedError - @quirk_docs('_disconnect') + @quirk_docs("_disconnect") def disconnect(self): """ Disconnect this client from backing service. @@ -526,18 +555,18 @@ def disconnect(self): `Duct` instance: A reference to this object. """ if not self.__prepared: - return - self.__disconnecting = True + return None + self.__disconnecting = True # pylint: disable=unused-private-member self.__connected = False try: self._disconnect() if self.remote and self.remote.has_port_forward(self._host, self._port): - logger.info('Freeing up local port {0}...'.format(self.port)) + logger.info(f"Freeing up local port {self.port}...") self.remote.port_forward_stop(local_port=self.port) finally: - self.__disconnecting = False + self.__disconnecting = False # pylint: disable=unused-private-member return self diff --git a/omniduct/filesystems/_pyarrow_compat.py b/omniduct/filesystems/_pyarrow_compat.py index 0372b41..63057e8 100644 --- a/omniduct/filesystems/_pyarrow_compat.py +++ b/omniduct/filesystems/_pyarrow_compat.py @@ -35,7 +35,7 @@ def mkdir(self, path, create_parents=True): return self.fs.mkdir(_stringify_path(path), recursive=create_parents) @implements(FileSystem.open) - def open(self, path, mode='rb'): + def open(self, path, mode="rb"): return self.fs.open(_stringify_path(path), mode=mode) @implements(FileSystem.ls) diff --git a/omniduct/filesystems/_webhdfs_helpers.py b/omniduct/filesystems/_webhdfs_helpers.py index 20d201c..ae7c1cc 100644 --- a/omniduct/filesystems/_webhdfs_helpers.py +++ b/omniduct/filesystems/_webhdfs_helpers.py @@ -1,12 +1,16 @@ +import http.client import json import xml.dom.minidom + import requests -from six.moves import http_client from pywebhdfs import errors -from pywebhdfs.webhdfs import (PyWebHdfsClient, _is_standby_exception, - _move_active_host_to_head) +from pywebhdfs.webhdfs import ( + PyWebHdfsClient, + _is_standby_exception, + _move_active_host_to_head, +) class OmniductPyWebHdfsClient(PyWebHdfsClient): @@ -22,17 +26,18 @@ def __init__(self, remote=None, namenodes=None, **kwargs): PyWebHdfsClient.__init__(self, **kwargs) - if self.namenodes and 'path_to_hosts' not in kwargs: - self.path_to_hosts = [('.*', self.namenodes)] + if self.namenodes and "path_to_hosts" not in kwargs: + self.path_to_hosts = [(".*", self.namenodes)] # Override base uri - self.base_uri_pattern = kwargs.get('base_uri_pattern', "http://{host}/webhdfs/v1/").format( - host="{host}") + self.base_uri_pattern = kwargs.get( + "base_uri_pattern", "http://{host}/webhdfs/v1/" + ).format(host="{host}") @property def host(self): - host = 'localhost' if self.remote else self._host - return '{}:{}'.format(host, str(self.port)) + host = "localhost" if self.remote else self._host + return f"{host}:{str(self.port)}" @host.setter def host(self, host): @@ -41,7 +46,7 @@ def host(self, host): @property def port(self): if self.remote: - return self.remote.port_forward('{}:{}'.format(self._host, self._port)) + return self.remote.port_forward(f"{self._host}:{self._port}") return self._port @port.setter @@ -51,9 +56,10 @@ def port(self, port): @property def namenodes(self): if self.remote: - return ['localhost:{}'.format(self.remote.port_forward(nn)) for nn in self._namenodes] - else: - return self._namenodes + return [ + f"localhost:{self.remote.port_forward(nn)}" for nn in self._namenodes + ] + return self._namenodes @namenodes.setter def namenodes(self, namenodes): @@ -66,35 +72,46 @@ def _make_uri_local(self, uri): return uri def get_home_directory(self): - response = self._resolve_host(requests.get, True, '/', operation='GETHOMEDIRECTORY') + response = self._resolve_host( + requests.get, True, "/", operation="GETHOMEDIRECTORY" + ) if response.ok: - return json.loads(response.content)['Path'] - return '/' + return json.loads(response.content)["Path"] + return "/" - def _resolve_host(self, req_func, allow_redirect, - path, operation, **kwargs): + def _resolve_host(self, req_func, allow_redirect, path, operation, **kwargs): """ This is where the magic happens, and where omniduct handles redirects during federation and HA. """ - import requests uri_without_host = self._create_uri(path, operation, **kwargs) hosts = self._resolve_federation(path) for host in hosts: uri = uri_without_host.format(host=host) try: while True: - response = req_func(uri, allow_redirects=False, - timeout=self.timeout, - **self.request_extra_opts) - - if allow_redirect and response.status_code == http_client.TEMPORARY_REDIRECT: - uri = self._make_uri_local(response.headers['location']) + response = req_func( + uri, + allow_redirects=False, + timeout=self.timeout, + **self.request_extra_opts, + ) + + if ( + allow_redirect + and response.status_code == http.client.TEMPORARY_REDIRECT + ): + uri = self._make_uri_local(response.headers["location"]) else: break - if not allow_redirect and response.status_code == http_client.TEMPORARY_REDIRECT: - response.headers['location'] = self._make_uri_local(response.headers['location']) + if ( + not allow_redirect + and response.status_code == http.client.TEMPORARY_REDIRECT + ): + response.headers["location"] = self._make_uri_local( + response.headers["location"] + ) if not _is_standby_exception(response): _move_active_host_to_head(hosts, host) @@ -104,7 +121,7 @@ def _resolve_host(self, req_func, allow_redirect, raise errors.ActiveHostNotFound(msg="Could not find active host") -class CdhHdfsConfParser(object): +class CdhHdfsConfParser: """ This class serves to automatically extract HDFS cluster information from Cloudera configuration files. @@ -118,11 +135,12 @@ def __init__(self, fs, conf_path=None): conf_path (str): The path of the configuration file to be parsed. """ self.fs = fs - self.conf_path = conf_path or '/etc/hadoop/conf.cloudera.hdfs2/hdfs-site.xml' + self.conf_path = conf_path or "/etc/hadoop/conf.cloudera.hdfs2/hdfs-site.xml" @property def config(self): - if not hasattr(self, '_config'): + if not hasattr(self, "_config"): + # pylint: disable-next=attribute-defined-outside-init self._config = self._get_config() return self._config @@ -130,11 +148,14 @@ def _get_config(self): with self.fs.open(self.conf_path) as f: d = xml.dom.minidom.parseString(f.read()) - properties = d.getElementsByTagName('property') + properties = d.getElementsByTagName("property") return { - prop.getElementsByTagName('name')[0].childNodes[0].wholeText: - prop.getElementsByTagName('value')[0].childNodes[0].wholeText + prop.getElementsByTagName("name")[0] + .childNodes[0] + .wholeText: prop.getElementsByTagName("value")[0] + .childNodes[0] + .wholeText for prop in properties } @@ -142,13 +163,13 @@ def _get_config(self): def clusters(self): clusters = [] for key in self.config: - if key.startswith('dfs.ha.namenodes.'): - clusters.append(key[len('dfs.ha.namenodes.'):]) + if key.startswith("dfs.ha.namenodes."): + clusters.append(key[len("dfs.ha.namenodes.") :]) return clusters def namenodes(self, cluster): - namenodes = self.config['dfs.ha.namenodes.{}'.format(cluster)].split(',') + namenodes = self.config[f"dfs.ha.namenodes.{cluster}"].split(",") return [ - self.config['dfs.namenode.http-address.{}.{}'.format(cluster, namenode)] + self.config[f"dfs.namenode.http-address.{cluster}.{namenode}"] for namenode in namenodes ] diff --git a/omniduct/filesystems/base.py b/omniduct/filesystems/base.py index 74fc321..f638818 100644 --- a/omniduct/filesystems/base.py +++ b/omniduct/filesystems/base.py @@ -23,8 +23,10 @@ class FileSystemClient(Duct, MagicsProvider): DUCT_TYPE = Duct.Type.FILESYSTEM DEFAULT_PORT = None - @quirk_docs('_init', mro=True) - def __init__(self, cwd=None, home=None, read_only=False, global_writes=False, **kwargs): + @quirk_docs("_init", mro=True) + def __init__( # pylint: disable=super-init-not-called + self, cwd=None, home=None, read_only=False, global_writes=False, **kwargs + ): """ cwd (None, str): The path prefix to use as the current working directory (if None, the user's home directory is used where that makes sense). @@ -51,7 +53,7 @@ def _init(self): # Path properties and helpers @property - @quirk_docs('_path_home') + @quirk_docs("_path_home") @require_connection def path_home(self): """ @@ -69,7 +71,9 @@ def path_home(self): @path_home.setter def path_home(self, path_home): if path_home is not None and not path_home.startswith(self.path_separator): - raise ValueError("The home path must be absolute. Received: '{}'.".format(path_home)) + raise ValueError( + f"The home path must be absolute. Received: '{path_home}'." + ) self.__path_home = path_home @abstractmethod @@ -92,7 +96,7 @@ def path_cwd(self, path_cwd): self._path_cwd = path_cwd @property - @quirk_docs('_path_separator') + @quirk_docs("_path_separator") def path_separator(self): """ str: The character(s) to use in separating path components. Typically @@ -124,12 +128,12 @@ def path_join(self, path, *components): in order, to the base path. """ for component in components: - if component.startswith('~'): + if component.startswith("~"): path = self.path_home + component[1:] elif component.startswith(self.path_separator): path = component else: - path = '{}{}{}'.format(path, self.path_separator if not path.endswith(self.path_separator) else '', component) + path = f"{path}{self.path_separator if not path.endswith(self.path_separator) else ''}{component}" return path def path_basename(self, path): @@ -163,7 +167,9 @@ def path_dirname(self, path): Returns: str: The extracted directory path. """ - return self.path_separator.join(self._path(path).split(self.path_separator)[:-1]) + return self.path_separator.join( + self._path(path).split(self.path_separator)[:-1] + ) def path_normpath(self, path): """ @@ -181,18 +187,20 @@ def path_normpath(self, path): components = self._path(path).split(self.path_separator) out_path = [] for component in components: - if component == '' and len(out_path) > 0: + if component == "" and len(out_path) > 0: continue - if component == '.': + if component == ".": continue - elif component == '..': + if component == "..": if len(out_path) > 1: out_path.pop() else: - raise RuntimeError("Cannot access parent directory of filesystem root.") + raise RuntimeError( + "Cannot access parent directory of filesystem root." + ) else: out_path.append(component) - if len(out_path) == 1 and out_path[0] == '': + if len(out_path) == 1 and out_path[0] == "": return self.path_separator return self.path_separator.join(out_path) @@ -229,14 +237,18 @@ def global_writes(self, global_writes): def _assert_path_is_writable(self, path): if self.read_only: - raise RuntimeError("This filesystem client is configured for read-only access. Set `{}`.`read_only` to `False` to override.".format(self.name)) + raise RuntimeError( + f"This filesystem client is configured for read-only access. Set `{self.name}`.`read_only` to `False` to override." + ) if not self.global_writes and not self._path_in_home_dir(path): - raise RuntimeError("Attempt to write outside of home directory without setting `{}`.`global_writes` to `True`.".format(self.name)) + raise RuntimeError( + f"Attempt to write outside of home directory without setting `{self.name}`.`global_writes` to `True`." + ) return True # Filesystem accessors - @quirk_docs('_exists') + @quirk_docs("_exists") @require_connection def exists(self, path): """ @@ -255,7 +267,7 @@ def exists(self, path): def _exists(self, path): raise NotImplementedError - @quirk_docs('_isdir') + @quirk_docs("_isdir") @require_connection def isdir(self, path): """ @@ -274,7 +286,7 @@ def isdir(self, path): def _isdir(self, path): raise NotImplementedError - @quirk_docs('_isfile') + @quirk_docs("_isfile") @require_connection def isfile(self, path): """ @@ -301,7 +313,7 @@ def _dir(self, path): """ raise NotImplementedError - @quirk_docs('_dir') + @quirk_docs("_dir") @require_connection def dir(self, path=None): """ @@ -321,7 +333,7 @@ def dir(self, path=None): generator: The children of `path` represented as `FileSystemFileDesc` objects. """ - assert self.isdir(path), "'{}' is not a valid directory.".format(path) + assert self.isdir(path), f"'{path}' is not a valid directory." return self._dir(self._path(path)) def listdir(self, path=None): @@ -359,7 +371,7 @@ def showdir(self, path=None): pandas.DataFrame: A DataFrame representation of the contents of the nominated directory. """ - assert self.isdir(path), "'{}' is not a valid directory.".format(path) + assert self.isdir(path), f"'{path}' is not a valid directory." return self._showdir(self._path(path)) def _showdir(self, path): @@ -367,15 +379,14 @@ def _showdir(self, path): if len(data) > 0: return ( pd.DataFrame(data) - .sort_values(['type', 'name']) + .sort_values(["type", "name"]) .reset_index(drop=True) - .dropna(axis='columns', how='all') - .drop(axis=1, labels=['fs', 'path']) + .dropna(axis="columns", how="all") + .drop(axis=1, labels=["fs", "path"]) ) - else: - return "Directory has no contents." + return "Directory has no contents." - @quirk_docs('_walk') + @quirk_docs("_walk") @require_connection def walk(self, path=None): """ @@ -393,24 +404,26 @@ def walk(self, path=None): generator: A generator of tuples, each tuple being associated with one directory that is either `path` or one of its descendants. """ - assert self.isdir(path), "'{}' is not a valid directory.".format(path) + assert self.isdir(path), f"'{path}' is not a valid directory." return self._walk(self._path(path)) def _walk(self, path): dirs = [] files = [] for f in self._dir(path): - if f.type == 'directory': + if f.type == "directory": dirs.append(f.name) else: files.append(f.name) yield (path, dirs, files) - for dir in dirs: - for walked in self._walk(self._path(self.path_join(path, dir))): # Note: using _walk directly here, which may fail if disconnected during walk. + for dirname in dirs: + for walked in self._walk( + self._path(self.path_join(path, dirname)) + ): # Note: using _walk directly here, which may fail if disconnected during walk. yield walked - @quirk_docs('_find') + @quirk_docs("_find") @require_connection def find(self, path_prefix=None, **attrs): """ @@ -435,31 +448,34 @@ def find(self, path_prefix=None, **attrs): objects that are descendents of `path_prefix` and which statisfy provided constraints. """ - assert self.isdir(path_prefix), "'{0}' is not a valid directory. Did you mean `.find(name='{0}')`?".format(path_prefix) + assert self.isdir( + path_prefix + ), f"'{path_prefix}' is not a valid directory. Did you mean `.find(name='{path_prefix}')`?" return self._find(self._path(path_prefix), **attrs) def _find(self, path_prefix, **attrs): - def is_match(f): for attr, value in attrs.items(): - if hasattr(value, '__call__') and not value(f.as_dict().get(attr)): + if hasattr(value, "__call__") and not value(f.as_dict().get(attr)): return False - elif value != f.as_dict().get(attr): + if value != f.as_dict().get(attr): return False return True dirs = [] for f in self._dir(path_prefix): - if f.type == 'directory': + if f.type == "directory": dirs.append(f.name) if is_match(f): yield f - for dir in dirs: - for match in self._find(self._path(self.path_join(path_prefix, dir)), **attrs): # Note: using _find directly here, which may fail if disconnected during find. + for dirname in dirs: + for match in self._find( + self._path(self.path_join(path_prefix, dirname)), **attrs + ): # Note: using _find directly here, which may fail if disconnected during find. yield match - @quirk_docs('_mkdir') + @quirk_docs("_mkdir") @require_connection def mkdir(self, path, recursive=True, exist_ok=False): """ @@ -482,7 +498,7 @@ def mkdir(self, path, recursive=True, exist_ok=False): def _mkdir(self, path, recursive, exist_ok): raise NotImplementedError - @quirk_docs('_remove') + @quirk_docs("_remove") @require_connection def remove(self, path, recursive=False): """ @@ -498,9 +514,11 @@ def remove(self, path, recursive=False): """ self._assert_path_is_writable(path) if not self.exists(path): - raise IOError("No file(s) exist at path '{}'.".format(path)) + raise IOError(f"No file(s) exist at path '{path}'.") if self.isdir(path) and not recursive: - raise IOError("Attempt to remove directory '{}' without passing `recursive=True`.".format(path)) + raise IOError( + f"Attempt to remove directory '{path}' without passing `recursive=True`." + ) return self._remove(self._path(path), recursive) @abstractmethod @@ -509,9 +527,9 @@ def _remove(self, path, recursive): # File handling - @quirk_docs('_open') + @quirk_docs("_open") @require_connection - def open(self, path, mode='rt'): + def open(self, path, mode="rt"): """ Open a file for reading and/or writing. @@ -528,14 +546,14 @@ def open(self, path, mode='rt'): Returns: FileSystemFile or file-like: An opened file-like object. """ - if 'w' in mode or 'a' in mode or '+' in mode: + if "w" in mode or "a" in mode or "+" in mode: self._assert_path_is_writable(path) return self._open(self._path(path), mode=mode) def _open(self, path, mode): return FileSystemFile(self, path, mode) - @quirk_docs('_file_read_') + @quirk_docs("_file_read_") @require_connection def _file_read(self, path, size=-1, offset=0, binary=False): """ @@ -552,12 +570,14 @@ def _file_read(self, path, size=-1, offset=0, binary=False): Returns: str or bytes: The contents of the file. """ - return self._file_read_(self._path(path), size=size, offset=offset, binary=binary) + return self._file_read_( + self._path(path), size=size, offset=offset, binary=binary + ) def _file_read_(self, path, size=-1, offset=0, binary=False): raise NotImplementedError - @quirk_docs('_file_write_') + @quirk_docs("_file_write_") @require_connection def _file_write(self, path, s, binary=False): """ @@ -579,7 +599,7 @@ def _file_write(self, path, s, binary=False): def _file_write_(self, path, s, binary): raise NotImplementedError - @quirk_docs('_file_append_') + @quirk_docs("_file_append_") @require_connection def _file_append(self, path, s, binary=False): """ @@ -603,7 +623,7 @@ def _file_append_(self, path, s, binary): # File transfer - @quirk_docs('_download') + @quirk_docs("_download") def download(self, source, dest=None, overwrite=False, fs=None): """ Download files to another filesystem. @@ -635,13 +655,14 @@ def download(self, source, dest=None, overwrite=False, fs=None): if fs is None: from .local import LocalFsClient + fs = LocalFsClient() source = self._path(source) dest = fs._path(dest or self.path_basename(source)) if dest.endswith(fs.path_separator): - assert fs.isdir(dest), "No such directory `{}`".format(dest) + assert fs.isdir(dest), f"No such directory `{dest}`" if not source.endswith(self.path_separator): dest = fs.path_join(fs._path(dest), self.path_basename(source)) @@ -651,25 +672,41 @@ def download(self, source, dest=None, overwrite=False, fs=None): if self.isdir(source): target_prefix = ( - source if source.endswith(self.path_separator) else source + self.path_separator + source + if source.endswith(self.path_separator) + else source + self.path_separator ) targets.append((source, dest, True)) for path, dirs, files in self.walk(source): - for dir in dirs: - target_source = self.path_join(path, dir) - targets.append(( - target_source, - fs.path_join(dest, *target_source[len(target_prefix):].split(self.path_separator)), - True - )) + for dirname in dirs: + target_source = self.path_join(path, dirname) + targets.append( + ( + target_source, + fs.path_join( + dest, + *target_source[len(target_prefix) :].split( + self.path_separator + ), + ), + True, + ) + ) for file in files: target_source = self.path_join(path, file) - targets.append(( - target_source, - fs.path_join(dest, *target_source[len(target_prefix):].split(self.path_separator)), - False - )) + targets.append( + ( + target_source, + fs.path_join( + dest, + *target_source[len(target_prefix) :].split( + self.path_separator + ), + ), + False, + ) + ) else: targets.append((source, dest, False)) @@ -682,8 +719,8 @@ def download(self, source, dest=None, overwrite=False, fs=None): def _download(self, source, dest, overwrite, fs): if not overwrite and fs.exists(dest): raise RuntimeError("File already exists on filesystem.") - with self.open(source, 'rb') as f_src: - with fs.open(dest, 'wb') as f_dest: + with self.open(source, "rb") as f_src: + with fs.open(dest, "wb") as f_dest: f_dest.write(f_src.read()) def upload(self, source, dest=None, overwrite=False, fs=None): @@ -713,6 +750,7 @@ def upload(self, source, dest=None, overwrite=False, fs=None): """ if fs is None: from .local import LocalFsClient + fs = LocalFsClient() return fs.download(source, dest, overwrite, self) @@ -721,43 +759,44 @@ def upload(self, source, dest=None, overwrite=False, fs=None): def _register_magics(self, base_name): from IPython.core.magic import register_line_magic, register_cell_magic - @register_line_magic("{}.listdir".format(base_name)) + @register_line_magic(f"{base_name}.listdir") @process_line_arguments - def listdir(path=''): + def listdir(path=""): return self.listdir(path) - @register_line_magic("{}.showdir".format(base_name)) + @register_line_magic(f"{base_name}.showdir") @process_line_arguments - def showdir(path=''): + def showdir(path=""): return self.showdir(path) - @register_line_magic("{}.read".format(base_name)) + @register_line_magic(f"{base_name}.read") @process_line_arguments def read_file(path): with self.open(path) as f: return f.read() - @register_cell_magic("{}.write".format(base_name)) + @register_cell_magic(f"{base_name}.write") @process_line_arguments def write_file(cell, path): - with self.open(path, 'w') as f: + with self.open(path, "w") as f: f.write(cell) # PyArrow compat @property def pyarrow_fs(self): from ._pyarrow_compat import OmniductFileSystem + return OmniductFileSystem(self) -class FileSystemFile(object): +class FileSystemFile: """ A file-like implementation that is interchangeable with native Python file objects, allowing remote files to be treated identically to local files both by omniduct, the user and other libraries. """ - def __init__(self, fs, path, mode='r'): + def __init__(self, fs, path, mode="r"): self.fs = fs self.path = path self.mode = mode @@ -770,8 +809,10 @@ def __init__(self, fs, path, mode='r'): else: self.__io_buffer = io.StringIO() - if 'w' not in self.mode: - self.__io_buffer.write(self.fs._file_read(self.path, binary=self.binary_mode)) + if "w" not in self.mode: + self.__io_buffer.write( + self.fs._file_read(self.path, binary=self.binary_mode) + ) if not self.appending: self.__io_buffer.seek(0) @@ -787,20 +828,22 @@ def mode(self): def mode(self, mode): try: assert len(set(mode)) == len(mode) - assert sum(opt in mode for opt in ['r', 'w', 'a', '+', 't', 'b']) == len(mode) - assert sum(opt in mode for opt in ['r', 'w', 'a']) == 1 - assert sum(opt in mode for opt in ['t', 'b']) < 2 - except AssertionError: - raise ValueError("invalid mode: '{}'".format(mode)) + assert sum(opt in mode for opt in ["r", "w", "a", "+", "t", "b"]) == len( + mode + ) + assert sum(opt in mode for opt in ["r", "w", "a"]) == 1 + assert sum(opt in mode for opt in ["t", "b"]) < 2 + except AssertionError as e: + raise ValueError(f"invalid mode: '{mode}'") from e self.__mode = mode @property def readable(self): - return 'r' in self.mode or '+' in self.mode + return "r" in self.mode or "+" in self.mode @property def writable(self): - return 'w' in self.mode or 'a' in self.mode or '+' in self.mode + return "w" in self.mode or "a" in self.mode or "+" in self.mode @property def seekable(self): @@ -808,16 +851,16 @@ def seekable(self): @property def appending(self): - return 'a' in self.mode + return "a" in self.mode @property def binary_mode(self): - return 'b' in self.mode + return "b" in self.mode def __enter__(self): return self - def __exit__(self, type, value, tb): + def __exit__(self, type, value, tb): # pylint: disable=redefined-builtin self.close() def close(self): @@ -844,7 +887,7 @@ def isatty(self): @property def newlines(self): - return '\n' # TODO: Support non-Unix newlines? + return "\n" # TODO: Support non-Unix newlines? def read(self, size=-1): if not self.readable: @@ -894,27 +937,32 @@ def detach(self): def readinto(self, buffer): data = self.read() - buffer[:len(data)] = data + buffer[: len(data)] = data return len(data) def readinto1(self, buffer): return self.readinto(buffer) -class FileSystemFileDesc(namedtuple('Node', [ - 'fs', - 'path', - 'name', - 'type', - 'bytes', - 'owner', - 'group', - 'permissions', - 'created', - 'last_modified', - 'last_accessed', - 'extra', -])): +class FileSystemFileDesc( + namedtuple( + "Node", + [ + "fs", + "path", + "name", + "type", + "bytes", + "owner", + "group", + "permissions", + "created", + "last_modified", + "last_accessed", + "extra", + ], + ) +): """ A representation of a file/directory stored within an Omniduct FileSystemClient. @@ -922,64 +970,85 @@ class FileSystemFileDesc(namedtuple('Node', [ __slots__ = () - def __new__(cls, fs, path, name, type, bytes=None, owner=None, - group=None, permissions=None, created=None, last_modified=None, - last_accessed=None, **extra): - assert type in ('directory', 'file') - return ( - super(FileSystemFileDesc, cls) - .__new__(cls, - fs=fs, - path=path, - name=name, - type=type, - bytes=bytes, - owner=owner, - group=group, - permissions=permissions, - created=created, - last_modified=last_modified, - last_accessed=last_accessed, - extra=extra) + def __new__( + cls, + fs, + path, + name, + type, # pylint: disable=redefined-builtin + bytes=None, # pylint: disable=redefined-builtin + owner=None, + group=None, + permissions=None, + created=None, + last_modified=None, + last_accessed=None, + **extra, + ): + assert type in ("directory", "file") + return super(FileSystemFileDesc, cls).__new__( + cls, + fs=fs, + path=path, + name=name, + type=type, + bytes=bytes, + owner=owner, + group=group, + permissions=permissions, + created=created, + last_modified=last_modified, + last_accessed=last_accessed, + extra=extra, ) def as_dict(self): - d = OrderedDict([ - ('fs', self.fs), - ('path', self.path), - ('type', self.type), - ('name', self.name), - ('bytes', self.bytes), - ('owner', self.owner), - ('group', self.group), - ('permissions', self.permissions), - ('created', self.created), - ('last_modified', self.last_modified), - ('last_accessed', self.last_accessed), - ]) + d = OrderedDict( + [ + ("fs", self.fs), + ("path", self.path), + ("type", self.type), + ("name", self.name), + ("bytes", self.bytes), + ("owner", self.owner), + ("group", self.group), + ("permissions", self.permissions), + ("created", self.created), + ("last_modified", self.last_modified), + ("last_accessed", self.last_accessed), + ] + ) d.update(self.extra) return d # Convenience methods - def open(self, mode='rt'): - assert self.type == 'file', "`.open(...)` is only appropriate for files." + def open(self, mode="rt"): + assert self.type == "file", "`.open(...)` is only appropriate for files." return self.fs.open(self.path, mode=mode) def dir(self): - assert self.type == 'directory', "`.dir(...)` is only appropriate for directories." + assert ( + self.type == "directory" + ), "`.dir(...)` is only appropriate for directories." return self.fs.dir(self.path) def listdir(self): - assert self.type == 'directory', "`.listdir(...)` is only appropriate for directories." + assert ( + self.type == "directory" + ), "`.listdir(...)` is only appropriate for directories." return self.fs.listdir(self.path) def showdir(self): - assert self.type == 'directory', "`.showdir(...)` is only appropriate for directories." + assert ( + self.type == "directory" + ), "`.showdir(...)` is only appropriate for directories." return self.fs.showdir(self.path) def find(self, **attrs): - assert self.type == 'directory', "`.find(...)` is only appropriate for directories." + assert ( + self.type == "directory" + ), "`.find(...)` is only appropriate for directories." return self.fs.find(self.path, **attrs) def download(self, dest=None, overwrite=False, fs=None): diff --git a/omniduct/filesystems/local.py b/omniduct/filesystems/local.py index fa4cc00..7827a4b 100644 --- a/omniduct/filesystems/local.py +++ b/omniduct/filesystems/local.py @@ -1,9 +1,9 @@ +# pylint: disable=abstract-method # We export a different type of file-handle. + import datetime import errno import os import shutil -import six -import sys from io import open from interface_meta import override @@ -24,7 +24,7 @@ class LocalFsClient(FileSystemClient): ``` """ - PROTOCOLS = ['localfs'] + PROTOCOLS = ["localfs"] @override def _init(self): @@ -32,8 +32,10 @@ def _init(self): @override def _prepare(self): - assert self.remote is None, "LocalFsClient cannot be used in conjunction with a remote client." - super(LocalFsClient, self)._prepare() + assert ( + self.remote is None + ), "LocalFsClient cannot be used in conjunction with a remote client." + super()._prepare() @override def _connect(self): @@ -50,7 +52,7 @@ def _disconnect(self): # File enumeration @override def _path_home(self): - return os.path.expanduser('~') + return os.path.expanduser("~") @override def _path_separator(self): @@ -77,28 +79,34 @@ def _dir(self, path): attrs = {} - if os.name == 'posix': + if os.name == "posix": import grp import pwd stat = os.stat(f_path) - attrs.update({ - 'owner': pwd.getpwuid(stat.st_uid).pw_name, - 'group': grp.getgrgid(stat.st_gid).gr_name, - 'permissions': oct(stat.st_mode), - 'created': str(datetime.datetime.fromtimestamp(stat.st_ctime)), - 'last_modified': str(datetime.datetime.fromtimestamp(stat.st_mtime)), - 'last_accessed': str(datetime.datetime.fromtimestamp(stat.st_atime)), - }) + attrs.update( + { + "owner": pwd.getpwuid(stat.st_uid).pw_name, + "group": grp.getgrgid(stat.st_gid).gr_name, + "permissions": oct(stat.st_mode), + "created": str(datetime.datetime.fromtimestamp(stat.st_ctime)), + "last_modified": str( + datetime.datetime.fromtimestamp(stat.st_mtime) + ), + "last_accessed": str( + datetime.datetime.fromtimestamp(stat.st_atime) + ), + } + ) yield FileSystemFileDesc( fs=self, path=f_path, name=f, - type='directory' if os.path.isdir(f_path) else 'file', + type="directory" if os.path.isdir(f_path) else "file", bytes=os.path.getsize(f_path), - **attrs + **attrs, ) @override @@ -111,7 +119,7 @@ def _mkdir(self, path, recursive, exist_ok): os.makedirs(path) if recursive else os.mkdir(path) except OSError as exc: # Python >2.5 if exc.errno != errno.EEXIST or not exist_ok or not os.path.isdir(path): - six.reraise(*sys.exc_info()) + raise @override def _remove(self, path, recursive): @@ -123,4 +131,4 @@ def _remove(self, path, recursive): # File opening @override def _open(self, path, mode): - return open(path, mode=mode, encoding=None if 'b' in mode else 'utf-8') + return open(path, mode=mode, encoding=None if "b" in mode else "utf-8") diff --git a/omniduct/filesystems/s3.py b/omniduct/filesystems/s3.py index e904eff..1e85bd7 100644 --- a/omniduct/filesystems/s3.py +++ b/omniduct/filesystems/s3.py @@ -1,15 +1,10 @@ +# pylint: disable=attribute-defined-outside-init import logging from interface_meta import override from omniduct.filesystems.base import FileSystemClient, FileSystemFileDesc -# Python 2 compatibility imports -try: - FileNotFoundError -except NameError: - FileNotFoundError = IOError - class S3Client(FileSystemClient): """ @@ -24,12 +19,19 @@ class S3Client(FileSystemClient): library, which is also aware of environment variables. """ - PROTOCOLS = ['s3'] + PROTOCOLS = ["s3"] DEFAULT_PORT = 80 @override - def _init(self, bucket=None, aws_profile=None, use_opinel=False, - session=None, path_separator='/', skip_hadoop_artifacts=True): + def _init( + self, + bucket=None, + aws_profile=None, + use_opinel=False, + session=None, + path_separator="/", + skip_hadoop_artifacts=True, + ): """ bucket (str): The name of the Amazon S3 bucket to use. aws_profile (str): The name of configured AWS profile to use. This should @@ -60,7 +62,9 @@ def _init(self, bucket=None, aws_profile=None, use_opinel=False, environments to subclass `S3Client` and override the `_get_boto3_session` method to suit your needs. """ - assert bucket is not None, 'S3 Bucket must be specified using the `bucket` kwarg.' + assert ( + bucket is not None + ), "S3 Bucket must be specified using the `bucket` kwarg." self.bucket = bucket self.aws_profile = aws_profile self.use_opinel = use_opinel @@ -71,18 +75,17 @@ def _init(self, bucket=None, aws_profile=None, use_opinel=False, # Ensure self.host is updated with correct AWS region import boto3 - self.host = 'autoscaling.{}.amazonaws.com'.format( - (session or boto3.Session(profile_name=self.aws_profile)).region_name or 'us-east-1' - ) + + self.host = f"autoscaling.{(session or boto3.Session(profile_name=self.aws_profile)).region_name or 'us-east-1'}.amazonaws.com" # Mask logging from botocore's vendored libraries - logging.getLogger('botocore.vendored').setLevel(100) + logging.getLogger("botocore.vendored").setLevel(100) @override def _connect(self): self._session = self._session or self._get_boto3_session() - self._client = self._session.client('s3') - self._resource = self._session.resource('s3') + self._client = self._session.client("s3") + self._resource = self._session.resource("s3") def _get_boto3_session(self): import boto3 @@ -94,9 +97,9 @@ def _get_boto3_session(self): self._credentials = read_creds(self.aws_profile) return boto3.Session( - aws_access_key_id=self._credentials['AccessKeyId'], - aws_secret_access_key=self._credentials['SecretAccessKey'], - aws_session_token=self._credentials['SessionToken'], + aws_access_key_id=self._credentials["AccessKeyId"], + aws_secret_access_key=self._credentials["SecretAccessKey"], + aws_session_token=self._credentials["SessionToken"], profile_name=self.aws_profile, ) @@ -108,14 +111,17 @@ def _is_connected(self): return False # Check if still able to perform requests against AWS import botocore + try: self._client.list_buckets() + return True except botocore.exceptions.ClientError as e: if len(e.args) > 0: - if 'ExpiredToken' in e.args[0] or 'InvalidToken' in e.args[0]: + if "ExpiredToken" in e.args[0] or "InvalidToken" in e.args[0]: return False - elif 'AccessDenied' in e.args[0]: + if "AccessDenied" in e.args[0]: return True + return False @override def _disconnect(self): @@ -137,36 +143,36 @@ def _exists(self, path): def _s3_path(self, path): if path.startswith(self.path_separator): - path = path[len(self.path_separator):] + path = path[len(self.path_separator) :] if path.endswith(self.path_separator): - path = path[:-len(self.path_separator)] + path = path[: -len(self.path_separator)] return path @override def _isdir(self, path): response = next(iter(self.__dir_paginator(path))) - if 'CommonPrefixes' in response or 'Contents' in response: + if "CommonPrefixes" in response or "Contents" in response: return True return False @override def _isfile(self, path): try: - self._client.get_object(Bucket=self.bucket, Key=self._s3_path(path) or '') + self._client.get_object(Bucket=self.bucket, Key=self._s3_path(path) or "") return True - except: + except: # pylint: disable=bare-except return False # Directory handling and enumeration def __dir_paginator(self, path): path = self._s3_path(path) - paginator = self._client.get_paginator('list_objects') + paginator = self._client.get_paginator("list_objects") iterator = paginator.paginate( Bucket=self.bucket, - Prefix=path + (self.path_separator if path else ''), + Prefix=path + (self.path_separator if path else ""), Delimiter=self.path_separator, - PaginationConfig={'PageSize': 500} + PaginationConfig={"PageSize": 500}, ) return iterator @@ -175,24 +181,28 @@ def _dir(self, path): iterator = self.__dir_paginator(path) for response_data in iterator: - for prefix in response_data.get('CommonPrefixes', []): + for prefix in response_data.get("CommonPrefixes", []): yield FileSystemFileDesc( fs=self, - path=prefix['Prefix'][:-len(self.path_separator)], - name=prefix['Prefix'][:-len(self.path_separator)].split(self.path_separator)[-1], # Remove trailing slash - type='directory', + path=prefix["Prefix"][: -len(self.path_separator)], + name=prefix["Prefix"][: -len(self.path_separator)].split( + self.path_separator + )[ + -1 + ], # Remove trailing slash + type="directory", ) - for prefix in response_data.get('Contents', []): - if self.skip_hadoop_artifacts and prefix['Key'].endswith('_$folder$'): + for prefix in response_data.get("Contents", []): + if self.skip_hadoop_artifacts and prefix["Key"].endswith("_$folder$"): continue yield FileSystemFileDesc( fs=self, - path=prefix['Key'], - name=prefix['Key'].split(self.path_separator)[-1], - type='file', - bytes=prefix['Size'], - owner=prefix['Owner']['DisplayName'] if 'Owner' in prefix else None, - last_modified=prefix['LastModified'] + path=prefix["Key"], + name=prefix["Key"].split(self.path_separator)[-1], + type="file", + bytes=prefix["Size"], + owner=prefix["Owner"]["DisplayName"] if "Owner" in prefix else None, + last_modified=prefix["LastModified"], ) # TODO: Interestingly, directly using Amazon S3 methods seems slower than generic approach. Hypothesis: keys is not asynchronous. @@ -225,24 +235,30 @@ def _remove(self, path, recursive): bucket = self._resource.Bucket(self.bucket) to_delete = [] for obj in bucket.objects.filter(Prefix=path + self.path_separator): - to_delete.append({'Key': obj.key}) - if len(to_delete) == 1000: # Maximum number of simultaneous deletes is 1000 - self._client.delete_objects(Bucket=self.bucket, Delete={'Objects': to_delete}) + to_delete.append({"Key": obj.key}) + if ( + len(to_delete) == 1000 + ): # Maximum number of simultaneous deletes is 1000 + self._client.delete_objects( + Bucket=self.bucket, Delete={"Objects": to_delete} + ) to_delete = [] - self._client.delete_objects(Bucket=self.bucket, Delete={'Objects': to_delete}) + self._client.delete_objects( + Bucket=self.bucket, Delete={"Objects": to_delete} + ) self._client.delete_object(Bucket=self.bucket, Key=path) # File handling @override def _file_read_(self, path, size=-1, offset=0, binary=False): if not self.isfile(path): - raise FileNotFoundError("File `{}` does not exist.".format(path)) + raise FileNotFoundError(f"File `{path}` does not exist.") obj = self._resource.Object(self.bucket, self._s3_path(path)) - body = obj.get()['Body'].read() + body = obj.get()["Body"].read() if not binary: - body = body.decode('utf-8') + body = body.decode("utf-8") if offset > 0: body = body[offset:] if size >= 0: @@ -251,12 +267,14 @@ def _file_read_(self, path, size=-1, offset=0, binary=False): @override def _file_append_(self, path, s, binary): - raise NotImplementedError("Support for S3 append operation has yet to be implemented.") + raise NotImplementedError( + "Support for S3 append operation has yet to be implemented." + ) @override def _file_write_(self, path, s, binary): obj = self._resource.Object(self.bucket, self._s3_path(path)) if not binary: - s = s.encode('utf-8') + s = s.encode("utf-8") obj.put(Body=s) return True diff --git a/omniduct/filesystems/stub.py b/omniduct/filesystems/stub.py index 3443b37..2d2bbbc 100644 --- a/omniduct/filesystems/stub.py +++ b/omniduct/filesystems/stub.py @@ -2,7 +2,6 @@ class StubFsClient(FileSystemClient): - PROTOCOLS = [] DEFAULT_PORT = None diff --git a/omniduct/filesystems/webhdfs.py b/omniduct/filesystems/webhdfs.py index 2b5a500..fe2b4f3 100644 --- a/omniduct/filesystems/webhdfs.py +++ b/omniduct/filesystems/webhdfs.py @@ -7,12 +7,6 @@ from .base import FileSystemClient, FileSystemFileDesc from .local import LocalFsClient -# Python 2 compatibility imports -try: - FileNotFoundError -except NameError: - FileNotFoundError = IOError - class WebHdfsClient(FileSystemClient): """ @@ -23,12 +17,18 @@ class WebHdfsClient(FileSystemClient): the HDFS cluster in form ":". """ - PROTOCOLS = ['webhdfs'] + PROTOCOLS = ["webhdfs"] DEFAULT_PORT = 50070 @override - def _init(self, namenodes=None, auto_conf=False, auto_conf_cluster=None, - auto_conf_path=None, **kwargs): + def _init( + self, + namenodes=None, + auto_conf=False, + auto_conf_cluster=None, + auto_conf_path=None, + **kwargs, + ): """ namenodes (list): A list of hosts that are acting as namenodes for the HDFS cluster in form ":". @@ -48,31 +48,41 @@ def _init(self, namenodes=None, auto_conf=False, auto_conf_cluster=None, if auto_conf: from ._webhdfs_helpers import CdhHdfsConfParser - assert auto_conf_cluster is not None, "You must specify a cluster via `auto_conf_cluster` for auto-detection to work." + assert ( + auto_conf_cluster is not None + ), "You must specify a cluster via `auto_conf_cluster` for auto-detection to work." def get_host_and_set_namenodes(duct, cluster, conf_path): - conf_parser = CdhHdfsConfParser(duct.remote or LocalFsClient(), conf_path=conf_path) + conf_parser = CdhHdfsConfParser( + duct.remote or LocalFsClient(), conf_path=conf_path + ) duct.namenodes = conf_parser.namenodes(cluster) return random.choice(duct.namenodes) - self._host = partial(get_host_and_set_namenodes, cluster=auto_conf_cluster, conf_path=auto_conf_path) + self._host = partial( + get_host_and_set_namenodes, + cluster=auto_conf_cluster, + conf_path=auto_conf_path, + ) elif not self._host and namenodes: self._host = random.choice(self.namenodes) self.__webhdfs = None self.__webhdfs_kwargs = kwargs - self.prepared_fields += ('namenodes',) + self.prepared_fields += ("namenodes",) @override def _connect(self): from ._webhdfs_helpers import OmniductPyWebHdfsClient + + # pylint: disable-next=attribute-defined-outside-init self.__webhdfs = OmniductPyWebHdfsClient( host=self._host, port=self._port, remote=self.remote, namenodes=self.namenodes, user_name=self.username, - **self.__webhdfs_kwargs + **self.__webhdfs_kwargs, ) @override @@ -81,11 +91,12 @@ def _is_connected(self): if self.remote and not self.remote.is_connected(): return False return self.__webhdfs is not None - except: + except: # pylint: disable=bare-except return False @override def _disconnect(self): + # pylint: disable-next=attribute-defined-outside-init self.__webhdfs = None # Path properties and helpers @@ -95,12 +106,13 @@ def _path_home(self): @override def _path_separator(self): - return '/' + return "/" # File node properties @override def _exists(self, path): from pywebhdfs.errors import FileNotFound + try: self.__webhdfs.get_file_dir_status(path) return True @@ -110,18 +122,20 @@ def _exists(self, path): @override def _isdir(self, path): from pywebhdfs.errors import FileNotFound + try: stats = self.__webhdfs.get_file_dir_status(path) - return stats['FileStatus']['type'] == 'DIRECTORY' + return stats["FileStatus"]["type"] == "DIRECTORY" except FileNotFound: return False @override def _isfile(self, path): from pywebhdfs.errors import FileNotFound + try: stats = self.__webhdfs.get_file_dir_status(path) - return stats['FileStatus']['type'] == 'FILE' + return stats["FileStatus"]["type"] == "FILE" except FileNotFound: return False @@ -129,27 +143,27 @@ def _isfile(self, path): @override def _dir(self, path): files = self.__webhdfs.list_dir(path) - for f in files['FileStatuses']['FileStatus']: + for f in files["FileStatuses"]["FileStatus"]: yield FileSystemFileDesc( fs=self, - path=posixpath.join(path, f['pathSuffix']), - name=f['pathSuffix'], - type=f['type'].lower(), - bytes=f['length'], - owner=f['owner'], - group=f['group'], - last_modified=f['modificationTime'], - last_accessed=f['accessTime'], - permissions=f['permission'], - replication=f['replication'] + path=posixpath.join(path, f["pathSuffix"]), + name=f["pathSuffix"], + type=f["type"].lower(), + bytes=f["length"], + owner=f["owner"], + group=f["group"], + last_modified=f["modificationTime"], + last_accessed=f["accessTime"], + permissions=f["permission"], + replication=f["replication"], ) @override def _mkdir(self, path, recursive, exist_ok): if not recursive and not self._isdir(self.path_basename(path)): - raise IOError("No parent directory found for {}.".format(path)) + raise IOError(f"No parent directory found for {path}.") if not exist_ok and self._exists(path): - raise IOError("Path already exists at {}.".format(path)) + raise IOError(f"Path already exists at {path}.") self.__webhdfs.make_dir(path) @override @@ -160,11 +174,13 @@ def _remove(self, path, recursive): @override def _file_read_(self, path, size=-1, offset=0, binary=False): if not self.isfile(path): - raise FileNotFoundError("File `{}` does not exist.".format(path)) + raise FileNotFoundError(f"File `{path}` does not exist.") - read = self.__webhdfs.read_file(path, offset=offset, length='null' if size < 0 else size) + read = self.__webhdfs.read_file( + path, offset=offset, length="null" if size < 0 else size + ) if not binary: - read = read.decode('utf-8') + read = read.decode("utf-8") return read @override diff --git a/omniduct/protocols.py b/omniduct/protocols.py index 8014788..18b4d07 100644 --- a/omniduct/protocols.py +++ b/omniduct/protocols.py @@ -1,5 +1,3 @@ -# flake8: noqa - # Omniduct's automatic registration of Duct protocols requires that the subclass # implementation be loaded into memory. Any protocol that should be enabled by # default should be imported here. @@ -18,3 +16,20 @@ from .remotes.ssh import SSHClient from .remotes.ssh_paramiko import ParamikoSSHClient from .restful.base import RestClient + +__all__ = [ + "FileSystemCache", + "DruidClient", + "ExasolClient", + "HiveServer2Client", + "Neo4jClient", + "PrestoClient", + "PySparkClient", + "SQLAlchemyClient", + "LocalFsClient", + "S3Client", + "WebHdfsClient", + "SSHClient", + "ParamikoSSHClient", + "RestClient", +] diff --git a/omniduct/registry.py b/omniduct/registry.py index 40e9c4e..93f4348 100644 --- a/omniduct/registry.py +++ b/omniduct/registry.py @@ -1,4 +1,3 @@ -import six import yaml from omniduct.duct import Duct @@ -8,7 +7,7 @@ from omniduct.utils.proxies import TreeProxy -class DuctRegistry(object): +class DuctRegistry: """ A convenient registry for `Duct` instances. @@ -30,7 +29,7 @@ def __init__(self, config=None): self.register_from_config(config) def __repr__(self): - return "".format(len(self._registry)) + return f"" # Registration methods def register(self, duct, name=None, override=False, register_magics=True): @@ -58,14 +57,18 @@ def register(self, duct, name=None, override=False, register_magics=True): """ name = name or duct.name if name is None: - raise ValueError("`Duct` instances must be named to be registered. Please either specify a name to this method call, or add a name to the Duct using `duct.name = '...'`.") - names = [n.strip() for n in name.split(',')] - for name in names: - if name in self._registry and not override: - raise ValueError("`Duct` with the same name ('{}') already present in the registry. Please pass `override=True` if you want to override the existing instance, or `name='...'` to specify a new name.".format(name)) + raise ValueError( + "`Duct` instances must be named to be registered. Please either specify a name to this method call, or add a name to the Duct using `duct.name = '...'`." + ) + aliases = [n.strip() for n in name.split(",")] + for alias in aliases: + if alias in self._registry and not override: + raise ValueError( + f"`Duct` with the same name ('{alias}') already present in the registry. Please pass `override=True` if you want to override the existing instance, or `name='...'` to specify a new name." + ) if register_magics and isinstance(duct, MagicsProvider): - duct.register_magics(base_name=name) - self._registry[name] = duct + duct.register_magics(base_name=alias) + self._registry[alias] = duct return duct def new(self, name, protocol, override=False, register_magics=True, **kwargs): @@ -90,13 +93,11 @@ def new(self, name, protocol, override=False, register_magics=True, **kwargs): """ return self.register( Duct.for_protocol(protocol)( - name=name.split(',')[0].strip(), - registry=self, - **kwargs + name=name.split(",")[0].strip(), registry=self, **kwargs ), name=name, override=override, - register_magics=register_magics + register_magics=register_magics, ) # Inspection and retrieval methods @@ -136,7 +137,9 @@ def lookup(self, name, kind=None): raise DuctNotFound(name) duct = self._registry[name] if kind and duct.DUCT_TYPE != kind: - raise DuctNotFound("Duct named '{}' exists, but is not of kind '{}'.".format(name, kind.value)) + raise DuctNotFound( + f"Duct named '{name}' exists, but is not of kind '{kind.value}'." + ) return duct # Exposing `Duct` instances. @@ -164,10 +167,15 @@ def populate_namespace(self, namespace=None, names=None, kinds=None): if namespace is None: namespace = {} if kinds is not None: - kinds = [Duct.Type(kind) if not isinstance(kind, Duct.Type) else kind for kind in kinds] + kinds = [ + Duct.Type(kind) if not isinstance(kind, Duct.Type) else kind + for kind in kinds + ] for name, duct in self._registry.items(): - if (kinds is None or duct.DUCT_TYPE in kinds) and (names is None or name in names): - namespace[name.split('/')[-1]] = duct + if (kinds is None or duct.DUCT_TYPE in kinds) and ( + names is None or name in names + ): + namespace[name.split("/")[-1]] = duct return namespace def get_proxy(self, by_kind=True): @@ -191,16 +199,17 @@ def get_proxy(self, by_kind=True): Returns: ServicesProxy: The proxy object. """ + def key_parser(k, v): - keys = k.split('/') - if by_kind and getattr(v, 'DUCT_TYPE', None) is not None: + keys = k.split("/") + if by_kind and getattr(v, "DUCT_TYPE", None) is not None: keys.insert(0, v.DUCT_TYPE.value) return keys dct = self._registry.copy() - dct['registry'] = self + dct["registry"] = self - return TreeProxy._for_dict(dct, key_parser=key_parser, name='services') + return TreeProxy._for_dict(dct, key_parser=key_parser, name="services") # Batch registration of duct configurations def register_from_config(self, config, override=False): @@ -236,22 +245,32 @@ def register_from_config(self, config, override=False): exception. """ # Extract configuration from a file if necessary, and then process it. - if isinstance(config, six.string_types): - if '\n' in config: + if isinstance(config, str): + if "\n" in config: config = yaml.safe_load(config) else: - with open(config) as f: + with open(config, encoding="utf-8") as f: config = yaml.safe_load(f.read()) config = self._process_config(config) for duct_config in config: - names = duct_config.pop('name') - protocol = duct_config.pop('protocol') - register_magics = duct_config.pop('register_magics', True) + names = duct_config.pop("name") + protocol = duct_config.pop("protocol") + register_magics = duct_config.pop("register_magics", True) try: - self.new(names, protocol, register_magics=register_magics, override=override, **duct_config) + self.new( + names, + protocol, + register_magics=register_magics, + override=override, + **duct_config, + ) except DuctProtocolUnknown as e: - logger.error("Failed to configure `Duct` instance(s) '{}'. {}".format("', '".join(names.split(',')), str(e))) + logger.error( + "Failed to configure `Duct` instance(s) '%s'. %s", + "', '".join(names.split(",")), + e, + ) return self @@ -261,18 +280,21 @@ def _process_config(self, config, name=None): arguments; each corresponding to a duct instance. """ - if isinstance(config, dict) and (name is not None or 'name' in config) and 'protocol' in config and not config.get('__OMNIDUCT_SKIP__', False): + if ( + isinstance(config, dict) + and (name is not None or "name" in config) + and "protocol" in config + and not config.get("__OMNIDUCT_SKIP__", False) + ): kwargs = config.copy() - if 'name' not in config: - kwargs['name'] = name + if "name" not in config: + kwargs["name"] = name yield kwargs elif isinstance(config, dict): - for name, subconfig in config.items(): - for config in self._process_config(subconfig, name=name): - yield config + for key, subconfig in config.items(): + yield from self._process_config(subconfig, name=key) elif isinstance(config, list): for subconfig in config: - for config in self._process_config(subconfig): - yield config + yield from self._process_config(subconfig) diff --git a/omniduct/remotes/base.py b/omniduct/remotes/base.py index 882e1cd..4a4d65e 100644 --- a/omniduct/remotes/base.py +++ b/omniduct/remotes/base.py @@ -2,9 +2,7 @@ import re from abc import abstractmethod -import six from interface_meta import quirk_docs, override -from future.utils import raise_with_traceback from omniduct.duct import Duct from omniduct.errors import DuctAuthenticationError, DuctServerUnreachable @@ -18,7 +16,7 @@ from urlparse import urlparse, urlunparse -class PortForwardingRegister(object): +class PortForwardingRegister: """ A register of all port forwards initiated by a particular Duct. """ @@ -38,7 +36,7 @@ def lookup(self, remote_host, remote_port): tuple, None: A tuple of local port and implementation-specific connection artifact, if it exists, and `None` otherwise. """ - return self._register.get('{}:{}'.format(remote_host, remote_port)) + return self._register.get(f"{remote_host}:{remote_port}") def lookup_port(self, remote_host, remote_port): """ @@ -55,6 +53,7 @@ def lookup_port(self, remote_host, remote_port): entry = self.lookup(remote_host, remote_port) if entry is not None: return entry[0] + return None def reverse_lookup(self, local_port): """ @@ -69,7 +68,7 @@ def reverse_lookup(self, local_port): """ for key, (port, connection) in self._register.items(): if port == local_port: - return key.split(':') + [connection] + return key.split(":") + [connection] return None def register(self, remote_host, remote_port, local_port, connection): @@ -82,9 +81,11 @@ def register(self, remote_host, remote_port, local_port, connection): local_port (int): The local port. connection (object): Implementation-specific connection artifact. """ - key = '{}:{}'.format(remote_host, remote_port) + key = f"{remote_host}:{remote_port}" if key in self._register: - raise RuntimeError("Remote host/port combination ({}) is already registered.".format(key)) + raise RuntimeError( + f"Remote host/port combination ({key}) is already registered." + ) self._register[key] = (local_port, connection) def deregister(self, remote_host, remote_port): @@ -99,7 +100,7 @@ def deregister(self, remote_host, remote_port): tuple: A tuple of local port and implementation-specific connection artifact, if it exists, and `None` otherwise. """ - return self._register.pop('{}:{}'.format(remote_host, remote_port)) + return self._register.pop(f"{remote_host}:{remote_port}") class RemoteClient(FileSystemClient): @@ -110,6 +111,7 @@ class RemoteClient(FileSystemClient): smartcard (dict): Mapping of smartcard names to system libraries compatible with `ssh-add -s '' ...`. """ + __doc_attrs = """ smartcard (dict): Mapping of smartcard names to system libraries compatible with `ssh-add -s '' ...`. @@ -118,8 +120,10 @@ class RemoteClient(FileSystemClient): DUCT_TYPE = Duct.Type.REMOTE DEFAULT_PORT = None - @quirk_docs('_init', mro=True) - def __init__(self, smartcards=None, **kwargs): + @quirk_docs("_init", mro=True) + def __init__( + self, smartcards=None, **kwargs + ): # pylint: disable=super-init-not-called """ Args: smartcards (dict): Mapping of smartcard names to system libraries @@ -157,13 +161,10 @@ def connect(self): """ try: Duct.connect(self) - except DuctServerUnreachable as e: - raise_with_traceback(e) - except DuctAuthenticationError as e: + except DuctAuthenticationError: if self.smartcards and self.prepare_smartcards(): Duct.connect(self) - else: - raise_with_traceback(e) + raise return self def prepare_smartcards(self): @@ -189,26 +190,34 @@ def prepare_smartcards(self): def _prepare_smartcard(self, name, filename): import pexpect - remover = pexpect.spawn('ssh-add -e "{}"'.format(filename)) + remover = pexpect.spawn(f'ssh-add -e "{filename}"') i = remover.expect(["Card removed:", "Could not remove card", pexpect.TIMEOUT]) if i == 2: - raise RuntimeError("Unable to reset card using ssh-agent. Output of ssh-agent was: \n{}\n\n" - "Please report this error!".format(remover.before)) + raise RuntimeError( + f"Unable to reset card using ssh-agent. Output of ssh-agent was: \n{remover.before}\n\nPlease report this error!" + ) - adder = pexpect.spawn('ssh-add -s "{}" -t 14400'.format(filename)) - i = adder.expect(['Enter passphrase for PKCS#11:', pexpect.TIMEOUT]) + adder = pexpect.spawn(f'ssh-add -s "{filename}" -t 14400') + i = adder.expect(["Enter passphrase for PKCS#11:", pexpect.TIMEOUT]) if i == 0: - adder.sendline(getpass.getpass('Please enter your passcode to unlock your "{}" smartcard: '.format(name))) + adder.sendline( + getpass.getpass( + f'Please enter your passcode to unlock your "{name}" smartcard: ' + ) + ) else: - raise RuntimeError("Unable to add card using ssh-agent. Output of ssh-agent was: \n{}\n\n" - "Please report this error!".format(remover.before)) - i = adder.expect(['Card added:', pexpect.TIMEOUT]) + raise RuntimeError( + f"Unable to add card using ssh-agent. Output of ssh-agent was: \n{remover.before}\n\nPlease report this error!" + ) + i = adder.expect(["Card added:", pexpect.TIMEOUT]) if i != 0: - raise RuntimeError("Unexpected error while adding card. Check your passcode and try again.") + raise RuntimeError( + "Unexpected error while adding card. Check your passcode and try again." + ) return True - @quirk_docs('_execute') + @quirk_docs("_execute") @require_connection def execute(self, cmd, **kwargs): """ @@ -232,20 +241,30 @@ def _execute(self, cmd, **kwargs): # Port forwarding code def _extract_host_and_ports(self, remote_host, remote_port, local_port): - assert remote_host is None or isinstance(remote_host, six.string_types), "Remote host, if specified, must be a string of form 'hostname(:port)'." - assert remote_port is None or isinstance(remote_port, int), "Remote port, if specified, must be an integer." - assert local_port is None or isinstance(local_port, int), "Local port, if specified, must be an integer." + assert remote_host is None or isinstance( + remote_host, str + ), "Remote host, if specified, must be a string of form 'hostname(:port)'." + assert remote_port is None or isinstance( + remote_port, int + ), "Remote port, if specified, must be an integer." + assert local_port is None or isinstance( + local_port, int + ), "Local port, if specified, must be an integer." host = port = None if remote_host is not None: - m = re.match(r'(?P[a-zA-Z0-9\-.]+)(?::(?P[0-9]+))?', remote_host) - assert m, "Host not valid: {}. Must be a string of form 'hostname(:port)'.".format(remote_host) - - host = m.group('host') - port = m.group('port') or remote_port + m = re.match( + r"(?P[a-zA-Z0-9\-.]+)(?::(?P[0-9]+))?", remote_host + ) + assert ( + m + ), f"Host not valid: {remote_host}. Must be a string of form 'hostname(:port)'." + + host = m.group("host") + port = m.group("port") or remote_port return host, port, local_port - @quirk_docs('_port_forward_start') + @quirk_docs("_port_forward_start") @require_connection def port_forward(self, remote_host, remote_port=None, local_port=None): """ @@ -268,12 +287,16 @@ def port_forward(self, remote_host, remote_port=None, local_port=None): """ # Hostname and port extraction - remote_host, remote_port, local_port = self._extract_host_and_ports(remote_host, remote_port, local_port) + remote_host, remote_port, local_port = self._extract_host_and_ports( + remote_host, remote_port, local_port + ) assert remote_host is not None, "Remote host must be specified." assert remote_port is not None, "Remote port must be specified." # Actual port forwarding - registered_port = self.__port_forwarding_register.lookup_port(remote_host, remote_port) + registered_port = self.__port_forwarding_register.lookup_port( + remote_host, remote_port + ) if registered_port is not None: if local_port is not None and registered_port != local_port: self.port_forward_stop(registered_port) @@ -283,12 +306,18 @@ def port_forward(self, remote_host, remote_port=None, local_port=None): if local_port is None: local_port = get_free_local_port() else: - assert is_local_port_free(local_port), "Specified local port is in use, and cannot be used." + assert is_local_port_free( + local_port + ), "Specified local port is in use, and cannot be used." if not self.is_port_bound(remote_host, remote_port): - raise DuctServerUnreachable("Server specified for port forwarding ({}:{}) is unreachable via '{}' ({}).".format(remote_host, remote_port, self.name, self.__class__.__name__)) + raise DuctServerUnreachable( + f"Server specified for port forwarding ({remote_host}:{remote_port}) is unreachable via '{self.name}' ({self.__class__.__name__})." + ) connection = self._port_forward_start(local_port, remote_host, remote_port) - self.__port_forwarding_register.register(remote_host, remote_port, local_port, connection) + self.__port_forwarding_register.register( + remote_host, remote_port, local_port, connection + ) return local_port @@ -308,16 +337,24 @@ def has_port_forward(self, remote_host=None, remote_port=None, local_port=None): port forwarding. """ # Hostname and port extraction - remote_host, remote_port, local_port = self._extract_host_and_ports(remote_host, remote_port, local_port) + remote_host, remote_port, local_port = self._extract_host_and_ports( + remote_host, remote_port, local_port + ) - assert remote_host is not None and remote_port is not None or local_port is not None, "Either remote host and port must be specified, or the local port must be specified." + assert ( + remote_host is not None + and remote_port is not None + or local_port is not None + ), "Either remote host and port must be specified, or the local port must be specified." if remote_host is not None and remote_port is not None: - return self.__port_forwarding_register.lookup(remote_host, remote_port) is not None - else: - return self.__port_forwarding_register.reverse_lookup(local_port) is not None + return ( + self.__port_forwarding_register.lookup(remote_host, remote_port) + is not None + ) + return self.__port_forwarding_register.reverse_lookup(local_port) is not None - @quirk_docs('_port_forward_stop') + @quirk_docs("_port_forward_stop") def port_forward_stop(self, local_port=None, remote_host=None, remote_port=None): """ Disconnect an existing port forward connection. @@ -333,14 +370,26 @@ def port_forward_stop(self, local_port=None, remote_host=None, remote_port=None) local_port (int, None): The port used locally. """ # Hostname and port extraction - remote_host, remote_port, local_port = self._extract_host_and_ports(remote_host, remote_port, local_port) + remote_host, remote_port, local_port = self._extract_host_and_ports( + remote_host, remote_port, local_port + ) - assert remote_host is not None and remote_port is not None or local_port is not None, "Either remote host and port must be specified, or the local port must be specified." + assert ( + remote_host is not None + and remote_port is not None + or local_port is not None + ), "Either remote host and port must be specified, or the local port must be specified." if remote_host is not None and remote_port is not None: - local_port, connection = self.__port_forwarding_register.lookup(remote_host, remote_port) + local_port, connection = self.__port_forwarding_register.lookup( + remote_host, remote_port + ) else: - remote_host, remote_port, connection = self.__port_forwarding_register.reverse_lookup(local_port) + ( + remote_host, + remote_port, + connection, + ) = self.__port_forwarding_register.reverse_lookup(local_port) self._port_forward_stop(local_port, remote_host, remote_port, connection) self.__port_forwarding_register.deregister(remote_host, remote_port) @@ -367,7 +416,11 @@ def get_local_uri(self, uri): str: A local uri that tunnels all traffic to the remote host. """ parsed_uri = urlparse(uri) - return urlunparse(parsed_uri._replace(netloc='localhost:{}'.format(self.port_forward(parsed_uri.netloc)))) + return urlunparse( + parsed_uri._replace( + netloc=f"localhost:{self.port_forward(parsed_uri.netloc)}" + ) + ) def show_port_forwards(self): """ @@ -375,8 +428,16 @@ def show_port_forwards(self): """ if len(self.__port_forwarding_register._register) == 0: print("No port forwards currently in use.") - for remote_host, (local_port, _) in self.__port_forwarding_register._register.items(): - print("localhost:{}".format(local_port), "->", remote_host, "(on {})".format(self._host)) + for remote_host, ( + local_port, + _, + ) in self.__port_forwarding_register._register.items(): + print( + f"localhost:{local_port}", + "->", + remote_host, + f"(on {self._host})", + ) @abstractmethod def _port_forward_start(self, local_port, remote_host, remote_port): @@ -386,7 +447,7 @@ def _port_forward_start(self, local_port, remote_host, remote_port): def _port_forward_stop(self, local_port, remote_host, remote_port, connection): raise NotImplementedError - @quirk_docs('_is_port_bound') + @quirk_docs("_is_port_bound") @require_connection def is_port_bound(self, host, port): """ diff --git a/omniduct/remotes/ssh.py b/omniduct/remotes/ssh.py index ed36068..65232ce 100644 --- a/omniduct/remotes/ssh.py +++ b/omniduct/remotes/ssh.py @@ -6,6 +6,7 @@ import tempfile from builtins import input from io import open +from shlex import quote as escape_path import pandas as pd from interface_meta import override @@ -17,13 +18,8 @@ from omniduct.utils.decorators import require_connection from omniduct.utils.processes import run_in_subprocess -try: # Python 3 - from shlex import quote as escape_path -except ImportError: # Python 2.7 - from pipes import quote as escape_path - -SSH_ASKPASS = '{omniduct_dir}/utils/ssh_askpass'.format(omniduct_dir=os.path.dirname(__file__)) +SSH_ASKPASS = f"{os.path.dirname(__file__)}/utils/ssh_askpass" SESSION_SSH_USERNAME = None SESSION_REMOTE_HOST = None SESSION_SSH_ASKPASS = False @@ -46,7 +42,7 @@ class SSHClient(RemoteClient): (default: `False`) """ - PROTOCOLS = ['ssh', 'ssh_cli'] + PROTOCOLS = ["ssh", "ssh_cli"] DEFAULT_PORT = 22 @override @@ -77,28 +73,32 @@ def _connect(self): if not os.path.exists(socket_dir): os.makedirs(socket_dir) # Create persistent master connection and exit. - cmd = ''.join([ - "ssh {login} -MT ", - "-S {socket} ", - "-o ControlPersist=yes ", - "-o StrictHostKeyChecking=no ", - "-o UserKnownHostsFile=/dev/null " if not self.check_known_hosts else "", - "-o NoHostAuthenticationForLocalhost=yes ", - "-o ServerAliveInterval=60 ", - "-o ServerAliveCountMax=2 ", - "'exit'", - ]).format(login=self._login_info, socket=self._socket_path) + cmd = "".join( + [ + "ssh {login} -MT ", + "-S {socket} ", + "-o ControlPersist=yes ", + "-o StrictHostKeyChecking=no ", + "-o UserKnownHostsFile=/dev/null " + if not self.check_known_hosts + else "", + "-o NoHostAuthenticationForLocalhost=yes ", + "-o ServerAliveInterval=60 ", + "-o ServerAliveCountMax=2 ", + "'exit'", + ] + ).format(login=self._login_info, socket=self._socket_path) expected = [ - "WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!", # 0 - "(?i)are you sure you want to continue connecting", # 1 - "(?i)(?:(?:password)|(?:passphrase for key)):", # 2 - "(?i)permission denied", # 3 - "(?i)terminal type", # 4 - pexpect.TIMEOUT, # 5 - "(?i)connection closed by remote host", # 6 - "(?i)could not resolve hostname", # 7 - pexpect.EOF # 8 + "WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!", # 0 + "(?i)are you sure you want to continue connecting", # 1 + "(?i)(?:(?:password)|(?:passphrase for key)):", # 2 + "(?i)permission denied", # 3 + "(?i)terminal type", # 4 + pexpect.TIMEOUT, # 5 + "(?i)connection closed by remote host", # 6 + "(?i)could not resolve hostname", # 7 + pexpect.EOF, # 8 ] try: @@ -106,65 +106,78 @@ def _connect(self): i = expect.expect(expected, timeout=10) # First phase - if i == 0: # If host identification changed, arrest any further attempts to connect - error_message = ( - 'Host identification for {} has changed! This is most likely ' - 'due to the the server being redeployed or reconfigured but ' - 'may also be due to a man-in-the-middle attack. If you trust ' - 'your network connection, you should be safe to update the ' - 'host keys for this host. To do this manually, please remove ' - 'the line corresponding to this host in ~/.ssh/known_hosts; ' - 'or call the `update_host_keys` method of this client.'.format(self._host) - ) + if ( + i == 0 + ): # If host identification changed, arrest any further attempts to connect + error_message = f"Host identification for {self._host} has changed! This is most likely due to the the server being redeployed or reconfigured but may also be due to a man-in-the-middle attack. If you trust your network connection, you should be safe to update the host keys for this host. To do this manually, please remove the line corresponding to this host in ~/.ssh/known_hosts; or call the `update_host_keys` method of this client." if self.interactive: logger.error(error_message) - auto_fix = input('Would you like this client to do this for you? (y/n)') - if auto_fix == 'y': + auto_fix = input( + "Would you like this client to do this for you? (y/n)" + ) + if auto_fix == "y": self.update_host_keys() - return self.connect() - else: - raise RuntimeError("Host keys not updated. Please update keys manually.") - else: - raise RuntimeError(error_message) - if i == 1: # Request to authorize host certificate (i.e. host not in the 'known_hosts' file) + self.connect() + return + raise RuntimeError( + "Host keys not updated. Please update keys manually." + ) + raise RuntimeError(error_message) + if ( + i == 1 + ): # Request to authorize host certificate (i.e. host not in the 'known_hosts' file) expect.sendline("yes") i = self.expect(expected) if i == 2: # Request for password/passphrase - expect.sendline(self.password or getpass.getpass('Password: ')) + expect.sendline(self.password or getpass.getpass("Password: ")) i = self.expect(expected) if i == 4: # Request for terminal type - expect.sendline('ascii') + expect.sendline("ascii") i = self.expect(expected) # Second phase - if i == 1: # Another request to authorize host certificate (i.e. host not in the 'known_hosts' file) - raise RuntimeError('Received a second request to authorize host key. This should not have happened!') - elif i in (2, 3): # Second request for password/passphrase or rejection of credentials. For now, give up. - raise DuctAuthenticationError('Invalid username and/or password, or private key is not unlocked.') - elif i == 4: # Another request for terminal type. - raise RuntimeError('Received a second request for terminal type. This should not have happened!') - elif i == 5: # Timeout + if ( + i == 1 + ): # Another request to authorize host certificate (i.e. host not in the 'known_hosts' file) + raise RuntimeError( + "Received a second request to authorize host key. This should not have happened!" + ) + if i in ( + 2, + 3, + ): # Second request for password/passphrase or rejection of credentials. For now, give up. + raise DuctAuthenticationError( + "Invalid username and/or password, or private key is not unlocked." + ) + if i == 4: # Another request for terminal type. + raise RuntimeError( + "Received a second request for terminal type. This should not have happened!" + ) + if i == 5: # Timeout # In our instance, this means that we have not handled some or another aspect of the login procedure. # Since we are expecting an EOF when we have successfully logged in, hanging means that the SSH login # procedure is waiting for more information. Since we have no more to give, this means our login # was unsuccessful. - raise RuntimeError('SSH client seems to be awaiting more information, but we have no more to give. The ' - 'messages received so far are:\n{}'.format(expect.before)) - elif i == 6: # Connection closed by remote host + raise RuntimeError( + f"SSH client seems to be awaiting more information, but we have no more to give. The messages received so far are:\n{expect.before}" + ) + if i == 6: # Connection closed by remote host raise RuntimeError("Remote closed SSH connection") - elif i == 7: - raise RuntimeError("Cannot connect to {} on your current network connection".format(self.host)) + if i == 7: + raise RuntimeError( + f"Cannot connect to {self.host} on your current network connection" + ) finally: expect.close() # We should be logged in at this point, but let us make doubly sure - assert self._is_connected(), 'Unexpected failure to establish a connection with the remote host with command: \n ' \ - '{}\n\n Please report this!'.format(cmd) + assert ( + self._is_connected() + ), f"Unexpected failure to establish a connection with the remote host with command: \n {cmd}\n\n Please report this!" @override def _is_connected(self): - cmd = "ssh {login} -T -S {socket} -O check".format(login=self._login_info, - socket=self._socket_path) + cmd = f"ssh {self._login_info} -T -S {self._socket_path} -O check" proc = run_in_subprocess(cmd) if proc.returncode != 0: @@ -176,8 +189,7 @@ def _is_connected(self): @override def _disconnect(self): # Send exit request to control socket. - cmd = "ssh {login} -T -S {socket} -O exit".format(login=self._login_info, - socket=self._socket_path) + cmd = f"ssh {self._login_info} -T -S {self._socket_path} -O exit" run_in_subprocess(cmd) # RemoteClient implementation @@ -190,112 +202,164 @@ def _execute(self, cmd, skip_cwd=False, **kwargs): command. This is mainly useful to methods internal to this class. """ - template = 'ssh {login} -T -o ControlPath={socket} << EOF\n{cwd}{cmd}\nEOF' + template = "ssh {login} -T -o ControlPath={socket} << EOF\n{cwd}{cmd}\nEOF" config = dict(self._subprocess_config) config.update(kwargs) - cwd = 'cd "{path}"\n'.format(path=escape_path(self.path_cwd)) if not skip_cwd else '' - return run_in_subprocess(template.format(login=self._login_info, - socket=self._socket_path, - cwd=cwd, - cmd=cmd), - check_output=True, - **config) + cwd = f'cd "{escape_path(self.path_cwd)}"\n' if not skip_cwd else "" + return run_in_subprocess( + template.format( + login=self._login_info, socket=self._socket_path, cwd=cwd, cmd=cmd + ), + check_output=True, + **config, + ) @override @require_connection def _port_forward_start(self, local_port, remote_host, remote_port): - logger.info('Establishing port forward...') - cmd_template = 'ssh {login} -T -O forward -S {socket} -L localhost:{local_port}:{remote_host}:{remote_port}' - cmd = cmd_template.format(login=self._login_info, - socket=self._socket_path, - local_port=local_port, - remote_host=remote_host, - remote_port=remote_port) + logger.info("Establishing port forward...") + cmd_template = "ssh {login} -T -O forward -S {socket} -L localhost:{local_port}:{remote_host}:{remote_port}" + cmd = cmd_template.format( + login=self._login_info, + socket=self._socket_path, + local_port=local_port, + remote_host=remote_host, + remote_port=remote_port, + ) proc = run_in_subprocess(cmd) if proc.returncode != 0: - raise Exception('Unable to port forward with command: {}'.format(cmd)) - logger.info(proc.stderr or 'Success') + raise RuntimeError(f"Unable to port forward with command: {cmd}") + logger.info(proc.stderr or "Success") return proc @override def _port_forward_stop(self, local_port, remote_host, remote_port, connection): - logger.info('Cancelling port forward...') - cmd_template = 'ssh {login} -T -O cancel -S {socket} -L localhost:{local_port}:{remote_host}:{remote_port}' - cmd = cmd_template.format(login=self._login_info, - socket=self._socket_path, - local_port=local_port, - remote_host=remote_host, - remote_port=remote_port) + logger.info("Cancelling port forward...") + cmd_template = "ssh {login} -T -O cancel -S {socket} -L localhost:{local_port}:{remote_host}:{remote_port}" + cmd = cmd_template.format( + login=self._login_info, + socket=self._socket_path, + local_port=local_port, + remote_host=remote_host, + remote_port=remote_port, + ) proc = run_in_subprocess(cmd) - logger.info('Port forward succesfully stopped.' if proc.returncode == 0 else 'Failed to stop port forwarding.') + logger.info( + "Port forward succesfully stopped." + if proc.returncode == 0 + else "Failed to stop port forwarding." + ) @override def _is_port_bound(self, host, port): - return self.execute('which nc; if [ $? -eq 0 ]; then nc -z -w2 {} {}; fi'.format(host, port)).returncode == 0 + return ( + self.execute( + f"which nc; if [ $? -eq 0 ]; then nc -z -w2 {host} {port}; fi" + ).returncode + == 0 + ) # FileSystem methods # Path properties and helpers @override def _path_home(self): - return self.execute('echo ~', skip_cwd=True).stdout.decode('utf-8').strip() + return self.execute("echo ~", skip_cwd=True).stdout.decode("utf-8").strip() @override def _path_separator(self): - return '/' + return "/" # File node properties @override def _exists(self, path): - return self.execute('if [ ! -e {} ]; then exit 1; fi'.format(path)).returncode == 0 + return self.execute(f"if [ ! -e {path} ]; then exit 1; fi").returncode == 0 @override def _isdir(self, path): - return self.execute('if [ ! -d {} ]; then exit 1; fi'.format(path)).returncode == 0 + return self.execute(f"if [ ! -d {path} ]; then exit 1; fi").returncode == 0 @override def _isfile(self, path): - return self.execute('if [ ! -f {} ]; then exit 1; fi'.format(path)).returncode == 0 + return self.execute(f"if [ ! -f {path} ]; then exit 1; fi").returncode == 0 # Directory handling and enumeration @override def _dir(self, path): # TODO: Currently we strip link annotations below with ...[:9]. Should we capture them? - dir = pd.DataFrame(sorted([re.split(r'\s+', f)[:9] for f in self.execute('ls -Al {}'.format(path)).stdout.decode('utf-8').strip().split('\n')[1:]]), - columns=['file_mode', 'link_count', 'owner', 'group', 'bytes', 'month', 'day', 'time', 'path']) + contents = pd.DataFrame( + sorted( + [ + re.split(r"\s+", f)[:9] + for f in self.execute(f"ls -Al {path}") + .stdout.decode("utf-8") + .strip() + .split("\n")[1:] + ] + ), + columns=[ + "file_mode", + "link_count", + "owner", + "group", + "bytes", + "month", + "day", + "time", + "path", + ], + ) def convert_to_datetime(x): - months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] - year = datetime.datetime.now().year if ':' in x.time else x.time - time = x.time if ':' in x.time else None + months = [ + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", + ] + year = datetime.datetime.now().year if ":" in x.time else x.time + time = x.time if ":" in x.time else None return datetime.datetime( year=int(year), month=months.index(x.month) + 1, day=int(x.day), - hour=int(time.split(':')[0]) if time is not None else 0, - minute=int(time.split(':')[1]) if time is not None else 0 + hour=int(time.split(":")[0]) if time is not None else 0, + minute=int(time.split(":")[1]) if time is not None else 0, ) - if len(dir) == 0: # Directory is empty + if len(contents) == 0: # Directory is empty return - dir = dir.assign( - last_modified=lambda x: x.apply(convert_to_datetime, axis=1), - type=lambda x: x.apply(lambda x: 'directory' if x.file_mode.startswith('d') else 'file', axis=1) - ).drop( - ['month', 'day', 'time'], - axis=1 - ).sort_values( - ['type', 'path'] - ).reset_index(drop=True) - - for i, row in dir.iterrows(): + contents = ( + contents.assign( + last_modified=lambda x: x.apply(convert_to_datetime, axis=1), + type=lambda x: x.apply( + lambda x: "directory" if x.file_mode.startswith("d") else "file", + axis=1, + ), + ) + .drop(["month", "day", "time"], axis=1) + .sort_values(["type", "path"]) + .reset_index(drop=True) + ) + + for _, row in contents.iterrows(): yield FileSystemFileDesc( fs=self, path=posixpath.join(path, row.path), name=row.path, - type='directory' if row.file_mode.startswith('d') else 'file', # TODO: What about links, which are of form: lrwxrwxrwx? + type="directory" + if row.file_mode.startswith("d") + else "file", # TODO: What about links, which are of form: lrwxrwxrwx? bytes=row.bytes, owner=row.owner, group=row.group, @@ -306,18 +370,28 @@ def convert_to_datetime(x): def _mkdir(self, path, recursive, exist_ok): if exist_ok and self.isdir(path): return - assert self.execute('mkdir ' + ('-p ' if recursive else '') + '"{}"'.format(path)).returncode == 0, "Failed to create directory at: `{}`".format(path) + assert ( + self.execute( + "mkdir " + ("-p " if recursive else "") + f'"{path}"' + ).returncode + == 0 + ), f"Failed to create directory at: `{path}`" @override def _remove(self, path, recursive): - assert self.execute('rm -f ' + ('-r ' if recursive else '') + '"{}"'.format(path)).returncode == 0, "Failed to remove file(s) at: `{}`".format(path) + assert ( + self.execute( + "rm -f " + ("-r " if recursive else "") + f'"{path}"' + ).returncode + == 0 + ), f"Failed to remove file(s) at: `{path}`" # File handling @override def _file_read_(self, path, size=-1, offset=0, binary=False): - read = self.execute('cat {}'.format(path)).stdout + read = self.execute(f"cat {path}").stdout if not binary: - read = read.decode('utf-8') + read = read.decode("utf-8") return read @override @@ -332,7 +406,11 @@ def _file_write_(self, path, s, binary): fd, tmp_path = tempfile.mkstemp(text=True) os.close(fd) - with open(tmp_path, 'w' + ('b' if binary else ''), encoding=None if binary else 'utf-8') as f: + with open( + tmp_path, + "w" + ("b" if binary else ""), + encoding=None if binary else "utf-8", + ) as f: f.write(s) return self.upload(tmp_path, path, overwrite=True) @@ -374,20 +452,19 @@ def download(self, source, dest=None, overwrite=False, fs=None): from ..filesystems.local import LocalFsClient if fs is None or isinstance(fs, LocalFsClient): - logger.info('Copying file to local...') + logger.info("Copying file to local...") dest = dest or posixpath.basename(source) - cmd = ( - "scp -r -o ControlPath={socket} {login}:'{remote_file}' '{local_file}'".format( - socket=self._socket_path, - login=self._login_info, - remote_file=dest.replace('"', r'\"'), - local_file=source.replace('"', r'\"'), # quote escaped for bash - ) + # pylint: disable-next=consider-using-f-string + cmd = "scp -r -o ControlPath={socket} {login}:'{remote_file}' '{local_file}'".format( + socket=self._socket_path, + login=self._login_info, + remote_file=dest.replace('"', r"\""), + local_file=source.replace('"', r"\""), # quote escaped for bash ) proc = run_in_subprocess(cmd, check_output=True) - logger.info(proc.stderr or 'Success') + logger.info(proc.stderr or "Success") else: - return super(RemoteClient, self).download(source, dest, overwrite, fs) + super(RemoteClient, self).download(source, dest, overwrite, fs) @override @require_connection @@ -424,33 +501,32 @@ def upload(self, source, dest=None, overwrite=False, fs=None): from ..filesystems.local import LocalFsClient if fs is None or isinstance(fs, LocalFsClient): - logger.info('Copying file from local...') + logger.info("Copying file from local...") dest = dest or posixpath.basename(source) - cmd = ( - "scp -r -o ControlPath={socket} '{local_file}' {login}:'{remote_file}'".format( - socket=self._socket_path, - local_file=source.replace('"', r'\"'), # quote escaped for bash - login=self._login_info, - remote_file=dest.replace('"', r'\"'), - ) + # pylint: disable-next=consider-using-f-string + cmd = "scp -r -o ControlPath={socket} '{local_file}' {login}:'{remote_file}'".format( + socket=self._socket_path, + local_file=source.replace('"', r"\""), # quote escaped for bash + login=self._login_info, + remote_file=dest.replace('"', r"\""), ) proc = run_in_subprocess(cmd, check_output=True) - logger.info(proc.stderr or 'Success') + logger.info(proc.stderr or "Success") else: - return super(RemoteClient, self).upload(source, dest, overwrite, fs) + super(RemoteClient, self).upload(source, dest, overwrite, fs) # Helper methods @property def _login_info(self): - return '@'.join([self.username, self.host]) + return "@".join([self.username, self.host]) @property def _socket_path(self): # On Linux the maximum socket path length is 108 characters, and on Mac OS X it is 104 characters, including # the final sentinel character (or so it seems). SSH appends a '.' character, followed by random sequence of 16 # characters. We therefore need the rest of the path to be less than 86 characters. - return os.path.expanduser('~/.ssh/omniduct/{}'.format(self._login_info))[:86] + return os.path.expanduser(f"~/.ssh/omniduct/{self._login_info}")[:86] @property def _subprocess_config(self): @@ -465,10 +541,13 @@ def update_host_keys(self): for example, redeployed and have different host keys. """ assert not self.remote, "Updating host key only works for local connections." - cmd = "ssh-keygen -R {host} && ssh-keyscan {host} >> ~/.ssh/known_hosts".format(host=self.host) + cmd = "ssh-keygen -R {host} && ssh-keyscan {host} >> ~/.ssh/known_hosts".format( + host=self.host + ) proc = run_in_subprocess(cmd, True) if proc.returncode != 0: raise RuntimeError( "Could not update host keys! Please handle this manually. The " - "error was:\n" + '\n'.join([proc.stdout.decode('utf-8'), proc.stderr.decode('utf-8')]) + "error was:\n" + + "\n".join([proc.stdout.decode("utf-8"), proc.stderr.decode("utf-8")]) ) diff --git a/omniduct/remotes/ssh_paramiko.py b/omniduct/remotes/ssh_paramiko.py index b92a633..55bd4fd 100644 --- a/omniduct/remotes/ssh_paramiko.py +++ b/omniduct/remotes/ssh_paramiko.py @@ -1,5 +1,8 @@ +# pylint: disable=abstract-method + import posixpath import select +import socketserver import stat import threading @@ -11,18 +14,7 @@ from omniduct.utils.debug import logger from omniduct.utils.processes import SubprocessResults -# Python 2 compatibility imports -try: - import SocketServer -except ImportError: - import socketserver as SocketServer - -try: - FileNotFoundError -except NameError: - FileNotFoundError = IOError - -__all__ = ['ParamikoSSHClient'] +__all__ = ["ParamikoSSHClient"] class ParamikoSSHClient(RemoteClient): @@ -33,62 +25,71 @@ class ParamikoSSHClient(RemoteClient): client. """ - PROTOCOLS = ['ssh_paramiko'] + PROTOCOLS = ["ssh_paramiko"] DEFAULT_PORT = 22 @override def _init(self): - logger.warning("The Paramiko SSH client is still under development, \ - and is not ready for use as a daily driver.") + logger.warning( + "The Paramiko SSH client is still under development, \ + and is not ready for use as a daily driver." + ) @override def _connect(self): - import paramiko # Imported here due to relatively slow import + import paramiko + + # pylint: disable-next=attribute-defined-outside-init self.__client = paramiko.SSHClient() self.__client.set_missing_host_key_policy(paramiko.client.AutoAddPolicy()) self.__client.load_system_host_keys() try: self.__client.connect(self.host, username=self.username) - self.__client_sftp = paramiko.SFTPClient.from_transport(self.__client.get_transport()) + # pylint: disable-next=attribute-defined-outside-init + self.__client_sftp = paramiko.SFTPClient.from_transport( + self.__client.get_transport() + ) except paramiko.SSHException as e: - if len(e.args) == 1 and e.args[0] == 'No authentication methods available': - raise DuctAuthenticationError(e.args[0]) + if len(e.args) == 1 and e.args[0] == "No authentication methods available": + raise DuctAuthenticationError(e.args[0]) from e raise e @override def _is_connected(self): try: return self.__client.get_transport().is_active() - except: + except: # pylint: disable=bare-except return False @override def _disconnect(self): try: self.__client_sftp.close() - return self.__client.close() - except: + self.__client.close() + except: # pylint: disable=bare-except pass @override def _execute(self, cmd, **kwargs): - stdin, stdout, stderr = self.__client.exec_command(cmd) + _, stdout, stderr = self.__client.exec_command(cmd) returncode = stdout.channel.recv_exit_status() return SubprocessResults( - returncode=returncode, - stdout=stdout.read(), - stderr=stderr.read() + returncode=returncode, stdout=stdout.read(), stderr=stderr.read() ) @override def _port_forward_start(self, local_port, remote_host, remote_port): - logger.debug('Now forwarding port {} to {}:{} ...'.format(local_port, remote_host, remote_port)) + logger.debug( + f"Now forwarding port {local_port} to {remote_host}:{remote_port} ..." + ) try: - server = forward_tunnel(local_port, remote_host, remote_port, self.__client.get_transport()) + server = forward_tunnel( + local_port, remote_host, remote_port, self.__client.get_transport() + ) except KeyboardInterrupt: - print('C-c: Port forwarding stopped.') + print("C-c: Port forwarding stopped.") return server @override @@ -102,11 +103,11 @@ def _is_port_bound(self, host, port): # Path properties and helpers @override def _path_home(self): - return self.execute('echo ~', skip_cwd=True).stdout.decode('utf-8').strip() + return self.execute("echo ~", skip_cwd=True).stdout.decode("utf-8").strip() @override def _path_separator(self): - return '/' + return "/" # File node properties @override @@ -139,7 +140,9 @@ def _dir(self, path): fs=self, path=posixpath.join(path, attrs.filename), name=attrs.filename, - type='directory' if stat.S_ISDIR(attrs.st_mode) else 'file', # TODO: What about links, which are of form: lrwxrwxrwx? + type="directory" + if stat.S_ISDIR(attrs.st_mode) + else "file", # TODO: What about links, which are of form: lrwxrwxrwx? bytes=attrs.st_size, owner=attrs.st_uid, group=attrs.st_gid, @@ -150,11 +153,21 @@ def _dir(self, path): def _mkdir(self, path, recursive, exist_ok): if exist_ok and self.isdir(path): return - assert self.execute('mkdir ' + ('-p ' if recursive else '') + '"{}"'.format(path)).returncode == 0, "Failed to create directory at: `{}`".format(path) + assert ( + self.execute( + "mkdir " + ("-p " if recursive else "") + f'"{path}"' + ).returncode + == 0 + ), f"Failed to create directory at: `{path}`" @override def _remove(self, path, recursive): - assert self.execute('rm -f ' + ('-r ' if recursive else '') + '"{}"'.format(path)).returncode == 0, "Failed to remove file(s) at: `{}`".format(path) + assert ( + self.execute( + "rm -f " + ("-r " if recursive else "") + f'"{path}"' + ).returncode + == 0 + ), f"Failed to remove file(s) at: `{path}`" # File handling @override @@ -170,32 +183,41 @@ def _open(self, path, mode): # Port Forwarding Utility Code # Largely based on code from: https://github.com/paramiko/paramiko/blob/master/demos/forward.py -class ForwardServer (SocketServer.ThreadingTCPServer): + +class ForwardServer(socketserver.ThreadingTCPServer): daemon_threads = True allow_reuse_address = True -class Handler (SocketServer.BaseRequestHandler): - +class Handler(socketserver.BaseRequestHandler): def handle(self): try: - chan = self.ssh_transport.open_channel('direct-tcpip', - (self.chain_host, self.chain_port), - self.request.getpeername()) - except Exception as e: - logger.info('Incoming request to %s:%d failed: %s' % (self.chain_host, - self.chain_port, - repr(e))) + chan = self.ssh_transport.open_channel( + "direct-tcpip", + (self.chain_host, self.chain_port), + self.request.getpeername(), + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.info( + "Incoming request to %s:%d failed: %s", + self.chain_host, + self.chain_port, + repr(e), + ) return if chan is None: - logger.info('Incoming request to %s:%d was rejected by the SSH server.' % - (self.chain_host, self.chain_port)) + logger.info( + "Incoming request to %s:%d was rejected by the SSH server.", + self.chain_host, + self.chain_port, + ) return - logger.info('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(), - chan.getpeername(), (self.chain_host, self.chain_port))) + logger.info( + f"Connected! Tunnel open {self.request.getpeername()!r} -> {chan.getpeername()!r} -> {self.chain_host, self.chain_port!r}" + ) while True: - r, w, x = select.select([self.request, chan], [], []) + r, _, _ = select.select([self.request, chan], [], []) if self.request in r: data = self.request.recv(1024) if len(data) == 0: @@ -210,21 +232,22 @@ def handle(self): peername = self.request.getpeername() chan.close() self.request.close() - logger.info('Tunnel closed from %r' % (peername,)) + logger.info(f"Tunnel closed from {peername!r}") def forward_tunnel(local_port, remote_host, remote_port, transport): # this is a little convoluted, but lets me configure things for the Handler - # object. (SocketServer doesn't give Handlers any way to access the outer + # object. (socketserver doesn't give Handlers any way to access the outer # server normally.) class SubHandler(Handler): chain_host = remote_host chain_port = remote_port ssh_transport = transport - server = ForwardServer(('', local_port), SubHandler) + + server = ForwardServer(("", local_port), SubHandler) t = threading.Thread(target=server.serve_forever) - t.setDaemon(True) # don't hang on exit + t.daemon = True # don't hang on exit t.start() return server diff --git a/omniduct/remotes/stub.py b/omniduct/remotes/stub.py index eb43fdb..52f8779 100644 --- a/omniduct/remotes/stub.py +++ b/omniduct/remotes/stub.py @@ -2,7 +2,6 @@ class StubFsClient(RemoteClient): - PROTOCOLS = [] DEFAULT_PORT = None diff --git a/omniduct/restful/base.py b/omniduct/restful/base.py index 66201e3..69a5697 100644 --- a/omniduct/restful/base.py +++ b/omniduct/restful/base.py @@ -1,6 +1,6 @@ import json +from urllib.parse import urljoin -from future.moves.urllib.parse import urljoin from interface_meta import quirk_docs, override from omniduct.duct import Duct @@ -29,8 +29,15 @@ class RestClientBase(Duct): DUCT_TYPE = Duct.Type.RESTFUL - @quirk_docs('_init', mro=True) - def __init__(self, server_protocol='http', assume_json=False, endpoint_prefix='', **kwargs): + @quirk_docs("_init", mro=True) + def __init__( # pylint: disable=super-init-not-called + self, + server_protocol="http", + assume_json=False, + endpoint_prefix="", + default_timeout=None, + **kwargs, + ): """ Args: server_protocol (str): The protocol to use when connecting to the @@ -39,6 +46,9 @@ def __init__(self, server_protocol='http', assume_json=False, endpoint_prefix='' instances of this class (default: `False`). endpoint_prefix (str): The base_url path relative to the host at which the API is accessible (default: `''`). + default_timeout (optional float): The number of seconds to wait for + a response. Will be used except where overridden by specific + requests. **kwargs (dict): Additional keyword arguments passed on to subclasses. """ @@ -47,13 +57,14 @@ def __init__(self, server_protocol='http', assume_json=False, endpoint_prefix='' self.server_protocol = server_protocol self.assume_json = assume_json self.endpoint_prefix = endpoint_prefix + self.default_timeout = default_timeout self._init(**kwargs) def _init(self): pass - def __call__(self, endpoint, method='get', **kwargs): + def __call__(self, endpoint, method="get", **kwargs): if self.assume_json: return self.request_json(endpoint, method=method, **kwargs) return self.request(endpoint, method=method, **kwargs) @@ -61,13 +72,16 @@ def __call__(self, endpoint, method='get', **kwargs): @property def base_url(self): """str: The base url of the REST API.""" - url = urljoin('{}://{}:{}'.format(self.server_protocol, self.host, self.port or 80), self.endpoint_prefix) - if not url.endswith('/'): - url += '/' + url = urljoin( + f"{self.server_protocol}://{self.host}:{self.port or 80}", + self.endpoint_prefix, + ) + if not url.endswith("/"): + url += "/" return url @require_connection - def request(self, endpoint, method='get', **kwargs): + def request(self, endpoint, method="get", **kwargs): """ Request data from a nominated endpoint. @@ -81,10 +95,13 @@ def request(self, endpoint, method='get', **kwargs): requests.Response: The response object associated with this request. """ import requests + url = urljoin(self.base_url, endpoint) - return requests.request(method, url, **kwargs) + return requests.request( + method, url, **{"timeout": self.default_timeout, **kwargs} + ) - def request_json(self, endpoint, method='get', **kwargs): + def request_json(self, endpoint, method="get", **kwargs): """ Request JSON data from a nominated endpoint. @@ -100,9 +117,13 @@ def request_json(self, endpoint, method='get', **kwargs): request = self.request(endpoint, method=method, **kwargs) if not request.status_code == 200: try: - raise RuntimeError("Server responded with HTTP response code {}, with content: {}.".format(request.status_code, json.dumps(request.json()))) - except: - raise RuntimeError("Server responded with HTTP response code {}, with content: {}.".format(request.status_code, request.content.decode('utf-8'))) + raise RuntimeError( + f"Server responded with HTTP response code {request.status_code}, with content: {json.dumps(request.json())}." + ) + except Exception as e: # pylint: disable=broad-exception-caught + raise RuntimeError( + f"Server responded with HTTP response code {request.status_code}, with content: {request.content.decode('utf-8')}." + ) from e return request.json() @override @@ -122,4 +143,5 @@ class RestClient(RestClientBase): """ A trivial implementation of `RestClientBase` for basic REST access. """ - PROTOCOLS = ['rest'] + + PROTOCOLS = ["rest"] diff --git a/omniduct/utils/about.py b/omniduct/utils/about.py index 8b0b5c9..c1349f1 100644 --- a/omniduct/utils/about.py +++ b/omniduct/utils/about.py @@ -1,8 +1,8 @@ import base64 import textwrap +import urllib.parse import jinja2 -import six.moves.urllib as urllib import omniduct @@ -56,8 +56,16 @@ """.strip() -def show_about(name, version=None, logo=None, maintainers=None, attributes=None, - description=None, endorsements=None, endorse_omniduct=True): +def show_about( + name, + version=None, + logo=None, + maintainers=None, + attributes=None, + description=None, + endorsements=None, + endorse_omniduct=True, +): """ Output information about a project in HTML for notebooks and text otherwise. @@ -79,36 +87,41 @@ def show_about(name, version=None, logo=None, maintainers=None, attributes=None, endorsements = endorsements or [] if endorse_omniduct: - endorsements.append({ - 'name': 'Omniduct', - 'version': omniduct.__version__, - 'logo': omniduct.__logo__ - }) + endorsements.append( + { + "name": "Omniduct", + "version": omniduct.__version__, + "logo": omniduct.__logo__, + } + ) for endorsement in endorsements: - endorsement['logo'] = get_image_url(endorsement.get('logo')) - endorsements = sorted(endorsements, key=lambda x: x['name']) + endorsement["logo"] = get_image_url(endorsement.get("logo")) + endorsements = sorted(endorsements, key=lambda x: x["name"]) context = { - 'name': name, - 'version': version, - 'logo': get_image_url(logo), - 'maintainers': maintainers or {}, - 'attributes': attributes or {}, - 'description': textwrap.dedent(description).strip() if description else None, - 'endorsements': endorsements + "name": name, + "version": version, + "logo": get_image_url(logo), + "maintainers": maintainers or {}, + "attributes": attributes or {}, + "description": textwrap.dedent(description).strip() if description else None, + "endorsements": endorsements, } try: from IPython import get_ipython from IPython.display import display, HTML + ip = get_ipython() - if ip is not None and ip.has_trait('kernel'): + if ip is not None and ip.has_trait("kernel"): return display(HTML(jinja2.Template(ABOUT_TEMPLATE_HTML).render(**context))) - except: + except: # pylint: disable=bare-except pass # Textual fallback if HTML not running in a notebook - print(textwrap.dedent(jinja2.Template(ABOUT_TEMPLATE_TEXT).render(**context))) + return print( + textwrap.dedent(jinja2.Template(ABOUT_TEMPLATE_TEXT).render(**context)) + ) def get_image_url(uri): @@ -125,9 +138,9 @@ def get_image_url(uri): str: The uri of the image suitable for rendering in a notebook. """ if not uri: - return + return None parsed = urllib.parse.urlparse(uri) - if parsed.scheme in ('', 'file'): - with open(parsed.path, 'rb') as image: - return "data:image/png;base64,{}".format(base64.b64encode(image.read()).decode()) + if parsed.scheme in ("", "file"): + with open(parsed.path, "rb") as image: + return f"data:image/png;base64,{base64.b64encode(image.read()).decode()}" return uri diff --git a/omniduct/utils/config.py b/omniduct/utils/config.py index 8b570e4..70fcb6a 100644 --- a/omniduct/utils/config.py +++ b/omniduct/utils/config.py @@ -5,8 +5,6 @@ import os import sys -import six - def ensure_path_exists(path): path = os.path.expanduser(path) @@ -18,13 +16,20 @@ def ensure_path_exists(path): logger = logging.getLogger(__name__) -class ConfigurationRegistry(object): - +class ConfigurationRegistry: def __init__(self): self._register = {} - def register(self, key, description=None, default=None, onchange=None, onload=None, - type=None, host=None): + def register( + self, + key, + description=None, + default=None, + onchange=None, + onload=None, + type=None, # pylint: disable=redefined-builtin + host=None, + ): """ Register a configuration key that can be set by the user. As noted in the class level documentation, these keys should not lead to changes in the @@ -43,26 +48,32 @@ class level documentation, these keys should not lead to changes in the * If not specified, these fields default to None. """ - if key in dir(self): - raise KeyError("Key `{0}` cannot be registered as it conflicts with a method of OmniductConfiguration.".format(key)) + if key in dir(self.__class__): + raise KeyError( + f"Key `{key}` cannot be registered as it conflicts with a method of OmniductConfiguration." + ) if key in self._register: - logger.warn("Overwriting existing key `{0}`, previously registered by {1}".format(key, self._register[key]['host'])) + logger.debug( + "Overwriting existing omniduct registry key `%s`, previously registered by %s", + key, + self._register[key]["host"], + ) try: - caller_frame = inspect.current_frame().f_back + caller_frame = inspect.currentframe().f_back host = inspect.getmodule(caller_frame).__name__ - except: - host = 'unknown' + except: # pylint: disable=bare-except + host = "unknown" if default is not None and type is not None: assert isinstance(default, type) self._register[key] = { - 'description': description, - 'host': host, - 'default': default, - 'onchange': onchange, - 'onload': onload, - 'type': type, + "description": description, + "host": host, + "default": default, + "onchange": onchange, + "onload": onload, + "type": type, } def show(self): @@ -72,12 +83,12 @@ def show(self): registered. """ for key in sorted(self._register.keys()): - desc = self._register[key].get('description') + desc = self._register[key].get("description") if desc is None: - desc = 'No description' - print('{0} with default = {1}'.format(key, self._register[key]['default'])) - print('\t{0}'.format(desc)) - print('\t({0})'.format(self._register[key]['host'])) + desc = "No description" + print(f"{key} with default = {self._register[key]['default']}") + print(f"\t{desc}") + print(f"\t({self._register[key]['host']})") class Configuration(ConfigurationRegistry): @@ -109,7 +120,7 @@ def __init__(self, *registries, **kwargs): self.register(key, **props) self._config = {} - self.__config_path = kwargs.pop('config_path', None) + self.__config_path = kwargs.pop("config_path", None) def __dir__(self): return sorted(self._register.keys()) @@ -126,9 +137,10 @@ def _config_path(self, path): # Restore configuration try: self.load(force=True) - except: + except Exception as e: # pylint: disable=broad-exception-caught raise RuntimeError( - "Configuration file at {0} cannot be loaded. Perhaps try deleting it.".format(self.__config_path)) + f"Configuration file at {self.__config_path} cannot be loaded. Perhaps try deleting it." + ) from e def all(self): """ @@ -146,13 +158,13 @@ def show(self): registered. """ for key in sorted(self._register.keys()): - desc = self._register[key].get('description') + desc = self._register[key].get("description") if desc is None: - desc = 'No description' - val = str(self._config.get(key, '')) - print('{0} = {1} (default = {2})'.format(key, val, self._register[key]['default'])) - print('\t{0}'.format(desc)) - print('\t({0})'.format(self._register[key]['host'])) + desc = "No description" + val = str(self._config.get(key, "")) + print(f"{key} = {val} (default = {self._register[key]['default']})") + print(f"\t{desc}") + print(f"\t({self._register[key]['host']})") def __setattr__(self, key, value): """ @@ -161,18 +173,19 @@ def __setattr__(self, key, value): Attributes prefixed with '_' are loaded from this class. """ - if key.startswith('_'): + if key.startswith("_"): object.__setattr__(self, key, value) elif key in self._register: - if self._register[key]['type'] is not None: - if not isinstance(value, self._register[key]['type']): + if self._register[key]["type"] is not None: + if not isinstance(value, self._register[key]["type"]): raise ValueError( - "{} must be in type(s) {}".format(key, self._register[key]['type'])) - if self._register[key]['onchange'] is not None: - self._register[key]['onchange'](value) + f"{key} must be in type(s) {self._register[key]['type']}" + ) + if self._register[key]["onchange"] is not None: + self._register[key]["onchange"](value) self._config[key] = value else: - raise KeyError("No such configuration key `{0}`.".format(key)) + raise KeyError(f"No such configuration key `{key}`.") def __getattr__(self, key): """ @@ -181,18 +194,21 @@ def __getattr__(self, key): Attributes prefixed with '_' are loaded from this class. """ - if key.startswith('_'): + if key.startswith("_"): return object.__getattr__(self, key) if key in self._register: if key in self._config: return self._config[key] # if a lazy loader is specified, use it - if self._register[key]['default'] is None and self._register[key]['onload'] is not None: - setattr(self, key, self._register[key]['onload']()) + if ( + self._register[key]["default"] is None + and self._register[key]["onload"] is not None + ): + setattr(self, key, self._register[key]["onload"]()) - return self._config.get(key, self._register[key]['default']) - raise AttributeError("No such configuration key `{0}`.".format(key)) + return self._config.get(key, self._register[key]["default"]) + raise AttributeError(f"No such configuration key `{key}`.") def reset(self, *keys, **target_config): """ @@ -216,18 +232,24 @@ def reset(self, *keys, **target_config): for key, value in target_config.items(): self._config[key] = value if key in self._register: - if value == self._register[key]['default']: + if value == self._register[key]["default"]: self._config.pop(key) - if self._register[key]['onchange'] is not None: - self._register[key]['onchange'](getattr(self, key)) + if self._register[key]["onchange"] is not None: + self._register[key]["onchange"](getattr(self, key)) else: # Allow users to delete deprecated keys - logger.warning("Added value for configuration key `{0}` which has yet to be registered.".format(key)) + logger.warning( + "Added value for configuration key `%s` which has yet to be registered.", + key, + ) for key in reset_keys: if key in self._config: self._config.pop(key) - if key in self._register and self._register[key]['onchange'] is not None: - self._register[key]['onchange'](getattr(self, key)) + if ( + key in self._register + and self._register[key]["onchange"] is not None + ): + self._register[key]["onchange"](getattr(self, key)) def __restrict_keys(self, d, keys): if keys is None: @@ -248,19 +270,21 @@ def save(self, filename=None, keys=None, replace=None): or `False` if specific keys are specified. (default=None) """ filename = filename or self._config_path - filename = os.path.join(ensure_path_exists(os.path.dirname(filename)), os.path.basename(filename)) + filename = os.path.join( + ensure_path_exists(os.path.dirname(filename)), os.path.basename(filename) + ) config = {} if replace is None: - replace = True if keys is None else False + replace = keys is None if keys is None: replace = True if not replace and os.path.exists(filename): - with io.open(filename, 'r') as f: + with io.open(filename, "r", encoding="utf-8") as f: config = json.load(f) config.update(self.__restrict_keys(self._config, keys)) - with io.open(filename, 'w') as f: + with io.open(filename, "w", encoding="utf-8") as f: json_config = json.dumps(config, ensure_ascii=False, indent=4) - if sys.version_info.major == 2 and isinstance(json_config, six.string_types): + if sys.version_info.major == 2 and isinstance(json_config, str): json_config = json_config.decode("utf-8") f.write(json_config) @@ -283,10 +307,10 @@ def load(self, filename=None, keys=None, replace=None, force=False): """ filename = filename or self._config_path if replace is None: - replace = True if keys is None else False + replace = keys is None if keys is None: replace = True - with io.open(filename, 'r') as f: + with io.open(filename, "r", encoding="utf-8") as f: config = self.__restrict_keys(json.load(f), keys) if force: self._config = config diff --git a/omniduct/utils/debug.py b/omniduct/utils/debug.py index f538dfa..322edda 100644 --- a/omniduct/utils/debug.py +++ b/omniduct/utils/debug.py @@ -6,19 +6,19 @@ import time import progressbar -import six from decorator import decorate -from future.utils import raise_with_traceback from .config import config -config.register('logging_level', - description='Set the default logging level.', - default=logging.INFO, - onchange=lambda level: logger.setLevel(level, context='omniduct')) +config.register( + "logging_level", + description="Set the default logging level.", + default=logging.INFO, + onchange=lambda level: logger.setLevel(level, context="omniduct"), +) -class StatusLogger(object): +class StatusLogger: """ StatusLogger is a wrapper around `logging.Logger` that allows for consistent treatment of logging messages. While not strictly required, @@ -38,11 +38,13 @@ def __init__(self, auto_scoping=False): self.__scopes = [] ch = LoggingHandler() - formatter = logging.Formatter("%(levelname)s: %(name)s (%(funcName)s:%(lineno)s): %(message)s") + formatter = logging.Formatter( + "%(levelname)s: %(name)s (%(funcName)s:%(lineno)s): %(message)s" + ) ch.setFormatter(formatter) - self.setLevel(config.logging_level, context='omniduct') - omniductLogger = self.__get_logger_instance(context='omniduct') + self.setLevel(config.logging_level, context="omniduct") + omniductLogger = self.__get_logger_instance(context="omniduct") omniductLogger.addHandler(ch) omniductLogger.propagate = False @@ -58,49 +60,58 @@ def disabled(self, disabled): def _scope_enter(self, name, timed=False, extra=None): if config.logging_level < logging.INFO: - print("\t" * len(self.__scopes) + "Entering manual scope: {}".format(name), file=sys.stderr) - props = {'name': name} + print( + "\t" * len(self.__scopes) + f"Entering manual scope: {name}", + file=sys.stderr, + ) + props = {"name": name} if timed: - props['time'] = time.time() + props["time"] = time.time() if extra is not None: - props['extra'] = extra - props['caveats'] = [] + props["extra"] = extra + props["caveats"] = [] self.__scopes.append(props) def _scope_exit(self, success=True): if self._progress_bar is not None: self.progress(100, complete=True) props = self.__scopes[-1] - if 'time' in props: + if "time" in props: logger.warning( - "{} after {} on {}.".format( - 'Complete' if success else 'Failed', - self.__get_time(time.time() - props['time']), - time.strftime('%Y-%m-%d') - ) + (' CAVEATS: {}.'.format('; '.join(props['caveats'])) if props['caveats'] else '') + f"{'Complete' if success else 'Failed'} after {self.__get_time(time.time() - props['time'])} on {time.strftime('%Y-%m-%d')}." + + ( + f" CAVEATS: {'; '.join(props['caveats'])}." + if props["caveats"] + else "" + ) ) scope = self.__scopes.pop() if config.logging_level < logging.INFO: - print("\t" * len(self.__scopes) + "Exited manual scope: {}".format(scope['name']), file=sys.stderr) - elif 'has_logged' in scope: + print( + "\t" * len(self.__scopes) + f"Exited manual scope: {scope['name']}", + file=sys.stderr, + ) + elif "has_logged" in scope: if len(self.__scopes) != 0: - self.current_scope_props['has_logged'] = self.current_scope_props.get('has_logged') or props.get('has_logged', False) + self.current_scope_props["has_logged"] = self.current_scope_props.get( + "has_logged" + ) or props.get("has_logged", False) def __get_time(self, seconds): m, s = divmod(seconds, 60) h, m = divmod(m, 60) if h > 0: - return "{:.0f} hrs, {:.0f} min".format(h, m) + return f"{h:.0f} hrs, {m:.0f} min" if m > 0: - return "{:.0f} min, {:.0f} sec".format(m, s) - return "{:.2f} sec".format(s) + return f"{m:.0f} min, {s:.0f} sec" + return f"{s:.2f} sec" def caveat(self, caveat): if len(self.__scopes) == 0: - self.warning("CAVEAT: {}".format(caveat)) + self.warning(f"CAVEAT: {caveat}") else: - self.current_scope_props['caveats'].append(caveat) + self.current_scope_props["caveats"].append(caveat) @property def current_scopes(self): @@ -125,10 +136,18 @@ def __get_progress_bar(self, indeterminate=False): prefix = ": ".join(self.current_scopes) + ": " else: prefix = "\t" * len(self.current_scopes) - self._progress_bar = progressbar.ProgressBar(widgets=[prefix, progressbar.widgets.RotatingMarker() if indeterminate else progressbar.widgets.Bar(), progressbar.widgets.Timer(format=' %(elapsed)s')], - redirect_stderr=True, - redirect_stdout=True, - max_value=100).start() + self._progress_bar = progressbar.ProgressBar( + widgets=[ + prefix, + progressbar.widgets.RotatingMarker() + if indeterminate + else progressbar.widgets.Bar(), + progressbar.widgets.Timer(format=" %(elapsed)s"), + ], + redirect_stderr=True, + redirect_stdout=True, + max_value=100, + ).start() return self._progress_bar @@ -137,7 +156,9 @@ def progress(self, progress=None, complete=False, indeterminate=False): Set the current progress to `progress`, and if not already showing, display a progress bar. If `complete` evaluates to True, then finish displaying the progress. """ - complete = complete or (self.current_scope_props is None) # Only leave progress bar open if within a scope + complete = complete or ( + self.current_scope_props is None + ) # Only leave progress bar open if within a scope if config.logging_level <= logging.INFO: self.__get_progress_bar(indeterminate=indeterminate).update(progress) if complete: @@ -155,10 +176,10 @@ def __get_logger_instance(self, context=None): try: caller = inspect.stack()[2] context = inspect.getmodule(caller.frame).__name__ - except: - context = 'omniduct' - if not context == 'omniduct' and not context.startswith('omniduct.'): - context = 'omniduct.external.{}'.format(context) + except: # pylint: disable=bare-except + context = "omniduct" + if context != "omniduct" and not context.startswith("omniduct."): + context = f"omniduct.external.{context}" return logging.getLogger(context) def __getattr__(self, name): @@ -182,13 +203,14 @@ def detect_scopes(): current_frame = inspect.currentframe() while current_frame is not None: - if current_frame.f_code.co_name == 'logging_scope': - scopes.append(current_frame.f_locals['name']) + if current_frame.f_code.co_name == "logging_scope": + scopes.append(current_frame.f_locals["name"]) else: argvalues = inspect.getargvalues(current_frame) - if 'self' in argvalues.args and getattr(argvalues.locals['self'].__class__, 'AUTO_LOGGING_SCOPE', - False): - scopes.append(argvalues.locals['self']) + if "self" in argvalues.args and getattr( + argvalues.locals["self"].__class__, "AUTO_LOGGING_SCOPE", False + ): + scopes.append(argvalues.locals["self"]) current_frame = current_frame.f_back out_scopes = [] @@ -197,12 +219,18 @@ def detect_scopes(): if scope not in seen: out_scopes.append( scope - if isinstance(scope, six.string_types) else - (getattr(scope, "LOGGING_SCOPE", None) or getattr(scope, "name", None) or scope.__class__.__name__)) + if isinstance(scope, str) + else ( + getattr(scope, "LOGGING_SCOPE", None) + or getattr(scope, "name", None) + or scope.__class__.__name__ + ) + ) seen.add(scope) return out_scopes +# pylint: disable-next=abstract-method class LoggingHandler(logging.Handler): """ An implementation of Logging.Handler to render the logging methods shown in Omniduct and derivatives. @@ -210,50 +238,57 @@ class LoggingHandler(logging.Handler): def __init__(self, level=logging.NOTSET): logging.Handler.__init__(self, level=level) - self.setFormatter(logging.Formatter("%(levelname)s: %(name)s (%(funcName)s:%(lineno)s): %(message)s")) + self.setFormatter( + logging.Formatter( + "%(levelname)s: %(name)s (%(funcName)s:%(lineno)s): %(message)s" + ) + ) def format_simple(self, record): - return "{}".format(record.getMessage()) + return f"{record.getMessage()}" def handle(self, record): try: scopes = logger.current_scopes - except: + except: # pylint: disable=bare-except scopes = [] if config.logging_level < logging.INFO: # Print everything verbosely - prefix = '\t' * len(scopes) - self._overwrite(prefix + self.format(record), - overwritable=False, - truncate=False) + prefix = "\t" * len(scopes) + self._overwrite( + prefix + self.format(record), overwritable=False, truncate=False + ) else: prefix = "" - important = (record.levelno >= logging.WARNING or - logger._progress_bar is not None or - len(scopes) == 0) + important = ( + record.levelno >= logging.WARNING + or logger._progress_bar is not None + or len(scopes) == 0 + ) if len(scopes) > 0: prefix = ": ".join(scopes) + ": " if logger.current_scope_props is not None: - logger.current_scope_props['has_logged'] = True + logger.current_scope_props["has_logged"] = True - self._overwrite(prefix + self.format_simple(record), - overwritable=not important, - truncate=not important - ) + self._overwrite( + prefix + self.format_simple(record), + overwritable=not important, + truncate=not important, + ) sys.stderr.flush() def _overwrite(self, text, overwritable=True, truncate=True, file=sys.stderr): - w, h = progressbar.utils.get_terminal_size() - file.write('\r' + ' ' * w + '\r') # Clear current line + w, _ = progressbar.utils.get_terminal_size() + file.write("\r" + " " * w + "\r") # Clear current line if overwritable: - text.replace('\n', ' ') + text.replace("\n", " ") if truncate: if len(text) > w: - text = text[:w - 3] + '...' + text = text[: w - 3] + "..." if not overwritable: - text += '\n' + text += "\n" file.write(text) @@ -264,17 +299,19 @@ def logging_scope(name, *wargs, **wkwargs): supported keyword arguments are "timed", in which case when the scope closes, the duration of the call is shown. """ + def logging_scope(func, *args, **kwargs): logger._scope_enter(name, *wargs, **wkwargs) success = True try: f = func(*args, **kwargs) return f - except Exception as e: + except Exception: # pylint: disable=broad-exception-caught success = False - raise_with_traceback(e) + raise finally: logger._scope_exit(success) + return lambda func: decorate(func, logging_scope) diff --git a/omniduct/utils/decorators.py b/omniduct/utils/decorators.py index c0b9650..82519bd 100644 --- a/omniduct/utils/decorators.py +++ b/omniduct/utils/decorators.py @@ -1,16 +1,10 @@ import inspect -import sys import decorator -import six -from future.utils import raise_with_traceback def function_args_as_kwargs(func, *args, **kwargs): - if six.PY3 and not hasattr(sys, 'pypy_version_info'): - arguments = inspect.signature(func).parameters.keys() - else: - arguments = inspect.getargspec(func).args + arguments = inspect.signature(func).parameters.keys() kwargs.update(dict(zip(list(arguments), args))) return kwargs @@ -27,7 +21,7 @@ def require_connection(f, self, *args, **kwargs): try: return f(self, *args, **kwargs) - except Exception as e: + except Exception: # pylint: disable=broad-exception-caught # Check to see if it is possible that we failed due to connection issues. # If so, try again once more. If we fail again, raise. # TODO: Explore adding a DuctConnectionError class and filter this @@ -35,4 +29,4 @@ def require_connection(f, self, *args, **kwargs): if not self.is_connected(): self.connect() return f(self, *args, **kwargs) - raise_with_traceback(e) + raise diff --git a/omniduct/utils/dependencies.py b/omniduct/utils/dependencies.py index d3db157..38060f4 100644 --- a/omniduct/utils/dependencies.py +++ b/omniduct/utils/dependencies.py @@ -1,13 +1,28 @@ import importlib import re +from typing import Optional -import pkg_resources -from pkg_resources import VersionConflict +import packaging.requirements from omniduct._version import __optional_dependencies__ from omniduct.utils.debug import logger +def get_package_version(package_name: str) -> Optional[str]: + """ + Return the version of the given package, or None if the package is not + installed. + """ + try: # Python 3.8+ + import importlib.metadata + + return importlib.metadata.version(package_name) + except ImportError: # Python <3.12 + import pkg_resources + + return pkg_resources.get_distribution(package_name).version + + def check_dependencies(protocols, message=None): if protocols is None: return @@ -18,36 +33,42 @@ def check_dependencies(protocols, message=None): warning_deps = {} for dep in dependencies: - m = re.match('^[a-z_][a-z0-9]*', dep) + m = re.match("^[a-z_][a-z0-9]*", dep) if not m: - logger.warning('Invalid dependency requested: {}'.format(dep)) + logger.warning(f"Invalid dependency requested: {dep}") package_name = m.group(0) accept_any_version = package_name == dep - try: - pkg_resources.get_distribution(dep) - except VersionConflict: - warning_deps[dep] = "{}=={}".format(package_name, pkg_resources.get_distribution(m.group(0)).version) - except: + dep_req = packaging.requirements.Requirement(dep) + package_name = dep_req.name + package_version = get_package_version(package_name) + + if package_version is None: # Some packages may be available, but not installed. If so, we # should accept them with warnings (if version specified in dep). try: importlib.import_module(package_name) if not accept_any_version: - warning_deps.append('{}=='.format(package_name)) - except: # ImportError in python 2, ModuleNotFoundError in Python 3 + warning_deps[dep] = f"{package_name}==" + except ModuleNotFoundError: missing_deps.append(dep) + elif dep_req.specifier and not dep_req.specifier.contains(package_version): + warning_deps[dep] = f"{package_name}=={package_version}" if warning_deps: message = "You may have some outdated packages:\n" for key in sorted(warning_deps): - message += '\t- Want {}, found {}'.format(key, warning_deps[key]) + message += f"\t- Want {key}, found {warning_deps[key]}" logger.warning(message) if missing_deps: - message = message or "Whoops! You do not seem to have all the dependencies required." - fix = ("You can fix this by running:\n\n" - "\t{install_command}\n\n" - "Note: Depending on your system's installation of Python, you may " - "need to use `pip2` or `pip3` instead of `pip`.").format(install_command='pip install --upgrade ' + ' '.join(missing_deps)) - raise RuntimeError('\n\n'.join([message, fix])) + message = ( + message or "Whoops! You do not seem to have all the dependencies required." + ) + fix = ( + "You can fix this by running:\n\n" + f"\tpip install --upgrade {' '.join(missing_deps)}\n\n" + "Note: Depending on your system's installation of Python, you may " + "need to use `pip3` instead of `pip`." + ) + raise RuntimeError("\n\n".join([message, fix])) diff --git a/omniduct/utils/magics.py b/omniduct/utils/magics.py index d2d1349..3044c4d 100644 --- a/omniduct/utils/magics.py +++ b/omniduct/utils/magics.py @@ -1,7 +1,5 @@ from abc import ABCMeta, abstractmethod -from future.utils import with_metaclass - from interface_meta import quirk_docs @@ -12,6 +10,7 @@ def wrapped(*args, **kwargs): args += new_args kwargs.update(new_kwargs) return f(*args, **kwargs) + return wrapped @@ -24,32 +23,33 @@ def wrapped(*args, **kwargs): args += new_args kwargs.update(new_kwargs) return f(*args, **kwargs) + return wrapped def _process_line_arguments(line_arguments): from IPython import get_ipython + args = [] kwargs = {} reached_kwargs = False for arg in line_arguments.split(): - if '=' in arg: + if "=" in arg: reached_kwargs = True - key, value = arg.split('=') + key, value = arg.split("=") value = eval(value, get_ipython().user_ns) if key in kwargs: - raise ValueError('Duplicate keyword argument `{}`.'.format(key)) + raise ValueError(f"Duplicate keyword argument `{key}`.") kwargs[key] = value else: if reached_kwargs: - raise ValueError('Positional argument `{}` after keyword argument.'.format(arg)) + raise ValueError(f"Positional argument `{arg}` after keyword argument.") args.append(arg) return args, kwargs -class MagicsProvider(with_metaclass(ABCMeta, object)): - - @quirk_docs('_register_magics') +class MagicsProvider(metaclass=ABCMeta): + @quirk_docs("_register_magics") def register_magics(self, base_name=None): base_name = base_name or self.name if base_name is None: @@ -57,10 +57,11 @@ def register_magics(self, base_name=None): try: from IPython import get_ipython + ip = get_ipython() assert ip is not None has_ipython = True - except Exception: + except Exception: # pylint: disable=broad-exception-caught has_ipython = False if has_ipython: diff --git a/omniduct/utils/ports.py b/omniduct/utils/ports.py index 97e6b7e..d616acc 100644 --- a/omniduct/utils/ports.py +++ b/omniduct/utils/ports.py @@ -42,7 +42,7 @@ def get_free_local_port(): s.bind(("", 0)) free_port = s.getsockname()[1] s.close() - logger.info('found port {0}'.format(free_port)) + logger.info(f"found port {free_port}") return free_port @@ -52,7 +52,7 @@ def is_port_bound(hostname, port, timeout=None): s.settimeout(timeout) try: s.connect((hostname, port)) - except: + except: # pylint: disable=bare-except return False finally: s.close() @@ -66,13 +66,12 @@ def naive_load_balancer(hosts, port): random.shuffle(hosts) # Check if host is available and if so return it - pattern = re.compile(r'(?P[^\:]+)(?::(?P[0-9]{1,5}))?') + pattern = re.compile(r"(?P[^\:]+)(?::(?P[0-9]{1,5}))?") for host in hosts: m = pattern.match(host) - if is_port_bound(m.group('host'), int(m.group('port') or port), timeout=1): + if is_port_bound(m.group("host"), int(m.group("port") or port), timeout=1): return host - else: - logger.warning("Avoiding down or inaccessible host: '{}'.".format(host)) + logger.warning(f"Avoiding down or inaccessible host: '{host}'.") raise RuntimeError( "Unable to connect to any of the hosts associated with this service. " diff --git a/omniduct/utils/processes.py b/omniduct/utils/processes.py index 1e6e3da..dba4533 100644 --- a/omniduct/utils/processes.py +++ b/omniduct/utils/processes.py @@ -1,31 +1,25 @@ import os import signal -import sys +import subprocess +from subprocess import TimeoutExpired from omniduct.utils.config import config as omniduct_config from omniduct.utils.debug import logger -if os.name == 'posix' and sys.version_info[0] < 3: - import subprocess32 as subprocess - from subprocess32 import TimeoutExpired -else: - import subprocess - from subprocess import TimeoutExpired -__all__ = ['run_in_subprocess', 'TimeoutExpired', 'Timeout', 'TimeoutError'] +__all__ = ["run_in_subprocess", "TimeoutExpired", "Timeout"] DEFAULT_SUBPROCESS_CONFIG = { - 'shell': True, - 'close_fds': False, - 'stdin': None, - 'stdout': subprocess.PIPE, - 'stderr': subprocess.PIPE, - 'preexec_fn': os.setsid # Set the process as the group leader, so we can kill recursively + "shell": True, + "close_fds": False, + "stdin": None, + "stdout": subprocess.PIPE, + "stderr": subprocess.PIPE, + "preexec_fn": os.setsid, # Set the process as the group leader, so we can kill recursively } -class SubprocessResults(object): - +class SubprocessResults: def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) @@ -48,35 +42,38 @@ def run_in_subprocess(cmd, check_output=False, **kwargs): Subprocess used to run command. """ - logger.debug('Executing command: {0}'.format(cmd)) - config = DEFAULT_SUBPROCESS_CONFIG.copy() - config.update(kwargs) - if not check_output: - if omniduct_config.logging_level < 20: - config['stdout'] = None - config['stderr'] = None - else: - config['stdout'] = open(os.devnull, 'w') - config['stderr'] = open(os.devnull, 'w') - timeout = config.pop('timeout', None) - - process = subprocess.Popen(cmd, **config) - try: - stdout, stderr = process.communicate(None, timeout=timeout) - except subprocess.TimeoutExpired: - os.killpg(os.getpgid(process.pid), signal.SIGINT) # send signal to the process group, recursively killing all children - output, unused_err = process.communicate() - raise subprocess.TimeoutExpired(process.args, timeout, output=output) - return SubprocessResults(returncode=process.returncode, stdout=stdout or b'', stderr=stderr or b'') - - -class TimeoutError(Exception): - pass - - -class Timeout(object): - - def __init__(self, seconds=1, error_message='Timeout'): + with open(os.devnull, "w", encoding="utf-8") as devnull: + logger.debug(f"Executing command: {cmd}") + config = DEFAULT_SUBPROCESS_CONFIG.copy() + config.update(kwargs) + if not check_output: + if omniduct_config.logging_level < 20: + config["stdout"] = None + config["stderr"] = None + else: + config["stdout"] = devnull + config["stderr"] = devnull + timeout = config.pop("timeout", None) + + with subprocess.Popen(cmd, **config) as process: + try: + stdout, stderr = process.communicate(None, timeout=timeout) + returncode = process.returncode + except subprocess.TimeoutExpired as e: + os.killpg( + os.getpgid(process.pid), signal.SIGINT + ) # send signal to the process group, recursively killing all children + output, unused_err = process.communicate() + raise subprocess.TimeoutExpired( + process.args, timeout, output=output + ) from e + return SubprocessResults( + returncode=returncode, stdout=stdout or b"", stderr=stderr or b"" + ) + + +class Timeout: + def __init__(self, seconds=1, error_message="Timeout"): self.seconds = seconds self.error_message = error_message @@ -87,5 +84,6 @@ def __enter__(self): signal.signal(signal.SIGALRM, self.handle_timeout) signal.alarm(self.seconds) + # pylint: disable-next=redefined-builtin def __exit__(self, type, value, traceback): signal.alarm(0) diff --git a/omniduct/utils/proxies.py b/omniduct/utils/proxies.py index 533f468..6a27cc1 100644 --- a/omniduct/utils/proxies.py +++ b/omniduct/utils/proxies.py @@ -11,7 +11,7 @@ class TreeProxy: ``` """ - __slots__ = ('__tree__', '__nodename__') + __slots__ = ("__tree__", "__nodename__") @classmethod def _for_dict(cls, dct, key_parser=None, name=None): @@ -30,9 +30,11 @@ def __init__(self, tree, name=None): def __getitem__(self, name): if name in self.__tree__: if not isinstance(self.__tree__[name], TreeProxy): - return TreeProxy._for_tree(self.__tree__[name], name=self.__name_of_child(name)) + return TreeProxy._for_tree( + self.__tree__[name], name=self.__name_of_child(name) + ) return self.__tree__[name] - raise KeyError('Invalid child node `{node_name}`.'.format(node_name=name)) + raise KeyError(f"Invalid child node `{name}`.") def __iter__(self): return iter(self.__tree__) @@ -43,16 +45,18 @@ def __len__(self): def __getattr__(self, name): try: return self[name] - except KeyError: - raise AttributeError('Invalid child node `{node_name}`.'.format(node_name=name)) + except KeyError as e: + raise AttributeError(f"Invalid child node `{name}`.") from e def __dir__(self): return list(self.__tree__) def __repr__(self): if self.__nodename__: - return "".format(self.__nodename__, len(self.__tree__)) - return "".format(len(self.__tree__)) + return ( + f"" + ) + return f"" # Helpers @@ -65,7 +69,9 @@ def __name_of_child(self, child): def __dict_to_tree(cls, dct, key_parser): tree = {} for key, value in dct.items(): - cls.__add_nested_key_value(tree, keys=key_parser(key, value) if key_parser else [key], value=value) + cls.__add_nested_key_value( + tree, keys=key_parser(key, value) if key_parser else [key], value=value + ) return tree @classmethod @@ -76,8 +82,6 @@ def __add_nested_key_value(cls, tree, keys, value): tree = tree[key] if len(tree) and None not in tree: raise ValueError( - "`TreeProxy` objects can only proxy trees with values only on leaf " - "nodes; error encounted while trying to add value to node {}." - .format(keys) + f"`TreeProxy` objects can only proxy trees with values only on leaf nodes; error encounted while trying to add value to node {keys}." ) tree[None] = value diff --git a/omniduct/utils/submodules.py b/omniduct/utils/submodules.py index 2c0c3dd..f221087 100644 --- a/omniduct/utils/submodules.py +++ b/omniduct/utils/submodules.py @@ -4,13 +4,13 @@ def import_submodules(package_name): - """ Import all submodules of a module, recursively + """Import all submodules of a module, recursively :param package_name: Package name :type package_name: str :rtype: dict[types.ModuleType] """ package = sys.modules[package_name] return { - name: importlib.import_module(package_name + '.' + name) + name: importlib.import_module(package_name + "." + name) for loader, name, is_pkg in pkgutil.walk_packages(package.__path__) } diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f733fd6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,219 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "omniduct" +dynamic = ["version"] +description = "A toolkit providing a uniform interface for connecting to and extracting data from a wide variety of (potentially remote) data stores (including HDFS, Hive, Presto, MySQL, etc)." +readme = "README.md" +license = "" +authors = [ + { name = "Matthew Wardrop", email = "mpwardrop@gmail.com" }, + { name = "Dan Frank", email = "danfrankj@gmail.com" }, +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +requires-python = ">=3.7" +dependencies = [ + "decorator", + "interface_meta>=1.1.0,<2", + "jinja2", + "lazy-object-proxy", + "packaging", + "pandas>=0.20.3", + "progressbar2>=3.30.0", + "python-dateutil", + "pyyaml", + "sqlalchemy", + "sqlparse", + "wrapt", +] + +[project.optional-dependencies] +all = [ + "boto3", + "coverage", + "flake8", + "mock", + "nose", + "paramiko", + "pexpect", + "pexpect", + "pydruid>=0.4.0", + "pyexasol", + "pyfakefs", + "pyhive[hive]>=0.4", + "pyhive[presto]>=0.4", + "pyspark", + "pywebhdfs", + "requests", + "requests", + "snowflake-sqlalchemy", + "sphinx", + "sphinx_autobuild", + "sphinx_rtd_theme", + "thrift>=0.10.0", +] +docs = [ + "sphinx", + "sphinx_autobuild", + "sphinx_rtd_theme", +] +druid = [ + "pydruid>=0.4.0", +] +exasol = [ + "pyexasol", +] +hiveserver2 = [ + "pyhive[hive]>=0.4", + "thrift>=0.10.0", +] +presto = [ + "pyhive[presto]>=0.4", +] +pyspark = [ + "pyspark", +] +rest = [ + "requests", +] +s3 = [ + "boto3", +] +snowflake = [ + "snowflake-sqlalchemy", +] +ssh = [ + "pexpect", +] +ssh_paramiko = [ + "paramiko", + "pexpect", +] +test = [ + "coverage", + "flake8", + "mock", + "nose", + "pyfakefs", +] +webhdfs = [ + "pywebhdfs", + "requests", +] + +[project.urls] +Homepage = "https://github.com/airbnb/omniduct" + +[tool.hatch.version] +source = "vcs" + +[tool.hatch.build.hooks.vcs] +version-file = "omniduct/_version_info.py" + +# Build target configuration + +[tool.hatch.build.targets.sdist] +include = [ + "docs", + "example_wrapper", + "omniduct", + "tests", + "LICENSE", + "README.md", + "MANIFEST.in", + "pyproject.toml", +] + +# Testing configuration + +[tool.hatch.envs.default] +dependencies = [ + "mock", + "pyfakefs", + "pytest", + "pytest-cov", + "pytest-mock", + "requests", +] + +[tool.hatch.envs.default.scripts] +tests = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=omniduct --cov-report=xml -vv {args:tests}" + +[[tool.hatch.envs.test.matrix]] +python = ["37", "38", "39", "310", "311", "312"] + +[tool.hatch.envs.lint] +detached=true +dependencies = [ + "black==23.10.1", + "flake8==6.1.0", + "flake8-pyproject", + "pylint==2.17.4", +] + +[tool.hatch.envs.lint.scripts] +check = [ + "flake8 omniduct tests", + "pylint omniduct", + "black --check omniduct tests", +] +format = "black omniduct tests" + +# Linter and format configuration + +[tool.flake8] +ignore = [ + "C901","E203","E501","E712","E722","E731","W503","W504","W601" +] +max-complexity = 25 +max-line-length = 160 +import-order-style = "edited" +application-import-names = "formulaic" + +[tool.pylint."MESSAGES CONTROL"] +disable = [ + "cyclic-import", + "duplicate-code", + "eval-used", + "fixme", + "import-error", + "import-outside-toplevel", + "invalid-name", + "line-too-long", + "missing-class-docstring", + "missing-function-docstring", + "missing-module-docstring", + "no-member", + "protected-access", + "redefined-outer-name", + "too-few-public-methods", + "too-many-ancestors", + "too-many-arguments", + "too-many-branches", + "too-many-instance-attributes", + "too-many-lines", + "too-many-locals", + "too-many-public-methods", + "too-many-return-statements", + "too-many-statements", + "ungrouped-imports", + "unnecessary-lambda-assignment", + "unused-argument", + "use-dict-literal", +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b17e62e..0000000 --- a/setup.cfg +++ /dev/null @@ -1,6 +0,0 @@ -[bdist_wheel] -universal=1 - -[flake8] -ignore = E501,E712,E722,W503,W504,W601,W606 -max-line-length = 160 diff --git a/setup.py b/setup.py deleted file mode 100644 index e43d02a..0000000 --- a/setup.py +++ /dev/null @@ -1,56 +0,0 @@ -from setuptools import find_packages, setup - -# Extract version information from Omniduct _version.py -version_info = {} -with open('omniduct/_version.py') as version_file: - exec(version_file.read(), version_info) - -# Extract long description from readme -with open('README.md') as readme: - long_description = "" - while True: - line = readme.readline() - if line.startswith('`omniduct`'): - long_description = line - break - long_description += readme.read() - -setup( - # Package metadata - name="omniduct", - versioning='post', - version=version_info['__version__'], - author=version_info['__author__'], - author_email=version_info['__author_email__'], - url="https://github.com/airbnb/omniduct", - description=( - "A toolkit providing a uniform interface for connecting to and " - "extracting data from a wide variety of (potentially remote) data " - "stores (including HDFS, Hive, Presto, MySQL, etc)." - ), - long_description=long_description, - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Console', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'Intended Audience :: Information Technology', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - ], - - # Package details - packages=find_packages(), - include_package_data=True, - - # Dependencies - setup_requires=['setupmeta'], - install_requires=version_info['__dependencies__'], - extras_require=version_info['__optional_dependencies__'] -) diff --git a/tests/databases/test__namespaces.py b/tests/databases/test__namespaces.py index 12a23ff..47461be 100644 --- a/tests/databases/test__namespaces.py +++ b/tests/databases/test__namespaces.py @@ -4,101 +4,87 @@ class TestParseNamespaces: - def test_simple(self): namespace = ParsedNamespaces.from_name( - name='my_db.my_table', - namespaces=['database', 'table'] + name="my_db.my_table", namespaces=["database", "table"] ) - assert namespace.database == 'my_db' - assert namespace.table == 'my_table' - assert namespace.as_dict() == { - 'database': 'my_db', - 'table': 'my_table' - } + assert namespace.database == "my_db" + assert namespace.table == "my_table" + assert namespace.as_dict() == {"database": "my_db", "table": "my_table"} def test_quoted_names(self): namespace = ParsedNamespaces.from_name( - name='`my_db`.`my . table`', - namespaces=['catalog', 'database', 'table'], - quote_char='`' + name="`my_db`.`my . table`", + namespaces=["catalog", "database", "table"], + quote_char="`", ) assert namespace.catalog is None - assert namespace.database == 'my_db' - assert namespace.table == 'my . table' + assert namespace.database == "my_db" + assert namespace.table == "my . table" assert namespace.as_dict() == { - 'catalog': None, - 'database': 'my_db', - 'table': 'my . table' + "catalog": None, + "database": "my_db", + "table": "my . table", } def test_separator(self): namespace = ParsedNamespaces.from_name( - name='cat|my_db|my_table', - namespaces=['catalog', 'database', 'table'], - separator='|' + name="cat|my_db|my_table", + namespaces=["catalog", "database", "table"], + separator="|", ) - assert namespace.catalog == 'cat' - assert namespace.database == 'my_db' - assert namespace.table == 'my_table' + assert namespace.catalog == "cat" + assert namespace.database == "my_db" + assert namespace.table == "my_table" assert namespace.as_dict() == { - 'catalog': 'cat', - 'database': 'my_db', - 'table': 'my_table' + "catalog": "cat", + "database": "my_db", + "table": "my_table", } def test_parsing_failure(self): with pytest.raises(ValueError): - ParsedNamespaces.from_name( - name='my_db.my_table', - namespaces=['table'] - ) + ParsedNamespaces.from_name(name="my_db.my_table", namespaces=["table"]) def test_nonexistent_namespace(self): with pytest.raises(AttributeError): - ParsedNamespaces.from_name( - name='my_table', - namespaces=['table'] - ).database + ParsedNamespaces.from_name(name="my_table", namespaces=["table"]).database def test_not_encapsulated(self): - namespace = ParsedNamespaces.from_name('my_db.my_table', ['database', 'table']) - assert namespace.as_dict() == {'database': 'my_db', 'table': 'my_table'} + namespace = ParsedNamespaces.from_name("my_db.my_table", ["database", "table"]) + assert namespace.as_dict() == {"database": "my_db", "table": "my_table"} with pytest.raises(ValueError): - ParsedNamespaces.from_name(namespace, ['schema', 'table']) + ParsedNamespaces.from_name(namespace, ["schema", "table"]) def test_empty(self): - namespace = ParsedNamespaces.from_name("", ['database', 'table']) + namespace = ParsedNamespaces.from_name("", ["database", "table"]) assert bool(namespace) is False assert namespace.database is None assert namespace.table is None - assert str(namespace) == '' - assert repr(namespace) == 'Namespace<>' + assert str(namespace) == "" + assert repr(namespace) == "Namespace<>" def test_parent(self): namespace = ParsedNamespaces.from_name( - name='my_db.my_table', - namespaces=['catalog', 'database', 'table'] + name="my_db.my_table", namespaces=["catalog", "database", "table"] ) assert namespace.parent.name == '"my_db"' - assert namespace.parent.as_dict() == { - 'catalog': None, - 'database': 'my_db' - } + assert namespace.parent.as_dict() == {"catalog": None, "database": "my_db"} def test_casting(self): namespace = ParsedNamespaces.from_name( - name='my_db.my_table', - namespaces=['catalog', 'database', 'table'] + name="my_db.my_table", namespaces=["catalog", "database", "table"] ) assert str(namespace) == '"my_db"."my_table"' assert bool(namespace) is True - assert namespace.__bool__() == namespace.__nonzero__() # Python 2/3 compatibility + assert ( + namespace.__bool__() == namespace.__nonzero__() + ) # Python 2/3 compatibility assert repr(namespace) == 'Namespace<"my_db"."my_table">' diff --git a/tests/databases/test_base.py b/tests/databases/test_base.py index f9c5574..05e04a6 100644 --- a/tests/databases/test_base.py +++ b/tests/databases/test_base.py @@ -7,7 +7,6 @@ class DummyDatabaseClient(DatabaseClient): - PROTOCOLS = [] DEFAULT_PORT = None @@ -49,20 +48,20 @@ def _table_props(self, table, **kwargs): raise NotImplementedError -class DummyCursor(object): +class DummyCursor: """ This DBAPI2 compatible cursor wrapped around a Pandas DataFrame """ def __init__(self): self.df = pd.DataFrame( - {'field1': list(range(10)), 'field2': list('abcdefghij')} + {"field1": list(range(10)), "field2": list("abcdefghij")} ) self._df_iter = None @property def df_iter(self): - if not getattr(self, '_df_iter'): + if not getattr(self, "_df_iter"): self._df_iter = (tuple(row) for i, row in self.df.iterrows()) return self._df_iter @@ -70,10 +69,9 @@ def df_iter(self): @property def description(self): - return tuple([ - (name, None, None, None, None, None, None) - for name in self.df.columns - ]) + return tuple( + [(name, None, None, None, None, None, None) for name in self.df.columns] + ) @property def row_count(self): @@ -106,7 +104,6 @@ def setoutputsize(self, size, column=None): class TestDatabaseClient: - @pytest.fixture def db_client(self): return DummyDatabaseClient() @@ -114,8 +111,8 @@ def db_client(self): def test_query(self, db_client): result = db_client.query("DUMMY QUERY") - assert type(result) == pd.DataFrame - assert list(result.columns) == ['field1', 'field2'] + assert type(result) is pd.DataFrame + assert list(result.columns) == ["field1", "field2"] assert all(db_client("DUMMY_QUERY") == result) @@ -126,14 +123,20 @@ def test_multiple_queries(self, mocker, db_client): def test_statement_hash(self, db_client): statement = "DUMMY QUERY" - assert db_client.statement_hash(statement) == hashlib.sha256(statement.encode()).hexdigest() + assert ( + db_client.statement_hash(statement) + == hashlib.sha256(statement.encode()).hexdigest() + ) def test_stream(self, db_client): stream = db_client.stream("DUMMY QUERY") row = next(stream) - assert tuple(row) == (0, 'a') + assert tuple(row) == (0, "a") stream.close() def test_format(self, db_client): - result = db_client.query("DUMMY QUERY", format='csv') - assert result == "field1,field2\r\n0,a\r\n1,b\r\n2,c\r\n3,d\r\n4,e\r\n5,f\r\n6,g\r\n7,h\r\n8,i\r\n9,j\r\n" + result = db_client.query("DUMMY QUERY", format="csv") + assert ( + result + == "field1,field2\r\n0,a\r\n1,b\r\n2,c\r\n3,d\r\n4,e\r\n5,f\r\n6,g\r\n7,h\r\n8,i\r\n9,j\r\n" + ) diff --git a/tests/test_restful.py b/tests/test_restful.py index ea042bd..2f9f264 100644 --- a/tests/test_restful.py +++ b/tests/test_restful.py @@ -5,11 +5,10 @@ class TestRestClient(unittest.TestCase): - - @mock.patch.object(RestClient, 'connect') - @mock.patch('requests.request') + @mock.patch.object(RestClient, "connect") + @mock.patch("requests.request") def test_default_request(self, mock_request, mock_connect): - client = RestClient(server_protocol='http', host='localhost', port=80) - client.request('/') + client = RestClient(server_protocol="http", host="localhost", port=80) + client.request("/") mock_connect.assert_called_with() - mock_request.assert_called_with("get", "http://localhost:80/") + mock_request.assert_called_with("get", "http://localhost:80/", timeout=None) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 7b893be..0000000 --- a/tox.ini +++ /dev/null @@ -1,19 +0,0 @@ -[tox] -envlist = - py27 - py36 - py37 - py38 - -[testenv] -deps= - mock - flake8 - pyfakefs - pytest - pytest-cov - pytest-mock - requests -commands= - pytest --cov omniduct --cov-report term-missing tests - flake8 omniduct tests