Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:napalm-automation/napalm-base in…
Browse files Browse the repository at this point in the history
…to develop
  • Loading branch information
bewing committed Jul 19, 2017
2 parents 7d48685 + 83d51ae commit 7e782fc
Show file tree
Hide file tree
Showing 19 changed files with 443 additions and 17 deletions.
7 changes: 5 additions & 2 deletions napalm_base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# NAPALM base
from napalm_base.base import NetworkDriver
from napalm_base.exceptions import ModuleImportError
from napalm_base.mock import MockDriver
from napalm_base.utils import py23_compat

try:
Expand All @@ -49,7 +50,7 @@
]


def get_network_driver(module_name):
def get_network_driver(module_name, prepend=True):
"""
Searches for a class derived form the base NAPALM class NetworkDriver in a specific library.
The library name must repect the following pattern: napalm_[DEVICE_OS].
Expand Down Expand Up @@ -81,6 +82,8 @@ def get_network_driver(module_name):
napalm_base.exceptions.ModuleImportError: Cannot import "napalm_wrong". Is the library \
installed?
"""
if module_name == "mock":
return MockDriver

if not (isinstance(module_name, py23_compat.string_types) and len(module_name) > 0):
raise ModuleImportError('Please provide a valid driver name.')
Expand All @@ -91,7 +94,7 @@ def get_network_driver(module_name):
# Try to not raise error when users requests IOS-XR for e.g.
module_install_name = module_name.replace('-', '')
# Can also request using napalm_[SOMETHING]
if 'napalm_' not in module_install_name:
if 'napalm_' not in module_install_name and prepend is True:
module_install_name = 'napalm_{name}'.format(name=module_install_name)
module = importlib.import_module(module_install_name)
except ImportError:
Expand Down
11 changes: 5 additions & 6 deletions napalm_base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ def __init__(self, hostname, username, password, timeout=60, optional_args=None)
def __enter__(self):
try:
self.open()
except: # noqa
except Exception: # noqa
exc_info = sys.exc_info()
self.__raise_clean_exception(exc_info[0], exc_info[1], exc_info[2])
return self.__raise_clean_exception(exc_info[0], exc_info[1], exc_info[2])
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.close()
if exc_type is not None:
self.__raise_clean_exception(exc_type, exc_value, exc_traceback)
return self.__raise_clean_exception(exc_type, exc_value, exc_traceback)

def __del__(self):
"""
Expand All @@ -68,7 +68,7 @@ def __del__(self):
try:
if self.is_alive()["is_alive"]:
self.close()
except NotImplementedError:
except Exception:
pass

@staticmethod
Expand All @@ -90,8 +90,7 @@ def __raise_clean_exception(exc_type, exc_value, exc_traceback):
"https://github.com/napalm-automation/napalm/issues\n"
"Don't forget to include this traceback.")
print(epilog)
# Traceback should already be attached to exception; no need to re-attach
raise exc_value
return False

def open(self):
"""
Expand Down
199 changes: 199 additions & 0 deletions napalm_base/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright 2017 Dravetech AB. All rights reserved.
#
# The contents of this file are licensed under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with the
# License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.

# Python3 support
from __future__ import print_function
from __future__ import unicode_literals

from napalm_base.base import NetworkDriver
import napalm_base.exceptions

import inspect
import json
import os
import re


from pydoc import locate


def raise_exception(result):
exc = locate(result["exception"])
if exc:
raise exc(*result.get("args", []), **result.get("kwargs", {}))
else:
raise TypeError("Couldn't resolve exception {}", result["exception"])


def is_mocked_method(method):
mocked_methods = []
if method.startswith("get_") or method in mocked_methods:
return True
return False


def mocked_method(path, name, count):
parent_method = getattr(NetworkDriver, name)
parent_method_args = inspect.getargspec(parent_method)
modifier = 0 if 'self' not in parent_method_args.args else 1

def _mocked_method(*args, **kwargs):
# Check len(args)
if len(args) + len(kwargs) + modifier > len(parent_method_args.args):
raise TypeError(
"{}: expected at most {} arguments, got {}".format(
name, len(parent_method_args.args), len(args) + modifier))

# Check kwargs
unexpected = [x for x in kwargs if x not in parent_method_args.args]
if unexpected:
raise TypeError("{} got an unexpected keyword argument '{}'".format(name,
unexpected[0]))
return mocked_data(path, name, count)

return _mocked_method


def mocked_data(path, name, count):
filename = "{}.{}".format(os.path.join(path, name), count)
try:
with open(filename) as f:
result = json.loads(f.read())
except IOError:
raise NotImplementedError("You can provide mocked data in {}".format(filename))

if "exception" in result:
raise_exception(result)
else:
return result


class MockDevice(object):

def __init__(self, parent, profile):
self.parent = parent
self.profile = profile

def run_commands(self, commands):
"""Only useful for EOS"""
if "eos" in self.profile:
return self.parent.cli(commands).values()[0]
else:
raise AttributeError("MockedDriver instance has not attribute '_rpc'")


class MockDriver(NetworkDriver):

def __init__(self, hostname, username, password, timeout=60, optional_args=None):
"""
Supported optional_args:
* path(str) - path to where the mocked files are located
* profile(list) - List of profiles to assign
"""
self.hostname = hostname
self.username = username
self.password = password
self.path = optional_args["path"]
self.profile = optional_args.get("profile", [])

self.opened = False
self.calls = {}
self.device = MockDevice(self, self.profile)

# None no action, True load_merge, False load_replace
self.merge = None
self.filename = None
self.config = None

def _count_calls(self, name):
current_count = self.calls.get(name, 0)
self.calls[name] = current_count + 1
return self.calls[name]

def _raise_if_closed(self):
if not self.opened:
raise napalm_base.exceptions.ConnectionClosedException("connection closed")

def open(self):
self.opened = True

def close(self):
self.opened = False

def is_alive(self):
return {"is_alive": self.opened}

def cli(self, commands):
count = self._count_calls("cli")
result = {}
regexp = re.compile('[^a-zA-Z0-9]+')
for i, c in enumerate(commands):
sanitized = re.sub(regexp, '_', c)
name = "cli.{}.{}".format(count, sanitized)
filename = "{}.{}".format(os.path.join(self.path, name), i)
with open(filename, 'r') as f:
result[c] = f.read()
return result

def load_merge_candidate(self, filename=None, config=None):
count = self._count_calls("load_merge_candidate")
self._raise_if_closed()
self.merge = True
self.filename = filename
self.config = config
mocked_data(self.path, "load_merge_candidate", count)

def load_replace_candidate(self, filename=None, config=None):
count = self._count_calls("load_replace_candidate")
self._raise_if_closed()
self.merge = False
self.filename = filename
self.config = config
mocked_data(self.path, "load_replace_candidate", count)

def compare_config(self, filename=None, config=None):
count = self._count_calls("compare_config")
self._raise_if_closed()
return mocked_data(self.path, "compare_config", count)["diff"]

def commit_config(self):
count = self._count_calls("commit_config")
self._raise_if_closed()
self.merge = None
self.filename = None
self.config = None
mocked_data(self.path, "commit_config", count)

def discard_config(self):
count = self._count_calls("commit_config")
self._raise_if_closed()
self.merge = None
self.filename = None
self.config = None
mocked_data(self.path, "discard_config", count)

def _rpc(self, get):
"""This one is only useful for junos."""
if "junos" in self.profile:
return self.cli([get]).values()[0]
else:
raise AttributeError("MockedDriver instance has not attribute '_rpc'")

def __getattribute__(self, name):
if is_mocked_method(name):
self._raise_if_closed()
count = self._count_calls(name)
return mocked_method(self.path, name, count)
else:
return object.__getattribute__(self, name)
24 changes: 20 additions & 4 deletions napalm_base/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import re


# We put it here to compile it only once
numeric_compare_regex = re.compile("^(<|>|<=|>=|==|!=)(\d+(\.\d+){0,1})$")


def _get_validation_file(validation_file):
try:
with open(validation_file, 'r') as stream:
Expand Down Expand Up @@ -139,10 +143,22 @@ def _compare_getter(src, dst):

def compare_numeric(src_num, dst_num):
"""Compare numerical values. You can use '<%d','>%d'."""
complies = eval(str(dst_num)+src_num)
if not isinstance(complies, bool):
return False
return complies
dst_num = float(dst_num)

match = numeric_compare_regex.match(src_num)
if not match:
error = "Failed numeric comparison. Collected: {}. Expected: {}".format(dst_num, src_num)
raise ValueError(error)

operand = {
"<": "__lt__",
">": "__gt__",
">=": "__ge__",
"<=": "__le__",
"==": "__eq__",
"!=": "__ne__",
}
return getattr(dst_num, operand[match.group(1)])(float(match.group(2)))


def empty_tree(input_list):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name="napalm-base",
version='0.24.1',
version='0.24.3',
packages=find_packages(),
author="David Barroso, Kirk Byers, Mircea Ulinic",
author_email="dbarrosop@dravetech.com, ping@mirceaulinic.net, ktbyers@twb-tech.com",
Expand Down
8 changes: 4 additions & 4 deletions test/unit/TestGetNetworkDriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
@ddt
class TestGetNetworkDriver(unittest.TestCase):
"""Test the method get_network_driver."""
network_drivers = ('eos', 'fortios', 'ios', 'iosxr', 'IOS-XR', 'junos', 'ros', 'nxos',
'pluribus', 'panos', 'vyos')
network_drivers = ('eos', 'napalm_eos', 'fortios', 'ios', 'iosxr', 'IOS-XR', 'junos', 'ros',
'nxos', 'pluribus', 'panos', 'vyos')

@data(*network_drivers)
def test_get_network_driver(self, driver):
"""Check that we can get the desired driver and is instance of NetworkDriver."""
self.assertTrue(issubclass(get_network_driver(driver), NetworkDriver))

@data('fake', 'network', 'driver')
@data('fake', 'network', 'driver', 'sys', 1)
def test_get_wrong_network_driver(self, driver):
"""Check that inexisting driver throws ModuleImportError."""
self.assertRaises(ModuleImportError, get_network_driver, driver)
self.assertRaises(ModuleImportError, get_network_driver, driver, prepend=False)
Loading

0 comments on commit 7e782fc

Please sign in to comment.