diff --git a/snapcraft/commands/registries.py b/snapcraft/commands/registries.py index 95fe349df2..d71d5b7613 100644 --- a/snapcraft/commands/registries.py +++ b/snapcraft/commands/registries.py @@ -82,6 +82,8 @@ class StoreEditRegistriesCommand(craft_application.commands.AppCommand): If the registries set does not exist, then a new registries set will be created. + If a key name is not provided, the default key is used. + The account ID of the authenticated account can be determined with the ``snapcraft whoami`` command. @@ -100,10 +102,14 @@ def fill_parser(self, parser: "argparse.ArgumentParser") -> None: parser.add_argument( "name", metavar="name", help="Name of the registries set to edit" ) + parser.add_argument( + "--key-name", metavar="key-name", help="Key used to sign the registries set" + ) @override def run(self, parsed_args: "argparse.Namespace"): self._services.registries.edit_assertion( name=parsed_args.name, account_id=parsed_args.account_id, + key_name=parsed_args.key_name, ) diff --git a/snapcraft/errors.py b/snapcraft/errors.py index 408c2fcb02..49c88f32ee 100644 --- a/snapcraft/errors.py +++ b/snapcraft/errors.py @@ -173,3 +173,10 @@ def __init__(self, message: str, *, resolution: str) -> None: resolution=resolution, docs_url="https://snapcraft.io/docs/snapcraft-authentication", ) + + +class SnapcraftAssertionError(SnapcraftError): + """Error raised when an assertion (validation or registries set) is invalid. + + Not to be confused with Python's built-in AssertionError. + """ diff --git a/snapcraft/models/assertions.py b/snapcraft/models/assertions.py index d557ff9092..efe6a30a4e 100644 --- a/snapcraft/models/assertions.py +++ b/snapcraft/models/assertions.py @@ -16,13 +16,44 @@ """Assertion models.""" -from typing import Literal +import numbers +from collections import abc +from typing import Any, Literal import pydantic from craft_application import models from typing_extensions import Self +def cast_dict_scalars_to_strings(data: dict) -> dict: + """Cast all scalars in a dictionary to strings. + + Supported scalar types are str, bool, and numbers. + """ + return {_to_string(key): _to_string(value) for key, value in data.items()} + + +def _to_string(data: Any) -> Any: + """Recurse through nested dicts and lists and cast scalar values to strings. + + Supported scalar types are str, bool, and numbers. + """ + # check for a string first, as it is the most common scenario + if isinstance(data, str): + return data + + if isinstance(data, abc.Mapping): + return {_to_string(key): _to_string(value) for key, value in data.items()} + + if isinstance(data, abc.Collection): + return [_to_string(i) for i in data] + + if isinstance(data, (numbers.Number, bool)): + return str(data) + + return data + + class Registry(models.CraftBaseModel): """Access and data definitions for a specific facet of a snap or system.""" @@ -52,7 +83,6 @@ class EditableRegistryAssertion(models.CraftBaseModel): """Issuer of the registry assertion and owner of the signing key.""" name: str - summary: str | None = None revision: int | None = 0 views: dict[str, Rules] @@ -61,6 +91,10 @@ class EditableRegistryAssertion(models.CraftBaseModel): body: str | None = None """A JSON schema that defines the storage structure.""" + def marshal_scalars_as_strings(self) -> dict[str, Any]: + """Marshal the model where all scalars are represented as strings.""" + return cast_dict_scalars_to_strings(self.marshal()) + class RegistryAssertion(EditableRegistryAssertion): """A full registries assertion containing editable and non-editable fields.""" diff --git a/snapcraft/services/assertions.py b/snapcraft/services/assertions.py index b3650b7883..c5aa03e876 100644 --- a/snapcraft/services/assertions.py +++ b/snapcraft/services/assertions.py @@ -33,6 +33,7 @@ from craft_application.errors import CraftValidationError from craft_application.services import base from craft_application.util import safe_yaml_load +from craft_store.errors import StoreServerError from typing_extensions import override from snapcraft import const, errors, models, store, utils @@ -68,6 +69,24 @@ def _get_assertions(self, name: str | None = None) -> list[models.Assertion]: :returns: A list of assertions. """ + @abc.abstractmethod + def _build_assertion(self, assertion: models.EditableAssertion) -> models.Assertion: + """Build an assertion from an editable assertion. + + :param assertion: The editable assertion to build. + + :returns: The built assertion. + """ + + @abc.abstractmethod + def _post_assertion(self, assertion_data: bytes) -> models.Assertion: + """Post an assertion to the store. + + :param assertion_data: A signed assertion represented as bytes. + + :returns: The published assertion. + """ + @abc.abstractmethod def _normalize_assertions( self, assertions: list[models.Assertion] @@ -102,6 +121,15 @@ def _generate_yaml_from_template(self, name: str, account_id: str) -> str: :returns: A multi-line yaml string. """ + @abc.abstractmethod + def _get_success_message(self, assertion: models.Assertion) -> str: + """Create a message after an assertion has been successfully posted. + + :param assertion: The published assertion. + + :returns: The success message to log. + """ + def list_assertions(self, *, output_format: str, name: str | None = None) -> None: """List assertions from the store. @@ -150,6 +178,7 @@ def _edit_yaml_file(self, filepath: pathlib.Path) -> models.EditableAssertion: :returns: The edited assertion. """ + craft_cli.emit.progress(f"Editing {self._assertion_name}.") while True: craft_cli.emit.debug(f"Using {self._editor_cmd} to edit file.") with craft_cli.emit.pause(): @@ -161,8 +190,9 @@ def _edit_yaml_file(self, filepath: pathlib.Path) -> models.EditableAssertion: data=data, # filepath is only shown for pydantic errors and snapcraft should # not expose the temp file name - filepath=pathlib.Path(self._assertion_name.replace(" ", "-")), + filepath=pathlib.Path(self._assertion_name), ) + craft_cli.emit.progress(f"Edited {self._assertion_name}.") return edited_assertion except (yaml.YAMLError, CraftValidationError) as err: craft_cli.emit.message(f"{err!s}") @@ -178,12 +208,12 @@ def _get_yaml_data(self, name: str, account_id: str) -> str: if assertions := self._get_assertions(name=name): yaml_data = self._generate_yaml_from_model(assertions[0]) + craft_cli.emit.progress( + f"Retrieved {self._assertion_name} '{name}' from the store.", + ) else: craft_cli.emit.progress( - f"Creating a new {self._assertion_name} because no existing " - f"{self._assertion_name} named '{name}' was found for the " - "authenticated account.", - permanent=True, + f"Could not find an existing {self._assertion_name} named '{name}'.", ) yaml_data = self._generate_yaml_from_template( name=name, account_id=account_id @@ -204,30 +234,83 @@ def _remove_temp_file(filepath: pathlib.Path) -> None: craft_cli.emit.trace(f"Removing temporary file '{filepath}'.") filepath.unlink() - def edit_assertion(self, *, name: str, account_id: str) -> None: + @staticmethod + def _sign_assertion(assertion: models.Assertion, key_name: str | None) -> bytes: + """Sign an assertion with `snap sign`. + + :param assertion: The assertion to sign. + :param key_name: Name of the key to sign the assertion. + + :returns: A signed assertion represented as bytes. + """ + craft_cli.emit.progress("Signing assertion.") + cmdline = ["snap", "sign"] + if key_name: + cmdline += ["-k", key_name] + + # snapd expects a json string where all scalars are strings + unsigned_assertion = json.dumps(assertion.marshal_scalars_as_strings()) + + try: + # pause the emitter for passphrase prompts + with craft_cli.emit.pause(): + signed_assertion = subprocess.check_output( + cmdline, input=unsigned_assertion.encode() + ) + except subprocess.CalledProcessError as sign_error: + raise errors.SnapcraftAssertionError( + "Failed to sign assertion" + ) from sign_error + + craft_cli.emit.progress("Signed assertion.") + craft_cli.emit.trace(f"Signed assertion: {signed_assertion.decode()}") + return signed_assertion + + def edit_assertion( + self, *, name: str, account_id: str, key_name: str | None = None + ) -> None: """Edit, sign and upload an assertion. If the assertion does not exist, a new assertion is created from a template. :param name: The name of the assertion to edit. :param account_id: The account ID associated with the registries set. + :param key_name: Name of the key to sign the assertion. """ yaml_data = self._get_yaml_data(name=name, account_id=account_id) yaml_file = self._write_to_file(yaml_data) original_assertion = self._editable_assertion_class.unmarshal( safe_yaml_load(io.StringIO(yaml_data)) ) - edited_assertion = self._edit_yaml_file(yaml_file) - if edited_assertion == original_assertion: - craft_cli.emit.message("No changes made.") + try: + while True: + try: + edited_assertion = self._edit_yaml_file(yaml_file) + if edited_assertion == original_assertion: + craft_cli.emit.message("No changes made.") + break + + craft_cli.emit.progress(f"Building {self._assertion_name}.") + built_assertion = self._build_assertion(edited_assertion) + craft_cli.emit.progress(f"Built {self._assertion_name}.") + + signed_assertion = self._sign_assertion(built_assertion, key_name) + published_assertion = self._post_assertion(signed_assertion) + craft_cli.emit.message( + self._get_success_message(published_assertion) + ) + break + except ( + StoreServerError, + errors.SnapcraftAssertionError, + ) as assertion_error: + craft_cli.emit.message(str(assertion_error)) + if not utils.confirm_with_user( + f"Do you wish to amend the {self._assertion_name}?" + ): + raise errors.SnapcraftError( + "operation aborted" + ) from assertion_error + finally: self._remove_temp_file(yaml_file) - return - - # TODO: build, sign, and push assertion (#5018) - - self._remove_temp_file(yaml_file) - craft_cli.emit.message(f"Successfully edited {self._assertion_name} {name!r}.") - raise errors.FeatureNotImplemented( - f"Building, signing and uploading {self._assertion_name} is not implemented.", - ) diff --git a/snapcraft/services/registries.py b/snapcraft/services/registries.py index e6cd785c16..ce4de0c18f 100644 --- a/snapcraft/services/registries.py +++ b/snapcraft/services/registries.py @@ -31,7 +31,6 @@ """\ account-id: {account_id} name: {set_name} - # summary: {summary} # The revision for this registries set # revision: {revision} {views} @@ -85,6 +84,14 @@ def _editable_assertion_class(self) -> type[models.EditableAssertion]: def _get_assertions(self, name: str | None = None) -> list[models.Assertion]: return self._store_client.list_registries(name=name) + @override + def _build_assertion(self, assertion: models.EditableAssertion) -> models.Assertion: + return self._store_client.build_registries(registries=assertion) + + @override + def _post_assertion(self, assertion_data: bytes) -> models.Assertion: + return self._store_client.post_registries(registries_data=assertion_data) + @override def _normalize_assertions( self, assertions: list[models.Assertion] @@ -110,7 +117,6 @@ def _generate_yaml_from_model(self, assertion: models.Assertion) -> str: {"views": assertion.marshal().get("views")}, default_flow_style=False ), body=dump_yaml({"body": assertion.body}, default_flow_style=False), - summary=assertion.summary, set_name=assertion.name, revision=assertion.revision, ) @@ -121,7 +127,10 @@ def _generate_yaml_from_template(self, name: str, account_id: str) -> str: account_id=account_id, views=_REGISTRY_SETS_VIEWS_TEMPLATE, body=_REGISTRY_SETS_BODY_TEMPLATE, - summary="A brief summary of the registries set", set_name=name, revision=1, ) + + @override + def _get_success_message(self, assertion: models.Assertion) -> str: + return f"Successfully created revision {assertion.revision!r} for {assertion.name!r}." diff --git a/snapcraft/store/client.py b/snapcraft/store/client.py index dbde3a4daf..6170c5d8e2 100644 --- a/snapcraft/store/client.py +++ b/snapcraft/store/client.py @@ -23,7 +23,9 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, cast import craft_store +import pydantic import requests +from craft_application.util.error_formatting import format_pydantic_errors from craft_cli import emit from overrides import overrides @@ -495,6 +497,22 @@ def list_revisions(self, snap_name: str) -> Revisions: return Revisions.unmarshal(response.json()) + @staticmethod + def _unmarshal_registries_set(registries_data) -> models.RegistryAssertion: + """Unmarshal a registries set. + + :raises StoreAssertionError: If the registries set cannot be unmarshalled. + """ + try: + return models.RegistryAssertion.unmarshal(registries_data) + except pydantic.ValidationError as err: + raise errors.SnapcraftAssertionError( + message="Received invalid registries set from the store", + # this is an unexpected failure that the user can't fix, so hide + # the response in the details + details=f"{format_pydantic_errors(err.errors(), file_name='registries set')}", + ) from err + def list_registries( self, *, name: str | None = None ) -> list[models.RegistryAssertion]: @@ -518,16 +536,75 @@ def list_registries( registry_assertions = [] if assertions := response.json().get("assertions"): for assertion_data in assertions: - emit.debug(f"Parsing assertion: {assertion_data}") # move body into model - assertion_data["headers"]["body"] = assertion_data["body"] - assertion = models.RegistryAssertion.unmarshal( - assertion_data["headers"] - ) + assertion_data["headers"]["body"] = assertion_data.get("body") + + assertion = self._unmarshal_registries_set(assertion_data["headers"]) registry_assertions.append(assertion) + emit.debug(f"Parsed registries set: {assertion.model_dump_json()}") return registry_assertions + def build_registries( + self, *, registries: models.EditableRegistryAssertion + ) -> models.RegistryAssertion: + """Build a registries set. + + Sends an edited registries set to the store, which validates the data, + populates additional fields, and returns the registries set. + + :param registries: The registries set to build. + + :returns: The built registries set. + """ + response = self.request( + "POST", + f"{self._base_url}/api/v2/registries/build-assertion", + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + json=registries.marshal(), + ) + + assertion = self._unmarshal_registries_set(response.json()) + emit.debug(f"Built registries set: {assertion.model_dump_json()}") + return assertion + + def post_registries(self, *, registries_data: bytes) -> models.RegistryAssertion: + """Send a registries set to be published. + + :param registries_data: A signed registries set represented as bytes. + + :returns: The published assertion. + """ + response = self.request( + "POST", + f"{self._base_url}/api/v2/registries", + headers={ + "Accept": "application/json", + "Content-Type": "application/x.ubuntu.assertion", + }, + data=registries_data, + ) + + assertions = response.json().get("assertions") + + if not assertions or len(assertions) != 1: + raise errors.SnapcraftAssertionError( + message="Received invalid registries set from the store", + # this is an unexpected failure that the user can't fix, so hide + # the response in the details + details=f"Received data: {assertions}", + ) + + # move body into model + assertions[0]["headers"]["body"] = assertions[0]["body"] + + assertion = self._unmarshal_registries_set(assertions[0]["headers"]) + emit.debug(f"Published registries set: {assertion.model_dump_json()}") + return assertion + class OnPremStoreClientCLI(LegacyStoreClientCLI): """On Premises Store Client command line interface.""" diff --git a/tests/unit/commands/test_registries.py b/tests/unit/commands/test_registries.py index c4f34dac92..f206fe3866 100644 --- a/tests/unit/commands/test_registries.py +++ b/tests/unit/commands/test_registries.py @@ -53,7 +53,6 @@ def test_list_registries(mocker, mock_list_assertions, output_format, name): @pytest.mark.parametrize("name", [None, "test"]) def test_list_registries_default_format(mocker, mock_list_assertions, name): """Default format is 'table'.""" - """Test `snapcraft list-registries`.""" cmd = ["snapcraft", "list-registries"] if name: cmd.extend(["--name", name]) @@ -65,15 +64,18 @@ def test_list_registries_default_format(mocker, mock_list_assertions, name): mock_list_assertions.assert_called_once_with(name=name, output_format="table") +@pytest.mark.parametrize("key_name", [None, "test-key"]) @pytest.mark.usefixtures("memory_keyring") -def test_edit_registries(mocker, mock_edit_assertion): +def test_edit_registries(key_name, mocker, mock_edit_assertion): """Test `snapcraft edit-registries`.""" cmd = ["snapcraft", "edit-registries", "test-account-id", "test-name"] + if key_name: + cmd.extend(["--key-name", key_name]) mocker.patch.object(sys, "argv", cmd) app = application.create_app() app.run() mock_edit_assertion.assert_called_once_with( - name="test-name", account_id="test-account-id" + name="test-name", account_id="test-account-id", key_name=key_name ) diff --git a/tests/unit/models/test_assertions.py b/tests/unit/models/test_assertions.py index a6fed23d95..16c6552347 100644 --- a/tests/unit/models/test_assertions.py +++ b/tests/unit/models/test_assertions.py @@ -17,7 +17,42 @@ """Tests for Assertion models.""" +import pytest + from snapcraft.models import EditableRegistryAssertion, Registry, RegistryAssertion +from snapcraft.models.assertions import cast_dict_scalars_to_strings + + +@pytest.mark.parametrize( + ("input_dict", "expected_dict"), + [ + pytest.param({}, {}, id="empty"), + pytest.param( + {False: False, True: True}, + {"False": "False", "True": "True"}, + id="boolean values", + ), + pytest.param( + {0: 0, None: None, "dict": {}, "list": [], "str": ""}, + ({"0": "0", None: None, "dict": {}, "list": [], "str": ""}), + id="none-like values", + ), + pytest.param( + {10: 10, 20.0: 20.0, "30": "30", True: True}, + {"10": "10", "20.0": "20.0", "30": "30", "True": "True"}, + id="scalar values", + ), + pytest.param( + {"foo": {"bar": [1, 2.0], "baz": {"qux": True}}}, + {"foo": {"bar": ["1", "2.0"], "baz": {"qux": "True"}}}, + id="nested data structures", + ), + ], +) +def test_cast_dict_scalars_to_strings(input_dict, expected_dict): + actual = cast_dict_scalars_to_strings(input_dict) + + assert actual == expected_dict def test_registry_defaults(check): @@ -73,11 +108,34 @@ def test_editable_registry_assertion_defaults(check): } ) - check.is_none(assertion.summary) check.equal(assertion.revision, 0) check.is_none(assertion.body) +def test_editable_registry_assertion_marshal_as_str(): + """Cast all scalars to string when marshalling.""" + assertion = EditableRegistryAssertion.unmarshal( + { + "account_id": "test-account-id", + "name": "test-registry", + "revision": 10, + "views": { + "wifi-setup": { + "rules": [ + { + "storage": "wifi.ssids", + } + ] + } + }, + } + ) + + assertion_dict = assertion.marshal_scalars_as_strings() + + assert assertion_dict["revision"] == "10" + + def test_registry_assertion_defaults(check): """Test default values of the RegistryAssertion model.""" assertion = RegistryAssertion.unmarshal( @@ -104,5 +162,31 @@ def test_registry_assertion_defaults(check): check.is_none(assertion.body) check.is_none(assertion.body_length) check.is_none(assertion.sign_key_sha3_384) - check.is_none(assertion.summary) check.equal(assertion.revision, 0) + + +def test_registry_assertion_marshal_as_str(): + """Cast all scalars to strings when marshalling.""" + assertion = RegistryAssertion.unmarshal( + { + "account_id": "test-account-id", + "authority_id": "test-authority-id", + "name": "test-registry", + "revision": 10, + "timestamp": "2024-01-01T10:20:30Z", + "type": "registry", + "views": { + "wifi-setup": { + "rules": [ + { + "storage": "wifi.ssids", + } + ] + } + }, + } + ) + + assertion_dict = assertion.marshal_scalars_as_strings() + + assert assertion_dict["revision"] == "10" diff --git a/tests/unit/services/test_assertions.py b/tests/unit/services/test_assertions.py index 73b15bc92e..56da17d207 100644 --- a/tests/unit/services/test_assertions.py +++ b/tests/unit/services/test_assertions.py @@ -16,15 +16,19 @@ """Tests for the abstract assertions service.""" +import json +import tempfile import textwrap from typing import Any from unittest import mock +import craft_store.errors import pytest from craft_application.models import CraftBaseModel from typing_extensions import override from snapcraft import const, errors +from tests.unit.store.utils import FakeResponse @pytest.fixture(autouse=True) @@ -48,8 +52,8 @@ def mock_confirm_with_user(mocker, request): @pytest.fixture -def mock_subprocess_run(mocker, tmp_path, request): - """Mock the subprocess.run function to write data to a file. +def write_text(mocker, tmp_path, request): + """Mock the subprocess.run function to write fake data to a temp assertion file. :param request: A list of strings to write to a file. Each time the subprocess.run function is called, the last string in the list will be written to the file @@ -65,12 +69,49 @@ def side_effect(*args, **kwargs): return subprocess_mock +@pytest.fixture +def fake_sign_assertion(mocker): + def _fake_sign(cmdline, input): # noqa: A002 (builtin-argument-shadowing) + return input + b"-signed" + + mock_subprocess = mocker.patch("subprocess.check_output") + mock_subprocess.side_effect = _fake_sign + return mock_subprocess + + +@pytest.fixture(autouse=True) +def mock_named_temporary_file(mocker, tmp_path): + _mock_tempfile = mocker.patch( + "tempfile.NamedTemporaryFile", spec=tempfile.NamedTemporaryFile + ) + _mock_tempfile.return_value.__enter__.return_value.name = str( + tmp_path / "assertion-file" + ) + yield _mock_tempfile.return_value + + +FAKE_STORE_ERROR = craft_store.errors.StoreServerError( + response=FakeResponse( + content=json.dumps( + {"error_list": [{"code": "bad assertion", "message": "bad assertion"}]} + ), + status_code=400, + ) +) + + class FakeAssertion(CraftBaseModel): """Fake assertion model.""" test_field_1: str test_field_2: int + def marshal_scalars_as_strings(self): + return { + "test_field_1": self.test_field_1, + "test_field_2": str(self.test_field_2), + } + @pytest.fixture def fake_assertion_service(default_factory): @@ -99,6 +140,21 @@ def _get_assertions( # type: ignore[override] FakeAssertion(test_field_1="test-value-2", test_field_2=100), ] + @override + def _build_assertion( # type: ignore[override] + self, assertion: FakeAssertion + ) -> FakeAssertion: + assertion.test_field_1 = assertion.test_field_1 + "-built" + return assertion + + @override + def _post_assertion( # type: ignore[override] + self, assertion_data: bytes + ) -> FakeAssertion: + return FakeAssertion( + test_field_1="test-published-assertion", test_field_2=0 + ) + @override def _normalize_assertions( # type: ignore[override] self, assertions: list[FakeAssertion] @@ -133,19 +189,13 @@ def _generate_yaml_from_template(self, name: str, account_id: str) -> str: """ ) - return FakeAssertionService(app=APP_METADATA, services=default_factory) - + @override + def _get_success_message( # type: ignore[override] + self, assertion: FakeAssertion + ) -> str: + return "Success." -@pytest.fixture -def fake_edit_yaml_file(mocker, fake_assertion_service): - """Apply a fake edit to a yaml file.""" - return mocker.patch.object( - fake_assertion_service, - "_edit_yaml_file", - return_value=FakeAssertion( - test_field_1="test-value-1-UPDATED", test_field_2=999 - ), - ) + return FakeAssertionService(app=APP_METADATA, services=default_factory) def test_list_assertions_table(fake_assertion_service, emitter): @@ -199,62 +249,213 @@ def test_list_assertions_unknown_format(fake_assertion_service): ) +@pytest.mark.parametrize( + "write_text", + [["test-field-1: test-value-1-edited\ntest-field-2: 999"]], + indirect=True, +) +@pytest.mark.usefixtures("fake_sign_assertion") def test_edit_assertions_changes_made( - fake_edit_yaml_file, fake_assertion_service, emitter + fake_assertion_service, + emitter, + mocker, + tmp_path, + write_text, ): """Edit an assertion and make a valid change.""" - expected = "Building, signing and uploading fake assertion is not implemented" - fake_assertion_service.setup() + expected_assertion = ( + b'{"test_field_1": "test-value-1-edited-built", "test_field_2": "999"}-signed' + ) + mock_post_assertion = mocker.spy(fake_assertion_service, "_post_assertion") - with pytest.raises(errors.FeatureNotImplemented, match=expected): - fake_assertion_service.edit_assertion( - name="test-registry", account_id="test-account-id" - ) + fake_assertion_service.setup() + fake_assertion_service.edit_assertion( + name="test-registry", account_id="test-account-id", key_name="test-key" + ) - emitter.assert_message("Successfully edited fake assertion 'test-registry'.") + mock_post_assertion.assert_called_once_with(expected_assertion) + emitter.assert_trace(f"Signed assertion: {expected_assertion.decode()}") + emitter.assert_message("Success.") +@pytest.mark.parametrize( + "write_text", + [["test-field-1: test-value-1\ntest-field-2: 0"]], + indirect=True, +) def test_edit_assertions_no_changes_made( - fake_edit_yaml_file, fake_assertion_service, emitter, mocker + fake_assertion_service, emitter, tmp_path, write_text ): """Edit an assertion but make no changes to the data.""" + fake_assertion_service.setup() + fake_assertion_service.edit_assertion( + name="test-registry", account_id="test-account-id" + ) + + emitter.assert_message("No changes made.") + assert not (tmp_path / "assertion-file").exists() + + +@pytest.mark.parametrize( + "write_text", + [ + [ + "test-field-1: test-value-1-edited-edited\ntest-field-2: 999", + "test-field-1: test-value-1-edited\ntest-field-2: 999", + ], + ], + indirect=True, +) +@pytest.mark.parametrize("mock_confirm_with_user", [True], indirect=True) +@pytest.mark.parametrize( + "error", [FAKE_STORE_ERROR, errors.SnapcraftAssertionError("bad assertion")] +) +@pytest.mark.usefixtures("fake_sign_assertion") +def test_edit_assertions_build_assertion_error( + error, + fake_assertion_service, + emitter, + mock_confirm_with_user, + write_text, + mocker, + tmp_path, +): + """Receive an error while building an assertion, then re-edit and post the assertion.""" + expected_assertion = b'{"test_field_1": "test-value-1-edited-edited-built", "test_field_2": "999"}-signed' + mock_post_assertion = mocker.spy(fake_assertion_service, "_post_assertion") mocker.patch.object( fake_assertion_service, - "_edit_yaml_file", - # make no changes to the fake assertion - return_value=FakeAssertion(test_field_1="test-value-1", test_field_2=0), + "_build_assertion", + side_effect=[ + error, + FakeAssertion( + test_field_1="test-value-1-edited-edited-built", test_field_2=999 + ), + ], ) + fake_assertion_service.setup() + fake_assertion_service.edit_assertion( + name="test-registry", account_id="test-account-id", key_name="test-key" + ) + assert mock_confirm_with_user.mock_calls == [ + mock.call("Do you wish to amend the fake assertion?") + ] + assert mock_post_assertion.mock_calls == [mock.call(expected_assertion)] + emitter.assert_trace(f"Signed assertion: {expected_assertion.decode()}") + emitter.assert_message("Success.") + assert not (tmp_path / "assertion-file").exists() + + +@pytest.mark.parametrize( + "write_text", + [ + [ + "test-field-1: test-value-1-edited-edited\ntest-field-2: 999", + "test-field-1: test-value-1-edited\ntest-field-2: 999", + ], + ], + indirect=True, +) +@pytest.mark.parametrize("mock_confirm_with_user", [True], indirect=True) +@pytest.mark.usefixtures("fake_sign_assertion") +def test_edit_assertions_sign_assertion_error( + fake_assertion_service, + emitter, + mock_confirm_with_user, + write_text, + mocker, + tmp_path, +): + """Receive an error while signing an assertion, then re-edit and post the assertion.""" + expected_assertion = b'{"test_field_1": "test-value-1-edited-edited-built", "test_field_2": "999"}-signed' + mock_post_assertion = mocker.spy(fake_assertion_service, "_post_assertion") + mocker.patch.object( + fake_assertion_service, + "_sign_assertion", + side_effect=[ + errors.SnapcraftAssertionError("bad assertion"), + expected_assertion, + ], + ) + + fake_assertion_service.setup() fake_assertion_service.edit_assertion( - name="test-registry", account_id="test-account-id" + name="test-registry", account_id="test-account-id", key_name="test-key" ) - emitter.assert_message("No changes made.") + assert mock_confirm_with_user.mock_calls == [ + mock.call("Do you wish to amend the fake assertion?") + ] + assert mock_post_assertion.mock_calls == [mock.call(expected_assertion)] + emitter.assert_message("Success.") + assert not (tmp_path / "assertion-file").exists() -@pytest.mark.parametrize("editor", [None, "faux-vi"]) @pytest.mark.parametrize( - "mock_subprocess_run", + "write_text", [ [ - textwrap.dedent( - """\ - test-field-1: test-value-1-UPDATED - test-field-2: 999 - """ - ), + "test-field-1: test-value-1-edited-edited\ntest-field-2: 999", + "test-field-1: test-value-1-edited\ntest-field-2: 999", ], ], indirect=True, ) @pytest.mark.parametrize("mock_confirm_with_user", [True], indirect=True) +@pytest.mark.parametrize( + "error", [FAKE_STORE_ERROR, errors.SnapcraftAssertionError("bad assertion")] +) +@pytest.mark.usefixtures("fake_sign_assertion") +def test_edit_assertions_post_assertion_error( + error, + fake_assertion_service, + emitter, + mock_confirm_with_user, + write_text, + mocker, + tmp_path, +): + """Receive an error while processing an assertion, then re-edit and post the assertion.""" + expected_first_assertion = ( + b'{"test_field_1": "test-value-1-edited-built", "test_field_2": "999"}-signed' + ) + expected_second_assertion = b'{"test_field_1": "test-value-1-edited-edited-built", "test_field_2": "999"}-signed' + mock_post_assertion = mocker.patch.object( + fake_assertion_service, "_post_assertion", side_effect=[error, None] + ) + + fake_assertion_service.setup() + fake_assertion_service.edit_assertion( + name="test-registry", account_id="test-account-id", key_name="test-key" + ) + + assert mock_confirm_with_user.mock_calls == [ + mock.call("Do you wish to amend the fake assertion?") + ] + assert mock_post_assertion.mock_calls == [ + mock.call(expected_first_assertion), + mock.call(expected_second_assertion), + ] + emitter.assert_trace(f"Signed assertion: {expected_second_assertion.decode()}") + emitter.assert_message("Success.") + assert not (tmp_path / "assertion-file").exists() + + +@pytest.mark.parametrize("editor", [None, "faux-vi"]) +@pytest.mark.parametrize( + "write_text", + [["test-field-1: test-value-1-edited\ntest-field-2: 999"]], + indirect=True, +) +@pytest.mark.parametrize("mock_confirm_with_user", [True], indirect=True) def test_edit_yaml_file( editor, fake_assertion_service, tmp_path, mock_confirm_with_user, - mock_subprocess_run, + write_text, monkeypatch, ): """Successfully edit a yaml file with the correct editor.""" @@ -271,50 +472,26 @@ def test_edit_yaml_file( edited_assertion = fake_assertion_service._edit_yaml_file(tmp_file) assert edited_assertion == FakeAssertion( - test_field_1="test-value-1-UPDATED", test_field_2=999 + test_field_1="test-value-1-edited", test_field_2=999 ) mock_confirm_with_user.assert_not_called() - assert mock_subprocess_run.mock_calls == [ - mock.call([expected_editor, tmp_file], check=True) - ] + assert write_text.mock_calls == [mock.call([expected_editor, tmp_file], check=True)] @pytest.mark.parametrize( - "mock_subprocess_run", + "write_text", [ pytest.param( [ - textwrap.dedent( - """\ - test-field-1: test-value-1-UPDATED - test-field-2: 999 - """ - ), - textwrap.dedent( - """\ - bad yaml {{ - test-field-1: test-value-1 - test-field-2: 0 - """ - ), + "test-field-1: test-value-1-edited\ntest-field-2: 999", + "bad yaml {{\ntest-field-1: test-value-1\ntest-field-2: 0", ], id="invalid yaml syntax", ), pytest.param( [ - textwrap.dedent( - """\ - test-field-1: test-value-1-UPDATED - test-field-2: 999 - """ - ), - textwrap.dedent( - """\ - extra-field: not-allowed - test-field-1: [wrong data type] - test-field-2: 0 - """ - ), + "test-field-1: test-value-1-edited\ntest-field-2: 999", + "extra-field: not-allowed\ntest-field-1: [wrong data type]\ntest-field-2: 0", ], id="invalid pydantic data", ), @@ -326,7 +503,7 @@ def test_edit_yaml_file_error_retry( fake_assertion_service, tmp_path, mock_confirm_with_user, - mock_subprocess_run, + write_text, ): """Edit a yaml file but encounter an error and retry.""" tmp_file = tmp_path / "assertion-file" @@ -335,30 +512,17 @@ def test_edit_yaml_file_error_retry( edited_assertion = fake_assertion_service._edit_yaml_file(tmp_file) assert edited_assertion == FakeAssertion( - test_field_1="test-value-1-UPDATED", test_field_2=999 + test_field_1="test-value-1-edited", test_field_2=999 ) assert mock_confirm_with_user.mock_calls == [ mock.call("Do you wish to amend the fake assertion?") ] - assert ( - mock_subprocess_run.mock_calls - == [mock.call(["faux-vi", tmp_file], check=True)] * 2 - ) + assert write_text.mock_calls == [mock.call(["faux-vi", tmp_file], check=True)] * 2 @pytest.mark.parametrize( - "mock_subprocess_run", - [ - [ - textwrap.dedent( - """\ - bad yaml {{ - test-field-1: test-value-1 - test-field-2: 0 - """ - ), - ], - ], + "write_text", + [["bad yaml {{\ntest-field-1: test-value-1\ntest-field-2: 0"]], indirect=True, ) @pytest.mark.parametrize("mock_confirm_with_user", [False], indirect=True) @@ -366,7 +530,7 @@ def test_edit_error_no_retry( fake_assertion_service, tmp_path, mock_confirm_with_user, - mock_subprocess_run, + write_text, ): """Edit a yaml file and encounter an error but do not retry.""" tmp_file = tmp_path / "assertion-file" @@ -378,6 +542,4 @@ def test_edit_error_no_retry( assert mock_confirm_with_user.mock_calls == [ mock.call("Do you wish to amend the fake assertion?") ] - assert mock_subprocess_run.mock_calls == [ - mock.call(["faux-vi", tmp_file], check=True) - ] + assert write_text.mock_calls == [mock.call(["faux-vi", tmp_file], check=True)] diff --git a/tests/unit/services/test_registries.py b/tests/unit/services/test_registries.py index f791cceb9a..2305f13413 100644 --- a/tests/unit/services/test_registries.py +++ b/tests/unit/services/test_registries.py @@ -17,8 +17,9 @@ """Tests for the registries service.""" import textwrap +from unittest import mock -from snapcraft.models import EditableRegistryAssertion +from snapcraft.models import EditableRegistryAssertion, RegistryAssertion def test_registries_service_type(registries_service): @@ -37,6 +38,24 @@ def test_get_assertions(registries_service): ) +def test_build_assertion(registries_service): + mock_assertion = mock.Mock(spec=RegistryAssertion) + + registries_service._build_assertion(mock_assertion) + + registries_service._store_client.build_registries.assert_called_once_with( + registries=mock_assertion + ) + + +def test_post_assertions(registries_service): + registries_service._post_assertion(b"test-assertion-data") + + registries_service._store_client.post_registries.assert_called_once_with( + registries_data=b"test-assertion-data" + ) + + def test_normalize_assertions_empty(registries_service, check): headers, registries = registries_service._normalize_assertions([]) @@ -71,7 +90,6 @@ def test_normalize_assertions(fake_registry_assertion, registries_service, check def test_generate_yaml_from_model(fake_registry_assertion, registries_service): assertion = fake_registry_assertion( - summary="test-summary", revision="10", views={ "wifi-setup": { @@ -102,7 +120,6 @@ def test_generate_yaml_from_model(fake_registry_assertion, registries_service): """\ account-id: test-account-id name: test-registry - # summary: test-summary # The revision for this registries set # revision: 10 views: @@ -129,3 +146,11 @@ def test_generate_yaml_from_model(fake_registry_assertion, registries_service): """ ) + + +def test_get_success_message(fake_registry_assertion, registries_service): + message = registries_service._get_success_message( + fake_registry_assertion(revision=10) + ) + + assert message == "Successfully created revision 10 for 'test-registry'." diff --git a/tests/unit/store/test_client.py b/tests/unit/store/test_client.py index fd3bf73921..6fc14ae9f0 100644 --- a/tests/unit/store/test_client.py +++ b/tests/unit/store/test_client.py @@ -17,7 +17,7 @@ import json import textwrap import time -from unittest.mock import ANY, call +from unittest.mock import ANY, Mock, call import craft_store import pytest @@ -180,7 +180,7 @@ def list_registries_payload(): "account-id": "test-account-id", "authority-id": "test-authority-id", "body-length": "92", - "name": "test-registry", + "name": "test-registries", "revision": "9", "sign-key-sha3-384": "test-sign-key", "timestamp": "2024-01-01T10:20:30Z", @@ -203,6 +203,62 @@ def list_registries_payload(): } +@pytest.fixture +def build_registries_payload(): + return { + "account_id": "test-account-id", + "authority_id": "test-authority-id", + "name": "test-registries", + "revision": "10", + "views": { + "wifi-setup": { + "rules": [ + { + "request": "ssids", + "storage": "wifi.ssids", + "access": "read-write", + } + ] + } + }, + "body": '{\n "storage": {\n "schema": {\n "wifi": {\n "values": "any"\n }\n }\n }\n}', + "type": "registry", + "timestamp": "2024-01-01T10:20:30Z", + } + + +@pytest.fixture +def post_registries_payload(): + return { + "assertions": [ + { + "headers": { + "account-id": "test-account-id", + "authority-id": "test-authority-id", + "body-length": "92", + "name": "test-registries", + "revision": "10", + "sign-key-sha3-384": "test-key", + "timestamp": "2024-01-01T10:20:30Z", + "type": "registry", + "views": { + "wifi-setup": { + "rules": [ + { + "access": "read", + "request": "ssids", + "storage": "wifi.ssids", + } + ] + } + }, + }, + "body": '{\n "storage": {\n "schema": {\n "wifi": {\n "values": "any"\n }\n }\n }\n}', + } + ] + } + + #################### # User Agent Tests # #################### @@ -1067,7 +1123,7 @@ def test_list_revisions(fake_client, list_revisions_payload): @pytest.mark.parametrize("name", [None, "test-registry"]) def test_list_registries(name, fake_client, list_registries_payload, check): - """Test the registries endpoint.""" + """Test the list registries endpoint.""" fake_client.request.return_value = FakeResponse( status_code=200, content=json.dumps(list_registries_payload).encode() ) @@ -1098,7 +1154,7 @@ def test_list_registries(name, fake_client, list_registries_payload, check): def test_list_registries_empty(fake_client, check): - """Test the registries endpoint with no registries returned.""" + """Test the list registries endpoint with no registries returned.""" fake_client.request.return_value = FakeResponse( status_code=200, content=json.dumps({"assertions": []}).encode() ) @@ -1121,6 +1177,140 @@ def test_list_registries_empty(fake_client, check): ) +def test_list_registries_unmarshal_error(fake_client, list_registries_payload): + """Raise an error if the response cannot be unmarshalled.""" + list_registries_payload["assertions"][0]["headers"].pop("name") + fake_client.request.return_value = FakeResponse( + status_code=200, content=json.dumps(list_registries_payload).encode() + ) + + with pytest.raises(errors.SnapcraftAssertionError) as raised: + client.StoreClientCLI().list_registries() + + assert str(raised.value) == "Received invalid registries set from the store" + assert raised.value.details == ( + "Bad registries set content:\n" + "- field 'name' required in top-level configuration" + ) + + +#################### +# Build Registries # +#################### + + +def test_build_registries(fake_client, build_registries_payload): + """Test the build registries endpoint.""" + mock_registries = Mock(spec=models.RegistryAssertion) + expected_registries = models.RegistryAssertion(**build_registries_payload) + fake_client.request.return_value = FakeResponse( + status_code=200, content=json.dumps(build_registries_payload).encode() + ) + + registries_set = client.StoreClientCLI().build_registries( + registries=mock_registries + ) + + assert registries_set == expected_registries + assert fake_client.request.mock_calls == [ + call( + "POST", + "https://dashboard.snapcraft.io/api/v2/registries/build-assertion", + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + json=mock_registries.marshal(), + ) + ] + + +def test_build_registries_unmarshal_error(fake_client, build_registries_payload): + """Raise an error if the response cannot be unmarshalled.""" + mock_registries = Mock(spec=models.RegistryAssertion) + build_registries_payload.pop("name") + fake_client.request.return_value = FakeResponse( + status_code=200, content=json.dumps(build_registries_payload).encode() + ) + + with pytest.raises(errors.SnapcraftAssertionError) as raised: + client.StoreClientCLI().build_registries(registries=mock_registries) + + assert str(raised.value) == "Received invalid registries set from the store" + assert raised.value.details == ( + "Bad registries set content:\n" + "- field 'name' required in top-level configuration" + ) + + +################### +# Post Registries # +################### + + +def test_post_registries(fake_client, post_registries_payload): + """Test the post registries endpoint.""" + expected_registries = models.RegistryAssertion( + **post_registries_payload["assertions"][0]["headers"], + body=post_registries_payload["assertions"][0]["body"], + ) + fake_client.request.return_value = FakeResponse( + status_code=200, content=json.dumps(post_registries_payload).encode() + ) + + registries_set = client.StoreClientCLI().post_registries( + registries_data=b"test-data" + ) + + assert registries_set == expected_registries + assert fake_client.request.mock_calls == [ + call( + "POST", + "https://dashboard.snapcraft.io/api/v2/registries", + headers={ + "Accept": "application/json", + "Content-Type": "application/x.ubuntu.assertion", + }, + data=b"test-data", + ) + ] + + +@pytest.mark.parametrize("num_assertions", [0, 2]) +def test_post_registries_wrong_payload_error( + num_assertions, fake_client, post_registries_payload +): + """Error if the wrong number of assertions are returned.""" + post_registries_payload["assertions"] = ( + post_registries_payload["assertions"] * num_assertions + ) + fake_client.request.return_value = FakeResponse( + status_code=200, content=json.dumps(post_registries_payload).encode() + ) + + with pytest.raises(errors.SnapcraftAssertionError) as raised: + client.StoreClientCLI().post_registries(registries_data=b"test-data") + + assert str(raised.value) == "Received invalid registries set from the store" + + +def test_post_registries_unmarshal_error(fake_client, post_registries_payload): + """Raise an error if the response cannot be unmarshalled.""" + post_registries_payload["assertions"][0]["headers"].pop("name") + fake_client.request.return_value = FakeResponse( + status_code=200, content=json.dumps(post_registries_payload).encode() + ) + + with pytest.raises(errors.SnapcraftAssertionError) as raised: + client.StoreClientCLI().post_registries(registries_data=b"test-data") + + assert str(raised.value) == "Received invalid registries set from the store" + assert raised.value.details == ( + "Bad registries set content:\n" + "- field 'name' required in top-level configuration" + ) + + ######################## # OnPremStoreClientCLI # ########################