Skip to content

Commit b8c5bb7

Browse files
authored
Support Table.to_arrow_batch_reader (#786)
* _task_to_table to _task_to_record_batches * to_arrow_batches * tests * fix * fix * deletes * batch reader * merge main * adopt review feedback
1 parent 2182060 commit b8c5bb7

File tree

4 files changed

+269
-39
lines changed

4 files changed

+269
-39
lines changed

Diff for: mkdocs/docs/api.md

+9
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,15 @@ tpep_dropoff_datetime: [[2021-04-01 00:47:59.000000,...,2021-05-01 00:14:47.0000
10031003

10041004
This will only pull in the files that that might contain matching rows.
10051005

1006+
One can also return a PyArrow RecordBatchReader, if reading one record batch at a time is preferred:
1007+
1008+
```python
1009+
table.scan(
1010+
row_filter=GreaterThanOrEqual("trip_distance", 10.0),
1011+
selected_fields=("VendorID", "tpep_pickup_datetime", "tpep_dropoff_datetime"),
1012+
).to_arrow_batch_reader()
1013+
```
1014+
10061015
### Pandas
10071016

10081017
<!-- prettier-ignore-start -->

Diff for: pyiceberg/io/pyarrow.py

+116-39
Original file line numberDiff line numberDiff line change
@@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedAr
655655
}
656656

657657

658-
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: int) -> pa.Array:
658+
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array:
659659
if len(positional_deletes) == 1:
660660
all_chunks = positional_deletes[0]
661661
else:
662662
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
663-
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
663+
return np.subtract(np.setdiff1d(np.arange(start_index, end_index), all_chunks, assume_unique=False), start_index)
664664

665665

666666
def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = None) -> Schema:
@@ -995,17 +995,16 @@ def _field_id(self, field: pa.Field) -> int:
995995
return -1
996996

997997

998-
def _task_to_table(
998+
def _task_to_record_batches(
999999
fs: FileSystem,
10001000
task: FileScanTask,
10011001
bound_row_filter: BooleanExpression,
10021002
projected_schema: Schema,
10031003
projected_field_ids: Set[int],
10041004
positional_deletes: Optional[List[ChunkedArray]],
10051005
case_sensitive: bool,
1006-
limit: Optional[int] = None,
10071006
name_mapping: Optional[NameMapping] = None,
1008-
) -> Optional[pa.Table]:
1007+
) -> Iterator[pa.RecordBatch]:
10091008
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
10101009
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
10111010
with fs.open_input_file(path) as fin:
@@ -1035,36 +1034,39 @@ def _task_to_table(
10351034
columns=[col.name for col in file_project_schema.columns],
10361035
)
10371036

1038-
if positional_deletes:
1039-
# Create the mask of indices that we're interested in
1040-
indices = _combine_positional_deletes(positional_deletes, fragment.count_rows())
1041-
1042-
if limit:
1043-
if pyarrow_filter is not None:
1044-
# In case of the filter, we don't exactly know how many rows
1045-
# we need to fetch upfront, can be optimized in the future:
1046-
# https://github.com/apache/arrow/issues/35301
1047-
arrow_table = fragment_scanner.take(indices)
1048-
arrow_table = arrow_table.filter(pyarrow_filter)
1049-
arrow_table = arrow_table.slice(0, limit)
1050-
else:
1051-
arrow_table = fragment_scanner.take(indices[0:limit])
1052-
else:
1053-
arrow_table = fragment_scanner.take(indices)
1037+
current_index = 0
1038+
batches = fragment_scanner.to_batches()
1039+
for batch in batches:
1040+
if positional_deletes:
1041+
# Create the mask of indices that we're interested in
1042+
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
1043+
batch = batch.take(indices)
10541044
# Apply the user filter
10551045
if pyarrow_filter is not None:
1046+
# we need to switch back and forth between RecordBatch and Table
1047+
# as Expression filter isn't yet supported in RecordBatch
1048+
# https://github.com/apache/arrow/issues/39220
1049+
arrow_table = pa.Table.from_batches([batch])
10561050
arrow_table = arrow_table.filter(pyarrow_filter)
1057-
else:
1058-
# If there are no deletes, we can just take the head
1059-
# and the user-filter is already applied
1060-
if limit:
1061-
arrow_table = fragment_scanner.head(limit)
1062-
else:
1063-
arrow_table = fragment_scanner.to_table()
1051+
batch = arrow_table.to_batches()[0]
1052+
yield to_requested_schema(projected_schema, file_project_schema, batch)
1053+
current_index += len(batch)
10641054

1065-
if len(arrow_table) < 1:
1066-
return None
1067-
return to_requested_schema(projected_schema, file_project_schema, arrow_table)
1055+
1056+
def _task_to_table(
1057+
fs: FileSystem,
1058+
task: FileScanTask,
1059+
bound_row_filter: BooleanExpression,
1060+
projected_schema: Schema,
1061+
projected_field_ids: Set[int],
1062+
positional_deletes: Optional[List[ChunkedArray]],
1063+
case_sensitive: bool,
1064+
name_mapping: Optional[NameMapping] = None,
1065+
) -> pa.Table:
1066+
batches = _task_to_record_batches(
1067+
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
1068+
)
1069+
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
10681070

10691071

10701072
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
@@ -1143,7 +1145,6 @@ def project_table(
11431145
projected_field_ids,
11441146
deletes_per_file.get(task.file.file_path),
11451147
case_sensitive,
1146-
limit,
11471148
table_metadata.name_mapping(),
11481149
)
11491150
for task in tasks
@@ -1177,16 +1178,86 @@ def project_table(
11771178
return result
11781179

11791180

1180-
def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
1181-
struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
1181+
def project_batches(
1182+
tasks: Iterable[FileScanTask],
1183+
table_metadata: TableMetadata,
1184+
io: FileIO,
1185+
row_filter: BooleanExpression,
1186+
projected_schema: Schema,
1187+
case_sensitive: bool = True,
1188+
limit: Optional[int] = None,
1189+
) -> Iterator[pa.RecordBatch]:
1190+
"""Resolve the right columns based on the identifier.
1191+
1192+
Args:
1193+
tasks (Iterable[FileScanTask]): A URI or a path to a local file.
1194+
table_metadata (TableMetadata): The table metadata of the table that's being queried
1195+
io (FileIO): A FileIO to open streams to the object store
1196+
row_filter (BooleanExpression): The expression for filtering rows.
1197+
projected_schema (Schema): The output schema.
1198+
case_sensitive (bool): Case sensitivity when looking up column names.
1199+
limit (Optional[int]): Limit the number of records.
1200+
1201+
Raises:
1202+
ResolveError: When an incompatible query is done.
1203+
"""
1204+
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
1205+
if isinstance(io, PyArrowFileIO):
1206+
fs = io.fs_by_scheme(scheme, netloc)
1207+
else:
1208+
try:
1209+
from pyiceberg.io.fsspec import FsspecFileIO
1210+
1211+
if isinstance(io, FsspecFileIO):
1212+
from pyarrow.fs import PyFileSystem
1213+
1214+
fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
1215+
else:
1216+
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
1217+
except ModuleNotFoundError as e:
1218+
# When FsSpec is not installed
1219+
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e
1220+
1221+
bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
1222+
1223+
projected_field_ids = {
1224+
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
1225+
}.union(extract_field_ids(bound_row_filter))
1226+
1227+
deletes_per_file = _read_all_delete_files(fs, tasks)
1228+
1229+
total_row_count = 0
1230+
1231+
for task in tasks:
1232+
batches = _task_to_record_batches(
1233+
fs,
1234+
task,
1235+
bound_row_filter,
1236+
projected_schema,
1237+
projected_field_ids,
1238+
deletes_per_file.get(task.file.file_path),
1239+
case_sensitive,
1240+
table_metadata.name_mapping(),
1241+
)
1242+
for batch in batches:
1243+
if limit is not None:
1244+
if total_row_count + len(batch) >= limit:
1245+
yield batch.slice(0, limit - total_row_count)
1246+
break
1247+
yield batch
1248+
total_row_count += len(batch)
1249+
1250+
1251+
def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
1252+
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
11821253

11831254
arrays = []
11841255
fields = []
11851256
for pos, field in enumerate(requested_schema.fields):
11861257
array = struct_array.field(pos)
11871258
arrays.append(array)
11881259
fields.append(pa.field(field.name, array.type, field.optional))
1189-
return pa.Table.from_arrays(arrays, schema=pa.schema(fields))
1260+
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
11901261

11911262

11921263
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
@@ -1293,8 +1364,10 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st
12931364

12941365
if isinstance(partner_struct, pa.StructArray):
12951366
return partner_struct.field(name)
1296-
elif isinstance(partner_struct, pa.Table):
1297-
return partner_struct.column(name).combine_chunks()
1367+
elif isinstance(partner_struct, pa.RecordBatch):
1368+
return partner_struct.column(name)
1369+
else:
1370+
raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}")
12981371

12991372
return None
13001373

@@ -1831,15 +1904,19 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
18311904

18321905
def write_parquet(task: WriteTask) -> DataFile:
18331906
table_schema = task.schema
1834-
arrow_table = pa.Table.from_batches(task.record_batches)
1907+
18351908
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
18361909
# otherwise use the original schema
18371910
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
18381911
file_schema = sanitized_schema
18391912
else:
18401913
file_schema = table_schema
18411914

1842-
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
1915+
batches = [
1916+
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
1917+
for batch in task.record_batches
1918+
]
1919+
arrow_table = pa.Table.from_batches(batches)
18431920
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
18441921
fo = io.new_output(file_path)
18451922
with fo.create(overwrite=True) as fos:

Diff for: pyiceberg/table/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,24 @@ def to_arrow(self) -> pa.Table:
18781878
limit=self.limit,
18791879
)
18801880

1881+
def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
1882+
import pyarrow as pa
1883+
1884+
from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow
1885+
1886+
return pa.RecordBatchReader.from_batches(
1887+
schema_to_pyarrow(self.projection()),
1888+
project_batches(
1889+
self.plan_files(),
1890+
self.table_metadata,
1891+
self.io,
1892+
self.row_filter,
1893+
self.projection(),
1894+
case_sensitive=self.case_sensitive,
1895+
limit=self.limit,
1896+
),
1897+
)
1898+
18811899
def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
18821900
return self.to_arrow().to_pandas(**kwargs)
18831901

0 commit comments

Comments
 (0)