Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic contract mocking #61

Merged
merged 26 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/rust-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ jobs:
- name: Run tests for examples
shell: bash
run: |
# todo: use loop xD
deuszx marked this conversation as resolved.
Show resolved Hide resolved

pushd examples/flipper
cargo contract build --release
cargo test --release
Expand All @@ -73,3 +75,8 @@ jobs:
cargo contract build --release
cargo test --release
popd

pushd examples/mocking
cargo contract build --release
cargo test --release
popd
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ exclude = [
"examples/counter",
"examples/flipper",
"examples/cross-contract-call-tracing",
"examples/mocking",
]

[workspace.package]
Expand All @@ -19,7 +20,7 @@ homepage = "https://github.com/Cardinal-Cryptography/drink"
license = "Apache-2.0"
readme = "README.md"
repository = "https://github.com/Cardinal-Cryptography/drink"
version = "0.4.1"
version = "0.5.0"

[workspace.dependencies]
anyhow = { version = "1.0.71" }
Expand All @@ -40,7 +41,7 @@ frame-metadata = { version = "16.0.0" }
frame-support = { version = "23.0.0" }
frame-system = { version = "23.0.0" }
pallet-balances = { version = "23.0.0" }
pallet-contracts = { package = "pallet-contracts-for-drink", version = "22.0.0" }
pallet-contracts = { package = "pallet-contracts-for-drink", version = "22.0.1" }
pallet-contracts-primitives = { version = "26.0.0" }
pallet-timestamp = { version = "22.0.0" }
sp-core = { version = "23.0.0" }
Expand All @@ -50,4 +51,4 @@ sp-runtime-interface = { version = "19.0.0" }

# Local dependencies

drink = { version = "0.4.1", path = "drink" }
drink = { version = "0.5.0", path = "drink" }
2 changes: 0 additions & 2 deletions drink/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ sp-runtime-interface = { workspace = true }

scale-info = { workspace = true }
thiserror = { workspace = true }

[dev-dependencies]
wat = { workspace = true }

[features]
Expand Down
15 changes: 0 additions & 15 deletions drink/src/error.rs

This file was deleted.

43 changes: 43 additions & 0 deletions drink/src/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//! Module gathering common error and result types.

use thiserror::Error;

/// Main error type for the drink crate.
#[derive(Error, Debug)]
pub enum Error {
/// Externalities could not be initialized.
#[error("Failed to build storage: {0}")]
StorageBuilding(String),
/// Block couldn't have been initialized.
#[error("Failed to initialize block: {0}")]
BlockInitialize(String),
/// Block couldn't have been finalized.
#[error("Failed to finalize block: {0}")]
BlockFinalize(String),
}

/// Every contract message wraps its return value in `Result<T, LangResult>`. This is the error
/// type.
///
/// Copied from ink primitives.
#[non_exhaustive]
#[repr(u32)]
#[derive(
Debug,
Copy,
Clone,
PartialEq,
Eq,
parity_scale_codec::Encode,
parity_scale_codec::Decode,
scale_info::TypeInfo,
Error,
)]
pub enum LangError {
/// Failed to read execution input for the dispatchable.
#[error("Failed to read execution input for the dispatchable.")]
CouldNotReadInput = 1u32,
}

/// The `Result` type for ink! messages.
pub type MessageResult<T> = Result<T, LangError>;
91 changes: 84 additions & 7 deletions drink/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,37 @@

pub mod chain_api;
pub mod contract_api;
mod error;
pub mod errors;
mod mock;
pub mod runtime;
#[cfg(feature = "session")]
pub mod session;
use std::marker::PhantomData;

pub use error::Error;
use std::{
marker::PhantomData,
sync::{Arc, Mutex},
};

pub use errors::Error;
use frame_support::sp_runtime::{traits::One, BuildStorage};
pub use frame_support::{
sp_runtime::{AccountId32, DispatchError},
weights::Weight,
};
use frame_system::{pallet_prelude::BlockNumberFor, EventRecord, GenesisConfig};
pub use mock::{mock_message, ContractMock, MessageMock, MockedCallResult, MockingApi, Selector};
use pallet_contracts::debug::ExecResult;
use pallet_contracts_primitives::{ExecReturnValue, ReturnFlags};
use parity_scale_codec::{Decode, Encode};
use sp_io::TestExternalities;

use crate::{
pallet_contracts_debugging::DebugExt,
runtime::{pallet_contracts_debugging::NoopDebugExt, *},
mock::MockRegistry,
pallet_contracts_debugging::{InterceptingExt, TracingExt},
runtime::{
pallet_contracts_debugging::{InterceptingExtT, NoopExt},
*,
},
};

/// Main result type for the drink crate.
Expand All @@ -35,6 +48,8 @@ pub type EventRecordOf<T> =
/// A sandboxed runtime.
pub struct Sandbox<R: Runtime> {
externalities: TestExternalities,
mock_registry: Arc<Mutex<MockRegistry<AccountIdFor<R>>>>,
deuszx marked this conversation as resolved.
Show resolved Hide resolved
mock_counter: usize,
_phantom: PhantomData<R>,
}

Expand All @@ -57,6 +72,8 @@ impl<R: Runtime> Sandbox<R> {

let mut sandbox = Self {
externalities: TestExternalities::new(storage),
mock_registry: Arc::new(Mutex::new(MockRegistry::new())),
mock_counter: 0,
_phantom: PhantomData,
};

Expand All @@ -68,7 +85,9 @@ impl<R: Runtime> Sandbox<R> {
.map_err(Error::BlockInitialize)?;

// We register a noop debug extension by default.
sandbox.override_debug_handle(DebugExt(Box::new(NoopDebugExt {})));
sandbox.override_debug_handle(TracingExt(Box::new(NoopExt {})));

sandbox.setup_mock_extension();

Ok(sandbox)
}
Expand All @@ -77,7 +96,65 @@ impl<R: Runtime> Sandbox<R> {
///
/// By default, a new `Sandbox` instance is created with a noop debug extension. This method
/// allows to override it with a custom debug extension.
pub fn override_debug_handle(&mut self, d: DebugExt) {
pub fn override_debug_handle(&mut self, d: TracingExt) {
self.externalities.register_extension(d);
}

/// Registers the extension for intercepting calls to contracts.
fn setup_mock_extension(&mut self) {
self.externalities
.register_extension(InterceptingExt(Box::new(MockingExtension {
mock_registry: Arc::clone(&self.mock_registry),
})));
}
}

/// Runtime extension enabling contract call interception.
struct MockingExtension<AccountId: Ord> {
/// Mock registry, shared with the sandbox.
///
/// Potentially the runtime is executed in parallel and thus we need to wrap the registry in
/// `Arc<Mutex>` instead of `Rc<RefCell>`.
mock_registry: Arc<Mutex<MockRegistry<AccountId>>>,
}

impl<AccountId: Ord + Decode> InterceptingExtT for MockingExtension<AccountId> {
fn intercept_call(
&self,
contract_address: Vec<u8>,
_is_call: bool,
input_data: Vec<u8>,
) -> Vec<u8> {
let contract_address = Decode::decode(&mut &contract_address[..])
deuszx marked this conversation as resolved.
Show resolved Hide resolved
.expect("Contract address should be decodable");

match self
.mock_registry
.lock()
.expect("Should be able to acquire registry")
.get(&contract_address)
{
// There is no mock registered for this address, so we return `None` to indicate that
// the call should be executed normally.
None => None::<()>.encode(),
// We intercept the call and return the result of the mock.
Some(mock) => {
let (selector, call_data) = input_data.split_at(4);
let selector: Selector = selector
.try_into()
.expect("Input data should contain at least selector bytes");

let result = mock
.call(selector, call_data.to_vec())
.expect("TODO: let the user define the fallback mechanism");
deuszx marked this conversation as resolved.
Show resolved Hide resolved

let result: ExecResult = Ok(ExecReturnValue {
flags: ReturnFlags::empty(),
deuszx marked this conversation as resolved.
Show resolved Hide resolved
data: result,
});

Some(result).encode()
}
}
}
}
36 changes: 36 additions & 0 deletions drink/src/mock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
mod contract;
mod error;
mod mocking_api;

use std::collections::BTreeMap;

pub use contract::{mock_message, ContractMock, MessageMock, Selector};
use error::MockingError;
pub use mocking_api::MockingApi;

/// Untyped result of a mocked call.
pub type MockedCallResult = Result<Vec<u8>, MockingError>;

/// A registry of mocked contracts.
pub(crate) struct MockRegistry<AccountId: Ord> {
mocked_contracts: BTreeMap<AccountId, ContractMock>,
}

impl<AccountId: Ord> MockRegistry<AccountId> {
/// Creates a new registry.
pub fn new() -> Self {
Self {
mocked_contracts: BTreeMap::new(),
}
}

/// Registers `mock` for `address`.
pub fn register(&mut self, address: AccountId, mock: ContractMock) {
self.mocked_contracts.insert(address, mock);
deuszx marked this conversation as resolved.
Show resolved Hide resolved
}

/// Returns the mock for `address`, if any.
pub fn get(&self, address: &AccountId) -> Option<&ContractMock> {
self.mocked_contracts.get(address)
}
}
64 changes: 64 additions & 0 deletions drink/src/mock/contract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::collections::BTreeMap;

use parity_scale_codec::{Decode, Encode};

use crate::{
errors::LangError,
mock::{error::MockingError, MockedCallResult},
};

/// Alias for a 4-byte selector.
pub type Selector = [u8; 4];
/// An untyped message mock.
///
/// Notice that in the end, we cannot operate on specific argument/return types. Rust won't let us
/// have a collection of differently typed closures. Fortunately, we can assume that all types are
/// en/decodable, so we can use `Vec<u8>` as a common denominator.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧠

pub type MessageMock = Box<dyn Fn(Vec<u8>) -> MockedCallResult + Send + Sync>;

/// A contract mock.
pub struct ContractMock {
messages: BTreeMap<Selector, MessageMock>,
}

impl ContractMock {
/// Creates a new mock without any message.
pub fn new() -> Self {
Self {
messages: BTreeMap::new(),
}
}

/// Adds a message mock.
pub fn with_message(mut self, selector: Selector, message: MessageMock) -> Self {
self.messages.insert(selector, message);
self
}

/// Try to call a message mock. Returns an error if there is no message mock for `selector`.
pub fn call(&self, selector: Selector, input: Vec<u8>) -> MockedCallResult {
match self.messages.get(&selector) {
None => Err(MockingError::MessageNotFound(selector)),
Some(message) => message(input),
}
}
}

impl Default for ContractMock {
fn default() -> Self {
Self::new()
}
}

/// A helper function to create a message mock out of a typed closure.
///
/// In particular, it takes care of decoding the input and encoding the output. Also, wraps the
/// return value in a `Result`, which is normally done implicitly by ink!.
pub fn mock_message<Args: Decode, Ret: Encode, Body: Fn(Args) -> Ret + Send + Sync + 'static>(
body: Body,
) -> MessageMock {
Box::new(move |encoded_input| {
let input = Decode::decode(&mut &*encoded_input).map_err(MockingError::ArgumentDecoding)?;
Ok(Ok::<Ret, LangError>(body(input)).encode())
})
}
Loading
Loading