Skip to content

Commit 2e194f3

Browse files
authored
refactor: Change dynamic literals to be separate category (#21849)
1 parent 9b63946 commit 2e194f3

File tree

32 files changed

+1104
-997
lines changed

32 files changed

+1104
-997
lines changed

crates/polars-core/src/datatypes/any_value.rs

+130-122
Original file line numberDiff line numberDiff line change
@@ -118,31 +118,83 @@ impl Serialize for AnyValue<'_> {
118118
let name = "AnyValue";
119119
match self {
120120
AnyValue::Null => serializer.serialize_unit_variant(name, 0, "Null"),
121+
121122
AnyValue::Int8(v) => serializer.serialize_newtype_variant(name, 1, "Int8", v),
122123
AnyValue::Int16(v) => serializer.serialize_newtype_variant(name, 2, "Int16", v),
123124
AnyValue::Int32(v) => serializer.serialize_newtype_variant(name, 3, "Int32", v),
124125
AnyValue::Int64(v) => serializer.serialize_newtype_variant(name, 4, "Int64", v),
125-
AnyValue::Int128(v) => serializer.serialize_newtype_variant(name, 4, "Int128", v),
126-
AnyValue::UInt8(v) => serializer.serialize_newtype_variant(name, 5, "UInt8", v),
127-
AnyValue::UInt16(v) => serializer.serialize_newtype_variant(name, 6, "UInt16", v),
128-
AnyValue::UInt32(v) => serializer.serialize_newtype_variant(name, 7, "UInt32", v),
129-
AnyValue::UInt64(v) => serializer.serialize_newtype_variant(name, 8, "UInt64", v),
130-
AnyValue::Float32(v) => serializer.serialize_newtype_variant(name, 9, "Float32", v),
131-
AnyValue::Float64(v) => serializer.serialize_newtype_variant(name, 10, "Float64", v),
132-
AnyValue::List(v) => serializer.serialize_newtype_variant(name, 11, "List", v),
133-
AnyValue::Boolean(v) => serializer.serialize_newtype_variant(name, 12, "Bool", v),
134-
// both string variants same number
135-
AnyValue::String(v) => serializer.serialize_newtype_variant(name, 13, "StringOwned", v),
126+
AnyValue::Int128(v) => serializer.serialize_newtype_variant(name, 5, "Int128", v),
127+
AnyValue::UInt8(v) => serializer.serialize_newtype_variant(name, 6, "UInt8", v),
128+
AnyValue::UInt16(v) => serializer.serialize_newtype_variant(name, 7, "UInt16", v),
129+
AnyValue::UInt32(v) => serializer.serialize_newtype_variant(name, 8, "UInt32", v),
130+
AnyValue::UInt64(v) => serializer.serialize_newtype_variant(name, 9, "UInt64", v),
131+
AnyValue::Float32(v) => serializer.serialize_newtype_variant(name, 10, "Float32", v),
132+
AnyValue::Float64(v) => serializer.serialize_newtype_variant(name, 11, "Float64", v),
133+
AnyValue::List(v) => serializer.serialize_newtype_variant(name, 12, "List", v),
134+
AnyValue::Boolean(v) => serializer.serialize_newtype_variant(name, 13, "Bool", v),
135+
136+
// both variants same number
137+
AnyValue::String(v) => serializer.serialize_newtype_variant(name, 14, "StringOwned", v),
136138
AnyValue::StringOwned(v) => {
137-
serializer.serialize_newtype_variant(name, 13, "StringOwned", v.as_str())
139+
serializer.serialize_newtype_variant(name, 14, "StringOwned", v.as_str())
138140
},
139-
AnyValue::Binary(v) => serializer.serialize_newtype_variant(name, 14, "BinaryOwned", v),
141+
142+
// both variants same number
143+
AnyValue::Binary(v) => serializer.serialize_newtype_variant(name, 15, "BinaryOwned", v),
140144
AnyValue::BinaryOwned(v) => {
141-
serializer.serialize_newtype_variant(name, 14, "BinaryOwned", v)
145+
serializer.serialize_newtype_variant(name, 15, "BinaryOwned", v)
146+
},
147+
148+
#[cfg(feature = "dtype-date")]
149+
AnyValue::Date(v) => serializer.serialize_newtype_variant(name, 16, "Date", v),
150+
// both variants same number
151+
#[cfg(feature = "dtype-datetime")]
152+
AnyValue::Datetime(v, tu, tz) => serializer.serialize_newtype_variant(
153+
name,
154+
17,
155+
"DatetimeOwned",
156+
&(*v, *tu, tz.map(|v| v.as_str())),
157+
),
158+
#[cfg(feature = "dtype-datetime")]
159+
AnyValue::DatetimeOwned(v, tu, tz) => serializer.serialize_newtype_variant(
160+
name,
161+
17,
162+
"DatetimeOwned",
163+
&(*v, *tu, tz.as_deref().map(|v| v.as_str())),
164+
),
165+
#[cfg(feature = "dtype-duration")]
166+
AnyValue::Duration(v, tu) => {
167+
serializer.serialize_newtype_variant(name, 18, "Duration", &(*v, *tu))
168+
},
169+
#[cfg(feature = "dtype-time")]
170+
AnyValue::Time(v) => serializer.serialize_newtype_variant(name, 19, "Time", v),
171+
172+
// Not 100% sure how to deal with these.
173+
#[cfg(feature = "dtype-categorical")]
174+
AnyValue::Categorical(..) | AnyValue::CategoricalOwned(..) => Err(
175+
serde::ser::Error::custom("Cannot serialize categorical value."),
176+
),
177+
#[cfg(feature = "dtype-categorical")]
178+
AnyValue::Enum(..) | AnyValue::EnumOwned(..) => {
179+
Err(serde::ser::Error::custom("Cannot serialize enum value."))
180+
},
181+
182+
#[cfg(feature = "dtype-array")]
183+
AnyValue::Array(v, width) => {
184+
serializer.serialize_newtype_variant(name, 22, "Array", &(v, *width))
185+
},
186+
#[cfg(feature = "object")]
187+
AnyValue::Object(_) | AnyValue::ObjectOwned(_) => {
188+
Err(serde::ser::Error::custom("Cannot serialize object value."))
189+
},
190+
#[cfg(feature = "dtype-struct")]
191+
AnyValue::Struct(_, _, _) | AnyValue::StructOwned(_) => {
192+
Err(serde::ser::Error::custom("Cannot serialize struct value."))
193+
},
194+
#[cfg(feature = "dtype-decimal")]
195+
AnyValue::Decimal(v, scale) => {
196+
serializer.serialize_newtype_variant(name, 25, "Decimal", &(*v, *scale))
142197
},
143-
_ => Err(serde::ser::Error::custom(
144-
"Unknown data type. Cannot serialize",
145-
)),
146198
}
147199
}
148200
}
@@ -153,8 +205,18 @@ impl<'a> Deserialize<'a> for AnyValue<'static> {
153205
where
154206
D: Deserializer<'a>,
155207
{
156-
#[repr(u8)]
157-
enum AvField {
208+
macro_rules! define_av_field {
209+
($($variant:ident,)+) => {
210+
#[derive(Deserialize, Serialize)]
211+
enum AvField {
212+
$($variant,)+
213+
}
214+
const VARIANTS: &'static [&'static str] = &[
215+
$(stringify!($variant),)+
216+
];
217+
};
218+
}
219+
define_av_field! {
158220
Null,
159221
Int8,
160222
Int16,
@@ -171,109 +233,17 @@ impl<'a> Deserialize<'a> for AnyValue<'static> {
171233
Bool,
172234
StringOwned,
173235
BinaryOwned,
174-
}
175-
const VARIANTS: &[&str] = &[
176-
"Null",
177-
"UInt8",
178-
"UInt16",
179-
"UInt32",
180-
"UInt64",
181-
"Int8",
182-
"Int16",
183-
"Int32",
184-
"Int64",
185-
"Int128",
186-
"Float32",
187-
"Float64",
188-
"List",
189-
"Boolean",
190-
"StringOwned",
191-
"BinaryOwned",
192-
];
193-
const LAST: u8 = unsafe { std::mem::transmute::<_, u8>(AvField::BinaryOwned) };
194-
195-
struct FieldVisitor;
196-
197-
impl Visitor<'_> for FieldVisitor {
198-
type Value = AvField;
199-
200-
fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
201-
write!(formatter, "an integer between 0-{LAST}")
202-
}
203-
204-
fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, E>
205-
where
206-
E: Error,
207-
{
208-
let field: u8 = NumCast::from(v).ok_or_else(|| {
209-
serde::de::Error::invalid_value(
210-
Unexpected::Signed(v),
211-
&"expected value that fits into u8",
212-
)
213-
})?;
214-
215-
// SAFETY:
216-
// we are repr: u8 and check last value that we are in bounds
217-
let field = unsafe {
218-
if field <= LAST {
219-
std::mem::transmute::<u8, AvField>(field)
220-
} else {
221-
return Err(serde::de::Error::invalid_value(
222-
Unexpected::Signed(v),
223-
&"expected value that fits into AnyValue's number of fields",
224-
));
225-
}
226-
};
227-
Ok(field)
228-
}
229-
230-
fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
231-
where
232-
E: Error,
233-
{
234-
self.visit_bytes(v.as_bytes())
235-
}
236-
237-
fn visit_bytes<E>(self, v: &[u8]) -> std::result::Result<Self::Value, E>
238-
where
239-
E: Error,
240-
{
241-
let field = match v {
242-
b"Null" => AvField::Null,
243-
b"Int8" => AvField::Int8,
244-
b"Int16" => AvField::Int16,
245-
b"Int32" => AvField::Int32,
246-
b"Int64" => AvField::Int64,
247-
b"Int128" => AvField::Int128,
248-
b"UInt8" => AvField::UInt8,
249-
b"UInt16" => AvField::UInt16,
250-
b"UInt32" => AvField::UInt32,
251-
b"UInt64" => AvField::UInt64,
252-
b"Float32" => AvField::Float32,
253-
b"Float64" => AvField::Float64,
254-
b"List" => AvField::List,
255-
b"Bool" => AvField::Bool,
256-
b"StringOwned" | b"String" => AvField::StringOwned,
257-
b"BinaryOwned" | b"Binary" => AvField::BinaryOwned,
258-
_ => {
259-
return Err(serde::de::Error::unknown_variant(
260-
&String::from_utf8_lossy(v),
261-
VARIANTS,
262-
));
263-
},
264-
};
265-
Ok(field)
266-
}
267-
}
268-
269-
impl<'a> Deserialize<'a> for AvField {
270-
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
271-
where
272-
D: Deserializer<'a>,
273-
{
274-
deserializer.deserialize_identifier(FieldVisitor)
275-
}
276-
}
236+
Date,
237+
DatetimeOwned,
238+
Duration,
239+
Time,
240+
CategoricalOwned,
241+
EnumOwned,
242+
Array,
243+
Object,
244+
Struct,
245+
Decimal,
246+
};
277247

278248
struct OuterVisitor;
279249

@@ -350,6 +320,44 @@ impl<'a> Deserialize<'a> for AnyValue<'static> {
350320
let value = variant.newtype_variant()?;
351321
AnyValue::BinaryOwned(value)
352322
},
323+
(AvField::Date, variant) => feature_gated!("dtype-date", {
324+
let value = variant.newtype_variant()?;
325+
AnyValue::Date(value)
326+
}),
327+
(AvField::DatetimeOwned, variant) => feature_gated!("dtype-datetime", {
328+
let (value, time_unit, time_zone) = variant.newtype_variant()?;
329+
AnyValue::DatetimeOwned(value, time_unit, time_zone)
330+
}),
331+
(AvField::Duration, variant) => feature_gated!("dtype-duration", {
332+
let (value, time_unit) = variant.newtype_variant()?;
333+
AnyValue::Duration(value, time_unit)
334+
}),
335+
(AvField::Time, variant) => feature_gated!("dtype-time", {
336+
let value = variant.newtype_variant()?;
337+
AnyValue::Time(value)
338+
}),
339+
(AvField::CategoricalOwned, _) => feature_gated!("dtype-categorical", {
340+
return Err(serde::de::Error::custom(
341+
"unable to deserialize categorical",
342+
));
343+
}),
344+
(AvField::EnumOwned, _) => feature_gated!("dtype-categorical", {
345+
return Err(serde::de::Error::custom("unable to deserialize enum"));
346+
}),
347+
(AvField::Array, variant) => feature_gated!("dtype-array", {
348+
let (s, width) = variant.newtype_variant()?;
349+
AnyValue::Array(s, width)
350+
}),
351+
(AvField::Object, _) => feature_gated!("object", {
352+
return Err(serde::de::Error::custom("unable to deserialize object"));
353+
}),
354+
(AvField::Struct, _) => feature_gated!("dtype-struct", {
355+
return Err(serde::de::Error::custom("unable to deserialize struct"));
356+
}),
357+
(AvField::Decimal, variant) => feature_gated!("dtype-decimal", {
358+
let (v, scale) = variant.newtype_variant()?;
359+
AnyValue::Decimal(v, scale)
360+
}),
353361
};
354362
Ok(out)
355363
}

crates/polars-core/src/datatypes/dtype.rs

+3
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,9 @@ impl DataType {
534534
pub fn is_date(&self) -> bool {
535535
matches!(self, DataType::Date)
536536
}
537+
pub fn is_datetime(&self) -> bool {
538+
matches!(self, DataType::Datetime(..))
539+
}
537540

538541
pub fn is_object(&self) -> bool {
539542
#[cfg(feature = "object")]

crates/polars-core/src/datatypes/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use polars_utils::nulls::IsNull;
4545
use polars_utils::total_ord::TotalHash;
4646
pub use schema::SchemaExtPl;
4747
#[cfg(feature = "serde")]
48-
use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor};
48+
use serde::de::{EnumAccess, VariantAccess, Visitor};
4949
#[cfg(any(feature = "serde", feature = "serde-lazy"))]
5050
use serde::{Deserialize, Serialize};
5151
#[cfg(any(feature = "serde", feature = "serde-lazy"))]

crates/polars-core/src/scalar/from.rs

+1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ impl_from! {
2929
(f32, Float32, Float32)
3030
(f64, Float64, Float64)
3131
(PlSmallStr, StringOwned, String)
32+
(Vec<u8>, BinaryOwned, Binary)
3233
}

crates/polars-core/src/scalar/mod.rs

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
mod from;
2+
mod new;
23
pub mod reduce;
34

5+
use std::hash::Hash;
6+
7+
use polars_error::PolarsResult;
48
use polars_utils::pl_str::PlSmallStr;
59
#[cfg(feature = "serde")]
610
use serde::{Deserialize, Serialize};
711

12+
use crate::chunked_array::cast::CastOptions;
813
use crate::datatypes::{AnyValue, DataType};
914
use crate::prelude::{Column, Series};
1015

@@ -15,6 +20,13 @@ pub struct Scalar {
1520
value: AnyValue<'static>,
1621
}
1722

23+
impl Hash for Scalar {
24+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
25+
self.dtype.hash(state);
26+
self.value.hash_impl(state, true);
27+
}
28+
}
29+
1830
impl Default for Scalar {
1931
fn default() -> Self {
2032
Self {
@@ -26,14 +38,28 @@ impl Default for Scalar {
2638

2739
impl Scalar {
2840
#[inline(always)]
29-
pub fn new(dtype: DataType, value: AnyValue<'static>) -> Self {
41+
pub const fn new(dtype: DataType, value: AnyValue<'static>) -> Self {
3042
Self { dtype, value }
3143
}
3244

33-
pub fn null(dtype: DataType) -> Self {
45+
pub const fn null(dtype: DataType) -> Self {
3446
Self::new(dtype, AnyValue::Null)
3547
}
3648

49+
pub fn cast_with_options(self, dtype: &DataType, options: CastOptions) -> PolarsResult<Self> {
50+
if self.dtype() == dtype {
51+
return Ok(self);
52+
}
53+
54+
// @Optimize: If we have fully fleshed out casting semantics, we could just specify the
55+
// cast on AnyValue.
56+
let s = self
57+
.into_series(PlSmallStr::from_static("scalar"))
58+
.cast_with_options(dtype, options)?;
59+
let value = s.get(0).unwrap();
60+
Ok(Self::new(s.dtype().clone(), value.into_static()))
61+
}
62+
3763
#[inline(always)]
3864
pub fn is_null(&self) -> bool {
3965
self.value.is_null()

0 commit comments

Comments
 (0)