Skip to content

Commit a3b4cf0

Browse files
committed
Support default factories taking validated data as an argument
1 parent 1ced3e6 commit a3b4cf0

File tree

4 files changed

+51
-12
lines changed

4 files changed

+51
-12
lines changed

python/pydantic_core/core_schema.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,7 +2375,8 @@ class WithDefaultSchema(TypedDict, total=False):
23752375
type: Required[Literal['default']]
23762376
schema: Required[CoreSchema]
23772377
default: Any
2378-
default_factory: Callable[[], Any]
2378+
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any]]
2379+
default_factory_takes_data: bool
23792380
on_error: Literal['raise', 'omit', 'default'] # default: 'raise'
23802381
validate_default: bool # default: False
23812382
strict: bool
@@ -2388,7 +2389,8 @@ def with_default_schema(
23882389
schema: CoreSchema,
23892390
*,
23902391
default: Any = PydanticUndefined,
2391-
default_factory: Callable[[], Any] | None = None,
2392+
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any], None] = None,
2393+
default_factory_takes_data: bool | None = None,
23922394
on_error: Literal['raise', 'omit', 'default'] | None = None,
23932395
validate_default: bool | None = None,
23942396
strict: bool | None = None,
@@ -2413,7 +2415,8 @@ def with_default_schema(
24132415
Args:
24142416
schema: The schema to add a default value to
24152417
default: The default value to use
2416-
default_factory: A function that returns the default value to use
2418+
default_factory: A callable that returns the default value to use
2419+
default_factory_takes_data: Whether the default factory takes a validated data argument
24172420
on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default'
24182421
validate_default: Whether the default value should be validated
24192422
strict: Whether the underlying schema should be validated with strict mode
@@ -2425,6 +2428,7 @@ def with_default_schema(
24252428
type='default',
24262429
schema=schema,
24272430
default_factory=default_factory,
2431+
default_factory_takes_data=default_factory_takes_data,
24282432
on_error=on_error,
24292433
validate_default=validate_default,
24302434
strict=strict,

src/serializers/type_serializers/with_default.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ impl TypeSerializer for WithDefaultSerializer {
7272
}
7373

7474
fn get_default(&self, py: Python) -> PyResult<Option<PyObject>> {
75-
self.default.default_value(py)
75+
if let DefaultType::DefaultFactory(_, _takes_data @ true) = self.default {
76+
// We currently don't compute the default if the default factory takes
77+
// the data from other fields.
78+
Ok(None)
79+
} else {
80+
self.default.default_value(
81+
py, &None, // Won't be used.
82+
)
83+
}
7684
}
7785
}

src/validators/with_default.rs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use pyo3::intern;
22
use pyo3::prelude::*;
33
use pyo3::sync::GILOnceCell;
4+
use pyo3::types::PyBool;
45
use pyo3::types::PyDict;
56
use pyo3::types::PyString;
67
use pyo3::PyTraverseError;
@@ -25,7 +26,7 @@ fn get_deepcopy(py: Python) -> PyResult<PyObject> {
2526
pub enum DefaultType {
2627
None,
2728
Default(PyObject),
28-
DefaultFactory(PyObject),
29+
DefaultFactory(PyObject, bool),
2930
}
3031

3132
impl DefaultType {
@@ -37,23 +38,43 @@ impl DefaultType {
3738
) {
3839
(Some(_), Some(_)) => py_schema_err!("'default' and 'default_factory' cannot be used together"),
3940
(Some(default), None) => Ok(Self::Default(default)),
40-
(None, Some(default_factory)) => Ok(Self::DefaultFactory(default_factory)),
41+
(None, Some(default_factory)) => {
42+
let py_default_factory_takes_data: Bound<PyBool> = schema
43+
.get_as(intern!(py, "default_factory_takes_data"))?
44+
.unwrap_or_else(|| PyBool::new_bound(py, false).to_owned());
45+
Ok(Self::DefaultFactory(
46+
default_factory,
47+
py_default_factory_takes_data.is_true(),
48+
))
49+
}
4150
(None, None) => Ok(Self::None),
4251
}
4352
}
4453

45-
pub fn default_value(&self, py: Python) -> PyResult<Option<PyObject>> {
54+
pub fn default_value(&self, py: Python, validated_data: &Option<Bound<PyDict>>) -> PyResult<Option<PyObject>> {
4655
match self {
4756
Self::Default(ref default) => Ok(Some(default.clone_ref(py))),
48-
Self::DefaultFactory(ref default_factory) => Ok(Some(default_factory.call0(py)?)),
57+
Self::DefaultFactory(ref default_factory, ref takes_data) => {
58+
let result = if *takes_data {
59+
if validated_data.is_none() {
60+
default_factory.call1(py, ({},))
61+
} else {
62+
default_factory.call1(py, (validated_data.as_deref().unwrap(),))
63+
}
64+
} else {
65+
default_factory.call0(py)
66+
};
67+
68+
Ok(Some(result?))
69+
}
4970
Self::None => Ok(None),
5071
}
5172
}
5273
}
5374

5475
impl PyGcTraverse for DefaultType {
5576
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
56-
if let Self::Default(obj) | Self::DefaultFactory(obj) = self {
77+
if let Self::Default(obj) | Self::DefaultFactory(obj, _) = self {
5778
visit.call(obj)?;
5879
}
5980
Ok(())
@@ -163,7 +184,7 @@ impl Validator for WithDefaultValidator {
163184
outer_loc: Option<impl Into<LocItem>>,
164185
state: &mut ValidationState<'_, 'py>,
165186
) -> ValResult<Option<PyObject>> {
166-
match self.default.default_value(py)? {
187+
match self.default.default_value(py, &state.extra().data)? {
167188
Some(stored_dft) => {
168189
let dft: Py<PyAny> = if self.copy_default {
169190
let deepcopy_func = COPY_DEEPCOPY.get_or_init(py, || get_deepcopy(py).unwrap());

tests/validators/test_with_default.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,18 @@ def broken():
227227
v.validate_python('wrong')
228228

229229

230-
def test_factory_type_error():
230+
def test_factory_missing_arg():
231231
def broken(x):
232232
return 7
233233

234234
v = SchemaValidator(
235-
{'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'default', 'default_factory': broken}
235+
{
236+
'type': 'default',
237+
'schema': {'type': 'int'},
238+
'on_error': 'default',
239+
'default_factory': broken,
240+
'default_factory_takes_data': False,
241+
}
236242
)
237243
assert v.validate_python(42) == 42
238244
assert v.validate_python('42') == 42

0 commit comments

Comments
 (0)