Skip to content

Commit a6a58de

Browse files
committed
Support default factories taking validated data as an argument
1 parent 1ced3e6 commit a6a58de

File tree

4 files changed

+47
-12
lines changed

4 files changed

+47
-12
lines changed

python/pydantic_core/core_schema.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,7 +2375,8 @@ class WithDefaultSchema(TypedDict, total=False):
23752375
type: Required[Literal['default']]
23762376
schema: Required[CoreSchema]
23772377
default: Any
2378-
default_factory: Callable[[], Any]
2378+
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any]]
2379+
default_factory_takes_data: bool
23792380
on_error: Literal['raise', 'omit', 'default'] # default: 'raise'
23802381
validate_default: bool # default: False
23812382
strict: bool
@@ -2388,7 +2389,8 @@ def with_default_schema(
23882389
schema: CoreSchema,
23892390
*,
23902391
default: Any = PydanticUndefined,
2391-
default_factory: Callable[[], Any] | None = None,
2392+
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any], None] = None,
2393+
default_factory_takes_data: bool | None = None,
23922394
on_error: Literal['raise', 'omit', 'default'] | None = None,
23932395
validate_default: bool | None = None,
23942396
strict: bool | None = None,
@@ -2413,7 +2415,8 @@ def with_default_schema(
24132415
Args:
24142416
schema: The schema to add a default value to
24152417
default: The default value to use
2416-
default_factory: A function that returns the default value to use
2418+
default_factory: A callable that returns the default value to use
2419+
default_factory_takes_data: Whether the default factory takes a validated data argument
24172420
on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default'
24182421
validate_default: Whether the default value should be validated
24192422
strict: Whether the underlying schema should be validated with strict mode
@@ -2425,6 +2428,7 @@ def with_default_schema(
24252428
type='default',
24262429
schema=schema,
24272430
default_factory=default_factory,
2431+
default_factory_takes_data=default_factory_takes_data,
24282432
on_error=on_error,
24292433
validate_default=validate_default,
24302434
strict=strict,

src/serializers/type_serializers/with_default.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ impl TypeSerializer for WithDefaultSerializer {
7272
}
7373

7474
fn get_default(&self, py: Python) -> PyResult<Option<PyObject>> {
75-
self.default.default_value(py)
75+
if let DefaultType::DefaultFactory(_, _takes_data @ true) = self.default {
76+
// We currently don't compute the default if the default factory takes
77+
// the data from other fields.
78+
Ok(None)
79+
} else {
80+
self.default.default_value(
81+
py, &None, // Won't be used.
82+
)
83+
}
7684
}
7785
}

src/validators/with_default.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ fn get_deepcopy(py: Python) -> PyResult<PyObject> {
2525
pub enum DefaultType {
2626
None,
2727
Default(PyObject),
28-
DefaultFactory(PyObject),
28+
DefaultFactory(PyObject, bool),
2929
}
3030

3131
impl DefaultType {
@@ -37,23 +37,40 @@ impl DefaultType {
3737
) {
3838
(Some(_), Some(_)) => py_schema_err!("'default' and 'default_factory' cannot be used together"),
3939
(Some(default), None) => Ok(Self::Default(default)),
40-
(None, Some(default_factory)) => Ok(Self::DefaultFactory(default_factory)),
40+
(None, Some(default_factory)) => Ok(Self::DefaultFactory(
41+
default_factory,
42+
schema
43+
.get_as::<bool>(intern!(py, "default_factory_takes_data"))?
44+
.unwrap_or(false),
45+
)),
4146
(None, None) => Ok(Self::None),
4247
}
4348
}
4449

45-
pub fn default_value(&self, py: Python) -> PyResult<Option<PyObject>> {
50+
pub fn default_value(&self, py: Python, validated_data: &Option<Bound<PyDict>>) -> PyResult<Option<PyObject>> {
4651
match self {
4752
Self::Default(ref default) => Ok(Some(default.clone_ref(py))),
48-
Self::DefaultFactory(ref default_factory) => Ok(Some(default_factory.call0(py)?)),
53+
Self::DefaultFactory(ref default_factory, ref takes_data) => {
54+
let result = if *takes_data {
55+
if validated_data.is_none() {
56+
default_factory.call1(py, ({},))
57+
} else {
58+
default_factory.call1(py, (validated_data.as_deref().unwrap(),))
59+
}
60+
} else {
61+
default_factory.call0(py)
62+
};
63+
64+
Ok(Some(result?))
65+
}
4966
Self::None => Ok(None),
5067
}
5168
}
5269
}
5370

5471
impl PyGcTraverse for DefaultType {
5572
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
56-
if let Self::Default(obj) | Self::DefaultFactory(obj) = self {
73+
if let Self::Default(obj) | Self::DefaultFactory(obj, _) = self {
5774
visit.call(obj)?;
5875
}
5976
Ok(())
@@ -163,7 +180,7 @@ impl Validator for WithDefaultValidator {
163180
outer_loc: Option<impl Into<LocItem>>,
164181
state: &mut ValidationState<'_, 'py>,
165182
) -> ValResult<Option<PyObject>> {
166-
match self.default.default_value(py)? {
183+
match self.default.default_value(py, &state.extra().data)? {
167184
Some(stored_dft) => {
168185
let dft: Py<PyAny> = if self.copy_default {
169186
let deepcopy_func = COPY_DEEPCOPY.get_or_init(py, || get_deepcopy(py).unwrap());

tests/validators/test_with_default.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,18 @@ def broken():
227227
v.validate_python('wrong')
228228

229229

230-
def test_factory_type_error():
230+
def test_factory_missing_arg():
231231
def broken(x):
232232
return 7
233233

234234
v = SchemaValidator(
235-
{'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'default', 'default_factory': broken}
235+
{
236+
'type': 'default',
237+
'schema': {'type': 'int'},
238+
'on_error': 'default',
239+
'default_factory': broken,
240+
'default_factory_takes_data': False,
241+
}
236242
)
237243
assert v.validate_python(42) == 42
238244
assert v.validate_python('42') == 42

0 commit comments

Comments
 (0)