diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8f7b45f532..c073cbf1c7 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -918,6 +918,15 @@ def refresh(self) -> Table: self.metadata_location = fresh.metadata_location return self + def check_and_refresh_table(self) -> Optional[Table]: + fresh = self.catalog.load_table(self._identifier) + if self.metadata.current_snapshot_id != fresh.metadata.current_snapshot_id: + self.metadata = fresh.metadata + self.io = fresh.io + self.metadata_location = fresh.metadata_location + return fresh + return None + def name(self) -> Identifier: """Return the identifier of this table. diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index c705f3b9fd..be7eb2996c 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -239,7 +239,16 @@ def _summary(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> Summary: truncate_full_table=self._operation == Operation.OVERWRITE, ) + @abstractmethod + def _validate(self) -> None: + pass + def _commit(self) -> UpdatesAndRequirements: + from pyiceberg.table import StagedTable + + if not isinstance(self._transaction._table, StagedTable): + self._validate() + new_manifests = self._manifests() next_sequence_number = self._transaction.table_metadata.next_sequence_number() @@ -435,6 +444,9 @@ def _existing_manifests(self) -> List[ManifestFile]: def _deleted_entries(self) -> List[ManifestEntry]: return self._compute_deletes[1] + def _validate(self) -> None: + return + @property def rewrites_needed(self) -> bool: """Indicate if data files need to be rewritten.""" @@ -474,6 +486,15 @@ def _deleted_entries(self) -> List[ManifestEntry]: """ return [] + def _validate(self) -> None: + refresh_table = self._transaction._table.check_and_refresh_table() + if refresh_table is None: + return + current_snapshot = refresh_table.metadata.current_snapshot() + if current_snapshot is not None and current_snapshot.snapshot_id != self._parent_snapshot_id: + self._parent_snapshot_id = current_snapshot.snapshot_id + self._transaction.table_metadata = refresh_table.metadata + class _MergeAppendFiles(_FastAppendFiles): _target_size_bytes: int @@ -602,6 +623,9 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: else: return [] + def _validate(self) -> None: + return + class UpdateSnapshot: _transaction: Transaction diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 2c6eb4b4ab..c15c3aec36 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -898,3 +898,32 @@ def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_file with pytest.raises(ValueError) as exc_info: tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=True) assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_delete_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema) + tbl1.append(arrow_table_with_null) + tbl2 = session_catalog.load_table(identifier) + + # This is allowed + tbl1.delete("string == 'z'") + tbl2.append(arrow_table_with_null) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_conflict_append_append( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.test_conflict" + tbl1 = _create_table(session_catalog, identifier, format_version, schema=arrow_table_with_null.schema) + tbl1.append(arrow_table_with_null) + tbl2 = session_catalog.load_table(identifier) + + tbl1.append(arrow_table_with_null) + tbl2.append(arrow_table_with_null)