Skip to content

Commit dceedfa

Browse files
sungwyFokko
andauthored
Check if schema is compatible in add_files API (apache#907)
Co-authored-by: Fokko Driesprong <[email protected]>
1 parent aceed2a commit dceedfa

File tree

5 files changed

+211
-164
lines changed

5 files changed

+211
-164
lines changed

pyiceberg/io/pyarrow.py

+45
Original file line numberDiff line numberDiff line change
@@ -2032,6 +2032,49 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[
20322032
return bin_packed_record_batches
20332033

20342034

2035+
def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None:
2036+
"""
2037+
Check if the `table_schema` is compatible with `other_schema`.
2038+
2039+
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
2040+
2041+
Raises:
2042+
ValueError: If the schemas are not compatible.
2043+
"""
2044+
name_mapping = table_schema.name_mapping
2045+
try:
2046+
task_schema = pyarrow_to_schema(
2047+
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
2048+
)
2049+
except ValueError as e:
2050+
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
2051+
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
2052+
raise ValueError(
2053+
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
2054+
) from e
2055+
2056+
if table_schema.as_struct() != task_schema.as_struct():
2057+
from rich.console import Console
2058+
from rich.table import Table as RichTable
2059+
2060+
console = Console(record=True)
2061+
2062+
rich_table = RichTable(show_header=True, header_style="bold")
2063+
rich_table.add_column("")
2064+
rich_table.add_column("Table field")
2065+
rich_table.add_column("Dataframe field")
2066+
2067+
for lhs in table_schema.fields:
2068+
try:
2069+
rhs = task_schema.find_field(lhs.field_id)
2070+
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
2071+
except ValueError:
2072+
rich_table.add_row("❌", str(lhs), "Missing")
2073+
2074+
console.print(rich_table)
2075+
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
2076+
2077+
20352078
def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
20362079
for file_path in file_paths:
20372080
input_file = io.new_input(file_path)
@@ -2043,6 +2086,8 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_
20432086
f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids"
20442087
)
20452088
schema = table_metadata.schema()
2089+
_check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())
2090+
20462091
statistics = data_file_statistics_from_parquet_metadata(
20472092
parquet_metadata=parquet_metadata,
20482093
stats_columns=compute_statistics_plan(schema, table_metadata.properties),

pyiceberg/table/__init__.py

+10-52
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
manifest_evaluator,
7474
)
7575
from pyiceberg.io import FileIO, OutputFile, load_file_io
76-
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
76+
from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table
7777
from pyiceberg.manifest import (
7878
POSITIONAL_DELETE_SCHEMA,
7979
DataFile,
@@ -166,54 +166,8 @@
166166

167167
ALWAYS_TRUE = AlwaysTrue()
168168
TABLE_ROOT_ID = -1
169-
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
170169
_JAVA_LONG_MAX = 9223372036854775807
171-
172-
173-
def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
174-
"""
175-
Check if the `table_schema` is compatible with `other_schema`.
176-
177-
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
178-
179-
Raises:
180-
ValueError: If the schemas are not compatible.
181-
"""
182-
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema
183-
184-
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
185-
name_mapping = table_schema.name_mapping
186-
try:
187-
task_schema = pyarrow_to_schema(
188-
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
189-
)
190-
except ValueError as e:
191-
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
192-
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
193-
raise ValueError(
194-
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
195-
) from e
196-
197-
if table_schema.as_struct() != task_schema.as_struct():
198-
from rich.console import Console
199-
from rich.table import Table as RichTable
200-
201-
console = Console(record=True)
202-
203-
rich_table = RichTable(show_header=True, header_style="bold")
204-
rich_table.add_column("")
205-
rich_table.add_column("Table field")
206-
rich_table.add_column("Dataframe field")
207-
208-
for lhs in table_schema.fields:
209-
try:
210-
rhs = task_schema.find_field(lhs.field_id)
211-
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
212-
except ValueError:
213-
rich_table.add_row("❌", str(lhs), "Missing")
214-
215-
console.print(rich_table)
216-
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
170+
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
217171

218172

219173
class TableProperties:
@@ -526,8 +480,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
526480
raise ValueError(
527481
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
528482
)
529-
530-
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
483+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
484+
_check_schema_compatible(
485+
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
486+
)
531487
# cast if the two schemas are compatible but not equal
532488
table_arrow_schema = self._table.schema().as_arrow()
533489
if table_arrow_schema != df.schema:
@@ -585,8 +541,10 @@ def overwrite(
585541
raise ValueError(
586542
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
587543
)
588-
589-
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
544+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
545+
_check_schema_compatible(
546+
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
547+
)
590548
# cast if the two schemas are compatible but not equal
591549
table_arrow_schema = self._table.schema().as_arrow()
592550
if table_arrow_schema != df.schema:

tests/integration/test_add_files.py

+65-20
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint:disable=redefined-outer-name
1818

1919
import os
20+
import re
2021
from datetime import date
2122
from typing import Iterator
2223

@@ -463,6 +464,57 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat
463464
assert summary["snapshot_prop_a"] == "test_prop_a"
464465

465466

467+
@pytest.mark.integration
468+
def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
469+
identifier = f"default.table_schema_mismatch_fails_v{format_version}"
470+
471+
tbl = _create_table(session_catalog, identifier, format_version)
472+
WRONG_SCHEMA = pa.schema([
473+
("foo", pa.bool_()),
474+
("bar", pa.string()),
475+
("baz", pa.string()), # should be integer
476+
("qux", pa.date32()),
477+
])
478+
file_path = f"s3://warehouse/default/table_schema_mismatch_fails/v{format_version}/test.parquet"
479+
# write parquet files
480+
fo = tbl.io.new_output(file_path)
481+
with fo.create(overwrite=True) as fos:
482+
with pq.ParquetWriter(fos, schema=WRONG_SCHEMA) as writer:
483+
writer.write_table(
484+
pa.Table.from_pylist(
485+
[
486+
{
487+
"foo": True,
488+
"bar": "bar_string",
489+
"baz": "123",
490+
"qux": date(2024, 3, 7),
491+
},
492+
{
493+
"foo": True,
494+
"bar": "bar_string",
495+
"baz": "124",
496+
"qux": date(2024, 3, 7),
497+
},
498+
],
499+
schema=WRONG_SCHEMA,
500+
)
501+
)
502+
503+
expected = """Mismatch in fields:
504+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
505+
┃ ┃ Table field ┃ Dataframe field ┃
506+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
507+
│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │
508+
| ✅ │ 2: bar: optional string │ 2: bar: optional string │
509+
│ ❌ │ 3: baz: optional int │ 3: baz: optional string │
510+
│ ✅ │ 4: qux: optional date │ 4: qux: optional date │
511+
└────┴──────────────────────────┴──────────────────────────┘
512+
"""
513+
514+
with pytest.raises(ValueError, match=expected):
515+
tbl.add_files(file_paths=[file_path])
516+
517+
466518
@pytest.mark.integration
467519
def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
468520
identifier = f"default.unpartitioned_with_large_types{format_version}"
@@ -518,7 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca
518570
assert table_schema == arrow_schema_large
519571

520572

521-
def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
573+
def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
522574
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))
523575

524576
nanoseconds_schema = pa.schema([
@@ -549,25 +601,18 @@ def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_versi
549601
partition_spec=PartitionSpec(),
550602
)
551603

552-
file_paths = [f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test-{i}.parquet" for i in range(5)]
604+
file_path = f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet"
553605
# write parquet files
554-
for file_path in file_paths:
555-
fo = tbl.io.new_output(file_path)
556-
with fo.create(overwrite=True) as fos:
557-
with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
558-
writer.write_table(arrow_table)
606+
fo = tbl.io.new_output(file_path)
607+
with fo.create(overwrite=True) as fos:
608+
with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
609+
writer.write_table(arrow_table)
559610

560611
# add the parquet files as data files
561-
tbl.add_files(file_paths=file_paths)
562-
563-
assert tbl.scan().to_arrow() == pa.concat_tables(
564-
[
565-
arrow_table.cast(
566-
pa.schema([
567-
("quux", pa.timestamp("us", tz="UTC")),
568-
]),
569-
safe=False,
570-
)
571-
]
572-
* 5
573-
)
612+
with pytest.raises(
613+
TypeError,
614+
match=re.escape(
615+
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write."
616+
),
617+
):
618+
tbl.add_files(file_paths=[file_path])

tests/io/test_pyarrow.py

+91
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
PyArrowFile,
6161
PyArrowFileIO,
6262
StatsAggregator,
63+
_check_schema_compatible,
6364
_ConvertToArrowSchema,
6465
_determine_partitions,
6566
_primitive_to_physical,
@@ -1722,6 +1723,96 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
17221723
assert len(list(bin_packed)) == 5
17231724

17241725

1726+
def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
1727+
other_schema = pa.schema((
1728+
pa.field("foo", pa.string(), nullable=True),
1729+
pa.field("bar", pa.decimal128(18, 6), nullable=False),
1730+
pa.field("baz", pa.bool_(), nullable=True),
1731+
))
1732+
1733+
expected = r"""Mismatch in fields:
1734+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
1735+
┃ ┃ Table field ┃ Dataframe field ┃
1736+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1737+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1738+
│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
1739+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1740+
└────┴──────────────────────────┴─────────────────────────────────┘
1741+
"""
1742+
1743+
with pytest.raises(ValueError, match=expected):
1744+
_check_schema_compatible(table_schema_simple, other_schema)
1745+
1746+
1747+
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
1748+
other_schema = pa.schema((
1749+
pa.field("foo", pa.string(), nullable=True),
1750+
pa.field("bar", pa.int32(), nullable=True),
1751+
pa.field("baz", pa.bool_(), nullable=True),
1752+
))
1753+
1754+
expected = """Mismatch in fields:
1755+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
1756+
┃ ┃ Table field ┃ Dataframe field ┃
1757+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1758+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1759+
│ ❌ │ 2: bar: required int │ 2: bar: optional int │
1760+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1761+
└────┴──────────────────────────┴──────────────────────────┘
1762+
"""
1763+
1764+
with pytest.raises(ValueError, match=expected):
1765+
_check_schema_compatible(table_schema_simple, other_schema)
1766+
1767+
1768+
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
1769+
other_schema = pa.schema((
1770+
pa.field("foo", pa.string(), nullable=True),
1771+
pa.field("baz", pa.bool_(), nullable=True),
1772+
))
1773+
1774+
expected = """Mismatch in fields:
1775+
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
1776+
┃ ┃ Table field ┃ Dataframe field ┃
1777+
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
1778+
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
1779+
│ ❌ │ 2: bar: required int │ Missing │
1780+
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
1781+
└────┴──────────────────────────┴──────────────────────────┘
1782+
"""
1783+
1784+
with pytest.raises(ValueError, match=expected):
1785+
_check_schema_compatible(table_schema_simple, other_schema)
1786+
1787+
1788+
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
1789+
other_schema = pa.schema((
1790+
pa.field("foo", pa.string(), nullable=True),
1791+
pa.field("bar", pa.int32(), nullable=True),
1792+
pa.field("baz", pa.bool_(), nullable=True),
1793+
pa.field("new_field", pa.date32(), nullable=True),
1794+
))
1795+
1796+
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."
1797+
1798+
with pytest.raises(ValueError, match=expected):
1799+
_check_schema_compatible(table_schema_simple, other_schema)
1800+
1801+
1802+
def test_schema_downcast(table_schema_simple: Schema) -> None:
1803+
# large_string type is compatible with string type
1804+
other_schema = pa.schema((
1805+
pa.field("foo", pa.large_string(), nullable=True),
1806+
pa.field("bar", pa.int32(), nullable=False),
1807+
pa.field("baz", pa.bool_(), nullable=True),
1808+
))
1809+
1810+
try:
1811+
_check_schema_compatible(table_schema_simple, other_schema)
1812+
except Exception:
1813+
pytest.fail("Unexpected Exception raised when calling `_check_schema`")
1814+
1815+
17251816
def test_partition_for_demo() -> None:
17261817
test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
17271818
test_schema = Schema(

0 commit comments

Comments
 (0)