From 64db44ddaff05fc35c27f33aea827259dc242a60 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 28 Jun 2022 21:19:51 +0100 Subject: [PATCH 1/9] adding recursion_guard argument --- src/input/return_enums.rs | 24 +++++++++++++------ src/validators/any.rs | 3 ++- src/validators/bool.rs | 5 +++- src/validators/bytes.rs | 7 +++++- src/validators/date.rs | 4 +++- src/validators/datetime.rs | 4 +++- src/validators/dict.rs | 20 ++++++++++------ src/validators/float.rs | 7 +++++- src/validators/frozenset.rs | 13 +++++++---- src/validators/function.rs | 16 +++++++++---- src/validators/int.rs | 7 +++++- src/validators/list.rs | 13 +++++++---- src/validators/literal.rs | 7 +++++- src/validators/mod.rs | 44 +++++++++++++++++++++++++++-------- src/validators/model_class.rs | 6 +++-- src/validators/none.rs | 3 ++- src/validators/nullable.rs | 10 +++++--- src/validators/recursive.rs | 8 ++++--- src/validators/set.rs | 13 +++++++---- src/validators/string.rs | 7 +++++- src/validators/time.rs | 4 +++- src/validators/tuple.rs | 22 ++++++++++++------ src/validators/typed_dict.rs | 14 ++++++----- src/validators/union.rs | 11 +++++---- 24 files changed, 195 insertions(+), 77 deletions(-) diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index cf8902f83..7317625bc 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -4,7 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple}; use crate::errors::{ValError, ValLineError, ValResult}; -use crate::validators::{CombinedValidator, Extra, Validator}; +use crate::validators::{CombinedValidator, Extra, RecursionGuard, Validator}; use super::parse_json::{JsonArray, JsonObject}; @@ -41,11 +41,12 @@ macro_rules! build_validate_to_vec { validator: &'s CombinedValidator, extra: &Extra, slots: &'a [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'a, Vec> { let mut output: Vec = Vec::with_capacity(length); let mut errors: Vec = Vec::new(); for (index, item) in sequence.iter().enumerate() { - match validator.validate(py, item, extra, slots) { + match validator.validate(py, item, extra, slots, recursion_guard) { Ok(item) => output.push(item), Err(ValError::LineErrors(line_errors)) => { errors.extend( @@ -90,13 +91,22 @@ impl<'a> GenericSequence<'a> { validator: &'s CombinedValidator, extra: &Extra, slots: &'a [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'a, Vec> { match self { - Self::List(sequence) => validate_to_vec_list(py, sequence, length, validator, extra, slots), - Self::Tuple(sequence) => validate_to_vec_tuple(py, sequence, length, validator, extra, slots), - Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots), - Self::FrozenSet(sequence) => validate_to_vec_frozenset(py, sequence, length, validator, extra, slots), - Self::JsonArray(sequence) => validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots), + Self::List(sequence) => { + validate_to_vec_list(py, sequence, length, validator, extra, slots, recursion_guard) + } + Self::Tuple(sequence) => { + validate_to_vec_tuple(py, sequence, length, validator, extra, slots, recursion_guard) + } + Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots, recursion_guard), + Self::FrozenSet(sequence) => { + validate_to_vec_frozenset(py, sequence, length, validator, extra, slots, recursion_guard) + } + Self::JsonArray(sequence) => { + validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots, recursion_guard) + } } } } diff --git a/src/validators/any.rs b/src/validators/any.rs index bedb6bd7b..f31d8ed49 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -4,7 +4,7 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::Input; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; /// This might seem useless, but it's useful in DictValidator to avoid Option a lot #[derive(Debug, Clone)] @@ -29,6 +29,7 @@ impl Validator for AnyValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { // Ok(input.clone().into_py(py)) Ok(input.to_object(py)) diff --git a/src/validators/bool.rs b/src/validators/bool.rs index 9b888aad1..f8197c435 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -5,7 +5,7 @@ use crate::build_tools::is_strict; use crate::errors::ValResult; use crate::input::Input; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct BoolValidator; @@ -33,6 +33,7 @@ impl Validator for BoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { // TODO in theory this could be quicker if we used PyBool rather than going to a bool // and back again, might be worth profiling? @@ -45,6 +46,7 @@ impl Validator for BoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_bool()?.into_py(py)) } @@ -70,6 +72,7 @@ impl Validator for StrictBoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_bool()?.into_py(py)) } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index c9b3510e7..1e193ed4b 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -5,7 +5,7 @@ use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherBytes, Input}; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct BytesValidator; @@ -36,6 +36,7 @@ impl Validator for BytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.lax_bytes()?; Ok(either_bytes.into_py(py)) @@ -47,6 +48,7 @@ impl Validator for BytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.strict_bytes()?; Ok(either_bytes.into_py(py)) @@ -73,6 +75,7 @@ impl Validator for StrictBytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.strict_bytes()?; Ok(either_bytes.into_py(py)) @@ -97,6 +100,7 @@ impl Validator for BytesConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let bytes = match self.strict { true => input.strict_bytes()?, @@ -111,6 +115,7 @@ impl Validator for BytesConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_bytes()?) } diff --git a/src/validators/date.rs b/src/validators/date.rs index 39eb312e8..82bf5d688 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -6,7 +6,7 @@ use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, ValResult}; use crate::input::{EitherDate, Input}; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct DateValidator { @@ -58,6 +58,7 @@ impl Validator for DateValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let date = match self.strict { true => input.strict_date()?, @@ -80,6 +81,7 @@ impl Validator for DateValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_date()?) } diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 6f3b5465b..53fe709c6 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -6,7 +6,7 @@ use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherDateTime, Input}; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct DateTimeValidator { @@ -58,6 +58,7 @@ impl Validator for DateTimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let date = match self.strict { true => input.strict_datetime()?, @@ -72,6 +73,7 @@ impl Validator for DateTimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_datetime()?) } diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 473035965..039d82e78 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -6,7 +6,7 @@ use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, Va use crate::input::{GenericMapping, Input, JsonObject}; use super::any::AnyValidator; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct DictValidator { @@ -49,12 +49,13 @@ impl Validator for DictValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let dict = match self.strict { true => input.strict_dict()?, false => input.lax_dict()?, }; - self._validation_logic(py, input, dict, extra, slots) + self._validation_logic(py, input, dict, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -63,8 +64,9 @@ impl Validator for DictValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_dict()?, extra, slots) + self._validation_logic(py, input, input.strict_dict()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -86,6 +88,7 @@ macro_rules! build_validate { dict: &'data $dict_type, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if let Some(min_length) = self.min_items { if dict.len() < min_length { @@ -112,7 +115,7 @@ macro_rules! build_validate { let value_validator = self.value_validator.as_ref(); for (key, value) in dict.iter() { - let output_key = match key_validator.validate(py, key, extra, slots) { + let output_key = match key_validator.validate(py, key, extra, slots, recursion_guard) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -125,7 +128,7 @@ macro_rules! build_validate { } Err(err) => return Err(err), }; - let output_value = match value_validator.validate(py, value, extra, slots) { + let output_value = match value_validator.validate(py, value, extra, slots, recursion_guard) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -160,11 +163,14 @@ impl DictValidator { dict: GenericMapping<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match dict { - GenericMapping::PyDict(py_dict) => self.validate_dict(py, input, py_dict, extra, slots), + GenericMapping::PyDict(py_dict) => self.validate_dict(py, input, py_dict, extra, slots, recursion_guard), GenericMapping::PyGetAttr(_) => unreachable!(), - GenericMapping::JsonObject(json_object) => self.validate_json_object(py, input, json_object, extra, slots), + GenericMapping::JsonObject(json_object) => { + self.validate_json_object(py, input, json_object, extra, slots, recursion_guard) + } } } } diff --git a/src/validators/float.rs b/src/validators/float.rs index 13faeb15a..c5ea4b284 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -5,7 +5,7 @@ use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct FloatValidator; @@ -40,6 +40,7 @@ impl Validator for FloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_float()?.into_py(py)) } @@ -50,6 +51,7 @@ impl Validator for FloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_float()?.into_py(py)) } @@ -75,6 +77,7 @@ impl Validator for StrictFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_float()?.into_py(py)) } @@ -101,6 +104,7 @@ impl Validator for ConstrainedFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let float = match self.strict { true => input.strict_float()?, @@ -115,6 +119,7 @@ impl Validator for ConstrainedFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_float()?) } diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index c977c03d3..e01b284dd 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -7,7 +7,9 @@ use crate::input::{GenericSequence, Input}; use super::any::AnyValidator; use super::list::sequence_build_function; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; +use super::{ + build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, +}; #[derive(Debug, Clone)] pub struct FrozenSetValidator { @@ -29,12 +31,13 @@ impl Validator for FrozenSetValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let frozenset = match self.strict { true => input.strict_frozenset()?, false => input.lax_frozenset()?, }; - self._validation_logic(py, input, frozenset, extra, slots) + self._validation_logic(py, input, frozenset, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -43,8 +46,9 @@ impl Validator for FrozenSetValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_frozenset()?, extra, slots) + self._validation_logic(py, input, input.strict_frozenset()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -60,6 +64,7 @@ impl FrozenSetValidator { list: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let length = list.generic_len(); if let Some(min_length) = self.min_items { @@ -81,7 +86,7 @@ impl FrozenSetValidator { } } - let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots)?; + let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots, recursion_guard)?; Ok(PyFrozenSet::new(py, &output).map_err(as_internal)?.into_py(py)) } } diff --git a/src/validators/function.rs b/src/validators/function.rs index 094a27cdb..3971a7737 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -6,7 +6,7 @@ use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{context, val_line_error, ErrorKind, ValError, ValResult, ValidationError}; use crate::input::Input; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug)] pub struct FunctionBuilder; @@ -71,6 +71,7 @@ impl Validator for FunctionBeforeValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let kwargs = kwargs!(py, "data" => extra.data, "config" => self.config.as_ref()); let value = self @@ -79,7 +80,7 @@ impl Validator for FunctionBeforeValidator { .map_err(|e| convert_err(py, e, input))?; // maybe there's some way to get the PyAny here and explicitly tell rust it should have lifespan 'a? let new_input: &PyAny = value.as_ref(py); - match self.validator.validate(py, new_input, extra, slots) { + match self.validator.validate(py, new_input, extra, slots, recursion_guard) { Ok(v) => Ok(v), Err(ValError::InternalErr(err)) => Err(ValError::InternalErr(err)), Err(ValError::LineErrors(line_errors)) => { @@ -116,8 +117,9 @@ impl Validator for FunctionAfterValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let v = self.validator.validate(py, input, extra, slots)?; + let v = self.validator.validate(py, input, extra, slots, recursion_guard)?; let kwargs = kwargs!(py, "data" => extra.data, "config" => self.config.as_ref()); self.func.call(py, (v,), kwargs).map_err(|e| convert_err(py, e, input)) } @@ -154,6 +156,7 @@ impl Validator for FunctionPlainValidator { input: &'data impl Input<'data>, extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let kwargs = kwargs!(py, "data" => extra.data, "config" => self.config.as_ref()); self.func @@ -182,12 +185,14 @@ impl Validator for FunctionWrapValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let validator_kwarg = ValidatorCallable { validator: self.validator.clone(), slots: slots.to_vec(), data: extra.data.map(|d| d.into_py(py)), field: extra.field.map(|f| f.to_string()), + recursion_guard: recursion_guard.clone(), }; let kwargs = kwargs!( py, @@ -212,17 +217,18 @@ struct ValidatorCallable { slots: Vec, data: Option>, field: Option, + recursion_guard: RecursionGuard, } #[pymethods] impl ValidatorCallable { - fn __call__(&self, py: Python, arg: &PyAny) -> PyResult { + fn __call__(&mut self, py: Python, arg: &PyAny) -> PyResult { let extra = Extra { data: self.data.as_ref().map(|data| data.as_ref(py)), field: self.field.as_deref(), }; self.validator - .validate(py, arg, &extra, &self.slots) + .validate(py, arg, &extra, &self.slots, &mut self.recursion_guard) .map_err(|e| ValidationError::from_val_error(py, "Model".to_object(py), e)) } diff --git a/src/validators/int.rs b/src/validators/int.rs index 21b0718f4..46ab4a718 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -5,7 +5,7 @@ use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct IntValidator; @@ -40,6 +40,7 @@ impl Validator for IntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_int()?.into_py(py)) } @@ -50,6 +51,7 @@ impl Validator for IntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_int()?.into_py(py)) } @@ -75,6 +77,7 @@ impl Validator for StrictIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_int()?.into_py(py)) } @@ -101,6 +104,7 @@ impl Validator for ConstrainedIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let int = match self.strict { true => input.strict_int()?, @@ -115,6 +119,7 @@ impl Validator for ConstrainedIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_int()?) } diff --git a/src/validators/list.rs b/src/validators/list.rs index c7bb22513..3718fc0fb 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -6,7 +6,9 @@ use crate::errors::{context, err_val_error, ErrorKind}; use crate::input::{GenericSequence, Input}; use super::any::AnyValidator; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; +use super::{ + build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, +}; #[derive(Debug, Clone)] pub struct ListValidator { @@ -50,12 +52,13 @@ impl Validator for ListValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let list = match self.strict { true => input.strict_list()?, false => input.lax_list()?, }; - self._validation_logic(py, input, list, extra, slots) + self._validation_logic(py, input, list, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -64,8 +67,9 @@ impl Validator for ListValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_list()?, extra, slots) + self._validation_logic(py, input, input.strict_list()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -81,6 +85,7 @@ impl ListValidator { list: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let length = list.generic_len(); if let Some(min_length) = self.min_items { @@ -102,7 +107,7 @@ impl ListValidator { } } - let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots)?; + let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots, recursion_guard)?; Ok(output.into_py(py)) } } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index aee203d05..db9958a49 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -7,7 +7,7 @@ use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug)] pub struct LiteralBuilder; @@ -63,6 +63,7 @@ impl Validator for LiteralSingleStringValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_str = input.strict_str()?; if either_str.as_cow().as_ref() == self.expected.as_str() { @@ -99,6 +100,7 @@ impl Validator for LiteralSingleIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let str = input.strict_int()?; if str == self.expected { @@ -150,6 +152,7 @@ impl Validator for LiteralMultipleStringsValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_str = input.strict_str()?; if self.expected.contains(either_str.as_cow().as_ref()) { @@ -201,6 +204,7 @@ impl Validator for LiteralMultipleIntsValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let int = input.strict_int()?; if self.expected.contains(&int) { @@ -260,6 +264,7 @@ impl Validator for LiteralGeneralValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if !self.expected_int.is_empty() { if let Ok(int) = input.strict_int() { diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 33ac898cb..0276fd492 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -1,3 +1,4 @@ +use ahash::AHashSet; use std::fmt::Debug; use enum_dispatch::enum_dispatch; @@ -74,12 +75,19 @@ impl SchemaValidator { } pub fn validate_python(&self, py: Python, input: &PyAny) -> PyResult { - let r = self.validator.validate(py, input, &Extra::default(), &self.slots); + let mut recursion_guard = RecursionGuard::new(); + let r = self + .validator + .validate(py, input, &Extra::default(), &self.slots, &mut recursion_guard); r.map_err(|e| self.prepare_validation_err(py, e)) } pub fn isinstance_python(&self, py: Python, input: &PyAny) -> PyResult { - match self.validator.validate(py, input, &Extra::default(), &self.slots) { + let mut recursion_guard = RecursionGuard::new(); + match self + .validator + .validate(py, input, &Extra::default(), &self.slots, &mut recursion_guard) + { Ok(_) => Ok(true), Err(ValError::InternalErr(err)) => Err(err), _ => Ok(false), @@ -89,7 +97,10 @@ impl SchemaValidator { pub fn validate_json(&self, py: Python, input: String) -> PyResult { match parse_json::(&input) { Ok(input) => { - let r = self.validator.validate(py, &input, &Extra::default(), &self.slots); + let mut recursion_guard = RecursionGuard::new(); + let r = self + .validator + .validate(py, &input, &Extra::default(), &self.slots, &mut recursion_guard); r.map_err(|e| self.prepare_validation_err(py, e)) } Err(e) => { @@ -106,11 +117,17 @@ impl SchemaValidator { pub fn isinstance_json(&self, py: Python, input: String) -> PyResult { match parse_json::(&input) { - Ok(input) => match self.validator.validate(py, &input, &Extra::default(), &self.slots) { - Ok(_) => Ok(true), - Err(ValError::InternalErr(err)) => Err(err), - _ => Ok(false), - }, + Ok(input) => { + let mut recursion_guard = RecursionGuard::new(); + match self + .validator + .validate(py, &input, &Extra::default(), &self.slots, &mut recursion_guard) + { + Ok(_) => Ok(true), + Err(ValError::InternalErr(err)) => Err(err), + _ => Ok(false), + } + } Err(_) => Ok(false), } } @@ -120,7 +137,10 @@ impl SchemaValidator { data: Some(data), field: Some(field.as_str()), }; - let r = self.validator.validate(py, input, &extra, &self.slots); + let mut recursion_guard = RecursionGuard::new(); + let r = self + .validator + .validate(py, input, &extra, &self.slots, &mut recursion_guard); r.map_err(|e| self.prepare_validation_err(py, e)) } @@ -333,6 +353,8 @@ pub enum CombinedValidator { FrozenSet(frozenset::FrozenSetValidator), } +pub type RecursionGuard = AHashSet; + /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, /// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait #[enum_dispatch(CombinedValidator)] @@ -344,6 +366,7 @@ pub trait Validator: Send + Sync + Clone + Debug { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject>; /// This is used in unions for the first pass to see if we have an "exact match", @@ -354,8 +377,9 @@ pub trait Validator: Send + Sync + Clone + Debug { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self.validate(py, input, extra, slots) + self.validate(py, input, extra, slots, recursion_guard) } /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator diff --git a/src/validators/model_class.rs b/src/validators/model_class.rs index f167b5edf..a1aeba4dd 100644 --- a/src/validators/model_class.rs +++ b/src/validators/model_class.rs @@ -11,7 +11,7 @@ use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, ValResult}; use crate::input::Input; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct ModelClassValidator { @@ -59,6 +59,7 @@ impl Validator for ModelClassValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let class = self.class.as_ref(py); if input.strict_model_check(class)? { @@ -70,7 +71,7 @@ impl Validator for ModelClassValidator { context = context!("class_name" => self.get_name(py)) ) } else { - let output = self.validator.validate(py, input, extra, slots)?; + let output = self.validator.validate(py, input, extra, slots, recursion_guard)?; self.create_class(py, output).map_err(as_internal) } } @@ -81,6 +82,7 @@ impl Validator for ModelClassValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if input.strict_model_check(self.class.as_ref(py))? { Ok(input.to_object(py)) diff --git a/src/validators/none.rs b/src/validators/none.rs index 94b311677..90c8ec97f 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -4,7 +4,7 @@ use pyo3::types::PyDict; use crate::errors::{err_val_error, ErrorKind, ValResult}; use crate::input::Input; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct NoneValidator; @@ -28,6 +28,7 @@ impl Validator for NoneValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 47cef2a29..7aa5e6c13 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -4,7 +4,9 @@ use pyo3::types::PyDict; use crate::build_tools::SchemaDict; use crate::input::Input; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; +use super::{ + build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, +}; #[derive(Debug, Clone)] pub struct NullableValidator { @@ -34,10 +36,11 @@ impl Validator for NullableValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), - false => self.validator.validate(py, input, extra, slots), + false => self.validator.validate(py, input, extra, slots, recursion_guard), } } @@ -47,10 +50,11 @@ impl Validator for NullableValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), - false => self.validator.validate_strict(py, input, extra, slots), + false => self.validator.validate_strict(py, input, extra, slots, recursion_guard), } } diff --git a/src/validators/recursive.rs b/src/validators/recursive.rs index 5de7c2607..d44538978 100644 --- a/src/validators/recursive.rs +++ b/src/validators/recursive.rs @@ -5,7 +5,7 @@ use crate::build_tools::SchemaDict; use crate::errors::ValResult; use crate::input::Input; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct RecursiveContainerValidator { @@ -25,9 +25,10 @@ impl Validator for RecursiveContainerValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let validator = unsafe { slots.get_unchecked(self.validator_id) }; - validator.validate(py, input, extra, slots) + validator.validate(py, input, extra, slots, recursion_guard) } fn get_name(&self, _py: Python) -> String { @@ -61,9 +62,10 @@ impl Validator for RecursiveRefValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let validator = unsafe { slots.get_unchecked(self.validator_id) }; - validator.validate(py, input, extra, slots) + validator.validate(py, input, extra, slots, recursion_guard) } fn get_name(&self, _py: Python) -> String { diff --git a/src/validators/set.rs b/src/validators/set.rs index fd40b71d6..672c7bbbf 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -7,7 +7,9 @@ use crate::input::{GenericSequence, Input}; use super::any::AnyValidator; use super::list::sequence_build_function; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; +use super::{ + build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, +}; #[derive(Debug, Clone)] pub struct SetValidator { @@ -29,12 +31,13 @@ impl Validator for SetValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let set = match self.strict { true => input.strict_set()?, false => input.lax_set()?, }; - self._validation_logic(py, input, set, extra, slots) + self._validation_logic(py, input, set, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -43,8 +46,9 @@ impl Validator for SetValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_set()?, extra, slots) + self._validation_logic(py, input, input.strict_set()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -60,6 +64,7 @@ impl SetValidator { list: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let length = list.generic_len(); if let Some(min_length) = self.min_items { @@ -81,7 +86,7 @@ impl SetValidator { } } - let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots)?; + let output = list.validate_to_vec(py, length, &self.item_validator, extra, slots, recursion_guard)?; Ok(PySet::new(py, &output).map_err(as_internal)?.into_py(py)) } } diff --git a/src/validators/string.rs b/src/validators/string.rs index 678f015b7..a6ff31d14 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -6,7 +6,7 @@ use crate::build_tools::{is_strict, py_error, schema_or_config}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherString, Input}; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct StrValidator; @@ -53,6 +53,7 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_str()?.into_py(py)) } @@ -63,6 +64,7 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_str()?.into_py(py)) } @@ -88,6 +90,7 @@ impl Validator for StrictStrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_str()?.into_py(py)) } @@ -115,6 +118,7 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let str = match self.strict { true => input.strict_str()?, @@ -129,6 +133,7 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_str()?) } diff --git a/src/validators/time.rs b/src/validators/time.rs index 871d432c9..493bbd9b3 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -6,7 +6,7 @@ use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherTime, Input}; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] pub struct TimeValidator { @@ -58,6 +58,7 @@ impl Validator for TimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let time = match self.strict { true => input.strict_time()?, @@ -72,6 +73,7 @@ impl Validator for TimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_time()?) } diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index f5f7e329c..e2a5404e8 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -7,7 +7,9 @@ use crate::input::{GenericSequence, Input}; use super::any::AnyValidator; use super::list::sequence_build_function; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; +use super::{ + build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, +}; #[derive(Debug, Clone)] pub struct TupleVarLenValidator { @@ -29,12 +31,13 @@ impl Validator for TupleVarLenValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let tuple = match self.strict { true => input.strict_tuple()?, false => input.lax_tuple()?, }; - self._validation_logic(py, input, tuple, extra, slots) + self._validation_logic(py, input, tuple, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -43,8 +46,9 @@ impl Validator for TupleVarLenValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_tuple()?, extra, slots) + self._validation_logic(py, input, input.strict_tuple()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -60,6 +64,7 @@ impl TupleVarLenValidator { tuple: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let length = tuple.generic_len(); if let Some(min_length) = self.min_items { @@ -81,7 +86,7 @@ impl TupleVarLenValidator { } } - let output = tuple.validate_to_vec(py, length, &self.item_validator, extra, slots)?; + let output = tuple.validate_to_vec(py, length, &self.item_validator, extra, slots, recursion_guard)?; Ok(PyTuple::new(py, &output).into_py(py)) } } @@ -124,12 +129,13 @@ impl Validator for TupleFixLenValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let tuple = match self.strict { true => input.strict_tuple()?, false => input.lax_tuple()?, }; - self._validation_logic(py, input, tuple, extra, slots) + self._validation_logic(py, input, tuple, extra, slots, recursion_guard) } fn validate_strict<'s, 'data>( @@ -138,8 +144,9 @@ impl Validator for TupleFixLenValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - self._validation_logic(py, input, input.strict_tuple()?, extra, slots) + self._validation_logic(py, input, input.strict_tuple()?, extra, slots, recursion_guard) } fn get_name(&self, py: Python) -> String { @@ -161,6 +168,7 @@ impl TupleFixLenValidator { tuple: GenericSequence<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let expected_length = self.items_validators.len(); @@ -181,7 +189,7 @@ impl TupleFixLenValidator { macro_rules! iter { ($sequence:expr) => { for (validator, (index, item)) in self.items_validators.iter().zip($sequence.iter().enumerate()) { - match validator.validate(py, item, extra, slots) { + match validator.validate(py, item, extra, slots, recursion_guard) { Ok(item) => output.push(item), Err(ValError::LineErrors(line_errors)) => { errors.extend( diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 6783bba1a..54122c915 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -12,7 +12,7 @@ use crate::errors::{ use crate::input::{GenericMapping, Input, JsonInput, JsonObject}; use crate::SchemaError; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; #[derive(Debug, Clone)] struct TypedDictField { @@ -127,10 +127,11 @@ impl Validator for TypedDictValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if let Some(field) = extra.field { // we're validating assignment, completely different logic - return self.validate_assignment(py, field, input, extra, slots); + return self.validate_assignment(py, field, input, extra, slots, recursion_guard); } let dict = input.typed_dict(self.from_attributes, !self.strict)?; @@ -176,7 +177,7 @@ impl Validator for TypedDictValidator { // extra logic either way used_keys.insert(used_key); } - match field.validator.validate(py, value, &extra, slots) { + match field.validator.validate(py, value, &extra, slots, recursion_guard) { Ok(value) => { output_dict .set_item(&field.name_pystring, value) @@ -243,7 +244,7 @@ impl Validator for TypedDictValidator { } if let Some(ref validator) = self.extra_validator { - match validator.validate(py, value, &extra, slots) { + match validator.validate(py, value, &extra, slots, recursion_guard) { Ok(value) => { output_dict.set_item(py_key, value).map_err(as_internal)?; if let Some(ref mut fs) = fields_set_vec { @@ -298,6 +299,7 @@ impl TypedDictValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> where 'data: 's, @@ -331,11 +333,11 @@ impl TypedDictValidator { }; if let Some(field) = self.fields.iter().find(|f| f.name == field) { - prepare_result(field.validator.validate(py, input, extra, slots)) + prepare_result(field.validator.validate(py, input, extra, slots, recursion_guard)) } else if self.check_extra && !self.forbid_extra { // this is the "allow" case of extra_behavior match self.extra_validator { - Some(ref validator) => prepare_result(validator.validate(py, input, extra, slots)), + Some(ref validator) => prepare_result(validator.validate(py, input, extra, slots, recursion_guard)), None => prepare_tuple(input.to_object(py)), } } else { diff --git a/src/validators/union.rs b/src/validators/union.rs index 2d4f05ebd..6c794688a 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -5,7 +5,9 @@ use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{ValError, ValLineError}; use crate::input::Input; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; +use super::{ + build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, +}; #[derive(Debug, Clone)] pub struct UnionValidator { @@ -38,12 +40,13 @@ impl Validator for UnionValidator { input: &'data impl Input<'data>, extra: &Extra, slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if self.strict { let mut errors: Vec = Vec::with_capacity(self.choices.len()); for validator in &self.choices { - let line_errors = match validator.validate_strict(py, input, extra, slots) { + let line_errors = match validator.validate_strict(py, input, extra, slots, recursion_guard) { Err(ValError::LineErrors(line_errors)) => line_errors, otherwise => return otherwise, }; @@ -61,7 +64,7 @@ impl Validator for UnionValidator { if let Some(res) = self .choices .iter() - .map(|validator| validator.validate_strict(py, input, extra, slots)) + .map(|validator| validator.validate_strict(py, input, extra, slots, recursion_guard)) .find(ValResult::is_ok) { return res; @@ -71,7 +74,7 @@ impl Validator for UnionValidator { // 2nd pass: check if the value can be coerced into one of the Union types for validator in &self.choices { - let line_errors = match validator.validate(py, input, extra, slots) { + let line_errors = match validator.validate(py, input, extra, slots, recursion_guard) { Err(ValError::LineErrors(line_errors)) => line_errors, otherwise => return otherwise, }; From 7e8b175e03f3edc000244bf9b96c5af2fd082a63 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 28 Jun 2022 22:05:23 +0100 Subject: [PATCH 2/9] fix linting, start on logic --- src/errors/kinds.rs | 4 ++++ src/input/input_abstract.rs | 4 ++++ src/input/input_python.rs | 6 +++++- src/validators/any.rs | 2 +- src/validators/bool.rs | 6 +++--- src/validators/bytes.rs | 10 +++++----- src/validators/date.rs | 4 ++-- src/validators/datetime.rs | 4 ++-- src/validators/float.rs | 10 +++++----- src/validators/function.rs | 2 +- src/validators/int.rs | 10 +++++----- src/validators/literal.rs | 10 +++++----- src/validators/model_class.rs | 2 +- src/validators/none.rs | 2 +- src/validators/recursive.rs | 10 +++++++++- src/validators/string.rs | 10 +++++----- src/validators/time.rs | 4 ++-- 17 files changed, 60 insertions(+), 40 deletions(-) diff --git a/src/errors/kinds.rs b/src/errors/kinds.rs index b155391d8..77c9197a3 100644 --- a/src/errors/kinds.rs +++ b/src/errors/kinds.rs @@ -8,6 +8,10 @@ pub enum ErrorKind { #[strum(message = "Invalid JSON: {parser_error}")] InvalidJson, // --------------------- + // recursion error + #[strum(message = "Recursion error - cyclic reference detected")] + RecursionLoop, + // --------------------- // typed dict specific errors #[strum(message = "Value must be a valid dictionary or instance to extract fields from")] DictAttributesType, diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 9eeae3a89..e21d20312 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -16,6 +16,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn as_error_value(&'a self) -> InputValue<'a>; + fn identity(&'a self) -> Option { + None + } + fn is_none(&self) -> bool; fn strict_str<'data>(&'data self) -> ValResult>; diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 2e9bbf508..9b578ec04 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -1,12 +1,12 @@ use std::str::from_utf8; use pyo3::exceptions::{PyAttributeError, PyTypeError}; -use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{ PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyType, }; +use pyo3::{intern, AsPyPointer}; use crate::errors::location::LocItem; use crate::errors::{as_internal, context, err_val_error, py_err_string, ErrorKind, InputValue, ValResult}; @@ -36,6 +36,10 @@ impl<'a> Input<'a> for PyAny { InputValue::PyAny(self) } + fn identity(&'a self) -> Option { + Some(self.as_ptr() as isize) + } + fn is_none(&self) -> bool { self.is_none() } diff --git a/src/validators/any.rs b/src/validators/any.rs index f31d8ed49..2217c97e2 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -29,7 +29,7 @@ impl Validator for AnyValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { // Ok(input.clone().into_py(py)) Ok(input.to_object(py)) diff --git a/src/validators/bool.rs b/src/validators/bool.rs index f8197c435..a73ba87c2 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -33,7 +33,7 @@ impl Validator for BoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { // TODO in theory this could be quicker if we used PyBool rather than going to a bool // and back again, might be worth profiling? @@ -46,7 +46,7 @@ impl Validator for BoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_bool()?.into_py(py)) } @@ -72,7 +72,7 @@ impl Validator for StrictBoolValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_bool()?.into_py(py)) } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 1e193ed4b..651f03e13 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -36,7 +36,7 @@ impl Validator for BytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.lax_bytes()?; Ok(either_bytes.into_py(py)) @@ -48,7 +48,7 @@ impl Validator for BytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.strict_bytes()?; Ok(either_bytes.into_py(py)) @@ -75,7 +75,7 @@ impl Validator for StrictBytesValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_bytes = input.strict_bytes()?; Ok(either_bytes.into_py(py)) @@ -100,7 +100,7 @@ impl Validator for BytesConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let bytes = match self.strict { true => input.strict_bytes()?, @@ -115,7 +115,7 @@ impl Validator for BytesConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_bytes()?) } diff --git a/src/validators/date.rs b/src/validators/date.rs index 82bf5d688..73d5cc989 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -58,7 +58,7 @@ impl Validator for DateValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let date = match self.strict { true => input.strict_date()?, @@ -81,7 +81,7 @@ impl Validator for DateValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_date()?) } diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 53fe709c6..b1176857e 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -58,7 +58,7 @@ impl Validator for DateTimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let date = match self.strict { true => input.strict_datetime()?, @@ -73,7 +73,7 @@ impl Validator for DateTimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_datetime()?) } diff --git a/src/validators/float.rs b/src/validators/float.rs index c5ea4b284..139e8fc49 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -40,7 +40,7 @@ impl Validator for FloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_float()?.into_py(py)) } @@ -51,7 +51,7 @@ impl Validator for FloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_float()?.into_py(py)) } @@ -77,7 +77,7 @@ impl Validator for StrictFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_float()?.into_py(py)) } @@ -104,7 +104,7 @@ impl Validator for ConstrainedFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let float = match self.strict { true => input.strict_float()?, @@ -119,7 +119,7 @@ impl Validator for ConstrainedFloatValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_float()?) } diff --git a/src/validators/function.rs b/src/validators/function.rs index 3971a7737..23ef8d8b6 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -156,7 +156,7 @@ impl Validator for FunctionPlainValidator { input: &'data impl Input<'data>, extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let kwargs = kwargs!(py, "data" => extra.data, "config" => self.config.as_ref()); self.func diff --git a/src/validators/int.rs b/src/validators/int.rs index 46ab4a718..35663f646 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -40,7 +40,7 @@ impl Validator for IntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_int()?.into_py(py)) } @@ -51,7 +51,7 @@ impl Validator for IntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_int()?.into_py(py)) } @@ -77,7 +77,7 @@ impl Validator for StrictIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_int()?.into_py(py)) } @@ -104,7 +104,7 @@ impl Validator for ConstrainedIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let int = match self.strict { true => input.strict_int()?, @@ -119,7 +119,7 @@ impl Validator for ConstrainedIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_int()?) } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index db9958a49..963647851 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -63,7 +63,7 @@ impl Validator for LiteralSingleStringValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_str = input.strict_str()?; if either_str.as_cow().as_ref() == self.expected.as_str() { @@ -100,7 +100,7 @@ impl Validator for LiteralSingleIntValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let str = input.strict_int()?; if str == self.expected { @@ -152,7 +152,7 @@ impl Validator for LiteralMultipleStringsValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let either_str = input.strict_str()?; if self.expected.contains(either_str.as_cow().as_ref()) { @@ -204,7 +204,7 @@ impl Validator for LiteralMultipleIntsValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let int = input.strict_int()?; if self.expected.contains(&int) { @@ -264,7 +264,7 @@ impl Validator for LiteralGeneralValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if !self.expected_int.is_empty() { if let Ok(int) = input.strict_int() { diff --git a/src/validators/model_class.rs b/src/validators/model_class.rs index a1aeba4dd..823eb9bc0 100644 --- a/src/validators/model_class.rs +++ b/src/validators/model_class.rs @@ -82,7 +82,7 @@ impl Validator for ModelClassValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if input.strict_model_check(self.class.as_ref(py))? { Ok(input.to_object(py)) diff --git a/src/validators/none.rs b/src/validators/none.rs index 90c8ec97f..a06e3b148 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -28,7 +28,7 @@ impl Validator for NoneValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { match input.is_none() { true => Ok(py.None()), diff --git a/src/validators/recursive.rs b/src/validators/recursive.rs index d44538978..3d870b440 100644 --- a/src/validators/recursive.rs +++ b/src/validators/recursive.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::SchemaDict; -use crate::errors::ValResult; +use crate::errors::{err_val_error, ErrorKind, ValResult}; use crate::input::Input; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; @@ -64,6 +64,14 @@ impl Validator for RecursiveRefValidator { slots: &'data [CombinedValidator], recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { + if let Some(id) = input.identity() { + eprintln!("---------------"); + dbg!(&recursion_guard, &input, &id); + if recursion_guard.contains(&id) { + return err_val_error!(kind = ErrorKind::RecursionLoop, input_value = input.as_error_value(),); + } + recursion_guard.insert(id); + } let validator = unsafe { slots.get_unchecked(self.validator_id) }; validator.validate(py, input, extra, slots, recursion_guard) } diff --git a/src/validators/string.rs b/src/validators/string.rs index a6ff31d14..72f137b6d 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -53,7 +53,7 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.lax_str()?.into_py(py)) } @@ -64,7 +64,7 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_str()?.into_py(py)) } @@ -90,7 +90,7 @@ impl Validator for StrictStrValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { Ok(input.strict_str()?.into_py(py)) } @@ -118,7 +118,7 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let str = match self.strict { true => input.strict_str()?, @@ -133,7 +133,7 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self._validation_logic(py, input, input.strict_str()?) } diff --git a/src/validators/time.rs b/src/validators/time.rs index 493bbd9b3..b32a555bf 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -58,7 +58,7 @@ impl Validator for TimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { let time = match self.strict { true => input.strict_time()?, @@ -73,7 +73,7 @@ impl Validator for TimeValidator { input: &'data impl Input<'data>, _extra: &Extra, _slots: &'data [CombinedValidator], - recursion_guard: &'s mut RecursionGuard, + _recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { self.validation_comparison(py, input, input.strict_time()?) } From 49a041f32be0332fdab038d3e2e9be364144c07f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 28 Jun 2022 23:25:01 +0100 Subject: [PATCH 3/9] basic recursion implementation working --- src/input/input_abstract.rs | 2 +- src/input/input_python.rs | 4 ++-- src/validators/mod.rs | 2 +- src/validators/recursive.rs | 37 ++++++++++++++++++++---------- tests/validators/test_recursive.py | 35 ++++++++++++++++++++++++++++ 5 files changed, 64 insertions(+), 16 deletions(-) diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index e21d20312..ecfa77a52 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -16,7 +16,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn as_error_value(&'a self) -> InputValue<'a>; - fn identity(&'a self) -> Option { + fn identity(&'a self) -> Option { None } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 9b578ec04..a4cabe4e3 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -36,8 +36,8 @@ impl<'a> Input<'a> for PyAny { InputValue::PyAny(self) } - fn identity(&'a self) -> Option { - Some(self.as_ptr() as isize) + fn identity(&'a self) -> Option { + Some(self.as_ptr() as usize) } fn is_none(&self) -> bool { diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 0276fd492..8e1482bf2 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -353,7 +353,7 @@ pub enum CombinedValidator { FrozenSet(frozenset::FrozenSetValidator), } -pub type RecursionGuard = AHashSet; +pub type RecursionGuard = AHashSet; /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, /// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait diff --git a/src/validators/recursive.rs b/src/validators/recursive.rs index 3d870b440..1494e6703 100644 --- a/src/validators/recursive.rs +++ b/src/validators/recursive.rs @@ -27,8 +27,7 @@ impl Validator for RecursiveContainerValidator { slots: &'data [CombinedValidator], recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - let validator = unsafe { slots.get_unchecked(self.validator_id) }; - validator.validate(py, input, extra, slots, recursion_guard) + validate(self.validator_id, py, input, extra, slots, recursion_guard) } fn get_name(&self, _py: Python) -> String { @@ -64,19 +63,33 @@ impl Validator for RecursiveRefValidator { slots: &'data [CombinedValidator], recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { - if let Some(id) = input.identity() { - eprintln!("---------------"); - dbg!(&recursion_guard, &input, &id); - if recursion_guard.contains(&id) { - return err_val_error!(kind = ErrorKind::RecursionLoop, input_value = input.as_error_value(),); - } - recursion_guard.insert(id); - } - let validator = unsafe { slots.get_unchecked(self.validator_id) }; - validator.validate(py, input, extra, slots, recursion_guard) + validate(self.validator_id, py, input, extra, slots, recursion_guard) } fn get_name(&self, _py: Python) -> String { Self::EXPECTED_TYPE.to_string() } } + +fn validate<'s, 'data>( + validator_id: usize, + py: Python<'data>, + input: &'data impl Input<'data>, + extra: &Extra, + slots: &'data [CombinedValidator], + recursion_guard: &'s mut RecursionGuard, +) -> ValResult<'data, PyObject> { + if let Some(id) = input.identity() { + if recursion_guard.contains(&id) { + return err_val_error!(kind = ErrorKind::RecursionLoop, input_value = input.as_error_value()); + } + recursion_guard.insert(id); + let validator = unsafe { slots.get_unchecked(validator_id) }; + let output = validator.validate(py, input, extra, slots, recursion_guard); + recursion_guard.remove(&id); + output + } else { + let validator = unsafe { slots.get_unchecked(validator_id) }; + validator.validate(py, input, extra, slots, recursion_guard) + } +} diff --git a/tests/validators/test_recursive.py b/tests/validators/test_recursive.py index a4e39f810..7acc181be 100644 --- a/tests/validators/test_recursive.py +++ b/tests/validators/test_recursive.py @@ -1,6 +1,7 @@ from typing import Optional import pytest +from dirty_equals import IsPartialDict from pydantic_core import SchemaError, SchemaValidator, ValidationError @@ -231,3 +232,37 @@ def test_outside_parent(): 'tuple1': (1, 1, 'frog'), 'tuple2': (2, 2, 'toad'), } + + +def test_recursion(): + v = SchemaValidator( + { + 'type': 'typed-dict', + 'ref': 'Branch', + 'fields': { + 'name': {'schema': {'type': 'str'}}, + 'branch': { + 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}}, + 'default': None, + }, + }, + } + ) + assert v.validate_python({'name': 'root'}) == {'name': 'root', 'branch': None} + assert v.validate_python({'name': 'root', 'branch': {'name': 'b1', 'branch': None}}) == { + 'name': 'root', + 'branch': {'name': 'b1', 'branch': None}, + } + + b = {'name': 'recursive'} + b['branch'] = b + with pytest.raises(ValidationError) as exc_info: + assert v.validate_python(b) + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': ['branch'], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': {'name': 'recursive', 'branch': IsPartialDict(name='recursive')}, + } + ] From 897a5314e5ee69171695dc2d3c3cc48f0ae51fe4 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 29 Jun 2022 01:16:30 +0100 Subject: [PATCH 4/9] make recursion guard option-al --- Cargo.lock | 7 ++++ Cargo.toml | 1 + src/validators/mod.rs | 88 ++++++++++++++++++++++++++++++++----------- 3 files changed, 74 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 18e99107f..af7d64e63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -155,6 +155,12 @@ dependencies = [ "libmimalloc-sys", ] +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "once_cell" version = "1.10.0" @@ -203,6 +209,7 @@ dependencies = [ "enum_dispatch", "indexmap", "mimalloc", + "nohash-hasher", "pyo3", "regex", "serde", diff --git a/Cargo.toml b/Cargo.toml index adb8bf1f9..24338b51a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ indexmap = "1.8.1" mimalloc = { version = "0.1.29", default-features = false, optional = true } speedate = "0.4.1" ahash = "0.7.6" +nohash-hasher = "0.2.0" [lib] name = "_pydantic_core" diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 8e1482bf2..40de58261 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -1,7 +1,9 @@ -use ahash::AHashSet; +use std::collections::HashSet; use std::fmt::Debug; +use std::hash::BuildHasherDefault; use enum_dispatch::enum_dispatch; +use nohash_hasher::NoHashHasher; use pyo3::exceptions::PyRecursionError; use pyo3::prelude::*; @@ -75,19 +77,24 @@ impl SchemaValidator { } pub fn validate_python(&self, py: Python, input: &PyAny) -> PyResult { - let mut recursion_guard = RecursionGuard::new(); - let r = self - .validator - .validate(py, input, &Extra::default(), &self.slots, &mut recursion_guard); + let r = self.validator.validate( + py, + input, + &Extra::default(), + &self.slots, + &mut RecursionGuard::default(), + ); r.map_err(|e| self.prepare_validation_err(py, e)) } pub fn isinstance_python(&self, py: Python, input: &PyAny) -> PyResult { - let mut recursion_guard = RecursionGuard::new(); - match self - .validator - .validate(py, input, &Extra::default(), &self.slots, &mut recursion_guard) - { + match self.validator.validate( + py, + input, + &Extra::default(), + &self.slots, + &mut RecursionGuard::default(), + ) { Ok(_) => Ok(true), Err(ValError::InternalErr(err)) => Err(err), _ => Ok(false), @@ -97,10 +104,13 @@ impl SchemaValidator { pub fn validate_json(&self, py: Python, input: String) -> PyResult { match parse_json::(&input) { Ok(input) => { - let mut recursion_guard = RecursionGuard::new(); - let r = self - .validator - .validate(py, &input, &Extra::default(), &self.slots, &mut recursion_guard); + let r = self.validator.validate( + py, + &input, + &Extra::default(), + &self.slots, + &mut RecursionGuard::default(), + ); r.map_err(|e| self.prepare_validation_err(py, e)) } Err(e) => { @@ -118,11 +128,13 @@ impl SchemaValidator { pub fn isinstance_json(&self, py: Python, input: String) -> PyResult { match parse_json::(&input) { Ok(input) => { - let mut recursion_guard = RecursionGuard::new(); - match self - .validator - .validate(py, &input, &Extra::default(), &self.slots, &mut recursion_guard) - { + match self.validator.validate( + py, + &input, + &Extra::default(), + &self.slots, + &mut RecursionGuard::default(), + ) { Ok(_) => Ok(true), Err(ValError::InternalErr(err)) => Err(err), _ => Ok(false), @@ -137,10 +149,9 @@ impl SchemaValidator { data: Some(data), field: Some(field.as_str()), }; - let mut recursion_guard = RecursionGuard::new(); let r = self .validator - .validate(py, input, &extra, &self.slots, &mut recursion_guard); + .validate(py, input, &extra, &self.slots, &mut RecursionGuard::default()); r.map_err(|e| self.prepare_validation_err(py, e)) } @@ -353,7 +364,40 @@ pub enum CombinedValidator { FrozenSet(frozenset::FrozenSetValidator), } -pub type RecursionGuard = AHashSet; +#[derive(Debug, Clone, Default)] +pub struct RecursionGuard(Option>>>); + +impl RecursionGuard { + pub fn insert(&mut self, id: usize) { + match self.0 { + Some(ref mut set) => { + set.insert(id); + } + None => { + let mut set: HashSet>> = + HashSet::with_capacity_and_hasher(10, BuildHasherDefault::default()); + set.insert(id); + self.0 = Some(set); + } + }; + } + + pub fn remove(&mut self, id: &usize) { + match self.0 { + Some(ref mut set) => { + set.remove(id); + } + None => unreachable!(), + }; + } + + pub fn contains(&self, id: &usize) -> bool { + match self.0 { + Some(ref set) => set.contains(id), + None => false, + } + } +} /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, /// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait From cc8f137bb038c44e4182f74dc0c0843b58d3061a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 29 Jun 2022 12:43:44 +0100 Subject: [PATCH 5/9] more tests --- tests/validators/test_recursive.py | 99 +++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 2 deletions(-) diff --git a/tests/validators/test_recursive.py b/tests/validators/test_recursive.py index 7acc181be..4cd90682a 100644 --- a/tests/validators/test_recursive.py +++ b/tests/validators/test_recursive.py @@ -1,10 +1,12 @@ from typing import Optional import pytest -from dirty_equals import IsPartialDict +from dirty_equals import IsList, IsPartialDict from pydantic_core import SchemaError, SchemaValidator, ValidationError +from ..conftest import Err + def test_branch_nullable(): v = SchemaValidator( @@ -234,7 +236,7 @@ def test_outside_parent(): } -def test_recursion(): +def test_recursion_branch(): v = SchemaValidator( { 'type': 'typed-dict', @@ -266,3 +268,96 @@ def test_recursion(): 'input_value': {'name': 'recursive', 'branch': IsPartialDict(name='recursive')}, } ] + + +def test_recursive_list(): + v = SchemaValidator( + {'type': 'list', 'ref': 'the-list', 'items_schema': {'type': 'recursive-ref', 'schema_ref': 'the-list'}} + ) + assert v.validate_python([]) == [] + assert v.validate_python([[]]) == [[]] + + data = list() + data.append(data) + with pytest.raises(ValidationError, match='Recursion error - cyclic reference detected'): + assert v.validate_python(data) + + +@pytest.fixture(scope='module') +def multiple_tuple_schema(): + return SchemaValidator( + { + 'type': 'typed-dict', + 'fields': { + 'f1': { + 'schema': { + 'type': 'tuple-fix-len', + 'ref': 't', + 'items_schema': [ + {'type': 'int'}, + {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 't'}}, + ], + } + }, + 'f2': { + 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 't'}}, + 'default': None, + }, + }, + } + ) + + +@pytest.mark.parametrize( + 'input_value,expected', + [ + ({'f1': [1, None]}, {'f1': (1, None), 'f2': None}), + ({'f1': [1, None], 'f2': [2, None]}, {'f1': (1, None), 'f2': (2, None)}), + ( + {'f1': [1, (3, None)], 'f2': [2, (4, (4, (5, None)))]}, + {'f1': (1, (3, None)), 'f2': (2, (4, (4, (5, None))))}, + ), + ({'f1': [1, 2]}, Err(r'f1 -> 1\s+Value must be a valid tuple')), + ( + {'f1': [1, (3, None)], 'f2': [2, (4, (4, (5, 6)))]}, + Err(r'f2 -> 1 -> 1 -> 1 -> 1\s+Value must be a valid tuple'), + ), + ], +) +def test_multiple_tuple_param(multiple_tuple_schema, input_value, expected): + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=expected.message): + multiple_tuple_schema.validate_python(input_value) + # debug(repr(exc_info.value)) + else: + assert multiple_tuple_schema.validate_python(input_value) == expected + + +def test_multiple_tuple_repeat(multiple_tuple_schema): + t = (42, None) + assert multiple_tuple_schema.validate_python({'f1': (1, t), 'f2': (2, t)}) == { + 'f1': (1, (42, None)), + 'f2': (2, (42, None)), + } + + +def test_multiple_tuple_recursion(multiple_tuple_schema): + data = [1] + data.append(data) + with pytest.raises(ValidationError) as exc_info: + multiple_tuple_schema.validate_python({'f1': data, 'f2': data}) + + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': ['f1', 1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': [1, IsList(length=2)], + }, + { + 'kind': 'recursion_loop', + 'loc': ['f2', 1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': [1, IsList(length=2)], + }, + ] From 76e3556ca3e36d2859a30d9f81f97856ea1c38fe Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 29 Jun 2022 14:22:56 +0100 Subject: [PATCH 6/9] move RecursionGuard, optimise recursion check --- src/input/return_enums.rs | 3 ++- src/lib.rs | 1 + src/recursion_guard.rs | 36 +++++++++++++++++++++++++++ src/validators/any.rs | 3 ++- src/validators/bool.rs | 3 ++- src/validators/bytes.rs | 3 ++- src/validators/date.rs | 3 ++- src/validators/datetime.rs | 3 ++- src/validators/dict.rs | 3 ++- src/validators/float.rs | 3 ++- src/validators/frozenset.rs | 5 ++-- src/validators/function.rs | 3 ++- src/validators/int.rs | 3 ++- src/validators/list.rs | 5 ++-- src/validators/literal.rs | 3 ++- src/validators/mod.rs | 39 +----------------------------- src/validators/model_class.rs | 3 ++- src/validators/none.rs | 3 ++- src/validators/nullable.rs | 5 ++-- src/validators/recursive.rs | 8 +++--- src/validators/set.rs | 5 ++-- src/validators/string.rs | 3 ++- src/validators/time.rs | 3 ++- src/validators/tuple.rs | 5 ++-- src/validators/typed_dict.rs | 3 ++- src/validators/union.rs | 5 ++-- tests/validators/test_recursive.py | 22 +++++++++++++++++ 27 files changed, 109 insertions(+), 75 deletions(-) create mode 100644 src/recursion_guard.rs diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 7317625bc..e1f37ae16 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -4,7 +4,8 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple}; use crate::errors::{ValError, ValLineError, ValResult}; -use crate::validators::{CombinedValidator, Extra, RecursionGuard, Validator}; +use crate::recursion_guard::RecursionGuard; +use crate::validators::{CombinedValidator, Extra, Validator}; use super::parse_json::{JsonArray, JsonObject}; diff --git a/src/lib.rs b/src/lib.rs index 3e36b6088..2ac4b9c74 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; mod build_tools; mod errors; mod input; +mod recursion_guard; mod validators; // required for benchmarks diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs new file mode 100644 index 000000000..3a6b9dd9a --- /dev/null +++ b/src/recursion_guard.rs @@ -0,0 +1,36 @@ +use std::collections::HashSet; +use std::hash::BuildHasherDefault; + +use nohash_hasher::NoHashHasher; + +/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault. +/// It's used in `validators/recursive.rs` to detect when a reference is reused within itself. +#[derive(Debug, Clone, Default)] +pub struct RecursionGuard(Option>>>); + +impl RecursionGuard { + // insert a new id into the set, return whether the set already had the id in it + pub fn contains_or_insert(&mut self, id: usize) -> bool { + match self.0 { + // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert + // "If the set did not have this value present, `true` is returned." + Some(ref mut set) => !set.insert(id), + None => { + let mut set: HashSet>> = + HashSet::with_capacity_and_hasher(10, BuildHasherDefault::default()); + set.insert(id); + self.0 = Some(set); + false + } + } + } + + pub fn remove(&mut self, id: &usize) { + match self.0 { + Some(ref mut set) => { + set.remove(id); + } + None => unreachable!(), + }; + } +} diff --git a/src/validators/any.rs b/src/validators/any.rs index 2217c97e2..38df42197 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -3,8 +3,9 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; /// This might seem useless, but it's useful in DictValidator to avoid Option a lot #[derive(Debug, Clone)] diff --git a/src/validators/bool.rs b/src/validators/bool.rs index a73ba87c2..b425b3cb2 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -4,8 +4,9 @@ use pyo3::types::PyDict; use crate::build_tools::is_strict; use crate::errors::ValResult; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct BoolValidator; diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 651f03e13..8c4f45337 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -4,8 +4,9 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherBytes, Input}; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct BytesValidator; diff --git a/src/validators/date.rs b/src/validators/date.rs index 73d5cc989..927532e1b 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -5,8 +5,9 @@ use speedate::{Date, Time}; use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, ValResult}; use crate::input::{EitherDate, Input}; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct DateValidator { diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index b1176857e..e62807180 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -5,8 +5,9 @@ use speedate::DateTime; use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherDateTime, Input}; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct DateTimeValidator { diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 039d82e78..edabb2a51 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -4,9 +4,10 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, ValLineError, ValResult}; use crate::input::{GenericMapping, Input, JsonObject}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct DictValidator { diff --git a/src/validators/float.rs b/src/validators/float.rs index 139e8fc49..37993c422 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -4,8 +4,9 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct FloatValidator; diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index e01b284dd..b44bf8e10 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -4,12 +4,11 @@ use pyo3::types::{PyDict, PyFrozenSet}; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind}; use crate::input::{GenericSequence, Input}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; use super::list::sequence_build_function; -use super::{ - build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, -}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; #[derive(Debug, Clone)] pub struct FrozenSetValidator { diff --git a/src/validators/function.rs b/src/validators/function.rs index 23ef8d8b6..deefae2eb 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -5,8 +5,9 @@ use pyo3::types::{PyAny, PyDict}; use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{context, val_line_error, ErrorKind, ValError, ValResult, ValidationError}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug)] pub struct FunctionBuilder; diff --git a/src/validators/int.rs b/src/validators/int.rs index 35663f646..758d4f1f3 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -4,8 +4,9 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct IntValidator; diff --git a/src/validators/list.rs b/src/validators/list.rs index 3718fc0fb..61f261240 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -4,11 +4,10 @@ use pyo3::types::PyDict; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind}; use crate::input::{GenericSequence, Input}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; -use super::{ - build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, -}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; #[derive(Debug, Clone)] pub struct ListValidator { diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 963647851..534ddfc54 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -6,8 +6,9 @@ use ahash::AHashSet; use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug)] pub struct LiteralBuilder; diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 40de58261..7799d0e9d 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -1,9 +1,6 @@ -use std::collections::HashSet; use std::fmt::Debug; -use std::hash::BuildHasherDefault; use enum_dispatch::enum_dispatch; -use nohash_hasher::NoHashHasher; use pyo3::exceptions::PyRecursionError; use pyo3::prelude::*; @@ -13,6 +10,7 @@ use serde_json::from_str as parse_json; use crate::build_tools::{py_error, SchemaDict, SchemaError}; use crate::errors::{context, val_line_error, ErrorKind, ValError, ValResult, ValidationError}; use crate::input::{Input, JsonInput}; +use crate::recursion_guard::RecursionGuard; mod any; mod bool; @@ -364,41 +362,6 @@ pub enum CombinedValidator { FrozenSet(frozenset::FrozenSetValidator), } -#[derive(Debug, Clone, Default)] -pub struct RecursionGuard(Option>>>); - -impl RecursionGuard { - pub fn insert(&mut self, id: usize) { - match self.0 { - Some(ref mut set) => { - set.insert(id); - } - None => { - let mut set: HashSet>> = - HashSet::with_capacity_and_hasher(10, BuildHasherDefault::default()); - set.insert(id); - self.0 = Some(set); - } - }; - } - - pub fn remove(&mut self, id: &usize) { - match self.0 { - Some(ref mut set) => { - set.remove(id); - } - None => unreachable!(), - }; - } - - pub fn contains(&self, id: &usize) -> bool { - match self.0 { - Some(ref set) => set.contains(id), - None => false, - } - } -} - /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, /// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait #[enum_dispatch(CombinedValidator)] diff --git a/src/validators/model_class.rs b/src/validators/model_class.rs index 823eb9bc0..68eb4bdac 100644 --- a/src/validators/model_class.rs +++ b/src/validators/model_class.rs @@ -10,8 +10,9 @@ use pyo3::{ffi, intern}; use crate::build_tools::{py_error, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValError, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct ModelClassValidator { diff --git a/src/validators/none.rs b/src/validators/none.rs index a06e3b148..6e7aef25b 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -3,8 +3,9 @@ use pyo3::types::PyDict; use crate::errors::{err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct NoneValidator; diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 7aa5e6c13..dd2f05983 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -3,10 +3,9 @@ use pyo3::types::PyDict; use crate::build_tools::SchemaDict; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{ - build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, -}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; #[derive(Debug, Clone)] pub struct NullableValidator { diff --git a/src/validators/recursive.rs b/src/validators/recursive.rs index 1494e6703..312eb5bae 100644 --- a/src/validators/recursive.rs +++ b/src/validators/recursive.rs @@ -4,8 +4,9 @@ use pyo3::types::PyDict; use crate::build_tools::SchemaDict; use crate::errors::{err_val_error, ErrorKind, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct RecursiveContainerValidator { @@ -80,10 +81,11 @@ fn validate<'s, 'data>( recursion_guard: &'s mut RecursionGuard, ) -> ValResult<'data, PyObject> { if let Some(id) = input.identity() { - if recursion_guard.contains(&id) { + if recursion_guard.contains_or_insert(id) { + // remove ID in we use recursion_guard again + recursion_guard.remove(&id); return err_val_error!(kind = ErrorKind::RecursionLoop, input_value = input.as_error_value()); } - recursion_guard.insert(id); let validator = unsafe { slots.get_unchecked(validator_id) }; let output = validator.validate(py, input, extra, slots, recursion_guard); recursion_guard.remove(&id); diff --git a/src/validators/set.rs b/src/validators/set.rs index 672c7bbbf..60203d027 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -4,12 +4,11 @@ use pyo3::types::{PyDict, PySet}; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{as_internal, context, err_val_error, ErrorKind}; use crate::input::{GenericSequence, Input}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; use super::list::sequence_build_function; -use super::{ - build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, -}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; #[derive(Debug, Clone)] pub struct SetValidator { diff --git a/src/validators/string.rs b/src/validators/string.rs index 72f137b6d..330be9d53 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -5,8 +5,9 @@ use regex::Regex; use crate::build_tools::{is_strict, py_error, schema_or_config}; use crate::errors::{context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherString, Input}; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct StrValidator; diff --git a/src/validators/time.rs b/src/validators/time.rs index b32a555bf..350927a5c 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -5,8 +5,9 @@ use speedate::Time; use crate::build_tools::{is_strict, SchemaDict, SchemaError}; use crate::errors::{as_internal, context, err_val_error, ErrorKind, ValResult}; use crate::input::{EitherTime, Input}; +use crate::recursion_guard::RecursionGuard; -use super::{BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] pub struct TimeValidator { diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index e2a5404e8..62f563dd0 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -4,12 +4,11 @@ use pyo3::types::{PyDict, PyList, PyTuple}; use crate::build_tools::{is_strict, py_error, SchemaDict}; use crate::errors::{context, err_val_error, ErrorKind, ValError, ValLineError}; use crate::input::{GenericSequence, Input}; +use crate::recursion_guard::RecursionGuard; use super::any::AnyValidator; use super::list::sequence_build_function; -use super::{ - build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, -}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; #[derive(Debug, Clone)] pub struct TupleVarLenValidator { diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 54122c915..6183e7db6 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -10,9 +10,10 @@ use crate::errors::{ as_internal, context, err_val_error, py_err_string, val_line_error, ErrorKind, ValError, ValLineError, ValResult, }; use crate::input::{GenericMapping, Input, JsonInput, JsonObject}; +use crate::recursion_guard::RecursionGuard; use crate::SchemaError; -use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, Validator}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] struct TypedDictField { diff --git a/src/validators/union.rs b/src/validators/union.rs index 6c794688a..581a571d9 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -4,10 +4,9 @@ use pyo3::types::{PyDict, PyList}; use crate::build_tools::{is_strict, SchemaDict}; use crate::errors::{ValError, ValLineError}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; -use super::{ - build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, RecursionGuard, ValResult, Validator, -}; +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, ValResult, Validator}; #[derive(Debug, Clone)] pub struct UnionValidator { diff --git a/tests/validators/test_recursive.py b/tests/validators/test_recursive.py index 4cd90682a..f25114cd4 100644 --- a/tests/validators/test_recursive.py +++ b/tests/validators/test_recursive.py @@ -361,3 +361,25 @@ def test_multiple_tuple_recursion(multiple_tuple_schema): 'input_value': [1, IsList(length=2)], }, ] + + +def test_multiple_tuple_recursion_once(multiple_tuple_schema): + data = [1] + data.append(data) + with pytest.raises(ValidationError) as exc_info: + multiple_tuple_schema.validate_python({'f1': data, 'f2': data}) + + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': ['f1', 1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': [1, IsList(length=2)], + }, + { + 'kind': 'recursion_loop', + 'loc': ['f2', 1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': [1, IsList(length=2)], + }, + ] From dba8278635dab96e1b3535007559aa6d7570a867 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 29 Jun 2022 15:46:08 +0100 Subject: [PATCH 7/9] tests for recursion across a wrap validator --- src/errors/line_error.rs | 25 +++++++++++-- src/errors/validation_exception.rs | 49 ++++++++++++++++++++----- src/validators/function.rs | 3 ++ tests/validators/test_function.py | 21 ++++++++++- tests/validators/test_recursive.py | 57 ++++++++++++++++++++++++++++- tests/validators/test_typed_dict.py | 2 + 6 files changed, 141 insertions(+), 16 deletions(-) diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index 9594c03f9..8b61ae2c5 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -17,14 +17,27 @@ pub enum ValError<'a> { InternalErr(PyErr), } +impl<'a> From for ValError<'a> { + fn from(py_err: PyErr) -> Self { + Self::InternalErr(py_err) + } +} + +impl<'a> From>> for ValError<'a> { + fn from(line_errors: Vec>) -> Self { + Self::LineErrors(line_errors) + } +} + // ValError used to implement Error, see #78 for removed code +// TODO, remove and replace with just .into() pub fn as_internal<'a>(err: PyErr) -> ValError<'a> { - ValError::InternalErr(err) + err.into() } pub fn pretty_line_errors(py: Python, line_errors: Vec) -> String { - let py_line_errors: Vec = line_errors.into_iter().map(|e| PyLineError::new(py, e)).collect(); + let py_line_errors: Vec = line_errors.into_iter().map(|e| e.into_py(py)).collect(); pretty_py_line_errors(Some(py), py_line_errors.iter()) } @@ -58,7 +71,7 @@ impl<'a> ValLineError<'a> { ValLineError { kind: self.kind, reverse_location: self.reverse_location, - input_value: InputValue::PyObject(self.input_value.to_object(py)), + input_value: self.input_value.to_object(py).into(), context: self.context, } } @@ -79,6 +92,12 @@ impl Default for InputValue<'_> { } } +impl<'a> From for InputValue<'a> { + fn from(py_object: PyObject) -> Self { + Self::PyObject(py_object) + } +} + impl<'a> ToPyObject for InputValue<'a> { fn to_object(&self, py: Python) -> PyObject { match self { diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index d0285ca91..75b825c8c 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -16,7 +16,7 @@ use super::location::Location; use super::ValError; #[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")] -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ValidationError { line_errors: Vec, title: PyObject, @@ -33,7 +33,7 @@ impl ValidationError { pub fn from_val_error(py: Python, title: PyObject, error: ValError) -> PyErr { match error { ValError::LineErrors(raw_errors) => { - let line_errors: Vec = raw_errors.into_iter().map(|e| PyLineError::new(py, e)).collect(); + let line_errors: Vec = raw_errors.into_iter().map(|e| e.into_py(py)).collect(); PyErr::new::((line_errors, title)) } ValError::InternalErr(err) => err, @@ -61,6 +61,18 @@ impl Error for ValidationError { } } +// used to convert a validation error back to ValError for wrap functions +impl<'a> From for ValError<'a> { + fn from(val_error: ValidationError) -> Self { + val_error + .line_errors + .into_iter() + .map(|e| e.into()) + .collect::>() + .into() + } +} + #[pymethods] impl ValidationError { #[new] @@ -131,19 +143,36 @@ pub struct PyLineError { context: Context, } -impl PyLineError { - pub fn new(py: Python, raw_error: ValLineError) -> Self { +impl<'a> IntoPy for ValLineError<'a> { + fn into_py(self, py: Python<'_>) -> PyLineError { + PyLineError { + kind: self.kind, + location: match self.reverse_location.len() { + 0..=1 => self.reverse_location, + _ => self.reverse_location.into_iter().rev().collect(), + }, + input_value: self.input_value.to_object(py), + context: self.context, + } + } +} + +/// opposite of above, used extract line errors from a validation error for wrap functions +impl<'a> From for ValLineError<'a> { + fn from(py_line_error: PyLineError) -> Self { Self { - kind: raw_error.kind, - location: match raw_error.reverse_location.len() { - 0..=1 => raw_error.reverse_location, - _ => raw_error.reverse_location.into_iter().rev().collect(), + kind: py_line_error.kind, + reverse_location: match py_line_error.location.len() { + 0..=1 => py_line_error.location, + _ => py_line_error.location.into_iter().rev().collect(), }, - input_value: raw_error.input_value.to_object(py), - context: raw_error.context, + input_value: py_line_error.input_value.into(), + context: py_line_error.context, } } +} +impl PyLineError { pub fn as_dict(&self, py: Python) -> PyResult { let dict = PyDict::new(py); dict.set_item("kind", self.kind())?; diff --git a/src/validators/function.rs b/src/validators/function.rs index deefae2eb..8f743d927 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -259,6 +259,9 @@ fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> Val // Only ValueError and AssertionError are considered as validation errors, // TypeError is now considered as a runtime error to catch errors in function signatures let kind = if err.is_instance_of::(py) { + if let Ok(validation_error) = err.value(py).extract::() { + return validation_error.into(); + } ErrorKind::ValueError } else if err.is_instance_of::(py) { ErrorKind::AssertionError diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 57cafcddf..c4f473954 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -103,9 +103,7 @@ def f(input_value, *, validator, **kwargs): v = SchemaValidator({'title': 'Test', 'type': 'function', 'mode': 'wrap', 'function': f, 'schema': 'str'}) - # with pytest.raises(ValidationError) as exc_info: assert v.validate_python('input value') == 'input value Changed' - # print(exc_info.value) def test_function_wrap_repr(): @@ -134,6 +132,25 @@ def test_function_wrap_not_callable(): SchemaValidator({'title': 'Test', 'type': 'function', 'mode': 'wrap', 'schema': 'str'}) +def test_wrap_error(): + def f(input_value, *, validator, **kwargs): + return validator(input_value) * 2 + + v = SchemaValidator({'title': 'Test', 'type': 'function', 'mode': 'wrap', 'function': f, 'schema': 'int'}) + + assert v.validate_python('42') == 84 + with pytest.raises(ValidationError) as exc_info: + v.validate_python('wrong') + assert exc_info.value.errors() == [ + { + 'kind': 'int_parsing', + 'loc': [], + 'message': 'Value must be a valid integer, unable to parse string as an integer', + 'input_value': 'wrong', + } + ] + + def test_wrong_mode(): with pytest.raises(SchemaError, match='SchemaError: Unexpected function mode "foobar"'): SchemaValidator({'title': 'Test', 'type': 'function', 'mode': 'foobar', 'schema': 'str'}) diff --git a/tests/validators/test_recursive.py b/tests/validators/test_recursive.py index f25114cd4..38541b01b 100644 --- a/tests/validators/test_recursive.py +++ b/tests/validators/test_recursive.py @@ -1,11 +1,12 @@ from typing import Optional import pytest -from dirty_equals import IsList, IsPartialDict +from dirty_equals import AnyThing, HasAttributes, IsList, IsPartialDict from pydantic_core import SchemaError, SchemaValidator, ValidationError from ..conftest import Err +from .test_typed_dict import Cls def test_branch_nullable(): @@ -241,6 +242,7 @@ def test_recursion_branch(): { 'type': 'typed-dict', 'ref': 'Branch', + 'config': {'from_attributes': True}, 'fields': { 'name': {'schema': {'type': 'str'}}, 'branch': { @@ -256,10 +258,15 @@ def test_recursion_branch(): 'branch': {'name': 'b1', 'branch': None}, } + data = Cls(name='root') + data.branch = Cls(name='b1', branch=None) + assert v.validate_python(data) == {'name': 'root', 'branch': {'name': 'b1', 'branch': None}} + b = {'name': 'recursive'} b['branch'] = b with pytest.raises(ValidationError) as exc_info: assert v.validate_python(b) + assert exc_info.value.title == 'recursive-container' assert exc_info.value.errors() == [ { 'kind': 'recursion_loop', @@ -269,6 +276,19 @@ def test_recursion_branch(): } ] + data = Cls(name='root') + data.branch = data + with pytest.raises(ValidationError) as exc_info: + v.validate_python(data) + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': ['branch'], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': HasAttributes(name='root', branch=AnyThing()), + } + ] + def test_recursive_list(): v = SchemaValidator( @@ -383,3 +403,38 @@ def test_multiple_tuple_recursion_once(multiple_tuple_schema): 'input_value': [1, IsList(length=2)], }, ] + + +def test_recursive_wrap(): + def wrap_func(input_value, *, validator, **kwargs): + return validator(input_value) + (42,) + + v = SchemaValidator( + { + 'type': 'function', + 'ref': 'wrapper', + 'mode': 'wrap', + 'function': wrap_func, + 'schema': { + 'type': 'tuple-fix-len', + 'items_schema': [ + {'type': 'int'}, + {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'wrapper'}}, + ], + }, + } + ) + assert v.validate_python((1, None)) == (1, None, 42) + assert v.validate_python((1, (2, (3, None)))) == (1, (2, (3, None, 42), 42), 42) + t = [1] + t.append(t) + with pytest.raises(ValidationError) as exc_info: + v.validate_python(t) + assert exc_info.value.errors() == [ + { + 'kind': 'recursion_loop', + 'loc': [1], + 'message': 'Recursion error - cyclic reference detected', + 'input_value': IsList(positions={0: 1}, length=2), + } + ] diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index a3576835d..3999d3bf5 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -678,6 +678,8 @@ class MyDataclass: (Cls(a=1, b=2, c='ham'), ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})), (dict(a=1, b=2, c='ham'), ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})), (Map(a=1, b=2, c='ham'), ({'a': 1, 'b': 2, 'c': 'ham'}, {'a', 'b', 'c'})), + # using type gives `__module__ == 'builtins'` + (type('Testing', (), {}), Err('[kind=dict_attributes_type,')), ('123', Err('Value must be a valid dictionary or instance to extract fields from [kind=dict_attributes_type,')), ([(1, 2)], Err('kind=dict_attributes_type,')), (((1, 2),), Err('kind=dict_attributes_type,')), From 3e1cddf642144020c671ab73264623cc22e47f8c Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 29 Jun 2022 16:12:42 +0100 Subject: [PATCH 8/9] bump From 6f41b71fdc7e18bcb9b6fa70b85788ffe28d4ad5 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 29 Jun 2022 16:23:17 +0100 Subject: [PATCH 9/9] tweaks --- src/errors/validation_exception.rs | 4 ++-- src/validators/recursive.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 75b825c8c..621b4c282 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -68,7 +68,7 @@ impl<'a> From for ValError<'a> { .line_errors .into_iter() .map(|e| e.into()) - .collect::>() + .collect::>() .into() } } @@ -157,7 +157,7 @@ impl<'a> IntoPy for ValLineError<'a> { } } -/// opposite of above, used extract line errors from a validation error for wrap functions +/// opposite of above, used to extract line errors from a validation error for wrap functions impl<'a> From for ValLineError<'a> { fn from(py_line_error: PyLineError) -> Self { Self { diff --git a/src/validators/recursive.rs b/src/validators/recursive.rs index 312eb5bae..ba4f3a0eb 100644 --- a/src/validators/recursive.rs +++ b/src/validators/recursive.rs @@ -82,7 +82,7 @@ fn validate<'s, 'data>( ) -> ValResult<'data, PyObject> { if let Some(id) = input.identity() { if recursion_guard.contains_or_insert(id) { - // remove ID in we use recursion_guard again + // remove ID in case we use recursion_guard again recursion_guard.remove(&id); return err_val_error!(kind = ErrorKind::RecursionLoop, input_value = input.as_error_value()); }