Skip to content

Commit e61ef57

Browse files
authored
Add include_field_ids flag in schema_to_pyarrow (#789)
* include_field_ids flag * include_field_ids flag
1 parent 31c6c23 commit e61ef57

File tree

2 files changed

+45
-37
lines changed

2 files changed

+45
-37
lines changed

pyiceberg/io/pyarrow.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -469,15 +469,18 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
469469
self.fs_by_scheme = lru_cache(self._initialize_fs)
470470

471471

472-
def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
473-
return visit(schema, _ConvertToArrowSchema(metadata))
472+
def schema_to_pyarrow(
473+
schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True
474+
) -> pa.schema:
475+
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))
474476

475477

476478
class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
477479
_metadata: Dict[bytes, bytes]
478480

479-
def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT) -> None:
481+
def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True) -> None:
480482
self._metadata = metadata
483+
self._include_field_ids = include_field_ids
481484

482485
def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema:
483486
return pa.schema(list(struct_result), metadata=self._metadata)
@@ -486,13 +489,17 @@ def struct(self, _: StructType, field_results: List[pa.DataType]) -> pa.DataType
486489
return pa.struct(field_results)
487490

488491
def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
492+
metadata = {}
493+
if field.doc:
494+
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
495+
if self._include_field_ids:
496+
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
497+
489498
return pa.field(
490499
name=field.name,
491500
type=field_result,
492501
nullable=field.optional,
493-
metadata={PYARROW_FIELD_DOC_KEY: field.doc, PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)}
494-
if field.doc
495-
else {PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)},
502+
metadata=metadata,
496503
)
497504

498505
def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
@@ -1130,7 +1137,7 @@ def project_table(
11301137
tables = [f.result() for f in completed_futures if f.result()]
11311138

11321139
if len(tables) < 1:
1133-
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema))
1140+
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
11341141

11351142
result = pa.concat_tables(tables)
11361143

@@ -1161,7 +1168,7 @@ def __init__(self, file_schema: Schema):
11611168
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
11621169
file_field = self.file_schema.find_field(field.field_id)
11631170
if field.field_type.is_primitive and field.field_type != file_field.field_type:
1164-
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
1171+
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
11651172
return values
11661173

11671174
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
@@ -1188,7 +1195,7 @@ def struct(
11881195
field_arrays.append(array)
11891196
fields.append(self._construct_field(field, array.type))
11901197
elif field.optional:
1191-
arrow_type = schema_to_pyarrow(field.field_type)
1198+
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False)
11921199
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
11931200
fields.append(self._construct_field(field, arrow_type))
11941201
else:

tests/io/test_pyarrow.py

+29-28
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def test_deleting_hdfs_file_not_found() -> None:
344344
assert "Cannot delete file, does not exist:" in str(exc_info.value)
345345

346346

347-
def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
347+
def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: Schema) -> None:
348348
actual = schema_to_pyarrow(table_schema_nested)
349349
expected = """foo: string
350350
-- field metadata --
@@ -402,6 +402,30 @@ def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
402402
assert repr(actual) == expected
403403

404404

405+
def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: Schema) -> None:
406+
actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False)
407+
expected = """foo: string
408+
bar: int32 not null
409+
baz: bool
410+
qux: list<element: string not null> not null
411+
child 0, element: string not null
412+
quux: map<string, map<string, int32>> not null
413+
child 0, entries: struct<key: string not null, value: map<string, int32> not null> not null
414+
child 0, key: string not null
415+
child 1, value: map<string, int32> not null
416+
child 0, entries: struct<key: string not null, value: int32 not null> not null
417+
child 0, key: string not null
418+
child 1, value: int32 not null
419+
location: list<element: struct<latitude: float, longitude: float> not null> not null
420+
child 0, element: struct<latitude: float, longitude: float> not null
421+
child 0, latitude: float
422+
child 1, longitude: float
423+
person: struct<name: string, age: int32 not null>
424+
child 0, name: string
425+
child 1, age: int32 not null"""
426+
assert repr(actual) == expected
427+
428+
405429
def test_fixed_type_to_pyarrow() -> None:
406430
length = 22
407431
iceberg_type = FixedType(length)
@@ -945,23 +969,13 @@ def test_projection_add_column(file_int: str) -> None:
945969
== """id: int32
946970
list: list<element: int32>
947971
child 0, element: int32
948-
-- field metadata --
949-
PARQUET:field_id: '21'
950972
map: map<int32, string>
951973
child 0, entries: struct<key: int32 not null, value: string> not null
952974
child 0, key: int32 not null
953-
-- field metadata --
954-
PARQUET:field_id: '31'
955975
child 1, value: string
956-
-- field metadata --
957-
PARQUET:field_id: '32'
958976
location: struct<lat: double, lon: double>
959977
child 0, lat: double
960-
-- field metadata --
961-
PARQUET:field_id: '41'
962-
child 1, lon: double
963-
-- field metadata --
964-
PARQUET:field_id: '42'"""
978+
child 1, lon: double"""
965979
)
966980

967981

@@ -1014,11 +1028,7 @@ def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None
10141028
== """id: map<int32, string>
10151029
child 0, entries: struct<key: int32 not null, value: string> not null
10161030
child 0, key: int32 not null
1017-
-- field metadata --
1018-
PARQUET:field_id: '3'
1019-
child 1, value: string
1020-
-- field metadata --
1021-
PARQUET:field_id: '4'"""
1031+
child 1, value: string"""
10221032
)
10231033

10241034

@@ -1062,12 +1072,7 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
10621072
def test_projection_filter(schema_int: Schema, file_int: str) -> None:
10631073
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
10641074
assert len(result_table.columns[0]) == 0
1065-
assert (
1066-
repr(result_table.schema)
1067-
== """id: int32
1068-
-- field metadata --
1069-
PARQUET:field_id: '1'"""
1070-
)
1075+
assert repr(result_table.schema) == """id: int32"""
10711076

10721077

10731078
def test_projection_filter_renamed_column(file_int: str) -> None:
@@ -1304,11 +1309,7 @@ def test_projection_nested_struct_different_parent_id(file_struct: str) -> None:
13041309
repr(result_table.schema)
13051310
== """location: struct<lat: double, long: double>
13061311
child 0, lat: double
1307-
-- field metadata --
1308-
PARQUET:field_id: '41'
1309-
child 1, long: double
1310-
-- field metadata --
1311-
PARQUET:field_id: '42'"""
1312+
child 1, long: double"""
13121313
)
13131314

13141315

0 commit comments

Comments
 (0)