From eea53dab8a0c4f86e5720fbc9652d56bc19b790c Mon Sep 17 00:00:00 2001 From: jeanluc Date: Mon, 12 Aug 2024 15:16:35 +0200 Subject: [PATCH] feat: Add _copier_conf.operation variable --- copier/main.py | 14 ++++++++++---- tests/conftest.py | 4 ++++ tests/templates.py | 37 +++++++++++++++++++++++++++++++++++++ tests/test_copy.py | 10 ++++++++++ tests/test_recopy.py | 20 ++++++++++++++++++++ tests/test_updatediff.py | 22 ++++++++++++++++++++++ 6 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 tests/templates.py diff --git a/copier/main.py b/copier/main.py index 99916287e..cc67fc044 100644 --- a/copier/main.py +++ b/copier/main.py @@ -60,6 +60,9 @@ _T = TypeVar("_T") +Operation = Literal["copy", "recopy", "update"] + + @dataclass(config=ConfigDict(extra="forbid")) class Worker: """Copier process state manager. @@ -195,6 +198,7 @@ class Worker: unsafe: bool = False skip_answered: bool = False skip_tasks: bool = False + operation: Operation = "copy" answers: AnswersMap = field(default_factory=AnswersMap, init=False) _cleanup_hooks: list[Callable[[], None]] = field(default_factory=list, init=False) @@ -234,7 +238,7 @@ def _cleanup(self) -> None: for method in self._cleanup_hooks: method() - def _check_unsafe(self, mode: Literal["copy", "update"]) -> None: + def _check_unsafe(self, mode: Operation) -> None: """Check whether a template uses unsafe features.""" if self.unsafe: return @@ -846,7 +850,7 @@ def run_recopy(self) -> None: "Cannot recopy because cannot obtain old template references " f"from `{self.subproject.answers_relpath}`." ) - with replace(self, src_path=self.subproject.template.url) as new_worker: + with replace(self, src_path=self.subproject.template.url, operation="recopy") as new_worker: new_worker.run_copy() def run_update(self) -> None: @@ -896,8 +900,10 @@ def run_update(self) -> None: print( f"Updating to template version {self.template.version}", file=sys.stderr ) - self._apply_update() - self._print_message(self.template.message_after_update) + with replace(self, operation="update") as worker: + worker._apply_update() + worker._print_message(worker.template.message_after_update) + self.answers = worker.answers def _apply_update(self) -> None: # noqa: C901 git = get_git() diff --git a/tests/conftest.py b/tests/conftest.py index c5d5ba834..85db092b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,10 @@ from .helpers import Spawn +pytest_plugins = [ + "tests.templates", +] + @pytest.fixture def spawn() -> Spawn: diff --git a/tests/templates.py b/tests/templates.py new file mode 100644 index 000000000..e3ff04d23 --- /dev/null +++ b/tests/templates.py @@ -0,0 +1,37 @@ +import shutil +from pathlib import Path +from typing import Generator + +import pytest + +from .helpers import build_file_tree, git_save + + +@pytest.fixture(params=("copy", "recopy", "update")) +def operation_context_template(tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest) -> Generator[Path, None, None]: + src = tmp_path_factory.mktemp(f"operation_template_{request.param}") + try: + build_file_tree( + { + (src / f"{{% if _copier_conf.operation == '{request.param}' %}}foo{{% endif %}}"): "foo", + (src / "bar"): "bar", + (src / "{{ _copier_conf.answers_file }}.jinja"): "{{ _copier_answers|to_nice_yaml }}", + } + ) + git_save(src, tag="1.0.0") + yield src + finally: + shutil.rmtree(src, ignore_errors=True) + + +@pytest.fixture +def operation_context_template_v2(operation_context_template: Path) -> Path: + conditional_file = next(iter(operation_context_template.glob("*foo*"))) + build_file_tree( + { + conditional_file: "foo_update", + (operation_context_template / "bar"): "bar_update", + } + ) + git_save(operation_context_template, tag="2.0.0") + return operation_context_template diff --git a/tests/test_copy.py b/tests/test_copy.py index edaca4932..d3e240237 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -945,3 +945,13 @@ def test_multiselect_choices_preserve_order( ) copier.run_copy(str(src), dst, data={"q": ["three", "one", "two"]}) assert yaml.safe_load((dst / "q.yml").read_text()) == ["one", "two", "three"] + + +def test_operation_context(tmp_path: Path, operation_context_template: Path) -> None: + run_copy(str(operation_context_template), tmp_path) + conditional_file = tmp_path / "foo" + expected = "_copy" in operation_context_template.name + assert conditional_file.exists() is expected + if expected: + assert conditional_file.read_text() == "foo" + assert (tmp_path / "bar").read_text() == "bar" diff --git a/tests/test_recopy.py b/tests/test_recopy.py index 06257d958..4e3cd5dec 100644 --- a/tests/test_recopy.py +++ b/tests/test_recopy.py @@ -72,3 +72,23 @@ def test_recopy_works_without_replay(tpl: str, tmp_path: Path) -> None: # Recopy run_recopy(tmp_path, skip_answered=True, overwrite=True) assert (tmp_path / "name.txt").read_text() == "This is my name: Mario." + + +def test_operation_context(tmp_path: Path, operation_context_template: Path) -> None: + run_copy(str(operation_context_template), tmp_path) + git_save(tmp_path) + conditional_file = tmp_path / "foo" + expected_copy = "_copy" in operation_context_template.name + expected_recopy = "recopy" in operation_context_template.name + assert conditional_file.exists() is expected_copy + assert (tmp_path / "bar").read_text() == "bar" + if expected_copy: + assert conditional_file.read_text() == "foo" + conditional_file.unlink() + (tmp_path / "bar").write_text("baz") + git_save(tmp_path) + run_recopy(str(tmp_path), overwrite=True) + assert conditional_file.exists() is expected_recopy + if expected_recopy: + assert conditional_file.read_text() == "foo" + assert (tmp_path / "bar").read_text() == "bar" diff --git a/tests/test_updatediff.py b/tests/test_updatediff.py index 9c8db3680..17e396c6c 100644 --- a/tests/test_updatediff.py +++ b/tests/test_updatediff.py @@ -25,6 +25,7 @@ build_file_tree, git, git_init, + git_save, ) @@ -1290,3 +1291,24 @@ def test_update_with_new_file_in_template_and_project_via_migration( >>>>>>> after updating """ ) + + +def test_operation_context(tmp_path: Path, operation_context_template: Path, request: pytest.FixtureRequest) -> None: + run_copy(str(operation_context_template), tmp_path) + conditional_file = tmp_path / "foo" + expected_copy = "_copy" in operation_context_template.name + expected_update = "update" in operation_context_template.name + assert conditional_file.exists() is expected_copy + assert (tmp_path / "bar").read_text() == "bar" + if expected_copy: + assert conditional_file.read_text() == "foo" + git_save(tmp_path) + request.getfixturevalue("operation_context_template_v2") + run_update(str(tmp_path), overwrite=True) + if expected_update: + assert conditional_file.read_text() == "foo_update" + elif expected_copy: + assert conditional_file.read_text() == "foo" + else: + assert not conditional_file.exists() + assert (tmp_path / "bar").read_text() == "bar_update"