Skip to content

Commit 301e336

Browse files
authored
Cast 's', 'ms' and 'ns' PyArrow timestamp to 'us' precision on write (#848)
1 parent 3f574d3 commit 301e336

File tree

7 files changed

+224
-45
lines changed

7 files changed

+224
-45
lines changed

mkdocs/docs/configuration.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,8 @@ PyIceberg uses multiple threads to parallelize operations. The number of workers
305305

306306
# Backward Compatibility
307307

308-
Previous versions of Java (`<1.4.0`) implementations incorrectly assume the optional attribute `current-snapshot-id` to be a required attribute in TableMetadata. This means that if `current-snapshot-id` is missing in the metadata file (e.g. on table creation), the application will throw an exception without being able to load the table. This assumption has been corrected in more recent Iceberg versions. However, it is possible to force PyIceberg to create a table with a metadata file that will be compatible with previous versions. This can be configured by setting the `legacy-current-snapshot-id` entry as "True" in the configuration file, or by setting the `PYICEBERG_LEGACY_CURRENT_SNAPSHOT_ID` environment variable. Refer to the [PR discussion](https://github.com/apache/iceberg-python/pull/473) for more details on the issue
308+
Previous versions of Java (`<1.4.0`) implementations incorrectly assume the optional attribute `current-snapshot-id` to be a required attribute in TableMetadata. This means that if `current-snapshot-id` is missing in the metadata file (e.g. on table creation), the application will throw an exception without being able to load the table. This assumption has been corrected in more recent Iceberg versions. However, it is possible to force PyIceberg to create a table with a metadata file that will be compatible with previous versions. This can be configured by setting the `legacy-current-snapshot-id` property as "True" in the configuration file, or by setting the `PYICEBERG_LEGACY_CURRENT_SNAPSHOT_ID` environment variable. Refer to the [PR discussion](https://github.com/apache/iceberg-python/pull/473) for more details on the issue
309+
310+
# Nanoseconds Support
311+
312+
PyIceberg currently only supports upto microsecond precision in its TimestampType. PyArrow timestamp types in 's' and 'ms' will be upcast automatically to 'us' precision timestamps on write. Timestamps in 'ns' precision can also be downcast automatically on write if desired. This can be configured by setting the `downcast-ns-timestamp-to-us-on-write` property as "True" in the configuration file, or by setting the `PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE` environment variable. Refer to the [nanoseconds timestamp proposal document](https://docs.google.com/document/d/1bE1DcEGNzZAMiVJSZ0X1wElKLNkT9kRkk0hDlfkXzvU/edit#heading=h.ibflcctc9i1d) for more details on the long term roadmap for nanoseconds support

pyiceberg/catalog/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from pyiceberg.schema import Schema
5050
from pyiceberg.serializers import ToOutputFile
5151
from pyiceberg.table import (
52+
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE,
5253
CommitTableRequest,
5354
CommitTableResponse,
5455
CreateTableTransaction,
@@ -675,8 +676,11 @@ def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema:
675676

676677
from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow
677678

679+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
678680
if isinstance(schema, pa.Schema):
679-
schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore
681+
schema: Schema = visit_pyarrow( # type: ignore
682+
schema, _ConvertToIcebergWithoutIDs(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
683+
)
680684
return schema
681685
except ModuleNotFoundError:
682686
pass

pyiceberg/io/pyarrow.py

+62-21
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
UUIDType,
155155
)
156156
from pyiceberg.utils.concurrent import ExecutorFactory
157+
from pyiceberg.utils.config import Config
157158
from pyiceberg.utils.datetime import millis_to_datetime
158159
from pyiceberg.utils.singleton import Singleton
159160
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
@@ -470,7 +471,9 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
470471

471472

472473
def schema_to_pyarrow(
473-
schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True
474+
schema: Union[Schema, IcebergType],
475+
metadata: Dict[bytes, bytes] = EMPTY_DICT,
476+
include_field_ids: bool = True,
474477
) -> pa.schema:
475478
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))
476479

@@ -663,21 +666,23 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start
663666
return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index)
664667

665668

666-
def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema:
669+
def pyarrow_to_schema(
670+
schema: pa.Schema, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False
671+
) -> Schema:
667672
has_ids = visit_pyarrow(schema, _HasIds())
668673
if has_ids:
669-
visitor = _ConvertToIceberg()
674+
visitor = _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
670675
elif name_mapping is not None:
671-
visitor = _ConvertToIceberg(name_mapping=name_mapping)
676+
visitor = _ConvertToIceberg(name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
672677
else:
673678
raise ValueError(
674679
"Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined"
675680
)
676681
return visit_pyarrow(schema, visitor)
677682

678683

679-
def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
680-
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
684+
def _pyarrow_to_schema_without_ids(schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> Schema:
685+
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))
681686

682687

683688
def _pyarrow_schema_ensure_large_types(schema: pa.Schema) -> pa.Schema:
@@ -849,9 +854,10 @@ class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
849854
_field_names: List[str]
850855
_name_mapping: Optional[NameMapping]
851856

852-
def __init__(self, name_mapping: Optional[NameMapping] = None) -> None:
857+
def __init__(self, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False) -> None:
853858
self._field_names = []
854859
self._name_mapping = name_mapping
860+
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
855861

856862
def _field_id(self, field: pa.Field) -> int:
857863
if self._name_mapping:
@@ -918,11 +924,24 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
918924
return TimeType()
919925
elif pa.types.is_timestamp(primitive):
920926
primitive = cast(pa.TimestampType, primitive)
921-
if primitive.unit == "us":
922-
if primitive.tz == "UTC" or primitive.tz == "+00:00":
923-
return TimestamptzType()
924-
elif primitive.tz is None:
925-
return TimestampType()
927+
if primitive.unit in ("s", "ms", "us"):
928+
# Supported types, will be upcast automatically to 'us'
929+
pass
930+
elif primitive.unit == "ns":
931+
if self._downcast_ns_timestamp_to_us:
932+
logger.warning("Iceberg does not yet support 'ns' timestamp precision. Downcasting to 'us'.")
933+
else:
934+
raise TypeError(
935+
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write."
936+
)
937+
else:
938+
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")
939+
940+
if primitive.tz == "UTC" or primitive.tz == "+00:00":
941+
return TimestamptzType()
942+
elif primitive.tz is None:
943+
return TimestampType()
944+
926945
elif pa.types.is_binary(primitive) or pa.types.is_large_binary(primitive):
927946
return BinaryType()
928947
elif pa.types.is_fixed_size_binary(primitive):
@@ -1010,8 +1029,11 @@ def _task_to_record_batches(
10101029
with fs.open_input_file(path) as fin:
10111030
fragment = arrow_format.make_fragment(fin)
10121031
physical_schema = fragment.physical_schema
1013-
file_schema = pyarrow_to_schema(physical_schema, name_mapping)
1014-
1032+
# In V1 and V2 table formats, we only support Timestamp 'us' in Iceberg Schema
1033+
# Hence it is reasonable to always cast 'ns' timestamp to 'us' on read.
1034+
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
1035+
# the table format version.
1036+
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)
10151037
pyarrow_filter = None
10161038
if bound_row_filter is not AlwaysTrue():
10171039
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
@@ -1049,7 +1071,7 @@ def _task_to_record_batches(
10491071
arrow_table = pa.Table.from_batches([batch])
10501072
arrow_table = arrow_table.filter(pyarrow_filter)
10511073
batch = arrow_table.to_batches()[0]
1052-
yield to_requested_schema(projected_schema, file_project_schema, batch)
1074+
yield to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
10531075
current_index += len(batch)
10541076

10551077

@@ -1248,8 +1270,12 @@ def project_batches(
12481270
total_row_count += len(batch)
12491271

12501272

1251-
def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
1252-
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
1273+
def to_requested_schema(
1274+
requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, downcast_ns_timestamp_to_us: bool = False
1275+
) -> pa.RecordBatch:
1276+
struct_array = visit_with_partner(
1277+
requested_schema, batch, ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us), ArrowAccessor(file_schema)
1278+
)
12531279

12541280
arrays = []
12551281
fields = []
@@ -1263,8 +1289,9 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa
12631289
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
12641290
file_schema: Schema
12651291

1266-
def __init__(self, file_schema: Schema):
1292+
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False):
12671293
self.file_schema = file_schema
1294+
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
12681295

12691296
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
12701297
file_field = self.file_schema.find_field(field.field_id)
@@ -1275,7 +1302,15 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
12751302
# if file_field and field_type (e.g. String) are the same
12761303
# but the pyarrow type of the array is different from the expected type
12771304
# (e.g. string vs larger_string), we want to cast the array to the larger type
1278-
return values.cast(target_type)
1305+
safe = True
1306+
if (
1307+
pa.types.is_timestamp(target_type)
1308+
and target_type.unit == "us"
1309+
and pa.types.is_timestamp(values.type)
1310+
and values.type.unit == "ns"
1311+
):
1312+
safe = False
1313+
return values.cast(target_type, safe=safe)
12791314
return values
12801315

12811316
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
@@ -1899,7 +1934,7 @@ def data_file_statistics_from_parquet_metadata(
18991934

19001935

19011936
def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
1902-
from pyiceberg.table import PropertyUtil, TableProperties
1937+
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties
19031938

19041939
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
19051940
row_group_size = PropertyUtil.property_as_int(
@@ -1918,8 +1953,14 @@ def write_parquet(task: WriteTask) -> DataFile:
19181953
else:
19191954
file_schema = table_schema
19201955

1956+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
19211957
batches = [
1922-
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
1958+
to_requested_schema(
1959+
requested_schema=file_schema,
1960+
file_schema=table_schema,
1961+
batch=batch,
1962+
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
1963+
)
19231964
for batch in task.record_batches
19241965
]
19251966
arrow_table = pa.Table.from_batches(batches)

pyiceberg/table/__init__.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@
146146
transform_dict_value_to_str,
147147
)
148148
from pyiceberg.utils.concurrent import ExecutorFactory
149+
from pyiceberg.utils.config import Config
149150
from pyiceberg.utils.datetime import datetime_to_millis
150151
from pyiceberg.utils.deprecated import deprecated
151152
from pyiceberg.utils.singleton import _convert_to_hashable_type
@@ -161,7 +162,7 @@
161162

162163
ALWAYS_TRUE = AlwaysTrue()
163164
TABLE_ROOT_ID = -1
164-
165+
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
165166
_JAVA_LONG_MAX = 9223372036854775807
166167

167168

@@ -176,11 +177,14 @@ def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") ->
176177
"""
177178
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema
178179

180+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
179181
name_mapping = table_schema.name_mapping
180182
try:
181-
task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping)
183+
task_schema = pyarrow_to_schema(
184+
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
185+
)
182186
except ValueError as e:
183-
other_schema = _pyarrow_to_schema_without_ids(other_schema)
187+
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
184188
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
185189
raise ValueError(
186190
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."

tests/integration/test_add_files.py

+59
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
# under the License.
1717
# pylint:disable=redefined-outer-name
1818

19+
import os
1920
from datetime import date
2021
from typing import Iterator, Optional
2122

2223
import pyarrow as pa
2324
import pyarrow.parquet as pq
2425
import pytest
2526
from pyspark.sql import SparkSession
27+
from pytest_mock.plugin import MockerFixture
2628

2729
from pyiceberg.catalog import Catalog
2830
from pyiceberg.exceptions import NoSuchTableError
@@ -36,6 +38,7 @@
3638
IntegerType,
3739
NestedField,
3840
StringType,
41+
TimestamptzType,
3942
)
4043

4144
TABLE_SCHEMA = Schema(
@@ -448,3 +451,59 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat
448451

449452
assert "snapshot_prop_a" in summary
450453
assert summary["snapshot_prop_a"] == "test_prop_a"
454+
455+
456+
@pytest.mark.integration
457+
def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
458+
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))
459+
460+
nanoseconds_schema = pa.schema([
461+
("quux", pa.timestamp("ns", tz="UTC")),
462+
])
463+
464+
arrow_table = pa.Table.from_pylist(
465+
[
466+
{
467+
"quux": 1615967687249846175, # 2021-03-17 07:54:47.249846159
468+
}
469+
],
470+
schema=nanoseconds_schema,
471+
)
472+
mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})
473+
474+
identifier = f"default.timestamptz_ns_added{format_version}"
475+
476+
try:
477+
session_catalog.drop_table(identifier=identifier)
478+
except NoSuchTableError:
479+
pass
480+
481+
tbl = session_catalog.create_table(
482+
identifier=identifier,
483+
schema=nanoseconds_schema_iceberg,
484+
properties={"format-version": str(format_version)},
485+
partition_spec=PartitionSpec(),
486+
)
487+
488+
file_paths = [f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test-{i}.parquet" for i in range(5)]
489+
# write parquet files
490+
for file_path in file_paths:
491+
fo = tbl.io.new_output(file_path)
492+
with fo.create(overwrite=True) as fos:
493+
with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
494+
writer.write_table(arrow_table)
495+
496+
# add the parquet files as data files
497+
tbl.add_files(file_paths=file_paths)
498+
499+
assert tbl.scan().to_arrow() == pa.concat_tables(
500+
[
501+
arrow_table.cast(
502+
pa.schema([
503+
("quux", pa.timestamp("us", tz="UTC")),
504+
]),
505+
safe=False,
506+
)
507+
]
508+
* 5
509+
)

0 commit comments

Comments
 (0)