Skip to content

Commit 1b9b884

Browse files
Fokkosungwy
andauthored
PyArrow: Don't enforce the schema when reading/writing (apache#902)
* PyArrow: Don't enforce the schema PyIceberg struggled with the different type of arrow, such as the `string` and `large_string`. They represent the same, but are different under the hood. My take is that we should hide these kind of details from the user as much as possible. Now we went down the road of passing in the Iceberg schema into Arrow, but when doing this, Iceberg has to decide if it is a large or non-large type. This PR removes passing down the schema in order to let Arrow decide unless: - The type should be evolved - In case of re-ordering, we reorder the original types * WIP * Reuse Table schema * Make linter happy * Squash some bugs * Thanks Sung! Co-authored-by: Sung Yun <[email protected]> * Moar code moar bugs * Remove the variables wrt file sizes * Linting * Go with large ones for now * Missed one there! --------- Co-authored-by: Sung Yun <[email protected]>
1 parent 8f47dfd commit 1b9b884

File tree

7 files changed

+156
-58
lines changed

7 files changed

+156
-58
lines changed

pyiceberg/io/pyarrow.py

+45-28
Original file line numberDiff line numberDiff line change
@@ -1047,8 +1047,10 @@ def _task_to_record_batches(
10471047

10481048
fragment_scanner = ds.Scanner.from_fragment(
10491049
fragment=fragment,
1050-
# We always use large types in memory as it uses larger offsets
1051-
# That can chunk more row values into the buffers
1050+
# With PyArrow 16.0.0 there is an issue with casting record-batches:
1051+
# https://github.com/apache/arrow/issues/41884
1052+
# https://github.com/apache/arrow/issues/43183
1053+
# Would be good to remove this later on
10521054
schema=_pyarrow_schema_ensure_large_types(physical_schema),
10531055
# This will push down the query to Arrow.
10541056
# But in case there are positional deletes, we have to apply them first
@@ -1084,11 +1086,17 @@ def _task_to_table(
10841086
positional_deletes: Optional[List[ChunkedArray]],
10851087
case_sensitive: bool,
10861088
name_mapping: Optional[NameMapping] = None,
1087-
) -> pa.Table:
1088-
batches = _task_to_record_batches(
1089-
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
1089+
) -> Optional[pa.Table]:
1090+
batches = list(
1091+
_task_to_record_batches(
1092+
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
1093+
)
10901094
)
1091-
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
1095+
1096+
if len(batches) > 0:
1097+
return pa.Table.from_batches(batches)
1098+
else:
1099+
return None
10921100

10931101

10941102
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
@@ -1192,7 +1200,7 @@ def project_table(
11921200
if len(tables) < 1:
11931201
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
11941202

1195-
result = pa.concat_tables(tables)
1203+
result = pa.concat_tables(tables, promote_options="permissive")
11961204

11971205
if limit is not None:
11981206
return result.slice(0, limit)
@@ -1271,54 +1279,62 @@ def project_batches(
12711279

12721280

12731281
def to_requested_schema(
1274-
requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, downcast_ns_timestamp_to_us: bool = False
1282+
requested_schema: Schema,
1283+
file_schema: Schema,
1284+
batch: pa.RecordBatch,
1285+
downcast_ns_timestamp_to_us: bool = False,
1286+
include_field_ids: bool = False,
12751287
) -> pa.RecordBatch:
1288+
# We could re-use some of these visitors
12761289
struct_array = visit_with_partner(
1277-
requested_schema, batch, ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us), ArrowAccessor(file_schema)
1290+
requested_schema,
1291+
batch,
1292+
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids),
1293+
ArrowAccessor(file_schema),
12781294
)
1279-
1280-
arrays = []
1281-
fields = []
1282-
for pos, field in enumerate(requested_schema.fields):
1283-
array = struct_array.field(pos)
1284-
arrays.append(array)
1285-
fields.append(pa.field(field.name, array.type, field.optional))
1286-
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
1295+
return pa.RecordBatch.from_struct_array(struct_array)
12871296

12881297

12891298
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
12901299
file_schema: Schema
1300+
_include_field_ids: bool
12911301

1292-
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False):
1302+
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
12931303
self.file_schema = file_schema
1304+
self._include_field_ids = include_field_ids
12941305
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
12951306

12961307
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
12971308
file_field = self.file_schema.find_field(field.field_id)
1309+
12981310
if field.field_type.is_primitive:
12991311
if field.field_type != file_field.field_type:
1300-
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
1301-
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type:
1302-
# if file_field and field_type (e.g. String) are the same
1303-
# but the pyarrow type of the array is different from the expected type
1304-
# (e.g. string vs larger_string), we want to cast the array to the larger type
1305-
safe = True
1312+
return values.cast(
1313+
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
1314+
)
1315+
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
1316+
# Downcasting of nanoseconds to microseconds
13061317
if (
13071318
pa.types.is_timestamp(target_type)
13081319
and target_type.unit == "us"
13091320
and pa.types.is_timestamp(values.type)
13101321
and values.type.unit == "ns"
13111322
):
1312-
safe = False
1313-
return values.cast(target_type, safe=safe)
1323+
return values.cast(target_type, safe=False)
13141324
return values
13151325

13161326
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
1327+
metadata = {}
1328+
if field.doc:
1329+
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
1330+
if self._include_field_ids:
1331+
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
1332+
13171333
return pa.field(
13181334
name=field.name,
13191335
type=arrow_type,
13201336
nullable=field.optional,
1321-
metadata={DOC: field.doc} if field.doc is not None else None,
1337+
metadata=metadata,
13221338
)
13231339

13241340
def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
@@ -1960,14 +1976,15 @@ def write_parquet(task: WriteTask) -> DataFile:
19601976
file_schema=table_schema,
19611977
batch=batch,
19621978
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
1979+
include_field_ids=True,
19631980
)
19641981
for batch in task.record_batches
19651982
]
19661983
arrow_table = pa.Table.from_batches(batches)
19671984
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
19681985
fo = io.new_output(file_path)
19691986
with fo.create(overwrite=True) as fos:
1970-
with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer:
1987+
with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer:
19711988
writer.write(arrow_table, row_group_size=row_group_size)
19721989
statistics = data_file_statistics_from_parquet_metadata(
19731990
parquet_metadata=writer.writer.metadata,

pyiceberg/table/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2053,8 +2053,9 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
20532053

20542054
from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow
20552055

2056+
target_schema = schema_to_pyarrow(self.projection())
20562057
return pa.RecordBatchReader.from_batches(
2057-
schema_to_pyarrow(self.projection()),
2058+
target_schema,
20582059
project_batches(
20592060
self.plan_files(),
20602061
self.table_metadata,

tests/integration/test_add_files.py

+72-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import os
2020
from datetime import date
21-
from typing import Iterator, Optional
21+
from typing import Iterator
2222

2323
import pyarrow as pa
2424
import pyarrow.parquet as pq
@@ -28,7 +28,8 @@
2828

2929
from pyiceberg.catalog import Catalog
3030
from pyiceberg.exceptions import NoSuchTableError
31-
from pyiceberg.partitioning import PartitionField, PartitionSpec
31+
from pyiceberg.io import FileIO
32+
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
3233
from pyiceberg.schema import Schema
3334
from pyiceberg.table import Table
3435
from pyiceberg.transforms import BucketTransform, IdentityTransform, MonthTransform
@@ -107,23 +108,32 @@
107108
)
108109

109110

111+
def _write_parquet(io: FileIO, file_path: str, arrow_schema: pa.Schema, arrow_table: pa.Table) -> None:
112+
fo = io.new_output(file_path)
113+
with fo.create(overwrite=True) as fos:
114+
with pq.ParquetWriter(fos, schema=arrow_schema) as writer:
115+
writer.write_table(arrow_table)
116+
117+
110118
def _create_table(
111-
session_catalog: Catalog, identifier: str, format_version: int, partition_spec: Optional[PartitionSpec] = None
119+
session_catalog: Catalog,
120+
identifier: str,
121+
format_version: int,
122+
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
123+
schema: Schema = TABLE_SCHEMA,
112124
) -> Table:
113125
try:
114126
session_catalog.drop_table(identifier=identifier)
115127
except NoSuchTableError:
116128
pass
117129

118-
tbl = session_catalog.create_table(
130+
return session_catalog.create_table(
119131
identifier=identifier,
120-
schema=TABLE_SCHEMA,
132+
schema=schema,
121133
properties={"format-version": str(format_version)},
122-
partition_spec=partition_spec if partition_spec else PartitionSpec(),
134+
partition_spec=partition_spec,
123135
)
124136

125-
return tbl
126-
127137

128138
@pytest.fixture(name="format_version", params=[pytest.param(1, id="format_version=1"), pytest.param(2, id="format_version=2")])
129139
def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]:
@@ -454,6 +464,60 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat
454464

455465

456466
@pytest.mark.integration
467+
def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
468+
identifier = f"default.unpartitioned_with_large_types{format_version}"
469+
470+
iceberg_schema = Schema(NestedField(1, "foo", StringType(), required=True))
471+
arrow_schema = pa.schema([
472+
pa.field("foo", pa.string(), nullable=False),
473+
])
474+
arrow_schema_large = pa.schema([
475+
pa.field("foo", pa.large_string(), nullable=False),
476+
])
477+
478+
tbl = _create_table(session_catalog, identifier, format_version, schema=iceberg_schema)
479+
480+
file_path = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-0.parquet"
481+
_write_parquet(
482+
tbl.io,
483+
file_path,
484+
arrow_schema,
485+
pa.Table.from_pylist(
486+
[
487+
{
488+
"foo": "normal",
489+
}
490+
],
491+
schema=arrow_schema,
492+
),
493+
)
494+
495+
tbl.add_files([file_path])
496+
497+
table_schema = tbl.scan().to_arrow().schema
498+
assert table_schema == arrow_schema_large
499+
500+
file_path_large = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-1.parquet"
501+
_write_parquet(
502+
tbl.io,
503+
file_path_large,
504+
arrow_schema_large,
505+
pa.Table.from_pylist(
506+
[
507+
{
508+
"foo": "normal",
509+
}
510+
],
511+
schema=arrow_schema_large,
512+
),
513+
)
514+
515+
tbl.add_files([file_path_large])
516+
517+
table_schema = tbl.scan().to_arrow().schema
518+
assert table_schema == arrow_schema_large
519+
520+
457521
def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
458522
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))
459523

tests/integration/test_deletes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def test_partitioned_table_positional_deletes_sequence_number(spark: SparkSessio
291291
assert snapshots[2].summary == Summary(
292292
Operation.OVERWRITE,
293293
**{
294-
"added-files-size": "1145",
294+
"added-files-size": snapshots[2].summary["total-files-size"],
295295
"added-data-files": "1",
296296
"added-records": "2",
297297
"changed-partition-count": "1",

tests/integration/test_inspect_table.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,25 @@ def test_inspect_snapshots(
110110
for manifest_list in df["manifest_list"]:
111111
assert manifest_list.as_py().startswith("s3://")
112112

113+
file_size = int(next(value for key, value in df["summary"][0].as_py() if key == "added-files-size"))
114+
assert file_size > 0
115+
113116
# Append
114117
assert df["summary"][0].as_py() == [
115-
("added-files-size", "5459"),
118+
("added-files-size", str(file_size)),
116119
("added-data-files", "1"),
117120
("added-records", "3"),
118121
("total-data-files", "1"),
119122
("total-delete-files", "0"),
120123
("total-records", "3"),
121-
("total-files-size", "5459"),
124+
("total-files-size", str(file_size)),
122125
("total-position-deletes", "0"),
123126
("total-equality-deletes", "0"),
124127
]
125128

126129
# Delete
127130
assert df["summary"][1].as_py() == [
128-
("removed-files-size", "5459"),
131+
("removed-files-size", str(file_size)),
129132
("deleted-data-files", "1"),
130133
("deleted-records", "3"),
131134
("total-data-files", "0"),

tests/integration/test_writes/test_partitioned_writes.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -252,28 +252,32 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro
252252
assert operations == ["append", "append"]
253253

254254
summaries = [row.summary for row in rows]
255+
256+
file_size = int(summaries[0]["added-files-size"])
257+
assert file_size > 0
258+
255259
assert summaries[0] == {
256260
"changed-partition-count": "3",
257261
"added-data-files": "3",
258-
"added-files-size": "15029",
262+
"added-files-size": str(file_size),
259263
"added-records": "3",
260264
"total-data-files": "3",
261265
"total-delete-files": "0",
262266
"total-equality-deletes": "0",
263-
"total-files-size": "15029",
267+
"total-files-size": str(file_size),
264268
"total-position-deletes": "0",
265269
"total-records": "3",
266270
}
267271

268272
assert summaries[1] == {
269273
"changed-partition-count": "3",
270274
"added-data-files": "3",
271-
"added-files-size": "15029",
275+
"added-files-size": str(file_size),
272276
"added-records": "3",
273277
"total-data-files": "6",
274278
"total-delete-files": "0",
275279
"total-equality-deletes": "0",
276-
"total-files-size": "30058",
280+
"total-files-size": str(file_size * 2),
277281
"total-position-deletes": "0",
278282
"total-records": "6",
279283
}

0 commit comments

Comments
 (0)