diff --git a/.evergreen/config.yml b/.evergreen/config.yml index f1f871f2..8fe823de 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -92,16 +92,6 @@ functions: ${PREPARE_SHELL} .evergreen/run-tests-u2i.sh - "run serde tests": - - command: shell.exec - type: test - params: - shell: bash - working_dir: "src" - script: | - ${PREPARE_SHELL} - .evergreen/run-tests-serde.sh - "run decimal128 tests": - command: shell.exec type: test @@ -170,10 +160,6 @@ tasks: commands: - func: "run u2i tests" - - name: "test-serde" - commands: - - func: "run serde tests" - - name: "test-decimal128" commands: - func: "run decimal128 tests" @@ -211,7 +197,6 @@ buildvariants: tasks: - name: "test" - name: "test-u2i" - - name: "test-serde" - name: "test-decimal128" - matrix_name: "compile only" diff --git a/.evergreen/run-tests-decimal128.sh b/.evergreen/run-tests-decimal128.sh index 10a30b01..51f62cbf 100755 --- a/.evergreen/run-tests-decimal128.sh +++ b/.evergreen/run-tests-decimal128.sh @@ -4,3 +4,6 @@ set -o errexit . ~/.cargo/env RUST_BACKTRACE=1 cargo test --features decimal128 + +cd serde-tests +RUST_BACKTRACE=1 cargo test --features decimal128 diff --git a/.evergreen/run-tests-u2i.sh b/.evergreen/run-tests-u2i.sh index 373adb93..a3aba15c 100755 --- a/.evergreen/run-tests-u2i.sh +++ b/.evergreen/run-tests-u2i.sh @@ -4,3 +4,6 @@ set -o errexit . ~/.cargo/env RUST_BACKTRACE=1 cargo test --features u2i + +cd serde-tests +RUST_BACKTRACE=1 cargo test --features u2i diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 71e4230f..6d56e8f6 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -5,3 +5,6 @@ set -o errexit . ~/.cargo/env RUST_BACKTRACE=1 cargo test RUST_BACKTRACE=1 cargo test --features chrono-0_4,uuid-0_8 + +cd serde-tests +RUST_BACKTRACE=1 cargo test diff --git a/serde-tests/Cargo.toml b/serde-tests/Cargo.toml index 62be8bd2..45bdfce3 100644 --- a/serde-tests/Cargo.toml +++ b/serde-tests/Cargo.toml @@ -4,10 +4,15 @@ version = "0.1.0" authors = ["Kevin Yeh "] edition = "2018" +[features] +u2i = ["bson/u2i"] +decimal128 = ["bson/decimal128"] + [dependencies] -bson = { path = "..", features = ["decimal128"] } +bson = { path = ".." } serde = { version = "1.0", features = ["derive"] } pretty_assertions = "0.6.1" +hex = "0.4.2" [lib] name = "serde_tests" diff --git a/serde-tests/test.rs b/serde-tests/test.rs index d8e2df01..6e46dddd 100644 --- a/serde-tests/test.rs +++ b/serde-tests/test.rs @@ -14,6 +14,8 @@ use std::{ collections::{BTreeMap, HashSet}, }; +#[cfg(feature = "decimal128")] +use bson::Decimal128; use bson::{ doc, oid::ObjectId, @@ -21,7 +23,6 @@ use bson::{ Binary, Bson, DateTime, - Decimal128, Deserializer, Document, JavaScriptCodeWithScope, @@ -37,6 +38,7 @@ use bson::{ /// - deserializing a `T` from the raw BSON version of `expected_doc` produces `expected_value` /// - deserializing a `Document` from the raw BSON version of `expected_doc` produces /// `expected_doc` +/// - `bson::to_writer` and `Document::to_writer` produce the same result given the same input fn run_test(expected_value: &T, expected_doc: &Document, description: &str) where T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug, @@ -46,6 +48,16 @@ where .to_writer(&mut expected_bytes) .expect(description); + let expected_bytes_serde = bson::to_vec(&expected_value).expect(description); + assert_eq!(expected_bytes_serde, expected_bytes, "{}", description); + + let expected_bytes_from_doc_serde = bson::to_vec(&expected_doc).expect(description); + assert_eq!( + expected_bytes_from_doc_serde, expected_bytes, + "{}", + description + ); + let serialized_doc = bson::to_document(&expected_value).expect(description); assert_eq!(&serialized_doc, expected_doc, "{}", description); assert_eq!( @@ -702,7 +714,7 @@ fn all_types() { undefined: Bson, code: Bson, code_w_scope: JavaScriptCodeWithScope, - decimal: Decimal128, + decimal: Bson, symbol: Bson, min_key: Bson, max_key: Bson, @@ -737,6 +749,16 @@ fn all_types() { let oid = ObjectId::new(); let subdoc = doc! { "k": true, "b": { "hello": "world" } }; + #[cfg(not(feature = "decimal128"))] + let decimal = { + let bytes = hex::decode("18000000136400D0070000000000000000000000003A3000").unwrap(); + let d = Document::from_reader(bytes.as_slice()).unwrap(); + d.get("d").unwrap().clone() + }; + + #[cfg(feature = "decimal128")] + let decimal = Bson::Decimal128(Decimal128::from_str("2.000")); + let doc = doc! { "x": 1, "y": 2_i64, @@ -758,7 +780,7 @@ fn all_types() { "undefined": Bson::Undefined, "code": code.clone(), "code_w_scope": code_w_scope.clone(), - "decimal": Bson::Decimal128(Decimal128::from_i32(5)), + "decimal": decimal.clone(), "symbol": Bson::Symbol("ok".to_string()), "min_key": Bson::MinKey, "max_key": Bson::MaxKey, @@ -789,7 +811,7 @@ fn all_types() { undefined: Bson::Undefined, code, code_w_scope, - decimal: Decimal128::from_i32(5), + decimal, symbol: Bson::Symbol("ok".to_string()), min_key: Bson::MinKey, max_key: Bson::MaxKey, @@ -851,3 +873,83 @@ fn borrowed() { bson::from_slice(bson.as_slice()).expect("deserialization should succeed"); assert_eq!(deserialized, v); } + +#[cfg(feature = "u2i")] +#[test] +fn u2i() { + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct Foo { + u_8: u8, + u_16: u16, + u_32: u32, + u_32_max: u32, + u_64: u64, + i_64_max: u64, + } + + let v = Foo { + u_8: 15, + u_16: 123, + u_32: 1234, + u_32_max: u32::MAX, + u_64: 12345, + i_64_max: i64::MAX as u64, + }; + + let expected = doc! { + "u_8": 15_i32, + "u_16": 123_i32, + "u_32": 1234_i64, + "u_32_max": u32::MAX as i64, + "u_64": 12345_i64, + "i_64_max": i64::MAX as u64, + }; + + run_test(&v, &expected, "u2i - valid"); + + #[derive(Serialize, Debug)] + struct TooBig { + u_64: u64, + } + let v = TooBig { + u_64: i64::MAX as u64 + 1, + }; + bson::to_document(&v).unwrap_err(); + bson::to_vec(&v).unwrap_err(); +} + +#[cfg(not(feature = "u2i"))] +#[test] +fn unsigned() { + #[derive(Serialize, Debug)] + struct U8 { + v: u8, + } + let v = U8 { v: 1 }; + bson::to_document(&v).unwrap_err(); + bson::to_vec(&v).unwrap_err(); + + #[derive(Serialize, Debug)] + struct U16 { + v: u16, + } + let v = U16 { v: 1 }; + bson::to_document(&v).unwrap_err(); + bson::to_vec(&v).unwrap_err(); + + #[derive(Serialize, Debug)] + struct U32 { + v: u32, + } + let v = U32 { v: 1 }; + bson::to_document(&v).unwrap_err(); + bson::to_vec(&v).unwrap_err(); + + #[derive(Serialize, Debug)] + struct U64 { + v: u64, + } + let v = U64 { v: 1 }; + bson::to_document(&v).unwrap_err(); + bson::to_vec(&v).unwrap_err(); +} diff --git a/src/bson.rs b/src/bson.rs index 03b8bee0..c70fdb94 100644 --- a/src/bson.rs +++ b/src/bson.rs @@ -22,7 +22,7 @@ //! BSON definition use std::{ - convert::TryFrom, + convert::{TryFrom, TryInto}, fmt::{self, Debug, Display, Formatter}, }; @@ -698,6 +698,20 @@ impl Bson { } } + ["$numberDecimalBytes"] => { + if let Ok(bytes) = doc.get_binary_generic("$numberDecimalBytes") { + if let Ok(b) = bytes.clone().try_into() { + #[cfg(not(feature = "decimal128"))] + return Bson::Decimal128(Decimal128 { bytes: b }); + + #[cfg(feature = "decimal128")] + unsafe { + return Bson::Decimal128(Decimal128::from_raw_bytes_le(b)); + } + } + } + } + ["$binary"] => { if let Some(binary) = Binary::from_extended_doc(&doc) { return Bson::Binary(binary); diff --git a/src/extjson/models.rs b/src/extjson/models.rs index 551d5813..79065c19 100644 --- a/src/extjson/models.rs +++ b/src/extjson/models.rs @@ -4,6 +4,7 @@ use chrono::Utc; use serde::{ de::{Error, Unexpected}, Deserialize, + Serialize, }; use crate::{extjson, oid, spec::BinarySubtype, Bson}; @@ -27,7 +28,7 @@ impl Int32 { } } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub(crate) struct Int64 { #[serde(rename = "$numberLong")] @@ -72,7 +73,7 @@ impl Double { } } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct ObjectId { #[serde(rename = "$oid")] @@ -86,6 +87,12 @@ impl ObjectId { } } +impl From for ObjectId { + fn from(id: crate::oid::ObjectId) -> Self { + Self { oid: id.to_hex() } + } +} + #[derive(Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct Symbol { @@ -100,7 +107,7 @@ pub(crate) struct Regex { body: RegexBody, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct RegexBody { pub(crate) pattern: String, @@ -127,7 +134,7 @@ pub(crate) struct Binary { pub(crate) body: BinaryBody, } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub(crate) struct BinaryBody { pub(crate) base64: String, @@ -207,10 +214,13 @@ pub(crate) struct Timestamp { body: TimestampBody, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct TimestampBody { + #[serde(serialize_with = "crate::serde_helpers::serialize_u32_as_i64")] pub(crate) t: u32, + + #[serde(serialize_with = "crate::serde_helpers::serialize_u32_as_i64")] pub(crate) i: u32, } @@ -223,20 +233,28 @@ impl Timestamp { } } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct DateTime { #[serde(rename = "$date")] pub(crate) body: DateTimeBody, } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize)] #[serde(untagged)] pub(crate) enum DateTimeBody { Canonical(Int64), Relaxed(String), } +impl DateTimeBody { + pub(crate) fn from_millis(m: i64) -> Self { + DateTimeBody::Canonical(Int64 { + value: m.to_string(), + }) + } +} + impl DateTime { pub(crate) fn parse(self) -> extjson::de::Result { match self.body { @@ -307,7 +325,7 @@ pub(crate) struct DbPointer { body: DbPointerBody, } -#[derive(Deserialize)] +#[derive(Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct DbPointerBody { #[serde(rename = "$ref")] diff --git a/src/lib.rs b/src/lib.rs index 5b9f64eb..a0ae75ba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,7 +198,7 @@ pub use self::{ Deserializer, }, decimal128::Decimal128, - ser::{to_bson, to_document, Serializer}, + ser::{to_bson, to_document, to_vec, Serializer}, }; #[macro_use] diff --git a/src/ser/mod.rs b/src/ser/mod.rs index a2669007..f18dd155 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -22,6 +22,7 @@ //! Serializer mod error; +mod raw; mod serde; pub use self::{ @@ -35,9 +36,10 @@ use std::{io::Write, iter::FromIterator, mem}; use crate::decimal128::Decimal128; use crate::{ bson::{Binary, Bson, DbPointer, Document, JavaScriptCodeWithScope, Regex}, + de::MAX_BSON_SIZE, spec::BinarySubtype, }; -use ::serde::Serialize; +use ::serde::{ser::Error as SerdeError, Serialize}; fn write_string(writer: &mut W, s: &str) -> Result<()> { writer.write_all(&(s.len() as i32 + 1).to_le_bytes())?; @@ -83,6 +85,31 @@ fn write_f128(writer: &mut W, val: Decimal128) -> Result<()> writer.write_all(&raw).map_err(From::from) } +#[inline] +fn write_binary(mut writer: W, bytes: &[u8], subtype: BinarySubtype) -> Result<()> { + let len = if let BinarySubtype::BinaryOld = subtype { + bytes.len() + 4 + } else { + bytes.len() + }; + + if len > MAX_BSON_SIZE as usize { + return Err(Error::custom(format!( + "binary length {} exceeded maximum size", + bytes.len() + ))); + } + + write_i32(&mut writer, len as i32)?; + writer.write_all(&[subtype.into()])?; + + if let BinarySubtype::BinaryOld = subtype { + write_i32(&mut writer, len as i32 - 4)?; + }; + + writer.write_all(bytes).map_err(From::from) +} + fn serialize_array(writer: &mut W, arr: &[Bson]) -> Result<()> { let mut buf = Vec::new(); for (key, val) in arr.iter().enumerate() { @@ -141,22 +168,7 @@ pub(crate) fn serialize_bson( Bson::Int32(v) => write_i32(writer, v), Bson::Int64(v) => write_i64(writer, v), Bson::Timestamp(ts) => write_i64(writer, ts.to_le_i64()), - Bson::Binary(Binary { subtype, ref bytes }) => { - let len = if let BinarySubtype::BinaryOld = subtype { - bytes.len() + 4 - } else { - bytes.len() - }; - - write_i32(writer, len as i32)?; - writer.write_all(&[subtype.into()])?; - - if let BinarySubtype::BinaryOld = subtype { - write_i32(writer, len as i32 - 4)?; - }; - - writer.write_all(bytes).map_err(From::from) - } + Bson::Binary(Binary { subtype, ref bytes }) => write_binary(writer, bytes, subtype), Bson::DateTime(ref v) => write_i64(writer, v.timestamp_millis()), Bson::Null => Ok(()), Bson::Symbol(ref v) => write_string(writer, v), @@ -204,3 +216,14 @@ where }), } } + +/// Serialize the given `T` as a BSON byte vector. +#[inline] +pub fn to_vec(value: &T) -> Result> +where + T: Serialize, +{ + let mut serializer = raw::Serializer::new(); + value.serialize(&mut serializer)?; + Ok(serializer.into_vec()) +} diff --git a/src/ser/raw/document_serializer.rs b/src/ser/raw/document_serializer.rs new file mode 100644 index 00000000..ba2447f0 --- /dev/null +++ b/src/ser/raw/document_serializer.rs @@ -0,0 +1,361 @@ +use serde::{ser::Impossible, Serialize}; + +use crate::{ + ser::{write_cstring, write_i32, Error, Result}, + to_bson, + Bson, +}; + +use super::Serializer; + +pub(crate) struct DocumentSerializationResult<'a> { + pub(crate) root_serializer: &'a mut Serializer, +} + +/// Serializer used to serialize document or array bodies. +pub(crate) struct DocumentSerializer<'a> { + root_serializer: &'a mut Serializer, + num_keys_serialized: usize, + start: usize, +} + +impl<'a> DocumentSerializer<'a> { + pub(crate) fn start(rs: &'a mut Serializer) -> crate::ser::Result { + let start = rs.bytes.len(); + write_i32(&mut rs.bytes, 0)?; + Ok(Self { + root_serializer: rs, + num_keys_serialized: 0, + start, + }) + } + + fn serialize_doc_key(&mut self, key: &T) -> Result<()> + where + T: serde::Serialize + ?Sized, + { + // push a dummy element type for now, will update this once we serialize the value + self.root_serializer.reserve_element_type(); + key.serialize(KeySerializer { + root_serializer: &mut *self.root_serializer, + })?; + + self.num_keys_serialized += 1; + Ok(()) + } + + pub(crate) fn end_doc(self) -> crate::ser::Result> { + self.root_serializer.bytes.push(0); + let length = (self.root_serializer.bytes.len() - self.start) as i32; + self.root_serializer.replace_i32(self.start, length); + Ok(DocumentSerializationResult { + root_serializer: self.root_serializer, + }) + } +} + +impl<'a> serde::ser::SerializeSeq for DocumentSerializer<'a> { + type Ok = (); + type Error = Error; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: serde::Serialize, + { + self.serialize_doc_key(&self.num_keys_serialized.to_string())?; + value.serialize(&mut *self.root_serializer) + } + + #[inline] + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +impl<'a> serde::ser::SerializeMap for DocumentSerializer<'a> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: serde::Serialize, + { + self.serialize_doc_key(key) + } + + #[inline] + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: serde::Serialize, + { + value.serialize(&mut *self.root_serializer) + } + + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +impl<'a> serde::ser::SerializeStruct for DocumentSerializer<'a> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: serde::Serialize, + { + self.serialize_doc_key(key)?; + value.serialize(&mut *self.root_serializer) + } + + #[inline] + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +impl<'a> serde::ser::SerializeTuple for DocumentSerializer<'a> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: serde::Serialize, + { + self.serialize_doc_key(&self.num_keys_serialized.to_string())?; + value.serialize(&mut *self.root_serializer) + } + + #[inline] + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +impl<'a> serde::ser::SerializeTupleStruct for DocumentSerializer<'a> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: serde::Serialize, + { + self.serialize_doc_key(&self.num_keys_serialized.to_string())?; + value.serialize(&mut *self.root_serializer) + } + + #[inline] + fn end(self) -> Result { + self.end_doc().map(|_| ()) + } +} + +/// Serializer used specifically for serializing document keys. +struct KeySerializer<'a> { + root_serializer: &'a mut Serializer, +} + +impl<'a> KeySerializer<'a> { + fn invalid_key(v: T) -> Error { + Error::InvalidDocumentKey(to_bson(&v).unwrap_or(Bson::Null)) + } +} + +impl<'a> serde::Serializer for KeySerializer<'a> { + type Ok = (); + + type Error = Error; + + type SerializeSeq = Impossible<(), Error>; + type SerializeTuple = Impossible<(), Error>; + type SerializeTupleStruct = Impossible<(), Error>; + type SerializeTupleVariant = Impossible<(), Error>; + type SerializeMap = Impossible<(), Error>; + type SerializeStruct = Impossible<(), Error>; + type SerializeStructVariant = Impossible<(), Error>; + + #[inline] + fn serialize_bool(self, v: bool) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_i8(self, v: i8) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_i16(self, v: i16) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_i32(self, v: i32) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_i64(self, v: i64) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_u8(self, v: u8) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_u16(self, v: u16) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_u32(self, v: u32) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_u64(self, v: u64) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_f32(self, v: f32) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_f64(self, v: f64) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_char(self, v: char) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_str(self, v: &str) -> Result { + write_cstring(&mut self.root_serializer.bytes, v) + } + + #[inline] + fn serialize_bytes(self, v: &[u8]) -> Result { + Err(Self::invalid_key(v)) + } + + #[inline] + fn serialize_none(self) -> Result { + Err(Self::invalid_key(Bson::Null)) + } + + #[inline] + fn serialize_some(self, value: &T) -> Result + where + T: Serialize, + { + value.serialize(self) + } + + #[inline] + fn serialize_unit(self) -> Result { + Err(Self::invalid_key(Bson::Null)) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(Self::invalid_key(Bson::Null)) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result + where + T: Serialize, + { + value.serialize(self) + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + value: &T, + ) -> Result + where + T: Serialize, + { + Err(Self::invalid_key(value)) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result { + Err(Self::invalid_key(Bson::Array(vec![]))) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result { + Err(Self::invalid_key(Bson::Array(vec![]))) + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(Self::invalid_key(Bson::Document(doc! {}))) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Self::invalid_key(Bson::Array(vec![]))) + } + + #[inline] + fn serialize_map(self, _len: Option) -> Result { + Err(Self::invalid_key(Bson::Document(doc! {}))) + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(Self::invalid_key(Bson::Document(doc! {}))) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(Self::invalid_key(Bson::Document(doc! {}))) + } +} diff --git a/src/ser/raw/mod.rs b/src/ser/raw/mod.rs new file mode 100644 index 00000000..e38d5ac2 --- /dev/null +++ b/src/ser/raw/mod.rs @@ -0,0 +1,485 @@ +mod document_serializer; +mod value_serializer; + +use serde::{ + ser::{Error as SerdeError, SerializeMap, SerializeStruct}, + Serialize, +}; + +use self::value_serializer::{ValueSerializer, ValueType}; + +use super::{write_binary, write_cstring, write_f64, write_i32, write_i64, write_string}; +use crate::{ + ser::{Error, Result}, + spec::{BinarySubtype, ElementType}, +}; +use document_serializer::DocumentSerializer; + +/// Serializer used to convert a type `T` into raw BSON bytes. +pub(crate) struct Serializer { + bytes: Vec, + + /// The index into `bytes` where the current element type will need to be stored. + /// This needs to be set retroactively because in BSON, the element type comes before the key, + /// but in serde, the serializer learns of the type after serializing the key. + type_index: usize, +} + +impl Serializer { + pub(crate) fn new() -> Self { + Self { + bytes: Vec::new(), + type_index: 0, + } + } + + /// Convert this serializer into the vec of the serialized bytes. + pub(crate) fn into_vec(self) -> Vec { + self.bytes + } + + /// Reserve a spot for the element type to be set retroactively via `update_element_type`. + #[inline] + fn reserve_element_type(&mut self) { + self.type_index = self.bytes.len(); // record index + self.bytes.push(0); // push temporary placeholder + } + + /// Retroactively set the element type of the most recently serialized element. + #[inline] + fn update_element_type(&mut self, t: ElementType) -> Result<()> { + if self.type_index == 0 { + if matches!(t, ElementType::EmbeddedDocument) { + // don't need to set the element type for the top level document + return Ok(()); + } else { + return Err(Error::custom(format!( + "attempted to encode a non-document type at the top level: {:?}", + t + ))); + } + } + + self.bytes[self.type_index] = t as u8; + Ok(()) + } + + /// Replace an i32 value at the given index with the given value. + #[inline] + fn replace_i32(&mut self, at: usize, with: i32) { + self.bytes + .splice(at..at + 4, with.to_le_bytes().iter().cloned()); + } +} + +impl<'a> serde::Serializer for &'a mut Serializer { + type Ok = (); + type Error = Error; + + type SerializeSeq = DocumentSerializer<'a>; + type SerializeTuple = DocumentSerializer<'a>; + type SerializeTupleStruct = DocumentSerializer<'a>; + type SerializeTupleVariant = VariantSerializer<'a>; + type SerializeMap = DocumentSerializer<'a>; + type SerializeStruct = StructSerializer<'a>; + type SerializeStructVariant = VariantSerializer<'a>; + + #[inline] + fn serialize_bool(self, v: bool) -> Result { + self.update_element_type(ElementType::Boolean)?; + self.bytes.push(if v { 1 } else { 0 }); + Ok(()) + } + + #[inline] + fn serialize_i8(self, v: i8) -> Result { + self.serialize_i32(v.into()) + } + + #[inline] + fn serialize_i16(self, v: i16) -> Result { + self.serialize_i32(v.into()) + } + + #[inline] + fn serialize_i32(self, v: i32) -> Result { + self.update_element_type(ElementType::Int32)?; + write_i32(&mut self.bytes, v)?; + Ok(()) + } + + #[inline] + fn serialize_i64(self, v: i64) -> Result { + self.update_element_type(ElementType::Int64)?; + write_i64(&mut self.bytes, v)?; + Ok(()) + } + + #[inline] + fn serialize_u8(self, v: u8) -> Result { + #[cfg(feature = "u2i")] + { + self.serialize_i32(v.into()) + } + + #[cfg(not(feature = "u2i"))] + Err(Error::UnsupportedUnsignedInteger(v.into())) + } + + #[inline] + fn serialize_u16(self, v: u16) -> Result { + #[cfg(feature = "u2i")] + { + self.serialize_i32(v.into()) + } + + #[cfg(not(feature = "u2i"))] + Err(Error::UnsupportedUnsignedInteger(v.into())) + } + + #[inline] + fn serialize_u32(self, v: u32) -> Result { + #[cfg(feature = "u2i")] + { + self.serialize_i64(v.into()) + } + + #[cfg(not(feature = "u2i"))] + Err(Error::UnsupportedUnsignedInteger(v.into())) + } + + #[inline] + fn serialize_u64(self, v: u64) -> Result { + #[cfg(feature = "u2i")] + { + use std::convert::TryFrom; + + match i64::try_from(v) { + Ok(ivalue) => self.serialize_i64(ivalue), + Err(_) => Err(Error::UnsignedIntegerExceededRange(v)), + } + } + + #[cfg(not(feature = "u2i"))] + Err(Error::UnsupportedUnsignedInteger(v)) + } + + #[inline] + fn serialize_f32(self, v: f32) -> Result { + self.serialize_f64(v.into()) + } + + #[inline] + fn serialize_f64(self, v: f64) -> Result { + self.update_element_type(ElementType::Double)?; + write_f64(&mut self.bytes, v) + } + + #[inline] + fn serialize_char(self, v: char) -> Result { + let mut s = String::new(); + s.push(v); + self.serialize_str(&s) + } + + #[inline] + fn serialize_str(self, v: &str) -> Result { + self.update_element_type(ElementType::String)?; + write_string(&mut self.bytes, v) + } + + #[inline] + fn serialize_bytes(self, v: &[u8]) -> Result { + self.update_element_type(ElementType::Binary)?; + write_binary(&mut self.bytes, v, BinarySubtype::Generic)?; + Ok(()) + } + + #[inline] + fn serialize_none(self) -> Result { + self.update_element_type(ElementType::Null)?; + Ok(()) + } + + #[inline] + fn serialize_some(self, value: &T) -> Result + where + T: serde::Serialize, + { + value.serialize(self) + } + + #[inline] + fn serialize_unit(self) -> Result { + self.serialize_none() + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result + where + T: serde::Serialize, + { + value.serialize(self) + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: serde::Serialize, + { + self.update_element_type(ElementType::EmbeddedDocument)?; + let mut d = DocumentSerializer::start(&mut *self)?; + d.serialize_entry(variant, value)?; + d.end_doc()?; + Ok(()) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result { + self.update_element_type(ElementType::Array)?; + DocumentSerializer::start(&mut *self) + } + + #[inline] + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_seq(Some(len)) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + self.update_element_type(ElementType::EmbeddedDocument)?; + VariantSerializer::start(&mut *self, variant, VariantInnerType::Tuple) + } + + #[inline] + fn serialize_map(self, _len: Option) -> Result { + self.update_element_type(ElementType::EmbeddedDocument)?; + DocumentSerializer::start(&mut *self) + } + + #[inline] + fn serialize_struct(self, name: &'static str, _len: usize) -> Result { + let value_type = match name { + "$oid" => Some(ValueType::ObjectId), + "$date" => Some(ValueType::DateTime), + "$binary" => Some(ValueType::Binary), + "$timestamp" => Some(ValueType::Timestamp), + "$minKey" => Some(ValueType::MinKey), + "$maxKey" => Some(ValueType::MaxKey), + "$code" => Some(ValueType::JavaScriptCode), + "$codeWithScope" => Some(ValueType::JavaScriptCodeWithScope), + "$symbol" => Some(ValueType::Symbol), + "$undefined" => Some(ValueType::Undefined), + "$regularExpression" => Some(ValueType::RegularExpression), + "$dbPointer" => Some(ValueType::DbPointer), + "$numberDecimal" => Some(ValueType::Decimal128), + _ => None, + }; + + self.update_element_type( + value_type + .map(Into::into) + .unwrap_or(ElementType::EmbeddedDocument), + )?; + match value_type { + Some(vt) => Ok(StructSerializer::Value(ValueSerializer::new(self, vt))), + None => Ok(StructSerializer::Document(DocumentSerializer::start(self)?)), + } + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + self.update_element_type(ElementType::EmbeddedDocument)?; + VariantSerializer::start(&mut *self, variant, VariantInnerType::Struct) + } +} + +pub(crate) enum StructSerializer<'a> { + /// Serialize a BSON value currently represented in serde as a struct (e.g. ObjectId) + Value(ValueSerializer<'a>), + + /// Serialize the struct as a document. + Document(DocumentSerializer<'a>), +} + +impl<'a> SerializeStruct for StructSerializer<'a> { + type Ok = (); + type Error = Error; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: Serialize, + { + match self { + StructSerializer::Value(ref mut v) => (&mut *v).serialize_field(key, value), + StructSerializer::Document(d) => d.serialize_field(key, value), + } + } + + #[inline] + fn end(self) -> Result { + match self { + StructSerializer::Document(d) => SerializeStruct::end(d), + StructSerializer::Value(mut v) => v.end(), + } + } +} + +enum VariantInnerType { + Tuple, + Struct, +} + +/// Serializer used for enum variants, including both tuple (e.g. Foo::Bar(1, 2, 3)) and +/// struct (e.g. Foo::Bar { a: 1 }). +pub(crate) struct VariantSerializer<'a> { + root_serializer: &'a mut Serializer, + + /// Variants are serialized as documents of the form `{ : }`, + /// and `doc_start` indicates the index at which the outer document begins. + doc_start: usize, + + /// `inner_start` indicates the index at which the inner document or array begins. + inner_start: usize, + + /// How many elements have been serialized in the inner document / array so far. + num_elements_serialized: usize, +} + +impl<'a> VariantSerializer<'a> { + fn start( + rs: &'a mut Serializer, + variant: &'static str, + inner_type: VariantInnerType, + ) -> Result { + let doc_start = rs.bytes.len(); + // write placeholder length for document, will be updated at end + write_i32(&mut rs.bytes, 0)?; + + let inner = match inner_type { + VariantInnerType::Struct => ElementType::EmbeddedDocument, + VariantInnerType::Tuple => ElementType::Array, + }; + rs.bytes.push(inner as u8); + write_cstring(&mut rs.bytes, variant)?; + let inner_start = rs.bytes.len(); + // write placeholder length for inner, will be updated at end + write_i32(&mut rs.bytes, 0)?; + + Ok(Self { + root_serializer: rs, + num_elements_serialized: 0, + doc_start, + inner_start, + }) + } + + #[inline] + fn serialize_element(&mut self, k: &str, v: &T) -> Result<()> + where + T: Serialize + ?Sized, + { + self.root_serializer.reserve_element_type(); + write_cstring(&mut self.root_serializer.bytes, k)?; + v.serialize(&mut *self.root_serializer)?; + + self.num_elements_serialized += 1; + Ok(()) + } + + #[inline] + fn end_both(self) -> Result<()> { + // null byte for the inner + self.root_serializer.bytes.push(0); + let arr_length = (self.root_serializer.bytes.len() - self.inner_start) as i32; + self.root_serializer + .replace_i32(self.inner_start, arr_length); + + // null byte for document + self.root_serializer.bytes.push(0); + let doc_length = (self.root_serializer.bytes.len() - self.doc_start) as i32; + self.root_serializer.replace_i32(self.doc_start, doc_length); + Ok(()) + } +} + +impl<'a> serde::ser::SerializeTupleVariant for VariantSerializer<'a> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: Serialize, + { + self.serialize_element(format!("{}", self.num_elements_serialized).as_str(), value) + } + + #[inline] + fn end(self) -> Result { + self.end_both() + } +} + +impl<'a> serde::ser::SerializeStructVariant for VariantSerializer<'a> { + type Ok = (); + + type Error = Error; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: Serialize, + { + self.serialize_element(key, value) + } + + #[inline] + fn end(self) -> Result { + self.end_both() + } +} diff --git a/src/ser/raw/value_serializer.rs b/src/ser/raw/value_serializer.rs new file mode 100644 index 00000000..ace276b1 --- /dev/null +++ b/src/ser/raw/value_serializer.rs @@ -0,0 +1,574 @@ +use std::{convert::TryFrom, io::Write}; + +use serde::{ + ser::{Error as SerdeError, Impossible, SerializeMap, SerializeStruct}, + Serialize, +}; + +#[cfg(feature = "decimal128")] +use crate::Decimal128; +use crate::{ + oid::ObjectId, + ser::{write_binary, write_cstring, write_i32, write_i64, write_string, Error, Result}, + spec::{BinarySubtype, ElementType}, +}; + +use super::{document_serializer::DocumentSerializer, Serializer}; + +/// A serializer used specifically for serializing the serde-data-model form of a BSON type (e.g. +/// `Binary`) to raw bytes. +pub(crate) struct ValueSerializer<'a> { + root_serializer: &'a mut Serializer, + state: SerializationStep, +} + +/// State machine used to track which step in the serialization of a given type the serializer is +/// currently on. +#[derive(Debug)] +enum SerializationStep { + Oid, + + DateTime, + DateTimeNumberLong, + + Binary, + BinaryBase64, + BinarySubType { base64: String }, + + Symbol, + + RegEx, + RegExPattern, + RegExOptions, + + Timestamp, + TimestampTime, + TimestampIncrement { time: i64 }, + + DbPointer, + DbPointerRef, + DbPointerId, + + Code, + + CodeWithScopeCode, + CodeWithScopeScope { code: String }, + + MinKey, + + MaxKey, + + Undefined, + + Decimal128, + Decimal128Value, + + Done, +} + +/// Enum of BSON "value" types that this serializer can serialize. +#[derive(Debug, Clone, Copy)] +pub(super) enum ValueType { + DateTime, + Binary, + ObjectId, + Symbol, + RegularExpression, + Timestamp, + DbPointer, + JavaScriptCode, + JavaScriptCodeWithScope, + MinKey, + MaxKey, + Decimal128, + Undefined, +} + +impl From for ElementType { + fn from(vt: ValueType) -> Self { + match vt { + ValueType::Binary => ElementType::Binary, + ValueType::DateTime => ElementType::DateTime, + ValueType::DbPointer => ElementType::DbPointer, + ValueType::Decimal128 => ElementType::Decimal128, + ValueType::Symbol => ElementType::Symbol, + ValueType::RegularExpression => ElementType::RegularExpression, + ValueType::Timestamp => ElementType::Timestamp, + ValueType::JavaScriptCode => ElementType::JavaScriptCode, + ValueType::JavaScriptCodeWithScope => ElementType::JavaScriptCodeWithScope, + ValueType::MaxKey => ElementType::MaxKey, + ValueType::MinKey => ElementType::MinKey, + ValueType::Undefined => ElementType::Undefined, + ValueType::ObjectId => ElementType::ObjectId, + } + } +} + +impl<'a> ValueSerializer<'a> { + pub(super) fn new(rs: &'a mut Serializer, value_type: ValueType) -> Self { + let state = match value_type { + ValueType::DateTime => SerializationStep::DateTime, + ValueType::Binary => SerializationStep::Binary, + ValueType::ObjectId => SerializationStep::Oid, + ValueType::Symbol => SerializationStep::Symbol, + ValueType::RegularExpression => SerializationStep::RegEx, + ValueType::Timestamp => SerializationStep::Timestamp, + ValueType::DbPointer => SerializationStep::DbPointer, + ValueType::JavaScriptCode => SerializationStep::Code, + ValueType::JavaScriptCodeWithScope => SerializationStep::CodeWithScopeCode, + ValueType::MinKey => SerializationStep::MinKey, + ValueType::MaxKey => SerializationStep::MaxKey, + ValueType::Decimal128 => SerializationStep::Decimal128, + ValueType::Undefined => SerializationStep::Undefined, + }; + Self { + root_serializer: rs, + state, + } + } + + fn invalid_step(&self, primitive_type: &'static str) -> Error { + Error::custom(format!( + "cannot serialize {} at step {:?}", + primitive_type, self.state + )) + } +} + +impl<'a, 'b> serde::Serializer for &'b mut ValueSerializer<'a> { + type Ok = (); + type Error = Error; + + type SerializeSeq = Impossible<(), Error>; + type SerializeTuple = Impossible<(), Error>; + type SerializeTupleStruct = Impossible<(), Error>; + type SerializeTupleVariant = Impossible<(), Error>; + type SerializeMap = CodeWithScopeSerializer<'b>; + type SerializeStruct = Self; + type SerializeStructVariant = Impossible<(), Error>; + + #[inline] + fn serialize_bool(self, _v: bool) -> Result { + Err(self.invalid_step("bool")) + } + + #[inline] + fn serialize_i8(self, _v: i8) -> Result { + Err(self.invalid_step("i8")) + } + + #[inline] + fn serialize_i16(self, _v: i16) -> Result { + Err(self.invalid_step("i16")) + } + + #[inline] + fn serialize_i32(self, _v: i32) -> Result { + Err(self.invalid_step("i32")) + } + + #[inline] + fn serialize_i64(self, v: i64) -> Result { + match self.state { + SerializationStep::TimestampTime => { + self.state = SerializationStep::TimestampIncrement { time: v }; + Ok(()) + } + SerializationStep::TimestampIncrement { time } => { + let t = u32::try_from(time).map_err(Error::custom)?; + let i = u32::try_from(v).map_err(Error::custom)?; + + write_i32(&mut self.root_serializer.bytes, i as i32)?; + write_i32(&mut self.root_serializer.bytes, t as i32)?; + Ok(()) + } + _ => Err(self.invalid_step("i64")), + } + } + + #[inline] + fn serialize_u8(self, _v: u8) -> Result { + Err(self.invalid_step("u8")) + } + + #[inline] + fn serialize_u16(self, _v: u16) -> Result { + Err(self.invalid_step("u16")) + } + + #[inline] + fn serialize_u32(self, _v: u32) -> Result { + Err(self.invalid_step("u32")) + } + + #[inline] + fn serialize_u64(self, _v: u64) -> Result { + Err(self.invalid_step("u64")) + } + + #[inline] + fn serialize_f32(self, _v: f32) -> Result { + Err(self.invalid_step("f32")) + } + + #[inline] + fn serialize_f64(self, _v: f64) -> Result { + Err(self.invalid_step("f64")) + } + + #[inline] + fn serialize_char(self, _v: char) -> Result { + Err(self.invalid_step("char")) + } + + fn serialize_str(self, v: &str) -> Result { + match &self.state { + SerializationStep::DateTimeNumberLong => { + let millis: i64 = v.parse().map_err(Error::custom)?; + write_i64(&mut self.root_serializer.bytes, millis)?; + } + SerializationStep::Oid => { + let oid = ObjectId::parse_str(v).map_err(Error::custom)?; + self.root_serializer.bytes.write_all(&oid.bytes())?; + } + SerializationStep::BinaryBase64 => { + self.state = SerializationStep::BinarySubType { + base64: v.to_string(), + }; + } + SerializationStep::BinarySubType { base64 } => { + let subtype_byte = hex::decode(v).map_err(Error::custom)?; + let subtype: BinarySubtype = subtype_byte[0].into(); + + let bytes = base64::decode(base64.as_str()).map_err(Error::custom)?; + + write_binary(&mut self.root_serializer.bytes, bytes.as_slice(), subtype)?; + } + SerializationStep::Symbol | SerializationStep::DbPointerRef => { + write_string(&mut self.root_serializer.bytes, v)?; + } + SerializationStep::RegExPattern | SerializationStep::RegExOptions => { + write_cstring(&mut self.root_serializer.bytes, v)?; + } + SerializationStep::Code => { + write_string(&mut self.root_serializer.bytes, v)?; + } + SerializationStep::CodeWithScopeCode => { + self.state = SerializationStep::CodeWithScopeScope { + code: v.to_string(), + }; + } + #[cfg(feature = "decimal128")] + SerializationStep::Decimal128Value => { + let d = Decimal128::from_str(v); + self.root_serializer.bytes.write_all(&d.to_raw_bytes_le())?; + } + s => { + return Err(Error::custom(format!( + "can't serialize string for step {:?}", + s + ))) + } + } + Ok(()) + } + + #[inline] + fn serialize_bytes(self, v: &[u8]) -> Result { + match self.state { + SerializationStep::Decimal128Value => { + self.root_serializer.bytes.write_all(v)?; + Ok(()) + } + _ => Err(self.invalid_step("&[u8]")), + } + } + + #[inline] + fn serialize_none(self) -> Result { + Err(self.invalid_step("none")) + } + + #[inline] + fn serialize_some(self, _value: &T) -> Result + where + T: Serialize, + { + Err(self.invalid_step("some")) + } + + #[inline] + fn serialize_unit(self) -> Result { + Err(self.invalid_step("unit")) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Err(self.invalid_step("unit_struct")) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + Err(self.invalid_step("unit_variant")) + } + + #[inline] + fn serialize_newtype_struct( + self, + _name: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + Err(self.invalid_step("newtype_struct")) + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + Err(self.invalid_step("newtype_variant")) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result { + Err(self.invalid_step("newtype_seq")) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result { + Err(self.invalid_step("newtype_tuple")) + } + + #[inline] + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(self.invalid_step("tuple_struct")) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(self.invalid_step("tuple_variant")) + } + + #[inline] + fn serialize_map(self, _len: Option) -> Result { + match self.state { + SerializationStep::CodeWithScopeScope { ref code } => { + CodeWithScopeSerializer::start(code.as_str(), self.root_serializer) + } + _ => Err(self.invalid_step("tuple_map")), + } + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Ok(self) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(self.invalid_step("struct_variant")) + } +} + +impl<'a, 'b> SerializeStruct for &'b mut ValueSerializer<'a> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: Serialize, + { + match (&self.state, key) { + (SerializationStep::DateTime, "$date") => { + self.state = SerializationStep::DateTimeNumberLong; + value.serialize(&mut **self)?; + } + (SerializationStep::DateTimeNumberLong, "$numberLong") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Oid, "$oid") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Binary, "$binary") => { + self.state = SerializationStep::BinaryBase64; + value.serialize(&mut **self)?; + } + (SerializationStep::BinaryBase64, "base64") => { + // state is updated in serialize + value.serialize(&mut **self)?; + } + (SerializationStep::BinarySubType { .. }, "subType") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Symbol, "$symbol") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::RegEx, "$regularExpression") => { + self.state = SerializationStep::RegExPattern; + value.serialize(&mut **self)?; + } + (SerializationStep::RegExPattern, "pattern") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::RegExOptions; + } + (SerializationStep::RegExOptions, "options") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Timestamp, "$timestamp") => { + self.state = SerializationStep::TimestampTime; + value.serialize(&mut **self)?; + } + (SerializationStep::TimestampTime, "t") => { + // state is updated in serialize + value.serialize(&mut **self)?; + } + (SerializationStep::TimestampIncrement { .. }, "i") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::DbPointer, "$dbPointer") => { + self.state = SerializationStep::DbPointerRef; + value.serialize(&mut **self)?; + } + (SerializationStep::DbPointerRef, "$ref") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::DbPointerId; + } + (SerializationStep::DbPointerId, "$id") => { + self.state = SerializationStep::Oid; + value.serialize(&mut **self)?; + } + (SerializationStep::Code, "$code") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::CodeWithScopeCode, "$code") => { + // state is updated in serialize + value.serialize(&mut **self)?; + } + (SerializationStep::CodeWithScopeScope { .. }, "$scope") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::MinKey, "$minKey") => { + self.state = SerializationStep::Done; + } + (SerializationStep::MaxKey, "$maxKey") => { + self.state = SerializationStep::Done; + } + (SerializationStep::Undefined, "$undefined") => { + self.state = SerializationStep::Done; + } + (SerializationStep::Decimal128, "$numberDecimal") + | (SerializationStep::Decimal128, "$numberDecimalBytes") => { + self.state = SerializationStep::Decimal128Value; + value.serialize(&mut **self)?; + } + (SerializationStep::Decimal128Value, "$numberDecimal") => { + value.serialize(&mut **self)?; + self.state = SerializationStep::Done; + } + (SerializationStep::Done, k) => { + return Err(Error::custom(format!( + "expected to end serialization of type, got extra key \"{}\"", + k + ))); + } + (state, k) => { + return Err(Error::custom(format!( + "mismatched serialization step and next key: {:?} + \"{}\"", + state, k + ))); + } + } + + Ok(()) + } + + #[inline] + fn end(self) -> Result { + Ok(()) + } +} + +pub(crate) struct CodeWithScopeSerializer<'a> { + start: usize, + doc: DocumentSerializer<'a>, +} + +impl<'a> CodeWithScopeSerializer<'a> { + #[inline] + fn start(code: &str, rs: &'a mut Serializer) -> Result { + let start = rs.bytes.len(); + write_i32(&mut rs.bytes, 0)?; // placeholder length + write_string(&mut rs.bytes, code)?; + + let doc = DocumentSerializer::start(rs)?; + Ok(Self { start, doc }) + } +} + +impl<'a> SerializeMap for CodeWithScopeSerializer<'a> { + type Ok = (); + type Error = Error; + + #[inline] + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: Serialize, + { + self.doc.serialize_key(key) + } + + #[inline] + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: Serialize, + { + self.doc.serialize_value(value) + } + + #[inline] + fn end(self) -> Result { + let result = self.doc.end_doc()?; + + let total_len = (result.root_serializer.bytes.len() - self.start) as i32; + result.root_serializer.replace_i32(self.start, total_len); + Ok(()) + } +} diff --git a/src/ser/serde.rs b/src/ser/serde.rs index e903b1e2..04f61ba0 100644 --- a/src/ser/serde.rs +++ b/src/ser/serde.rs @@ -9,12 +9,15 @@ use serde::ser::{ SerializeTupleStruct, SerializeTupleVariant, }; +#[cfg(not(feature = "decimal128"))] +use serde_bytes::Bytes; #[cfg(feature = "decimal128")] use crate::decimal128::Decimal128; use crate::{ bson::{Array, Binary, Bson, DbPointer, Document, JavaScriptCodeWithScope, Regex, Timestamp}, datetime::DateTime, + extjson, oid::ObjectId, spec::BinarySubtype, }; @@ -27,8 +30,8 @@ impl Serialize for ObjectId { where S: serde::ser::Serializer, { - let mut ser = serializer.serialize_map(Some(1))?; - ser.serialize_entry("$oid", &self.to_string())?; + let mut ser = serializer.serialize_struct("$oid", 1)?; + ser.serialize_field("$oid", &self.to_string())?; ser.end() } } @@ -53,22 +56,54 @@ impl Serialize for Bson { where S: ser::Serializer, { - match *self { - Bson::Double(v) => serializer.serialize_f64(v), - Bson::String(ref v) => serializer.serialize_str(v), - Bson::Array(ref v) => v.serialize(serializer), - Bson::Document(ref v) => v.serialize(serializer), - Bson::Boolean(v) => serializer.serialize_bool(v), + match self { + Bson::Double(v) => serializer.serialize_f64(*v), + Bson::String(v) => serializer.serialize_str(v), + Bson::Array(v) => v.serialize(serializer), + Bson::Document(v) => v.serialize(serializer), + Bson::Boolean(v) => serializer.serialize_bool(*v), Bson::Null => serializer.serialize_unit(), - Bson::Int32(v) => serializer.serialize_i32(v), - Bson::Int64(v) => serializer.serialize_i64(v), - Bson::Binary(Binary { - subtype: BinarySubtype::Generic, - ref bytes, - }) => serializer.serialize_bytes(bytes), - _ => { - let doc = self.clone().into_extended_document(); - doc.serialize(serializer) + Bson::Int32(v) => serializer.serialize_i32(*v), + Bson::Int64(v) => serializer.serialize_i64(*v), + Bson::ObjectId(oid) => oid.serialize(serializer), + Bson::DateTime(dt) => dt.serialize(serializer), + Bson::Binary(b) => b.serialize(serializer), + Bson::JavaScriptCode(c) => { + let mut state = serializer.serialize_struct("$code", 1)?; + state.serialize_field("$code", c)?; + state.end() + } + Bson::JavaScriptCodeWithScope(code_w_scope) => code_w_scope.serialize(serializer), + Bson::DbPointer(dbp) => dbp.serialize(serializer), + Bson::Symbol(s) => { + let mut state = serializer.serialize_struct("$symbol", 1)?; + state.serialize_field("$symbol", s)?; + state.end() + } + Bson::RegularExpression(re) => re.serialize(serializer), + Bson::Timestamp(t) => t.serialize(serializer), + #[cfg(not(feature = "decimal128"))] + Bson::Decimal128(d) => { + let mut state = serializer.serialize_struct("$numberDecimal", 1)?; + state.serialize_field("$numberDecimalBytes", Bytes::new(&d.bytes))?; + state.end() + } + #[cfg(feature = "decimal128")] + Bson::Decimal128(d) => d.serialize(serializer), + Bson::Undefined => { + let mut state = serializer.serialize_struct("$undefined", 1)?; + state.serialize_field("$undefined", &true)?; + state.end() + } + Bson::MaxKey => { + let mut state = serializer.serialize_struct("$maxKey", 1)?; + state.serialize_field("$maxKey", &1)?; + state.end() + } + Bson::MinKey => { + let mut state = serializer.serialize_struct("$minKey", 1)?; + state.serialize_field("$minKey", &1)?; + state.end() } } } @@ -506,8 +541,13 @@ impl Serialize for Timestamp { where S: ser::Serializer, { - let value = Bson::Timestamp(*self); - value.serialize(serializer) + let mut state = serializer.serialize_struct("$timestamp", 1)?; + let body = extjson::models::TimestampBody { + t: self.time, + i: self.increment, + }; + state.serialize_field("$timestamp", &body)?; + state.end() } } @@ -517,8 +557,13 @@ impl Serialize for Regex { where S: ser::Serializer, { - let value = Bson::RegularExpression(self.clone()); - value.serialize(serializer) + let mut state = serializer.serialize_struct("$regularExpression", 1)?; + let body = extjson::models::RegexBody { + pattern: self.pattern.clone(), + options: self.options.clone(), + }; + state.serialize_field("$regularExpression", &body)?; + state.end() } } @@ -528,8 +573,10 @@ impl Serialize for JavaScriptCodeWithScope { where S: ser::Serializer, { - let value = Bson::JavaScriptCodeWithScope(self.clone()); - value.serialize(serializer) + let mut state = serializer.serialize_struct("$codeWithScope", 2)?; + state.serialize_field("$code", &self.code)?; + state.serialize_field("$scope", &self.scope)?; + state.end() } } @@ -539,8 +586,17 @@ impl Serialize for Binary { where S: ser::Serializer, { - let value = Bson::Binary(self.clone()); - value.serialize(serializer) + if let BinarySubtype::Generic = self.subtype { + serializer.serialize_bytes(self.bytes.as_slice()) + } else { + let mut state = serializer.serialize_struct("$binary", 1)?; + let body = extjson::models::BinaryBody { + base64: base64::encode(self.bytes.as_slice()), + subtype: hex::encode([self.subtype.into()]), + }; + state.serialize_field("$binary", &body)?; + state.end() + } } } @@ -551,8 +607,12 @@ impl Serialize for Decimal128 { where S: ser::Serializer, { - let value = Bson::Decimal128(self.clone()); - value.serialize(serializer) + let mut state = serializer.serialize_struct("$numberDecimal", 1)?; + state.serialize_field( + "$numberDecimalBytes", + serde_bytes::Bytes::new(&self.to_raw_bytes_le()), + )?; + state.end() } } @@ -562,9 +622,10 @@ impl Serialize for DateTime { where S: ser::Serializer, { - // Cloning a `DateTime` is extremely cheap - let value = Bson::DateTime(*self); - value.serialize(serializer) + let mut state = serializer.serialize_struct("$date", 1)?; + let body = extjson::models::DateTimeBody::from_millis(self.timestamp_millis()); + state.serialize_field("$date", &body)?; + state.end() } } @@ -574,7 +635,12 @@ impl Serialize for DbPointer { where S: ser::Serializer, { - let value = Bson::DbPointer(self.clone()); - value.serialize(serializer) + let mut state = serializer.serialize_struct("$dbPointer", 1)?; + let body = extjson::models::DbPointerBody { + ref_ns: self.namespace.clone(), + id: self.id.into(), + }; + state.serialize_field("$dbPointer", &body)?; + state.end() } } diff --git a/src/tests/spec/corpus.rs b/src/tests/spec/corpus.rs index a1a574b8..4623f5dc 100644 --- a/src/tests/spec/corpus.rs +++ b/src/tests/spec/corpus.rs @@ -61,48 +61,77 @@ fn run_test(test: TestFile) { let canonical_bson = hex::decode(&valid.canonical_bson).expect(&description); - let bson_to_native_cb = + // these four cover the four ways to create a `Document` from the provided BSON. + let documentfromreader_cb = Document::from_reader(canonical_bson.as_slice()).expect(&description); - let bson_to_native_cb_serde: Document = + let fromreader_cb: Document = crate::from_reader(canonical_bson.as_slice()).expect(&description); - let native_to_native_cb_serde: Document = - crate::from_document(bson_to_native_cb.clone()).expect(&description); + let fromdocument_documentfromreader_cb: Document = + crate::from_document(documentfromreader_cb.clone()).expect(&description); - let mut native_to_bson_bson_to_native_cb = Vec::new(); - bson_to_native_cb - .to_writer(&mut native_to_bson_bson_to_native_cb) + let todocument_documentfromreader_cb: Document = + crate::to_document(&documentfromreader_cb).expect(&description); + + // These cover the ways to serialize those `Documents` back to BSON. + let mut documenttowriter_documentfromreader_cb = Vec::new(); + documentfromreader_cb + .to_writer(&mut documenttowriter_documentfromreader_cb) + .expect(&description); + + let mut documenttowriter_fromreader_cb = Vec::new(); + fromreader_cb + .to_writer(&mut documenttowriter_fromreader_cb) .expect(&description); - let mut native_to_bson_bson_to_native_cb_serde = Vec::new(); - bson_to_native_cb_serde - .to_writer(&mut native_to_bson_bson_to_native_cb_serde) + let mut documenttowriter_fromdocument_documentfromreader_cb = Vec::new(); + fromdocument_documentfromreader_cb + .to_writer(&mut documenttowriter_fromdocument_documentfromreader_cb) .expect(&description); - let mut native_to_bson_native_to_native_cb_serde = Vec::new(); - native_to_native_cb_serde - .to_writer(&mut native_to_bson_native_to_native_cb_serde) + let mut documenttowriter_todocument_documentfromreader_cb = Vec::new(); + todocument_documentfromreader_cb + .to_writer(&mut documenttowriter_todocument_documentfromreader_cb) .expect(&description); + let tovec_documentfromreader_cb = + crate::to_vec(&documentfromreader_cb).expect(&description); + // native_to_bson( bson_to_native(cB) ) = cB + // now we ensure the hex for all 5 are equivalent to the canonical BSON provided by the + // test. assert_eq!( - hex::encode(native_to_bson_bson_to_native_cb).to_lowercase(), + hex::encode(documenttowriter_documentfromreader_cb).to_lowercase(), valid.canonical_bson.to_lowercase(), "{}", description, ); assert_eq!( - hex::encode(native_to_bson_bson_to_native_cb_serde).to_lowercase(), + hex::encode(documenttowriter_fromreader_cb).to_lowercase(), valid.canonical_bson.to_lowercase(), "{}", description, ); assert_eq!( - hex::encode(native_to_bson_native_to_native_cb_serde).to_lowercase(), + hex::encode(documenttowriter_fromdocument_documentfromreader_cb).to_lowercase(), + valid.canonical_bson.to_lowercase(), + "{}", + description, + ); + + assert_eq!( + hex::encode(documenttowriter_todocument_documentfromreader_cb).to_lowercase(), + valid.canonical_bson.to_lowercase(), + "{}", + description, + ); + + assert_eq!( + hex::encode(tovec_documentfromreader_cb).to_lowercase(), valid.canonical_bson.to_lowercase(), "{}", description, @@ -110,14 +139,16 @@ fn run_test(test: TestFile) { // NaN == NaN is false, so we skip document comparisons that contain NaN if !description.to_ascii_lowercase().contains("nan") && !description.contains("decq541") { + assert_eq!(documentfromreader_cb, fromreader_cb, "{}", description); + assert_eq!( - bson_to_native_cb, bson_to_native_cb_serde, + documentfromreader_cb, fromdocument_documentfromreader_cb, "{}", description ); assert_eq!( - bson_to_native_cb, native_to_native_cb_serde, + documentfromreader_cb, todocument_documentfromreader_cb, "{}", description ); @@ -156,7 +187,7 @@ fn run_test(test: TestFile) { // NaN == NaN is false, so we skip document comparisons that contain NaN if !description.contains("NaN") { assert_eq!( - bson_to_native_db_serde, bson_to_native_cb, + bson_to_native_db_serde, documentfromreader_cb, "{}", description ); @@ -201,7 +232,7 @@ fn run_test(test: TestFile) { // TODO RUST-36: Enable decimal128 tests. if test.bson_type != "0x13" { assert_eq!( - Bson::Document(bson_to_native_cb.clone()).into_canonical_extjson(), + Bson::Document(documentfromreader_cb.clone()).into_canonical_extjson(), cej_updated_float, "{}", description @@ -214,7 +245,7 @@ fn run_test(test: TestFile) { let rej: serde_json::Value = serde_json::from_str(relaxed_extjson).expect(&description); assert_eq!( - Bson::Document(bson_to_native_cb.clone()).into_relaxed_extjson(), + Bson::Document(documentfromreader_cb.clone()).into_relaxed_extjson(), rej, "{}", description