Skip to content

Support default factories taking validated data as an argument #1491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2375,7 +2375,8 @@ class WithDefaultSchema(TypedDict, total=False):
type: Required[Literal['default']]
schema: Required[CoreSchema]
default: Any
default_factory: Callable[[], Any]
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any]]
default_factory_takes_data: bool
on_error: Literal['raise', 'omit', 'default'] # default: 'raise'
validate_default: bool # default: False
strict: bool
Expand All @@ -2388,7 +2389,8 @@ def with_default_schema(
schema: CoreSchema,
*,
default: Any = PydanticUndefined,
default_factory: Callable[[], Any] | None = None,
default_factory: Union[Callable[[], Any], Callable[[Dict[str, Any]], Any], None] = None,
default_factory_takes_data: bool | None = None,
on_error: Literal['raise', 'omit', 'default'] | None = None,
validate_default: bool | None = None,
strict: bool | None = None,
Expand All @@ -2413,7 +2415,8 @@ def with_default_schema(
Args:
schema: The schema to add a default value to
default: The default value to use
default_factory: A function that returns the default value to use
default_factory: A callable that returns the default value to use
default_factory_takes_data: Whether the default factory takes a validated data argument
on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default'
validate_default: Whether the default value should be validated
strict: Whether the underlying schema should be validated with strict mode
Expand All @@ -2425,6 +2428,7 @@ def with_default_schema(
type='default',
schema=schema,
default_factory=default_factory,
default_factory_takes_data=default_factory_takes_data,
on_error=on_error,
validate_default=validate_default,
strict=strict,
Expand Down
10 changes: 9 additions & 1 deletion src/serializers/type_serializers/with_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ impl TypeSerializer for WithDefaultSerializer {
}

fn get_default(&self, py: Python) -> PyResult<Option<PyObject>> {
self.default.default_value(py)
if let DefaultType::DefaultFactory(_, _takes_data @ true) = self.default {
// We currently don't compute the default if the default factory takes
// the data from other fields.
Ok(None)
} else {
self.default.default_value(
py, &None, // Won't be used.
)
}
}
}
29 changes: 23 additions & 6 deletions src/validators/with_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn get_deepcopy(py: Python) -> PyResult<PyObject> {
pub enum DefaultType {
None,
Default(PyObject),
DefaultFactory(PyObject),
DefaultFactory(PyObject, bool),
}

impl DefaultType {
Expand All @@ -37,23 +37,40 @@ impl DefaultType {
) {
(Some(_), Some(_)) => py_schema_err!("'default' and 'default_factory' cannot be used together"),
(Some(default), None) => Ok(Self::Default(default)),
(None, Some(default_factory)) => Ok(Self::DefaultFactory(default_factory)),
(None, Some(default_factory)) => Ok(Self::DefaultFactory(
default_factory,
schema
.get_as::<bool>(intern!(py, "default_factory_takes_data"))?
.unwrap_or(false),
)),
(None, None) => Ok(Self::None),
}
}

pub fn default_value(&self, py: Python) -> PyResult<Option<PyObject>> {
pub fn default_value(&self, py: Python, validated_data: &Option<Bound<PyDict>>) -> PyResult<Option<PyObject>> {
match self {
Self::Default(ref default) => Ok(Some(default.clone_ref(py))),
Self::DefaultFactory(ref default_factory) => Ok(Some(default_factory.call0(py)?)),
Self::DefaultFactory(ref default_factory, ref takes_data) => {
let result = if *takes_data {
if validated_data.is_none() {
default_factory.call1(py, ({},))
} else {
default_factory.call1(py, (validated_data.as_deref().unwrap(),))
}
} else {
default_factory.call0(py)
};

Ok(Some(result?))
}
Self::None => Ok(None),
}
}
}

impl PyGcTraverse for DefaultType {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
if let Self::Default(obj) | Self::DefaultFactory(obj) = self {
if let Self::Default(obj) | Self::DefaultFactory(obj, _) = self {
visit.call(obj)?;
}
Ok(())
Expand Down Expand Up @@ -163,7 +180,7 @@ impl Validator for WithDefaultValidator {
outer_loc: Option<impl Into<LocItem>>,
state: &mut ValidationState<'_, 'py>,
) -> ValResult<Option<PyObject>> {
match self.default.default_value(py)? {
match self.default.default_value(py, &state.extra().data)? {
Some(stored_dft) => {
let dft: Py<PyAny> = if self.copy_default {
let deepcopy_func = COPY_DEEPCOPY.get_or_init(py, || get_deepcopy(py).unwrap());
Expand Down
10 changes: 8 additions & 2 deletions tests/validators/test_with_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,18 @@ def broken():
v.validate_python('wrong')


def test_factory_type_error():
def test_factory_missing_arg():
def broken(x):
return 7

v = SchemaValidator(
{'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'default', 'default_factory': broken}
{
'type': 'default',
'schema': {'type': 'int'},
'on_error': 'default',
'default_factory': broken,
'default_factory_takes_data': False,
}
)
assert v.validate_python(42) == 42
assert v.validate_python('42') == 42
Expand Down
Loading