Skip to content

Commit 28ed9bd

Browse files
committed
make recursion guard option-al
1 parent 087d040 commit 28ed9bd

File tree

3 files changed

+74
-22
lines changed

3 files changed

+74
-22
lines changed

Cargo.lock

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ indexmap = "1.8.1"
1919
mimalloc = { version = "0.1.29", default-features = false, optional = true }
2020
speedate = "0.4.1"
2121
ahash = "0.7.6"
22+
nohash-hasher = "0.2.0"
2223

2324
[lib]
2425
name = "_pydantic_core"

src/validators/mod.rs

+66-22
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
use ahash::AHashSet;
1+
use std::collections::HashSet;
22
use std::fmt::Debug;
3+
use std::hash::BuildHasherDefault;
34

45
use enum_dispatch::enum_dispatch;
6+
use nohash_hasher::NoHashHasher;
57

68
use pyo3::exceptions::PyRecursionError;
79
use pyo3::prelude::*;
@@ -75,19 +77,24 @@ impl SchemaValidator {
7577
}
7678

7779
pub fn validate_python(&self, py: Python, input: &PyAny) -> PyResult<PyObject> {
78-
let mut recursion_guard = RecursionGuard::new();
79-
let r = self
80-
.validator
81-
.validate(py, input, &Extra::default(), &self.slots, &mut recursion_guard);
80+
let r = self.validator.validate(
81+
py,
82+
input,
83+
&Extra::default(),
84+
&self.slots,
85+
&mut RecursionGuard::default(),
86+
);
8287
r.map_err(|e| self.prepare_validation_err(py, e))
8388
}
8489

8590
pub fn isinstance_python(&self, py: Python, input: &PyAny) -> PyResult<bool> {
86-
let mut recursion_guard = RecursionGuard::new();
87-
match self
88-
.validator
89-
.validate(py, input, &Extra::default(), &self.slots, &mut recursion_guard)
90-
{
91+
match self.validator.validate(
92+
py,
93+
input,
94+
&Extra::default(),
95+
&self.slots,
96+
&mut RecursionGuard::default(),
97+
) {
9198
Ok(_) => Ok(true),
9299
Err(ValError::InternalErr(err)) => Err(err),
93100
_ => Ok(false),
@@ -97,10 +104,13 @@ impl SchemaValidator {
97104
pub fn validate_json(&self, py: Python, input: String) -> PyResult<PyObject> {
98105
match parse_json::<JsonInput>(&input) {
99106
Ok(input) => {
100-
let mut recursion_guard = RecursionGuard::new();
101-
let r = self
102-
.validator
103-
.validate(py, &input, &Extra::default(), &self.slots, &mut recursion_guard);
107+
let r = self.validator.validate(
108+
py,
109+
&input,
110+
&Extra::default(),
111+
&self.slots,
112+
&mut RecursionGuard::default(),
113+
);
104114
r.map_err(|e| self.prepare_validation_err(py, e))
105115
}
106116
Err(e) => {
@@ -118,11 +128,13 @@ impl SchemaValidator {
118128
pub fn isinstance_json(&self, py: Python, input: String) -> PyResult<bool> {
119129
match parse_json::<JsonInput>(&input) {
120130
Ok(input) => {
121-
let mut recursion_guard = RecursionGuard::new();
122-
match self
123-
.validator
124-
.validate(py, &input, &Extra::default(), &self.slots, &mut recursion_guard)
125-
{
131+
match self.validator.validate(
132+
py,
133+
&input,
134+
&Extra::default(),
135+
&self.slots,
136+
&mut RecursionGuard::default(),
137+
) {
126138
Ok(_) => Ok(true),
127139
Err(ValError::InternalErr(err)) => Err(err),
128140
_ => Ok(false),
@@ -137,10 +149,9 @@ impl SchemaValidator {
137149
data: Some(data),
138150
field: Some(field.as_str()),
139151
};
140-
let mut recursion_guard = RecursionGuard::new();
141152
let r = self
142153
.validator
143-
.validate(py, input, &extra, &self.slots, &mut recursion_guard);
154+
.validate(py, input, &extra, &self.slots, &mut RecursionGuard::default());
144155
r.map_err(|e| self.prepare_validation_err(py, e))
145156
}
146157

@@ -353,7 +364,40 @@ pub enum CombinedValidator {
353364
FrozenSet(frozenset::FrozenSetValidator),
354365
}
355366

356-
pub type RecursionGuard = AHashSet<usize>;
367+
#[derive(Debug, Clone, Default)]
368+
pub struct RecursionGuard(Option<HashSet<usize, BuildHasherDefault<NoHashHasher<usize>>>>);
369+
370+
impl RecursionGuard {
371+
pub fn insert(&mut self, id: usize) {
372+
match self.0 {
373+
Some(ref mut set) => {
374+
set.insert(id);
375+
}
376+
None => {
377+
let mut set: HashSet<usize, BuildHasherDefault<NoHashHasher<usize>>> =
378+
HashSet::with_capacity_and_hasher(10, BuildHasherDefault::default());
379+
set.insert(id);
380+
self.0 = Some(set);
381+
}
382+
};
383+
}
384+
385+
pub fn remove(&mut self, id: &usize) {
386+
match self.0 {
387+
Some(ref mut set) => {
388+
set.remove(id);
389+
}
390+
None => unreachable!(),
391+
};
392+
}
393+
394+
pub fn contains(&self, id: &usize) -> bool {
395+
match self.0 {
396+
Some(ref set) => set.contains(id),
397+
None => false,
398+
}
399+
}
400+
}
357401

358402
/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,
359403
/// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait

0 commit comments

Comments
 (0)