Skip to content

Commit

Permalink
Fixed issue with relative paths in local files.
Browse files Browse the repository at this point in the history
You can pass a local_path to use to resolve relative local paths for I/O and template determination.
  • Loading branch information
coordt committed Sep 28, 2023
1 parent c72790f commit 7f1b354
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 33 deletions.
48 changes: 41 additions & 7 deletions cookie_composer/templates/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from typing import Optional, Tuple
from urllib.parse import urlparse

from cookiecutter.config import get_user_config

from cookie_composer.templates.git_repo import template_repo_from_git
from cookie_composer.templates.types import Locality, TemplateFormat, TemplateRepo
from cookie_composer.templates.zipfile_repo import template_repo_from_zipfile


def identify_repo(url: str) -> Tuple[TemplateFormat, Locality]:
def identify_repo(url: str, local_path: Optional[Path] = None) -> Tuple[TemplateFormat, Locality]:
"""Identify the repo format and locality from the URL."""
parsed_url = urlparse(url)
locality = Locality.LOCAL if parsed_url.scheme in {"", "file"} else Locality.REMOTE
Expand All @@ -17,7 +19,7 @@ def identify_repo(url: str) -> Tuple[TemplateFormat, Locality]:
return TemplateFormat.ZIP, locality

if locality == Locality.LOCAL:
git_path = Path(parsed_url.path).joinpath(".git")
git_path = resolve_local_path(parsed_url.path, local_path).joinpath(".git")
is_git = git_path.exists() and git_path.is_dir()

return TemplateFormat.GIT if is_git else TemplateFormat.PLAIN, locality
Expand All @@ -29,22 +31,54 @@ def identify_repo(url: str) -> Tuple[TemplateFormat, Locality]:


def get_template_repo(
url: str, cache_dir: Path, checkout: Optional[str] = None, password: Optional[str] = None
url: str, local_path: Optional[Path] = None, checkout: Optional[str] = None, password: Optional[str] = None
) -> TemplateRepo:
"""Get the template repo for a URL."""
tmpl_format, locality = identify_repo(url)
"""
Get a template repository from a URL.
Args:
url: The string from the template field in the composition file.
local_path: Used to resolve local paths.
checkout: The branch, tag or commit to check out after git clone
password: The password to use if template is a password-protected Zip archive.
Returns:
A :class:`TemplateRepo` object.
"""
user_config = get_user_config()
tmpl_format, locality = identify_repo(url, local_path)

if locality == Locality.LOCAL:
cache_dir = resolve_local_path(url, local_path)
else:
cache_dir = Path(user_config["cookiecutters_dir"])

if tmpl_format == TemplateFormat.ZIP:
return template_repo_from_zipfile(url, locality, cache_dir, password=password)
elif tmpl_format == TemplateFormat.GIT:
return template_repo_from_git(url, locality, cache_dir, checkout=checkout)
else:
dir_path = Path(url).expanduser().resolve()
return TemplateRepo(
source=url,
cached_source=Path(dir_path),
cached_source=cache_dir,
format=TemplateFormat.PLAIN,
locality=Locality.LOCAL,
checkout=None,
password=None,
)


def resolve_local_path(url: str, local_path: Optional[Path] = None) -> Path:
"""
Resolve a local path.
Args:
url: The string from the template field in the composition file.
local_path: An optional path to resolve the URL against.
Returns:
The resolved path.
"""
if local_path is None:
return Path(url).expanduser().resolve()
return local_path.joinpath(url).expanduser().resolve()
86 changes: 60 additions & 26 deletions tests/templates/test_source.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
"""Tests for the templates.source module."""
from typing import Tuple

from cookiecutter.config import get_user_config

from cookie_composer.templates.types import TemplateFormat, Locality, TemplateRepo
from cookie_composer.templates.source import identify_repo, get_template_repo

import pytest
from pathlib import Path

CACHE_DIR = Path(get_user_config()["cookiecutters_dir"])


@pytest.mark.parametrize(
"url, expected",
"url, local_path, expected",
[
("https://example.com/repo.zip", (TemplateFormat.ZIP, Locality.REMOTE)),
("file:///path/to/repo.zip", (TemplateFormat.ZIP, Locality.LOCAL)),
("/path/to/repo.zip", (TemplateFormat.ZIP, Locality.LOCAL)),
("https://github.com/user/repo.git", (TemplateFormat.GIT, Locality.REMOTE)),
("git+https://github.com/user/repo.git", (TemplateFormat.GIT, Locality.REMOTE)),
("https://git.example.com/user/repo", (TemplateFormat.GIT, Locality.REMOTE)),
("https://example.com/repo.zip", None, (TemplateFormat.ZIP, Locality.REMOTE)),
("file:///path/to/repo.zip", None, (TemplateFormat.ZIP, Locality.LOCAL)),
("/path/to/repo.zip", None, (TemplateFormat.ZIP, Locality.LOCAL)),
("https://github.com/user/repo.git", None, (TemplateFormat.GIT, Locality.REMOTE)),
("git+https://github.com/user/repo.git", None, (TemplateFormat.GIT, Locality.REMOTE)),
("https://git.example.com/user/repo", None, (TemplateFormat.GIT, Locality.REMOTE)),
(".", Path("/path/to/local/"), (TemplateFormat.PLAIN, Locality.LOCAL)),
("../template", Path("/path/to/template"), (TemplateFormat.PLAIN, Locality.LOCAL)),
],
)
def test_identify_repo(url, expected):
def test_identify_repo(url, local_path, expected):
"""Test that the repo format and locality are identified correctly."""
result = identify_repo(url)
result = identify_repo(url, local_path)
assert result == expected


Expand All @@ -44,23 +50,35 @@ def test_local_plain_repo(mocker):


@pytest.mark.parametrize(
"url, tmpl_format, locality, expected_args",
"url, local_path, cache_dir, tmpl_format, locality, expected_args",
[
(
"https://example.com/repo.zip",
None,
CACHE_DIR,
TemplateFormat.ZIP,
Locality.REMOTE,
{"password": None},
),
(
"/path/to/local/repo.zip",
None,
Path("/path/to/local/repo.zip"),
TemplateFormat.ZIP,
Locality.LOCAL,
{"password": None},
),
(
"repo.zip",
Path("/path/to/local/"),
Path("/path/to/local/repo.zip"),
TemplateFormat.ZIP,
Locality.LOCAL,
{"password": None},
),
],
)
def test_get_template_repo(mocker, url, tmpl_format, locality, expected_args):
def test_get_template_repo(mocker, url, local_path, cache_dir, tmpl_format, locality, expected_args):
# Mock identify_repo to return tmpl_format and locality
mocker.patch("cookie_composer.templates.source.identify_repo", return_value=(tmpl_format, locality))

Expand All @@ -69,27 +87,42 @@ def test_get_template_repo(mocker, url, tmpl_format, locality, expected_args):
"cookie_composer.templates.source.template_repo_from_zipfile", return_value=mocker.Mock(spec=TemplateRepo)
)

# Call the function
cache_dir = Path("/cache")
result = get_template_repo(url, cache_dir)
result = get_template_repo(url, local_path=local_path)

# Check which function was called and with which arguments
mocked_func.assert_called_once_with(url, locality, cache_dir, **expected_args)


@pytest.mark.parametrize(
"url, tmpl_format, locality, expected_args",
"url, local_path, cache_dir, tmpl_format, locality, expected_args",
[
(
"https://github.com/user/repo.git",
None,
CACHE_DIR,
TemplateFormat.GIT,
Locality.REMOTE,
{"checkout": None},
),
("/path/to/local/git/repo", TemplateFormat.GIT, Locality.LOCAL, {"checkout": None}),
(
"/path/to/local/git/repo",
None,
Path("/path/to/local/git/repo"),
TemplateFormat.GIT,
Locality.LOCAL,
{"checkout": None},
),
(
"repo",
Path("/path/to/local/git/"),
Path("/path/to/local/git/repo"),
TemplateFormat.GIT,
Locality.LOCAL,
{"checkout": None},
),
],
)
def test_get_template_repo_git(mocker, url, tmpl_format, locality, expected_args):
def test_get_template_repo_git(mocker, url, local_path, cache_dir, tmpl_format, locality, expected_args):
# Mock identify_repo to return tmpl_format and locality
mocker.patch("cookie_composer.templates.source.identify_repo", return_value=(tmpl_format, locality))

Expand All @@ -99,28 +132,29 @@ def test_get_template_repo_git(mocker, url, tmpl_format, locality, expected_args
)

# Call the function
cache_dir = Path("/cache")
result = get_template_repo(url, cache_dir)
result = get_template_repo(url, local_path=local_path)
mocked_func.assert_called_once_with(url, locality, cache_dir, **expected_args)


@pytest.mark.parametrize(
"url, tmpl_format, locality",
"url, local_path, tmpl_format, locality",
[
("/path/to/local/plain/repo", TemplateFormat.PLAIN, Locality.LOCAL),
("/path/to/local/plain/repo", None, TemplateFormat.PLAIN, Locality.LOCAL),
("repo", Path("/path/to/local/plain/"), TemplateFormat.PLAIN, Locality.LOCAL),
],
)
def test_get_template_repo_plain(mocker, url, tmpl_format, locality):
def test_get_template_repo_plain(mocker, url, local_path, tmpl_format, locality):
# Mock identify_repo to return tmpl_format and locality
mocker.patch("cookie_composer.templates.source.identify_repo", return_value=(tmpl_format, locality))

# Mock the relevant return function to just return a dummy TemplateRepo
# Call the function
cache_dir = Path("/cache")
result = get_template_repo(url, cache_dir)
result = get_template_repo(url, local_path=local_path)

assert isinstance(result, TemplateRepo)
assert result.source == url
if local_path is None:
assert result.cached_source == Path(url).expanduser().resolve()
else:
assert result.cached_source == local_path.joinpath(url).expanduser().resolve()
assert result.format == TemplateFormat.PLAIN
assert result.locality == Locality.LOCAL
assert result.checkout is None
Expand Down

0 comments on commit 7f1b354

Please sign in to comment.