diff --git a/frame/support/procedural/src/lib.rs b/frame/support/procedural/src/lib.rs index bf87bd552fd06..701ede60b064b 100644 --- a/frame/support/procedural/src/lib.rs +++ b/frame/support/procedural/src/lib.rs @@ -325,3 +325,8 @@ pub fn construct_runtime(input: TokenStream) -> TokenStream { pub fn transactional(attr: TokenStream, input: TokenStream) -> TokenStream { transactional::transactional(attr, input).unwrap_or_else(|e| e.to_compile_error().into()) } + +#[proc_macro_attribute] +pub fn require_transactional(attr: TokenStream, input: TokenStream) -> TokenStream { + transactional::require_transactional(attr, input).unwrap_or_else(|e| e.to_compile_error().into()) +} diff --git a/frame/support/procedural/src/transactional.rs b/frame/support/procedural/src/transactional.rs index fbd0c9ca0b3c4..688d8cfa067f3 100644 --- a/frame/support/procedural/src/transactional.rs +++ b/frame/support/procedural/src/transactional.rs @@ -41,3 +41,18 @@ pub fn transactional(_attr: TokenStream, input: TokenStream) -> Result Result { + let ItemFn { attrs, vis, sig, block } = syn::parse(input)?; + + let crate_ = generate_crate_access_2018()?; + let output = quote! { + #(#attrs)* + #vis #sig { + #crate_::storage::require_transaction(); + #block + } + }; + + Ok(output.into()) +} diff --git a/frame/support/src/lib.rs b/frame/support/src/lib.rs index bdbdfc04a31f9..889ff15066ba5 100644 --- a/frame/support/src/lib.rs +++ b/frame/support/src/lib.rs @@ -267,7 +267,40 @@ macro_rules! ord_parameter_types { } #[doc(inline)] -pub use frame_support_procedural::{decl_storage, construct_runtime, transactional}; +pub use frame_support_procedural::{ + decl_storage, construct_runtime, transactional +}; + +/// Assert the annotated function is executed within a storage transaction. +/// +/// The assertion is enabled for native execution and when `debug_assertions` are enabled. +/// +/// # Example +/// +/// ``` +/// # use frame_support::{ +/// # require_transactional, transactional, dispatch::DispatchResult +/// # }; +/// +/// #[require_transactional] +/// fn update_all(value: u32) -> DispatchResult { +/// // Update multiple storages. +/// // Return `Err` to indicate should revert. +/// Ok(()) +/// } +/// +/// #[transactional] +/// fn safe_update(value: u32) -> DispatchResult { +/// // This is safe +/// update_all(value) +/// } +/// +/// fn unsafe_update(value: u32) -> DispatchResult { +/// // this may panic if unsafe_update is not called within a storage transaction +/// update_all(value) +/// } +/// ``` +pub use frame_support_procedural::require_transactional; /// Return Err of the expression: `return Err($expression);`. /// diff --git a/frame/support/src/storage/mod.rs b/frame/support/src/storage/mod.rs index 5ee144c79c4db..97c1eabe6d39d 100644 --- a/frame/support/src/storage/mod.rs +++ b/frame/support/src/storage/mod.rs @@ -30,6 +30,57 @@ pub mod child; pub mod generator; pub mod migration; +#[cfg(all(feature = "std", any(test, debug_assertions)))] +mod debug_helper { + use std::cell::RefCell; + + thread_local! { + static TRANSACTION_LEVEL: RefCell = RefCell::new(0); + } + + pub fn require_transaction() { + let level = TRANSACTION_LEVEL.with(|v| *v.borrow()); + if level == 0 { + panic!("Require transaction not called within with_transaction"); + } + } + + pub struct TransactionLevelGuard; + + impl Drop for TransactionLevelGuard { + fn drop(&mut self) { + TRANSACTION_LEVEL.with(|v| *v.borrow_mut() -= 1); + } + } + + /// Increments the transaction level. + /// + /// Returns a guard that when dropped decrements the transaction level automatically. + pub fn inc_transaction_level() -> TransactionLevelGuard { + TRANSACTION_LEVEL.with(|v| { + let mut val = v.borrow_mut(); + *val += 1; + if *val > 10 { + crate::debug::warn!( + "Detected with_transaction with nest level {}. Nested usage of with_transaction is not recommended.", + *val + ); + } + }); + + TransactionLevelGuard + } +} + +/// Assert this method is called within a storage transaction. +/// This will **panic** if is not called within a storage transaction. +/// +/// This assertion is enabled for native execution and when `debug_assertions` are enabled. +pub fn require_transaction() { + #[cfg(all(feature = "std", any(test, debug_assertions)))] + debug_helper::require_transaction(); +} + /// Execute the supplied function in a new storage transaction. /// /// All changes to storage performed by the supplied function are discarded if the returned @@ -43,6 +94,10 @@ pub fn with_transaction(f: impl FnOnce() -> TransactionOutcome) -> R { use TransactionOutcome::*; start_transaction(); + + #[cfg(all(feature = "std", any(test, debug_assertions)))] + let _guard = debug_helper::inc_transaction_level(); + match f() { Commit(res) => { commit_transaction(); res }, Rollback(res) => { rollback_transaction(); res }, @@ -732,4 +787,27 @@ mod test { assert_eq!(Digest::decode(&mut &value[..]).unwrap(), expected); }); } + + #[test] + #[should_panic(expected = "Require transaction not called within with_transaction")] + fn require_transaction_should_panic() { + TestExternalities::default().execute_with(|| { + require_transaction(); + }); + } + + #[test] + fn require_transaction_should_not_panic_in_with_transaction() { + TestExternalities::default().execute_with(|| { + with_transaction(|| { + require_transaction(); + TransactionOutcome::Commit(()) + }); + + with_transaction(|| { + require_transaction(); + TransactionOutcome::Rollback(()) + }); + }); + } } diff --git a/frame/support/test/tests/storage_transaction.rs b/frame/support/test/tests/storage_transaction.rs index a7e4a75c27fcb..e305940ee68e2 100644 --- a/frame/support/test/tests/storage_transaction.rs +++ b/frame/support/test/tests/storage_transaction.rs @@ -17,7 +17,9 @@ use codec::{Encode, Decode, EncodeLike}; use frame_support::{ - assert_ok, assert_noop, dispatch::{DispatchError, DispatchResult}, transactional, StorageMap, StorageValue, + assert_ok, assert_noop, transactional, + StorageMap, StorageValue, + dispatch::{DispatchError, DispatchResult}, storage::{with_transaction, TransactionOutcome::*}, }; use sp_io::TestExternalities;