diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 2d51d544d..f68b9af0d 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -2375,7 +2375,8 @@ class WithDefaultSchema(TypedDict, total=False): type: Required[Literal['default']] schema: Required[CoreSchema] default: Any - default_factory: Callable[[], Any] + default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any]] + default_factory_takes_data: bool on_error: Literal['raise', 'omit', 'default'] # default: 'raise' validate_default: bool # default: False strict: bool @@ -2388,7 +2389,8 @@ def with_default_schema( schema: CoreSchema, *, default: Any = PydanticUndefined, - default_factory: Callable[[], Any] | None = None, + default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any], None] = None, + default_factory_takes_data: bool | None = None, on_error: Literal['raise', 'omit', 'default'] | None = None, validate_default: bool | None = None, strict: bool | None = None, @@ -2413,7 +2415,8 @@ def with_default_schema( Args: schema: The schema to add a default value to default: The default value to use - default_factory: A function that returns the default value to use + default_factory: A callable that returns the default value to use + default_factory_takes_data: Whether the default factory takes a validated data argument on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default' validate_default: Whether the default value should be validated strict: Whether the underlying schema should be validated with strict mode @@ -2425,6 +2428,7 @@ def with_default_schema( type='default', schema=schema, default_factory=default_factory, + default_factory_takes_data=default_factory_takes_data, on_error=on_error, validate_default=validate_default, strict=strict, diff --git a/src/serializers/type_serializers/with_default.rs b/src/serializers/type_serializers/with_default.rs index 6d510ce55..7a4225d26 100644 --- a/src/serializers/type_serializers/with_default.rs +++ b/src/serializers/type_serializers/with_default.rs @@ -72,6 +72,14 @@ impl TypeSerializer for WithDefaultSerializer { } fn get_default(&self, py: Python) -> PyResult> { - self.default.default_value(py) + if let DefaultType::DefaultFactory(_, _takes_data @ true) = self.default { + // We currently don't compute the default if the default factory takes + // the data from other fields. + Ok(None) + } else { + self.default.default_value( + py, &None, // Won't be used. + ) + } } } diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index e03830e77..565c31060 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -25,7 +25,7 @@ fn get_deepcopy(py: Python) -> PyResult { pub enum DefaultType { None, Default(PyObject), - DefaultFactory(PyObject), + DefaultFactory(PyObject, bool), } impl DefaultType { @@ -37,15 +37,32 @@ impl DefaultType { ) { (Some(_), Some(_)) => py_schema_err!("'default' and 'default_factory' cannot be used together"), (Some(default), None) => Ok(Self::Default(default)), - (None, Some(default_factory)) => Ok(Self::DefaultFactory(default_factory)), + (None, Some(default_factory)) => Ok(Self::DefaultFactory( + default_factory, + schema + .get_as::(intern!(py, "default_factory_takes_data"))? + .unwrap_or(false), + )), (None, None) => Ok(Self::None), } } - pub fn default_value(&self, py: Python) -> PyResult> { + pub fn default_value(&self, py: Python, validated_data: &Option>) -> PyResult> { match self { Self::Default(ref default) => Ok(Some(default.clone_ref(py))), - Self::DefaultFactory(ref default_factory) => Ok(Some(default_factory.call0(py)?)), + Self::DefaultFactory(ref default_factory, ref takes_data) => { + let result = if *takes_data { + if validated_data.is_none() { + default_factory.call1(py, ({},)) + } else { + default_factory.call1(py, (validated_data.as_deref().unwrap(),)) + } + } else { + default_factory.call0(py) + }; + + Ok(Some(result?)) + } Self::None => Ok(None), } } @@ -53,7 +70,7 @@ impl DefaultType { impl PyGcTraverse for DefaultType { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { - if let Self::Default(obj) | Self::DefaultFactory(obj) = self { + if let Self::Default(obj) | Self::DefaultFactory(obj, _) = self { visit.call(obj)?; } Ok(()) @@ -163,7 +180,7 @@ impl Validator for WithDefaultValidator { outer_loc: Option>, state: &mut ValidationState<'_, 'py>, ) -> ValResult> { - match self.default.default_value(py)? { + match self.default.default_value(py, &state.extra().data)? { Some(stored_dft) => { let dft: Py = if self.copy_default { let deepcopy_func = COPY_DEEPCOPY.get_or_init(py, || get_deepcopy(py).unwrap()); diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 189a555bb..105241140 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -227,12 +227,18 @@ def broken(): v.validate_python('wrong') -def test_factory_type_error(): +def test_factory_missing_arg(): def broken(x): return 7 v = SchemaValidator( - {'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'default', 'default_factory': broken} + { + 'type': 'default', + 'schema': {'type': 'int'}, + 'on_error': 'default', + 'default_factory': broken, + 'default_factory_takes_data': False, + } ) assert v.validate_python(42) == 42 assert v.validate_python('42') == 42