diff --git a/napalm_base/__init__.py b/napalm_base/__init__.py index 3b09d4d4..22ced908 100644 --- a/napalm_base/__init__.py +++ b/napalm_base/__init__.py @@ -49,7 +49,7 @@ ] -def get_network_driver(module_name): +def get_network_driver(module_name, prepend=True): """ Searches for a class derived form the base NAPALM class NetworkDriver in a specific library. The library name must repect the following pattern: napalm_[DEVICE_OS]. @@ -91,7 +91,7 @@ def get_network_driver(module_name): # Try to not raise error when users requests IOS-XR for e.g. module_install_name = module_name.replace('-', '') # Can also request using napalm_[SOMETHING] - if 'napalm_' not in module_install_name: + if 'napalm_' not in module_install_name and prepend is True: module_install_name = 'napalm_{name}'.format(name=module_install_name) module = importlib.import_module(module_install_name) except ImportError: diff --git a/test/unit/TestGetNetworkDriver.py b/test/unit/TestGetNetworkDriver.py index 20527d1a..88d14d8a 100644 --- a/test/unit/TestGetNetworkDriver.py +++ b/test/unit/TestGetNetworkDriver.py @@ -13,15 +13,15 @@ @ddt class TestGetNetworkDriver(unittest.TestCase): """Test the method get_network_driver.""" - network_drivers = ('eos', 'fortios', 'ios', 'iosxr', 'IOS-XR', 'junos', 'ros', 'nxos', - 'pluribus', 'panos', 'vyos') + network_drivers = ('eos', 'napalm_eos', 'fortios', 'ios', 'iosxr', 'IOS-XR', 'junos', 'ros', + 'nxos', 'pluribus', 'panos', 'vyos') @data(*network_drivers) def test_get_network_driver(self, driver): """Check that we can get the desired driver and is instance of NetworkDriver.""" self.assertTrue(issubclass(get_network_driver(driver), NetworkDriver)) - @data('fake', 'network', 'driver') + @data('fake', 'network', 'driver', 'sys', 1) def test_get_wrong_network_driver(self, driver): """Check that inexisting driver throws ModuleImportError.""" - self.assertRaises(ModuleImportError, get_network_driver, driver) + self.assertRaises(ModuleImportError, get_network_driver, driver, prepend=False)