Skip to content

Commit 32e8f88

Browse files
sungwyFokko
andauthored
support PyArrow timestamptz with Etc/UTC (apache#910)
Co-authored-by: Fokko Driesprong <[email protected]>
1 parent f6d56e9 commit 32e8f88

File tree

7 files changed

+218
-86
lines changed

7 files changed

+218
-86
lines changed

pyiceberg/io/pyarrow.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
MAP_KEY_NAME = "key"
175175
MAP_VALUE_NAME = "value"
176176
DOC = "doc"
177+
UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"}
177178

178179
T = TypeVar("T")
179180

@@ -937,7 +938,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
937938
else:
938939
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")
939940

940-
if primitive.tz == "UTC" or primitive.tz == "+00:00":
941+
if primitive.tz in UTC_ALIASES:
941942
return TimestamptzType()
942943
elif primitive.tz is None:
943944
return TimestampType()
@@ -1073,7 +1074,7 @@ def _task_to_record_batches(
10731074
arrow_table = pa.Table.from_batches([batch])
10741075
arrow_table = arrow_table.filter(pyarrow_filter)
10751076
batch = arrow_table.to_batches()[0]
1076-
yield to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
1077+
yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
10771078
current_index += len(batch)
10781079

10791080

@@ -1278,7 +1279,7 @@ def project_batches(
12781279
total_row_count += len(batch)
12791280

12801281

1281-
def to_requested_schema(
1282+
def _to_requested_schema(
12821283
requested_schema: Schema,
12831284
file_schema: Schema,
12841285
batch: pa.RecordBatch,
@@ -1296,31 +1297,49 @@ def to_requested_schema(
12961297

12971298

12981299
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
1299-
file_schema: Schema
1300+
_file_schema: Schema
13001301
_include_field_ids: bool
1302+
_downcast_ns_timestamp_to_us: bool
13011303

13021304
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
1303-
self.file_schema = file_schema
1305+
self._file_schema = file_schema
13041306
self._include_field_ids = include_field_ids
1305-
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
1307+
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
13061308

13071309
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
1308-
file_field = self.file_schema.find_field(field.field_id)
1310+
file_field = self._file_schema.find_field(field.field_id)
13091311

13101312
if field.field_type.is_primitive:
13111313
if field.field_type != file_field.field_type:
13121314
return values.cast(
13131315
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
13141316
)
13151317
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
1316-
# Downcasting of nanoseconds to microseconds
1317-
if (
1318-
pa.types.is_timestamp(target_type)
1319-
and target_type.unit == "us"
1320-
and pa.types.is_timestamp(values.type)
1321-
and values.type.unit == "ns"
1322-
):
1323-
return values.cast(target_type, safe=False)
1318+
if field.field_type == TimestampType():
1319+
# Downcasting of nanoseconds to microseconds
1320+
if (
1321+
pa.types.is_timestamp(target_type)
1322+
and not target_type.tz
1323+
and pa.types.is_timestamp(values.type)
1324+
and not values.type.tz
1325+
):
1326+
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
1327+
return values.cast(target_type, safe=False)
1328+
elif target_type.unit == "us" and values.type.unit in {"s", "ms"}:
1329+
return values.cast(target_type)
1330+
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
1331+
elif field.field_type == TimestamptzType():
1332+
if (
1333+
pa.types.is_timestamp(target_type)
1334+
and target_type.tz == "UTC"
1335+
and pa.types.is_timestamp(values.type)
1336+
and values.type.tz in UTC_ALIASES
1337+
):
1338+
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
1339+
return values.cast(target_type, safe=False)
1340+
elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}:
1341+
return values.cast(target_type)
1342+
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
13241343
return values
13251344

13261345
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
@@ -1970,7 +1989,7 @@ def write_parquet(task: WriteTask) -> DataFile:
19701989

19711990
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
19721991
batches = [
1973-
to_requested_schema(
1992+
_to_requested_schema(
19741993
requested_schema=file_schema,
19751994
file_schema=table_schema,
19761995
batch=batch,

pyiceberg/table/__init__.py

-8
Original file line numberDiff line numberDiff line change
@@ -484,10 +484,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
484484
_check_schema_compatible(
485485
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
486486
)
487-
# cast if the two schemas are compatible but not equal
488-
table_arrow_schema = self._table.schema().as_arrow()
489-
if table_arrow_schema != df.schema:
490-
df = df.cast(table_arrow_schema)
491487

492488
manifest_merge_enabled = PropertyUtil.property_as_bool(
493489
self.table_metadata.properties,
@@ -545,10 +541,6 @@ def overwrite(
545541
_check_schema_compatible(
546542
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
547543
)
548-
# cast if the two schemas are compatible but not equal
549-
table_arrow_schema = self._table.schema().as_arrow()
550-
if table_arrow_schema != df.schema:
551-
df = df.cast(table_arrow_schema)
552544

553545
self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)
554546

tests/conftest.py

+114-2
Original file line numberDiff line numberDiff line change
@@ -2382,10 +2382,122 @@ def arrow_table_date_timestamps() -> "pa.Table":
23822382

23832383

23842384
@pytest.fixture(scope="session")
2385-
def arrow_table_date_timestamps_schema() -> Schema:
2386-
"""Pyarrow table Schema with only date, timestamp and timestamptz values."""
2385+
def table_date_timestamps_schema() -> Schema:
2386+
"""Iceberg table Schema with only date, timestamp and timestamptz values."""
23872387
return Schema(
23882388
NestedField(field_id=1, name="date", field_type=DateType(), required=False),
23892389
NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False),
23902390
NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False),
23912391
)
2392+
2393+
2394+
@pytest.fixture(scope="session")
2395+
def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema":
2396+
"""Pyarrow Schema with all supported timestamp types."""
2397+
import pyarrow as pa
2398+
2399+
return pa.schema([
2400+
("timestamp_s", pa.timestamp(unit="s")),
2401+
("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
2402+
("timestamp_ms", pa.timestamp(unit="ms")),
2403+
("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
2404+
("timestamp_us", pa.timestamp(unit="us")),
2405+
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
2406+
("timestamp_ns", pa.timestamp(unit="ns")),
2407+
("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
2408+
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
2409+
("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")),
2410+
("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")),
2411+
])
2412+
2413+
2414+
@pytest.fixture(scope="session")
2415+
def arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions: "pa.Schema") -> "pa.Table":
2416+
"""Pyarrow table with all supported timestamp types."""
2417+
import pandas as pd
2418+
import pyarrow as pa
2419+
2420+
test_data = pd.DataFrame({
2421+
"timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
2422+
"timestamptz_s": [
2423+
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
2424+
None,
2425+
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
2426+
],
2427+
"timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
2428+
"timestamptz_ms": [
2429+
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
2430+
None,
2431+
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
2432+
],
2433+
"timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
2434+
"timestamptz_us": [
2435+
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
2436+
None,
2437+
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
2438+
],
2439+
"timestamp_ns": [
2440+
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6),
2441+
None,
2442+
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7),
2443+
],
2444+
"timestamptz_ns": [
2445+
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
2446+
None,
2447+
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
2448+
],
2449+
"timestamptz_us_etc_utc": [
2450+
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
2451+
None,
2452+
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
2453+
],
2454+
"timestamptz_ns_z": [
2455+
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"),
2456+
None,
2457+
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"),
2458+
],
2459+
"timestamptz_s_0000": [
2460+
datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc),
2461+
None,
2462+
datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc),
2463+
],
2464+
})
2465+
return pa.Table.from_pandas(test_data, schema=arrow_table_schema_with_all_timestamp_precisions)
2466+
2467+
2468+
@pytest.fixture(scope="session")
2469+
def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> "pa.Schema":
2470+
"""Pyarrow Schema with all microseconds timestamp."""
2471+
import pyarrow as pa
2472+
2473+
return pa.schema([
2474+
("timestamp_s", pa.timestamp(unit="us")),
2475+
("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
2476+
("timestamp_ms", pa.timestamp(unit="us")),
2477+
("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
2478+
("timestamp_us", pa.timestamp(unit="us")),
2479+
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
2480+
("timestamp_ns", pa.timestamp(unit="us")),
2481+
("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
2482+
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
2483+
("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")),
2484+
("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")),
2485+
])
2486+
2487+
2488+
@pytest.fixture(scope="session")
2489+
def table_schema_with_all_microseconds_timestamp_precision() -> Schema:
2490+
"""Iceberg table Schema with only date, timestamp and timestamptz values."""
2491+
return Schema(
2492+
NestedField(field_id=1, name="timestamp_s", field_type=TimestampType(), required=False),
2493+
NestedField(field_id=2, name="timestamptz_s", field_type=TimestamptzType(), required=False),
2494+
NestedField(field_id=3, name="timestamp_ms", field_type=TimestampType(), required=False),
2495+
NestedField(field_id=4, name="timestamptz_ms", field_type=TimestamptzType(), required=False),
2496+
NestedField(field_id=5, name="timestamp_us", field_type=TimestampType(), required=False),
2497+
NestedField(field_id=6, name="timestamptz_us", field_type=TimestamptzType(), required=False),
2498+
NestedField(field_id=7, name="timestamp_ns", field_type=TimestampType(), required=False),
2499+
NestedField(field_id=8, name="timestamptz_ns", field_type=TimestamptzType(), required=False),
2500+
NestedField(field_id=9, name="timestamptz_us_etc_utc", field_type=TimestamptzType(), required=False),
2501+
NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False),
2502+
NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False),
2503+
)

tests/integration/test_add_files.py

+1
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca
570570
assert table_schema == arrow_schema_large
571571

572572

573+
@pytest.mark.integration
573574
def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
574575
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))
575576

tests/integration/test_writes/test_partitioned_writes.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -461,15 +461,15 @@ def test_append_transform_partition_verify_partitions_count(
461461
session_catalog: Catalog,
462462
spark: SparkSession,
463463
arrow_table_date_timestamps: pa.Table,
464-
arrow_table_date_timestamps_schema: Schema,
464+
table_date_timestamps_schema: Schema,
465465
transform: Transform[Any, Any],
466466
expected_partitions: Set[Any],
467467
format_version: int,
468468
) -> None:
469469
# Given
470470
part_col = "timestamptz"
471471
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
472-
nested_field = arrow_table_date_timestamps_schema.find_field(part_col)
472+
nested_field = table_date_timestamps_schema.find_field(part_col)
473473
partition_spec = PartitionSpec(
474474
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
475475
)
@@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count(
481481
properties={"format-version": str(format_version)},
482482
data=[arrow_table_date_timestamps],
483483
partition_spec=partition_spec,
484-
schema=arrow_table_date_timestamps_schema,
484+
schema=table_date_timestamps_schema,
485485
)
486486

487487
# Then
@@ -510,20 +510,20 @@ def test_append_multiple_partitions(
510510
session_catalog: Catalog,
511511
spark: SparkSession,
512512
arrow_table_date_timestamps: pa.Table,
513-
arrow_table_date_timestamps_schema: Schema,
513+
table_date_timestamps_schema: Schema,
514514
format_version: int,
515515
) -> None:
516516
# Given
517517
identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions"
518518
partition_spec = PartitionSpec(
519519
PartitionField(
520-
source_id=arrow_table_date_timestamps_schema.find_field("date").field_id,
520+
source_id=table_date_timestamps_schema.find_field("date").field_id,
521521
field_id=1001,
522522
transform=YearTransform(),
523523
name="date_year",
524524
),
525525
PartitionField(
526-
source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id,
526+
source_id=table_date_timestamps_schema.find_field("timestamptz").field_id,
527527
field_id=1000,
528528
transform=HourTransform(),
529529
name="timestamptz_hour",
@@ -537,7 +537,7 @@ def test_append_multiple_partitions(
537537
properties={"format-version": str(format_version)},
538538
data=[arrow_table_date_timestamps],
539539
partition_spec=partition_spec,
540-
schema=arrow_table_date_timestamps_schema,
540+
schema=table_date_timestamps_schema,
541541
)
542542

543543
# Then

0 commit comments

Comments
 (0)