diff --git a/CHANGELOG.md b/CHANGELOG.md index ab69be2a..3a7afb92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +## 1.14.3 +* bugfix where `Command` was not aware of default call args when wrapping the module [#559](https://github.com/amoffat/sh/pull/573) + ## 1.14.1 - 10/24/20 * bugfix where setting `_ok_code` to not include 0, but 0 was the exit code [#545](https://github.com/amoffat/sh/pull/545) diff --git a/sh.py b/sh.py index 1baa6bf3..830d9fea 100644 --- a/sh.py +++ b/sh.py @@ -3512,7 +3512,18 @@ def __init__(self, self_module, baked_args=None): # if we set this to None. and 3.3 needs a value for __path__ self.__path__ = [] self.__self_module = self_module - self.__env = Environment(globals(), baked_args=baked_args) + + # Copy the Command class and add any baked call kwargs to it + cls_attrs = Command.__dict__.copy() + if baked_args: + call_args, _ = Command._extract_call_args(baked_args) + cls_attrs['_call_args'] = cls_attrs['_call_args'].copy() + cls_attrs['_call_args'].update(call_args) + command_cls = type(Command.__name__, Command.__bases__, cls_attrs) + globs = globals().copy() + globs[Command.__name__] = command_cls + + self.__env = Environment(globs, baked_args=baked_args) def __getattr__(self, name): return self.__env[name] diff --git a/test.py b/test.py index d9cc1dac..4e9e8e03 100644 --- a/test.py +++ b/test.py @@ -3138,6 +3138,13 @@ def test_reimport_no_interfere(self): _sh.echo("-n", "TEST") self.assertEqual("TEST", out.getvalue()) + def test_command_with_baked_call_args(self): + # Test that sh.Command() knows about baked call args + import sh + _sh = sh(_ok_code=1) + self.assertEqual(sh.Command._call_args['ok_code'], 0) + self.assertEqual(_sh.Command._call_args['ok_code'], 1) + def test_importer_detects_module_name(self): import sh _sh = sh()