diff --git a/build.rs b/build.rs index 1abd7456b4..854778873e 100644 --- a/build.rs +++ b/build.rs @@ -31,6 +31,7 @@ fn main() { // Parse each rust file with syn and run the linting suite on it in parallel rust_files.par_iter().for_each_with(tx.clone(), |tx, file| { + let is_test = file.display().to_string().contains("test"); let Ok(content) = fs::read_to_string(file) else { return; }; @@ -63,6 +64,10 @@ fn main() { track_lint(ForbidKeysRemoveCall::lint(&parsed_file)); track_lint(RequireFreezeStruct::lint(&parsed_file)); track_lint(RequireExplicitPalletIndex::lint(&parsed_file)); + + if is_test { + track_lint(ForbidSaturatingMath::lint(&parsed_file)); + } }); // Collect and print all errors after the parallel processing is done diff --git a/pallets/subtensor/src/tests/epoch.rs b/pallets/subtensor/src/tests/epoch.rs index 4c24769c27..93ca1d5c82 100644 --- a/pallets/subtensor/src/tests/epoch.rs +++ b/pallets/subtensor/src/tests/epoch.rs @@ -2285,19 +2285,19 @@ fn test_compute_alpha_values() { // exp_val = exp(0.0 - 1.0 * 0.1) = exp(-0.1) // alpha[0] = 1 / (1 + exp(-0.1)) ~ 0.9048374180359595 let exp_val_0 = I32F32::from_num(0.9048374180359595); - let expected_alpha_0 = I32F32::from_num(1.0) / I32F32::from_num(1.0).saturating_add(exp_val_0); + let expected_alpha_0 = I32F32::from_num(1.0) / (I32F32::from_num(1.0) + exp_val_0); // For consensus[1] = 0.5: // exp_val = exp(0.0 - 1.0 * 0.5) = exp(-0.5) // alpha[1] = 1 / (1 + exp(-0.5)) ~ 0.6065306597126334 let exp_val_1 = I32F32::from_num(0.6065306597126334); - let expected_alpha_1 = I32F32::from_num(1.0) / I32F32::from_num(1.0).saturating_add(exp_val_1); + let expected_alpha_1 = I32F32::from_num(1.0) / (I32F32::from_num(1.0) + exp_val_1); // For consensus[2] = 0.9: // exp_val = exp(0.0 - 1.0 * 0.9) = exp(-0.9) // alpha[2] = 1 / (1 + exp(-0.9)) ~ 0.4065696597405991 let exp_val_2 = I32F32::from_num(0.4065696597405991); - let expected_alpha_2 = I32F32::from_num(1.0) / I32F32::from_num(1.0).saturating_add(exp_val_2); + let expected_alpha_2 = I32F32::from_num(1.0) / (I32F32::from_num(1.0) + exp_val_2); // Define an epsilon for approximate equality checks. let epsilon = I32F32::from_num(1e-6); @@ -2329,13 +2329,13 @@ fn test_compute_alpha_values_256_miners() { for (i, &c) in consensus.iter().enumerate() { // Use saturating subtraction and multiplication - let exponent = b.saturating_sub(a.saturating_mul(c)); + let exponent = b - (a * c); // Use safe_exp instead of exp let exp_val = safe_exp(exponent); // Use saturating addition and division - let expected_alpha = I32F32::from_num(1.0) / I32F32::from_num(1.0).saturating_add(exp_val); + let expected_alpha = I32F32::from_num(1.0) / (I32F32::from_num(1.0) + exp_val); // Assert that the computed alpha values match the expected values within the epsilon. assert_approx_eq(alpha[i], expected_alpha, epsilon); diff --git a/support/linting/src/forbid_saturating_math.rs b/support/linting/src/forbid_saturating_math.rs new file mode 100644 index 0000000000..9ad5385b36 --- /dev/null +++ b/support/linting/src/forbid_saturating_math.rs @@ -0,0 +1,113 @@ +use super::*; +use syn::{Expr, ExprCall, ExprMethodCall, ExprPath, File, Path, spanned::Spanned, visit::Visit}; + +pub struct ForbidSaturatingMath; + +impl Lint for ForbidSaturatingMath { + fn lint(source: &File) -> Result { + let mut visitor = SaturatingMathBanVisitor::default(); + visitor.visit_file(source); + + if visitor.errors.is_empty() { + Ok(()) + } else { + Err(visitor.errors) + } + } +} + +#[derive(Default)] +struct SaturatingMathBanVisitor { + errors: Vec, +} + +impl<'ast> Visit<'ast> for SaturatingMathBanVisitor { + fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) { + let ExprMethodCall { method, .. } = node; + + if method.to_string().starts_with("saturating_") { + let msg = "Safe math is banned to encourage tests to panic"; + self.errors.push(syn::Error::new(method.span(), msg)); + } + } + + fn visit_expr_call(&mut self, node: &'ast ExprCall) { + let ExprCall { func, .. } = node; + + if is_saturating_math_call(func) { + let msg = "Safe math is banned to encourage tests to panic"; + self.errors.push(syn::Error::new(node.func.span(), msg)); + } + } +} + +fn is_saturating_math_call(func: &Expr) -> bool { + let Expr::Path(ExprPath { + path: Path { segments: path, .. }, + .. + }) = func + else { + return false; + }; + + path.last() + .is_some_and(|seg| seg.ident.to_string().starts_with("saturating_")) +} + +#[cfg(test)] +mod tests { + use super::*; + use quote::quote; + + fn lint(input: proc_macro2::TokenStream) -> Result { + let mut visitor = SaturatingMathBanVisitor::default(); + let expr: syn::Expr = syn::parse2(input).expect("should be a valid expression"); + + match &expr { + syn::Expr::MethodCall(call) => visitor.visit_expr_method_call(call), + syn::Expr::Call(call) => visitor.visit_expr_call(call), + _ => panic!("should be a valid method call or function call"), + } + + if visitor.errors.is_empty() { + Ok(()) + } else { + Err(visitor.errors) + } + } + + #[test] + fn test_saturating_forbidden() { + let input = quote! { stake.saturating_add(alpha) }; + assert!(lint(input).is_err()); + let input = quote! { alpha_price.saturating_mul(float_alpha_block_emission) }; + assert!(lint(input).is_err()); + let input = quote! { alpha_out_i.saturating_sub(root_alpha) }; + assert!(lint(input).is_err()); + } + + #[test] + fn test_saturating_ufcs_forbidden() { + let input = quote! { SaturatingAdd::saturating_add(stake, alpha) }; + assert!(lint(input).is_err()); + let input = quote! { core::num::SaturatingAdd::saturating_add(stake, alpha) }; + assert!(lint(input).is_err()); + let input = + quote! { SaturatingMul::saturating_mul(alpha_price, float_alpha_block_emission) }; + assert!(lint(input).is_err()); + let input = quote! { core::num::SaturatingMul::saturating_mul(alpha_price, float_alpha_block_emission) }; + assert!(lint(input).is_err()); + let input = quote! { SaturatingSub::saturating_sub(alpha_out_i, root_alpha) }; + assert!(lint(input).is_err()); + let input = quote! { core::num::SaturatingSub::saturating_sub(alpha_out_i, root_alpha) }; + assert!(lint(input).is_err()); + } + + #[test] + fn test_saturating_to_from_num_forbidden() { + let input = quote! { I96F32::saturating_from_num(u64::MAX) }; + assert!(lint(input).is_err()); + let input = quote! { remaining_emission.saturating_to_num::() }; + assert!(lint(input).is_err()); + } +} diff --git a/support/linting/src/lib.rs b/support/linting/src/lib.rs index a65466e6a5..864d4a5572 100644 --- a/support/linting/src/lib.rs +++ b/support/linting/src/lib.rs @@ -3,10 +3,12 @@ pub use lint::*; mod forbid_as_primitive; mod forbid_keys_remove; +mod forbid_saturating_math; mod pallet_index; mod require_freeze_struct; pub use forbid_as_primitive::ForbidAsPrimitiveConversion; pub use forbid_keys_remove::ForbidKeysRemoveCall; +pub use forbid_saturating_math::ForbidSaturatingMath; pub use pallet_index::RequireExplicitPalletIndex; pub use require_freeze_struct::RequireFreezeStruct;