Skip to content

include_field_ids flag in schema_to_pyarrow #789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,15 +469,18 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.fs_by_scheme = lru_cache(self._initialize_fs)


def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata))
def schema_to_pyarrow(
schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True
) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))


class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
_metadata: Dict[bytes, bytes]

def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT) -> None:
def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True) -> None:
self._metadata = metadata
self._include_field_ids = include_field_ids

def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema:
return pa.schema(list(struct_result), metadata=self._metadata)
Expand All @@ -486,13 +489,17 @@ def struct(self, _: StructType, field_results: List[pa.DataType]) -> pa.DataType
return pa.struct(field_results)

def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
metadata = {}
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)

return pa.field(
name=field.name,
type=field_result,
nullable=field.optional,
metadata={PYARROW_FIELD_DOC_KEY: field.doc, PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)}
if field.doc
else {PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)},
metadata=metadata,
)

def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
Expand Down Expand Up @@ -1130,7 +1137,7 @@ def project_table(
tables = [f.result() for f in completed_futures if f.result()]

if len(tables) < 1:
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema))
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False))

result = pa.concat_tables(tables)

Expand Down Expand Up @@ -1161,7 +1168,7 @@ def __init__(self, file_schema: Schema):
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive and field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand All @@ -1188,7 +1195,7 @@ def struct(
field_arrays.append(array)
fields.append(self._construct_field(field, array.type))
elif field.optional:
arrow_type = schema_to_pyarrow(field.field_type)
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False)
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
fields.append(self._construct_field(field, arrow_type))
else:
Expand Down
57 changes: 29 additions & 28 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test_deleting_hdfs_file_not_found() -> None:
assert "Cannot delete file, does not exist:" in str(exc_info.value)


def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested)
expected = """foo: string
-- field metadata --
Expand Down Expand Up @@ -402,6 +402,30 @@ def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None:
assert repr(actual) == expected


def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: Schema) -> None:
actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False)
expected = """foo: string
bar: int32 not null
baz: bool
qux: list<element: string not null> not null
child 0, element: string not null
quux: map<string, map<string, int32>> not null
child 0, entries: struct<key: string not null, value: map<string, int32> not null> not null
child 0, key: string not null
child 1, value: map<string, int32> not null
child 0, entries: struct<key: string not null, value: int32 not null> not null
child 0, key: string not null
child 1, value: int32 not null
location: list<element: struct<latitude: float, longitude: float> not null> not null
child 0, element: struct<latitude: float, longitude: float> not null
child 0, latitude: float
child 1, longitude: float
person: struct<name: string, age: int32 not null>
child 0, name: string
child 1, age: int32 not null"""
assert repr(actual) == expected


def test_fixed_type_to_pyarrow() -> None:
length = 22
iceberg_type = FixedType(length)
Expand Down Expand Up @@ -945,23 +969,13 @@ def test_projection_add_column(file_int: str) -> None:
== """id: int32
list: list<element: int32>
child 0, element: int32
-- field metadata --
PARQUET:field_id: '21'
map: map<int32, string>
child 0, entries: struct<key: int32 not null, value: string> not null
child 0, key: int32 not null
-- field metadata --
PARQUET:field_id: '31'
child 1, value: string
-- field metadata --
PARQUET:field_id: '32'
location: struct<lat: double, lon: double>
child 0, lat: double
-- field metadata --
PARQUET:field_id: '41'
child 1, lon: double
-- field metadata --
PARQUET:field_id: '42'"""
child 1, lon: double"""
)


Expand Down Expand Up @@ -1014,11 +1028,7 @@ def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None
== """id: map<int32, string>
child 0, entries: struct<key: int32 not null, value: string> not null
child 0, key: int32 not null
-- field metadata --
PARQUET:field_id: '3'
child 1, value: string
-- field metadata --
PARQUET:field_id: '4'"""
child 1, value: string"""
)


Expand Down Expand Up @@ -1062,12 +1072,7 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
assert (
repr(result_table.schema)
== """id: int32
-- field metadata --
PARQUET:field_id: '1'"""
)
assert repr(result_table.schema) == """id: int32"""


def test_projection_filter_renamed_column(file_int: str) -> None:
Expand Down Expand Up @@ -1304,11 +1309,7 @@ def test_projection_nested_struct_different_parent_id(file_struct: str) -> None:
repr(result_table.schema)
== """location: struct<lat: double, long: double>
child 0, lat: double
-- field metadata --
PARQUET:field_id: '41'
child 1, long: double
-- field metadata --
PARQUET:field_id: '42'"""
child 1, long: double"""
)


Expand Down