diff --git a/prdoc/pr_5198.prdoc b/prdoc/pr_5198.prdoc new file mode 100644 index 000000000000..417b0b5a4fd9 --- /dev/null +++ b/prdoc/pr_5198.prdoc @@ -0,0 +1,13 @@ +title: "MQ processor should be transactional" + +doc: + - audience: [Runtime User, Runtime Dev] + description: | + Enforce transactional processing on pallet Message Queue Processor. + + Storage changes that were done while processing a message will now be rolled back + when the processing returns an error. `Ok(false)` will not revert, only `Err(_)`. + +crates: + - name: pallet-message-queue + bump: major \ No newline at end of file diff --git a/substrate/frame/message-queue/src/integration_test.rs b/substrate/frame/message-queue/src/integration_test.rs index 14b8d2217eb2..e4db87d8be7a 100644 --- a/substrate/frame/message-queue/src/integration_test.rs +++ b/substrate/frame/message-queue/src/integration_test.rs @@ -151,6 +151,7 @@ fn stress_test_recursive() { TotalEnqueued::set(TotalEnqueued::get() + enqueued); Enqueued::set(Enqueued::get() + enqueued); Called::set(Called::get() + 1); + Ok(()) })); build_and_execute::(|| { diff --git a/substrate/frame/message-queue/src/lib.rs b/substrate/frame/message-queue/src/lib.rs index 2dbffef7e5a2..48002acb1474 100644 --- a/substrate/frame/message-queue/src/lib.rs +++ b/substrate/frame/message-queue/src/lib.rs @@ -225,7 +225,7 @@ use sp_arithmetic::traits::{BaseArithmetic, Unsigned}; use sp_core::{defer, H256}; use sp_runtime::{ traits::{One, Zero}, - SaturatedConversion, Saturating, + SaturatedConversion, Saturating, TransactionOutcome, }; use sp_weights::WeightMeter; pub use weights::WeightInfo; @@ -1435,6 +1435,8 @@ impl Pallet { /// The base weight of this function needs to be accounted for by the caller. `weight` is the /// remaining weight to process the message. `overweight_limit` is the maximum weight that a /// message can ever consume. Messages above this limit are marked as permanently overweight. + /// This process is also transactional, any form of error that occurs in processing a message + /// causes storage changes to be rolled back. fn process_message_payload( origin: MessageOriginOf, page_index: PageIndex, @@ -1447,7 +1449,27 @@ impl Pallet { use ProcessMessageError::*; let prev_consumed = meter.consumed(); - match T::MessageProcessor::process_message(message, origin.clone(), meter, &mut id) { + let transaction = + storage::with_transaction(|| -> TransactionOutcome> { + let res = + T::MessageProcessor::process_message(message, origin.clone(), meter, &mut id); + match &res { + Ok(_) => TransactionOutcome::Commit(Ok(res)), + Err(_) => TransactionOutcome::Rollback(Ok(res)), + } + }); + + let transaction = match transaction { + Ok(result) => result, + _ => { + defensive!( + "Error occurred processing message, storage changes will be rolled back" + ); + return MessageExecutionStatus::Unprocessable { permanent: true } + }, + }; + + match transaction { Err(Overweight(w)) if w.any_gt(overweight_limit) => { // Permanently overweight. Self::deposit_event(Event::::OverweightEnqueued { diff --git a/substrate/frame/message-queue/src/mock.rs b/substrate/frame/message-queue/src/mock.rs index 26533cc7c330..d3f719c62356 100644 --- a/substrate/frame/message-queue/src/mock.rs +++ b/substrate/frame/message-queue/src/mock.rs @@ -184,8 +184,15 @@ impl ProcessMessage for RecordingMessageProcessor { if meter.try_consume(required).is_ok() { if let Some(p) = message.strip_prefix(&b"callback="[..]) { let s = String::from_utf8(p.to_vec()).expect("Need valid UTF8"); - Callback::get()(&origin, s.parse().expect("Expected an u32")); + if let Err(()) = Callback::get()(&origin, s.parse().expect("Expected an u32")) { + return Err(ProcessMessageError::Corrupt) + } + + if s.contains("000") { + return Ok(false) + } } + let mut m = MessagesProcessed::get(); m.push((message.to_vec(), origin)); MessagesProcessed::set(m); @@ -197,7 +204,7 @@ impl ProcessMessage for RecordingMessageProcessor { } parameter_types! { - pub static Callback: Box = Box::new(|_, _| {}); + pub static Callback: Box Result<(), ()>> = Box::new(|_, _| { Ok(()) }); pub static IgnoreStackOvError: bool = false; } @@ -252,7 +259,9 @@ impl ProcessMessage for CountingMessageProcessor { if meter.try_consume(required).is_ok() { if let Some(p) = message.strip_prefix(&b"callback="[..]) { let s = String::from_utf8(p.to_vec()).expect("Need valid UTF8"); - Callback::get()(&origin, s.parse().expect("Expected an u32")); + if let Err(()) = Callback::get()(&origin, s.parse().expect("Expected an u32")) { + return Err(ProcessMessageError::Corrupt) + } } NumMessagesProcessed::set(NumMessagesProcessed::get() + 1); Ok(true) diff --git a/substrate/frame/message-queue/src/tests.rs b/substrate/frame/message-queue/src/tests.rs index e89fdb8b3208..fac135f135ce 100644 --- a/substrate/frame/message-queue/src/tests.rs +++ b/substrate/frame/message-queue/src/tests.rs @@ -1675,6 +1675,7 @@ fn regression_issue_2319() { build_and_execute::(|| { Callback::set(Box::new(|_, _| { MessageQueue::enqueue_message(mock_helpers::msg("anothermessage"), There); + Ok(()) })); use MessageOrigin::*; @@ -1695,23 +1696,26 @@ fn regression_issue_2319() { #[test] fn recursive_enqueue_works() { build_and_execute::(|| { - Callback::set(Box::new(|o, i| match i { - 0 => { - MessageQueue::enqueue_message(msg(&format!("callback={}", 1)), *o); - }, - 1 => { - for _ in 0..100 { - MessageQueue::enqueue_message(msg(&format!("callback={}", 2)), *o); - } - for i in 0..100 { - MessageQueue::enqueue_message(msg(&format!("callback={}", 3)), i.into()); - } - }, - 2 | 3 => { - MessageQueue::enqueue_message(msg(&format!("callback={}", 4)), *o); - }, - 4 => (), - _ => unreachable!(), + Callback::set(Box::new(|o, i| { + match i { + 0 => { + MessageQueue::enqueue_message(msg(&format!("callback={}", 1)), *o); + }, + 1 => { + for _ in 0..100 { + MessageQueue::enqueue_message(msg(&format!("callback={}", 2)), *o); + } + for i in 0..100 { + MessageQueue::enqueue_message(msg(&format!("callback={}", 3)), i.into()); + } + }, + 2 | 3 => { + MessageQueue::enqueue_message(msg(&format!("callback={}", 4)), *o); + }, + 4 => (), + _ => unreachable!(), + }; + Ok(()) })); MessageQueue::enqueue_message(msg("callback=0"), MessageOrigin::Here); @@ -1735,6 +1739,7 @@ fn recursive_service_is_forbidden() { // This call will fail since it is recursive. But it will not mess up the state. assert_storage_noop!(MessageQueue::service_queues(10.into_weight())); MessageQueue::enqueue_message(msg("m2"), There); + Ok(()) })); for _ in 0..5 { @@ -1778,6 +1783,7 @@ fn recursive_overweight_while_service_is_forbidden() { ), ExecuteOverweightError::RecursiveDisallowed ); + Ok(()) })); MessageQueue::enqueue_message(msg("weight=10"), There); @@ -1800,6 +1806,7 @@ fn recursive_reap_page_is_forbidden() { Callback::set(Box::new(|_, _| { // This call will fail since it is recursive. But it will not mess up the state. assert_noop!(MessageQueue::do_reap_page(&Here, 0), Error::::RecursiveDisallowed); + Ok(()) })); // Create 10 pages more than the stale limit. @@ -1975,3 +1982,55 @@ fn execute_overweight_keeps_stack_ov_message() { System::reset_events(); }); } + +#[test] +fn process_message_error_reverts_storage_changes() { + build_and_execute::(|| { + assert!(!sp_io::storage::exists(b"key"), "Key should not exist"); + + Callback::set(Box::new(|_, _| { + sp_io::storage::set(b"key", b"value"); + Err(()) + })); + + MessageQueue::enqueue_message(msg("callback=0"), MessageOrigin::Here); + MessageQueue::service_queues(10.into_weight()); + + assert!(!sp_io::storage::exists(b"key"), "Key should have been rolled back"); + }); +} + +#[test] +fn process_message_ok_false_keeps_storage_changes() { + build_and_execute::(|| { + assert!(!sp_io::storage::exists(b"key"), "Key should not exist"); + + Callback::set(Box::new(|_, _| { + sp_io::storage::set(b"key", b"value"); + Ok(()) + })); + + // 000 will make it return `Ok(false)` + MessageQueue::enqueue_message(msg("callback=000"), MessageOrigin::Here); + MessageQueue::service_queues(10.into_weight()); + + assert_eq!(sp_io::storage::exists(b"key"), true); + }); +} + +#[test] +fn process_message_ok_true_keeps_storage_changes() { + build_and_execute::(|| { + assert!(!sp_io::storage::exists(b"key"), "Key should not exist"); + + Callback::set(Box::new(|_, _| { + sp_io::storage::set(b"key", b"value"); + Ok(()) + })); + + MessageQueue::enqueue_message(msg("callback=0"), MessageOrigin::Here); + MessageQueue::service_queues(10.into_weight()); + + assert_eq!(sp_io::storage::exists(b"key"), true); + }); +}