Skip to content

Commit 7dc970f

Browse files
authored
fix for model subclass serialization in unions (pydantic#632)
1 parent 726ef5f commit 7dc970f

File tree

3 files changed

+105
-11
lines changed

3 files changed

+105
-11
lines changed

src/serializers/fields.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use pyo3::types::{PyDict, PyString};
66
use ahash::AHashMap;
77
use serde::ser::SerializeMap;
88

9+
use crate::serializers::extra::SerCheck;
910
use crate::PydanticSerializationUnexpectedValue;
1011

1112
use super::computed_fields::ComputedFields;
@@ -75,7 +76,7 @@ fn exclude_default(value: &PyAny, extra: &Extra, serializer: &CombinedSerializer
7576
Ok(false)
7677
}
7778

78-
#[derive(Debug, Clone)]
79+
#[derive(Debug, Clone, Eq, PartialEq)]
7980
pub(super) enum FieldsMode {
8081
// typeddict with no extra items
8182
SimpleDict,
@@ -196,10 +197,10 @@ impl TypeSerializer for GeneralFieldsSerializer {
196197
if field.required {
197198
used_req_fields += 1;
198199
}
199-
} else if matches!(self.mode, FieldsMode::TypedDictAllow) {
200+
} else if self.mode == FieldsMode::TypedDictAllow {
200201
let value = infer_to_python(value, next_include, next_exclude, &extra)?;
201202
output_dict.set_item(key, value)?;
202-
} else if extra.check.enabled() {
203+
} else if extra.check == SerCheck::Strict {
203204
return Err(PydanticSerializationUnexpectedValue::new_err(None));
204205
}
205206
}
@@ -281,11 +282,12 @@ impl TypeSerializer for GeneralFieldsSerializer {
281282
map.serialize_entry(&output_key, &s)?;
282283
}
283284
}
284-
} else if matches!(self.mode, FieldsMode::TypedDictAllow) {
285+
} else if self.mode == FieldsMode::TypedDictAllow {
285286
let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?;
286287
let s = SerializeInfer::new(value, next_include, next_exclude, &extra);
287288
map.serialize_entry(&output_key, &s)?
288289
}
290+
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
289291
}
290292
}
291293
// this is used to include `__pydantic_extra__` in serialization on models

src/serializers/type_serializers/model.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,10 @@ fn has_extra(schema: &PyDict, config: Option<&PyDict>) -> PyResult<bool> {
109109

110110
impl ModelSerializer {
111111
fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult<bool> {
112+
let class = self.class.as_ref(value.py());
112113
match extra.check {
113-
SerCheck::Strict => Ok(value.get_type().is(self.class.as_ref(value.py()))),
114-
SerCheck::Lax => value.is_instance(self.class.as_ref(value.py())),
114+
SerCheck::Strict => Ok(value.get_type().is(class)),
115+
SerCheck::Lax => value.is_instance(class),
115116
SerCheck::None => value.hasattr(intern!(value.py(), "__dict__")),
116117
}
117118
}

tests/serializers/test_union.py

+96-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
import json
23
import re
34

@@ -179,17 +180,16 @@ def test_typed_dict_literal():
179180

180181

181182
def test_typed_dict_missing():
182-
"""
183-
TODO, needs tests for each case
184-
"""
185183
s = SchemaSerializer(
186184
core_schema.union_schema(
187185
[
188186
core_schema.typed_dict_schema(dict(foo=core_schema.typed_dict_field(core_schema.int_schema()))),
189187
core_schema.typed_dict_schema(
190188
dict(
191189
foo=core_schema.typed_dict_field(
192-
core_schema.int_schema(serialization=core_schema.format_ser_schema('04d'))
190+
core_schema.int_schema(
191+
serialization=core_schema.format_ser_schema('04d', when_used='always')
192+
)
193193
),
194194
bar=core_schema.typed_dict_field(core_schema.int_schema()),
195195
)
@@ -201,7 +201,8 @@ def test_typed_dict_missing():
201201
assert s.to_python(dict(foo=1)) == {'foo': 1}
202202
assert s.to_python(dict(foo=1), mode='json') == {'foo': 1}
203203
assert s.to_json(dict(foo=1)) == b'{"foo":1}'
204-
assert s.to_python(dict(foo=1, bar=2)) == {'foo': 1, 'bar': 2}
204+
205+
assert s.to_python(dict(foo=1, bar=2)) == {'foo': '0001', 'bar': 2}
205206
assert s.to_python(dict(foo=1, bar=2), mode='json') == {'foo': '0001', 'bar': 2}
206207
assert s.to_json(dict(foo=1, bar=2)) == b'{"foo":"0001","bar":2}'
207208

@@ -269,3 +270,93 @@ def test_typed_dict_different_fields():
269270
assert s.to_python(dict(spam=1, ham=2)) == {'spam': 1, 'ham': 2}
270271
assert s.to_python(dict(spam=1, ham=2), mode='json') == {'spam': 1, 'ham': '0002'}
271272
assert s.to_json(dict(spam=1, ham=2)) == b'{"spam":1,"ham":"0002"}'
273+
274+
275+
def test_dataclass_union():
276+
@dataclasses.dataclass
277+
class BaseUser:
278+
name: str
279+
280+
@dataclasses.dataclass
281+
class User(BaseUser):
282+
surname: str
283+
284+
@dataclasses.dataclass
285+
class DBUser(User):
286+
password_hash: str
287+
288+
@dataclasses.dataclass
289+
class Item:
290+
name: str
291+
price: float
292+
293+
user_schema = core_schema.dataclass_schema(
294+
User,
295+
core_schema.dataclass_args_schema(
296+
'User',
297+
[
298+
core_schema.dataclass_field(name='name', schema=core_schema.str_schema()),
299+
core_schema.dataclass_field(name='surname', schema=core_schema.str_schema()),
300+
],
301+
),
302+
['name', 'surname'],
303+
)
304+
item_schema = core_schema.dataclass_schema(
305+
Item,
306+
core_schema.dataclass_args_schema(
307+
'Item',
308+
[
309+
core_schema.dataclass_field(name='name', schema=core_schema.str_schema()),
310+
core_schema.dataclass_field(name='price', schema=core_schema.float_schema()),
311+
],
312+
),
313+
['name', 'price'],
314+
)
315+
s = SchemaSerializer(core_schema.union_schema([user_schema, item_schema]))
316+
assert s.to_python(User(name='foo', surname='bar')) == {'name': 'foo', 'surname': 'bar'}
317+
assert s.to_python(DBUser(name='foo', surname='bar', password_hash='x')) == {'name': 'foo', 'surname': 'bar'}
318+
assert s.to_json(DBUser(name='foo', surname='bar', password_hash='x')) == b'{"name":"foo","surname":"bar"}'
319+
320+
321+
def test_model_union():
322+
class BaseUser:
323+
def __init__(self, name: str):
324+
self.name = name
325+
326+
class User(BaseUser):
327+
def __init__(self, name: str, surname: str):
328+
super().__init__(name)
329+
self.surname = surname
330+
331+
class DBUser(User):
332+
def __init__(self, name: str, surname: str, password_hash: str):
333+
super().__init__(name, surname)
334+
self.password_hash = password_hash
335+
336+
class Item:
337+
def __init__(self, name: str, price: float):
338+
self.name = name
339+
self.price = price
340+
341+
user_schema = core_schema.model_schema(
342+
User,
343+
core_schema.model_fields_schema(
344+
{
345+
'name': core_schema.model_field(schema=core_schema.str_schema()),
346+
'surname': core_schema.model_field(schema=core_schema.str_schema()),
347+
}
348+
),
349+
)
350+
item_schema = core_schema.model_schema(
351+
Item,
352+
core_schema.model_fields_schema(
353+
{
354+
'name': core_schema.model_field(schema=core_schema.str_schema()),
355+
'price': core_schema.model_field(schema=core_schema.float_schema()),
356+
}
357+
),
358+
)
359+
s = SchemaSerializer(core_schema.union_schema([user_schema, item_schema]))
360+
assert s.to_python(User(name='foo', surname='bar')) == {'name': 'foo', 'surname': 'bar'}
361+
assert s.to_python(DBUser(name='foo', surname='bar', password_hash='x')) == {'name': 'foo', 'surname': 'bar'}
362+
assert s.to_json(DBUser(name='foo', surname='bar', password_hash='x')) == b'{"name":"foo","surname":"bar"}'

0 commit comments

Comments
 (0)