From e1cb0ebe5633fee835cd149768ebf43adfde4f6b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 16:26:37 +0000 Subject: [PATCH] improve performance of recursion guard (#1156) Co-authored-by: David Hewitt Co-authored-by: David Hewitt --- src/recursion_guard.rs | 144 ++++++++++++++++++++++++++++------ src/serializers/extra.rs | 14 ++-- src/validators/definitions.rs | 16 ++-- tests/serializers/test_any.py | 2 +- 4 files changed, 134 insertions(+), 42 deletions(-) diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index 453f01a1d..fe5b1bcdd 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -1,4 +1,5 @@ use ahash::AHashSet; +use std::mem::MaybeUninit; type RecursionKey = ( // Identifier for the input object, e.g. the id() of a Python dict @@ -13,56 +14,147 @@ type RecursionKey = ( /// It's used in `validators/definition` to detect when a reference is reused within itself. #[derive(Debug, Clone, Default)] pub struct RecursionGuard { - ids: Option>, + ids: RecursionStack, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators - depth: u16, + depth: u8, } // A hard limit to avoid stack overflows when rampant recursion occurs -pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) { +pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) { // wasm and windows PyPy have very limited stack sizes - 50 + 49 } else if cfg!(any(PyPy, windows)) { // PyPy and Windows in general have more restricted stack space - 100 + 99 } else { 255 }; impl RecursionGuard { - // insert a new id into the set, return whether the set already had the id in it - pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool { - match self.ids { - // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert - // "If the set did not have this value present, `true` is returned." - Some(ref mut set) => !set.insert((obj_id, node_id)), - None => { - let mut set: AHashSet = AHashSet::with_capacity(10); - set.insert((obj_id, node_id)); - self.ids = Some(set); - false - } - } + // insert a new value + // * return `false` if the stack already had it in it + // * return `true` if the stack didn't have it in it and it was inserted + pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool { + self.ids.insert((obj_id, node_id)) } // see #143 this is used as a backup in case the identity check recursion guard fails #[must_use] + #[cfg(any(target_family = "wasm", windows, PyPy))] pub fn incr_depth(&mut self) -> bool { - self.depth += 1; - self.depth >= RECURSION_GUARD_LIMIT + // use saturating_add as it's faster (since there's no error path) + // and the RECURSION_GUARD_LIMIT check will be hit before it overflows + debug_assert!(RECURSION_GUARD_LIMIT < 255); + self.depth = self.depth.saturating_add(1); + self.depth > RECURSION_GUARD_LIMIT + } + + #[must_use] + #[cfg(not(any(target_family = "wasm", windows, PyPy)))] + pub fn incr_depth(&mut self) -> bool { + debug_assert_eq!(RECURSION_GUARD_LIMIT, 255); + // use checked_add to check if we've hit the limit + if let Some(depth) = self.depth.checked_add(1) { + self.depth = depth; + false + } else { + true + } } pub fn decr_depth(&mut self) { - self.depth -= 1; + // for the same reason as incr_depth, use saturating_sub + self.depth = self.depth.saturating_sub(1); } pub fn remove(&mut self, obj_id: usize, node_id: usize) { - match self.ids { - Some(ref mut set) => { - set.remove(&(obj_id, node_id)); + self.ids.remove(&(obj_id, node_id)); + } +} + +// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower +const ARRAY_SIZE: usize = 16; + +#[derive(Debug, Clone)] +enum RecursionStack { + Array { + data: [MaybeUninit; ARRAY_SIZE], + len: usize, + }, + Set(AHashSet), +} + +impl Default for RecursionStack { + fn default() -> Self { + Self::Array { + data: std::array::from_fn(|_| MaybeUninit::uninit()), + len: 0, + } + } +} + +impl RecursionStack { + // insert a new value + // * return `false` if the stack already had it in it + // * return `true` if the stack didn't have it in it and it was inserted + pub fn insert(&mut self, v: RecursionKey) -> bool { + match self { + Self::Array { data, len } => { + if *len < ARRAY_SIZE { + for value in data.iter().take(*len) { + // Safety: reading values within bounds + if unsafe { value.assume_init() } == v { + return false; + } + } + + data[*len].write(v); + *len += 1; + true + } else { + let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1); + for existing in data.iter() { + // Safety: the array is fully initialized + set.insert(unsafe { existing.assume_init() }); + } + let inserted = set.insert(v); + *self = Self::Set(set); + inserted + } + } + // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert + // "If the set did not have this value present, `true` is returned." + Self::Set(set) => set.insert(v), + } + } + + pub fn remove(&mut self, v: &RecursionKey) { + match self { + Self::Array { data, len } => { + *len = len.checked_sub(1).expect("remove from empty recursion guard"); + // Safety: this is reading what was the back of the initialized array + let removed = unsafe { data.get_unchecked_mut(*len) }; + assert!(unsafe { removed.assume_init_ref() } == v, "remove did not match insert"); + // this should compile away to a noop + unsafe { std::ptr::drop_in_place(removed.as_mut_ptr()) } + } + Self::Set(set) => { + set.remove(v); + } + } + } +} + +impl Drop for RecursionStack { + fn drop(&mut self) { + // This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed + // desirable to leave this in for safety in case that should change in the future + if let Self::Array { data, len } = self { + for value in data.iter_mut().take(*len) { + // Safety: reading values within bounds + unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) }; } - None => unreachable!(), - }; + } } } diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 37307055e..b3978613a 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -346,17 +346,17 @@ pub struct SerRecursionGuard { impl SerRecursionGuard { pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult { - // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert - // "If the set did not have this value present, `true` is returned." let id = value.as_ptr() as usize; let mut guard = self.guard.borrow_mut(); - if guard.contains_or_insert(id, def_ref_id) { - Err(PyValueError::new_err("Circular reference detected (id repeated)")) - } else if guard.incr_depth() { - Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) + if guard.insert(id, def_ref_id) { + if guard.incr_depth() { + Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) + } else { + Ok(id) + } } else { - Ok(id) + Err(PyValueError::new_err("Circular reference detected (id repeated)")) } } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 0b5f78c10..e8c67a690 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -76,10 +76,7 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) - } else { + if state.recursion_guard.insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } @@ -87,6 +84,9 @@ impl Validator for DefinitionRefValidator { state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output + } else { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) } } else { validator.validate(py, input, state) @@ -105,10 +105,7 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) - } else { + if state.recursion_guard.insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } @@ -116,6 +113,9 @@ impl Validator for DefinitionRefValidator { state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output + } else { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) } } else { validator.validate_assignment(py, obj, field_name, field_value, state) diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 98ec22c1f..fa6e702fe 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -371,7 +371,7 @@ def fallback_func(obj): f = FoobarCount(0) v = 0 # when recursion is detected and we're in mode python, we just return the value - expected_visits = pydantic_core._pydantic_core._recursion_limit - 1 + expected_visits = pydantic_core._pydantic_core._recursion_limit assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'') with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):