Skip to content

Commit b95b3d2

Browse files
authored
Recursion guard (#134)
* adding recursion_guard argument * fix linting, start on logic * basic recursion implementation working * make recursion guard option-al * more tests * move RecursionGuard, optimise recursion check * tests for recursion across a wrap validator * bump * tweaks
1 parent 771c928 commit b95b3d2

36 files changed

+565
-73
lines changed

Cargo.lock

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ indexmap = "1.8.1"
1919
mimalloc = { version = "0.1.29", default-features = false, optional = true }
2020
speedate = "0.4.1"
2121
ahash = "0.7.6"
22+
nohash-hasher = "0.2.0"
2223

2324
[lib]
2425
name = "_pydantic_core"

src/errors/kinds.rs

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ pub enum ErrorKind {
88
#[strum(message = "Invalid JSON: {parser_error}")]
99
InvalidJson,
1010
// ---------------------
11+
// recursion error
12+
#[strum(message = "Recursion error - cyclic reference detected")]
13+
RecursionLoop,
14+
// ---------------------
1115
// typed dict specific errors
1216
#[strum(message = "Value must be a valid dictionary or instance to extract fields from")]
1317
DictAttributesType,

src/errors/line_error.rs

+22-3
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,27 @@ pub enum ValError<'a> {
1717
InternalErr(PyErr),
1818
}
1919

20+
impl<'a> From<PyErr> for ValError<'a> {
21+
fn from(py_err: PyErr) -> Self {
22+
Self::InternalErr(py_err)
23+
}
24+
}
25+
26+
impl<'a> From<Vec<ValLineError<'a>>> for ValError<'a> {
27+
fn from(line_errors: Vec<ValLineError<'a>>) -> Self {
28+
Self::LineErrors(line_errors)
29+
}
30+
}
31+
2032
// ValError used to implement Error, see #78 for removed code
2133

34+
// TODO, remove and replace with just .into()
2235
pub fn as_internal<'a>(err: PyErr) -> ValError<'a> {
23-
ValError::InternalErr(err)
36+
err.into()
2437
}
2538

2639
pub fn pretty_line_errors(py: Python, line_errors: Vec<ValLineError>) -> String {
27-
let py_line_errors: Vec<PyLineError> = line_errors.into_iter().map(|e| PyLineError::new(py, e)).collect();
40+
let py_line_errors: Vec<PyLineError> = line_errors.into_iter().map(|e| e.into_py(py)).collect();
2841
pretty_py_line_errors(Some(py), py_line_errors.iter())
2942
}
3043

@@ -58,7 +71,7 @@ impl<'a> ValLineError<'a> {
5871
ValLineError {
5972
kind: self.kind,
6073
reverse_location: self.reverse_location,
61-
input_value: InputValue::PyObject(self.input_value.to_object(py)),
74+
input_value: self.input_value.to_object(py).into(),
6275
context: self.context,
6376
}
6477
}
@@ -79,6 +92,12 @@ impl Default for InputValue<'_> {
7992
}
8093
}
8194

95+
impl<'a> From<PyObject> for InputValue<'a> {
96+
fn from(py_object: PyObject) -> Self {
97+
Self::PyObject(py_object)
98+
}
99+
}
100+
82101
impl<'a> ToPyObject for InputValue<'a> {
83102
fn to_object(&self, py: Python) -> PyObject {
84103
match self {

src/errors/validation_exception.rs

+39-10
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use super::location::Location;
1616
use super::ValError;
1717

1818
#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
19-
#[derive(Debug)]
19+
#[derive(Debug, Clone)]
2020
pub struct ValidationError {
2121
line_errors: Vec<PyLineError>,
2222
title: PyObject,
@@ -33,7 +33,7 @@ impl ValidationError {
3333
pub fn from_val_error(py: Python, title: PyObject, error: ValError) -> PyErr {
3434
match error {
3535
ValError::LineErrors(raw_errors) => {
36-
let line_errors: Vec<PyLineError> = raw_errors.into_iter().map(|e| PyLineError::new(py, e)).collect();
36+
let line_errors: Vec<PyLineError> = raw_errors.into_iter().map(|e| e.into_py(py)).collect();
3737
PyErr::new::<ValidationError, _>((line_errors, title))
3838
}
3939
ValError::InternalErr(err) => err,
@@ -61,6 +61,18 @@ impl Error for ValidationError {
6161
}
6262
}
6363

64+
// used to convert a validation error back to ValError for wrap functions
65+
impl<'a> From<ValidationError> for ValError<'a> {
66+
fn from(val_error: ValidationError) -> Self {
67+
val_error
68+
.line_errors
69+
.into_iter()
70+
.map(|e| e.into())
71+
.collect::<Vec<_>>()
72+
.into()
73+
}
74+
}
75+
6476
#[pymethods]
6577
impl ValidationError {
6678
#[new]
@@ -131,19 +143,36 @@ pub struct PyLineError {
131143
context: Context,
132144
}
133145

134-
impl PyLineError {
135-
pub fn new(py: Python, raw_error: ValLineError) -> Self {
146+
impl<'a> IntoPy<PyLineError> for ValLineError<'a> {
147+
fn into_py(self, py: Python<'_>) -> PyLineError {
148+
PyLineError {
149+
kind: self.kind,
150+
location: match self.reverse_location.len() {
151+
0..=1 => self.reverse_location,
152+
_ => self.reverse_location.into_iter().rev().collect(),
153+
},
154+
input_value: self.input_value.to_object(py),
155+
context: self.context,
156+
}
157+
}
158+
}
159+
160+
/// opposite of above, used to extract line errors from a validation error for wrap functions
161+
impl<'a> From<PyLineError> for ValLineError<'a> {
162+
fn from(py_line_error: PyLineError) -> Self {
136163
Self {
137-
kind: raw_error.kind,
138-
location: match raw_error.reverse_location.len() {
139-
0..=1 => raw_error.reverse_location,
140-
_ => raw_error.reverse_location.into_iter().rev().collect(),
164+
kind: py_line_error.kind,
165+
reverse_location: match py_line_error.location.len() {
166+
0..=1 => py_line_error.location,
167+
_ => py_line_error.location.into_iter().rev().collect(),
141168
},
142-
input_value: raw_error.input_value.to_object(py),
143-
context: raw_error.context,
169+
input_value: py_line_error.input_value.into(),
170+
context: py_line_error.context,
144171
}
145172
}
173+
}
146174

175+
impl PyLineError {
147176
pub fn as_dict(&self, py: Python) -> PyResult<PyObject> {
148177
let dict = PyDict::new(py);
149178
dict.set_item("kind", self.kind())?;

src/input/input_abstract.rs

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
1616

1717
fn as_error_value(&'a self) -> InputValue<'a>;
1818

19+
fn identity(&'a self) -> Option<usize> {
20+
None
21+
}
22+
1923
fn is_none(&self) -> bool;
2024

2125
fn strict_str<'data>(&'data self) -> ValResult<EitherString<'data>>;

src/input/input_python.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use std::str::from_utf8;
22

33
use pyo3::exceptions::{PyAttributeError, PyTypeError};
4-
use pyo3::intern;
54
use pyo3::prelude::*;
65
use pyo3::types::{
76
PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, PyString,
87
PyTime, PyTuple, PyType,
98
};
9+
use pyo3::{intern, AsPyPointer};
1010

1111
use crate::errors::location::LocItem;
1212
use crate::errors::{as_internal, context, err_val_error, py_err_string, ErrorKind, InputValue, ValResult};
@@ -36,6 +36,10 @@ impl<'a> Input<'a> for PyAny {
3636
InputValue::PyAny(self)
3737
}
3838

39+
fn identity(&'a self) -> Option<usize> {
40+
Some(self.as_ptr() as usize)
41+
}
42+
3943
fn is_none(&self) -> bool {
4044
self.is_none()
4145
}

src/input/return_enums.rs

+17-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use pyo3::prelude::*;
44
use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple};
55

66
use crate::errors::{ValError, ValLineError, ValResult};
7+
use crate::recursion_guard::RecursionGuard;
78
use crate::validators::{CombinedValidator, Extra, Validator};
89

910
use super::parse_json::{JsonArray, JsonObject};
@@ -41,11 +42,12 @@ macro_rules! build_validate_to_vec {
4142
validator: &'s CombinedValidator,
4243
extra: &Extra,
4344
slots: &'a [CombinedValidator],
45+
recursion_guard: &'s mut RecursionGuard,
4446
) -> ValResult<'a, Vec<PyObject>> {
4547
let mut output: Vec<PyObject> = Vec::with_capacity(length);
4648
let mut errors: Vec<ValLineError> = Vec::new();
4749
for (index, item) in sequence.iter().enumerate() {
48-
match validator.validate(py, item, extra, slots) {
50+
match validator.validate(py, item, extra, slots, recursion_guard) {
4951
Ok(item) => output.push(item),
5052
Err(ValError::LineErrors(line_errors)) => {
5153
errors.extend(
@@ -90,13 +92,22 @@ impl<'a> GenericSequence<'a> {
9092
validator: &'s CombinedValidator,
9193
extra: &Extra,
9294
slots: &'a [CombinedValidator],
95+
recursion_guard: &'s mut RecursionGuard,
9396
) -> ValResult<'a, Vec<PyObject>> {
9497
match self {
95-
Self::List(sequence) => validate_to_vec_list(py, sequence, length, validator, extra, slots),
96-
Self::Tuple(sequence) => validate_to_vec_tuple(py, sequence, length, validator, extra, slots),
97-
Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots),
98-
Self::FrozenSet(sequence) => validate_to_vec_frozenset(py, sequence, length, validator, extra, slots),
99-
Self::JsonArray(sequence) => validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots),
98+
Self::List(sequence) => {
99+
validate_to_vec_list(py, sequence, length, validator, extra, slots, recursion_guard)
100+
}
101+
Self::Tuple(sequence) => {
102+
validate_to_vec_tuple(py, sequence, length, validator, extra, slots, recursion_guard)
103+
}
104+
Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots, recursion_guard),
105+
Self::FrozenSet(sequence) => {
106+
validate_to_vec_frozenset(py, sequence, length, validator, extra, slots, recursion_guard)
107+
}
108+
Self::JsonArray(sequence) => {
109+
validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots, recursion_guard)
110+
}
100111
}
101112
}
102113
}

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
1010
mod build_tools;
1111
mod errors;
1212
mod input;
13+
mod recursion_guard;
1314
mod validators;
1415

1516
// required for benchmarks

src/recursion_guard.rs

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use std::collections::HashSet;
2+
use std::hash::BuildHasherDefault;
3+
4+
use nohash_hasher::NoHashHasher;
5+
6+
/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
7+
/// It's used in `validators/recursive.rs` to detect when a reference is reused within itself.
8+
#[derive(Debug, Clone, Default)]
9+
pub struct RecursionGuard(Option<HashSet<usize, BuildHasherDefault<NoHashHasher<usize>>>>);
10+
11+
impl RecursionGuard {
12+
// insert a new id into the set, return whether the set already had the id in it
13+
pub fn contains_or_insert(&mut self, id: usize) -> bool {
14+
match self.0 {
15+
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
16+
// "If the set did not have this value present, `true` is returned."
17+
Some(ref mut set) => !set.insert(id),
18+
None => {
19+
let mut set: HashSet<usize, BuildHasherDefault<NoHashHasher<usize>>> =
20+
HashSet::with_capacity_and_hasher(10, BuildHasherDefault::default());
21+
set.insert(id);
22+
self.0 = Some(set);
23+
false
24+
}
25+
}
26+
}
27+
28+
pub fn remove(&mut self, id: &usize) {
29+
match self.0 {
30+
Some(ref mut set) => {
31+
set.remove(id);
32+
}
33+
None => unreachable!(),
34+
};
35+
}
36+
}

src/validators/any.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use pyo3::types::PyDict;
33

44
use crate::errors::ValResult;
55
use crate::input::Input;
6+
use crate::recursion_guard::RecursionGuard;
67

78
use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
89

@@ -29,6 +30,7 @@ impl Validator for AnyValidator {
2930
input: &'data impl Input<'data>,
3031
_extra: &Extra,
3132
_slots: &'data [CombinedValidator],
33+
_recursion_guard: &'s mut RecursionGuard,
3234
) -> ValResult<'data, PyObject> {
3335
// Ok(input.clone().into_py(py))
3436
Ok(input.to_object(py))

src/validators/bool.rs

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use pyo3::types::PyDict;
44
use crate::build_tools::is_strict;
55
use crate::errors::ValResult;
66
use crate::input::Input;
7+
use crate::recursion_guard::RecursionGuard;
78

89
use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
910

@@ -33,6 +34,7 @@ impl Validator for BoolValidator {
3334
input: &'data impl Input<'data>,
3435
_extra: &Extra,
3536
_slots: &'data [CombinedValidator],
37+
_recursion_guard: &'s mut RecursionGuard,
3638
) -> ValResult<'data, PyObject> {
3739
// TODO in theory this could be quicker if we used PyBool rather than going to a bool
3840
// and back again, might be worth profiling?
@@ -45,6 +47,7 @@ impl Validator for BoolValidator {
4547
input: &'data impl Input<'data>,
4648
_extra: &Extra,
4749
_slots: &'data [CombinedValidator],
50+
_recursion_guard: &'s mut RecursionGuard,
4851
) -> ValResult<'data, PyObject> {
4952
Ok(input.strict_bool()?.into_py(py))
5053
}
@@ -70,6 +73,7 @@ impl Validator for StrictBoolValidator {
7073
input: &'data impl Input<'data>,
7174
_extra: &Extra,
7275
_slots: &'data [CombinedValidator],
76+
_recursion_guard: &'s mut RecursionGuard,
7377
) -> ValResult<'data, PyObject> {
7478
Ok(input.strict_bool()?.into_py(py))
7579
}

0 commit comments

Comments
 (0)