diff --git a/ic-canister-runtime/src/stub/mod.rs b/ic-canister-runtime/src/stub/mod.rs index 846c75f..4b92894 100644 --- a/ic-canister-runtime/src/stub/mod.rs +++ b/ic-canister-runtime/src/stub/mod.rs @@ -26,7 +26,7 @@ use std::{collections::VecDeque, sync::Mutex}; /// let runtime = StubRuntime::new() /// .add_stub_response(1_u64) /// .add_stub_response("two") -/// .add_stub_response(Some(3_u128)); +/// .add_stub_error(IcError::CallPerformFailed); /// /// let result_1: Result = runtime /// .update_call(PRINCIPAL, METHOD, ARGS, 0) @@ -39,17 +39,17 @@ use std::{collections::VecDeque, sync::Mutex}; /// assert_eq!(result_2, Ok("two".to_string())); /// /// let result_3: Result, IcError> = runtime -/// .update_call(PRINCIPAL, METHOD, ARGS, 0) +/// .query_call(PRINCIPAL, METHOD, ARGS) /// .await; -/// assert_eq!(result_3, Ok(Some -/// (3_u128))); +/// assert_eq!(result_3, Err(IcError::CallPerformFailed)); /// # Ok(()) /// # } /// ``` #[derive(Debug, Default, Clone)] pub struct StubRuntime { // Use a mutex so that this struct is Send and Sync - call_results: Arc>>>, + #[allow(clippy::type_complexity)] + call_results: Arc, IcError>>>>, } impl StubRuntime { @@ -63,7 +63,16 @@ impl StubRuntime { /// Panics if the stub response cannot be encoded using Candid. pub fn add_stub_response(self, stub_response: Out) -> Self { let result = Encode!(&stub_response).expect("Failed to encode Candid stub response"); - self.call_results.try_lock().unwrap().push_back(result); + self.call_results.try_lock().unwrap().push_back(Ok(result)); + self + } + + /// Mutate the [`StubRuntime`] instance to add the given stub error. + pub fn add_stub_error(self, stub_error: impl Into) -> Self { + self.call_results + .try_lock() + .unwrap() + .push_back(Err(stub_error.into())); self } @@ -71,13 +80,12 @@ impl StubRuntime { where Out: CandidType + DeserializeOwned, { - let bytes = self - .call_results + self.call_results .try_lock() .unwrap() .pop_front() - .unwrap_or_else(|| panic!("No available call response")); - Ok(Decode!(&bytes, Out).expect("Failed to decode Candid stub response")) + .unwrap_or_else(|| panic!("No available call response")) + .map(|bytes| Decode!(&bytes, Out).expect("Failed to decode Candid stub response")) } } diff --git a/ic-canister-runtime/src/stub/tests.rs b/ic-canister-runtime/src/stub/tests.rs index 75ee108..e2de1f1 100644 --- a/ic-canister-runtime/src/stub/tests.rs +++ b/ic-canister-runtime/src/stub/tests.rs @@ -1,5 +1,6 @@ use crate::{IcError, Runtime, StubRuntime}; use candid::{CandidType, Principal}; +use ic_error_types::RejectCode; use serde::Deserialize; const DEFAULT_PRINCIPAL: Principal = Principal::from_slice(&[0x9d, 0xf7, 0x01]); @@ -38,6 +39,18 @@ async fn should_return_single_stub_response() { assert_eq!(result, Ok(expected)); } +#[tokio::test] +async fn should_return_single_stub_error() { + let expected = IcError::CallPerformFailed; + let runtime = StubRuntime::new().add_stub_error(expected.clone()); + + let result: Result = runtime + .update_call(DEFAULT_PRINCIPAL, DEFAULT_METHOD, DEFAULT_ARGS, 0) + .await; + + assert_eq!(result, Err(expected)); +} + #[tokio::test] async fn should_return_multiple_stub_responses() { let expected1 = MultiResult::Consistent("Hello, world!".to_string()); @@ -46,25 +59,32 @@ async fn should_return_multiple_stub_responses() { "Goodbye, world!".to_string(), ]); let expected3 = 0_u128; + let expected4 = IcError::CallRejected { + code: RejectCode::SysFatal, + message: "Fatal error!".to_string(), + }; let runtime = StubRuntime::new() .add_stub_response(expected1.clone()) .add_stub_response(expected2.clone()) - .add_stub_response(expected3); + .add_stub_response(expected3) + .add_stub_error(expected4.clone()); let result1: Result = runtime .update_call(DEFAULT_PRINCIPAL, DEFAULT_METHOD, DEFAULT_ARGS, 0) .await; assert_eq!(result1, Ok(expected1)); - let result2: Result = runtime - .update_call(DEFAULT_PRINCIPAL, DEFAULT_METHOD, DEFAULT_ARGS, 0) + .query_call(DEFAULT_PRINCIPAL, DEFAULT_METHOD, DEFAULT_ARGS) .await; assert_eq!(result2, Ok(expected2)); - let result3: Result = runtime - .update_call(DEFAULT_PRINCIPAL, DEFAULT_METHOD, DEFAULT_ARGS, 0) + .query_call(DEFAULT_PRINCIPAL, DEFAULT_METHOD, DEFAULT_ARGS) .await; assert_eq!(result3, Ok(expected3)); + let result4: Result = runtime + .update_call(DEFAULT_PRINCIPAL, DEFAULT_METHOD, DEFAULT_ARGS, 0) + .await; + assert_eq!(result4, Err(expected4)); } #[tokio::test]