Skip to content

Commit 4939f0b

Browse files
committed
Add object id handling and deserialization
1 parent 0a1dfd5 commit 4939f0b

File tree

3 files changed

+301
-44
lines changed

3 files changed

+301
-44
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ edition = "2018"
88

99
[dependencies]
1010
serde = "^1.0.101"
11+
bson = "^0.14.0"
1112

1213
[dev-dependencies]
13-
bson = "^0.14.0"
1414
serde = { version = "^1.0.101", features = ["derive"] }

src/de.rs

Lines changed: 265 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,13 @@ impl<'a, 'de: 'a> Deserializer<'de> for &'a mut BsonDeserializer<'de> {
115115
type Error = Error;
116116

117117
fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
118+
println!("deserialize any");
118119
match self.bson {
119120
BsonFragment::Key(_) => self.deserialize_identifier(visitor),
120121
BsonFragment::Element(bson) => match bson.bsontype {
121122
BsonType::String => self.deserialize_str(visitor),
122123
BsonType::Document => self.deserialize_map(visitor),
124+
BsonType::ObjectId => self.deserialize_map(visitor),
123125
},
124126
}
125127
}
@@ -179,7 +181,9 @@ impl<'a, 'de: 'a> Deserializer<'de> for &'a mut BsonDeserializer<'de> {
179181
println!("deserialize string");
180182

181183
match &self.bson {
182-
BsonFragment::Element(bson) => visitor.visit_str(bson.as_str().ok_or(Error::MalformedDocument)?),
184+
BsonFragment::Element(bson) => {
185+
visitor.visit_str(bson.as_str().ok_or(Error::MalformedDocument)?)
186+
}
183187
BsonFragment::Key(k) => visitor.visit_str(k),
184188
}
185189
}
@@ -215,12 +219,16 @@ impl<'a, 'de: 'a> Deserializer<'de> for &'a mut BsonDeserializer<'de> {
215219
}
216220

217221
fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
218-
self.bson.apply_element(|bson| {
219-
println!("deserialize map");
220-
let doc = bson.as_doc().ok_or(Error::MalformedDocument)?;
221-
let mapper = BsonDocumentMap::new(doc.into_iter());
222-
Ok(visitor.visit_map(mapper)?)
223-
})
222+
self.bson.apply_element(|bson| match bson.bsontype {
223+
BsonType::Document => {
224+
println!("deserialize map with type: {:?}", bson.bsontype);
225+
let doc = bson.as_doc().ok_or(Error::MalformedDocument)?;
226+
let mapper = BsonDocumentMap::new(doc.into_iter());
227+
Ok(visitor.visit_map(mapper)?)
228+
}
229+
BsonType::ObjectId => visitor.visit_map(object_id::ObjectIdDeserializer::new(*bson)),
230+
_ => Err(Error::MalformedDocument),
231+
})
224232
}
225233

226234
fn deserialize_tuple<V: Visitor<'de>>(
@@ -308,50 +316,285 @@ impl<'de> MapAccess<'de> for BsonDocumentMap<'de> {
308316
}
309317
}
310318

319+
pub mod object_id {
320+
// ObjectId handling
321+
322+
use serde::de::{DeserializeSeed, Deserializer, MapAccess, Visitor};
323+
324+
use super::Error;
325+
use crate::{BsonType, RawBson};
326+
327+
pub static FIELD: &str = "$__bson_object_id";
328+
pub static NAME: &str = "$__bson_ObjectId";
329+
330+
pub(super) struct ObjectIdDeserializer<'de> {
331+
bson: RawBson<'de>,
332+
visited: bool,
333+
}
334+
335+
impl<'de> ObjectIdDeserializer<'de> {
336+
pub(super) fn new(bson: RawBson<'de>) -> ObjectIdDeserializer<'de> {
337+
ObjectIdDeserializer {
338+
bson,
339+
visited: false,
340+
}
341+
}
342+
}
343+
344+
impl<'de> MapAccess<'de> for ObjectIdDeserializer<'de> {
345+
type Error = Error;
346+
347+
fn next_key_seed<K>(
348+
&mut self,
349+
seed: K,
350+
) -> Result<Option<<K as DeserializeSeed<'de>>::Value>, Self::Error>
351+
where
352+
K: DeserializeSeed<'de>,
353+
{
354+
println!("object id next key");
355+
if self.visited {
356+
Ok(None)
357+
} else {
358+
self.visited = true;
359+
seed.deserialize(ObjectIdFieldDeserializer).map(Some)
360+
}
361+
}
362+
363+
fn next_value_seed<V>(
364+
&mut self,
365+
seed: V,
366+
) -> Result<<V as DeserializeSeed<'de>>::Value, Self::Error>
367+
where
368+
V: DeserializeSeed<'de>,
369+
{
370+
println!("object id next value");
371+
372+
seed.deserialize(ObjectIdValueDeserializer(self.bson))
373+
}
374+
}
375+
376+
struct ObjectIdFieldDeserializer;
377+
378+
impl<'de> Deserializer<'de> for ObjectIdFieldDeserializer {
379+
type Error = Error;
380+
381+
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
382+
where
383+
V: Visitor<'de>,
384+
{
385+
visitor.visit_borrowed_str(FIELD)
386+
}
387+
388+
serde::forward_to_deserialize_any!(
389+
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq
390+
bytes byte_buf map struct option unit newtype_struct
391+
ignored_any unit_struct tuple_struct tuple enum identifier
392+
);
393+
}
394+
395+
struct ObjectIdValueDeserializer<'de>(RawBson<'de>);
396+
397+
impl<'de> ObjectIdValueDeserializer<'de> {
398+
fn new(bson: RawBson<'de>) -> ObjectIdValueDeserializer<'de> {
399+
ObjectIdValueDeserializer(bson)
400+
}
401+
}
402+
403+
impl<'de> Deserializer<'de> for ObjectIdValueDeserializer<'de> {
404+
type Error = Error;
405+
406+
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
407+
where
408+
V: Visitor<'de>,
409+
{
410+
match self.0.bsontype {
411+
BsonType::ObjectId => visitor.visit_borrowed_bytes(self.0.data),
412+
_ => Err(Error::MalformedDocument),
413+
}
414+
}
415+
416+
serde::forward_to_deserialize_any!(
417+
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq
418+
bytes byte_buf map struct option unit newtype_struct
419+
ignored_any unit_struct tuple_struct tuple enum identifier
420+
);
421+
}
422+
}
423+
311424
#[cfg(test)]
312-
mod tests{
425+
mod tests {
313426
use std::collections::HashMap;
427+
314428
use bson::{bson, doc, encode_document};
429+
use serde::Deserialize;
315430

316-
use serde::{Deserialize};
317431
use super::from_bytes;
318432

433+
mod object_id {
434+
use std::fmt;
435+
436+
use serde::de::{Deserialize, Deserializer, MapAccess, Visitor};
437+
use serde::export::fmt::Error;
438+
use serde::export::Formatter;
439+
440+
use crate::de::object_id;
441+
442+
#[derive(Debug)]
443+
pub(super) struct ObjectId(Vec<u8>);
444+
445+
impl ObjectId {
446+
pub(super) fn from_bytes(bytes: Vec<u8>) -> Option<ObjectId> {
447+
if bytes.len() == 12 {
448+
Some(ObjectId(bytes))
449+
} else {
450+
None
451+
}
452+
}
453+
}
454+
455+
impl<'de> Deserialize<'de> for ObjectId {
456+
fn deserialize<D>(deserializer: D) -> Result<ObjectId, D::Error>
457+
where
458+
D: Deserializer<'de>,
459+
{
460+
struct ObjectIdVisitor;
461+
462+
impl<'de> Visitor<'de> for ObjectIdVisitor {
463+
type Value = ObjectId;
464+
465+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
466+
formatter.write_str("a bson objectid")
467+
}
468+
469+
fn visit_map<V>(self, mut visitor: V) -> Result<ObjectId, V::Error>
470+
where
471+
V: MapAccess<'de>,
472+
{
473+
let value = visitor.next_key::<ObjectIdKey>()?;
474+
if value.is_none() {
475+
return Err(serde::de::Error::custom(
476+
"No ObjectIdKey not found in synthesized struct",
477+
));
478+
}
479+
let v: ObjectIdFromBytes = visitor.next_value()?;
480+
Ok(v.0)
481+
}
482+
}
483+
484+
static FIELDS: [&str; 1] = [object_id::FIELD];
485+
deserializer.deserialize_struct(object_id::NAME, &FIELDS, ObjectIdVisitor)
486+
}
487+
}
488+
489+
struct ObjectIdKey;
490+
491+
impl<'de> Deserialize<'de> for ObjectIdKey {
492+
fn deserialize<D>(deserializer: D) -> Result<ObjectIdKey, D::Error>
493+
where
494+
D: Deserializer<'de>,
495+
{
496+
struct FieldVisitor;
497+
498+
impl<'de> Visitor<'de> for FieldVisitor {
499+
type Value = ();
500+
501+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
502+
formatter.write_str("a valid object id field")
503+
}
504+
505+
fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<(), E> {
506+
if s == super::super::object_id::FIELD {
507+
Ok(())
508+
} else {
509+
Err(serde::de::Error::custom(
510+
"field was not $__bson_object_id in synthesized object id struct",
511+
))
512+
}
513+
}
514+
}
515+
516+
deserializer.deserialize_identifier(FieldVisitor)?;
517+
Ok(ObjectIdKey)
518+
}
519+
}
520+
521+
struct ObjectIdFromBytes(ObjectId);
522+
523+
impl<'de> Deserialize<'de> for ObjectIdFromBytes {
524+
fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
525+
where
526+
D: Deserializer<'de>,
527+
{
528+
struct FieldVisitor;
529+
530+
impl<'de> Visitor<'de> for FieldVisitor {
531+
type Value = ObjectIdFromBytes;
532+
533+
fn expecting(&self, formatter: &mut Formatter<'_>) -> Result<(), Error> {
534+
formatter.write_str("an object id of twelve bytes")
535+
}
536+
537+
fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
538+
ObjectId::from_bytes(v.to_vec())
539+
.map(ObjectIdFromBytes)
540+
.ok_or(serde::de::Error::custom("invalid object id"))
541+
}
542+
}
543+
544+
deserializer.deserialize_bytes(FieldVisitor)
545+
}
546+
}
547+
}
548+
319549
#[derive(Deserialize)]
320550
struct Person<'a> {
551+
#[serde(rename = "_id")]
552+
id: object_id::ObjectId,
321553
first_name: &'a str,
322554
last_name: String,
323555
}
324556

325557
#[test]
326558
fn deserialize_struct() {
327559
let mut docbytes = Vec::new();
328-
encode_document(&mut docbytes, &doc! {
329-
"first_name": "Edward",
330-
"last_name": "Teach",
331-
}).expect("could not encode document");
332-
let p: Person = super::from_bytes(&docbytes).expect("could not decode into Person struct");
560+
encode_document(
561+
&mut docbytes,
562+
&doc! {
563+
"_id": bson::oid::ObjectId::with_string("abcdefabcdefabcdefabcdef").unwrap(),
564+
"first_name": "Edward",
565+
"last_name": "Teach",
566+
},
567+
)
568+
.expect("could not encode document");
569+
let p: Person = from_bytes(&docbytes).expect("could not decode into Person struct");
333570
assert_eq!(p.first_name, "Edward");
334571
assert_eq!(p.last_name, "Teach");
335572
}
573+
336574
#[test]
337575
fn deserialize_map() {
338576
let mut docbytes = Vec::new();
339-
encode_document(&mut docbytes, &doc! {
340-
"this": "that",
341-
"three": "four",
342-
"keymaster": "gatekeeper",
343-
}).expect("could not encode document");
344-
let map: HashMap<&str, &str> = super::from_bytes(&docbytes).expect("could not decode into HashMap<&str, &str>");
577+
encode_document(
578+
&mut docbytes,
579+
&doc! {
580+
"this": "that",
581+
"three": "four",
582+
"keymaster": "gatekeeper",
583+
},
584+
)
585+
.expect("could not encode document");
586+
let map: HashMap<&str, &str> =
587+
from_bytes(&docbytes).expect("could not decode into HashMap<&str, &str>");
345588
assert_eq!(map.len(), 3);
346589
assert_eq!(*map.get("this").expect("key not found"), "that");
347590
assert_eq!(*map.get("three").expect("key not found"), "four");
348591
assert_eq!(*map.get("keymaster").expect("key not found"), "gatekeeper");
349592

350-
let map: HashMap<String, String> = super::from_bytes(&docbytes).expect("could not decode into HashMap<String, String>");
593+
let map: HashMap<String, String> =
594+
from_bytes(&docbytes).expect("could not decode into HashMap<String, String>");
351595
assert_eq!(map.len(), 3);
352596
assert_eq!(map.get("this").expect("key not found"), "that");
353597
assert_eq!(map.get("three").expect("key not found"), "four");
354598
assert_eq!(map.get("keymaster").expect("key not found"), "gatekeeper");
355599
}
356-
357600
}

0 commit comments

Comments
 (0)