diff --git a/cookie_composer/cc_overrides.py b/cookie_composer/cc_overrides.py new file mode 100644 index 0000000..df82f77 --- /dev/null +++ b/cookie_composer/cc_overrides.py @@ -0,0 +1,151 @@ +"""This overrides the default cookie cutter environment.""" +from typing import Any + +import json + +from cookiecutter.environment import StrictEnvironment +from cookiecutter.exceptions import UndefinedVariableInTemplate +from cookiecutter.prompt import ( + prompt_choice_for_config, + read_user_dict, + read_user_variable, + render_variable, +) +from jinja2 import UndefinedError +from jinja2.ext import Extension + +from cookie_composer.data_merge import Context + + +def jsonify_context(value: Any) -> dict: + """Convert a ``Context`` to a dict.""" + if isinstance(value, Context): + return value.flatten() + + raise TypeError() + + +class JsonifyContextExtension(Extension): + """Jinja2 extension to convert a Python object to JSON.""" + + def __init__(self, environment): + """Initialize the extension with the given environment.""" + super().__init__(environment) + + def jsonify(obj): + return json.dumps(obj, sort_keys=True, indent=4, default=jsonify_context) + + environment.filters["jsonify"] = jsonify + + +class CustomStrictEnvironment(StrictEnvironment): + """ + Create strict Jinja2 environment. + + Jinja2 environment will raise error on undefined variable in template-rendering context. + + Does not expect all the context to be under the `cookiecutter` key. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if "cookiecutter.extensions.JsonifyExtension" in self.extensions: + del self.extensions["cookiecutter.extensions.JsonifyExtension"] + self.add_extension("cookie_composer.cc_overrides.JsonifyContextExtension") + + def _read_extensions(self, context) -> list[str]: + """ + Return list of extensions as str to be passed on to the Jinja2 env. + + If context does not contain the relevant info, return an empty + list instead. + + Args: + context: A ``dict`` possibly containing the ``_extensions`` key + + Returns: + List of extensions as str to be passed on to the Jinja2 env + """ + return [str(ext) for ext in context.get("_extensions", [])] + + +def prompt_for_config(prompts: dict, existing_config: Context, no_input=False) -> Context: + """ + Prompt user to enter a new config using an existing config as a basis. + + Will not prompt for configurations already in the existing configuration. + + Prompts can refer to items in the existing config. + + Args: + prompts: A dictionary of configuration prompts and default values + existing_config: An existing configuration to use as a basis + no_input: If ``True`` Don't prompt the user at command line for manual configuration + + Raises: + UndefinedVariableInTemplate: If a variable in a prompt defaults is not in the context + + Returns: + A new configuration context + """ + import copy + + # Make sure we have a fresh layer to populate + if existing_config.is_empty: + context = existing_config + else: + context = existing_config.new_child() + + env = CustomStrictEnvironment(context=existing_config) + + # First pass: Handle simple and raw variables, plus choices. + # These must be done first because the dictionaries keys and + # values might refer to them. + for key, raw in prompts.items(): + if key.startswith("_") and not key.startswith("__"): + context[key] = raw + continue + elif key.startswith("__"): + context[key] = render_variable(env, raw, context) + continue + elif key in context: + context[key] = copy.deepcopy(context.parents[key]) + continue + + try: + if isinstance(raw, list): + # We are dealing with a choice variable + val = prompt_choice_for_config(context, env, key, raw, no_input) + context[key] = val + elif not isinstance(raw, dict): + # We are dealing with a regular variable + val = render_variable(env, raw, context) + + if not no_input: + val = read_user_variable(key, val) + + context[key] = val + except UndefinedError as err: + msg = f"Unable to render variable '{key}'" + raise UndefinedVariableInTemplate(msg, err, context) from err + + # Second pass; handle the dictionaries. + for key, raw in prompts.items(): + # Skip private type dicts not ot be rendered. + if key.startswith("_") and not key.startswith("__"): + continue + + try: + if isinstance(raw, dict): + # We are dealing with a dict variable + val = render_variable(env, raw, context) + + if not no_input and not key.startswith("__"): + val = read_user_dict(key, val) + + context[key] = val + except UndefinedError as err: + msg = f"Unable to render variable '{key}'" + raise UndefinedVariableInTemplate(msg, err, context) + + return context diff --git a/cookie_composer/data_merge.py b/cookie_composer/data_merge.py index bbf325f..fcb4266 100644 --- a/cookie_composer/data_merge.py +++ b/cookie_composer/data_merge.py @@ -2,6 +2,7 @@ from typing import Any, Iterable import copy +from collections import ChainMap, OrderedDict from functools import reduce @@ -10,7 +11,7 @@ def deep_merge(*dicts) -> dict: Merges dicts deeply. Args: - dicts: List of dicts to merge with the first one the base + dicts: List of dicts to merge with the first one as the base Returns: dict: The merged dict @@ -60,28 +61,25 @@ def comprehensive_merge(*args) -> Any: Returns: The merged data - - Raises: - ValueError: If the values are not of the same type """ + dict_types = (dict, OrderedDict) + iterable_types = (list, set, tuple) def merge_into(d1, d2): - if type(d1) != type(d2): - raise ValueError(f"Cannot merge {type(d2)} into {type(d1)}.") + if isinstance(d1, dict_types) and isinstance(d2, dict_types): + if isinstance(d1, OrderedDict) or isinstance(d2, OrderedDict): + d1 = OrderedDict(d1) + d2 = OrderedDict(d2) - if isinstance(d1, list): + for key in d2: + d1[key] = merge_into(d1[key], d2[key]) if key in d1 else copy.deepcopy(d2[key]) + return d1 + elif isinstance(d1, list) and isinstance(d2, iterable_types): return list(merge_iterables(d1, d2)) - elif isinstance(d1, set): + elif isinstance(d1, set) and isinstance(d2, iterable_types): return merge_iterables(d1, d2) - elif isinstance(d1, tuple): + elif isinstance(d1, tuple) and isinstance(d2, iterable_types): return tuple(merge_iterables(d1, d2)) - elif isinstance(d1, dict): - for key in d2: - if key in d1: - d1[key] = merge_into(d1[key], d2[key]) - else: - d1[key] = copy.deepcopy(d2[key]) - return d1 else: return copy.deepcopy(d2) @@ -91,7 +89,20 @@ def merge_into(d1, d2): return reduce(merge_into, args, tuple()) elif isinstance(args[0], set): return reduce(merge_into, args, set()) - elif isinstance(args[0], dict): + elif isinstance(args[0], dict_types): return reduce(merge_into, args, {}) else: return reduce(merge_into, args) + + +class Context(ChainMap): + """Provides merging and convenence functions for managing contexts.""" + + @property + def is_empty(self) -> bool: + """The context has only one mapping and it is empty.""" + return len(self.maps) == 1 and len(self.maps[0]) == 0 + + def flatten(self) -> dict: + """Comprehensively merge all the maps into a single mapping.""" + return reduce(comprehensive_merge, self.maps, {}) diff --git a/cookie_composer/layers.py b/cookie_composer/layers.py index 6899f60..b91209c 100644 --- a/cookie_composer/layers.py +++ b/cookie_composer/layers.py @@ -1,31 +1,38 @@ """Layer management.""" -from typing import List, Mapping, Optional +from typing import List, Optional -import logging import os import shutil import tempfile from enum import Enum from pathlib import Path +import structlog + +# from ._vendor.cookiecutter.config import get_user_config +# from ._vendor.cookiecutter.generate import generate_context, generate_files +# from ._vendor.cookiecutter.prompt import prompt_for_config +# from ._vendor.cookiecutter.repository import determine_repo_dir +# from ._vendor.cookiecutter.utils import rmtree +from cookiecutter.config import get_user_config +from cookiecutter.generate import generate_context, generate_files +from cookiecutter.repository import determine_repo_dir +from cookiecutter.utils import rmtree + +from cookie_composer.cc_overrides import prompt_for_config from cookie_composer.composition import ( DO_NOT_MERGE, LayerConfig, RenderedLayer, get_merge_strategy, ) -from cookie_composer.data_merge import comprehensive_merge +from cookie_composer.data_merge import Context from cookie_composer.matching import matches_any_glob from cookie_composer.merge_files import MERGE_FUNCTIONS -from ._vendor.cookiecutter.config import get_user_config -from ._vendor.cookiecutter.generate import generate_context, generate_files -from ._vendor.cookiecutter.prompt import prompt_for_config -from ._vendor.cookiecutter.repository import determine_repo_dir -from ._vendor.cookiecutter.utils import rmtree from .git_commands import get_latest_template_commit -logger = logging.getLogger(__name__) +logger = structlog.get_logger(__name__) class WriteStrategy(Enum): @@ -84,7 +91,7 @@ def get_write_strategy(origin: Path, destination: Path, rendered_layer: Rendered def render_layer( - layer_config: LayerConfig, render_dir: Path, full_context: Mapping = None, accept_hooks: bool = True + layer_config: LayerConfig, render_dir: Path, full_context: Optional[Context] = None, accept_hooks: bool = True ) -> RenderedLayer: """ Process one layer of the template composition. @@ -100,36 +107,25 @@ def render_layer( Returns: The rendered layer information """ - config_dict = get_user_config(config_file=None, default_config=False) - + full_context = full_context or Context() + user_config = get_user_config(config_file=None, default_config=False) repo_dir, cleanup = determine_repo_dir( template=layer_config.template, - abbreviations=config_dict["abbreviations"], - clone_to_dir=config_dict["cookiecutters_dir"], + abbreviations=user_config["abbreviations"], + clone_to_dir=user_config["cookiecutters_dir"], checkout=layer_config.commit or layer_config.checkout, no_input=layer_config.no_input, password=layer_config.password, directory=layer_config.directory, ) - # _copy_without_render is template-specific and fails if overridden - # So we are going to remove it from the "defaults" when generating the context - config_dict["default_context"].pop("_copy_without_render", None) - if full_context and "_copy_without_render" in full_context: - del full_context["_copy_without_render"] - - context = generate_context( - context_file=Path(repo_dir) / "cookiecutter.json", - default_context=config_dict["default_context"], - extra_context=full_context, - ) - context["cookiecutter"] = prompt_for_config(context, layer_config.no_input) + context = get_layer_context(layer_config, repo_dir, user_config, full_context) layer_config.commit = latest_commit = get_latest_template_commit(repo_dir) # call cookiecutter's generate files function generate_files( repo_dir=repo_dir, - context=context, + context={"cookiecutter": context.flatten()}, overwrite_if_exists=False, output_dir=str(render_dir), accept_hooks=accept_hooks, @@ -138,7 +134,7 @@ def render_layer( rendered_layer = RenderedLayer( layer=layer_config, location=render_dir, - new_context=context["cookiecutter"], + new_context=context.maps[0], latest_commit=latest_commit, ) @@ -148,6 +144,39 @@ def render_layer( return rendered_layer +def get_layer_context( + layer_config: LayerConfig, repo_dir: str, user_config: dict, full_context: Optional[Context] = None +) -> Context: + """ + Get the context for a layer pre-rendering values using previous layers contexts as defaults. + + Args: + layer_config: The configuration for this layer + repo_dir: The directory containing the template's ``cookiecutter.json`` file + user_config: The user's cookiecutter configuration + full_context: A full context from previous layers. + + Returns: + The context for rendering the layer + """ + full_context = full_context or Context() + + # _copy_without_render is template-specific and fails if overridden + # So we are going to remove it from the "defaults" when generating the context + user_config["default_context"].pop("_copy_without_render", None) + # if full_context and "_copy_without_render" in full_context: + # del full_context["_copy_without_render"] + + # This pulls in the template context and overrides the values with the user config defaults + # and the defaults specified in the layer. + prompts = generate_context( + context_file=Path(repo_dir) / "cookiecutter.json", + default_context=user_config["default_context"], + extra_context=layer_config.context or {}, + ) + return prompt_for_config(prompts["cookiecutter"], full_context, layer_config.no_input) + + def render_layers( layers: List[LayerConfig], destination: Path, @@ -168,21 +197,18 @@ def render_layers( Returns: A list of the rendered layer information """ - full_context = initial_context or {} + full_context = Context(initial_context) if initial_context else Context() rendered_layers = [] for layer_config in layers: layer_config.no_input = True if no_input else layer_config.no_input - if layer_config.context: - full_context = comprehensive_merge(full_context, layer_config.context) - with tempfile.TemporaryDirectory() as render_dir: rendered_layer = render_layer(layer_config, render_dir, full_context, accept_hooks) merge_layers(destination, rendered_layer) rendered_layer.layer.commit = rendered_layer.latest_commit rendered_layer.layer.context = rendered_layer.new_context rendered_layers.append(rendered_layer) - full_context = comprehensive_merge(full_context, rendered_layer.new_context) + full_context = full_context.new_child(rendered_layer.new_context) return rendered_layers diff --git a/tests/fixtures/multi-template.yaml b/tests/fixtures/multi-template.yaml index de9e9e6..d81aefa 100644 --- a/tests/fixtures/multi-template.yaml +++ b/tests/fixtures/multi-template.yaml @@ -1,3 +1,5 @@ template: tests/fixtures/template1 --- template: tests/fixtures/template2 +context: + project_slug: "{{ cookiecutter.repo_slug }}" diff --git a/tests/fixtures/rendered_composition.yaml b/tests/fixtures/rendered_composition.yaml new file mode 100644 index 0000000..6129975 --- /dev/null +++ b/tests/fixtures/rendered_composition.yaml @@ -0,0 +1,95 @@ +checkout: null +commit: 3391f60471939dd412e93e9853e3328e7a5a8c44 +context: + _copy_without_render: + - .github/**/*.jinja + _dev_requirements: + bump2version: '>=1.0.1' + generate-changelog: '>=0.7.6' + git-fame: '>=1.12.2' + pip-tools: '' + _prod_requirements: + environs: '>=9.3.5' + _test_requirements: + black: '>=19.10b0' + coverage: '>=6.1.2' + flake8: '>=4.0.1' + pre-commit: '>=2.15.0' + pytest: '>=6.0.0' + pytest-cov: '>=3.0.0' + author: Who am I? + email: whoami@existential-crisis.doom + friendly_name: My Test Project + github_user: whoami + project_name: my-test-project + project_short_description: '' + version: 0.1.0 +directory: cookiecutter-boilerplate +merge_strategies: + '*.json': comprehensive + '*.yaml': comprehensive + '*.yml': comprehensive +no_input: false +overwrite: [] +overwrite_exclude: [] +password: null +skip_generation: [] +skip_hooks: false +skip_if_file_exists: true +template: https://github.com/coordt/cookiecomposer-templates +--- +checkout: null +commit: 3391f60471939dd412e93e9853e3328e7a5a8c44 +context: + friendly_name: My Test Project + project_name: my-test-project + project_short_description: '' + project_slug: my_test_project + version: 0.1.0 +directory: cookiecutter-package +merge_strategies: + '*.json': comprehensive + '*.yaml': comprehensive + '*.yml': comprehensive +no_input: false +overwrite: [] +overwrite_exclude: [] +password: null +skip_generation: [] +skip_hooks: false +skip_if_file_exists: true +template: https://github.com/coordt/cookiecomposer-templates +--- +checkout: null +commit: 3391f60471939dd412e93e9853e3328e7a5a8c44 +context: + _copy_without_render: + - docsrc/**/*.rst + _dev_requirements: {} + _docs_requirements: + Sphinx: '>=4.3.0' + furo: '' + ghp-import: '' + linkify-it-py: '' + myst-parser: '' + sphinx-autodoc-typehints: '' + sphinx-click: '' + sphinx-copybutton: '' + _prod_requirements: {} + friendly_name: My Test Project + github_user: whoami + project_name: my-test-project + project_slug: my_test_project +directory: cookiecutter-docs +merge_strategies: + '*.json': comprehensive + '*.yaml': comprehensive + '*.yml': comprehensive +no_input: false +overwrite: [] +overwrite_exclude: [] +password: null +skip_generation: [] +skip_hooks: false +skip_if_file_exists: true +template: https://github.com/coordt/cookiecomposer-templates diff --git a/tests/fixtures/template1/cookiecutter.json b/tests/fixtures/template1/cookiecutter.json index 23a4e95..3b3c834 100644 --- a/tests/fixtures/template1/cookiecutter.json +++ b/tests/fixtures/template1/cookiecutter.json @@ -1,5 +1,6 @@ { "project_name": "Fake Project Template", "repo_name": "{{ cookiecutter.project_name|lower|replace(' ', '-') }}", + "repo_slug": "{{ cookiecutter.project_name|lower|replace(' ', '-') }}", "_requirements": {"foo": "", "bar": ">=5.0.0"} } diff --git a/tests/fixtures/template2/cookiecutter.json b/tests/fixtures/template2/cookiecutter.json index 9653f87..e3db872 100644 --- a/tests/fixtures/template2/cookiecutter.json +++ b/tests/fixtures/template2/cookiecutter.json @@ -1,5 +1,7 @@ { "project_name": "Fake Project Template", "repo_name": "{{ cookiecutter.project_name|lower|replace(' ', '-') }}", - "_requirements": {"bar": ">=5.0.0", "baz": ""} + "project_slug": "{{ cookiecutter.project_name|lower|replace(' ', '-') }}", + "_requirements": {"bar": ">=5.0.0", "baz": ""}, + "lower_project_name": "{{ cookiecutter.project_name|lower }}" } diff --git a/tests/test_composition.py b/tests/test_composition.py index 337ffd6..2916f69 100644 --- a/tests/test_composition.py +++ b/tests/test_composition.py @@ -6,9 +6,6 @@ from cookie_composer.composition import LayerConfig from cookie_composer.exceptions import MissingCompositionFileError -# TODO: test bad info in the template composition -# TODO: test missing info in the template composition - def test_multiple_templates(fixtures_path): filepath = fixtures_path / "multi-template.yaml" diff --git a/tests/test_layers.py b/tests/test_layers.py index 7c3a217..b385dde 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -2,8 +2,11 @@ import json import os import shutil +from collections import OrderedDict from pathlib import Path +from cookiecutter.config import get_user_config + from cookie_composer import layers from cookie_composer.composition import ( DO_NOT_MERGE, @@ -12,7 +15,7 @@ RenderedLayer, read_composition, ) -from cookie_composer.data_merge import comprehensive_merge +from cookie_composer.data_merge import Context, comprehensive_merge from cookie_composer.git_commands import get_latest_template_commit @@ -22,6 +25,7 @@ def test_render_layer(fixtures_path, tmp_path): rendered_layer = layers.render_layer(layer_conf, tmp_path) expected_context = json.loads(Path(fixtures_path / "template1/cookiecutter.json").read_text()) expected_context["repo_name"] = "fake-project-template" + expected_context["repo_slug"] = "fake-project-template" expected = RenderedLayer( layer=layer_conf, location=tmp_path, @@ -154,15 +158,17 @@ def test_merge_layers(tmp_path, fixtures_path): def test_render_layers(fixtures_path, tmp_path): """Render layers generates a list of rendered layer objects.""" - filepath = fixtures_path / "multi-template.yaml" - comp = read_composition(filepath) + tmpl_layers = [ + LayerConfig(template=str(fixtures_path / "template1")), + LayerConfig(template=str(fixtures_path / "template2")), + ] context1 = json.loads((fixtures_path / "template1" / "cookiecutter.json").read_text()) context2 = json.loads((fixtures_path / "template2" / "cookiecutter.json").read_text()) full_context = comprehensive_merge(context1, context2) - rendered_layers = layers.render_layers(comp.layers, tmp_path, full_context, no_input=True) + rendered_layers = layers.render_layers(tmpl_layers, tmp_path, None, no_input=True) rendered_project = tmp_path / rendered_layers[0].rendered_name - rendered_items = set([item.name for item in os.scandir(rendered_project)]) + rendered_items = {item.name for item in os.scandir(rendered_project)} assert rendered_items == {"ABOUT.md", "README.md", "requirements.txt"} @@ -188,6 +194,7 @@ def test_render_layer_git_template(fixtures_path, tmp_path): rendered_layer = layers.render_layer(layer_conf, render_dir) expected_context = json.loads(Path(fixtures_path / "template1/cookiecutter.json").read_text()) expected_context["repo_name"] = "fake-project-template" + expected_context["repo_slug"] = "fake-project-template" expected = RenderedLayer( layer=layer_conf, location=render_dir, @@ -198,3 +205,54 @@ def test_render_layer_git_template(fixtures_path, tmp_path): assert rendered_layer.latest_commit == latest_sha assert rendered_layer.layer.commit == latest_sha assert {x.name for x in Path(render_dir / "fake-project-template").iterdir()} == {"README.md", "requirements.txt"} + + +def test_get_layer_context(fixtures_path): + repo_dir = str(fixtures_path / "template1") + layer_conf = LayerConfig(template=repo_dir, no_input=True) + user_config = get_user_config(config_file=None, default_config=False) + + context = layers.get_layer_context(layer_conf, repo_dir, user_config) + assert context == Context( + dict( + [ + ("project_name", "Fake Project Template"), + ("repo_name", "fake-project-template"), + ("repo_slug", "fake-project-template"), + ("_requirements", OrderedDict([("foo", ""), ("bar", ">=5.0.0")])), + ] + ) + ) + + +def test_get_layer_context_with_extra(fixtures_path): + repo_dir = str(fixtures_path / "template2") + layer_conf = LayerConfig( + template=repo_dir, context={"project_slug": "{{ cookiecutter.repo_slug }}"}, no_input=True + ) + user_config = get_user_config(config_file=None, default_config=False) + full_context = Context( + { + "project_name": "Fake Project Template2", + "repo_name": "fake-project-template2", + "repo_slug": "fake-project-template-two", + "_requirements": {"foo": "", "bar": ">=5.0.0"}, + } + ) + context = layers.get_layer_context(layer_conf, repo_dir, user_config, full_context) + + assert context == Context( + { + "project_name": "Fake Project Template2", + "repo_name": "fake-project-template2", + "project_slug": "fake-project-template-two", + "_requirements": OrderedDict([("bar", ">=5.0.0"), ("baz", "")]), + "lower_project_name": "fake project template2", + }, + { + "project_name": "Fake Project Template2", + "repo_name": "fake-project-template2", + "repo_slug": "fake-project-template-two", + "_requirements": {"foo": "", "bar": ">=5.0.0"}, + }, + ) diff --git a/tests/test_merge_files_helpers.py b/tests/test_merge_files_helpers.py index d4ebbf7..a2a4a42 100644 --- a/tests/test_merge_files_helpers.py +++ b/tests/test_merge_files_helpers.py @@ -1,6 +1,8 @@ """Test the merge_files.helpers functions.""" from typing import Any +from collections import OrderedDict + import pytest from pytest import param @@ -55,6 +57,16 @@ def test_deepmerge(dict_list: list, expected: dict): id="dict 2 iterable merges with dict 1 iterable", ), param([1, 2], 2, id="scalar 2 overwrites scalar 1"), + param( + [OrderedDict({"first": 1, "second": 2}), {"second": "two", "third": 3}], + OrderedDict({"first": 1, "second": "two", "third": 3}), + id="dict into ordered dict", + ), + param( + [{"first": 1, "second": 2}, OrderedDict({"third": 3})], + OrderedDict({"first": 1, "second": 2, "third": 3}), + id="ordered dict into dict", + ), ], ) def test_comprehensive_merge(args: list, expected: Any): @@ -64,9 +76,35 @@ def test_comprehensive_merge(args: list, expected: Any): assert data_merge.comprehensive_merge(*args) == expected -def test_comprehensive_merge_bad_types(): - """ - Make sure it raises an error if the types are not the same. - """ - with pytest.raises(ValueError): - data_merge.comprehensive_merge([1, 2], (2, 3)) +def test_context_flatten(): + """Should return a merged dict.""" + context = data_merge.Context( + { + "project_name": "Fake Project Template2", + "repo_name": "fake-project-template2", + "project_slug": "fake-project-template-two", + "_requirements": OrderedDict([("bar", ">=5.0.0"), ("baz", "")]), + "lower_project_name": "fake project template2", + }, + { + "project_name": "Fake Project Template2", + "repo_name": "fake-project-template2", + "repo_slug": "fake-project-template-two", + "_requirements": {"foo": "", "bar": ">=5.0.0"}, + }, + ) + expected = { + "project_name": "Fake Project Template2", + "repo_name": "fake-project-template2", + "project_slug": "fake-project-template-two", + "repo_slug": "fake-project-template-two", + "_requirements": OrderedDict( + [ + ("bar", ">=5.0.0"), + ("baz", ""), + ("foo", ""), + ] + ), + "lower_project_name": "fake project template2", + } + assert context.flatten() == expected diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..262b9c9 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,52 @@ +"""Tests for the utils module.""" +import pytest +from pytest import param + +from cookie_composer import utils +from cookie_composer.composition import read_rendered_composition + + +def test_get_context_for_layer(fixtures_path): + """Return a context for a given layer.""" + rendered_comp = read_rendered_composition(fixtures_path / "rendered_composition.yaml") + + result1 = utils.get_context_for_layer(rendered_comp, 0) + assert result1 == rendered_comp.layers[0].new_context + + result2 = utils.get_context_for_layer(rendered_comp, 1) + assert "project_slug" in result2 + assert len(result2) == len(result1) + 1 + + result3 = utils.get_context_for_layer(rendered_comp, 2) + assert "_docs_requirements" in result3 + assert len(result3) == len(result2) + 1 + + result4 = utils.get_context_for_layer(rendered_comp) + assert result4 == result3 + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + param("/path/to/template/", "template", id="local directory"), + param("/path/to/composition.yaml", "composition", id="local composition"), + param("https://example.com/path/to/template", "template", id="remote directory"), + param("https://example.com/path/to/composition.yaml", "composition", id="remote composition"), + ], +) +def test_get_template_name(value, expected): + """The template name should be the base name of the path.""" + assert utils.get_template_name(value) == expected + + +@pytest.mark.parametrize( + ["bad_value"], + [ + ("https://example.com",), + ("",), + ], +) +def test_get_template_name_errors(bad_value): + """It should raise errors""" + with pytest.raises(ValueError): + utils.get_template_name(bad_value)