Skip to content

feat: refresh table when committing to support concurrent appends #1885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,15 @@ def refresh(self) -> Table:
self.metadata_location = fresh.metadata_location
return self

def check_and_refresh_table(self) -> Optional[Table]:
fresh = self.catalog.load_table(self._identifier)
if self.metadata.current_snapshot_id != fresh.metadata.current_snapshot_id:
self.metadata = fresh.metadata
self.io = fresh.io
self.metadata_location = fresh.metadata_location
return fresh
return None

def name(self) -> Identifier:
"""Return the identifier of this table.

Expand Down
24 changes: 24 additions & 0 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,16 @@ def _summary(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> Summary:
truncate_full_table=self._operation == Operation.OVERWRITE,
)

@abstractmethod
def _validate(self) -> None:
pass

def _commit(self) -> UpdatesAndRequirements:
from pyiceberg.table import StagedTable

if not isinstance(self._transaction._table, StagedTable):
self._validate()

new_manifests = self._manifests()
next_sequence_number = self._transaction.table_metadata.next_sequence_number()

Expand Down Expand Up @@ -435,6 +444,9 @@ def _existing_manifests(self) -> List[ManifestFile]:
def _deleted_entries(self) -> List[ManifestEntry]:
return self._compute_deletes[1]

def _validate(self) -> None:
return

@property
def rewrites_needed(self) -> bool:
"""Indicate if data files need to be rewritten."""
Expand Down Expand Up @@ -474,6 +486,15 @@ def _deleted_entries(self) -> List[ManifestEntry]:
"""
return []

def _validate(self) -> None:
refresh_table = self._transaction._table.check_and_refresh_table()
if refresh_table is None:
return
current_snapshot = refresh_table.metadata.current_snapshot()
if current_snapshot is not None and current_snapshot.snapshot_id != self._parent_snapshot_id:
self._parent_snapshot_id = current_snapshot.snapshot_id
self._transaction.table_metadata = refresh_table.metadata


class _MergeAppendFiles(_FastAppendFiles):
_target_size_bytes: int
Expand Down Expand Up @@ -602,6 +623,9 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
else:
return []

def _validate(self) -> None:
return


class UpdateSnapshot:
_transaction: Transaction
Expand Down
29 changes: 29 additions & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,3 +898,32 @@ def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_file
with pytest.raises(ValueError) as exc_info:
tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=True)
assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_conflict_delete_append(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = "default.test_conflict"
tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema)
tbl1.append(arrow_table_with_null)
tbl2 = session_catalog.load_table(identifier)

# This is allowed
tbl1.delete("string == 'z'")
tbl2.append(arrow_table_with_null)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_conflict_append_append(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = "default.test_conflict"
tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema)
tbl1.append(arrow_table_with_null)
tbl2 = session_catalog.load_table(identifier)

tbl1.append(arrow_table_with_null)
tbl2.append(arrow_table_with_null)