Skip to content

Commit

Permalink
Tune cast of config keys on read/CLI set
Browse files Browse the repository at this point in the history
  • Loading branch information
khaeru committed Feb 8, 2022
1 parent e2fdc77 commit 6ec7c4a
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions ixmp/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import copy
from dataclasses import asdict, dataclass, field, fields, make_dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Type

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,9 +74,7 @@ def _platform_default():
class BaseValues:
"""Base class for storing configuration values."""

platform: Dict[str, Union[str, Dict[str, Any]]] = field(
default_factory=_platform_default
)
platform: dict = field(default_factory=_platform_default)

def __getitem__(self, name):
return getattr(self, name.replace(" ", "_"))
Expand Down Expand Up @@ -127,17 +125,16 @@ def keys(self) -> Tuple[str, ...]:
return tuple(map(lambda f: f.name.replace("_", " "), fields(self)))

def set(self, name: str, value: Any, strict: bool = True):
f = self.get_field(name)
if strict and f is None:
raise KeyError(name)

# Retrieve the type for `name`; or None if unregistered
f = self.get_field(name) or None
type_ = getattr(f, "type", None)

try:
# Attempt to cast to the correct type, if any
if type_ and strict:
value = type_(value)
except TypeError:
# `strict` but `name` is not registered; tried to call None(value)
raise KeyError(name)
value = type_(value) if type_ else value
except Exception:
raise TypeError(
f"expected {type_} for {repr(name)}; got {repr(value)} ({type(value)})"
Expand Down Expand Up @@ -244,13 +241,18 @@ def read(self):
contents = config_path.read_text()

try:
# Parse JSON and set values; _strict=False tolerates unregistered values
for key, value in json.loads(contents).items():
self.set(key, value, _strict=False)
data = json.loads(contents)
except json.JSONDecodeError:
print(config_path, contents)
raise

# Parse JSON and set values
for key, value in data.items():
try:
self.set(key, value, _strict=True) # Cast type for registered keys
except KeyError:
self.set(key, value, _strict=False) # Tolerate unregistered keys

# Public methods

def get(self, name: str) -> Any:
Expand Down

0 comments on commit 6ec7c4a

Please sign in to comment.