diff --git a/yt/config.py b/yt/config.py index 37c8b27ba77..2bd63266ef4 100644 --- a/yt/config.py +++ b/yt/config.py @@ -67,14 +67,22 @@ ) -CONFIG_DIR = os.environ.get( - "XDG_CONFIG_HOME", os.path.join(os.path.expanduser("~"), ".config", "yt") -) -if not os.path.exists(CONFIG_DIR): - try: - os.makedirs(CONFIG_DIR) - except OSError: - warnings.warn("unable to create yt config directory") +def config_dir(): + config_root = os.environ.get( + "XDG_CONFIG_HOME", os.path.join(os.path.expanduser("~"), ".config") + ) + conf_dir = os.path.join(config_root, "yt") + + if not os.path.exists(conf_dir): + try: + os.makedirs(conf_dir) + except OSError: + warnings.warn("unable to create yt config directory") + return conf_dir + + +def old_config_file(): + return os.path.join(config_dir(), "ytrc") class YTConfig: @@ -168,21 +176,20 @@ def write(self, file_handler): @staticmethod def get_global_config_file(): - return os.path.join(CONFIG_DIR, "yt.toml") + return os.path.join(config_dir(), "yt.toml") @staticmethod def get_local_config_file(): return os.path.join(os.path.abspath(os.curdir), "yt.toml") -OLD_CONFIG_FILE = os.path.join(CONFIG_DIR, "ytrc") _global_config_file = YTConfig.get_global_config_file() _local_config_file = YTConfig.get_local_config_file() -if os.path.exists(OLD_CONFIG_FILE): +if os.path.exists(old_config_file()): if os.path.exists(_global_config_file): issue_deprecation_warning( - f"The configuration file {OLD_CONFIG_FILE} is deprecated in " + f"The configuration file {old_config_file()} is deprecated in " f"favor of {_global_config_file}. Currently, both are present. " "Please manually remove the deprecated one to silence " "this warning.", @@ -203,7 +210,7 @@ def get_local_config_file(): stack = inspect.stack() if len(stack) < 2 or stack[-2].function != "importlib_load_entry_point": issue_deprecation_warning( - f"The configuration file {OLD_CONFIG_FILE} is deprecated. " + f"The configuration file {old_config_file()} is deprecated. " f"Please migrate your config to {_global_config_file} by running: " "'yt config migrate'", since="4.0.0", @@ -225,6 +232,10 @@ def get_local_config_file(): ytcfg = YTConfig() ytcfg.update(ytcfg_defaults, metadata={"source": "defaults"}) +# For backward compatibility, do not use these vars internally in yt +CONFIG_DIR = config_dir() +OLD_CONFIG_FILE = old_config_file() + # Try loading the local config first, otherwise fall back to global config if os.path.exists(_local_config_file): ytcfg.read(_local_config_file) diff --git a/yt/fields/tests/test_fields_plugins.py b/yt/fields/tests/test_fields_plugins.py index 85a89a8885b..b15516fe44c 100644 --- a/yt/fields/tests/test_fields_plugins.py +++ b/yt/fields/tests/test_fields_plugins.py @@ -2,7 +2,7 @@ import unittest import yt -from yt.config import CONFIG_DIR, ytcfg +from yt.config import config_dir, ytcfg from yt.testing import assert_raises, fake_random_ds TEST_PLUGIN_FILE = """ @@ -22,14 +22,14 @@ def myfunc(): def setUpModule(): my_plugin_name = ytcfg.get("yt", "plugin_filename") # In the following order if plugin_filename is: an absolute path, located in - # the CONFIG_DIR, located in an obsolete config dir. + # the config_dir(), located in an obsolete config dir. old_config_dir = os.path.join(os.path.expanduser("~"), ".yt") - for base_prefix in ("", CONFIG_DIR, old_config_dir): + for base_prefix in ("", config_dir(), old_config_dir): potential_plugin_file = os.path.join(base_prefix, my_plugin_name) if os.path.isfile(potential_plugin_file): os.rename(potential_plugin_file, potential_plugin_file + ".bak_test") - plugin_file = os.path.join(CONFIG_DIR, my_plugin_name) + plugin_file = os.path.join(config_dir(), my_plugin_name) with open(plugin_file, "w") as fh: fh.write(TEST_PLUGIN_FILE) @@ -39,7 +39,7 @@ def tearDownModule(): my_plugins_fields.clear() my_plugin_name = ytcfg.get("yt", "plugin_filename") - plugin_file = os.path.join(CONFIG_DIR, my_plugin_name) + plugin_file = os.path.join(config_dir(), my_plugin_name) os.remove(plugin_file) @@ -48,32 +48,32 @@ class TestPluginFile(unittest.TestCase): def setUpClass(cls): my_plugin_name = ytcfg.get("yt", "plugin_filename") # In the following order if plugin_filename is: an absolute path, located in - # the CONFIG_DIR, located in an obsolete config dir. + # the config_dir(), located in an obsolete config dir. old_config_dir = os.path.join(os.path.expanduser("~"), ".yt") - for base_prefix in ("", CONFIG_DIR, old_config_dir): + for base_prefix in ("", config_dir(), old_config_dir): potential_plugin_file = os.path.join(base_prefix, my_plugin_name) if os.path.isfile(potential_plugin_file): os.rename(potential_plugin_file, potential_plugin_file + ".bak_test") - plugin_file = os.path.join(CONFIG_DIR, my_plugin_name) + plugin_file = os.path.join(config_dir(), my_plugin_name) with open(plugin_file, "w") as fh: fh.write(TEST_PLUGIN_FILE) @classmethod def tearDownClass(cls): my_plugin_name = ytcfg.get("yt", "plugin_filename") - plugin_file = os.path.join(CONFIG_DIR, my_plugin_name) + plugin_file = os.path.join(config_dir(), my_plugin_name) os.remove(plugin_file) old_config_dir = os.path.join(os.path.expanduser("~"), ".yt") - for base_prefix in ("", CONFIG_DIR, old_config_dir): + for base_prefix in ("", config_dir(), old_config_dir): potential_plugin_file = os.path.join(base_prefix, my_plugin_name) if os.path.isfile(potential_plugin_file + ".bak_test"): os.rename(potential_plugin_file + ".bak_test", potential_plugin_file) del yt.myfunc def testCustomField(self): - plugin_file = os.path.join(CONFIG_DIR, ytcfg.get("yt", "plugin_filename")) + plugin_file = os.path.join(config_dir(), ytcfg.get("yt", "plugin_filename")) msg = f"INFO:yt:Loading plugins from {plugin_file}" with self.assertLogs("yt", level="INFO") as cm: diff --git a/yt/funcs.py b/yt/funcs.py index e5468d16cac..9d83f4bffe6 100644 --- a/yt/funcs.py +++ b/yt/funcs.py @@ -945,7 +945,7 @@ def enable_plugins(plugin_filename=None): file is shared with it. """ import yt - from yt.config import CONFIG_DIR, ytcfg + from yt.config import config_dir, ytcfg from yt.fields.my_plugin_fields import my_plugins_fields if plugin_filename is not None: @@ -959,7 +959,7 @@ def enable_plugins(plugin_filename=None): # - obsolete config dir. my_plugin_name = ytcfg.get("yt", "plugin_filename") old_config_dir = os.path.join(os.path.expanduser("~"), ".yt") - for base_prefix in ("", CONFIG_DIR, old_config_dir): + for base_prefix in ("", config_dir(), old_config_dir): if os.path.isfile(os.path.join(base_prefix, my_plugin_name)): _fn = os.path.join(base_prefix, my_plugin_name) break @@ -971,7 +971,7 @@ def enable_plugins(plugin_filename=None): "Your plugin file is located in a deprecated directory. " "Please move it from %s to %s", os.path.join(old_config_dir, my_plugin_name), - os.path.join(CONFIG_DIR, my_plugin_name), + os.path.join(config_dir(), my_plugin_name), ) mylog.info("Loading plugins from %s", _fn) diff --git a/yt/utilities/configure.py b/yt/utilities/configure.py index e2f661ed11e..5ae73ec1ce4 100644 --- a/yt/utilities/configure.py +++ b/yt/utilities/configure.py @@ -2,7 +2,7 @@ import os import sys -from yt.config import OLD_CONFIG_FILE, YTConfig, ytcfg_defaults +from yt.config import YTConfig, old_config_file, ytcfg_defaults CONFIG = YTConfig() @@ -48,7 +48,7 @@ def write_config(config_file): def migrate_config(): - if not os.path.exists(OLD_CONFIG_FILE): + if not os.path.exists(old_config_file()): print("Old config not found.") sys.exit(1) @@ -56,7 +56,7 @@ def migrate_config(): # Preserve case: # See https://stackoverflow.com/questions/1611799/preserve-case-in-configparser old_config.optionxform = str - old_config.read(OLD_CONFIG_FILE) + old_config.read(old_config_file()) # In order to migrate, we'll convert everything to lowercase, and map that # to the new snake_case convention @@ -93,8 +93,8 @@ def usesCamelCase(key): global_config_file = YTConfig.get_global_config_file() print(f"Writing a new config file to: {global_config_file}") write_config(global_config_file) - print(f"Backing up the old config file: {OLD_CONFIG_FILE}.bak") - os.rename(OLD_CONFIG_FILE, OLD_CONFIG_FILE + ".bak") + print(f"Backing up the old config file: {old_config_file()}.bak") + os.rename(old_config_file(), old_config_file() + ".bak") def rm_config(section, option, config_file): diff --git a/yt/utilities/tests/test_config.py b/yt/utilities/tests/test_config.py index f84a4b0b96b..f06b182b519 100644 --- a/yt/utilities/tests/test_config.py +++ b/yt/utilities/tests/test_config.py @@ -1,16 +1,17 @@ import contextlib import os +import shutil import sys +import tempfile import unittest import unittest.mock as mock from io import StringIO import yt.config import yt.utilities.command_line -from yt.config import OLD_CONFIG_FILE, YTConfig +from yt.config import YTConfig, old_config_file -GLOBAL_CONFIG_FILE = YTConfig.get_global_config_file() -LOCAL_CONFIG_FILE = YTConfig.get_local_config_file() +XDG_CONFIG_HOME = os.environ.get("XDG_CONFIG_HOME") _TEST_PLUGIN = "_test_plugin.py" # NOTE: the normalization of the crazy camel-case will be checked @@ -46,24 +47,18 @@ class SysExitException(Exception): pass -def setUpModule(): - for cfgfile in (GLOBAL_CONFIG_FILE, OLD_CONFIG_FILE, LOCAL_CONFIG_FILE): - if os.path.exists(cfgfile): - os.rename(cfgfile, cfgfile + ".bak_test") - - if cfgfile == GLOBAL_CONFIG_FILE: - yt.utilities.configure.CONFIG = YTConfig() - if not yt.utilities.configure.CONFIG.has_section("yt"): - yt.utilities.configure.CONFIG.add_section("yt") - - -def tearDownModule(): - for cfgfile in (GLOBAL_CONFIG_FILE, OLD_CONFIG_FILE, LOCAL_CONFIG_FILE): - if os.path.exists(cfgfile + ".bak_test"): - os.rename(cfgfile + ".bak_test", cfgfile) +class TestYTConfig(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + os.environ["XDG_CONFIG_HOME"] = self.tmpdir + def tearDown(self): + shutil.rmtree(self.tmpdir) + if XDG_CONFIG_HOME: + os.environ["XDG_CONFIG_HOME"] = XDG_CONFIG_HOME + else: + os.environ.pop("XDG_CONFIG_HOME") -class TestYTConfig(unittest.TestCase): def _runYTConfig(self, args): args = ["yt", "config"] + args retcode = 0 @@ -108,7 +103,7 @@ def testConfigCommands(self): def remove_spaces_and_breaks(s): return "".join(s.split()) - self.assertFalse(os.path.exists(GLOBAL_CONFIG_FILE)) + self.assertFalse(os.path.exists(YTConfig.get_global_config_file())) info = self._runYTConfig(["--help"]) self.assertEqual(info["rc"], 0) @@ -142,12 +137,13 @@ def remove_spaces_and_breaks(s): self._testKeyTypeError("foo.bar", "foo", "bar", expect_error=False) def tearDown(self): - if os.path.exists(GLOBAL_CONFIG_FILE): - os.remove(GLOBAL_CONFIG_FILE) + if os.path.exists(YTConfig.get_global_config_file()): + os.remove(YTConfig.get_global_config_file()) class TestYTConfigGlobalLocal(TestYTConfig): def setUp(self): + super().setUp() with open(YTConfig.get_local_config_file(), mode="w") as f: f.writelines("[yt]\n") with open(YTConfig.get_global_config_file(), mode="w") as f: @@ -164,33 +160,35 @@ def testAmbiguousConfig(self): class TestYTConfigMigration(TestYTConfig): def setUp(self): - if not os.path.exists(os.path.dirname(OLD_CONFIG_FILE)): - os.makedirs(os.path.dirname(OLD_CONFIG_FILE)) + super().setUp() + if not os.path.exists(os.path.dirname(old_config_file())): + os.makedirs(os.path.dirname(old_config_file())) - with open(OLD_CONFIG_FILE, "w") as fh: + with open(old_config_file(), "w") as fh: fh.write(_DUMMY_CFG_INI) - if os.path.exists(GLOBAL_CONFIG_FILE): - os.remove(GLOBAL_CONFIG_FILE) + if os.path.exists(YTConfig.get_global_config_file()): + os.remove(YTConfig.get_global_config_file()) def tearDown(self): - if os.path.exists(GLOBAL_CONFIG_FILE): - os.remove(GLOBAL_CONFIG_FILE) - if os.path.exists(OLD_CONFIG_FILE + ".bak"): - os.remove(OLD_CONFIG_FILE + ".bak") + if os.path.exists(YTConfig.get_global_config_file()): + os.remove(YTConfig.get_global_config_file()) + if os.path.exists(old_config_file() + ".bak"): + os.remove(old_config_file() + ".bak") + super().tearDown() def testConfigMigration(self): - self.assertFalse(os.path.exists(GLOBAL_CONFIG_FILE)) - self.assertTrue(os.path.exists(OLD_CONFIG_FILE)) + self.assertFalse(os.path.exists(YTConfig.get_global_config_file())) + self.assertTrue(os.path.exists(old_config_file())) info = self._runYTConfig(["migrate"]) self.assertEqual(info["rc"], 0) - self.assertTrue(os.path.exists(GLOBAL_CONFIG_FILE)) - self.assertFalse(os.path.exists(OLD_CONFIG_FILE)) - self.assertTrue(os.path.exists(OLD_CONFIG_FILE + ".bak")) + self.assertTrue(os.path.exists(YTConfig.get_global_config_file())) + self.assertFalse(os.path.exists(old_config_file())) + self.assertTrue(os.path.exists(old_config_file() + ".bak")) - with open(GLOBAL_CONFIG_FILE, mode="r") as fh: + with open(YTConfig.get_global_config_file(), mode="r") as fh: new_cfg = fh.read() self.assertEqual(new_cfg, _DUMMY_CFG_TOML)