Skip to content

Commit

Permalink
Make config related paths dynamical. Fixes yt-project#3104
Browse files Browse the repository at this point in the history
  • Loading branch information
Xarthisius committed Feb 27, 2021
1 parent 10221b0 commit 49bfe83
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 69 deletions.
32 changes: 19 additions & 13 deletions yt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,21 @@
)


CONFIG_DIR = os.environ.get(
"XDG_CONFIG_HOME", os.path.join(os.path.expanduser("~"), ".config", "yt")
)
if not os.path.exists(CONFIG_DIR):
try:
os.makedirs(CONFIG_DIR)
except OSError:
warnings.warn("unable to create yt config directory")
def config_dir():
conf_dir = os.environ.get(
"XDG_CONFIG_HOME", os.path.join(os.path.expanduser("~"), ".config", "yt")
)

if not os.path.exists(conf_dir):
try:
os.makedirs(conf_dir)
except OSError:
warnings.warn("unable to create yt config directory")
return conf_dir


def old_config_file():
return os.path.join(config_dir(), "ytrc")


class YTConfig:
Expand Down Expand Up @@ -168,21 +175,20 @@ def write(self, file_handler):

@staticmethod
def get_global_config_file():
return os.path.join(CONFIG_DIR, "yt.toml")
return os.path.join(config_dir(), "yt.toml")

@staticmethod
def get_local_config_file():
return os.path.join(os.path.abspath(os.curdir), "yt.toml")


OLD_CONFIG_FILE = os.path.join(CONFIG_DIR, "ytrc")
_global_config_file = YTConfig.get_global_config_file()
_local_config_file = YTConfig.get_local_config_file()

if os.path.exists(OLD_CONFIG_FILE):
if os.path.exists(old_config_file()):
if os.path.exists(_global_config_file):
issue_deprecation_warning(
f"The configuration file {OLD_CONFIG_FILE} is deprecated in "
f"The configuration file {old_config_file()} is deprecated in "
f"favor of {_global_config_file}. Currently, both are present. "
"Please manually remove the deprecated one to silence "
"this warning.",
Expand All @@ -203,7 +209,7 @@ def get_local_config_file():
stack = inspect.stack()
if len(stack) < 2 or stack[-2].function != "importlib_load_entry_point":
issue_deprecation_warning(
f"The configuration file {OLD_CONFIG_FILE} is deprecated. "
f"The configuration file {old_config_file()} is deprecated. "
f"Please migrate your config to {_global_config_file} by running: "
"'yt config migrate'",
since="4.0.0",
Expand Down
22 changes: 11 additions & 11 deletions yt/fields/tests/test_fields_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

import yt
from yt.config import CONFIG_DIR, ytcfg
from yt.config import config_dir, ytcfg
from yt.testing import assert_raises, fake_random_ds

TEST_PLUGIN_FILE = """
Expand All @@ -22,14 +22,14 @@ def myfunc():
def setUpModule():
my_plugin_name = ytcfg.get("yt", "plugin_filename")
# In the following order if plugin_filename is: an absolute path, located in
# the CONFIG_DIR, located in an obsolete config dir.
# the config_dir(), located in an obsolete config dir.
old_config_dir = os.path.join(os.path.expanduser("~"), ".yt")
for base_prefix in ("", CONFIG_DIR, old_config_dir):
for base_prefix in ("", config_dir(), old_config_dir):
potential_plugin_file = os.path.join(base_prefix, my_plugin_name)
if os.path.isfile(potential_plugin_file):
os.rename(potential_plugin_file, potential_plugin_file + ".bak_test")

plugin_file = os.path.join(CONFIG_DIR, my_plugin_name)
plugin_file = os.path.join(config_dir(), my_plugin_name)
with open(plugin_file, "w") as fh:
fh.write(TEST_PLUGIN_FILE)

Expand All @@ -39,7 +39,7 @@ def tearDownModule():

my_plugins_fields.clear()
my_plugin_name = ytcfg.get("yt", "plugin_filename")
plugin_file = os.path.join(CONFIG_DIR, my_plugin_name)
plugin_file = os.path.join(config_dir(), my_plugin_name)
os.remove(plugin_file)


Expand All @@ -48,32 +48,32 @@ class TestPluginFile(unittest.TestCase):
def setUpClass(cls):
my_plugin_name = ytcfg.get("yt", "plugin_filename")
# In the following order if plugin_filename is: an absolute path, located in
# the CONFIG_DIR, located in an obsolete config dir.
# the config_dir(), located in an obsolete config dir.
old_config_dir = os.path.join(os.path.expanduser("~"), ".yt")
for base_prefix in ("", CONFIG_DIR, old_config_dir):
for base_prefix in ("", config_dir(), old_config_dir):
potential_plugin_file = os.path.join(base_prefix, my_plugin_name)
if os.path.isfile(potential_plugin_file):
os.rename(potential_plugin_file, potential_plugin_file + ".bak_test")

plugin_file = os.path.join(CONFIG_DIR, my_plugin_name)
plugin_file = os.path.join(config_dir(), my_plugin_name)
with open(plugin_file, "w") as fh:
fh.write(TEST_PLUGIN_FILE)

@classmethod
def tearDownClass(cls):
my_plugin_name = ytcfg.get("yt", "plugin_filename")
plugin_file = os.path.join(CONFIG_DIR, my_plugin_name)
plugin_file = os.path.join(config_dir(), my_plugin_name)
os.remove(plugin_file)

old_config_dir = os.path.join(os.path.expanduser("~"), ".yt")
for base_prefix in ("", CONFIG_DIR, old_config_dir):
for base_prefix in ("", config_dir(), old_config_dir):
potential_plugin_file = os.path.join(base_prefix, my_plugin_name)
if os.path.isfile(potential_plugin_file + ".bak_test"):
os.rename(potential_plugin_file + ".bak_test", potential_plugin_file)
del yt.myfunc

def testCustomField(self):
plugin_file = os.path.join(CONFIG_DIR, ytcfg.get("yt", "plugin_filename"))
plugin_file = os.path.join(config_dir(), ytcfg.get("yt", "plugin_filename"))
msg = f"INFO:yt:Loading plugins from {plugin_file}"

with self.assertLogs("yt", level="INFO") as cm:
Expand Down
6 changes: 3 additions & 3 deletions yt/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def enable_plugins(plugin_filename=None):
file is shared with it.
"""
import yt
from yt.config import CONFIG_DIR, ytcfg
from yt.config import config_dir, ytcfg
from yt.fields.my_plugin_fields import my_plugins_fields

if plugin_filename is not None:
Expand All @@ -959,7 +959,7 @@ def enable_plugins(plugin_filename=None):
# - obsolete config dir.
my_plugin_name = ytcfg.get("yt", "plugin_filename")
old_config_dir = os.path.join(os.path.expanduser("~"), ".yt")
for base_prefix in ("", CONFIG_DIR, old_config_dir):
for base_prefix in ("", config_dir(), old_config_dir):
if os.path.isfile(os.path.join(base_prefix, my_plugin_name)):
_fn = os.path.join(base_prefix, my_plugin_name)
break
Expand All @@ -971,7 +971,7 @@ def enable_plugins(plugin_filename=None):
"Your plugin file is located in a deprecated directory. "
"Please move it from %s to %s",
os.path.join(old_config_dir, my_plugin_name),
os.path.join(CONFIG_DIR, my_plugin_name),
os.path.join(config_dir(), my_plugin_name),
)

mylog.info("Loading plugins from %s", _fn)
Expand Down
10 changes: 5 additions & 5 deletions yt/utilities/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sys

from yt.config import OLD_CONFIG_FILE, YTConfig, ytcfg_defaults
from yt.config import YTConfig, old_config_file, ytcfg_defaults

CONFIG = YTConfig()

Expand Down Expand Up @@ -48,15 +48,15 @@ def write_config(config_file):


def migrate_config():
if not os.path.exists(OLD_CONFIG_FILE):
if not os.path.exists(old_config_file()):
print("Old config not found.")
sys.exit(1)

old_config = configparser.RawConfigParser()
# Preserve case:
# See https://stackoverflow.com/questions/1611799/preserve-case-in-configparser
old_config.optionxform = str
old_config.read(OLD_CONFIG_FILE)
old_config.read(old_config_file())

# In order to migrate, we'll convert everything to lowercase, and map that
# to the new snake_case convention
Expand Down Expand Up @@ -93,8 +93,8 @@ def usesCamelCase(key):
global_config_file = YTConfig.get_global_config_file()
print(f"Writing a new config file to: {global_config_file}")
write_config(global_config_file)
print(f"Backing up the old config file: {OLD_CONFIG_FILE}.bak")
os.rename(OLD_CONFIG_FILE, OLD_CONFIG_FILE + ".bak")
print(f"Backing up the old config file: {old_config_file()}.bak")
os.rename(old_config_file(), old_config_file() + ".bak")


def rm_config(section, option, config_file):
Expand Down
72 changes: 35 additions & 37 deletions yt/utilities/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import contextlib
import os
import shutil
import sys
import tempfile
import unittest
import unittest.mock as mock
from io import StringIO

import yt.config
import yt.utilities.command_line
from yt.config import OLD_CONFIG_FILE, YTConfig
from yt.config import YTConfig, old_config_file

GLOBAL_CONFIG_FILE = YTConfig.get_global_config_file()
LOCAL_CONFIG_FILE = YTConfig.get_local_config_file()
XDG_CONFIG_HOME = os.environ.get("XDG_CONFIG_HOME")

_TEST_PLUGIN = "_test_plugin.py"
# NOTE: the normalization of the crazy camel-case will be checked
Expand Down Expand Up @@ -46,24 +47,18 @@ class SysExitException(Exception):
pass


def setUpModule():
for cfgfile in (GLOBAL_CONFIG_FILE, OLD_CONFIG_FILE, LOCAL_CONFIG_FILE):
if os.path.exists(cfgfile):
os.rename(cfgfile, cfgfile + ".bak_test")

if cfgfile == GLOBAL_CONFIG_FILE:
yt.utilities.configure.CONFIG = YTConfig()
if not yt.utilities.configure.CONFIG.has_section("yt"):
yt.utilities.configure.CONFIG.add_section("yt")


def tearDownModule():
for cfgfile in (GLOBAL_CONFIG_FILE, OLD_CONFIG_FILE, LOCAL_CONFIG_FILE):
if os.path.exists(cfgfile + ".bak_test"):
os.rename(cfgfile + ".bak_test", cfgfile)
class TestYTConfig(unittest.TestCase):
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
os.environ["XDG_CONFIG_HOME"] = self.tmpdir

def tearDown(self):
shutil.rmtree(self.tmpdir)
if XDG_CONFIG_HOME:
os.environ["XDG_CONFIG_HOME"] = XDG_CONFIG_HOME
else:
os.environ.pop("XDG_CONFIG_HOME")

class TestYTConfig(unittest.TestCase):
def _runYTConfig(self, args):
args = ["yt", "config"] + args
retcode = 0
Expand Down Expand Up @@ -108,7 +103,7 @@ def testConfigCommands(self):
def remove_spaces_and_breaks(s):
return "".join(s.split())

self.assertFalse(os.path.exists(GLOBAL_CONFIG_FILE))
self.assertFalse(os.path.exists(YTConfig.get_global_config_file()))

info = self._runYTConfig(["--help"])
self.assertEqual(info["rc"], 0)
Expand Down Expand Up @@ -142,12 +137,13 @@ def remove_spaces_and_breaks(s):
self._testKeyTypeError("foo.bar", "foo", "bar", expect_error=False)

def tearDown(self):
if os.path.exists(GLOBAL_CONFIG_FILE):
os.remove(GLOBAL_CONFIG_FILE)
if os.path.exists(YTConfig.get_global_config_file()):
os.remove(YTConfig.get_global_config_file())


class TestYTConfigGlobalLocal(TestYTConfig):
def setUp(self):
super().setUp()
with open(YTConfig.get_local_config_file(), mode="w") as f:
f.writelines("[yt]\n")
with open(YTConfig.get_global_config_file(), mode="w") as f:
Expand All @@ -164,33 +160,35 @@ def testAmbiguousConfig(self):

class TestYTConfigMigration(TestYTConfig):
def setUp(self):
if not os.path.exists(os.path.dirname(OLD_CONFIG_FILE)):
os.makedirs(os.path.dirname(OLD_CONFIG_FILE))
super().setUp()
if not os.path.exists(os.path.dirname(old_config_file())):
os.makedirs(os.path.dirname(old_config_file()))

with open(OLD_CONFIG_FILE, "w") as fh:
with open(old_config_file(), "w") as fh:
fh.write(_DUMMY_CFG_INI)

if os.path.exists(GLOBAL_CONFIG_FILE):
os.remove(GLOBAL_CONFIG_FILE)
if os.path.exists(YTConfig.get_global_config_file()):
os.remove(YTConfig.get_global_config_file())

def tearDown(self):
if os.path.exists(GLOBAL_CONFIG_FILE):
os.remove(GLOBAL_CONFIG_FILE)
if os.path.exists(OLD_CONFIG_FILE + ".bak"):
os.remove(OLD_CONFIG_FILE + ".bak")
if os.path.exists(YTConfig.get_global_config_file()):
os.remove(YTConfig.get_global_config_file())
if os.path.exists(old_config_file() + ".bak"):
os.remove(old_config_file() + ".bak")
super().tearDown()

def testConfigMigration(self):
self.assertFalse(os.path.exists(GLOBAL_CONFIG_FILE))
self.assertTrue(os.path.exists(OLD_CONFIG_FILE))
self.assertFalse(os.path.exists(YTConfig.get_global_config_file()))
self.assertTrue(os.path.exists(old_config_file()))

info = self._runYTConfig(["migrate"])
self.assertEqual(info["rc"], 0)

self.assertTrue(os.path.exists(GLOBAL_CONFIG_FILE))
self.assertFalse(os.path.exists(OLD_CONFIG_FILE))
self.assertTrue(os.path.exists(OLD_CONFIG_FILE + ".bak"))
self.assertTrue(os.path.exists(YTConfig.get_global_config_file()))
self.assertFalse(os.path.exists(old_config_file()))
self.assertTrue(os.path.exists(old_config_file() + ".bak"))

with open(GLOBAL_CONFIG_FILE, mode="r") as fh:
with open(YTConfig.get_global_config_file(), mode="r") as fh:
new_cfg = fh.read()

self.assertEqual(new_cfg, _DUMMY_CFG_TOML)

0 comments on commit 49bfe83

Please sign in to comment.