From 2a80d0126a76827843be7ecc3d11c2e9a37fe12c Mon Sep 17 00:00:00 2001 From: jeanluc Date: Mon, 12 Aug 2024 15:16:35 +0200 Subject: [PATCH] feat: Add _operation variable --- copier/main.py | 56 +++++++++++++++++++++++++-- copier/types.py | 7 ++++ docs/configuring.md | 17 ++++++++ docs/creating.md | 10 +++++ tests/test_context.py | 90 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 176 insertions(+), 4 deletions(-) create mode 100644 tests/test_context.py diff --git a/copier/main.py b/copier/main.py index bb514fba0..10255477f 100644 --- a/copier/main.py +++ b/copier/main.py @@ -6,9 +6,10 @@ import subprocess import sys from contextlib import suppress +from contextvars import ContextVar from dataclasses import asdict, field, replace from filecmp import dircmp -from functools import cached_property, partial +from functools import cached_property, partial, wraps from itertools import chain from pathlib import Path from shutil import rmtree @@ -60,6 +61,8 @@ MISSING, AnyByStrDict, JSONSerializable, + Operation, + ParamSpec, RelativePath, StrOrPath, ) @@ -67,6 +70,29 @@ from .vcs import get_git _T = TypeVar("_T") +_P = ParamSpec("_P") + +_operation: ContextVar[Operation] = ContextVar("_operation") + + +def as_operation(value: Operation) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """Decorator to set the current operation context, if not defined already. + + This value is used to template specific configuration options. + """ + + def _decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: + @wraps(func) + def _wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + token = _operation.set(_operation.get(value)) + try: + return func(*args, **kwargs) + finally: + _operation.reset(token) + + return _wrapper + + return _decorator @dataclass(config=ConfigDict(extra="forbid")) @@ -243,7 +269,7 @@ def _cleanup(self) -> None: for method in self._cleanup_hooks: method() - def _check_unsafe(self, mode: Literal["copy", "update"]) -> None: + def _check_unsafe(self, mode: Operation) -> None: """Check whether a template uses unsafe features.""" if self.unsafe: return @@ -296,8 +322,10 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None: Arguments: tasks: The list of tasks to run. """ + operation = _operation.get() for i, task in enumerate(tasks): extra_context = {f"_{k}": v for k, v in task.extra_vars.items()} + extra_context["_operation"] = operation if not cast_to_bool(self._render_value(task.condition, extra_context)): continue @@ -327,7 +355,7 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None: / Path(self._render_string(str(task.working_directory), extra_context)) ).absolute() - extra_env = {k.upper(): str(v) for k, v in task.extra_vars.items()} + extra_env = {k[1:].upper(): str(v) for k, v in extra_context.items()} with local.cwd(working_directory), local.env(**extra_env): subprocess.run(task_cmd, shell=use_shell, check=True, env=local.env) @@ -588,7 +616,14 @@ def _pathjoin( @cached_property def match_exclude(self) -> Callable[[Path], bool]: """Get a callable to match paths against all exclusions.""" - return self._path_matcher(self.all_exclusions) + # Include the current operation in the rendering context. + # Note: This method is a cached property, it needs to be regenerated + # when reusing an instance in different contexts. + extra_context = {"_operation": _operation.get()} + return self._path_matcher( + self._render_string(exclusion, extra_context=extra_context) + for exclusion in self.all_exclusions + ) @cached_property def match_skip(self) -> Callable[[Path], bool]: @@ -818,6 +853,7 @@ def template_copy_root(self) -> Path: return self.template.local_abspath / subdir # Main operations + @as_operation("copy") def run_copy(self) -> None: """Generate a subproject from zero, ignoring what was in the folder. @@ -828,6 +864,11 @@ def run_copy(self) -> None: See [generating a project][generating-a-project]. """ + with suppress(AttributeError): + # We might have switched operation context, ensure the cached property + # is regenerated to re-render templates. + del self.match_exclude + self._check_unsafe("copy") self._print_message(self.template.message_before_copy) self._ask() @@ -854,6 +895,7 @@ def run_copy(self) -> None: # TODO Unify printing tools print("") # padding space + @as_operation("copy") def run_recopy(self) -> None: """Update a subproject, keeping answers but discarding evolution.""" if self.subproject.template is None: @@ -864,6 +906,7 @@ def run_recopy(self) -> None: with replace(self, src_path=self.subproject.template.url) as new_worker: new_worker.run_copy() + @as_operation("update") def run_update(self) -> None: """Update a subproject that was already generated. @@ -911,6 +954,11 @@ def run_update(self) -> None: print( f"Updating to template version {self.template.version}", file=sys.stderr ) + with suppress(AttributeError): + # We might have switched operation context, ensure the cached property + # is regenerated to re-render templates. + del self.match_exclude + self._apply_update() self._print_message(self.template.message_after_update) diff --git a/copier/types.py b/copier/types.py index be9a5de91..57cf2a067 100644 --- a/copier/types.py +++ b/copier/types.py @@ -1,5 +1,6 @@ """Complex types, annotations, validators.""" +import sys from pathlib import Path from typing import ( Annotated, @@ -16,6 +17,11 @@ from pydantic import AfterValidator +if sys.version_info >= (3, 10): + from typing import ParamSpec as ParamSpec +else: + from typing_extensions import ParamSpec as ParamSpec + # simple types StrOrPath = Union[str, Path] AnyByStrDict = Dict[str, Any] @@ -35,6 +41,7 @@ Env = Mapping[str, str] MissingType = NewType("MissingType", object) MISSING = MissingType(object()) +Operation = Literal["copy", "update"] # Validators diff --git a/docs/configuring.md b/docs/configuring.md index 636f04ea3..9f5a4d375 100644 --- a/docs/configuring.md +++ b/docs/configuring.md @@ -893,6 +893,18 @@ to know available options. The CLI option can be passed several times to add several patterns. +Each pattern can be templated using Jinja. + +!!! example + + Templating `exclude` patterns using `_operation` allows to have files + that are rendered once during `copy`, but are never updated: + + ```yaml + _exclude: + - "{% if _operation == 'update' -%}src/*_example.py{% endif %}" + ``` + !!! info When you define this parameter in `copier.yml`, it will **replace** the default @@ -1351,6 +1363,8 @@ configuring `secret: true` in the [advanced prompt format][advanced-prompt-forma exist, but always be present. If they do not exist in a project during an `update` operation, they will be recreated. +Each pattern can be templated using Jinja. + !!! example For example, it can be used if your project generates a password the 1st time and @@ -1501,6 +1515,9 @@ other items not present. - [invoke, end-process, "--full-conf={{ _copier_conf|to_json }}"] # Your script can be run by the same Python environment used to run Copier - ["{{ _copier_python }}", task.py] + # Run a command during the initial copy operation only, excluding updates + - command: ["{{ _copier_python }}", task.py] + when: "{{ _operation == 'copy' }}" # OS-specific task (supported values are "linux", "macos", "windows" and `None`) - command: rm {{ name_of_the_project }}/README.md when: "{{ _copier_conf.os in ['linux', 'macos'] }}" diff --git a/docs/creating.md b/docs/creating.md index d0415f377..bc43dd828 100644 --- a/docs/creating.md +++ b/docs/creating.md @@ -125,6 +125,16 @@ The absolute path of the Python interpreter running Copier. The name of the project root directory. +## Variables (context-dependent) + +Some variables are only available in select contexts: + +### `_operation` + +The current operation, either `"copy"` or `"update"`. + +Availability: [`exclude`](configuring.md#exclude), [`tasks`](configuring.md#tasks) + ## Variables (context-specific) Some rendering contexts provide variables unique to them: diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 000000000..5479c7db1 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,90 @@ +import json +from pathlib import Path + +import pytest +from plumbum import local + +import copier + +from .helpers import build_file_tree, git_save + + +def test_exclude_templating_with_operation( + tmp_path_factory: pytest.TempPathFactory, +) -> None: + """ + Ensure it's possible to create one-off boilerplate files that are not + managed during updates via `_exclude` using the `_operation` context variable. + """ + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + + template = "{% if _operation == 'update' %}copy-only{% endif %}" + with local.cwd(src): + build_file_tree( + { + "copier.yml": f'_exclude:\n - "{template}"', + "{{ _copier_conf.answers_file }}.jinja": "{{ _copier_answers|to_yaml }}", + "copy-only": "foo", + "copy-and-update": "foo", + } + ) + git_save(tag="1.0.0") + build_file_tree( + { + "copy-only": "bar", + "copy-and-update": "bar", + } + ) + git_save(tag="2.0.0") + copy_only = dst / "copy-only" + copy_and_update = dst / "copy-and-update" + + copier.run_copy(str(src), dst, defaults=True, overwrite=True, vcs_ref="1.0.0") + for file in (copy_only, copy_and_update): + assert file.exists() + assert file.read_text() == "foo" + + with local.cwd(dst): + git_save() + + copier.run_update(str(dst), overwrite=True) + assert copy_only.read_text() == "foo" + assert copy_and_update.read_text() == "bar" + + +def test_task_templating_with_operation( + tmp_path_factory: pytest.TempPathFactory, tmp_path: Path +) -> None: + """ + Ensure that it is possible to define tasks that are only executed when copying. + """ + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + # Use a file outside the Copier working directories to ensure accurate tracking + task_counter = tmp_path / "task_calls.txt" + with local.cwd(src): + build_file_tree( + { + "copier.yml": ( + f"""\ + _tasks: + - command: echo {{{{ _operation }}}} >> {json.dumps(str(task_counter))} + when: "{{{{ _operation == 'copy' }}}}" + """ + ), + "{{ _copier_conf.answers_file }}.jinja": "{{ _copier_answers|to_yaml }}", + } + ) + git_save(tag="1.0.0") + + copier.run_copy(str(src), dst, defaults=True, overwrite=True, unsafe=True) + assert task_counter.exists() + assert len(task_counter.read_text().splitlines()) == 1 + + with local.cwd(dst): + git_save() + + copier.run_recopy(dst, defaults=True, overwrite=True, unsafe=True) + assert len(task_counter.read_text().splitlines()) == 2 + + copier.run_update(dst, defaults=True, overwrite=True, unsafe=True) + assert len(task_counter.read_text().splitlines()) == 2