Skip to content

Commit 158ab9a

Browse files
add timedelta type (#137)
* add timedelta type * fix lint single quotes * fix lint single quotes * fix review * merge main branch * add timedelta testcase * fix timedelta positive * fix pytimedelta_as_timedelta * remove dbg * use approx comparison for total_seconds * uprev speedate to limit durations Co-authored-by: Samuel Colvin <[email protected]>
1 parent 606da17 commit 158ab9a

File tree

13 files changed

+583
-18
lines changed

13 files changed

+583
-18
lines changed

Cargo.lock

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ enum_dispatch = "0.3.8"
1717
serde = "1.0.137"
1818
indexmap = "1.8.1"
1919
mimalloc = { version = "0.1.29", default-features = false, optional = true }
20-
speedate = "0.4.1"
20+
speedate = "0.6.0"
2121
ahash = "0.7.6"
2222
nohash-hasher = "0.2.0"
2323

pydantic_core/_types.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import sys
4-
from datetime import date, datetime, time
4+
from datetime import date, datetime, time, timedelta
55
from typing import Any, Callable, Dict, List, Type, Union
66

77
if sys.version_info < (3, 11):
@@ -211,6 +211,16 @@ class DatetimeSchema(TypedDict, total=False):
211211
ref: str
212212

213213

214+
class TimedeltaSchema(TypedDict, total=False):
215+
type: Required[Literal['timedelta']]
216+
strict: bool
217+
le: timedelta
218+
ge: timedelta
219+
lt: timedelta
220+
gt: timedelta
221+
ref: str
222+
223+
214224
class TupleFixLenSchema(TypedDict, total=False):
215225
type: Required[Literal['tuple-fix-len']]
216226
items_schema: Required[List[Schema]]
@@ -287,6 +297,7 @@ class CallableSchema(TypedDict):
287297
DateSchema,
288298
TimeSchema,
289299
DatetimeSchema,
300+
TimedeltaSchema,
290301
IsInstanceSchema,
291302
CallableSchema,
292303
]

src/errors/kinds.rs

+8
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ pub enum ErrorKind {
190190
)]
191191
DateTimeObjectInvalid { error: String },
192192
// ---------------------
193+
// timedelta errors
194+
#[strum(message = "Value must be a valid timedelta")]
195+
TimeDeltaType,
196+
#[strum(message = "Value must be a valid timedelta, {error}")]
197+
TimeDeltaParsing { error: &'static str },
198+
// ---------------------
193199
// frozenset errors
194200
#[strum(message = "Value must be a valid frozenset")]
195201
FrozenSetType,
@@ -275,6 +281,7 @@ impl ErrorKind {
275281
Self::TimeParsing { error } => render!(template, error),
276282
Self::DateTimeParsing { error } => render!(template, error),
277283
Self::DateTimeObjectInvalid { error } => render!(template, error),
284+
Self::TimeDeltaParsing { error } => render!(template, error),
278285
Self::IsInstanceOf { class } => render!(template, class),
279286
_ => template.to_string(),
280287
}
@@ -321,6 +328,7 @@ impl ErrorKind {
321328
Self::TimeParsing { error } => py_dict!(py, error),
322329
Self::DateTimeParsing { error } => py_dict!(py, error),
323330
Self::DateTimeObjectInvalid { error } => py_dict!(py, error),
331+
Self::TimeDeltaParsing { error } => py_dict!(py, error),
324332
Self::IsInstanceOf { class } => py_dict!(py, class),
325333
_ => Ok(None),
326334
}

src/input/datetime.rs

+101-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use pyo3::intern;
22
use pyo3::prelude::*;
3-
use pyo3::types::{PyDate, PyDateTime, PyDelta, PyTime, PyTzInfo};
4-
use speedate::{Date, DateTime, Time};
3+
use pyo3::types::{PyDate, PyDateTime, PyDelta, PyDeltaAccess, PyTime, PyTzInfo};
4+
use speedate::{Date, DateTime, Duration, Time};
55
use strum::EnumMessage;
66

7-
use super::Input;
87
use crate::errors::{ErrorKind, ValError, ValResult};
98

9+
use super::Input;
10+
1011
pub enum EitherDate<'a> {
1112
Raw(Date),
1213
Py(&'a PyDate),
@@ -69,6 +70,73 @@ impl<'a> From<&'a PyTime> for EitherTime<'a> {
6970
}
7071
}
7172

73+
#[derive(Debug, Clone)]
74+
pub enum EitherTimedelta<'a> {
75+
Raw(Duration),
76+
Py(&'a PyDelta),
77+
}
78+
79+
impl<'a> From<Duration> for EitherTimedelta<'a> {
80+
fn from(timedelta: Duration) -> Self {
81+
Self::Raw(timedelta)
82+
}
83+
}
84+
85+
impl<'a> From<&'a PyDelta> for EitherTimedelta<'a> {
86+
fn from(timedelta: &'a PyDelta) -> Self {
87+
Self::Py(timedelta)
88+
}
89+
}
90+
91+
pub fn pytimedelta_as_timedelta(py_timedelta: &PyDelta) -> Duration {
92+
// see https://docs.python.org/3/c-api/datetime.html#c.PyDateTime_DELTA_GET_DAYS
93+
// days can be negative, but seconds and microseconds are always positive.
94+
let mut days = py_timedelta.get_days(); // -999999999 to 999999999
95+
let mut seconds = py_timedelta.get_seconds(); // 0 through 86399
96+
let mut microseconds = py_timedelta.get_microseconds(); // 0 through 999999
97+
let positive = days >= 0;
98+
if !positive {
99+
// negative timedelta, we need to adjust values to match duration logic
100+
if microseconds != 0 {
101+
seconds += 1;
102+
microseconds = (microseconds - 1_000_000).abs();
103+
}
104+
if seconds != 0 {
105+
days += 1;
106+
seconds = (seconds - 86400).abs();
107+
}
108+
days = days.abs();
109+
}
110+
// we can safely "unwrap" since the methods above guarantee values are in the correct ranges.
111+
Duration::new(positive, days as u32, seconds as u32, microseconds as u32).unwrap()
112+
}
113+
114+
impl<'a> EitherTimedelta<'a> {
115+
pub fn as_raw(&self) -> Duration {
116+
match self {
117+
Self::Raw(timedelta) => timedelta.clone(),
118+
Self::Py(py_timedelta) => pytimedelta_as_timedelta(py_timedelta),
119+
}
120+
}
121+
122+
pub fn try_into_py(self, py: Python<'_>) -> PyResult<PyObject> {
123+
let timedelta = match self {
124+
Self::Py(timedelta) => Ok(timedelta),
125+
Self::Raw(duration) => {
126+
let sign = if duration.positive { 1 } else { -1 };
127+
PyDelta::new(
128+
py,
129+
sign * duration.day as i32,
130+
sign * duration.second as i32,
131+
sign * duration.microsecond as i32,
132+
true,
133+
)
134+
}
135+
}?;
136+
Ok(timedelta.into_py(py))
137+
}
138+
}
139+
72140
macro_rules! pytime_as_time {
73141
($py_time:expr) => {
74142
speedate::Time {
@@ -287,6 +355,36 @@ pub fn float_as_time<'a>(input: &'a impl Input<'a>, timestamp: f64) -> ValResult
287355
int_as_time(input, timestamp.floor() as i64, microseconds.round() as u32)
288356
}
289357

358+
pub fn bytes_as_timedelta<'a, 'b>(input: &'a impl Input<'a>, bytes: &'b [u8]) -> ValResult<'a, EitherTimedelta<'a>> {
359+
match Duration::parse_bytes(bytes) {
360+
Ok(dt) => Ok(dt.into()),
361+
Err(err) => Err(ValError::new(
362+
ErrorKind::TimeDeltaParsing {
363+
error: err.get_documentation().unwrap_or_default(),
364+
},
365+
input,
366+
)),
367+
}
368+
}
369+
370+
pub fn int_as_duration(total_seconds: i64) -> Duration {
371+
let positive = total_seconds >= 0;
372+
let total_seconds = total_seconds.unsigned_abs();
373+
// we can safely unwrap here since we've guaranteed seconds and microseconds can't cause overflow
374+
let days = (total_seconds / 86400) as u32;
375+
let seconds = (total_seconds % 86400) as u32;
376+
Duration::new(positive, days, seconds, 0).unwrap()
377+
}
378+
379+
pub fn float_as_duration(total_seconds: f64) -> Duration {
380+
let positive = total_seconds >= 0_f64;
381+
let total_seconds = total_seconds.abs();
382+
let microsecond = total_seconds.fract() * 1_000_000.0;
383+
let days = (total_seconds / 86400f64) as u32;
384+
let seconds = total_seconds as u64 % 86400;
385+
Duration::new(positive, days, seconds as u32, microsecond.round() as u32).unwrap()
386+
}
387+
290388
#[pyclass(module = "pydantic_core._pydantic_core", extends = PyTzInfo)]
291389
#[derive(Debug, Clone)]
292390
struct TzInfo {

src/input/input_abstract.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use pyo3::types::PyType;
66
use crate::errors::{InputValue, LocItem, ValResult};
77
use crate::input::datetime::EitherTime;
88

9-
use super::datetime::{EitherDate, EitherDateTime};
9+
use super::datetime::{EitherDate, EitherDateTime, EitherTimedelta};
1010
use super::return_enums::{EitherBytes, EitherString};
1111
use super::{GenericMapping, GenericSequence};
1212

@@ -99,6 +99,12 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
9999
self.strict_tuple()
100100
}
101101

102+
fn strict_timedelta(&self) -> ValResult<EitherTimedelta>;
103+
104+
fn lax_timedelta(&self) -> ValResult<EitherTimedelta> {
105+
self.strict_timedelta()
106+
}
107+
102108
fn is_instance(&self, _class: &PyType) -> PyResult<bool> {
103109
Ok(false)
104110
}

src/input/input_json.rs

+23-3
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ use pyo3::types::PyType;
33
use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValResult};
44

55
use super::datetime::{
6-
bytes_as_date, bytes_as_datetime, bytes_as_time, float_as_datetime, float_as_time, int_as_datetime, int_as_time,
7-
EitherDate, EitherDateTime, EitherTime,
6+
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
7+
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
88
};
99
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_int};
10-
use super::{EitherBytes, EitherString, GenericMapping, GenericSequence, Input, JsonInput};
10+
use super::{EitherBytes, EitherString, EitherTimedelta, GenericMapping, GenericSequence, Input, JsonInput};
1111

1212
impl<'a> Input<'a> for JsonInput {
1313
/// This is required by since JSON object keys are always strings, I don't think it can be called
@@ -197,6 +197,22 @@ impl<'a> Input<'a> for JsonInput {
197197
_ => Err(ValError::new(ErrorKind::TupleType, self)),
198198
}
199199
}
200+
201+
fn strict_timedelta(&self) -> ValResult<EitherTimedelta> {
202+
match self {
203+
JsonInput::String(v) => bytes_as_timedelta(self, v.as_bytes()),
204+
_ => Err(ValError::new(ErrorKind::TimeDeltaType, self)),
205+
}
206+
}
207+
208+
fn lax_timedelta(&self) -> ValResult<EitherTimedelta> {
209+
match self {
210+
JsonInput::String(v) => bytes_as_timedelta(self, v.as_bytes()),
211+
JsonInput::Int(v) => Ok(int_as_duration(*v).into()),
212+
JsonInput::Float(v) => Ok(float_as_duration(*v).into()),
213+
_ => Err(ValError::new(ErrorKind::TimeDeltaType, self)),
214+
}
215+
}
200216
}
201217

202218
/// Required for Dict keys so the string can behave like an Input
@@ -298,4 +314,8 @@ impl<'a> Input<'a> for String {
298314
fn strict_tuple<'data>(&'data self) -> ValResult<GenericSequence<'data>> {
299315
Err(ValError::new(ErrorKind::TupleType, self))
300316
}
317+
318+
fn strict_timedelta(&self) -> ValResult<super::EitherTimedelta> {
319+
bytes_as_timedelta(self, self.as_bytes())
320+
}
301321
}

src/input/input_python.rs

+30-5
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@ use std::str::from_utf8;
33
use pyo3::exceptions::{PyAttributeError, PyTypeError};
44
use pyo3::prelude::*;
55
use pyo3::types::{
6-
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDict, PyFrozenSet, PyInt, PyList, PyMapping, PySequence, PySet,
7-
PyString, PyTime, PyTuple, PyType,
6+
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyInt, PyList, PyMapping,
7+
PySequence, PySet, PyString, PyTime, PyTuple, PyType,
88
};
99
use pyo3::{intern, AsPyPointer};
1010

1111
use crate::errors::{as_internal, py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};
1212

1313
use super::datetime::{
14-
bytes_as_date, bytes_as_datetime, bytes_as_time, date_as_datetime, float_as_datetime, float_as_time,
15-
int_as_datetime, int_as_time, EitherDate, EitherDateTime, EitherTime,
14+
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
15+
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime,
16+
EitherTime,
1617
};
1718
use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_int};
18-
use super::{repr_string, EitherBytes, EitherString, GenericMapping, GenericSequence, Input};
19+
use super::{repr_string, EitherBytes, EitherString, EitherTimedelta, GenericMapping, GenericSequence, Input};
1920

2021
impl<'a> Input<'a> for PyAny {
2122
fn as_loc_item(&'a self) -> LocItem {
@@ -391,6 +392,30 @@ impl<'a> Input<'a> for PyAny {
391392
}
392393
}
393394

395+
fn strict_timedelta(&self) -> ValResult<EitherTimedelta> {
396+
if let Ok(dt) = self.cast_as::<PyDelta>() {
397+
Ok(dt.into())
398+
} else {
399+
Err(ValError::new(ErrorKind::TimeDeltaType, self))
400+
}
401+
}
402+
403+
fn lax_timedelta(&self) -> ValResult<EitherTimedelta> {
404+
if let Ok(dt) = self.cast_as::<PyDelta>() {
405+
Ok(dt.into())
406+
} else if let Ok(py_str) = self.cast_as::<PyString>() {
407+
bytes_as_timedelta(self, py_str.to_string_lossy().as_bytes())
408+
} else if let Ok(py_bytes) = self.cast_as::<PyBytes>() {
409+
bytes_as_timedelta(self, py_bytes.as_bytes())
410+
} else if let Ok(int) = self.extract::<i64>() {
411+
Ok(int_as_duration(int).into())
412+
} else if let Ok(float) = self.extract::<f64>() {
413+
Ok(float_as_duration(float).into())
414+
} else {
415+
Err(ValError::new(ErrorKind::TimeDeltaType, self))
416+
}
417+
}
418+
394419
fn is_instance(&self, class: &PyType) -> PyResult<bool> {
395420
self.is_instance(class)
396421
}

src/input/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mod parse_json;
88
mod return_enums;
99
mod shared;
1010

11-
pub use datetime::{EitherDate, EitherDateTime, EitherTime};
11+
pub use datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
1212
pub use input_abstract::Input;
1313
pub use parse_json::{JsonInput, JsonObject};
1414
pub use return_enums::{EitherBytes, EitherString, GenericMapping, GenericSequence};

src/validators/mod.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ mod recursive;
3232
mod set;
3333
mod string;
3434
mod time;
35+
mod timedelta;
3536
mod tuple;
3637
mod typed_dict;
3738
mod union;
38-
3939
#[pyclass(module = "pydantic_core._pydantic_core")]
4040
#[derive(Debug, Clone)]
4141
pub struct SchemaValidator {
@@ -292,6 +292,8 @@ pub fn build_validator<'a>(
292292
datetime::DateTimeValidator,
293293
// frozensets
294294
frozenset::FrozenSetValidator,
295+
// timedelta
296+
timedelta::TimeDeltaValidator,
295297
// introspection types
296298
is_instance::IsInstanceValidator,
297299
callable::CallableValidator,
@@ -374,6 +376,8 @@ pub enum CombinedValidator {
374376
Datetime(datetime::DateTimeValidator),
375377
// frozensets
376378
FrozenSet(frozenset::FrozenSetValidator),
379+
// timedelta
380+
Timedelta(timedelta::TimeDeltaValidator),
377381
// introspection types
378382
IsInstance(is_instance::IsInstanceValidator),
379383
Callable(callable::CallableValidator),

0 commit comments

Comments
 (0)