Skip to content

Commit

Permalink
feat: Add _copier_conf.operation variable
Browse files Browse the repository at this point in the history
  • Loading branch information
lkubb committed Aug 13, 2024
1 parent 0315674 commit eea53da
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 4 deletions.
14 changes: 10 additions & 4 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
_T = TypeVar("_T")


Operation = Literal["copy", "recopy", "update"]


@dataclass(config=ConfigDict(extra="forbid"))
class Worker:
"""Copier process state manager.
Expand Down Expand Up @@ -195,6 +198,7 @@ class Worker:
unsafe: bool = False
skip_answered: bool = False
skip_tasks: bool = False
operation: Operation = "copy"

answers: AnswersMap = field(default_factory=AnswersMap, init=False)
_cleanup_hooks: list[Callable[[], None]] = field(default_factory=list, init=False)
Expand Down Expand Up @@ -234,7 +238,7 @@ def _cleanup(self) -> None:
for method in self._cleanup_hooks:
method()

def _check_unsafe(self, mode: Literal["copy", "update"]) -> None:
def _check_unsafe(self, mode: Operation) -> None:
"""Check whether a template uses unsafe features."""
if self.unsafe:
return
Expand Down Expand Up @@ -846,7 +850,7 @@ def run_recopy(self) -> None:
"Cannot recopy because cannot obtain old template references "
f"from `{self.subproject.answers_relpath}`."
)
with replace(self, src_path=self.subproject.template.url) as new_worker:
with replace(self, src_path=self.subproject.template.url, operation="recopy") as new_worker:
new_worker.run_copy()

def run_update(self) -> None:
Expand Down Expand Up @@ -896,8 +900,10 @@ def run_update(self) -> None:
print(
f"Updating to template version {self.template.version}", file=sys.stderr
)
self._apply_update()
self._print_message(self.template.message_after_update)
with replace(self, operation="update") as worker:
worker._apply_update()
worker._print_message(worker.template.message_after_update)
self.answers = worker.answers

def _apply_update(self) -> None: # noqa: C901
git = get_git()
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

from .helpers import Spawn

pytest_plugins = [
"tests.templates",
]


@pytest.fixture
def spawn() -> Spawn:
Expand Down
37 changes: 37 additions & 0 deletions tests/templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import shutil
from pathlib import Path
from typing import Generator

import pytest

from .helpers import build_file_tree, git_save


@pytest.fixture(params=("copy", "recopy", "update"))
def operation_context_template(tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest) -> Generator[Path, None, None]:
src = tmp_path_factory.mktemp(f"operation_template_{request.param}")
try:
build_file_tree(
{
(src / f"{{% if _copier_conf.operation == '{request.param}' %}}foo{{% endif %}}"): "foo",
(src / "bar"): "bar",
(src / "{{ _copier_conf.answers_file }}.jinja"): "{{ _copier_answers|to_nice_yaml }}",
}
)
git_save(src, tag="1.0.0")
yield src
finally:
shutil.rmtree(src, ignore_errors=True)


@pytest.fixture
def operation_context_template_v2(operation_context_template: Path) -> Path:
conditional_file = next(iter(operation_context_template.glob("*foo*")))
build_file_tree(
{
conditional_file: "foo_update",
(operation_context_template / "bar"): "bar_update",
}
)
git_save(operation_context_template, tag="2.0.0")
return operation_context_template
10 changes: 10 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,3 +945,13 @@ def test_multiselect_choices_preserve_order(
)
copier.run_copy(str(src), dst, data={"q": ["three", "one", "two"]})
assert yaml.safe_load((dst / "q.yml").read_text()) == ["one", "two", "three"]


def test_operation_context(tmp_path: Path, operation_context_template: Path) -> None:
run_copy(str(operation_context_template), tmp_path)
conditional_file = tmp_path / "foo"
expected = "_copy" in operation_context_template.name
assert conditional_file.exists() is expected
if expected:
assert conditional_file.read_text() == "foo"
assert (tmp_path / "bar").read_text() == "bar"
20 changes: 20 additions & 0 deletions tests/test_recopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,23 @@ def test_recopy_works_without_replay(tpl: str, tmp_path: Path) -> None:
# Recopy
run_recopy(tmp_path, skip_answered=True, overwrite=True)
assert (tmp_path / "name.txt").read_text() == "This is my name: Mario."


def test_operation_context(tmp_path: Path, operation_context_template: Path) -> None:
run_copy(str(operation_context_template), tmp_path)
git_save(tmp_path)
conditional_file = tmp_path / "foo"
expected_copy = "_copy" in operation_context_template.name
expected_recopy = "recopy" in operation_context_template.name
assert conditional_file.exists() is expected_copy
assert (tmp_path / "bar").read_text() == "bar"
if expected_copy:
assert conditional_file.read_text() == "foo"
conditional_file.unlink()
(tmp_path / "bar").write_text("baz")
git_save(tmp_path)
run_recopy(str(tmp_path), overwrite=True)
assert conditional_file.exists() is expected_recopy
if expected_recopy:
assert conditional_file.read_text() == "foo"
assert (tmp_path / "bar").read_text() == "bar"
22 changes: 22 additions & 0 deletions tests/test_updatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
build_file_tree,
git,
git_init,
git_save,
)


Expand Down Expand Up @@ -1290,3 +1291,24 @@ def test_update_with_new_file_in_template_and_project_via_migration(
>>>>>>> after updating
"""
)


def test_operation_context(tmp_path: Path, operation_context_template: Path, request: pytest.FixtureRequest) -> None:
run_copy(str(operation_context_template), tmp_path)
conditional_file = tmp_path / "foo"
expected_copy = "_copy" in operation_context_template.name
expected_update = "update" in operation_context_template.name
assert conditional_file.exists() is expected_copy
assert (tmp_path / "bar").read_text() == "bar"
if expected_copy:
assert conditional_file.read_text() == "foo"
git_save(tmp_path)
request.getfixturevalue("operation_context_template_v2")
run_update(str(tmp_path), overwrite=True)
if expected_update:
assert conditional_file.read_text() == "foo_update"
elif expected_copy:
assert conditional_file.read_text() == "foo"
else:
assert not conditional_file.exists()
assert (tmp_path / "bar").read_text() == "bar_update"

0 comments on commit eea53da

Please sign in to comment.