diff --git a/src/ansible_creator/subcommands/init.py b/src/ansible_creator/subcommands/init.py index 093cdecb..066634fd 100644 --- a/src/ansible_creator/subcommands/init.py +++ b/src/ansible_creator/subcommands/init.py @@ -11,7 +11,7 @@ from ansible_creator.exceptions import CreatorError from ansible_creator.templar import Templar from ansible_creator.types import TemplateData -from ansible_creator.utils import Copier +from ansible_creator.utils import Copier, Walker if TYPE_CHECKING: @@ -125,14 +125,19 @@ def _scaffold(self) -> None: dev_file_name=self.unique_name_in_devfile(), ) - copier = Copier( - resources=[f"{self._project}_project", *self.common_resources], + walker = Walker( + resources=(f"{self._project}_project", *self.common_resources), resource_id=f"{self._project}_project", dest=self._init_path, output=self.output, templar=self._templar, template_data=template_data, ) - copier.copy_containers() + paths = walker.collect_paths() + + copier = Copier( + output=self.output, + ) + copier.copy_containers(paths) self.output.note(f"{self._project} project created at {self._init_path}") diff --git a/src/ansible_creator/utils.py b/src/ansible_creator/utils.py index 1d5d6773..7ce4017f 100644 --- a/src/ansible_creator/utils.py +++ b/src/ansible_creator/utils.py @@ -4,8 +4,10 @@ import copy import os +import shutil from dataclasses import dataclass +from functools import cached_property from importlib import resources as impl_resources from pathlib import Path from typing import TYPE_CHECKING @@ -16,7 +18,6 @@ if TYPE_CHECKING: - from ansible_creator.compat import Traversable from ansible_creator.output import Output from ansible_creator.templar import Templar @@ -65,8 +66,108 @@ def expand_path(path: str) -> Path: @dataclass -class Copier: - """Configuration for the Copier class. +class DestinationFile: + """Container to hold information about a file to be copied. + + Attributes: + source: The path of the original copy. + dest: The path the file will be written to. + content: The templated content to be written to dest. + """ + + source: Traversable + dest: Path + content: str = "" + + def __str__(self) -> str: + """Supports str() on DestinationFile. + + Returns: + A string representation of the destination path. + """ + return str(self.dest) + + @cached_property + def conflict(self) -> str: + """Check for file conflicts. + + Returns: + String describing the file conflict, if any. + """ + if not self.dest.exists(): + return "" + + if self.source.is_file(): + if self.dest.is_file(): + dest_content = self.dest.read_text("utf8") + if self.content != dest_content: + return f"{self.dest} will be overwritten!" + else: + return f"{self.dest} already exists and is a directory!" + + if self.source.is_dir() and not self.dest.is_dir(): + return f"{self.dest} already exists and is a file!" + + return "" + + @cached_property + def needs_write(self) -> bool: + """Check if file needs to be written to. + + Returns: + True if dest differs from source else False. + """ + # Skip files in SKIP_FILES_TYPES and __meta__.yaml + if self.source.is_file() and ( + self.source.name.split(".")[-1] in SKIP_FILES_TYPES + or self.source.name == "__meta__.yml" + ): + return False + + if not self.dest.exists(): + return True + return bool(self.conflict) + + def set_content(self, template_data: TemplateData, templar: Templar | None) -> None: + """Set expected content from source file, templated by templar if necessary. + + Args: + template_data: A dictionary containing current data to render templates with. + templar: An instance of the Templar class. + """ + content = self.source.read_text(encoding="utf-8") + # only render as templates if both of these are provided, + # and original file suffix was j2 + if templar and template_data and self.source.name.endswith("j2"): + content = templar.render_from_content( + template=content, + data=template_data, + ) + self.content = content + + def remove_existing(self) -> None: + """Remove existing files or directories at destination path.""" + if self.dest.is_file(): + self.dest.unlink() + elif self.dest.is_dir(): + shutil.rmtree(self.dest) + + +class FileList(list[DestinationFile]): + """A list subclass holding DestinationFiles with convenience methods.""" + + def has_conflicts(self) -> bool: + """Check if any files have conflicts in the destination. + + Returns: + True if there are any conflicts else False. + """ + return any(path.conflict for path in self) + + +@dataclass +class Walker: + """Configuration for the Walker class. Attributes: resources: List of resource containers to copy. @@ -74,144 +175,126 @@ class Copier: dest: The destination path to copy resources to. output: An instance of the Output class. template_data: A dictionary containing the original data to render templates with. - index: Index of the current resource being copied. resource_root: Root path for the resources. templar: An instance of the Templar class. """ - resources: list[str] + resources: tuple[str, ...] resource_id: str dest: Path output: Output template_data: TemplateData - index: int = 0 resource_root: str = "ansible_creator.resources" templar: Templar | None = None - @property - def resource(self: Copier) -> str: - """Return the current resource being copied.""" - return self.resources[self.index] - - def _recursive_copy( - self: Copier, + def _recursive_walk( + self, root: Traversable, + resource: str, template_data: TemplateData, - ) -> None: - """Recursively traverses a resource container and copies content to destination. + ) -> FileList: + """Recursively traverses a resource container looking for content to copy. Args: root: A traversable object representing root of the container to copy. + resource: The resource being scanned. template_data: A dictionary containing current data to render templates with. + + Returns: + A list of paths to be written to. """ self.output.debug(msg=f"current root set to {root}") + file_list = FileList() for obj in root.iterdir(): - self.each_obj(obj, template_data) + file_list.extend( + self.each_obj( + obj, + resource=resource, + template_data=template_data, + ), + ) + return file_list - def each_obj(self, obj: Traversable, template_data: TemplateData) -> None: + def each_obj( + self, + obj: Traversable, + resource: str, + template_data: TemplateData, + ) -> FileList: """Recursively traverses a resource container and copies content to destination. Args: obj: A traversable object representing the root of the container to copy. + resource: The resource to consult for path names. template_data: A dictionary containing current data to render templates with. + + Returns: + A list of paths. """ # resource names may have a . but directories use / in the path dest_name = str(obj).split( - self.resource.replace(".", "/") + "/", + resource.replace(".", "/") + "/", maxsplit=1, )[-1] - dest_path = self.dest / dest_name - # replace placeholders in destination path with real values for key, val in PATH_REPLACERS.items(): - if key in str(dest_path) and template_data: - str_dest_path = str(dest_path) + if key in dest_name: repl_val = getattr(template_data, val) - dest_path = Path(str_dest_path.replace(key, repl_val)) - - if obj.is_dir(): - if obj.name in SKIP_DIRS: - return - self._recursive_copy_dir(obj=obj, dest_path=dest_path, template_data=template_data) - - elif obj.is_file(): - if obj.name.split(".")[-1] in SKIP_FILES_TYPES or obj.name == "__meta__.yml": - return - self._copy_file( - obj=obj, - dest_path=dest_path, - template_data=template_data, - ) + dest_name = dest_name.replace(key, repl_val) + dest_name = dest_name.removesuffix(".j2") - def _copy_file( - self, - obj: Traversable, - dest_path: Path, - template_data: TemplateData, - ) -> None: - """Copy a file to destination. - - Args: - obj: A traversable object representing the file to copy. - dest_path: The destination path to copy the file to. - template_data: A dictionary containing current data to render templates with. - """ - # remove .j2 suffix at destination - needs_templating = False - if dest_path.suffix == ".j2": - dest_path = dest_path.with_suffix("") - needs_templating = True - dest_file = Path(self.dest) / dest_path - self.output.debug(msg=f"dest file is {dest_file}") - - content = obj.read_text(encoding="utf-8") - # only render as templates if both of these are provided, - # and original file suffix was j2 - if self.templar and template_data and needs_templating: - content = self.templar.render_from_content( - template=content, - data=template_data, - ) - with dest_file.open("w", encoding="utf-8") as df_handle: - df_handle.write(content) - - def _recursive_copy_dir( - self, - obj: Traversable, - dest_path: Path, - template_data: TemplateData, - ) -> None: - """Recursively copy directories to destination. + dest_path = DestinationFile( + dest=self.dest / dest_name, + source=obj, + ) + self.output.debug(f"Looking at {dest_path}") + + if obj.is_file(): + dest_path.set_content(template_data, self.templar) + + if dest_path.needs_write: + # Warn on conflict + conflict_msg = dest_path.conflict + if conflict_msg: + self.output.warning(conflict_msg) + + if obj.is_dir() and obj.name not in SKIP_DIRS: + return FileList( + [ + dest_path, + *self._recursive_walk( + root=obj, + resource=resource, + template_data=template_data, + ), + ], + ) + if obj.is_file(): + return FileList([dest_path]) - Args: - obj: A traversable object representing the directory to copy. - dest_path: The destination path to copy the directory to. - template_data: A dictionary containing current data to render templates with. - """ - dest_path.mkdir(parents=True, exist_ok=True) + if obj.is_dir() and obj.name not in SKIP_DIRS: + return self._recursive_walk(root=obj, resource=resource, template_data=template_data) - # recursively copy the directory - self._recursive_copy( - root=obj, - template_data=template_data, - ) + return FileList() - def _per_container(self: Copier) -> None: - """Copy files and directories from a possibly nested source to a destination. + def _per_container(self, resource: str) -> FileList: + """Generate a list of all paths that will be written to for a particular resource. - :param copier_config: Configuration for the Copier class. + Args: + resource: The resource to search through. - :raises CreatorError: if allow_overwrite is not a list. + Returns: + A list of paths to be written to. """ - msg = f"starting recursive copy with source container '{self.resource}'" + msg = f"starting recursive walk with source container '{resource}'" self.output.debug(msg) # Cast the template data to not pollute the original template_data = copy.deepcopy(self.template_data) # Collect and template any resource specific variables - meta_file = impl_resources.files(f"{self.resource_root}.{self.resource}") / "__meta__.yml" + meta_file = impl_resources.files(f"{self.resource_root}.{resource}") / "__meta__.yml" try: with meta_file.open("r", encoding="utf-8") as meta_fileh: self.output.debug( @@ -235,18 +318,61 @@ def _per_container(self: Copier) -> None: else: setattr(template_data, key, value["value"]) - self._recursive_copy( - root=impl_resources.files(f"{self.resource_root}.{self.resource}"), - template_data=template_data, + return self._recursive_walk( + impl_resources.files(f"{self.resource_root}.{resource}"), + resource, + template_data, ) - def copy_containers( - self: Copier, + def collect_paths(self) -> FileList: + """Determine paths that will be written to. + + Returns: + A list of paths to be written to. + """ + file_list = FileList() + for resource in self.resources: + file_list.extend(self._per_container(resource)) + + return file_list + + +@dataclass +class Copier: + """Configuration for the Copier class. + + Attributes: + output: An instance of the Output class. + """ + + output: Output + + def _copy_file( + self, + dest_path: DestinationFile, ) -> None: + """Copy a file to destination. + + Args: + dest_path: The destination path to copy the file to. + """ + # remove .j2 suffix at destination + self.output.debug(msg=f"Writing to {dest_path}") + + with dest_path.dest.open("w", encoding="utf-8") as df_handle: + df_handle.write(dest_path.content) + + def copy_containers(self: Copier, paths: FileList) -> None: """Copy multiple containers to destination. - :param copier_config: Configuration for the Copier class. + Args: + paths: A list of paths to create in the destination. """ - for i in range(len(self.resources)): - self.index = i - self._per_container() + for path in paths: + path.remove_existing() + + if path.source.is_dir(): + path.dest.mkdir(parents=True, exist_ok=True) + + elif path.source.is_file(): + self._copy_file(path) diff --git a/tests/units/test_utils.py b/tests/units/test_utils.py index 786b7b2b..d639b83e 100644 --- a/tests/units/test_utils.py +++ b/tests/units/test_utils.py @@ -2,11 +2,13 @@ from __future__ import annotations +import shutil + from pathlib import Path from typing import TYPE_CHECKING from ansible_creator.types import TemplateData -from ansible_creator.utils import Copier, expand_path +from ansible_creator.utils import Copier, Walker, expand_path if TYPE_CHECKING: @@ -35,13 +37,98 @@ def test_skip_dirs(tmp_path: Path, monkeypatch: pytest.MonkeyPatch, output: Outp output: Output class object. """ monkeypatch.setattr("ansible_creator.utils.SKIP_DIRS", ["docker"]) - copier = Copier( - resources=["common.devcontainer"], + + walker = Walker( + resources=("common.devcontainer",), resource_id="common.devcontainer", dest=tmp_path, output=output, template_data=TemplateData(), ) - copier.copy_containers() + paths = walker.collect_paths() + + copier = Copier( + output=output, + ) + copier.copy_containers(paths) assert (tmp_path / ".devcontainer" / "podman").exists() assert not (tmp_path / ".devcontainer" / "docker").exists() + + +def test_overwrite(tmp_path: Path, output: Output) -> None: + """Test Copier overwriting existing files. + + Args: + tmp_path: Temporary directory path. + output: Output class object. + """ + walker = Walker( + resources=("common.devcontainer",), + resource_id="common.devcontainer", + dest=tmp_path, + output=output, + template_data=TemplateData(), + ) + paths = walker.collect_paths() + + # We will be manipulating these paths later + base_file = tmp_path / ".devcontainer" / "devcontainer.json" + podman_dir = tmp_path / ".devcontainer" / "podman" + docker_file = tmp_path / ".devcontainer" / "docker" / "devcontainer.json" + + copier = Copier( + output=output, + ) + copier.copy_containers(paths) + base_contents = base_file.read_text() + assert podman_dir.is_dir() + assert docker_file.is_file() + + # Rewrite devcontainer.json + base_file.write_text("This is not what a devcontainer file looks like.") + # Replace podman with a file + shutil.rmtree(podman_dir) + podman_dir.write_text("This is an error") + # Replace docker devcontainer with a directory + docker_file.unlink() + docker_file.mkdir() + + # Re-walk directory to generate warnings, but not make changes + paths = walker.collect_paths() + assert base_file.read_text() != base_contents + assert podman_dir.is_file() + assert docker_file.is_dir() + assert paths.has_conflicts() + + # Re-copy to overwrite structure + copier.copy_containers(paths) + assert base_file.read_text() == base_contents + assert podman_dir.is_dir() + assert docker_file.is_file() + + +def test_skip_repeats(tmp_path: Path, output: Output) -> None: + """Test Copier skipping existing files. + + Args: + tmp_path: Temporary directory path. + output: Output class object. + """ + walker = Walker( + resources=("common.devcontainer",), + resource_id="common.devcontainer", + dest=tmp_path, + output=output, + template_data=TemplateData(), + ) + paths = walker.collect_paths() + assert paths + + copier = Copier( + output=output, + ) + copier.copy_containers(paths) + + # Re-walk directory to generate new path list + paths = walker.collect_paths() + assert not paths