Skip to content

Commit f6d4184

Browse files
implement __reduce__ to make SchemaValidator picklable (pydantic#53)
* attempt at implementing __reduce__ * python version * remove loc * back to rust * use array * minimize changes * again use array * remove import * test all protocol versions * Update src/validators/mod.rs Co-authored-by: Samuel Colvin <[email protected]> * Update src/validators/mod.rs Co-authored-by: Samuel Colvin <[email protected]> * pr feedback * pointlessly tweak layout Co-authored-by: Samuel Colvin <[email protected]> Co-authored-by: Samuel Colvin <[email protected]>
1 parent 69098c4 commit f6d4184

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

src/validators/mod.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ mod set;
2828
mod string;
2929
mod union;
3030

31-
#[pyclass]
31+
#[pyclass(module = "pydantic_core._pydantic_core")]
3232
#[derive(Debug, Clone)]
3333
pub struct SchemaValidator {
3434
validator: CombinedValidator,
3535
slots: Vec<CombinedValidator>,
36+
schema: PyObject,
3637
}
3738

3839
#[pymethods]
@@ -50,7 +51,17 @@ impl SchemaValidator {
5051
}
5152
};
5253
let slots = slots_builder.into_slots()?;
53-
Ok(Self { validator, slots })
54+
Ok(Self {
55+
validator,
56+
slots,
57+
schema: schema.into_py(py),
58+
})
59+
}
60+
61+
fn __reduce__(&self, py: Python) -> PyResult<PyObject> {
62+
let args = (self.schema.as_ref(py),);
63+
let cls = Py::new(py, self.to_owned())?.getattr(py, "__class__")?;
64+
Ok((cls, args).into_py(py))
5465
}
5566

5667
fn validate_python(&self, py: Python, input: &PyAny) -> PyResult<PyObject> {

tests/test_build.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pickle
2+
13
import pytest
24

35
from pydantic_core import SchemaError, SchemaValidator
@@ -39,3 +41,13 @@ def test_schema_wrong_type():
3941
assert exc_info.value.args[0] == (
4042
"Schema build error:\n TypeError: 'int' object cannot be converted to 'PyString'"
4143
)
44+
45+
46+
@pytest.mark.parametrize('pickle_protocol', range(1, pickle.HIGHEST_PROTOCOL + 1))
47+
def test_pickle(pickle_protocol: int) -> None:
48+
v1 = SchemaValidator({'type': 'bool'})
49+
assert v1.validate_python('tRuE') is True
50+
p = pickle.dumps(v1, protocol=pickle_protocol)
51+
v2 = pickle.loads(p)
52+
assert v2.validate_python('tRuE') is True
53+
assert repr(v1) == repr(v2)

0 commit comments

Comments
 (0)