|
| 1 | +use std::{ |
| 2 | + convert::{TryFrom, TryInto}, |
| 3 | + mem::size_of, |
| 4 | +}; |
| 5 | + |
| 6 | +use serde::{Deserialize, Serialize}; |
| 7 | + |
| 8 | +use super::{Binary, Error, Result}; |
| 9 | +use crate::{spec::BinarySubtype, Bson, RawBson}; |
| 10 | + |
| 11 | +const INT8: u8 = 0x03; |
| 12 | +const FLOAT32: u8 = 0x27; |
| 13 | +const PACKED_BIT: u8 = 0x10; |
| 14 | + |
| 15 | +/// A vector of numeric values. This type can be converted into a [`Binary`] of subtype |
| 16 | +/// [`BinarySubtype::Vector`]. |
| 17 | +/// |
| 18 | +/// ```rust |
| 19 | +/// let vector = Vector::Int8(vec![0, 1, 2]); |
| 20 | +/// let binary = Binary::from(vector); |
| 21 | +/// ``` |
| 22 | +/// |
| 23 | +/// The `Serialize` and `Deserialize` implementations for `Vector` treat it as a `Binary`. |
| 24 | +/// |
| 25 | +/// ```rust |
| 26 | +/// #[derive(Serialize, Deserialize)] |
| 27 | +/// struct Data { |
| 28 | +/// vector: Vector, |
| 29 | +/// } |
| 30 | +/// |
| 31 | +/// let data = Data { vector: Vector::Int8(vec![0, 1, 2]) }; |
| 32 | +/// let document = bson::to_document(&data); |
| 33 | +/// assert_eq!(document.get("vector").unwrap().element_type(), ElementType::Binary); |
| 34 | +/// |
| 35 | +/// let data = bson::from_document(document); |
| 36 | +/// assert_eq!(data.vector, Vector::Int8(vec![0, 1, 2])); |
| 37 | +/// ``` |
| 38 | +/// |
| 39 | +/// See the |
| 40 | +/// [specification](https://github.com/mongodb/specifications/blob/master/source/bson-binary-vector/bson-binary-vector.md) |
| 41 | +/// for more details. |
| 42 | +#[derive(Clone, Debug, PartialEq)] |
| 43 | +pub enum Vector { |
| 44 | + /// A vector of `i8` values. |
| 45 | + Int8(Vec<i8>), |
| 46 | + |
| 47 | + /// A vector of `f32` values. |
| 48 | + Float32(Vec<f32>), |
| 49 | + |
| 50 | + /// A vector of packed bits. See [`PackedBitVector::new`] for more details. |
| 51 | + PackedBit(PackedBitVector), |
| 52 | +} |
| 53 | + |
| 54 | +/// A vector of packed bits. This type can be constructed by calling [`PackedBitVector::new`]. |
| 55 | +#[derive(Clone, Debug, PartialEq)] |
| 56 | +pub struct PackedBitVector { |
| 57 | + vector: Vec<u8>, |
| 58 | + padding: u8, |
| 59 | +} |
| 60 | + |
| 61 | +impl PackedBitVector { |
| 62 | + /// Construct a new `PackedBitVector`. Each `u8` value in the provided `vector` represents 8 |
| 63 | + /// single-bit elements in little-endian format. For example, the following vector: |
| 64 | + /// |
| 65 | + /// ```rust |
| 66 | + /// let packed_bits = vec![238, 224]; |
| 67 | + /// let vector = PackedBitVector::new(packed_bits, 0); |
| 68 | + /// ``` |
| 69 | + /// |
| 70 | + /// represents a 16-bit vector containing the following values: |
| 71 | + /// |
| 72 | + /// ``` |
| 73 | + /// [1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0] |
| 74 | + /// ``` |
| 75 | + /// |
| 76 | + /// Padding can optionally be specified to ignore a number of least-significant bits in the |
| 77 | + /// final byte. For example, the vector in the previous example with a padding of 4 would |
| 78 | + /// represent a 12-bit vector containing the following values: |
| 79 | + /// |
| 80 | + /// ``` |
| 81 | + /// [1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0] |
| 82 | + /// ``` |
| 83 | + /// |
| 84 | + /// Padding must be within 0-7 inclusive. Padding must be 0 or unspecified if the provided |
| 85 | + /// vector is empty. |
| 86 | + pub fn new(vector: Vec<u8>, padding: impl Into<Option<u8>>) -> Result<Self> { |
| 87 | + let padding = padding.into().unwrap_or(0); |
| 88 | + if !(0..8).contains(&padding) { |
| 89 | + return Err(Error::Vector { |
| 90 | + message: format!("padding must be within 0-7 inclusive, got {}", padding), |
| 91 | + }); |
| 92 | + } |
| 93 | + if padding != 0 && vector.is_empty() { |
| 94 | + return Err(Error::Vector { |
| 95 | + message: format!( |
| 96 | + "cannot specify non-zero padding if the provided vector is empty, got {}", |
| 97 | + padding |
| 98 | + ), |
| 99 | + }); |
| 100 | + } |
| 101 | + Ok(Self { vector, padding }) |
| 102 | + } |
| 103 | +} |
| 104 | + |
| 105 | +impl Vector { |
| 106 | + /// Construct a [`Vector`] from the given bytes. See the |
| 107 | + /// [specification](https://github.com/mongodb/specifications/blob/master/source/bson-binary-vector/bson-binary-vector.md#specification) |
| 108 | + /// for details on the expected byte format. |
| 109 | + pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> { |
| 110 | + let bytes = bytes.as_ref(); |
| 111 | + |
| 112 | + if bytes.len() < 2 { |
| 113 | + return Err(Error::Vector { |
| 114 | + message: format!( |
| 115 | + "the provided bytes must have a length of at least 2, got {}", |
| 116 | + bytes.len() |
| 117 | + ), |
| 118 | + }); |
| 119 | + } |
| 120 | + |
| 121 | + let d_type = bytes[0]; |
| 122 | + let padding = bytes[1]; |
| 123 | + if d_type != PACKED_BIT && padding != 0 { |
| 124 | + return Err(Error::Vector { |
| 125 | + message: format!( |
| 126 | + "padding can only be specified for a packed bit vector (data type {}), got \ |
| 127 | + type {}", |
| 128 | + PACKED_BIT, d_type |
| 129 | + ), |
| 130 | + }); |
| 131 | + } |
| 132 | + let number_bytes = &bytes[2..]; |
| 133 | + |
| 134 | + match d_type { |
| 135 | + INT8 => { |
| 136 | + let vector = number_bytes |
| 137 | + .iter() |
| 138 | + .map(|n| i8::from_le_bytes([*n])) |
| 139 | + .collect(); |
| 140 | + Ok(Self::Int8(vector)) |
| 141 | + } |
| 142 | + FLOAT32 => { |
| 143 | + const F32_BYTES: usize = size_of::<f32>(); |
| 144 | + |
| 145 | + let mut vector = Vec::new(); |
| 146 | + for chunk in number_bytes.chunks(F32_BYTES) { |
| 147 | + let bytes: [u8; F32_BYTES] = chunk.try_into().map_err(|_| Error::Vector { |
| 148 | + message: format!( |
| 149 | + "f32 vector values must be {} bytes, got {:?}", |
| 150 | + F32_BYTES, chunk, |
| 151 | + ), |
| 152 | + })?; |
| 153 | + vector.push(f32::from_le_bytes(bytes)); |
| 154 | + } |
| 155 | + Ok(Self::Float32(vector)) |
| 156 | + } |
| 157 | + PACKED_BIT => { |
| 158 | + let packed_bit_vector = PackedBitVector::new(number_bytes.to_vec(), padding)?; |
| 159 | + Ok(Self::PackedBit(packed_bit_vector)) |
| 160 | + } |
| 161 | + other => Err(Error::Vector { |
| 162 | + message: format!("unsupported vector data type: {}", other), |
| 163 | + }), |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + fn d_type(&self) -> u8 { |
| 168 | + match self { |
| 169 | + Self::Int8(_) => INT8, |
| 170 | + Self::Float32(_) => FLOAT32, |
| 171 | + Self::PackedBit(_) => PACKED_BIT, |
| 172 | + } |
| 173 | + } |
| 174 | + |
| 175 | + fn padding(&self) -> u8 { |
| 176 | + match self { |
| 177 | + Self::Int8(_) => 0, |
| 178 | + Self::Float32(_) => 0, |
| 179 | + Self::PackedBit(PackedBitVector { padding, .. }) => *padding, |
| 180 | + } |
| 181 | + } |
| 182 | +} |
| 183 | + |
| 184 | +impl From<&Vector> for Binary { |
| 185 | + fn from(vector: &Vector) -> Self { |
| 186 | + let d_type = vector.d_type(); |
| 187 | + let padding = vector.padding(); |
| 188 | + let mut bytes = vec![d_type, padding]; |
| 189 | + |
| 190 | + match vector { |
| 191 | + Vector::Int8(vector) => { |
| 192 | + for n in vector { |
| 193 | + bytes.extend_from_slice(&n.to_le_bytes()); |
| 194 | + } |
| 195 | + } |
| 196 | + Vector::Float32(vector) => { |
| 197 | + for n in vector { |
| 198 | + bytes.extend_from_slice(&n.to_le_bytes()); |
| 199 | + } |
| 200 | + } |
| 201 | + Vector::PackedBit(PackedBitVector { vector, .. }) => { |
| 202 | + for n in vector { |
| 203 | + bytes.extend_from_slice(&n.to_le_bytes()); |
| 204 | + } |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + Self { |
| 209 | + subtype: BinarySubtype::Vector, |
| 210 | + bytes, |
| 211 | + } |
| 212 | + } |
| 213 | +} |
| 214 | + |
| 215 | +impl From<Vector> for Binary { |
| 216 | + fn from(vector: Vector) -> Binary { |
| 217 | + Self::from(&vector) |
| 218 | + } |
| 219 | +} |
| 220 | + |
| 221 | +impl TryFrom<&Binary> for Vector { |
| 222 | + type Error = Error; |
| 223 | + |
| 224 | + fn try_from(binary: &Binary) -> Result<Self> { |
| 225 | + if binary.subtype != BinarySubtype::Vector { |
| 226 | + return Err(Error::Vector { |
| 227 | + message: format!("expected vector binary subtype, got {:?}", binary.subtype), |
| 228 | + }); |
| 229 | + } |
| 230 | + Self::from_bytes(&binary.bytes) |
| 231 | + } |
| 232 | +} |
| 233 | + |
| 234 | +impl TryFrom<Binary> for Vector { |
| 235 | + type Error = Error; |
| 236 | + |
| 237 | + fn try_from(binary: Binary) -> std::result::Result<Self, Self::Error> { |
| 238 | + Self::try_from(&binary) |
| 239 | + } |
| 240 | +} |
| 241 | + |
| 242 | +// Convenience impl to allow passing a Vector directly into the doc! macro. From<Vector> is already |
| 243 | +// implemented by a blanket impl in src/bson.rs. |
| 244 | +impl From<&Vector> for Bson { |
| 245 | + fn from(vector: &Vector) -> Self { |
| 246 | + Self::Binary(Binary::from(vector)) |
| 247 | + } |
| 248 | +} |
| 249 | + |
| 250 | +// Convenience impls to allow passing a Vector directly into the rawdoc! macro |
| 251 | +impl From<&Vector> for RawBson { |
| 252 | + fn from(vector: &Vector) -> Self { |
| 253 | + Self::Binary(Binary::from(vector)) |
| 254 | + } |
| 255 | +} |
| 256 | + |
| 257 | +impl From<Vector> for RawBson { |
| 258 | + fn from(vector: Vector) -> Self { |
| 259 | + Self::from(&vector) |
| 260 | + } |
| 261 | +} |
| 262 | + |
| 263 | +impl Serialize for Vector { |
| 264 | + fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> |
| 265 | + where |
| 266 | + S: serde::Serializer, |
| 267 | + { |
| 268 | + let binary = Binary::from(self); |
| 269 | + binary.serialize(serializer) |
| 270 | + } |
| 271 | +} |
| 272 | + |
| 273 | +impl<'de> Deserialize<'de> for Vector { |
| 274 | + fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> |
| 275 | + where |
| 276 | + D: serde::Deserializer<'de>, |
| 277 | + { |
| 278 | + let binary = Binary::deserialize(deserializer)?; |
| 279 | + Self::try_from(binary).map_err(serde::de::Error::custom) |
| 280 | + } |
| 281 | +} |
0 commit comments