diff --git a/serde-tests/Cargo.toml b/serde-tests/Cargo.toml index d5bbe9e7..93da0365 100644 --- a/serde-tests/Cargo.toml +++ b/serde-tests/Cargo.toml @@ -13,6 +13,11 @@ serde = { version = "1.0", features = ["derive"] } pretty_assertions = "0.6.1" hex = "0.4.2" +[dev-dependencies] +serde_json = "1" +rmp-serde = "0.15" +base64 = "0.13.0" + [lib] name = "serde_tests" path = "lib.rs" diff --git a/serde-tests/test.rs b/serde-tests/test.rs index 12b2356c..2db16982 100644 --- a/serde-tests/test.rs +++ b/serde-tests/test.rs @@ -21,11 +21,21 @@ use bson::{ Binary, Bson, DateTime, + Decimal128, Deserializer, Document, JavaScriptCodeWithScope, + RawArray, + RawBinary, + RawBson, + RawDbPointer, + RawDocument, + RawDocumentBuf, + RawJavaScriptCodeWithScope, + RawRegex, Regex, Timestamp, + Uuid, }; /// Verifies the following: @@ -112,6 +122,18 @@ where ); } +/// Verifies the following: +/// - Deserializing a `T` from the provided bytes does not error +/// - Serializing the `T` back to bytes produces the input. +fn run_raw_round_trip_test<'de, T>(bytes: &'de [u8], description: &str) +where + T: Deserialize<'de> + Serialize + std::fmt::Debug, +{ + let t: T = bson::from_slice(bytes).expect(description); + let vec = bson::to_vec(&t).expect(description); + assert_eq!(vec.as_slice(), bytes); +} + #[test] fn smoke() { #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -683,135 +705,411 @@ fn empty_array() { } #[test] -fn all_types() { - #[derive(Debug, Deserialize, Serialize, PartialEq)] - struct Bar { - a: i32, - b: i32, +fn raw_doc_buf() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo { + d: RawDocumentBuf, } - #[derive(Debug, Deserialize, Serialize, PartialEq)] - struct Foo { - x: i32, - y: i64, - s: String, - array: Vec, - bson: Bson, - oid: ObjectId, - null: Option<()>, - subdoc: Document, - b: bool, - d: f64, - binary: Binary, - binary_old: Binary, - binary_other: Binary, - date: DateTime, - regex: Regex, - ts: Timestamp, - i: Bar, - undefined: Bson, - code: Bson, - code_w_scope: JavaScriptCodeWithScope, - decimal: Bson, - symbol: Bson, - min_key: Bson, - max_key: Bson, + let bytes = bson::to_vec(&doc! { + "d": { + "a": 12, + "b": 5.5, + "c": [1, true, "ok"], + "d": { "a": "b" }, + "e": ObjectId::new(), + } + }) + .expect("raw_doc_buf"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_doc_buf"); +} + +#[test] +fn raw_doc() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + d: &'a RawDocument, } - let binary = Binary { - bytes: vec![36, 36, 36], - subtype: BinarySubtype::Generic, - }; - let binary_old = Binary { - bytes: vec![36, 36, 36], - subtype: BinarySubtype::BinaryOld, - }; - let binary_other = Binary { - bytes: vec![36, 36, 36], - subtype: BinarySubtype::UserDefined(0x81), - }; - let date = DateTime::now(); - let regex = Regex { - pattern: "hello".to_string(), - options: "x".to_string(), - }; - let timestamp = Timestamp { - time: 123, - increment: 456, - }; - let code = Bson::JavaScriptCode("console.log(1)".to_string()); - let code_w_scope = JavaScriptCodeWithScope { - code: "console.log(a)".to_string(), - scope: doc! { "a": 1 }, + let bytes = bson::to_vec(&doc! { + "d": { + "a": 12, + "b": 5.5, + "c": [1, true, "ok"], + "d": { "a": "b" }, + "e": ObjectId::new(), + } + }) + .expect("raw doc"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_doc"); +} + +#[test] +fn raw_array() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + d: &'a RawArray, + } + + let bytes = bson::to_vec(&doc! { + "d": [1, true, { "ok": 1 }, [ "sub", "array" ], Uuid::new()] + }) + .expect("raw_array"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_array"); +} + +#[test] +fn raw_binary() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + generic: RawBinary<'a>, + + #[serde(borrow)] + old: RawBinary<'a>, + + #[serde(borrow)] + uuid: RawBinary<'a>, + + #[serde(borrow)] + other: RawBinary<'a>, + } + + let bytes = bson::to_vec(&doc! { + "generic": Binary { + bytes: vec![1, 2, 3, 4, 5], + subtype: BinarySubtype::Generic, + }, + "old": Binary { + bytes: vec![1, 2, 3], + subtype: BinarySubtype::BinaryOld, + }, + "uuid": Uuid::new(), + "other": Binary { + bytes: vec![1u8; 100], + subtype: BinarySubtype::UserDefined(100), + } + }) + .expect("raw_binary"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_binary"); +} + +#[test] +fn raw_regex() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + r: RawRegex<'a>, + } + + let bytes = bson::to_vec(&doc! { + "r": Regex { + pattern: "a[b-c]d".to_string(), + options: "ab".to_string(), + }, + }) + .expect("raw_regex"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_regex"); +} + +#[test] +fn raw_code_w_scope() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + r: RawJavaScriptCodeWithScope<'a>, + } + + let bytes = bson::to_vec(&doc! { + "r": JavaScriptCodeWithScope { + code: "console.log(x)".to_string(), + scope: doc! { "x": 1 }, + }, + }) + .expect("raw_code_w_scope"); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_code_w_scope"); +} + +#[test] +fn raw_db_pointer() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo<'a> { + #[serde(borrow)] + a: RawDbPointer<'a>, + } + + // From the "DBpointer" bson corpus test + let bytes = hex::decode("1A0000000C610002000000620056E1FC72E0C917E9C471416100").unwrap(); + + run_raw_round_trip_test::(bytes.as_slice(), "raw_db_pointer"); +} + +#[derive(Debug, Deserialize, Serialize, PartialEq)] +struct SubDoc { + a: i32, + b: i32, +} + +#[derive(Debug, Deserialize, Serialize, PartialEq)] +struct AllTypes { + x: i32, + y: i64, + s: String, + array: Vec, + bson: Bson, + oid: ObjectId, + null: Option<()>, + subdoc: Document, + b: bool, + d: f64, + binary: Binary, + binary_old: Binary, + binary_other: Binary, + date: DateTime, + regex: Regex, + ts: Timestamp, + i: SubDoc, + undefined: Bson, + code: Bson, + code_w_scope: JavaScriptCodeWithScope, + decimal: Decimal128, + symbol: Bson, + min_key: Bson, + max_key: Bson, +} + +impl AllTypes { + fn fixtures() -> (Self, Document) { + let binary = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::Generic, + }; + let binary_old = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::BinaryOld, + }; + let binary_other = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::UserDefined(0x81), + }; + let date = DateTime::now(); + let regex = Regex { + pattern: "hello".to_string(), + options: "x".to_string(), + }; + let timestamp = Timestamp { + time: 123, + increment: 456, + }; + let code = Bson::JavaScriptCode("console.log(1)".to_string()); + let code_w_scope = JavaScriptCodeWithScope { + code: "console.log(a)".to_string(), + scope: doc! { "a": 1 }, + }; + let oid = ObjectId::new(); + let subdoc = doc! { "k": true, "b": { "hello": "world" } }; + + let decimal = { + let bytes = hex::decode("18000000136400D0070000000000000000000000003A3000").unwrap(); + let d = Document::from_reader(bytes.as_slice()).unwrap(); + match d.get("d") { + Some(Bson::Decimal128(d)) => *d, + c => panic!("expected decimal128, got {:?}", c), + } + }; + + let doc = doc! { + "x": 1, + "y": 2_i64, + "s": "oke", + "array": [ true, "oke", { "12": 24 } ], + "bson": 1234.5, + "oid": oid, + "null": Bson::Null, + "subdoc": subdoc.clone(), + "b": true, + "d": 12.5, + "binary": binary.clone(), + "binary_old": binary_old.clone(), + "binary_other": binary_other.clone(), + "date": date, + "regex": regex.clone(), + "ts": timestamp, + "i": { "a": 300, "b": 12345 }, + "undefined": Bson::Undefined, + "code": code.clone(), + "code_w_scope": code_w_scope.clone(), + "decimal": Bson::Decimal128(decimal), + "symbol": Bson::Symbol("ok".to_string()), + "min_key": Bson::MinKey, + "max_key": Bson::MaxKey, + }; + + let v = AllTypes { + x: 1, + y: 2, + s: "oke".to_string(), + array: vec![ + Bson::Boolean(true), + Bson::String("oke".to_string()), + Bson::Document(doc! { "12": 24 }), + ], + bson: Bson::Double(1234.5), + oid, + null: None, + subdoc, + b: true, + d: 12.5, + binary, + binary_old, + binary_other, + date, + regex, + ts: timestamp, + i: SubDoc { a: 300, b: 12345 }, + undefined: Bson::Undefined, + code, + code_w_scope, + decimal, + symbol: Bson::Symbol("ok".to_string()), + min_key: Bson::MinKey, + max_key: Bson::MaxKey, + }; + + (v, doc) + } +} + +#[test] +fn all_types() { + let (v, doc) = AllTypes::fixtures(); + + run_test(&v, &doc, "all types"); +} + +#[test] +fn all_types_json() { + let (mut v, _) = AllTypes::fixtures(); + + let code = match v.code { + Bson::JavaScriptCode(ref c) => c.clone(), + c => panic!("expected code, found {:?}", c), }; - let oid = ObjectId::new(); - let subdoc = doc! { "k": true, "b": { "hello": "world" } }; - let decimal = { - let bytes = hex::decode("18000000136400D0070000000000000000000000003A3000").unwrap(); - let d = Document::from_reader(bytes.as_slice()).unwrap(); - d.get("d").unwrap().clone() + let code_w_scope = JavaScriptCodeWithScope { + code: "hello world".to_string(), + scope: doc! { "x": 1 }, }; + let scope_json = serde_json::json!({ "x": 1 }); + v.code_w_scope = code_w_scope.clone(); - let doc = doc! { + let json = serde_json::json!({ "x": 1, - "y": 2_i64, + "y": 2, "s": "oke", - "array": [ true, "oke", { "12": 24 } ], + "array": vec![ + serde_json::json!(true), + serde_json::json!("oke".to_string()), + serde_json::json!({ "12": 24 }), + ], "bson": 1234.5, - "oid": oid, - "null": Bson::Null, - "subdoc": subdoc.clone(), + "oid": { "$oid": v.oid.to_hex() }, + "null": serde_json::Value::Null, + "subdoc": { "k": true, "b": { "hello": "world" } }, "b": true, "d": 12.5, - "binary": binary.clone(), - "binary_old": binary_old.clone(), - "binary_other": binary_other.clone(), - "date": date, - "regex": regex.clone(), - "ts": timestamp, - "i": { "a": 300, "b": 12345 }, - "undefined": Bson::Undefined, - "code": code.clone(), - "code_w_scope": code_w_scope.clone(), - "decimal": decimal.clone(), - "symbol": Bson::Symbol("ok".to_string()), - "min_key": Bson::MinKey, - "max_key": Bson::MaxKey, - }; + "binary": v.binary.bytes, + "binary_old": { "$binary": { "base64": base64::encode(&v.binary_old.bytes), "subType": "02" } }, + "binary_other": { "$binary": { "base64": base64::encode(&v.binary_old.bytes), "subType": "81" } }, + "date": { "$date": { "$numberLong": v.date.timestamp_millis().to_string() } }, + "regex": { "$regularExpression": { "pattern": v.regex.pattern, "options": v.regex.options } }, + "ts": { "$timestamp": { "t": 123, "i": 456 } }, + "i": { "a": v.i.a, "b": v.i.b }, + "undefined": { "$undefined": true }, + "code": { "$code": code }, + "code_w_scope": { "$code": code_w_scope.code, "$scope": scope_json }, + "decimal": { "$numberDecimalBytes": v.decimal.bytes() }, + "symbol": { "$symbol": "ok" }, + "min_key": { "$minKey": 1 }, + "max_key": { "$maxKey": 1 }, + }); + + assert_eq!(serde_json::to_value(&v).unwrap(), json); +} - let v = Foo { - x: 1, - y: 2, - s: "oke".to_string(), - array: vec![ - Bson::Boolean(true), - Bson::String("oke".to_string()), - Bson::Document(doc! { "12": 24 }), - ], - bson: Bson::Double(1234.5), - oid, - null: None, - subdoc, - b: true, - d: 12.5, - binary, - binary_old, - binary_other, - date, - regex, - ts: timestamp, - i: Bar { a: 300, b: 12345 }, - undefined: Bson::Undefined, - code, - code_w_scope, - decimal, - symbol: Bson::Symbol("ok".to_string()), - min_key: Bson::MinKey, - max_key: Bson::MaxKey, +#[test] +fn all_types_rmp() { + let (v, _) = AllTypes::fixtures(); + let serialized = rmp_serde::to_vec_named(&v).unwrap(); + let back: AllTypes = rmp_serde::from_slice(&serialized).unwrap(); + + assert_eq!(back, v); +} + +#[test] +fn all_raw_types_rmp() { + #[derive(Debug, Serialize, Deserialize, PartialEq)] + struct AllRawTypes<'a> { + #[serde(borrow)] + bson: RawBson<'a>, + #[serde(borrow)] + document: &'a RawDocument, + #[serde(borrow)] + array: &'a RawArray, + buf: RawDocumentBuf, + #[serde(borrow)] + binary: RawBinary<'a>, + #[serde(borrow)] + code_w_scope: RawJavaScriptCodeWithScope<'a>, + #[serde(borrow)] + regex: RawRegex<'a>, + } + + let doc_bytes = bson::to_vec(&doc! { + "bson": "some string", + "array": [1, 2, 3], + "binary": Binary { bytes: vec![1, 2, 3], subtype: BinarySubtype::Generic }, + "binary_old": Binary { bytes: vec![1, 2, 3], subtype: BinarySubtype::BinaryOld }, + "code_w_scope": JavaScriptCodeWithScope { + code: "ok".to_string(), + scope: doc! { "x": 1 }, + }, + "regex": Regex { + pattern: "pattern".to_string(), + options: "opt".to_string() + } + }) + .unwrap(); + let doc_buf = RawDocumentBuf::new(doc_bytes).unwrap(); + let document = &doc_buf; + let array = document.get_array("array").unwrap(); + + let v = AllRawTypes { + bson: document.get("bson").unwrap().unwrap(), + array, + document, + buf: doc_buf.clone(), + binary: document.get_binary("binary").unwrap(), + code_w_scope: document + .get("code_w_scope") + .unwrap() + .unwrap() + .as_javascript_with_scope() + .unwrap(), + regex: document.get_regex("regex").unwrap(), }; + let serialized = rmp_serde::to_vec_named(&v).unwrap(); + let back: AllRawTypes = rmp_serde::from_slice(&serialized).unwrap(); - run_test(&v, &doc, "all types"); + assert_eq!(back, v); } #[test] @@ -910,3 +1208,35 @@ fn u2i() { bson::to_document(&v).unwrap_err(); bson::to_vec(&v).unwrap_err(); } + +#[test] +fn hint_cleared() { + #[derive(Debug, Serialize, Deserialize)] + struct Foo<'a> { + #[serde(borrow)] + doc: &'a RawDocument, + #[serde(borrow)] + binary: RawBinary<'a>, + } + + let binary_value = Binary { + bytes: vec![1, 2, 3, 4], + subtype: BinarySubtype::Generic, + }; + + let doc_value = doc! { + "binary": binary_value.clone() + }; + + let bytes = bson::to_vec(&doc_value).unwrap(); + + let doc = RawDocument::new(&bytes).unwrap(); + let binary = doc.get_binary("binary").unwrap(); + + let f = Foo { doc, binary }; + + let serialized_bytes = bson::to_vec(&f).unwrap(); + let round_doc: Document = bson::from_slice(&serialized_bytes).unwrap(); + + assert_eq!(round_doc, doc! { "doc": doc_value, "binary": binary_value }); +} diff --git a/src/de/mod.rs b/src/de/mod.rs index 8a20a342..c577c891 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -35,6 +35,7 @@ use std::io::Read; use crate::{ bson::{Array, Binary, Bson, DbPointer, Document, JavaScriptCodeWithScope, Regex, Timestamp}, oid::{self, ObjectId}, + raw::RawBinary, ser::write_i32, spec::{self, BinarySubtype}, Decimal128, @@ -45,7 +46,10 @@ use ::serde::{ Deserialize, }; -pub(crate) use self::serde::BsonVisitor; +pub(crate) use self::serde::{convert_unsigned_to_signed_raw, BsonVisitor}; + +#[cfg(test)] +pub(crate) use self::raw::Deserializer as RawDeserializer; pub(crate) const MAX_BSON_SIZE: i32 = 16 * 1024 * 1024; pub(crate) const MIN_BSON_DOCUMENT_SIZE: i32 = 4 + 1; // 4 bytes for length, one byte for null terminator @@ -266,7 +270,7 @@ pub(crate) fn deserialize_bson_kvp( impl Binary { pub(crate) fn from_reader(mut reader: R) -> Result { - let len = read_i32(&mut reader)?; + let mut len = read_i32(&mut reader)?; if !(0..=MAX_BSON_SIZE).contains(&len) { return Err(Error::invalid_length( len as usize, @@ -274,20 +278,6 @@ impl Binary { )); } let subtype = BinarySubtype::from(read_u8(&mut reader)?); - Self::from_reader_with_len_and_payload(reader, len, subtype) - } - - pub(crate) fn from_reader_with_len_and_payload( - mut reader: R, - mut len: i32, - subtype: BinarySubtype, - ) -> Result { - if !(0..=MAX_BSON_SIZE).contains(&len) { - return Err(Error::invalid_length( - len as usize, - &format!("binary length must be between 0 and {}", MAX_BSON_SIZE).as_str(), - )); - } // Skip length data in old binary. if let BinarySubtype::BinaryOld = subtype { @@ -317,6 +307,50 @@ impl Binary { } } +impl<'a> RawBinary<'a> { + pub(crate) fn from_slice_with_len_and_payload( + mut bytes: &'a [u8], + mut len: i32, + subtype: BinarySubtype, + ) -> Result { + if !(0..=MAX_BSON_SIZE).contains(&len) { + return Err(Error::invalid_length( + len as usize, + &format!("binary length must be between 0 and {}", MAX_BSON_SIZE).as_str(), + )); + } else if len as usize > bytes.len() { + return Err(Error::invalid_length( + len as usize, + &format!( + "binary length {} exceeds buffer length {}", + len, + bytes.len() + ) + .as_str(), + )); + } + + // Skip length data in old binary. + if let BinarySubtype::BinaryOld = subtype { + let data_len = read_i32(&mut bytes)?; + + if data_len + 4 != len { + return Err(Error::invalid_length( + data_len as usize, + &"0x02 length did not match top level binary length", + )); + } + + len -= 4; + } + + Ok(Self { + bytes: &bytes[0..len as usize], + subtype, + }) + } +} + impl DbPointer { pub(crate) fn from_reader(mut reader: R) -> Result { let ns = read_string(&mut reader, false)?; diff --git a/src/de/raw.rs b/src/de/raw.rs index c48ba3c7..48a20de4 100644 --- a/src/de/raw.rs +++ b/src/de/raw.rs @@ -1,25 +1,25 @@ use std::{ borrow::Cow, + convert::TryInto, io::{ErrorKind, Read}, + sync::Arc, }; use serde::{ - de::{EnumAccess, Error as SerdeError, IntoDeserializer, VariantAccess}, + de::{EnumAccess, Error as SerdeError, IntoDeserializer, MapAccess, VariantAccess}, forward_to_deserialize_any, Deserializer as SerdeDeserializer, }; use crate::{ oid::ObjectId, + raw::{RawBinary, RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, spec::{BinarySubtype, ElementType}, uuid::UUID_NEWTYPE_NAME, - Binary, Bson, DateTime, - DbPointer, Decimal128, - JavaScriptCodeWithScope, - Regex, + RawDocument, Timestamp, }; @@ -34,9 +34,26 @@ use super::{ Error, Result, MAX_BSON_SIZE, + MIN_CODE_WITH_SCOPE_SIZE, }; use crate::de::serde::MapDeserializer; +/// Hint provided to the deserializer via `deserialize_newtype_struct` as to the type of thing +/// being deserialized. +#[derive(Debug, Clone, Copy)] +enum DeserializerHint { + /// No hint provided, deserialize normally. + None, + + /// The type being deserialized expects the BSON to contain a binary value with the provided + /// subtype. This is currently used to deserialize `bson::Uuid` values. + BinarySubtype(BinarySubtype), + + /// The type being deserialized is raw BSON, meaning no allocations should occur as part of + /// deserializing and everything should be visited via borrowing or `Copy`. + RawBson, +} + /// Deserializer used to parse and deserialize raw BSON bytes. pub(crate) struct Deserializer<'de> { bytes: BsonBuf<'de>, @@ -50,6 +67,13 @@ pub(crate) struct Deserializer<'de> { current_type: ElementType, } +/// Enum used to determine what the type of document being deserialized is in +/// `Deserializer::deserialize_document`. +enum DocumentType { + Array, + EmbeddedDocument, +} + impl<'de> Deserializer<'de> { pub(crate) fn new(buf: &'de [u8], utf8_lossy: bool) -> Self { Self { @@ -95,13 +119,69 @@ impl<'de> Deserializer<'de> { self.bytes.read_str() } - fn deserialize_document_key(&mut self) -> Result> { + /// Read a null-terminated C style string from the underling BSON. + /// + /// If utf8_lossy, this will be an owned string if invalid UTF-8 is encountered in the string, + /// otherwise it will be borrowed. + fn deserialize_cstr(&mut self) -> Result> { self.bytes.read_cstr() } + /// Read an ObjectId from the underling BSON. + /// + /// If hinted to use raw BSON, the bytes of the ObjectId will be visited. + /// Otherwise, a map in the shape of the extended JSON format of an ObjectId will be. + fn deserialize_objectid(&mut self, visitor: V, hint: DeserializerHint) -> Result + where + V: serde::de::Visitor<'de>, + { + let oid = ObjectId::from_reader(&mut self.bytes)?; + visitor.visit_map(ObjectIdAccess::new(oid, hint)) + } + + /// Read a document from the underling BSON, whether it's an array or an actual document. + /// + /// If hinted to use raw BSON, the bytes themselves will be visited using a special newtype + /// name. Otherwise, the key-value pairs will be accessed in order, either as part of a + /// `MapAccess` for documents or a `SeqAccess` for arrays. + fn deserialize_document( + &mut self, + visitor: V, + hint: DeserializerHint, + document_type: DocumentType, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + let is_array = match document_type { + DocumentType::Array => true, + DocumentType::EmbeddedDocument => false, + }; + + match hint { + DeserializerHint::RawBson => { + let mut len = self.bytes.slice(4)?; + let len = read_i32(&mut len)?; + + let doc = RawDocument::new(self.bytes.read_slice(len as usize)?) + .map_err(Error::custom)?; + + let access = if is_array { + RawDocumentAccess::for_array(doc) + } else { + RawDocumentAccess::new(doc) + }; + + visitor.visit_map(access) + } + _ if is_array => self.access_document(|access| visitor.visit_seq(access)), + _ => self.access_document(|access| visitor.visit_map(access)), + } + } + /// Construct a `DocumentAccess` and pass it into the provided closure, returning the /// result of the closure if no other errors are encountered. - fn deserialize_document(&mut self, f: F) -> Result + fn access_document(&mut self, f: F) -> Result where F: FnOnce(DocumentAccess<'_, 'de>) -> Result, { @@ -132,15 +212,13 @@ impl<'de> Deserializer<'de> { Ok(Some(element_type)) } - fn deserialize_next( - &mut self, - visitor: V, - binary_subtype_hint: Option, - ) -> Result + /// Deserialize the next element in the BSON, using the type of the element along with the + /// provided hint to determine how to visit the data. + fn deserialize_next(&mut self, visitor: V, hint: DeserializerHint) -> Result where V: serde::de::Visitor<'de>, { - if let Some(expected_st) = binary_subtype_hint { + if let DeserializerHint::BinarySubtype(expected_st) = hint { if self.current_type != ElementType::Binary { return Err(Error::custom(format!( "expected Binary with subtype {:?}, instead got {:?}", @@ -159,14 +237,11 @@ impl<'de> Deserializer<'de> { }, ElementType::Boolean => visitor.visit_bool(read_bool(&mut self.bytes)?), ElementType::Null => visitor.visit_unit(), - ElementType::ObjectId => { - let oid = ObjectId::from_reader(&mut self.bytes)?; - visitor.visit_map(ObjectIdAccess::new(oid)) - } + ElementType::ObjectId => self.deserialize_objectid(visitor, hint), ElementType::EmbeddedDocument => { - self.deserialize_document(|access| visitor.visit_map(access)) + self.deserialize_document(visitor, hint, DocumentType::EmbeddedDocument) } - ElementType::Array => self.deserialize_document(|access| visitor.visit_seq(access)), + ElementType::Array => self.deserialize_document(visitor, hint, DocumentType::Array), ElementType::Binary => { let len = read_i32(&mut self.bytes)?; if !(0..=MAX_BSON_SIZE).contains(&len) { @@ -177,7 +252,7 @@ impl<'de> Deserializer<'de> { } let subtype = BinarySubtype::from(read_u8(&mut self.bytes)?); - if let Some(expected_subtype) = binary_subtype_hint { + if let DeserializerHint::BinarySubtype(expected_subtype) = hint { if subtype != expected_subtype { return Err(Error::custom(format!( "expected binary subtype {:?} instead got {:?}", @@ -191,12 +266,12 @@ impl<'de> Deserializer<'de> { visitor.visit_borrowed_bytes(self.bytes.read_slice(len as usize)?) } _ => { - let binary = Binary::from_reader_with_len_and_payload( - &mut self.bytes, + let binary = RawBinary::from_slice_with_len_and_payload( + self.bytes.read_slice(len as usize)?, len, subtype, )?; - let mut d = BinaryDeserializer::new(binary); + let mut d = BinaryDeserializer::new(binary, hint); visitor.visit_map(BinaryAccess { deserializer: &mut d, }) @@ -204,45 +279,92 @@ impl<'de> Deserializer<'de> { } } ElementType::Undefined => { - let doc = Bson::Undefined.into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + visitor.visit_map(RawBsonAccess::new("$undefined", BsonContent::Boolean(true))) } ElementType::DateTime => { let dti = read_i64(&mut self.bytes)?; let dt = DateTime::from_millis(dti); - let mut d = DateTimeDeserializer::new(dt); + let mut d = DateTimeDeserializer::new(dt, hint); visitor.visit_map(DateTimeAccess { deserializer: &mut d, }) } ElementType::RegularExpression => { - let doc = Bson::RegularExpression(Regex::from_reader(&mut self.bytes)?) - .into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + let mut de = RegexDeserializer::new(&mut *self); + visitor.visit_map(RegexAccess::new(&mut de)) } ElementType::DbPointer => { - let doc = Bson::DbPointer(DbPointer::from_reader(&mut self.bytes)?) - .into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + let mut de = DbPointerDeserializer::new(&mut *self, hint); + visitor.visit_map(DbPointerAccess::new(&mut de)) } ElementType::JavaScriptCode => { let utf8_lossy = self.bytes.utf8_lossy; - let code = read_string(&mut self.bytes, utf8_lossy)?; - let doc = Bson::JavaScriptCode(code).into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + + match hint { + DeserializerHint::RawBson => visitor.visit_map(RawBsonAccess::new( + "$code", + BsonContent::Str(self.bytes.read_borrowed_str()?), + )), + _ => { + let code = read_string(&mut self.bytes, utf8_lossy)?; + let doc = Bson::JavaScriptCode(code).into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + } } ElementType::JavaScriptCodeWithScope => { - let utf8_lossy = self.bytes.utf8_lossy; - let code_w_scope = - JavaScriptCodeWithScope::from_reader(&mut self.bytes, utf8_lossy)?; - let doc = Bson::JavaScriptCodeWithScope(code_w_scope).into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + let len = read_i32(&mut self.bytes)?; + + if len < MIN_CODE_WITH_SCOPE_SIZE { + return Err(SerdeError::invalid_length( + len.try_into().unwrap_or(0), + &format!( + "CodeWithScope to be at least {} bytes", + MIN_CODE_WITH_SCOPE_SIZE + ) + .as_str(), + )); + } else if (self.bytes.bytes_remaining() as i32) < len - 4 { + return Err(SerdeError::invalid_length( + len.try_into().unwrap_or(0), + &format!( + "CodeWithScope to be at most {} bytes", + self.bytes.bytes_remaining() + ) + .as_str(), + )); + } + + let mut de = CodeWithScopeDeserializer::new(&mut *self, hint, len - 4); + let out = visitor.visit_map(CodeWithScopeAccess::new(&mut de)); + + if de.length_remaining != 0 { + return Err(SerdeError::invalid_length( + len.try_into().unwrap_or(0), + &format!( + "CodeWithScope length {} bytes greater than actual length", + de.length_remaining + ) + .as_str(), + )); + } + + out } ElementType::Symbol => { let utf8_lossy = self.bytes.utf8_lossy; - let symbol = read_string(&mut self.bytes, utf8_lossy)?; - let doc = Bson::Symbol(symbol).into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + + match hint { + DeserializerHint::RawBson => visitor.visit_map(RawBsonAccess::new( + "$symbol", + BsonContent::Str(self.bytes.read_borrowed_str()?), + )), + _ => { + let symbol = read_string(&mut self.bytes, utf8_lossy)?; + let doc = Bson::Symbol(symbol).into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + } } ElementType::Timestamp => { let ts = Timestamp::from_reader(&mut self.bytes)?; @@ -256,12 +378,10 @@ impl<'de> Deserializer<'de> { visitor.visit_map(Decimal128Access::new(d128)) } ElementType::MaxKey => { - let doc = Bson::MaxKey.into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + visitor.visit_map(RawBsonAccess::new("$maxKey", BsonContent::Int32(1))) } ElementType::MinKey => { - let doc = Bson::MinKey.into_extended_document(); - visitor.visit_map(MapDeserializer::new(doc)) + visitor.visit_map(RawBsonAccess::new("$minKey", BsonContent::Int32(1))) } } } @@ -275,7 +395,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: serde::de::Visitor<'de>, { - self.deserialize_next(visitor, None) + self.deserialize_next(visitor, DeserializerHint::None) } #[inline] @@ -301,7 +421,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { match self.current_type { ElementType::String => visitor.visit_enum(self.deserialize_str()?.into_deserializer()), ElementType::EmbeddedDocument => { - self.deserialize_document(|access| visitor.visit_enum(access)) + self.access_document(|access| visitor.visit_enum(access)) } t => Err(Error::custom(format!("expected enum, instead got {:?}", t))), } @@ -321,10 +441,33 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: serde::de::Visitor<'de>, { - if name == UUID_NEWTYPE_NAME { - self.deserialize_next(visitor, Some(BinarySubtype::Uuid)) - } else { - visitor.visit_newtype_struct(self) + match name { + UUID_NEWTYPE_NAME => self.deserialize_next( + visitor, + DeserializerHint::BinarySubtype(BinarySubtype::Uuid), + ), + RAW_BSON_NEWTYPE => self.deserialize_next(visitor, DeserializerHint::RawBson), + RAW_DOCUMENT_NEWTYPE => { + if self.current_type != ElementType::EmbeddedDocument { + return Err(serde::de::Error::custom(format!( + "expected raw document, instead got {:?}", + self.current_type + ))); + } + + self.deserialize_next(visitor, DeserializerHint::RawBson) + } + RAW_ARRAY_NEWTYPE => { + if self.current_type != ElementType::Array { + return Err(serde::de::Error::custom(format!( + "expected raw array, instead got {:?}", + self.current_type + ))); + } + + self.deserialize_next(visitor, DeserializerHint::RawBson) + } + _ => visitor.visit_newtype_struct(self), } } @@ -428,7 +571,7 @@ impl<'d, 'de> serde::de::SeqAccess<'de> for DocumentAccess<'d, 'de> { if self.read_next_type()?.is_none() { return Ok(None); } - let _index = self.read(|s| s.root_deserializer.deserialize_document_key())?; + let _index = self.read(|s| s.root_deserializer.deserialize_cstr())?; self.read_next_value(seed).map(Some) } } @@ -498,13 +641,17 @@ impl<'d, 'de> serde::de::Deserializer<'de> for DocumentKeyDeserializer<'d, 'de> where V: serde::de::Visitor<'de>, { - let s = self.root_deserializer.deserialize_document_key()?; + let s = self.root_deserializer.deserialize_cstr()?; match s { Cow::Borrowed(b) => visitor.visit_borrowed_str(b), Cow::Owned(string) => visitor.visit_string(string), } } + fn is_human_readable(&self) -> bool { + false + } + forward_to_deserialize_any! { bool char str bytes byte_buf option unit unit_struct string identifier newtype_struct seq tuple tuple_struct struct map enum @@ -527,6 +674,99 @@ impl<'de> serde::de::Deserializer<'de> for FieldDeserializer { visitor.visit_borrowed_str(self.field_name) } + fn is_human_readable(&self) -> bool { + false + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +/// A `MapAccess` used to deserialize entire documents as chunks of bytes without deserializing +/// the individual key/value pairs. +struct RawDocumentAccess<'d> { + deserializer: RawDocumentDeserializer<'d>, + + /// Whether the first key has been deserialized yet or not. + deserialized_first: bool, + + /// Whether or not this document being deserialized is for an array or not. + array: bool, +} + +impl<'de> RawDocumentAccess<'de> { + fn new(doc: &'de RawDocument) -> Self { + Self { + deserializer: RawDocumentDeserializer { raw_doc: doc }, + deserialized_first: false, + array: false, + } + } + + fn for_array(doc: &'de RawDocument) -> Self { + Self { + deserializer: RawDocumentDeserializer { raw_doc: doc }, + deserialized_first: false, + array: true, + } + } +} + +impl<'de> serde::de::MapAccess<'de> for RawDocumentAccess<'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + if !self.deserialized_first { + self.deserialized_first = true; + + // the newtype name will indicate to the `RawBson` enum that the incoming + // bytes are meant to be treated as a document or array instead of a binary value. + seed.deserialize(FieldDeserializer { + field_name: if self.array { + RAW_ARRAY_NEWTYPE + } else { + RAW_DOCUMENT_NEWTYPE + }, + }) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(self.deserializer) + } +} + +#[derive(Clone, Copy)] +struct RawDocumentDeserializer<'a> { + raw_doc: &'a RawDocument, +} + +impl<'de> serde::de::Deserializer<'de> for RawDocumentDeserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_borrowed_bytes(self.raw_doc.as_bytes()) + } + + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -537,13 +777,15 @@ impl<'de> serde::de::Deserializer<'de> for FieldDeserializer { struct ObjectIdAccess { oid: ObjectId, visited: bool, + hint: DeserializerHint, } impl ObjectIdAccess { - fn new(oid: ObjectId) -> Self { + fn new(oid: ObjectId, hint: DeserializerHint) -> Self { Self { oid, visited: false, + hint, } } } @@ -567,11 +809,17 @@ impl<'de> serde::de::MapAccess<'de> for ObjectIdAccess { where V: serde::de::DeserializeSeed<'de>, { - seed.deserialize(ObjectIdDeserializer(self.oid)) + seed.deserialize(ObjectIdDeserializer { + oid: self.oid, + hint: self.hint, + }) } } -struct ObjectIdDeserializer(ObjectId); +struct ObjectIdDeserializer { + oid: ObjectId, + hint: DeserializerHint, +} impl<'de> serde::de::Deserializer<'de> for ObjectIdDeserializer { type Error = Error; @@ -580,7 +828,15 @@ impl<'de> serde::de::Deserializer<'de> for ObjectIdDeserializer { where V: serde::de::Visitor<'de>, { - visitor.visit_string(self.0.to_hex()) + // save an allocation when deserializing to raw bson + match self.hint { + DeserializerHint::RawBson => visitor.visit_bytes(&self.oid.bytes()), + _ => visitor.visit_string(self.oid.to_hex()), + } + } + + fn is_human_readable(&self) -> bool { + false } serde::forward_to_deserialize_any! { @@ -641,6 +897,10 @@ impl<'de> serde::de::Deserializer<'de> for Decimal128Deserializer { visitor.visit_bytes(&self.0.bytes) } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -732,6 +992,10 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut TimestampDeserializer { } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -739,12 +1003,13 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut TimestampDeserializer { } } -enum DateTimeDeserializationStage { - TopLevel, - NumberLong, - Done, -} - +/// A `MapAccess` providing access to a BSON datetime being deserialized. +/// +/// If hinted to be raw BSON, this deserializes the serde data model equivalent +/// of { "$date": }. +/// +/// Otherwise, this deserializes the serde data model equivalent of +/// { "$date": { "$numberLong": } }. struct DateTimeAccess<'d> { deserializer: &'d mut DateTimeDeserializer, } @@ -782,13 +1047,21 @@ impl<'de, 'd> serde::de::MapAccess<'de> for DateTimeAccess<'d> { struct DateTimeDeserializer { dt: DateTime, stage: DateTimeDeserializationStage, + hint: DeserializerHint, +} + +enum DateTimeDeserializationStage { + TopLevel, + NumberLong, + Done, } impl DateTimeDeserializer { - fn new(dt: DateTime) -> Self { + fn new(dt: DateTime, hint: DeserializerHint) -> Self { Self { dt, stage: DateTimeDeserializationStage::TopLevel, + hint, } } } @@ -801,12 +1074,18 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { V: serde::de::Visitor<'de>, { match self.stage { - DateTimeDeserializationStage::TopLevel => { - self.stage = DateTimeDeserializationStage::NumberLong; - visitor.visit_map(DateTimeAccess { - deserializer: &mut self, - }) - } + DateTimeDeserializationStage::TopLevel => match self.hint { + DeserializerHint::RawBson => { + self.stage = DateTimeDeserializationStage::Done; + visitor.visit_i64(self.dt.timestamp_millis()) + } + _ => { + self.stage = DateTimeDeserializationStage::NumberLong; + visitor.visit_map(DateTimeAccess { + deserializer: &mut self, + }) + } + }, DateTimeDeserializationStage::NumberLong => { self.stage = DateTimeDeserializationStage::Done; visitor.visit_string(self.dt.timestamp_millis().to_string()) @@ -817,6 +1096,10 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -824,35 +1107,35 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { } } -struct BinaryAccess<'d> { - deserializer: &'d mut BinaryDeserializer, +/// A `MapAccess` providing access to a BSON binary being deserialized. +/// +/// If hinted to be raw BSON, this deserializes the serde data model equivalent +/// of { "$binary": { "subType": , "bytes": } }. +/// +/// Otherwise, this deserializes the serde data model equivalent of +/// { "$binary": { "subType": , "base64": } }. +struct BinaryAccess<'d, 'de> { + deserializer: &'d mut BinaryDeserializer<'de>, } -impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d> { +impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d, 'de> { type Error = Error; fn next_key_seed(&mut self, seed: K) -> Result> where K: serde::de::DeserializeSeed<'de>, { - match self.deserializer.stage { - BinaryDeserializationStage::TopLevel => seed - .deserialize(FieldDeserializer { - field_name: "$binary", - }) - .map(Some), - BinaryDeserializationStage::Subtype => seed - .deserialize(FieldDeserializer { - field_name: "subType", - }) - .map(Some), - BinaryDeserializationStage::Bytes => seed - .deserialize(FieldDeserializer { - field_name: "base64", - }) - .map(Some), - BinaryDeserializationStage::Done => Ok(None), - } + let field_name = match self.deserializer.stage { + BinaryDeserializationStage::TopLevel => "$binary", + BinaryDeserializationStage::Subtype => "subType", + BinaryDeserializationStage::Bytes => match self.deserializer.hint { + DeserializerHint::RawBson => "bytes", + _ => "base64", + }, + BinaryDeserializationStage::Done => return Ok(None), + }; + + seed.deserialize(FieldDeserializer { field_name }).map(Some) } fn next_value_seed(&mut self, seed: V) -> Result @@ -863,21 +1146,23 @@ impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d> { } } -struct BinaryDeserializer { - binary: Binary, +struct BinaryDeserializer<'a> { + binary: RawBinary<'a>, + hint: DeserializerHint, stage: BinaryDeserializationStage, } -impl BinaryDeserializer { - fn new(binary: Binary) -> Self { +impl<'a> BinaryDeserializer<'a> { + fn new(binary: RawBinary<'a>, hint: DeserializerHint) -> Self { Self { binary, + hint, stage: BinaryDeserializationStage::TopLevel, } } } -impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer { +impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer<'de> { type Error = Error; fn deserialize_any(mut self, visitor: V) -> Result @@ -893,11 +1178,17 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer { } BinaryDeserializationStage::Subtype => { self.stage = BinaryDeserializationStage::Bytes; - visitor.visit_string(hex::encode([u8::from(self.binary.subtype)])) + match self.hint { + DeserializerHint::RawBson => visitor.visit_u8(self.binary.subtype().into()), + _ => visitor.visit_string(hex::encode([u8::from(self.binary.subtype)])), + } } BinaryDeserializationStage::Bytes => { self.stage = BinaryDeserializationStage::Done; - visitor.visit_string(base64::encode(self.binary.bytes.as_slice())) + match self.hint { + DeserializerHint::RawBson => visitor.visit_borrowed_bytes(self.binary.bytes), + _ => visitor.visit_string(base64::encode(self.binary.bytes)), + } } BinaryDeserializationStage::Done => { Err(Error::custom("Binary fully deserialized already")) @@ -905,6 +1196,10 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer { } } + fn is_human_readable(&self) -> bool { + false + } + serde::forward_to_deserialize_any! { bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq bytes byte_buf map struct option unit newtype_struct @@ -919,6 +1214,449 @@ enum BinaryDeserializationStage { Done, } +/// A `MapAccess` providing access to a BSON code with scope being deserialized. +/// +/// If hinted to be raw BSON, this deserializes the serde data model equivalent +/// of { "$code": , "$scope": <&RawDocument> } }. +/// +/// Otherwise, this deserializes the serde data model equivalent of +/// { "$code": "$scope": }. +struct CodeWithScopeAccess<'de, 'd, 'a> { + deserializer: &'a mut CodeWithScopeDeserializer<'de, 'd>, +} + +impl<'de, 'd, 'a> CodeWithScopeAccess<'de, 'd, 'a> { + fn new(deserializer: &'a mut CodeWithScopeDeserializer<'de, 'd>) -> Self { + Self { deserializer } + } +} + +impl<'de, 'd, 'a> serde::de::MapAccess<'de> for CodeWithScopeAccess<'de, 'd, 'a> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + match self.deserializer.stage { + CodeWithScopeDeserializationStage::Code => seed + .deserialize(FieldDeserializer { + field_name: "$code", + }) + .map(Some), + CodeWithScopeDeserializationStage::Scope => seed + .deserialize(FieldDeserializer { + field_name: "$scope", + }) + .map(Some), + CodeWithScopeDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct CodeWithScopeDeserializer<'de, 'a> { + root_deserializer: &'a mut Deserializer<'de>, + stage: CodeWithScopeDeserializationStage, + hint: DeserializerHint, + length_remaining: i32, +} + +impl<'de, 'a> CodeWithScopeDeserializer<'de, 'a> { + fn new(root_deserializer: &'a mut Deserializer<'de>, hint: DeserializerHint, len: i32) -> Self { + Self { + root_deserializer, + stage: CodeWithScopeDeserializationStage::Code, + hint, + length_remaining: len, + } + } + + /// Executes a closure that reads from the BSON bytes and returns an error if the number of + /// bytes read exceeds length_remaining. + /// + /// A mutable reference to this `CodeWithScopeDeserializer` is passed into the closure. + fn read(&mut self, f: F) -> Result + where + F: FnOnce(&mut Self) -> Result, + { + let start_bytes = self.root_deserializer.bytes.bytes_read(); + let out = f(self); + let bytes_read = self.root_deserializer.bytes.bytes_read() - start_bytes; + self.length_remaining -= bytes_read as i32; + + if self.length_remaining < 0 { + return Err(Error::custom("length of CodeWithScope too short")); + } + out + } +} + +impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut CodeWithScopeDeserializer<'de, 'a> { + type Error = Error; + + fn deserialize_any(mut self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.stage { + CodeWithScopeDeserializationStage::Code => { + self.stage = CodeWithScopeDeserializationStage::Scope; + match self.read(|s| s.root_deserializer.deserialize_str())? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(s) => visitor.visit_string(s), + } + } + CodeWithScopeDeserializationStage::Scope => { + self.stage = CodeWithScopeDeserializationStage::Done; + self.read(|s| { + s.root_deserializer.deserialize_document( + visitor, + s.hint, + DocumentType::EmbeddedDocument, + ) + }) + } + CodeWithScopeDeserializationStage::Done => Err(Error::custom( + "JavaScriptCodeWithScope fully deserialized already", + )), + } + } + + fn is_human_readable(&self) -> bool { + false + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +#[derive(Debug)] +enum CodeWithScopeDeserializationStage { + Code, + Scope, + Done, +} + +/// A `MapAccess` providing access to a BSON DB pointer being deserialized. +/// +/// Regardless of the hint, this deserializes the serde data model equivalent +/// of { "$dbPointer": { "$ref": , "$id": } }. +struct DbPointerAccess<'de, 'd, 'a> { + deserializer: &'a mut DbPointerDeserializer<'de, 'd>, +} + +impl<'de, 'd, 'a> DbPointerAccess<'de, 'd, 'a> { + fn new(deserializer: &'a mut DbPointerDeserializer<'de, 'd>) -> Self { + Self { deserializer } + } +} + +impl<'de, 'd, 'a> serde::de::MapAccess<'de> for DbPointerAccess<'de, 'd, 'a> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + match self.deserializer.stage { + DbPointerDeserializationStage::TopLevel => seed + .deserialize(FieldDeserializer { + field_name: "$dbPointer", + }) + .map(Some), + DbPointerDeserializationStage::Namespace => seed + .deserialize(FieldDeserializer { field_name: "$ref" }) + .map(Some), + DbPointerDeserializationStage::Id => seed + .deserialize(FieldDeserializer { field_name: "$id" }) + .map(Some), + DbPointerDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct DbPointerDeserializer<'de, 'a> { + root_deserializer: &'a mut Deserializer<'de>, + stage: DbPointerDeserializationStage, + hint: DeserializerHint, +} + +impl<'de, 'a> DbPointerDeserializer<'de, 'a> { + fn new(root_deserializer: &'a mut Deserializer<'de>, hint: DeserializerHint) -> Self { + Self { + root_deserializer, + stage: DbPointerDeserializationStage::TopLevel, + hint, + } + } +} + +impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut DbPointerDeserializer<'de, 'a> { + type Error = Error; + + fn deserialize_any(mut self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.stage { + DbPointerDeserializationStage::TopLevel => { + self.stage = DbPointerDeserializationStage::Namespace; + visitor.visit_map(DbPointerAccess::new(self)) + } + DbPointerDeserializationStage::Namespace => { + self.stage = DbPointerDeserializationStage::Id; + match self.root_deserializer.deserialize_str()? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(s) => visitor.visit_string(s), + } + } + DbPointerDeserializationStage::Id => { + self.stage = DbPointerDeserializationStage::Done; + self.root_deserializer + .deserialize_objectid(visitor, self.hint) + } + DbPointerDeserializationStage::Done => { + Err(Error::custom("DbPointer fully deserialized already")) + } + } + } + + fn is_human_readable(&self) -> bool { + false + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +#[derive(Debug)] +enum DbPointerDeserializationStage { + TopLevel, + Namespace, + Id, + Done, +} + +/// A `MapAccess` providing access to a BSON regular expression being deserialized. +/// +/// Regardless of the hint, this deserializes the serde data model equivalent +/// of { "$regularExpression": { "pattern": , "options": } }. +struct RegexAccess<'de, 'd, 'a> { + deserializer: &'a mut RegexDeserializer<'de, 'd>, +} + +impl<'de, 'd, 'a> RegexAccess<'de, 'd, 'a> { + fn new(deserializer: &'a mut RegexDeserializer<'de, 'd>) -> Self { + Self { deserializer } + } +} + +impl<'de, 'd, 'a> serde::de::MapAccess<'de> for RegexAccess<'de, 'd, 'a> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + match self.deserializer.stage { + RegexDeserializationStage::TopLevel => seed + .deserialize(FieldDeserializer { + field_name: "$regularExpression", + }) + .map(Some), + RegexDeserializationStage::Pattern => seed + .deserialize(FieldDeserializer { + field_name: "pattern", + }) + .map(Some), + RegexDeserializationStage::Options => seed + .deserialize(FieldDeserializer { + field_name: "options", + }) + .map(Some), + RegexDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct RegexDeserializer<'de, 'a> { + root_deserializer: &'a mut Deserializer<'de>, + stage: RegexDeserializationStage, +} + +impl<'de, 'a> RegexDeserializer<'de, 'a> { + fn new(root_deserializer: &'a mut Deserializer<'de>) -> Self { + Self { + root_deserializer, + stage: RegexDeserializationStage::TopLevel, + } + } +} + +impl<'de, 'a, 'b> serde::de::Deserializer<'de> for &'b mut RegexDeserializer<'de, 'a> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.stage { + RegexDeserializationStage::TopLevel => { + self.stage.advance(); + visitor.visit_map(RegexAccess::new(self)) + } + RegexDeserializationStage::Pattern | RegexDeserializationStage::Options => { + self.stage.advance(); + match self.root_deserializer.deserialize_cstr()? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(s) => visitor.visit_string(s), + } + } + RegexDeserializationStage::Done => { + Err(Error::custom("DbPointer fully deserialized already")) + } + } + } + + fn is_human_readable(&self) -> bool { + false + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +#[derive(Debug)] +enum RegexDeserializationStage { + TopLevel, + Pattern, + Options, + Done, +} + +impl RegexDeserializationStage { + fn advance(&mut self) { + *self = match self { + RegexDeserializationStage::TopLevel => RegexDeserializationStage::Pattern, + RegexDeserializationStage::Pattern => RegexDeserializationStage::Options, + RegexDeserializationStage::Options => RegexDeserializationStage::Done, + RegexDeserializationStage::Done => RegexDeserializationStage::Done, + } + } +} + +/// Helper access struct for visiting the extended JSON model of simple BSON types. +/// e.g. Symbol, Timestamp, etc. +struct RawBsonAccess<'a> { + key: &'static str, + value: BsonContent<'a>, + first: bool, +} + +/// Enum value representing some cached BSON data needed to represent a given +/// BSON type's extended JSON model. +#[derive(Debug, Clone, Copy)] +enum BsonContent<'a> { + Str(&'a str), + Int32(i32), + Boolean(bool), +} + +impl<'a> RawBsonAccess<'a> { + fn new(key: &'static str, value: BsonContent<'a>) -> Self { + Self { + key, + value, + first: true, + } + } +} + +impl<'de> MapAccess<'de> for RawBsonAccess<'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + if self.first { + self.first = false; + seed.deserialize(FieldDeserializer { + field_name: self.key, + }) + .map(Some) + } else { + Ok(None) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(RawBsonDeserializer { value: self.value }) + } +} + +struct RawBsonDeserializer<'a> { + value: BsonContent<'a>, +} + +impl<'de, 'a> serde::de::Deserializer<'de> for RawBsonDeserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.value { + BsonContent::Boolean(b) => visitor.visit_bool(b), + BsonContent::Str(s) => visitor.visit_borrowed_str(s), + BsonContent::Int32(i) => visitor.visit_i32(i), + } + } + + fn is_human_readable(&self) -> bool { + false + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + /// Struct wrapping a slice of BSON bytes. struct BsonBuf<'a> { bytes: &'a [u8], @@ -951,6 +1689,10 @@ impl<'a> BsonBuf<'a> { self.index } + fn bytes_remaining(&self) -> usize { + self.bytes.len() - self.bytes_read() + } + /// Verify the index has not run out of bounds. fn index_check(&self) -> std::io::Result<()> { if self.index >= self.bytes.len() { @@ -960,9 +1702,11 @@ impl<'a> BsonBuf<'a> { } /// Get the string starting at the provided index and ending at the buffer's current index. - fn str(&mut self, start: usize) -> Result> { + /// + /// Can optionally override the global UTF-8 lossy setting to ensure bytes are not allocated. + fn str(&mut self, start: usize, utf8_lossy_override: Option) -> Result> { let bytes = &self.bytes[start..self.index]; - let s = if self.utf8_lossy { + let s = if utf8_lossy_override.unwrap_or(self.utf8_lossy) { String::from_utf8_lossy(bytes) } else { Cow::Borrowed(std::str::from_utf8(bytes).map_err(Error::custom)?) @@ -991,15 +1735,10 @@ impl<'a> BsonBuf<'a> { self.index_check()?; - self.str(start) + self.str(start, None) } - /// Attempts to read a null-terminated UTF-8 string from the data. - /// - /// If invalid UTF-8 is encountered, the unicode replacement character will be inserted in place - /// of the offending data, resulting in an owned `String`. Otherwise, the data will be - /// borrowed as-is. - fn read_str(&mut self) -> Result> { + fn _advance_to_len_encoded_str(&mut self) -> Result { let len = read_i32(self)?; let start = self.index; @@ -1014,13 +1753,41 @@ impl<'a> BsonBuf<'a> { self.index += (len - 1) as usize; self.index_check()?; - self.str(start) + Ok(start) + } + + /// Attempts to read a null-terminated UTF-8 string from the data. + /// + /// If invalid UTF-8 is encountered, the unicode replacement character will be inserted in place + /// of the offending data, resulting in an owned `String`. Otherwise, the data will be + /// borrowed as-is. + fn read_str(&mut self) -> Result> { + let start = self._advance_to_len_encoded_str()?; + self.str(start, None) + } + + /// Attempts to read a null-terminated UTF-8 string from the data. + fn read_borrowed_str(&mut self) -> Result<&'a str> { + let start = self._advance_to_len_encoded_str()?; + match self.str(start, Some(false))? { + Cow::Borrowed(s) => Ok(s), + Cow::Owned(_) => panic!("should have errored when encountering invalid UTF-8"), + } + } + + fn slice(&self, length: usize) -> Result<&'a [u8]> { + if self.index + length > self.bytes.len() { + return Err(Error::Io(Arc::new( + std::io::ErrorKind::UnexpectedEof.into(), + ))); + } + + Ok(&self.bytes[self.index..(self.index + length)]) } fn read_slice(&mut self, length: usize) -> Result<&'a [u8]> { - let start = self.index; + let slice = self.slice(length)?; self.index += length; - self.index_check()?; - Ok(&self.bytes[start..self.index]) + Ok(slice) } } diff --git a/src/de/serde.rs b/src/de/serde.rs index 507391d0..23fc5e8d 100644 --- a/src/de/serde.rs +++ b/src/de/serde.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, convert::{TryFrom, TryInto}, fmt, vec, @@ -24,6 +25,7 @@ use crate::{ datetime::DateTime, document::{Document, IntoIter}, oid::ObjectId, + raw::RawBson, spec::BinarySubtype, uuid::UUID_NEWTYPE_NAME, Decimal128, @@ -264,15 +266,68 @@ impl<'de> Visitor<'de> for BsonVisitor { while let Some(k) = visitor.next_key::()? { match k.as_str() { "$oid" => { - let hex: String = visitor.next_value()?; - return Ok(Bson::ObjectId(ObjectId::parse_str(hex.as_str()).map_err( - |_| { - V::Error::invalid_value( - Unexpected::Str(&hex), - &"24-character, big-endian hex string", - ) - }, - )?)); + enum BytesOrHex<'a> { + Bytes([u8; 12]), + Hex(Cow<'a, str>), + } + + impl<'a, 'de: 'a> Deserialize<'de> for BytesOrHex<'a> { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct BytesOrHexVisitor; + + impl<'de> Visitor<'de> for BytesOrHexVisitor { + type Value = BytesOrHex<'de>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "hexstring or byte array") + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + Ok(BytesOrHex::Hex(Cow::Owned(v.to_string()))) + } + + fn visit_borrowed_str( + self, + v: &'de str, + ) -> Result + where + E: Error, + { + Ok(BytesOrHex::Hex(Cow::Borrowed(v))) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: Error, + { + Ok(BytesOrHex::Bytes(v.try_into().map_err(Error::custom)?)) + } + } + + deserializer.deserialize_any(BytesOrHexVisitor) + } + } + + let bytes_or_hex: BytesOrHex = visitor.next_value()?; + match bytes_or_hex { + BytesOrHex::Bytes(b) => return Ok(Bson::ObjectId(ObjectId::from_bytes(b))), + BytesOrHex::Hex(hex) => { + return Ok(Bson::ObjectId(ObjectId::parse_str(&hex).map_err( + |_| { + V::Error::invalid_value( + Unexpected::Str(&hex), + &"24-character, big-endian hex string", + ) + }, + )?)); + } + } } "$symbol" => { let string: String = visitor.next_value()?; @@ -368,10 +423,7 @@ impl<'de> Visitor<'de> for BsonVisitor { "$regularExpression" => { let re = visitor.next_value::()?; - return Ok(Bson::RegularExpression(Regex { - pattern: re.pattern, - options: re.options, - })); + return Ok(Bson::RegularExpression(Regex::new(re.pattern, re.options))); } "$dbPointer" => { @@ -421,16 +473,12 @@ impl<'de> Visitor<'de> for BsonVisitor { "$numberDecimalBytes" => { let bytes = visitor.next_value::()?; - let arr = bytes.into_vec().try_into().map_err(|v: Vec| { - Error::custom(format!( - "expected decimal128 as byte buffer, instead got buffer of length {}", - v.len() - )) - })?; - return Ok(Bson::Decimal128(Decimal128 { bytes: arr })); + return Ok(Bson::Decimal128(Decimal128::deserialize_from_slice( + &bytes, + )?)); } - _ => { + k => { let v = visitor.next_value::()?; doc.insert(k, v); } @@ -471,14 +519,16 @@ impl<'de> Visitor<'de> for BsonVisitor { } } -fn convert_unsigned_to_signed(value: u64) -> Result -where - E: Error, -{ +enum BsonInteger { + Int32(i32), + Int64(i64), +} + +fn _convert_unsigned(value: u64) -> Result { if let Ok(int32) = i32::try_from(value) { - Ok(Bson::Int32(int32)) + Ok(BsonInteger::Int32(int32)) } else if let Ok(int64) = i64::try_from(value) { - Ok(Bson::Int64(int64)) + Ok(BsonInteger::Int64(int64)) } else { Err(Error::custom(format!( "cannot represent {} as a signed number", @@ -487,6 +537,28 @@ where } } +fn convert_unsigned_to_signed(value: u64) -> Result +where + E: Error, +{ + let bi = _convert_unsigned(value)?; + match bi { + BsonInteger::Int32(i) => Ok(Bson::Int32(i)), + BsonInteger::Int64(i) => Ok(Bson::Int64(i)), + } +} + +pub(crate) fn convert_unsigned_to_signed_raw<'a, E>(value: u64) -> Result, E> +where + E: Error, +{ + let bi = _convert_unsigned(value)?; + match bi { + BsonInteger::Int32(i) => Ok(RawBson::Int32(i)), + BsonInteger::Int64(i) => Ok(RawBson::Int64(i)), + } +} + /// Serde Deserializer pub struct Deserializer { value: Option, @@ -1010,9 +1082,12 @@ impl<'de> Deserialize<'de> for Binary { where D: de::Deserializer<'de>, { - match Bson::deserialize(deserializer)? { + match deserializer.deserialize_byte_buf(BsonVisitor)? { Bson::Binary(binary) => Ok(binary), - _ => Err(D::Error::custom("expecting Binary")), + d => Err(D::Error::custom(format!( + "expecting Binary but got {:?} instead", + d + ))), } } } @@ -1024,7 +1099,10 @@ impl<'de> Deserialize<'de> for Decimal128 { { match Bson::deserialize(deserializer)? { Bson::Decimal128(d128) => Ok(d128), - _ => Err(D::Error::custom("expecting Decimal128")), + o => Err(D::Error::custom(format!( + "expecting Decimal128, got {:?}", + o + ))), } } } diff --git a/src/decimal128.rs b/src/decimal128.rs index c217bb5d..533b10dd 100644 --- a/src/decimal128.rs +++ b/src/decimal128.rs @@ -1,6 +1,6 @@ //! [BSON Decimal128](https://github.com/mongodb/specifications/blob/master/source/bson-decimal128/decimal128.rst) data type representation -use std::fmt; +use std::{convert::TryInto, fmt}; /// Struct representing a BSON Decimal128 type. /// @@ -22,6 +22,13 @@ impl Decimal128 { pub fn bytes(&self) -> [u8; 128 / 8] { self.bytes } + + pub(crate) fn deserialize_from_slice( + bytes: &[u8], + ) -> std::result::Result { + let arr: [u8; 128 / 8] = bytes.try_into().map_err(E::custom)?; + Ok(Decimal128 { bytes: arr }) + } } impl fmt::Debug for Decimal128 { diff --git a/src/lib.rs b/src/lib.rs index d9379dca..be043012 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -271,16 +271,14 @@ pub use self::{ bson::{Array, Binary, Bson, DbPointer, Document, JavaScriptCodeWithScope, Regex, Timestamp}, datetime::DateTime, de::{ - from_bson, - from_document, - from_reader, - from_reader_utf8_lossy, - from_slice, - from_slice_utf8_lossy, - Deserializer, + from_bson, from_document, from_reader, from_reader_utf8_lossy, from_slice, + from_slice_utf8_lossy, Deserializer, }, decimal128::Decimal128, - raw::{RawDocument, RawDocumentBuf, RawArray}, + raw::{ + RawArray, RawBinary, RawBson, RawDbPointer, RawDocument, RawDocumentBuf, RawJavaScriptCodeWithScope, + RawRegex, + }, ser::{to_bson, to_document, to_vec, Serializer}, uuid::{Uuid, UuidRepresentation}, }; diff --git a/src/raw/array.rs b/src/raw/array.rs index 684a4a4c..665c5632 100644 --- a/src/raw/array.rs +++ b/src/raw/array.rs @@ -1,5 +1,7 @@ use std::convert::TryFrom; +use serde::{ser::SerializeSeq, Deserialize, Serialize}; + use super::{ error::{ValueAccessError, ValueAccessErrorKind, ValueAccessResult}, Error, @@ -10,7 +12,14 @@ use super::{ RawRegex, Result, }; -use crate::{oid::ObjectId, spec::ElementType, Bson, DateTime, Timestamp}; +use crate::{ + oid::ObjectId, + raw::{RawBsonVisitor, RAW_ARRAY_NEWTYPE}, + spec::{BinarySubtype, ElementType}, + Bson, + DateTime, + Timestamp, +}; /// A slice of a BSON document containing a BSON array value (akin to [`std::str`]). This can be /// retrieved from a [`RawDocument`] via [`RawDocument::get`]. @@ -240,3 +249,51 @@ impl<'a> Iterator for RawArrayIter<'a> { } } } + +impl<'de: 'a, 'a> Deserialize<'de> for &'a RawArray { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match deserializer.deserialize_newtype_struct(RAW_ARRAY_NEWTYPE, RawBsonVisitor)? { + RawBson::Array(d) => Ok(d), + RawBson::Binary(b) if b.subtype == BinarySubtype::Generic => { + let doc = RawDocument::new(b.bytes).map_err(serde::de::Error::custom)?; + Ok(RawArray::from_doc(doc)) + } + b => Err(serde::de::Error::custom(format!( + "expected raw array reference, instead got {:?}", + b + ))), + } + } +} + +impl<'a> Serialize for &'a RawArray { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + struct SeqSerializer<'a>(&'a RawArray); + + impl<'a> Serialize for SeqSerializer<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + if serializer.is_human_readable() { + let mut seq = serializer.serialize_seq(None)?; + for v in self.0 { + let v = v.map_err(serde::ser::Error::custom)?; + seq.serialize_element(&v)?; + } + seq.end() + } else { + serializer.serialize_bytes(self.0.as_bytes()) + } + } + } + + serializer.serialize_newtype_struct(RAW_ARRAY_NEWTYPE, &SeqSerializer(self)) + } +} diff --git a/src/raw/bson.rs b/src/raw/bson.rs index 05ae4e19..9529a1e7 100644 --- a/src/raw/bson.rs +++ b/src/raw/bson.rs @@ -1,10 +1,17 @@ use std::convert::{TryFrom, TryInto}; +use serde::{de::Visitor, ser::SerializeStruct, Deserialize, Serialize}; +use serde_bytes::{ByteBuf, Bytes}; + use super::{Error, RawArray, RawDocument, Result}; use crate::{ + de::convert_unsigned_to_signed_raw, + extjson, oid::{self, ObjectId}, + raw::{RAW_ARRAY_NEWTYPE, RAW_BSON_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, spec::{BinarySubtype, ElementType}, Bson, + DateTime, DbPointer, Decimal128, Timestamp, @@ -239,6 +246,307 @@ impl<'a> RawBson<'a> { } } +/// A visitor used to deserialize types backed by raw BSON. +pub(crate) struct RawBsonVisitor; + +impl<'de> Visitor<'de> for RawBsonVisitor { + type Value = RawBson<'de>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a raw BSON reference") + } + + fn visit_borrowed_str(self, v: &'de str) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::String(v)) + } + + fn visit_borrowed_bytes(self, bytes: &'de [u8]) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Binary(RawBinary { + bytes, + subtype: BinarySubtype::Generic, + })) + } + + fn visit_i8(self, v: i8) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int32(v.into())) + } + + fn visit_i16(self, v: i16) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int32(v.into())) + } + + fn visit_i32(self, v: i32) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int32(v)) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Int64(v)) + } + + fn visit_u8(self, value: u8) -> std::result::Result + where + E: serde::de::Error, + { + convert_unsigned_to_signed_raw(value.into()) + } + + fn visit_u16(self, value: u16) -> std::result::Result + where + E: serde::de::Error, + { + convert_unsigned_to_signed_raw(value.into()) + } + + fn visit_u32(self, value: u32) -> std::result::Result + where + E: serde::de::Error, + { + convert_unsigned_to_signed_raw(value.into()) + } + + fn visit_u64(self, value: u64) -> std::result::Result + where + E: serde::de::Error, + { + convert_unsigned_to_signed_raw(value) + } + + fn visit_bool(self, v: bool) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Boolean(v)) + } + + fn visit_f64(self, v: f64) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Double(v)) + } + + fn visit_none(self) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Null) + } + + fn visit_unit(self) -> std::result::Result + where + E: serde::de::Error, + { + Ok(RawBson::Null) + } + + fn visit_newtype_struct(self, deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_map(self, mut map: A) -> std::result::Result + where + A: serde::de::MapAccess<'de>, + { + let k = map + .next_key::<&str>()? + .ok_or_else(|| serde::de::Error::custom("expected a key when deserializing RawBson"))?; + match k { + "$oid" => { + let oid: ObjectId = map.next_value()?; + Ok(RawBson::ObjectId(oid)) + } + "$symbol" => { + let s: &str = map.next_value()?; + Ok(RawBson::Symbol(s)) + } + "$numberDecimalBytes" => { + let bytes = map.next_value::()?; + return Ok(RawBson::Decimal128(Decimal128::deserialize_from_slice( + &bytes, + )?)); + } + "$regularExpression" => { + #[derive(Debug, Deserialize)] + struct BorrowedRegexBody<'a> { + pattern: &'a str, + + options: &'a str, + } + let body: BorrowedRegexBody = map.next_value()?; + Ok(RawBson::RegularExpression(RawRegex { + pattern: body.pattern, + options: body.options, + })) + } + "$undefined" => { + let _: bool = map.next_value()?; + Ok(RawBson::Undefined) + } + "$binary" => { + #[derive(Debug, Deserialize)] + struct BorrowedBinaryBody<'a> { + bytes: &'a [u8], + + #[serde(rename = "subType")] + subtype: u8, + } + + let v = map.next_value::()?; + + Ok(RawBson::Binary(RawBinary { + bytes: v.bytes, + subtype: v.subtype.into(), + })) + } + "$date" => { + let v = map.next_value::()?; + Ok(RawBson::DateTime(DateTime::from_millis(v))) + } + "$timestamp" => { + let v = map.next_value::()?; + Ok(RawBson::Timestamp(Timestamp { + time: v.t, + increment: v.i, + })) + } + "$minKey" => { + let _ = map.next_value::()?; + Ok(RawBson::MinKey) + } + "$maxKey" => { + let _ = map.next_value::()?; + Ok(RawBson::MaxKey) + } + "$code" => { + let code = map.next_value::<&str>()?; + if let Some(key) = map.next_key::<&str>()? { + if key == "$scope" { + let scope = map.next_value::<&RawDocument>()?; + Ok(RawBson::JavaScriptCodeWithScope( + RawJavaScriptCodeWithScope { code, scope }, + )) + } else { + Err(serde::de::Error::unknown_field(key, &["$scope"])) + } + } else { + Ok(RawBson::JavaScriptCode(code)) + } + } + "$dbPointer" => { + #[derive(Deserialize)] + struct BorrowedDbPointerBody<'a> { + #[serde(rename = "$ref")] + ns: &'a str, + + #[serde(rename = "$id")] + id: ObjectId, + } + + let body: BorrowedDbPointerBody = map.next_value()?; + Ok(RawBson::DbPointer(RawDbPointer { + namespace: body.ns, + id: body.id, + })) + } + RAW_DOCUMENT_NEWTYPE => { + let bson = map.next_value::<&[u8]>()?; + let doc = RawDocument::new(bson).map_err(serde::de::Error::custom)?; + Ok(RawBson::Document(doc)) + } + RAW_ARRAY_NEWTYPE => { + let bson = map.next_value::<&[u8]>()?; + let doc = RawDocument::new(bson).map_err(serde::de::Error::custom)?; + Ok(RawBson::Array(RawArray::from_doc(doc))) + } + k => Err(serde::de::Error::custom(format!( + "can't deserialize RawBson from map, key={}", + k + ))), + } + } +} + +impl<'de: 'a, 'a> Deserialize<'de> for RawBson<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_newtype_struct(RAW_BSON_NEWTYPE, RawBsonVisitor) + } +} + +impl<'a> Serialize for RawBson<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + match self { + RawBson::Double(v) => serializer.serialize_f64(*v), + RawBson::String(v) => serializer.serialize_str(v), + RawBson::Array(v) => v.serialize(serializer), + RawBson::Document(v) => v.serialize(serializer), + RawBson::Boolean(v) => serializer.serialize_bool(*v), + RawBson::Null => serializer.serialize_unit(), + RawBson::Int32(v) => serializer.serialize_i32(*v), + RawBson::Int64(v) => serializer.serialize_i64(*v), + RawBson::ObjectId(oid) => oid.serialize(serializer), + RawBson::DateTime(dt) => dt.serialize(serializer), + RawBson::Binary(b) => b.serialize(serializer), + RawBson::JavaScriptCode(c) => { + let mut state = serializer.serialize_struct("$code", 1)?; + state.serialize_field("$code", c)?; + state.end() + } + RawBson::JavaScriptCodeWithScope(code_w_scope) => code_w_scope.serialize(serializer), + RawBson::DbPointer(dbp) => dbp.serialize(serializer), + RawBson::Symbol(s) => { + let mut state = serializer.serialize_struct("$symbol", 1)?; + state.serialize_field("$symbol", s)?; + state.end() + } + RawBson::RegularExpression(re) => re.serialize(serializer), + RawBson::Timestamp(t) => t.serialize(serializer), + RawBson::Decimal128(d) => d.serialize(serializer), + RawBson::Undefined => { + let mut state = serializer.serialize_struct("$undefined", 1)?; + state.serialize_field("$undefined", &true)?; + state.end() + } + RawBson::MaxKey => { + let mut state = serializer.serialize_struct("$maxKey", 1)?; + state.serialize_field("$maxKey", &1)?; + state.end() + } + RawBson::MinKey => { + let mut state = serializer.serialize_struct("$minKey", 1)?; + state.serialize_field("$minKey", &1)?; + state.end() + } + } + } +} + impl<'a> TryFrom> for Bson { type Error = Error; @@ -302,8 +610,8 @@ impl<'a> TryFrom> for Bson { /// A BSON binary value referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawBinary<'a> { - pub(super) subtype: BinarySubtype, - pub(super) bytes: &'a [u8], + pub(crate) subtype: BinarySubtype, + pub(crate) bytes: &'a [u8], } impl<'a> RawBinary<'a> { @@ -318,11 +626,61 @@ impl<'a> RawBinary<'a> { } } +impl<'de: 'a, 'a> Deserialize<'de> for RawBinary<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::Binary(b) => Ok(b), + c => Err(serde::de::Error::custom(format!( + "expected binary, but got {:?} instead", + c + ))), + } + } +} + +impl<'a> Serialize for RawBinary<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + if let BinarySubtype::Generic = self.subtype { + serializer.serialize_bytes(self.bytes) + } else if !serializer.is_human_readable() { + #[derive(Serialize)] + struct BorrowedBinary<'a> { + bytes: &'a Bytes, + + #[serde(rename = "subType")] + subtype: u8, + } + + let mut state = serializer.serialize_struct("$binary", 1)?; + let body = BorrowedBinary { + bytes: Bytes::new(self.bytes), + subtype: self.subtype.into(), + }; + state.serialize_field("$binary", &body)?; + state.end() + } else { + let mut state = serializer.serialize_struct("$binary", 1)?; + let body = extjson::models::BinaryBody { + base64: base64::encode(self.bytes), + subtype: hex::encode([self.subtype.into()]), + }; + state.serialize_field("$binary", &body)?; + state.end() + } + } +} + /// A BSON regex referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawRegex<'a> { - pub(super) pattern: &'a str, - pub(super) options: &'a str, + pub(crate) pattern: &'a str, + pub(crate) options: &'a str, } impl<'a> RawRegex<'a> { @@ -337,10 +695,47 @@ impl<'a> RawRegex<'a> { } } +impl<'de: 'a, 'a> Deserialize<'de> for RawRegex<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::RegularExpression(b) => Ok(b), + c => Err(serde::de::Error::custom(format!( + "expected Regex, but got {:?} instead", + c + ))), + } + } +} + +impl<'a> Serialize for RawRegex<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + #[derive(Serialize)] + struct BorrowedRegexBody<'a> { + pattern: &'a str, + options: &'a str, + } + + let mut state = serializer.serialize_struct("$regularExpression", 1)?; + let body = BorrowedRegexBody { + pattern: self.pattern, + options: self.options, + }; + state.serialize_field("$regularExpression", &body)?; + state.end() + } +} + /// A BSON "code with scope" value referencing raw bytes stored elsewhere. #[derive(Clone, Copy, Debug, PartialEq)] pub struct RawJavaScriptCodeWithScope<'a> { pub(crate) code: &'a str, + pub(crate) scope: &'a RawDocument, } @@ -356,9 +751,75 @@ impl<'a> RawJavaScriptCodeWithScope<'a> { } } +impl<'de: 'a, 'a> Deserialize<'de> for RawJavaScriptCodeWithScope<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::JavaScriptCodeWithScope(b) => Ok(b), + c => Err(serde::de::Error::custom(format!( + "expected CodeWithScope, but got {:?} instead", + c + ))), + } + } +} + +impl<'a> Serialize for RawJavaScriptCodeWithScope<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let mut state = serializer.serialize_struct("$codeWithScope", 2)?; + state.serialize_field("$code", &self.code)?; + state.serialize_field("$scope", &self.scope)?; + state.end() + } +} + /// A BSON DB pointer value referencing raw bytes stored elesewhere. #[derive(Debug, Clone, Copy, PartialEq)] pub struct RawDbPointer<'a> { pub(crate) namespace: &'a str, pub(crate) id: ObjectId, } + +impl<'de: 'a, 'a> Deserialize<'de> for RawDbPointer<'a> { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match RawBson::deserialize(deserializer)? { + RawBson::DbPointer(b) => Ok(b), + c => Err(serde::de::Error::custom(format!( + "expected DbPointer, but got {:?} instead", + c + ))), + } + } +} + +impl<'a> Serialize for RawDbPointer<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + #[derive(Serialize)] + struct BorrowedDbPointerBody<'a> { + #[serde(rename = "$ref")] + ref_ns: &'a str, + + #[serde(rename = "$id")] + id: ObjectId, + } + + let mut state = serializer.serialize_struct("$dbPointer", 1)?; + let body = BorrowedDbPointerBody { + ref_ns: self.namespace, + id: self.id, + }; + state.serialize_field("$dbPointer", &body)?; + state.end() + } +} diff --git a/src/raw/document.rs b/src/raw/document.rs index e2141bcc..053b64d0 100644 --- a/src/raw/document.rs +++ b/src/raw/document.rs @@ -3,7 +3,14 @@ use std::{ convert::{TryFrom, TryInto}, }; -use crate::{raw::error::ErrorKind, DateTime, Timestamp}; +use serde::{ser::SerializeMap, Deserialize, Serialize}; + +use crate::{ + raw::{error::ErrorKind, RawBsonVisitor, RAW_DOCUMENT_NEWTYPE}, + spec::BinarySubtype, + DateTime, + Timestamp, +}; use super::{ error::{ValueAccessError, ValueAccessErrorKind, ValueAccessResult}, @@ -482,6 +489,57 @@ impl RawDocument { } } +impl<'de: 'a, 'a> Deserialize<'de> for &'a RawDocument { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + match deserializer.deserialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, RawBsonVisitor)? { + RawBson::Document(d) => Ok(d), + + // For non-BSON formats, RawDocument gets serialized as bytes, so we need to deserialize + // from them here too. For BSON, the deserializier will return an error if it + // sees the RAW_DOCUMENT_NEWTYPE but the next type isn't a document. + RawBson::Binary(b) if b.subtype == BinarySubtype::Generic => { + RawDocument::new(b.bytes).map_err(serde::de::Error::custom) + } + + o => Err(serde::de::Error::custom(format!( + "expected raw document reference, instead got {:?}", + o + ))), + } + } +} + +impl<'a> Serialize for &'a RawDocument { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + struct KvpSerializer<'a>(&'a RawDocument); + + impl<'a> Serialize for KvpSerializer<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + if serializer.is_human_readable() { + let mut map = serializer.serialize_map(None)?; + for kvp in self.0 { + let (k, v) = kvp.map_err(serde::ser::Error::custom)?; + map.serialize_entry(k, &v)?; + } + map.end() + } else { + serializer.serialize_bytes(self.0.as_bytes()) + } + } + } + serializer.serialize_newtype_struct(RAW_DOCUMENT_NEWTYPE, &KvpSerializer(self)) + } +} + impl std::fmt::Debug for RawDocument { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RawDocument") diff --git a/src/raw/document_buf.rs b/src/raw/document_buf.rs index 019a6e25..f1c216d3 100644 --- a/src/raw/document_buf.rs +++ b/src/raw/document_buf.rs @@ -4,6 +4,8 @@ use std::{ ops::Deref, }; +use serde::{Deserialize, Serialize}; + use crate::Document; use super::{Error, ErrorKind, Iter, RawBson, RawDocument, Result}; @@ -139,6 +141,27 @@ impl RawDocumentBuf { } } +impl<'de> Deserialize<'de> for RawDocumentBuf { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + // TODO: RUST-1045 implement visit_map to deserialize from arbitrary maps. + let doc: &'de RawDocument = Deserialize::deserialize(deserializer)?; + Ok(doc.to_owned()) + } +} + +impl Serialize for RawDocumentBuf { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let doc: &RawDocument = self.deref(); + doc.serialize(serializer) + } +} + impl std::fmt::Debug for RawDocumentBuf { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RawDocumentBuf") diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 59f36595..89c16ea1 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -134,6 +134,17 @@ pub use self::{ iter::Iter, }; +pub(crate) use self::bson::RawBsonVisitor; + +/// Special newtype name indicating that the type being (de)serialized is a raw BSON document. +pub(crate) const RAW_DOCUMENT_NEWTYPE: &str = "$__private__bson_RawDocument"; + +/// Special newtype name indicating that the type being (de)serialized is a raw BSON array. +pub(crate) const RAW_ARRAY_NEWTYPE: &str = "$__private__bson_RawArray"; + +/// Special newtype name indicating that the type being (de)serialized is a raw BSON value. +pub(crate) const RAW_BSON_NEWTYPE: &str = "$__private__bson_RawBson"; + /// Given a u8 slice, return an i32 calculated from the first four bytes in /// little endian order. fn f64_from_slice(val: &[u8]) -> Result { diff --git a/src/raw/test/props.rs b/src/raw/test/props.rs index 850dcade..6f0157d2 100644 --- a/src/raw/test/props.rs +++ b/src/raw/test/props.rs @@ -22,11 +22,7 @@ pub(crate) fn arbitrary_bson() -> impl Strategy { any::().prop_map(Bson::Int32), any::().prop_map(Bson::Int64), any::<(String, String)>().prop_map(|(pattern, options)| { - let mut chars: Vec<_> = options.chars().collect(); - chars.sort_unstable(); - - let options: String = chars.into_iter().collect(); - Bson::RegularExpression(Regex { pattern, options }) + Bson::RegularExpression(Regex::new(pattern, options)) }), any::<[u8; 12]>().prop_map(|bytes| Bson::ObjectId(crate::oid::ObjectId::from_bytes(bytes))), (arbitrary_binary_subtype(), any::>()).prop_map(|(subtype, bytes)| { diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs index eef34538..bc178f76 100644 --- a/src/ser/raw/mod.rs +++ b/src/ser/raw/mod.rs @@ -1,6 +1,8 @@ mod document_serializer; mod value_serializer; +use std::io::Write; + use serde::{ ser::{Error as SerdeError, SerializeMap, SerializeStruct}, Serialize, @@ -10,6 +12,7 @@ use self::value_serializer::{ValueSerializer, ValueType}; use super::{write_binary, write_cstring, write_f64, write_i32, write_i64, write_string}; use crate::{ + raw::{RAW_ARRAY_NEWTYPE, RAW_DOCUMENT_NEWTYPE}, ser::{Error, Result}, spec::{BinarySubtype, ElementType}, uuid::UUID_NEWTYPE_NAME, @@ -25,9 +28,30 @@ pub(crate) struct Serializer { /// but in serde, the serializer learns of the type after serializing the key. type_index: usize, - /// Whether the binary value about to be serialized is a UUID or not. - /// This is indicated by serializing a newtype with name UUID_NEWTYPE_NAME; - is_uuid: bool, + /// Hint provided by the type being serialized. + hint: SerializerHint, +} + +/// Various bits of information that the serialized type can provide to the serializer to +/// inform the purpose of the next serialization step. +#[derive(Debug, Clone, Copy)] +enum SerializerHint { + None, + + /// The next call to `serialize_bytes` is for the purposes of serializing a UUID. + Uuid, + + /// The next call to `serialize_bytes` is for the purposes of serializing a raw document. + RawDocument, + + /// The next call to `serialize_bytes` is for the purposes of serializing a raw array. + RawArray, +} + +impl SerializerHint { + fn take(&mut self) -> SerializerHint { + std::mem::replace(self, SerializerHint::None) + } } impl Serializer { @@ -35,7 +59,7 @@ impl Serializer { Self { bytes: Vec::new(), type_index: 0, - is_uuid: false, + hint: SerializerHint::None, } } @@ -176,16 +200,27 @@ impl<'a> serde::Serializer for &'a mut Serializer { #[inline] fn serialize_bytes(self, v: &[u8]) -> Result { - self.update_element_type(ElementType::Binary)?; + match self.hint.take() { + SerializerHint::RawDocument => { + self.update_element_type(ElementType::EmbeddedDocument)?; + self.bytes.write_all(v)?; + } + SerializerHint::RawArray => { + self.update_element_type(ElementType::Array)?; + self.bytes.write_all(v)?; + } + hint => { + self.update_element_type(ElementType::Binary)?; - let subtype = if self.is_uuid { - self.is_uuid = false; - BinarySubtype::Uuid - } else { - BinarySubtype::Generic - }; + let subtype = if matches!(hint, SerializerHint::Uuid) { + BinarySubtype::Uuid + } else { + BinarySubtype::Generic + }; - write_binary(&mut self.bytes, v, subtype)?; + write_binary(&mut self.bytes, v, subtype)?; + } + }; Ok(()) } @@ -232,8 +267,11 @@ impl<'a> serde::Serializer for &'a mut Serializer { where T: serde::Serialize, { - if name == UUID_NEWTYPE_NAME { - self.is_uuid = true; + match name { + UUID_NEWTYPE_NAME => self.hint = SerializerHint::Uuid, + RAW_DOCUMENT_NEWTYPE => self.hint = SerializerHint::RawDocument, + RAW_ARRAY_NEWTYPE => self.hint = SerializerHint::RawArray, + _ => {} } value.serialize(self) } diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs index 9af76549..49880a85 100644 --- a/src/ser/raw/value_serializer.rs +++ b/src/ser/raw/value_serializer.rs @@ -7,6 +7,7 @@ use serde::{ use crate::{ oid::ObjectId, + raw::RAW_DOCUMENT_NEWTYPE, ser::{write_binary, write_cstring, write_i32, write_i64, write_string, Error, Result}, spec::{BinarySubtype, ElementType}, }; @@ -30,8 +31,15 @@ enum SerializationStep { DateTimeNumberLong, Binary, - BinaryBase64, - BinarySubType { base64: String }, + /// This step can either transition to the raw or base64 steps depending + /// on whether a string or bytes are serialized. + BinaryBytes, + BinarySubType { + base64: String, + }, + RawBinarySubType { + bytes: Vec, + }, Symbol, @@ -41,7 +49,9 @@ enum SerializationStep { Timestamp, TimestampTime, - TimestampIncrement { time: i64 }, + TimestampIncrement { + time: i64, + }, DbPointer, DbPointerRef, @@ -50,7 +60,10 @@ enum SerializationStep { Code, CodeWithScopeCode, - CodeWithScopeScope { code: String }, + CodeWithScopeScope { + code: String, + raw: bool, + }, MinKey, @@ -185,8 +198,15 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { } #[inline] - fn serialize_u8(self, _v: u8) -> Result { - Err(self.invalid_step("u8")) + fn serialize_u8(self, v: u8) -> Result { + match self.state { + SerializationStep::RawBinarySubType { ref bytes } => { + write_binary(&mut self.root_serializer.bytes, bytes.as_slice(), v.into())?; + self.state = SerializationStep::Done; + Ok(()) + } + _ => Err(self.invalid_step("u8")), + } } #[inline] @@ -229,7 +249,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { let oid = ObjectId::parse_str(v).map_err(Error::custom)?; self.root_serializer.bytes.write_all(&oid.bytes())?; } - SerializationStep::BinaryBase64 => { + SerializationStep::BinaryBytes => { self.state = SerializationStep::BinarySubType { base64: v.to_string(), }; @@ -245,15 +265,23 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { SerializationStep::Symbol | SerializationStep::DbPointerRef => { write_string(&mut self.root_serializer.bytes, v)?; } - SerializationStep::RegExPattern | SerializationStep::RegExOptions => { + SerializationStep::RegExPattern => { write_cstring(&mut self.root_serializer.bytes, v)?; } + SerializationStep::RegExOptions => { + let mut chars: Vec<_> = v.chars().collect(); + chars.sort_unstable(); + + let sorted = chars.into_iter().collect::(); + write_cstring(&mut self.root_serializer.bytes, sorted.as_str())?; + } SerializationStep::Code => { write_string(&mut self.root_serializer.bytes, v)?; } SerializationStep::CodeWithScopeCode => { self.state = SerializationStep::CodeWithScopeScope { code: v.to_string(), + raw: false, }; } s => { @@ -273,6 +301,18 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { self.root_serializer.bytes.write_all(v)?; Ok(()) } + SerializationStep::BinaryBytes => { + self.state = SerializationStep::RawBinarySubType { bytes: v.to_vec() }; + Ok(()) + } + SerializationStep::CodeWithScopeScope { ref code, raw } if raw => { + let len = 4 + 4 + code.len() as i32 + 1 + v.len() as i32; + write_i32(&mut self.root_serializer.bytes, len)?; + write_string(&mut self.root_serializer.bytes, code)?; + self.root_serializer.bytes.write_all(v)?; + self.state = SerializationStep::Done; + Ok(()) + } _ => Err(self.invalid_step("&[u8]")), } } @@ -311,15 +351,23 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { } #[inline] - fn serialize_newtype_struct( - self, - _name: &'static str, - _value: &T, - ) -> Result + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result where T: Serialize, { - Err(self.invalid_step("newtype_struct")) + match (&mut self.state, name) { + ( + SerializationStep::CodeWithScopeScope { + code: _, + ref mut raw, + }, + RAW_DOCUMENT_NEWTYPE, + ) => { + *raw = true; + value.serialize(self) + } + _ => Err(self.invalid_step("newtype_struct")), + } } #[inline] @@ -338,7 +386,7 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { #[inline] fn serialize_seq(self, _len: Option) -> Result { - Err(self.invalid_step("newtype_seq")) + Err(self.invalid_step("seq")) } #[inline] @@ -369,10 +417,10 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { #[inline] fn serialize_map(self, _len: Option) -> Result { match self.state { - SerializationStep::CodeWithScopeScope { ref code } => { + SerializationStep::CodeWithScopeScope { ref code, raw } if !raw => { CodeWithScopeSerializer::start(code.as_str(), self.root_serializer) } - _ => Err(self.invalid_step("tuple_map")), + _ => Err(self.invalid_step("map")), } } @@ -391,6 +439,10 @@ impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { ) -> Result { Err(self.invalid_step("struct_variant")) } + + fn is_human_readable(&self) -> bool { + false + } } impl<'a, 'b> SerializeStruct for &'b mut ValueSerializer<'a> { @@ -415,13 +467,17 @@ impl<'a, 'b> SerializeStruct for &'b mut ValueSerializer<'a> { self.state = SerializationStep::Done; } (SerializationStep::Binary, "$binary") => { - self.state = SerializationStep::BinaryBase64; + self.state = SerializationStep::BinaryBytes; value.serialize(&mut **self)?; } - (SerializationStep::BinaryBase64, "base64") => { + (SerializationStep::BinaryBytes, key) if key == "bytes" || key == "base64" => { // state is updated in serialize value.serialize(&mut **self)?; } + (SerializationStep::RawBinarySubType { .. }, "subType") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } (SerializationStep::BinarySubType { .. }, "subType") => { value.serialize(&mut **self)?; self.state = SerializationStep::Done; diff --git a/src/ser/serde.rs b/src/ser/serde.rs index f5ff7797..046978d8 100644 --- a/src/ser/serde.rs +++ b/src/ser/serde.rs @@ -17,6 +17,7 @@ use crate::{ datetime::DateTime, extjson, oid::ObjectId, + raw::{RawDbPointer, RawRegex}, spec::BinarySubtype, uuid::UUID_NEWTYPE_NAME, Binary, @@ -546,13 +547,11 @@ impl Serialize for Regex { where S: ser::Serializer, { - let mut state = serializer.serialize_struct("$regularExpression", 1)?; - let body = extjson::models::RegexBody { - pattern: self.pattern.clone(), - options: self.options.clone(), + let raw = RawRegex { + pattern: self.pattern.as_str(), + options: self.options.as_str(), }; - state.serialize_field("$regularExpression", &body)?; - state.end() + raw.serialize(serializer) } } @@ -620,12 +619,10 @@ impl Serialize for DbPointer { where S: ser::Serializer, { - let mut state = serializer.serialize_struct("$dbPointer", 1)?; - let body = extjson::models::DbPointerBody { - ref_ns: self.namespace.clone(), - id: self.id.into(), + let raw = RawDbPointer { + namespace: self.namespace.as_str(), + id: self.id, }; - state.serialize_field("$dbPointer", &body)?; - state.end() + raw.serialize(serializer) } } diff --git a/src/tests/spec/corpus.rs b/src/tests/spec/corpus.rs index 9c6844dd..b4329fb1 100644 --- a/src/tests/spec/corpus.rs +++ b/src/tests/spec/corpus.rs @@ -1,11 +1,17 @@ use std::{ convert::{TryFrom, TryInto}, + marker::PhantomData, str::FromStr, }; -use crate::{raw::RawDocument, tests::LOCK, Bson, Document}; +use crate::{ + raw::{RawBson, RawDocument}, + tests::LOCK, + Bson, + Document, +}; use pretty_assertions::assert_eq; -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; use super::run_spec_test; @@ -59,6 +65,34 @@ struct ParseError { string: String, } +struct FieldVisitor<'a, T>(&'a str, PhantomData); + +impl<'de, 'a, T> serde::de::Visitor<'de> for FieldVisitor<'a, T> +where + T: Deserialize<'de>, +{ + type Value = T; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "expecting RawBson at field {}", self.0) + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + while let Some((k, v)) = map.next_entry::()? { + if k.as_str() == self.0 { + return Ok(v); + } + } + Err(serde::de::Error::custom(format!( + "missing field: {}", + self.0 + ))) + } +} + fn run_test(test: TestFile) { let _guard = LOCK.run_concurrently(); for valid in test.valid { @@ -79,11 +113,19 @@ fn run_test(test: TestFile) { let todocument_documentfromreader_cb: Document = crate::to_document(&documentfromreader_cb).expect(&description); - let document_from_raw_document: Document = RawDocument::new(canonical_bson.as_slice()) + let canonical_raw_document = + RawDocument::new(canonical_bson.as_slice()).expect(&description); + let document_from_raw_document: Document = + canonical_raw_document.try_into().expect(&description); + + let canonical_raw_bson_from_slice = crate::from_slice::(canonical_bson.as_slice()) .expect(&description) - .try_into() + .as_document() .expect(&description); + let canonical_raw_document_from_slice = + crate::from_slice::<&RawDocument>(canonical_bson.as_slice()).expect(&description); + // These cover the ways to serialize those `Documents` back to BSON. let mut documenttowriter_documentfromreader_cb = Vec::new(); documentfromreader_cb @@ -113,6 +155,62 @@ fn run_test(test: TestFile) { .to_writer(&mut documenttowriter_document_from_raw_document) .expect(&description); + // Serialize the raw versions "back" to BSON also. + let tovec_rawdocument = crate::to_vec(&canonical_raw_document).expect(&description); + let tovec_rawdocument_from_slice = + crate::to_vec(&canonical_raw_document_from_slice).expect(&description); + let tovec_rawbson = crate::to_vec(&canonical_raw_bson_from_slice).expect(&description); + + // test Bson / RawBson field deserialization + if let Some(ref test_key) = test.test_key { + // skip regex tests that don't have the value at the test key + if !description.contains("$regex query operator") { + // deserialize the field from raw Bytes into a RawBson + let mut deserializer_raw = + crate::de::RawDeserializer::new(canonical_bson.as_slice(), false); + let raw_bson_field = deserializer_raw + .deserialize_any(FieldVisitor(test_key.as_str(), PhantomData::)) + .expect(&description); + // convert to an owned Bson and put into a Document + let bson: Bson = raw_bson_field.try_into().expect(&description); + let from_raw_doc = doc! { + test_key: bson + }; + + // deserialize the field from raw Bytes into a Bson + let mut deserializer_value = + crate::de::RawDeserializer::new(canonical_bson.as_slice(), false); + let bson_field = deserializer_value + .deserialize_any(FieldVisitor(test_key.as_str(), PhantomData::)) + .expect(&description); + // put into a Document + let from_value_doc = doc! { + test_key: bson_field, + }; + + // deserialize the field from a Bson into a Bson + let deserializer_value_value = + crate::Deserializer::new(Bson::Document(documentfromreader_cb.clone())); + let bson_field = deserializer_value_value + .deserialize_any(FieldVisitor(test_key.as_str(), PhantomData::)) + .expect(&description); + // put into a Document + let from_value_value_doc = doc! { + test_key: bson_field, + }; + + // convert back into raw BSON for comparison with canonical BSON + let from_raw_vec = crate::to_vec(&from_raw_doc).expect(&description); + let from_value_vec = crate::to_vec(&from_value_doc).expect(&description); + let from_value_value_vec = + crate::to_vec(&from_value_value_doc).expect(&description); + + assert_eq!(from_raw_vec, canonical_bson, "{}", description); + assert_eq!(from_value_vec, canonical_bson, "{}", description); + assert_eq!(from_value_value_vec, canonical_bson, "{}", description); + } + } + // native_to_bson( bson_to_native(cB) ) = cB // now we ensure the hex for all 5 are equivalent to the canonical BSON provided by the @@ -159,6 +257,20 @@ fn run_test(test: TestFile) { description, ); + assert_eq!(tovec_rawdocument, tovec_rawbson, "{}", description); + assert_eq!( + tovec_rawdocument, tovec_rawdocument_from_slice, + "{}", + description + ); + + assert_eq!( + hex::encode(tovec_rawdocument).to_lowercase(), + valid.canonical_bson.to_lowercase(), + "{}", + description, + ); + // NaN == NaN is false, so we skip document comparisons that contain NaN if !description.to_ascii_lowercase().contains("nan") && !description.contains("decq541") { assert_eq!(documentfromreader_cb, fromreader_cb, "{}", description);