Skip to content

Commit

Permalink
Support updating cargo files
Browse files Browse the repository at this point in the history
ghstack-source-id: ffce44ec852e06945baa209c601de469282ea9aa
Pull Request resolved: #168
  • Loading branch information
amyreese committed Feb 16, 2024
1 parent 9f5b849 commit 646d883
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 12 deletions.
80 changes: 76 additions & 4 deletions attribution/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,34 @@

import textwrap
from pathlib import Path
from typing import Any, List, Tuple

import tomlkit
from jinja2 import Template

from .project import Project


class GeneratedFile:
EXPECTS: Tuple[str, ...] = ()
FILENAME: str = "FAKE.md"
TEMPLATE: str = "FAKE FILE, DO NOT COMMIT!"

def __init__(self, project: Project):
def __init__(self, project: Project, **kwargs: str):
self.project = project
assert all(kw in kwargs for kw in self.EXPECTS)
self.kwargs = kwargs
self.filename = project.root / self.FILENAME.format(project=project, **kwargs)

def __eq__(self, other: Any) -> bool:
return ( # noqa E721
type(self) == type(other)
and self.project == other.project
and self.kwargs == other.kwargs
)

def __repr__(self) -> str:
return f"CargoFile({self.project!r}, **{self.kwargs!r})"

def generate(self) -> str:
tags = self.project.tags
Expand All @@ -23,14 +39,14 @@ def generate(self) -> str:
project=self.project,
tags=tags,
len=len,
**self.kwargs,
)
return output

def write(self) -> Path:
content = self.generate()
fpath = Path(self.FILENAME.format(project=self.project))
fpath.write_text(content)
return fpath
self.filename.write_text(content)
return self.filename


class Changelog(GeneratedFile):
Expand Down Expand Up @@ -84,3 +100,59 @@ class VersionFile(GeneratedFile):
__version__ = "{{ project.latest.version }}"
'''


class CargoFile(GeneratedFile):
EXPECTS = ("package_name", "package_dir")
FILENAME = "{package_dir}/Cargo.toml"

def generate(self) -> str:
assert self.filename.is_file()
package_name = self.kwargs["package_name"]

data = tomlkit.loads(self.filename.read_text())
assert "package" in data
package_data: tomlkit.items.Table = data.get("package", tomlkit.table())
assert package_data.get("name", "") == package_name
package_data["version"] = str(self.project.latest.version)
return tomlkit.dumps(data)

def write(self) -> Path:
fn = super().write()
assert fn.name == "Cargo.toml"
lock_file = fn.with_suffix(".lock")
if lock_file.is_file():
package_name = self.kwargs["package_name"]
lock_data = tomlkit.loads(lock_file.read_text())
assert lock_data.get("version", 0) == 3
for package_data in lock_data.get("package", ()):
if package_data.get("name", "") == package_name:
package_data["version"] = str(self.project.latest.version)
lock_file.write_text(tomlkit.dumps(lock_data))

return fn

@classmethod
def search(cls, project: Project, cargo_packages: List[str]) -> List["CargoFile"]:
found_packages: List[Tuple[str, Path]] = []
queue = [project.root]
while queue:
path = queue.pop(0)
if path.is_dir():
queue += list(path.iterdir())
elif path.is_file() and path.name == "Cargo.toml":
cargo_data = tomlkit.loads(path.read_text())
assert "package" in cargo_data
package_data = cargo_data.get("package", tomlkit.table())
package_name = package_data.get("name", "")
if package_name in cargo_packages:
found_packages.append((package_name, path.parent))

return [
CargoFile(
project,
package_name=package_name,
package_dir=package_dir.relative_to(project.root).as_posix(),
)
for package_name, package_dir in found_packages
]
5 changes: 4 additions & 1 deletion attribution/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import tomlkit

from attribution import __version__
from .generate import Changelog, VersionFile
from .generate import CargoFile, Changelog, VersionFile
from .helpers import sh
from .project import Project
from .tag import Tag
Expand Down Expand Up @@ -171,6 +171,9 @@ def tag_release(version: Version, message: Optional[str]) -> None:
if project.config.get("version_file"):
path = VersionFile(project).write()
sh(f"git add {path}")
if cargo_packages := project.config.get("cargo_packages"):
for cargo_file in CargoFile.search(project, cargo_packages):
cargo_file.write()

# update commit and tag
sh("git commit --amend --no-edit")
Expand Down
15 changes: 11 additions & 4 deletions attribution/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
LOG = logging.getLogger(__name__)


@dataclass(eq=False)
@dataclass
class Project:
name: str
package: str
config: Dict[str, Any] = field(default_factory=dict)
_shortlog: Optional[str] = None
_tags: Tags = field(default_factory=list)
root: Path = field(default_factory=Path.cwd)
_shortlog: Optional[str] = field(default=None, compare=False)
_tags: Tags = field(default_factory=list, compare=False)

def __eq__(self, other: Any) -> bool:
if isinstance(other, Project):
Expand Down Expand Up @@ -97,6 +98,7 @@ def load(cls, path: Optional[Path] = None) -> "Project":
name = ""
package = ""
config: Dict[str, Any] = {
"cargo_packages": [],
"ignored_authors": [],
"version_file": True,
"signed_tags": True,
Expand Down Expand Up @@ -133,7 +135,12 @@ def load(cls, path: Optional[Path] = None) -> "Project":
if not package:
package = canonical_namespace(path.name)

return Project(name=name, package=package, config=config)
return Project(
name=name,
package=package,
config=config,
root=path,
)

@classmethod
def pyproject_path(cls, path: Optional[Path] = None) -> Path:
Expand Down
1 change: 1 addition & 0 deletions attribution/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 Amethyst Reese
# Licensed under the MIT license

from .generate import GenerateTest
from .helpers import HelpersTest
from .project import ProjectTest
from .tag import TagTest
124 changes: 124 additions & 0 deletions attribution/tests/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright Amethyst Reese
# Licensed under the MIT license

from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import TestCase

from .. import generate
from ..project import Project
from ..tag import Tag
from ..types import Version

FAKE_CARGO_TOML = """
[package]
name = "fluffy"
version = "1.0"
edition = "2038"
[dependencies]
dog = "3.1"
"""

FAKE_CARGO_LOCK = """
version = 3
[[package]]
name = "dog"
version = "3.1"
[[package]]
name = "fluffy"
version = "1.0"
dependencies = [
"dog",
]
[[package]]
name = "something"
version = "2.0"
"""


class GenerateTest(TestCase):
def test_cargo_file(self):
with TemporaryDirectory() as td:
tdp = Path(td)
(cargo_toml := tdp / "Cargo.toml").write_text(FAKE_CARGO_TOML)
(cargo_lock := tdp / "Cargo.lock").write_text(FAKE_CARGO_LOCK)
(tdp / "subdir" / "whatever").mkdir(parents=True)
(tdp / "subdir" / "whatever" / "Cargo.toml").write_text(
FAKE_CARGO_TOML.replace('name = "fluffy"', 'name = "whatever"')
)

project = Project(
"fluffy",
"fluffy",
root=tdp,
_tags=[
Tag("v2.1.3", Version("2.1.3")),
Tag("v1.0", Version("1.0")),
],
)

with self.subTest("search no cargo_packages"):
self.assertEqual([], generate.CargoFile.search(project, []))

with self.subTest("search fluffy"):
expected = [
generate.CargoFile(project, package_name="fluffy", package_dir="."),
]
result = generate.CargoFile.search(project, ["fluffy"])
self.assertEqual(expected, result)

with self.subTest("search whatever"):
expected = [
generate.CargoFile(
project, package_name="whatever", package_dir="subdir/whatever"
),
]
result = generate.CargoFile.search(project, ["whatever"])
self.assertEqual(expected, result)

with self.subTest("search fluffy and whatever"):
expected = [
generate.CargoFile(project, package_name="fluffy", package_dir="."),
generate.CargoFile(
project, package_name="whatever", package_dir="subdir/whatever"
),
]
result = generate.CargoFile.search(project, ["fluffy", "whatever"])
self.assertEqual(expected, result)

with self.subTest("generate fluffy"):
expected = FAKE_CARGO_TOML.replace(
'version = "1.0"', 'version = "2.1.3"'
)
result = generate.CargoFile.search(
project,
cargo_packages=["fluffy"],
)[0].generate()
self.assertEqual(expected, result)

with self.subTest("write fluffy"):
expected = FAKE_CARGO_TOML.replace(
'version = "1.0"', 'version = "2.1.3"'
)
result = (
generate.CargoFile.search(
project,
cargo_packages=["fluffy"],
)[0]
.write()
.read_text()
)
self.assertEqual(expected, result)

result = cargo_toml.read_text()
self.assertEqual(expected, result)

expected = FAKE_CARGO_LOCK.replace(
'version = "1.0"', 'version = "2.1.3"'
)
result = cargo_lock.read_text()
self.assertEqual(expected, result)
21 changes: 18 additions & 3 deletions attribution/tests/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def test_load(self, cwd_mock):
cwd_mock.return_value = td

with self.subTest("pyproject in cwd"):
cwd_mock.reset_mock()
project = Project.load()
cwd_mock.assert_called()
self.assertEqual(project.name, "fizzbuzz")
Expand All @@ -160,14 +161,15 @@ def test_load(self, cwd_mock):
{
"name": "fizzbuzz",
"package": "fizzbuzz",
"cargo_packages": [],
"ignored_authors": [],
"version_file": True,
"signed_tags": True,
},
)
cwd_mock.reset_mock()

with self.subTest("pyproject in given path"):
cwd_mock.reset_mock()
project = Project.load(td)
cwd_mock.assert_not_called()
self.assertEqual(project.name, "fizzbuzz")
Expand All @@ -176,6 +178,7 @@ def test_load(self, cwd_mock):
{
"name": "fizzbuzz",
"package": "fizzbuzz",
"cargo_packages": [],
"ignored_authors": [],
"version_file": True,
"signed_tags": True,
Expand All @@ -191,6 +194,7 @@ def test_load(self, cwd_mock):
{
"name": "fizzbuzz",
"package": "fizzbuzz",
"cargo_packages": [],
"ignored_authors": [],
"version_file": True,
"signed_tags": True,
Expand All @@ -206,6 +210,7 @@ def test_load(self, cwd_mock):
{
"name": "fizzbuzz",
"package": "fizzbuzz",
"cargo_packages": [],
"ignored_authors": [],
"version_file": False,
"signed_tags": True,
Expand All @@ -219,7 +224,12 @@ def test_load(self, cwd_mock):
self.assertEqual(project.name, td.name)
self.assertEqual(
project.config,
{"ignored_authors": [], "version_file": True, "signed_tags": True},
{
"cargo_packages": [],
"ignored_authors": [],
"version_file": True,
"signed_tags": True,
},
)

with self.subTest("no pyproject"):
Expand All @@ -229,5 +239,10 @@ def test_load(self, cwd_mock):
self.assertEqual(project.name, td.name)
self.assertEqual(
project.config,
{"ignored_authors": [], "version_file": True, "signed_tags": True},
{
"cargo_packages": [],
"ignored_authors": [],
"version_file": True,
"signed_tags": True,
},
)
Loading

0 comments on commit 646d883

Please sign in to comment.