diff --git a/crates/wasmparser/src/limits.rs b/crates/wasmparser/src/limits.rs index 953981c02b..094002e2c4 100644 --- a/crates/wasmparser/src/limits.rs +++ b/crates/wasmparser/src/limits.rs @@ -27,7 +27,7 @@ pub const MAX_WASM_ELEMENT_SEGMENTS: usize = 100_000; pub const MAX_WASM_DATA_SEGMENTS: usize = 100_000; pub const MAX_WASM_STRING_SIZE: usize = 100_000; pub const MAX_WASM_FUNCTION_SIZE: usize = 128 * 1024; -pub const MAX_WASM_FUNCTION_LOCALS: usize = 50000; +pub const MAX_WASM_FUNCTION_LOCALS: u32 = 50000; pub const MAX_WASM_FUNCTION_PARAMS: usize = 1000; pub const MAX_WASM_FUNCTION_RETURNS: usize = 1000; pub const _MAX_WASM_TABLE_SIZE: usize = 10_000_000; diff --git a/crates/wasmparser/src/validator/operators.rs b/crates/wasmparser/src/validator/operators.rs index e36dfc1f4e..1a599681ca 100644 --- a/crates/wasmparser/src/validator/operators.rs +++ b/crates/wasmparser/src/validator/operators.rs @@ -32,6 +32,7 @@ use crate::{ }; use crate::{prelude::*, CompositeInnerType, Ordering}; use core::ops::{Deref, DerefMut}; +use core::{cmp, iter}; #[cfg(feature = "simd")] mod simd; @@ -159,7 +160,7 @@ impl LocalInits { // No science was performed in the creation of this number, feel free to change // it if you so like. -const MAX_LOCALS_TO_TRACK: usize = 50; +const MAX_LOCALS_TO_TRACK: u32 = 50; pub(super) struct Locals { // Total number of locals in the function. @@ -181,7 +182,7 @@ pub(super) struct Locals { // `local.{get,set,tee}`. We do a binary search for the index desired, and // it either lies in a "hole" where the maximum index is specified later, // or it's at the end of the list meaning it's out of bounds. - all: Vec<(u32, ValType)>, + uncached: Vec<(u32, ValType)>, } /// A Wasm control flow block on the control flow stack during Wasm validation. @@ -218,7 +219,7 @@ pub struct OperatorValidatorAllocations { operands: Vec, local_inits: LocalInits, locals_first: Vec, - locals_all: Vec<(u32, ValType)>, + locals_uncached: Vec<(u32, ValType)>, } /// Type storage within the validator. @@ -323,7 +324,7 @@ impl OperatorValidator { operands, local_inits, locals_first, - locals_all, + locals_uncached, } = allocs; debug_assert!(popped_types_tmp.is_empty()); debug_assert!(control.is_empty()); @@ -331,12 +332,12 @@ impl OperatorValidator { debug_assert!(local_inits.is_empty()); debug_assert!(local_inits.is_empty()); debug_assert!(locals_first.is_empty()); - debug_assert!(locals_all.is_empty()); + debug_assert!(locals_uncached.is_empty()); OperatorValidator { locals: Locals { num_locals: 0, first: locals_first, - all: locals_all, + uncached: locals_uncached, }, local_inits, features: *features, @@ -521,7 +522,7 @@ impl OperatorValidator { self.local_inits }, locals_first: clear(self.locals.first), - locals_all: clear(self.locals.all), + locals_uncached: clear(self.locals.uncached), } } @@ -4262,20 +4263,23 @@ impl Locals { /// definition is unsuccessful in case the amount of total variables /// after definition exceeds the allowed maximum number. fn define(&mut self, count: u32, ty: ValType) -> bool { + if count == 0 { + return true; + } + let vacant_first = MAX_LOCALS_TO_TRACK.saturating_sub(self.num_locals); match self.num_locals.checked_add(count) { - Some(n) => self.num_locals = n, + Some(num_locals) if num_locals > MAX_WASM_FUNCTION_LOCALS => return false, None => return false, + Some(num_locals) => self.num_locals = num_locals, + }; + let push_to_first = cmp::min(vacant_first, count); + self.first + .extend(iter::repeat(ty).take(push_to_first as usize)); + let num_uncached = count - push_to_first; + if num_uncached > 0 { + let max_uncached_idx = self.num_locals - 1; + self.uncached.push((max_uncached_idx, ty)); } - if self.num_locals > (MAX_WASM_FUNCTION_LOCALS as u32) { - return false; - } - for _ in 0..count { - if self.first.len() >= MAX_LOCALS_TO_TRACK { - break; - } - self.first.push(ty); - } - self.all.push((self.num_locals - 1, ty)); true } @@ -4294,17 +4298,17 @@ impl Locals { } fn get_bsearch(&self, idx: u32) -> Option { - match self.all.binary_search_by_key(&idx, |(idx, _)| *idx) { + match self.uncached.binary_search_by_key(&idx, |(idx, _)| *idx) { // If this index would be inserted at the end of the list, then the // index is out of bounds and we return an error. - Err(i) if i == self.all.len() => None, + Err(i) if i == self.uncached.len() => None, // If `Ok` is returned we found the index exactly, or if `Err` is // returned the position is the one which is the least index // greater that `idx`, which is still the type of `idx` according // to our "compressed" representation. In both cases we access the // list at index `i`. - Ok(i) | Err(i) => Some(self.all[i].1), + Ok(i) | Err(i) => Some(self.uncached[i].1), } } }