diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 700a81300b..00749c6d1b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -621,7 +621,9 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed: warnings.warn("Delete operation did not match any records") - def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def add_files( + self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True + ) -> None: """ Shorthand API for adding files as data files to the table transaction. @@ -630,7 +632,21 @@ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = Raises: FileNotFoundError: If the file does not exist. + ValueError: Raises a ValueError given file_paths contains duplicate files + ValueError: Raises a ValueError given file_paths already referenced by table """ + if len(file_paths) != len(set(file_paths)): + raise ValueError("File paths must be unique") + + if check_duplicate_files: + import pyarrow.compute as pc + + expr = pc.field("file_path").isin(file_paths) + referenced_files = [file["file_path"] for file in self._table.inspect.files().filter(expr).to_pylist()] + + if referenced_files: + raise ValueError(f"Cannot add files that are already referenced by table, files: {', '.join(referenced_files)}") + if self.table_metadata.name_mapping() is None: self.set_properties(**{ TableProperties.DEFAULT_NAME_MAPPING: self.table_metadata.schema().name_mapping.model_dump_json() @@ -1632,7 +1648,9 @@ def delete( with self.transaction() as tx: tx.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties) - def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def add_files( + self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True + ) -> None: """ Shorthand API for adding files as data files to the table. @@ -1643,7 +1661,9 @@ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = FileNotFoundError: If the file does not exist. """ with self.transaction() as tx: - tx.add_files(file_paths=file_paths, snapshot_properties=snapshot_properties) + tx.add_files( + file_paths=file_paths, snapshot_properties=snapshot_properties, check_duplicate_files=check_duplicate_files + ) def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive) @@ -2270,7 +2290,8 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema: visit_with_partner( Catalog._convert_schema_if_needed(new_schema), -1, - UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore + UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), + # type: ignore PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive), ) return self diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 3703a9e0b6..85e626edf4 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -732,3 +732,98 @@ def test_add_files_subset_of_schema(spark: SparkSession, session_catalog: Catalo for column in written_arrow_table.column_names: for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): assert left == right + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_with_duplicate_files_in_file_paths(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.test_table_duplicate_add_files_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + file_path = "s3://warehouse/default/unpartitioned/v{format_version}/test-1.parquet" + file_paths = [file_path, file_path] + + # add the parquet files as data files + with pytest.raises(ValueError) as exc_info: + tbl.add_files(file_paths=file_paths) + assert "File paths must be unique" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_that_referenced_by_current_snapshot( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.test_table_add_referenced_file_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)] + + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths) + existing_files_in_table = tbl.inspect.files().to_pylist().pop()["file_path"] + + with pytest.raises(ValueError) as exc_info: + tbl.add_files(file_paths=[existing_files_in_table]) + assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_files_false( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.test_table_add_referenced_file_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths) + existing_files_in_table = tbl.inspect.files().to_pylist().pop()["file_path"] + tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=False) + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + assert [row.added_data_files_count for row in rows] == [5, 1, 5] + assert [row.existing_data_files_count for row in rows] == [0, 0, 0] + assert [row.deleted_data_files_count for row in rows] == [0, 0, 0] + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_add_files_that_referenced_by_current_snapshot_with_check_duplicate_files_true( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.test_table_add_referenced_file_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths) + existing_files_in_table = tbl.inspect.files().to_pylist().pop()["file_path"] + with pytest.raises(ValueError) as exc_info: + tbl.add_files(file_paths=[existing_files_in_table], check_duplicate_files=True) + assert f"Cannot add files that are already referenced by table, files: {existing_files_in_table}" in str(exc_info.value)