diff --git a/rope/refactor/importutils/__init__.py b/rope/refactor/importutils/__init__.py index ef0f1bac..1e1810e5 100644 --- a/rope/refactor/importutils/__init__.py +++ b/rope/refactor/importutils/__init__.py @@ -8,6 +8,8 @@ import rope.base.codeanalyze import rope.base.evaluate from rope.base import libutils +from rope.base.prefs import get_preferred_import_style +from rope.base.prefs import ImportStyle from rope.base.change import ChangeContents, ChangeSet from rope.refactor import occurrences, rename from rope.refactor.importutils import actions, module_imports @@ -299,6 +301,7 @@ def get_module_imports(project, pymodule): def add_import(project, pymodule, module_name, name=None): + preferred_import_style = get_preferred_import_style(project.prefs) imports = get_module_imports(project, pymodule) candidates = [] names = [] @@ -306,13 +309,15 @@ def add_import(project, pymodule, module_name, name=None): # from mod import name if name is not None: from_import = FromImport(module_name, 0, [(name, None)]) + if preferred_import_style == ImportStyle.from_global: + selected_import = from_import names.append(name) candidates.append(from_import) # from pkg import mod if "." in module_name: pkg, mod = module_name.rsplit(".", 1) from_import = FromImport(pkg, 0, [(mod, None)]) - if project.prefs.get("prefer_module_from_imports"): + if preferred_import_style == ImportStyle.from_module: selected_import = from_import candidates.append(from_import) if name: diff --git a/ropetest/refactor/movetest.py b/ropetest/refactor/movetest.py index 804fa688..1bf07655 100644 --- a/ropetest/refactor/movetest.py +++ b/ropetest/refactor/movetest.py @@ -254,6 +254,75 @@ def a_function(): self.mod3.read(), ) + def test_adding_imports_preferred_import_style_is_normal_import(self) -> None: + self.project.prefs.imports.preferred_import_style = "normal-import" + self.origin_module.write(dedent("""\ + class AClass(object): + pass + def a_function(): + pass + """)) + self.mod3.write(dedent("""\ + import origin_module + a_var = origin_module.AClass() + origin_module.a_function()""")) + # Move to destination_module_in_pkg which is in a different package + self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg) + self.assertEqual( + dedent("""\ + import origin_module + import pkg.destination_module_in_pkg + a_var = pkg.destination_module_in_pkg.AClass() + origin_module.a_function()"""), + self.mod3.read(), + ) + + def test_adding_imports_preferred_import_style_is_from_module(self) -> None: + self.project.prefs.imports.preferred_import_style = "from-module" + self.origin_module.write(dedent("""\ + class AClass(object): + pass + def a_function(): + pass + """)) + self.mod3.write(dedent("""\ + import origin_module + a_var = origin_module.AClass() + origin_module.a_function()""")) + # Move to destination_module_in_pkg which is in a different package + self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg) + self.assertEqual( + dedent("""\ + import origin_module + from pkg import destination_module_in_pkg + a_var = destination_module_in_pkg.AClass() + origin_module.a_function()"""), + self.mod3.read(), + ) + + def test_adding_imports_preferred_import_style_is_from_global(self) -> None: + self.project.prefs.imports.preferred_import_style = "from-global" + self.origin_module.write(dedent("""\ + class AClass(object): + pass + def a_function(): + pass + """)) + self.mod3.write(dedent("""\ + import origin_module + a_var = origin_module.AClass() + origin_module.a_function()""")) + # Move to destination_module_in_pkg which is in a different package + self._move(self.origin_module, self.origin_module.read().index("AClass") + 1, self.destination_module_in_pkg) + self.assertEqual( + dedent("""\ + import origin_module + from pkg.destination_module_in_pkg import AClass + a_var = AClass() + origin_module.a_function()"""), + self.mod3.read(), + ) + def test_adding_imports_noprefer_from_module(self) -> None: self.project.prefs["prefer_module_from_imports"] = False self.origin_module.write(dedent("""\