Skip to content

[Feature] Add Support for Distributed Write #1751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
andormarkus opened this issue Mar 3, 2025 · 5 comments
Open

[Feature] Add Support for Distributed Write #1751

andormarkus opened this issue Mar 3, 2025 · 5 comments

Comments

@andormarkus
Copy link

andormarkus commented Mar 3, 2025

Feature Request / Improvement

Problem Statement

A key problem in distributed Iceberg systems is that commit processes can block each other when multiple workers try to update table metadata simultaneously. This blocking creates a severe performance bottleneck that limits throughput, particularly in high-volume ingestion scenarios.

Use Case

In our distributed architecture:

  1. Process A writes Parquet files in Iceberg-compatible format
  2. Simple string identifiers (file paths) need to be passed between systems
  3. Process B takes these strings and commits the files to make them visible in queries

This pattern is especially useful for high-concurrency ingestion scenarios where multiple writers could be writing data to an Iceberg table simultaneously, but we want to centralize and coordinate the commit process. This approach is critical because in distributed environments, commit processes can block each other, creating a significant bottleneck in high-throughput scenarios.

Detailed Workflow

Our workflow involves:

# Process A: Write data but don't commit
table = catalog.load_table(identifier="iceberg.table")
data_files = list(pyiceberg.io.pyarrow._dataframe_to_data_files(
        table_metadata=table.metadata, write_uuid=uuid.uuid4(), df=pa_df, io=table.io
    )
)

queue.send(data_files)  # Send data_files strings to queue system

# Process B: Commit processor (runs separately)
data_files = queue.receive()
with table.transaction() as trx:
    with trx.update_snapshot().fast_append() as update_snapshot:
        for data_file in data_files:
            update_snapshot.append_data_file(data_file)

This separation of write and commit operations provides several advantages:

  • Improved throughput by parallelizing write operations across multiple workers
  • Reduced lock contention since metadata commits (which require locks) are centralized
  • Better failure handling - failed writes don't impact the table state
  • Controlled transaction timing - commits can be batched or scheduled optimally
  • Elimination of commit process blocking - by centralizing commits, we prevent distributed writers from blocking each other during metadata updates, which is a major performance bottleneck

Current Limitations

  • Serializing DataFile objects between processes is challenging
  • We've attempted custom serialization with compression (gzip, zlib), which is working however required long complex code
  • Using jsonpickle also presented significant problems

Proposed Solution

We're seeking a robust way to handle distributed writes, potentially with:

  1. Add serialization/deserialization methods to the DataFile class
  2. Support Avro for efficient serialization of DataFile objects (potentially smaller than other approaches)
  3. Better integration with append_data_file API
  4. OR a more accessible way to use the ManifestFile functionality that's already implemented in PyIceberg

Ideally, the solution would:

  • Handle schema evolution gracefully (unlike current add_files approach which has issues when schema changes)
  • Work efficiently with minimal overhead for large-scale concurrent processing
  • Provide simple primitives that can be used in distributed systems without requiring complex serialization
  • Follow patterns similar to those used in the Java implementation where appropriate

Alternative Approaches Tried

  • We've implemented a custom serialization/deserialization function with compression
  • We explored the approach in Check write snapshot compatibility #1678, but found it created too many commits and became a performance bottleneck

Related PRs/Issues

We're looking for guidance on the best approach to solve this distributed writing pattern while maintaining performance and schema compatibility.

@andormarkus
Copy link
Author

Hi @Fokko

Based on the source code writing to manifest / Avro can be achieved like this

manifest_path = f"temp-manifest-{uuid.uuid4()}.avro"
output_file = io.new_output(manifest_path)
    
# Write all datafiles to the manifest
with write_manifest(format_version, spec, schema, output_file, snapshot_id) as writer:
    for datafile in datafiles:
        writer.add(ManifestEntry(data_file=datafile))

I don't see how can we get back the DataFile from the manifest / Avro file. I need your guidance here,

@Fokko
Copy link
Contributor

Fokko commented Mar 4, 2025

@andormarkus Sure thing, does the following help:

from io import BytesIO

from pyiceberg.avro.decoder_fast import CythonBinaryDecoder
from pyiceberg.avro.encoder import BinaryEncoder
from pyiceberg.avro.resolver import construct_writer, resolve_reader
from pyiceberg.manifest import DATA_FILE_TYPE, DEFAULT_READ_VERSION, DataFile, DataFileContent, FileFormat
from pyiceberg.typedef import Record


def test_serialize():
    data_file = DataFile(
        content=DataFileContent.DATA,
        file_path="s3://some-path/some-file.parquet",
        file_format=FileFormat.PARQUET,
        partition=Record(),
        record_count=131327,
        file_size_in_bytes=220669226,
        column_sizes={1: 220661854},
        value_counts={1: 131327},
        null_value_counts={1: 0},
        nan_value_counts={},
        lower_bounds={1: b"aaaaaaaaaaaaaaaa"},
        upper_bounds={1: b"zzzzzzzzzzzzzzzz"},
        key_metadata=b"\xde\xad\xbe\xef",
        split_offsets=[4, 133697593],
        equality_ids=[],
        sort_order_id=4,
    )

    # Encode
    output = BytesIO()
    encoder = BinaryEncoder(output)
    schema = DATA_FILE_TYPE[DEFAULT_READ_VERSION]
    construct_writer(file_schema=schema).write(encoder, data_file)

    # Decode
    decoder = CythonBinaryDecoder(output.getvalue())
    result = resolve_reader(
        schema,
        schema,
        read_types={-1: DataFile},
    ).read(decoder)

    assert result.file_path == "s3://some-path/some-file.parquet"

@andormarkus
Copy link
Author

Hi @Fokko

Thank you soo much for the code snippet.

I have extended the test and run into the following problem with partitioned tables (non partitioned tables are passing the test):

  1. Serialisation deserialisation: Partition info is not replicated
Found differences: {'partition': {'_position_to_field_name': {'expected': ('portal_id', 'timestamp_day'), 'actual': ()}, 'attr.portal_id': {'expected': 9, 'actual': None}, 'attr.timestamp_day': {'expected': 20240301, 'actual': None}}}
  1. Append not working due missing partition info
self = Record[], pos = 0

    def __getitem__(self, pos: int) -> Any:
        """Fetch a value from a Record."""
>       return self.__getattribute__(self._position_to_field_name[pos])
E       IndexError: tuple index out of range

I could not figure it out in which steps the partition info is lost.

Source code:

import os
import uuid
from typing import Any, Dict

import pytest
import tempfile
import pyarrow as pa
from io import BytesIO

from pyiceberg.io.pyarrow import _dataframe_to_data_files
from pyiceberg.schema import Schema
from pyiceberg.types import NestedField, StringType, DoubleType, LongType, BinaryType
from pyiceberg.partitioning import PartitionSpec, PartitionField
from pyiceberg.transforms import IdentityTransform
from pyiceberg.avro.decoder_fast import CythonBinaryDecoder
from pyiceberg.avro.encoder import BinaryEncoder
from pyiceberg.avro.resolver import construct_writer, resolve_reader
from pyiceberg.manifest import DATA_FILE_TYPE, DEFAULT_READ_VERSION, DataFile
from pyiceberg.typedef import Record
from pyiceberg.catalog import load_catalog


def get_schema():
    # Define schema with partitioned fields
    return Schema(
        NestedField(1, "city", StringType(), required=False),
        NestedField(2, "lat", DoubleType(), required=False),
        NestedField(3, "long", DoubleType(), required=False),
        NestedField(4, "portal_id", LongType(), required=False),
        NestedField(5, "timestamp_day", LongType(), required=False),
        NestedField(6, "binary_data", BinaryType(), required=False),
    )


def get_partition_spec():
    # Define partition spec (portal_id, timestamp_day)
    return PartitionSpec(
        spec_id=0,
        fields=[
            PartitionField(
                source_id=4,
                field_id=1000,
                transform=IdentityTransform(),
                name="portal_id"
            ),
            PartitionField(
                source_id=5,
                field_id=1001,
                transform=IdentityTransform(),
                name="timestamp_day"
            ),
        ]
    )


def get_empty_partition_spec():
    # Define empty partition spec for non-partitioned tests
    return PartitionSpec(spec_id=0, fields=[])


def get_sample_data():
    # Create sample data with binary field
    return pa.Table.from_pylist([
        {"city": "Amsterdam", "lat": 52.371807, "long": 4.896029, "portal_id": 9,
         "timestamp_day": 20240301, "binary_data": b"Amsterdam data"},
        {"city": "San Francisco", "lat": 37.773972, "long": -122.431297, "portal_id": 9,
         "timestamp_day": 20240301, "binary_data": b"San Francisco data"},
        {"city": "Drachten", "lat": 53.11254, "long": 6.0989, "portal_id": 10,
         "timestamp_day": 20240302, "binary_data": b"Drachten data"},
        {"city": "Paris", "lat": 48.864716, "long": 2.349014, "portal_id": 10,
         "timestamp_day": 20240302, "binary_data": b"Paris data"},
    ])


def compare_datafiles(expected: Any, actual: Any) -> Dict[str, Any]:
    """
    Compare two DataFile objects and return differences.
    Returns empty dict if they're identical, otherwise returns the differences.

    Args:
        expected: First DataFile object to compare
        actual: Second DataFile object to compare

    Returns:
        Dictionary of differences, empty if objects are identical

    Raises:
        TypeError: If either argument is not a DataFile
    """
    # Input validation - make sure both are actually DataFile objects
    if not isinstance(expected, DataFile):
        raise TypeError(f"First argument must be a DataFile, got {type(expected)} instead")
    if not isinstance(actual, DataFile):
        raise TypeError(f"Second argument must be a DataFile, got {type(actual)} instead")

    differences = {}

    # Compare all slots from DataFile
    for slot in expected.__class__.__slots__:
        if slot == "_struct":  # Skip internal struct field
            continue

        if hasattr(expected, slot) and hasattr(actual, slot):
            expected_value = getattr(expected, slot)
            actual_value = getattr(actual, slot)

            # Special handling for different types
            if isinstance(expected_value, Record) and isinstance(actual_value, Record):
                # Enhanced comparison for Record objects (especially partition)
                record_differences = {}

                # Check structure (_position_to_field_name)
                if hasattr(expected_value, "_position_to_field_name") and hasattr(actual_value,
                                                                                  "_position_to_field_name"):
                    orig_fields = expected_value._position_to_field_name
                    des_fields = actual_value._position_to_field_name

                    if orig_fields != des_fields:
                        record_differences["_position_to_field_name"] = {
                            "expected": orig_fields,
                            "actual": des_fields
                        }

                # Check data values
                if hasattr(expected_value, "_data") and hasattr(actual_value, "_data"):
                    # Ensure both _data attributes are tuples/lists and have values
                    orig_data = expected_value._data if expected_value._data else ()
                    des_data = actual_value._data if actual_value._data else ()

                    # Check if one is empty but the other isn't
                    if bool(orig_data) != bool(des_data):
                        record_differences["_data_presence"] = {
                            "expected": "present" if orig_data else "empty",
                            "actual": "present" if des_data else "empty"
                        }

                    # Compare content if both exist
                    if orig_data and des_data:
                        if len(orig_data) != len(des_data):
                            record_differences["_data_length"] = {
                                "expected": len(orig_data),
                                "actual": len(des_data)
                            }
                        else:
                            # Compare each item
                            for i, (orig_item, des_item) in enumerate(zip(orig_data, des_data)):
                                if orig_item != des_item:
                                    record_differences[f"_data[{i}]"] = {
                                        "expected": orig_item,
                                        "actual": des_item
                                    }

                # Additional check: Try to access fields directly as attributes
                if hasattr(expected_value, "_position_to_field_name"):
                    for field_name in expected_value._position_to_field_name:
                        orig_attr = getattr(expected_value, field_name, None)
                        des_attr = getattr(actual_value, field_name, None)

                        if orig_attr != des_attr:
                            record_differences[f"attr.{field_name}"] = {
                                "expected": orig_attr,
                                "actual": des_attr
                            }

                # If any differences were found in the record
                if record_differences:
                    differences[slot] = record_differences

            elif isinstance(expected_value, dict) and isinstance(actual_value, dict):
                # Compare dictionaries (like lower_bounds, upper_bounds)
                if set(expected_value.keys()) != set(actual_value.keys()):
                    differences[f"{slot}.keys"] = {
                        "expected": set(expected_value.keys()),
                        "actual": set(actual_value.keys())
                    }

                # Compare values
                for key in expected_value:
                    if key in actual_value:
                        if expected_value[key] != actual_value[key]:
                            differences[f"{slot}[{key}]"] = {
                                "expected": expected_value[key],
                                "actual": actual_value[key]
                            }
            elif expected_value != actual_value:
                differences[slot] = {
                    "expected": expected_value,
                    "actual": actual_value
                }

    return differences


class TestIcebergBase:
    """Base class for Iceberg tests with shared methods"""

    def serialize_and_deserialize(self, sample_data_file):
        """Helper method to serialize and deserialize a DataFile"""
        # Encode
        output = BytesIO()
        encoder = BinaryEncoder(output)
        schema = DATA_FILE_TYPE[DEFAULT_READ_VERSION]
        construct_writer(file_schema=schema).write(encoder, sample_data_file)

        output = output.getvalue()

        # Decode
        decoder = CythonBinaryDecoder(output)
        actual_data_file = resolve_reader(
            schema,
            schema,
            read_types={-1: DataFile},
        ).read(decoder)

        return actual_data_file

    def append_data_file(self, table, data_file):
        """Helper method to append a DataFile to a table"""
        with table.transaction() as trx:
            with trx.update_snapshot().fast_append() as update_snapshot:
                update_snapshot.append_data_file(data_file)


@pytest.fixture(scope="class")
def iceberg_setup_with_partition():
    """Create a temporary Iceberg table with partitioning"""

    # Set data files output directory
    temp_dir = tempfile.mkdtemp()
    os.environ["PYICEBERG_PARQUET_OUTPUT"] = temp_dir

    # Create a catalog and schema
    catalog = load_catalog("catalog", type="in-memory")
    catalog.create_namespace("default")

    # Create table with partitioning
    table = catalog.create_table(
        identifier="default.cities_with_partition",
        schema=get_schema(),
        partition_spec=get_partition_spec()
    )

    # Create sample data with binary field
    data = get_sample_data()

    data_files = list(_dataframe_to_data_files(
        table_metadata=table.metadata, write_uuid=uuid.uuid4(), df=data, io=table.io))

    yield {
        "catalog": catalog,
        "table": table,
        "sample_data_file": data_files[0],
        "data_files": data_files
    }

    # Cleanup after all tests
    if os.path.exists(temp_dir):
        for root, dirs, files in os.walk(temp_dir, topdown=False):
            for file in files:
                os.remove(os.path.join(root, file))
            for dir in dirs:
                os.rmdir(os.path.join(root, dir))
        os.rmdir(temp_dir)


@pytest.fixture(scope="class")
def iceberg_setup_no_partition():
    """Create a temporary Iceberg table without partitioning"""

    # Set data files output directory
    temp_dir = tempfile.mkdtemp()
    os.environ["PYICEBERG_PARQUET_OUTPUT"] = temp_dir

    # Create a catalog and schema
    catalog = load_catalog("catalog", type="in-memory")
    catalog.create_namespace("default")

    # Create table without partitioning
    table = catalog.create_table(
        identifier="default.cities_no_partition",
        schema=get_schema(),
        partition_spec=get_empty_partition_spec()
    )

    # Create sample data with binary field
    data = get_sample_data()

    data_files = list(_dataframe_to_data_files(
        table_metadata=table.metadata, write_uuid=uuid.uuid4(), df=data, io=table.io))

    yield {
        "catalog": catalog,
        "table": table,
        "sample_data_file": data_files[0],
        "data_files": data_files
    }

    # Cleanup after all tests
    if os.path.exists(temp_dir):
        for root, dirs, files in os.walk(temp_dir, topdown=False):
            for file in files:
                os.remove(os.path.join(root, file))
            for dir in dirs:
                os.rmdir(os.path.join(root, dir))
        os.rmdir(temp_dir)


class TestIcebergWithPartition(TestIcebergBase):
    """Tests for Iceberg operations with partition"""

    @pytest.fixture(autouse=True)
    def setup(self, iceberg_setup_with_partition):
        """Setup for all tests in this class"""
        self.setup_data = iceberg_setup_with_partition
        self.sample_data_file = self.setup_data["sample_data_file"]
        self.data_files = self.setup_data["data_files"]
        self.table = self.setup_data["table"]

    def test_serialize(self):
        """Test serializing and deserializing DataFile with partition"""
        actual_data_file = self.serialize_and_deserialize(self.sample_data_file)

        differences = compare_datafiles(self.sample_data_file, actual_data_file)
        assert not differences, f"Found differences: {differences}"

    def test_fast_append_working(self):
        """Test fast append with native DataFile with partition"""
        self.append_data_file(self.table, self.data_files[0])

    def test_fast_append_with_avro(self):
        """Test fast append with Avro deserialized DataFile with partition"""
        actual_data_file = self.serialize_and_deserialize(self.sample_data_file)
        self.append_data_file(self.table, actual_data_file)


class TestIcebergNoPartition(TestIcebergBase):
    """Tests for Iceberg operations without partition"""

    @pytest.fixture(autouse=True)
    def setup(self, iceberg_setup_no_partition):
        """Setup for all tests in this class"""
        self.setup_data = iceberg_setup_no_partition
        self.sample_data_file = self.setup_data["sample_data_file"]
        self.data_files = self.setup_data["data_files"]
        self.table = self.setup_data["table"]

    def test_serialize(self):
        """Test serializing and deserializing DataFile without partition"""
        actual_data_file = self.serialize_and_deserialize(self.sample_data_file)

        differences = compare_datafiles(self.sample_data_file, actual_data_file)
        assert not differences, f"Found differences: {differences}"

    def test_fast_append_working(self):
        """Test fast append with native DataFile without partition"""
        self.append_data_file(self.table, self.data_files[0])

    def test_fast_append_with_avro(self):
        """Test fast append with Avro deserialized DataFile without partition"""
        actual_data_file = self.serialize_and_deserialize(self.sample_data_file)
        self.append_data_file(self.table, actual_data_file)

@andormarkus
Copy link
Author

Hi @Fokko,

I'd like to share a working example that demonstrates how to serialize and deserialize both partition and non-partitioned tables:

output = BytesIO()

# Get table schema
partition_type = table.spec().partition_type(schema=table.schema())
schema = data_file_with_partition(format_version= table.format_version, partition_type=partition_type)

# Encode
encoder = BinaryEncoder(output)
construct_writer(file_schema=schema).write(encoder, data_file)
output = output.getvalue()

# Decode
decoder = CythonBinaryDecoder(output)
actual_data_file = resolve_reader(schema, schema, read_types={-1: DataFile}, ).read(decoder)

I believe we should expand the documentation so the community can benefit from this work. However, I'm not sure which section would be most appropriate since this is an advanced topic that doesn't require code changes.

Additionally, I think we should create a public wrapper around pyiceberg.io.pyarrow._dataframe_to_data_files since achieving distributed writes currently requires using a private method.

Full unittest

import os
import uuid
from typing import Any, Dict

import pytest
import tempfile
import pyarrow as pa
from io import BytesIO

from pyiceberg.io.pyarrow import _dataframe_to_data_files
from pyiceberg.schema import Schema
from pyiceberg.types import NestedField, StringType, DoubleType, LongType, BinaryType
from pyiceberg.partitioning import PartitionSpec, PartitionField
from pyiceberg.transforms import IdentityTransform
from pyiceberg.avro.decoder_fast import CythonBinaryDecoder
from pyiceberg.avro.encoder import BinaryEncoder
from pyiceberg.avro.resolver import construct_writer, resolve_reader
from pyiceberg.manifest import DATA_FILE_TYPE, DEFAULT_READ_VERSION, DataFile, data_file_with_partition
from pyiceberg.typedef import Record
from pyiceberg.catalog import load_catalog


def get_schema():
    # Define schema with partitioned fields
    return Schema(
        NestedField(1, "city", StringType(), required=False),
        NestedField(2, "lat", DoubleType(), required=False),
        NestedField(3, "long", DoubleType(), required=False),
        NestedField(4, "portal_id", LongType(), required=False),
        NestedField(5, "timestamp_day", LongType(), required=False),
        NestedField(6, "binary_data", BinaryType(), required=False),
    )


def get_partition_spec():
    # Define partition spec (portal_id, timestamp_day)
    return PartitionSpec(
        spec_id=0,
        fields=[
            PartitionField(
                source_id=4,
                field_id=1000,
                transform=IdentityTransform(),
                name="portal_id"
            ),
            PartitionField(
                source_id=5,
                field_id=1001,
                transform=IdentityTransform(),
                name="timestamp_day"
            ),
        ]
    )


def get_empty_partition_spec():
    # Define empty partition spec for non-partitioned tests
    return PartitionSpec(spec_id=0, fields=[])


def get_sample_data():
    # Create sample data with binary field
    return pa.Table.from_pylist([
        {"city": "Amsterdam", "lat": 52.371807, "long": 4.896029, "portal_id": 9,
         "timestamp_day": 20240301, "binary_data": b"Amsterdam data"},
        {"city": "San Francisco", "lat": 37.773972, "long": -122.431297, "portal_id": 9,
         "timestamp_day": 20240301, "binary_data": b"San Francisco data"},
        {"city": "Drachten", "lat": 53.11254, "long": 6.0989, "portal_id": 10,
         "timestamp_day": 20240302, "binary_data": b"Drachten data"},
        {"city": "Paris", "lat": 48.864716, "long": 2.349014, "portal_id": 10,
         "timestamp_day": 20240302, "binary_data": b"Paris data"},
    ])


def compare_datafiles(expected: Any, actual: Any) -> Dict[str, Any]:
    """
    Compare two DataFile objects and return differences.
    Returns empty dict if they're identical, otherwise returns the differences.

    Args:
        expected: First DataFile object to compare
        actual: Second DataFile object to compare

    Returns:
        Dictionary of differences, empty if objects are identical

    Raises:
        TypeError: If either argument is not a DataFile
    """
    # Input validation - make sure both are actually DataFile objects
    if not isinstance(expected, DataFile):
        raise TypeError(f"First argument must be a DataFile, got {type(expected)} instead")
    if not isinstance(actual, DataFile):
        raise TypeError(f"Second argument must be a DataFile, got {type(actual)} instead")

    differences = {}

    # Compare all slots from DataFile
    for slot in expected.__class__.__slots__:
        if slot == "_struct":  # Skip internal struct field
            continue

        if hasattr(expected, slot) and hasattr(actual, slot):
            expected_value = getattr(expected, slot)
            actual_value = getattr(actual, slot)

            # Special handling for different types
            if isinstance(expected_value, Record) and isinstance(actual_value, Record):
                # Enhanced comparison for Record objects (especially partition)
                record_differences = {}

                # Check structure (_position_to_field_name)
                if hasattr(expected_value, "_position_to_field_name") and hasattr(actual_value,
                                                                                  "_position_to_field_name"):
                    orig_fields = expected_value._position_to_field_name
                    des_fields = actual_value._position_to_field_name

                    if orig_fields != des_fields:
                        record_differences["_position_to_field_name"] = {
                            "expected": orig_fields,
                            "actual": des_fields
                        }

                # Check data values
                if hasattr(expected_value, "_data") and hasattr(actual_value, "_data"):
                    # Ensure both _data attributes are tuples/lists and have values
                    orig_data = expected_value._data if expected_value._data else ()
                    des_data = actual_value._data if actual_value._data else ()

                    # Check if one is empty but the other isn't
                    if bool(orig_data) != bool(des_data):
                        record_differences["_data_presence"] = {
                            "expected": "present" if orig_data else "empty",
                            "actual": "present" if des_data else "empty"
                        }

                    # Compare content if both exist
                    if orig_data and des_data:
                        if len(orig_data) != len(des_data):
                            record_differences["_data_length"] = {
                                "expected": len(orig_data),
                                "actual": len(des_data)
                            }
                        else:
                            # Compare each item
                            for i, (orig_item, des_item) in enumerate(zip(orig_data, des_data)):
                                if orig_item != des_item:
                                    record_differences[f"_data[{i}]"] = {
                                        "expected": orig_item,
                                        "actual": des_item
                                    }

                # Additional check: Try to access fields directly as attributes
                if hasattr(expected_value, "_position_to_field_name"):
                    for field_name in expected_value._position_to_field_name:
                        orig_attr = getattr(expected_value, field_name, None)
                        des_attr = getattr(actual_value, field_name, None)

                        if orig_attr != des_attr:
                            record_differences[f"attr.{field_name}"] = {
                                "expected": orig_attr,
                                "actual": des_attr
                            }

                # If any differences were found in the record
                if record_differences:
                    differences[slot] = record_differences

            elif isinstance(expected_value, dict) and isinstance(actual_value, dict):
                # Compare dictionaries (like lower_bounds, upper_bounds)
                if set(expected_value.keys()) != set(actual_value.keys()):
                    differences[f"{slot}.keys"] = {
                        "expected": set(expected_value.keys()),
                        "actual": set(actual_value.keys())
                    }

                # Compare values
                for key in expected_value:
                    if key in actual_value:
                        if expected_value[key] != actual_value[key]:
                            differences[f"{slot}[{key}]"] = {
                                "expected": expected_value[key],
                                "actual": actual_value[key]
                            }
            elif expected_value != actual_value:
                differences[slot] = {
                    "expected": expected_value,
                    "actual": actual_value
                }

    return differences


class TestIcebergBase:
    """Base class for Iceberg tests with shared methods"""

    def serialize_and_deserialize(self, table, data_file):
        """Helper method to serialize and deserialize a DataFile"""
        output = BytesIO()

        # Get table schema
        partition_type = table.spec().partition_type(schema=table.schema())
        schema = data_file_with_partition(format_version= table.format_version, partition_type=partition_type)

        # Encode
        encoder = BinaryEncoder(output)
        construct_writer(file_schema=schema).write(encoder, data_file)
        output = output.getvalue()

        # Decode
        decoder = CythonBinaryDecoder(output)
        actual_data_file = resolve_reader(schema, schema, read_types={-1: DataFile}, ).read(decoder)

        return actual_data_file

    def append_data_file(self, table, data_file):
        """Helper method to append a DataFile to a table"""
        with table.transaction() as trx:
            with trx.update_snapshot().fast_append() as update_snapshot:
                update_snapshot.append_data_file(data_file)


    @staticmethod
    def serialize_datafile_to_avro_file(datafile: DataFile, file_path: str) -> None:
        """
        Serialize a DataFile to Avro format and write it directly to disk.

        Args:
            datafile: The DataFile object to serialize
            file_path: The path where the file should be written
        """
        schema = DATA_FILE_TYPE[DEFAULT_READ_VERSION]

        # Open a file for binary writing
        with open(file_path, 'wb') as file_output:
            encoder = BinaryEncoder(file_output)
            construct_writer(file_schema=schema).write(encoder, datafile)

@pytest.fixture(scope="class")
def iceberg_setup_with_partition():
    """Create a temporary Iceberg table with partitioning"""

    # Set data files output directory
    temp_dir = tempfile.mkdtemp()
    os.environ["PYICEBERG_PARQUET_OUTPUT"] = temp_dir

    # Create a catalog and schema
    catalog = load_catalog("catalog", type="in-memory")
    catalog.create_namespace("default")

    # Create table with partitioning
    table = catalog.create_table(
        identifier="default.cities_with_partition",
        schema=get_schema(),
        partition_spec=get_partition_spec()
    )

    # Create sample data with binary field
    data = get_sample_data()

    data_files = list(_dataframe_to_data_files(
        table_metadata=table.metadata, write_uuid=uuid.uuid4(), df=data, io=table.io))

    yield {
        "catalog": catalog,
        "table": table,
        "sample_data_file": data_files[0],
        "data_files": data_files
    }

    # Cleanup after all tests
    if os.path.exists(temp_dir):
        for root, dirs, files in os.walk(temp_dir, topdown=False):
            for file in files:
                os.remove(os.path.join(root, file))
            for dir in dirs:
                os.rmdir(os.path.join(root, dir))
        os.rmdir(temp_dir)


@pytest.fixture(scope="class")
def iceberg_setup_no_partition():
    """Create a temporary Iceberg table without partitioning"""

    # Set data files output directory
    temp_dir = tempfile.mkdtemp()
    os.environ["PYICEBERG_PARQUET_OUTPUT"] = temp_dir

    # Create a catalog and schema
    catalog = load_catalog("catalog", type="in-memory")
    catalog.create_namespace("default")

    # Create table without partitioning
    table = catalog.create_table(
        identifier="default.cities_no_partition",
        schema=get_schema(),
        partition_spec=get_empty_partition_spec()
    )

    # Create sample data with binary field
    data = get_sample_data()

    data_files = list(_dataframe_to_data_files(
        table_metadata=table.metadata, write_uuid=uuid.uuid4(), df=data, io=table.io))

    yield {
        "catalog": catalog,
        "table": table,
        "sample_data_file": data_files[0],
        "data_files": data_files
    }

    # Cleanup after all tests
    if os.path.exists(temp_dir):
        for root, dirs, files in os.walk(temp_dir, topdown=False):
            for file in files:
                os.remove(os.path.join(root, file))
            for dir in dirs:
                os.rmdir(os.path.join(root, dir))
        os.rmdir(temp_dir)


class TestIcebergWithPartition(TestIcebergBase):
    """Tests for Iceberg operations with partition"""

    @pytest.fixture(autouse=True)
    def setup(self, iceberg_setup_with_partition):
        """Setup for all tests in this class"""
        self.setup_data = iceberg_setup_with_partition
        self.sample_data_file = self.setup_data["sample_data_file"]
        self.data_files = self.setup_data["data_files"]
        self.table = self.setup_data["table"]

    def test_serialize(self):
        """Test serializing and deserializing DataFile with partition"""
        actual_data_file = self.serialize_and_deserialize(self.table, self.sample_data_file)

        differences = compare_datafiles(self.sample_data_file, actual_data_file)
        assert not differences, f"Found differences: {differences}"

    def test_fast_append_working(self):
        """Test fast append with native DataFile with partition"""
        self.append_data_file(self.table, self.data_files[0])

    def test_fast_append_with_avro(self):
        """Test fast append with Avro deserialized DataFile with partition"""
        actual_data_file = self.serialize_and_deserialize(self.table, self.sample_data_file)
        self.append_data_file(self.table, actual_data_file)


class TestIcebergNoPartition(TestIcebergBase):
    """Tests for Iceberg operations without partition"""

    @pytest.fixture(autouse=True)
    def setup(self, iceberg_setup_no_partition):
        """Setup for all tests in this class"""
        self.setup_data = iceberg_setup_no_partition
        self.sample_data_file = self.setup_data["sample_data_file"]
        self.data_files = self.setup_data["data_files"]
        self.table = self.setup_data["table"]

    def test_serialize(self):
        """Test serializing and deserializing DataFile without partition"""
        actual_data_file = self.serialize_and_deserialize(self.table, self.sample_data_file)

        differences = compare_datafiles(self.sample_data_file, actual_data_file)
        assert not differences, f"Found differences: {differences}"

    def test_fast_append_working(self):
        """Test fast append with native DataFile without partition"""
        self.append_data_file(self.table, self.data_files[0])

    def test_fast_append_with_avro(self):
        """Test fast append with Avro deserialized DataFile without partition"""
        actual_data_file = self.serialize_and_deserialize(self.table, self.sample_data_file)
        self.append_data_file(self.table, actual_data_file)

@Fokko
Copy link
Contributor

Fokko commented Mar 12, 2025

Hey @andormarkus Thanks for sharing. that looks great! I'm all in favor of supporting this. Very much looking forward to the PR

Should we support __bytes__ to return the Avro encoded bytes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants