Skip to content

Recursion guard #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ indexmap = "1.8.1"
mimalloc = { version = "0.1.29", default-features = false, optional = true }
speedate = "0.4.1"
ahash = "0.7.6"
nohash-hasher = "0.2.0"

[lib]
name = "_pydantic_core"
Expand Down
4 changes: 4 additions & 0 deletions src/errors/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ pub enum ErrorKind {
#[strum(message = "Invalid JSON: {parser_error}")]
InvalidJson,
// ---------------------
// recursion error
#[strum(message = "Recursion error - cyclic reference detected")]
RecursionLoop,
// ---------------------
// typed dict specific errors
#[strum(message = "Value must be a valid dictionary or instance to extract fields from")]
DictAttributesType,
Expand Down
25 changes: 22 additions & 3 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,27 @@ pub enum ValError<'a> {
InternalErr(PyErr),
}

impl<'a> From<PyErr> for ValError<'a> {
fn from(py_err: PyErr) -> Self {
Self::InternalErr(py_err)
}
}

impl<'a> From<Vec<ValLineError<'a>>> for ValError<'a> {
fn from(line_errors: Vec<ValLineError<'a>>) -> Self {
Self::LineErrors(line_errors)
}
}

// ValError used to implement Error, see #78 for removed code

// TODO, remove and replace with just .into()
pub fn as_internal<'a>(err: PyErr) -> ValError<'a> {
ValError::InternalErr(err)
err.into()
}

pub fn pretty_line_errors(py: Python, line_errors: Vec<ValLineError>) -> String {
let py_line_errors: Vec<PyLineError> = line_errors.into_iter().map(|e| PyLineError::new(py, e)).collect();
let py_line_errors: Vec<PyLineError> = line_errors.into_iter().map(|e| e.into_py(py)).collect();
pretty_py_line_errors(Some(py), py_line_errors.iter())
}

Expand Down Expand Up @@ -58,7 +71,7 @@ impl<'a> ValLineError<'a> {
ValLineError {
kind: self.kind,
reverse_location: self.reverse_location,
input_value: InputValue::PyObject(self.input_value.to_object(py)),
input_value: self.input_value.to_object(py).into(),
context: self.context,
}
}
Expand All @@ -79,6 +92,12 @@ impl Default for InputValue<'_> {
}
}

impl<'a> From<PyObject> for InputValue<'a> {
fn from(py_object: PyObject) -> Self {
Self::PyObject(py_object)
}
}

impl<'a> ToPyObject for InputValue<'a> {
fn to_object(&self, py: Python) -> PyObject {
match self {
Expand Down
49 changes: 39 additions & 10 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use super::location::Location;
use super::ValError;

#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ValidationError {
line_errors: Vec<PyLineError>,
title: PyObject,
Expand All @@ -33,7 +33,7 @@ impl ValidationError {
pub fn from_val_error(py: Python, title: PyObject, error: ValError) -> PyErr {
match error {
ValError::LineErrors(raw_errors) => {
let line_errors: Vec<PyLineError> = raw_errors.into_iter().map(|e| PyLineError::new(py, e)).collect();
let line_errors: Vec<PyLineError> = raw_errors.into_iter().map(|e| e.into_py(py)).collect();
PyErr::new::<ValidationError, _>((line_errors, title))
}
ValError::InternalErr(err) => err,
Expand Down Expand Up @@ -61,6 +61,18 @@ impl Error for ValidationError {
}
}

// used to convert a validation error back to ValError for wrap functions
impl<'a> From<ValidationError> for ValError<'a> {
fn from(val_error: ValidationError) -> Self {
val_error
.line_errors
.into_iter()
.map(|e| e.into())
.collect::<Vec<_>>()
.into()
}
}

#[pymethods]
impl ValidationError {
#[new]
Expand Down Expand Up @@ -131,19 +143,36 @@ pub struct PyLineError {
context: Context,
}

impl PyLineError {
pub fn new(py: Python, raw_error: ValLineError) -> Self {
impl<'a> IntoPy<PyLineError> for ValLineError<'a> {
fn into_py(self, py: Python<'_>) -> PyLineError {
PyLineError {
kind: self.kind,
location: match self.reverse_location.len() {
0..=1 => self.reverse_location,
_ => self.reverse_location.into_iter().rev().collect(),
},
input_value: self.input_value.to_object(py),
context: self.context,
}
}
}

/// opposite of above, used to extract line errors from a validation error for wrap functions
impl<'a> From<PyLineError> for ValLineError<'a> {
fn from(py_line_error: PyLineError) -> Self {
Self {
kind: raw_error.kind,
location: match raw_error.reverse_location.len() {
0..=1 => raw_error.reverse_location,
_ => raw_error.reverse_location.into_iter().rev().collect(),
kind: py_line_error.kind,
reverse_location: match py_line_error.location.len() {
0..=1 => py_line_error.location,
_ => py_line_error.location.into_iter().rev().collect(),
},
input_value: raw_error.input_value.to_object(py),
context: raw_error.context,
input_value: py_line_error.input_value.into(),
context: py_line_error.context,
}
}
}

impl PyLineError {
pub fn as_dict(&self, py: Python) -> PyResult<PyObject> {
let dict = PyDict::new(py);
dict.set_item("kind", self.kind())?;
Expand Down
4 changes: 4 additions & 0 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {

fn as_error_value(&'a self) -> InputValue<'a>;

fn identity(&'a self) -> Option<usize> {
None
}

fn is_none(&self) -> bool;

fn strict_str<'data>(&'data self) -> ValResult<EitherString<'data>>;
Expand Down
6 changes: 5 additions & 1 deletion src/input/input_python.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::str::from_utf8;

use pyo3::exceptions::{PyAttributeError, PyTypeError};
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{
PyBool, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet, PyString,
PyTime, PyTuple, PyType,
};
use pyo3::{intern, AsPyPointer};

use crate::errors::location::LocItem;
use crate::errors::{as_internal, context, err_val_error, py_err_string, ErrorKind, InputValue, ValResult};
Expand Down Expand Up @@ -36,6 +36,10 @@ impl<'a> Input<'a> for PyAny {
InputValue::PyAny(self)
}

fn identity(&'a self) -> Option<usize> {
Some(self.as_ptr() as usize)
}

fn is_none(&self) -> bool {
self.is_none()
}
Expand Down
23 changes: 17 additions & 6 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple};

use crate::errors::{ValError, ValLineError, ValResult};
use crate::recursion_guard::RecursionGuard;
use crate::validators::{CombinedValidator, Extra, Validator};

use super::parse_json::{JsonArray, JsonObject};
Expand Down Expand Up @@ -41,11 +42,12 @@ macro_rules! build_validate_to_vec {
validator: &'s CombinedValidator,
extra: &Extra,
slots: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'a, Vec<PyObject>> {
let mut output: Vec<PyObject> = Vec::with_capacity(length);
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item) in sequence.iter().enumerate() {
match validator.validate(py, item, extra, slots) {
match validator.validate(py, item, extra, slots, recursion_guard) {
Ok(item) => output.push(item),
Err(ValError::LineErrors(line_errors)) => {
errors.extend(
Expand Down Expand Up @@ -90,13 +92,22 @@ impl<'a> GenericSequence<'a> {
validator: &'s CombinedValidator,
extra: &Extra,
slots: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'a, Vec<PyObject>> {
match self {
Self::List(sequence) => validate_to_vec_list(py, sequence, length, validator, extra, slots),
Self::Tuple(sequence) => validate_to_vec_tuple(py, sequence, length, validator, extra, slots),
Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots),
Self::FrozenSet(sequence) => validate_to_vec_frozenset(py, sequence, length, validator, extra, slots),
Self::JsonArray(sequence) => validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots),
Self::List(sequence) => {
validate_to_vec_list(py, sequence, length, validator, extra, slots, recursion_guard)
}
Self::Tuple(sequence) => {
validate_to_vec_tuple(py, sequence, length, validator, extra, slots, recursion_guard)
}
Self::Set(sequence) => validate_to_vec_set(py, sequence, length, validator, extra, slots, recursion_guard),
Self::FrozenSet(sequence) => {
validate_to_vec_frozenset(py, sequence, length, validator, extra, slots, recursion_guard)
}
Self::JsonArray(sequence) => {
validate_to_vec_jsonarray(py, sequence, length, validator, extra, slots, recursion_guard)
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
mod build_tools;
mod errors;
mod input;
mod recursion_guard;
mod validators;

// required for benchmarks
Expand Down
36 changes: 36 additions & 0 deletions src/recursion_guard.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::collections::HashSet;
use std::hash::BuildHasherDefault;

use nohash_hasher::NoHashHasher;

/// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault.
/// It's used in `validators/recursive.rs` to detect when a reference is reused within itself.
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard(Option<HashSet<usize, BuildHasherDefault<NoHashHasher<usize>>>>);

impl RecursionGuard {
// insert a new id into the set, return whether the set already had the id in it
pub fn contains_or_insert(&mut self, id: usize) -> bool {
match self.0 {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
Some(ref mut set) => !set.insert(id),
None => {
let mut set: HashSet<usize, BuildHasherDefault<NoHashHasher<usize>>> =
HashSet::with_capacity_and_hasher(10, BuildHasherDefault::default());
set.insert(id);
self.0 = Some(set);
false
}
}
}

pub fn remove(&mut self, id: &usize) {
match self.0 {
Some(ref mut set) => {
set.remove(id);
}
None => unreachable!(),
};
}
}
2 changes: 2 additions & 0 deletions src/validators/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pyo3::types::PyDict;

use crate::errors::ValResult;
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;

use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

Expand All @@ -29,6 +30,7 @@ impl Validator for AnyValidator {
input: &'data impl Input<'data>,
_extra: &Extra,
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
// Ok(input.clone().into_py(py))
Ok(input.to_object(py))
Expand Down
4 changes: 4 additions & 0 deletions src/validators/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::types::PyDict;
use crate::build_tools::is_strict;
use crate::errors::ValResult;
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;

use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};

Expand Down Expand Up @@ -33,6 +34,7 @@ impl Validator for BoolValidator {
input: &'data impl Input<'data>,
_extra: &Extra,
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
// TODO in theory this could be quicker if we used PyBool rather than going to a bool
// and back again, might be worth profiling?
Expand All @@ -45,6 +47,7 @@ impl Validator for BoolValidator {
input: &'data impl Input<'data>,
_extra: &Extra,
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
Ok(input.strict_bool()?.into_py(py))
}
Expand All @@ -70,6 +73,7 @@ impl Validator for StrictBoolValidator {
input: &'data impl Input<'data>,
_extra: &Extra,
_slots: &'data [CombinedValidator],
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
Ok(input.strict_bool()?.into_py(py))
}
Expand Down
Loading