diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index eb9258f43..0ca693f66 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -107,24 +107,21 @@ def parse( dataset_version: Optional[str], dataset_version_created_at: Optional[datetime], ) -> Optional["DatasetDependency"]: - from datachain.client import Client - from datachain.lib.listing import is_listing_dataset, listing_uri_from_name + from datachain.lib.listing import is_listing_dataset if not dataset_id: return None assert dataset_name is not None - dependency_type = DatasetDependencyType.DATASET - dependency_name = dataset_name - - if is_listing_dataset(dataset_name): - dependency_type = DatasetDependencyType.STORAGE # type: ignore[arg-type] - dependency_name, _ = Client.parse_url(listing_uri_from_name(dataset_name)) return cls( id, - dependency_type, - dependency_name, + ( + DatasetDependencyType.STORAGE + if is_listing_dataset(dataset_name) + else DatasetDependencyType.DATASET + ), + dataset_name, ( dataset_version # type: ignore[arg-type] if dataset_version diff --git a/src/datachain/delta.py b/src/datachain/delta.py new file mode 100644 index 000000000..22465c25c --- /dev/null +++ b/src/datachain/delta.py @@ -0,0 +1,119 @@ +from collections.abc import Sequence +from copy import copy +from functools import wraps +from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union + +import datachain +from datachain.dataset import DatasetDependency +from datachain.error import DatasetNotFoundError + +if TYPE_CHECKING: + from typing_extensions import Concatenate, ParamSpec + + from datachain.lib.dc import DataChain + + P = ParamSpec("P") + + +T = TypeVar("T", bound="DataChain") + + +def delta_disabled( + method: "Callable[Concatenate[T, P], T]", +) -> "Callable[Concatenate[T, P], T]": + """ + Decorator for disabling DataChain methods (e.g `.agg()` or `.union()`) to + work with delta updates. It throws `NotImplementedError` if chain on which + method is called is marked as delta. + """ + + @wraps(method) + def _inner(self: T, *args: "P.args", **kwargs: "P.kwargs") -> T: + if self.delta: + raise NotImplementedError( + f"Delta update cannot be used with {method.__name__}" + ) + return method(self, *args, **kwargs) + + return _inner + + +def _append_steps(dc: "DataChain", other: "DataChain"): + """Returns cloned chain with appended steps from other chain. + Steps are all those modification methods applied like filters, mappers etc. + """ + dc = dc.clone() + dc._query.steps += other._query.steps.copy() + dc.signals_schema = other.signals_schema + return dc + + +def delta_update( + dc: "DataChain", + name: str, + on: Union[str, Sequence[str]], + right_on: Optional[Union[str, Sequence[str]]] = None, + compare: Optional[Union[str, Sequence[str]]] = None, +) -> tuple[Optional["DataChain"], Optional[list[DatasetDependency]], bool]: + """ + Creates new chain that consists of the last version of current delta dataset + plus diff from the source with all needed modifications. + This way we don't need to re-calculate the whole chain from the source again( + apply all the DataChain methods like filters, mappers, generators etc.) + but just the diff part which is very important for performance. + + Note that currently delta update works only if there is only one direct dependency. + """ + catalog = dc.session.catalog + dc._query.apply_listing_pre_step() + + try: + latest_version = catalog.get_dataset(name).latest_version + except DatasetNotFoundError: + # first creation of delta update dataset + return None, None, True + + dependencies = catalog.get_dataset_dependencies( + name, latest_version, indirect=False + ) + + dep = dependencies[0] + if not dep: + # starting dataset (e.g listing) was removed so we are backing off to normal + # dataset creation, as it was created first time + return None, None, True + + source_ds_name = dep.name + source_ds_version = dep.version + source_ds_latest_version = catalog.get_dataset(source_ds_name).latest_version + dependencies = copy(dependencies) + dependencies = [d for d in dependencies if d is not None] # filter out removed dep + dependencies[0].version = source_ds_latest_version # type: ignore[union-attr] + + source_dc = datachain.read_dataset(source_ds_name, source_ds_version) + source_dc_latest = datachain.read_dataset(source_ds_name, source_ds_latest_version) + + diff = source_dc_latest.compare(source_dc, on=on, compare=compare, deleted=False) + # We append all the steps from the original chain to diff, e.g filters, mappers. + diff = _append_steps(diff, dc) + + # to avoid re-calculating diff multiple times + diff = diff.persist() + + if diff.empty: + return None, None, False + + # merging diff and the latest version of dataset + delta_chain = ( + datachain.read_dataset(name, latest_version) + .compare( + diff, + on=right_on or on, + added=True, + modified=False, + deleted=False, + ) + .union(diff) + ) + + return delta_chain, dependencies, True # type: ignore[return-value] diff --git a/src/datachain/diff/__init__.py b/src/datachain/diff/__init__.py index 37a8415a7..93451a66d 100644 --- a/src/datachain/diff/__init__.py +++ b/src/datachain/diff/__init__.py @@ -30,7 +30,7 @@ class CompareStatus(str, Enum): SAME = "S" -def _compare( # noqa: C901 +def _compare( # noqa: C901, PLR0912 left: "DataChain", right: "DataChain", on: Union[str, Sequence[str]], @@ -77,14 +77,16 @@ def _to_list(obj: Optional[Union[str, Sequence[str]]]) -> Optional[list[str]]: cols_select = list(left.signals_schema.clone_without_sys_signals().values.keys()) # getting correct on and right_on column names + on_ = on on = left.signals_schema.resolve(*on).db_signals() # type: ignore[assignment] - right_on = right.signals_schema.resolve(*(right_on or on)).db_signals() # type: ignore[assignment] + right_on = right.signals_schema.resolve(*(right_on or on_)).db_signals() # type: ignore[assignment] # getting correct compare and right_compare column names if they are defined if compare: + compare_ = compare compare = left.signals_schema.resolve(*compare).db_signals() # type: ignore[assignment] right_compare = right.signals_schema.resolve( - *(right_compare or compare) + *(right_compare or compare_) ).db_signals() # type: ignore[assignment] elif not compare and len(cols) != len(right_cols): # here we will mark all rows that are not added or deleted as modified since @@ -155,7 +157,11 @@ def _to_list(obj: Optional[Union[str, Sequence[str]]]) -> Optional[list[str]]: if status_col: cols_select.append(diff_col) - dc_diff = dc_diff.select(*cols_select) + if not dc_diff._sys: + # TODO workaround when sys signal is not available in diff + dc_diff = dc_diff.settings(sys=True).select(*cols_select).settings(sys=False) + else: + dc_diff = dc_diff.select(*cols_select) # final schema is schema from the left chain with status column added if needed dc_diff.signals_schema = ( diff --git a/src/datachain/lib/dc/datachain.py b/src/datachain/lib/dc/datachain.py index 0d2ec2b99..08315ed03 100644 --- a/src/datachain/lib/dc/datachain.py +++ b/src/datachain/lib/dc/datachain.py @@ -25,6 +25,7 @@ from datachain import semver from datachain.dataset import DatasetRecord +from datachain.delta import delta_disabled, delta_update from datachain.func import literal from datachain.func.base import Function from datachain.func.func import Func @@ -72,6 +73,9 @@ P = ParamSpec("P") +T = TypeVar("T", bound="DataChain") + + class DataChain: """DataChain - a data structure for batch data processing and evaluation. @@ -164,6 +168,7 @@ def __init__( self.signals_schema = signal_schema self._setup: dict = setup or {} self._sys = _sys + self._delta = False def __repr__(self) -> str: """Return a string representation of the chain.""" @@ -177,6 +182,32 @@ def __repr__(self) -> str: self.print_schema(file=file) return file.getvalue() + def _as_delta( + self, + on: Optional[Union[str, Sequence[str]]] = None, + right_on: Optional[Union[str, Sequence[str]]] = None, + compare: Optional[Union[str, Sequence[str]]] = None, + ) -> "Self": + """Marks this chain as delta, which means special delta process will be + called on saving dataset for optimization""" + if on is None: + raise ValueError("'delta on' fields must be defined") + self._delta = True + self._delta_on = on + self._delta_result_on = right_on + self._delta_compare = compare + return self + + @property + def empty(self) -> bool: + """Returns True if chain has zero number of rows""" + return not bool(self.count()) + + @property + def delta(self) -> bool: + """Returns True if this chain is ran in "delta" update mode""" + return self._delta + @property def schema(self) -> dict[str, DataType]: """Get schema of the chain.""" @@ -254,9 +285,17 @@ def _evolve( signal_schema = copy.deepcopy(self.signals_schema) if _sys is None: _sys = self._sys - return type(self)( + chain = type(self)( query, settings, signal_schema=signal_schema, setup=self._setup, _sys=_sys ) + if self.delta: + chain = chain._as_delta( + on=self._delta_on, + right_on=self._delta_result_on, + compare=self._delta_compare, + ) + + return chain def settings( self, @@ -463,7 +502,7 @@ def save( # type: ignore[override] attrs: Optional[list[str]] = None, update_version: Optional[str] = "patch", **kwargs, - ) -> "Self": + ) -> "DataChain": """Save to a Dataset. It returns the chain itself. Parameters: @@ -490,6 +529,35 @@ def save( # type: ignore[override] ) schema = self.signals_schema.clone_without_sys_signals().serialize() + if self.delta and name: + delta_ds, dependencies, has_changes = delta_update( + self, + name, + on=self._delta_on, + right_on=self._delta_result_on, + compare=self._delta_compare, + ) + + if delta_ds: + return self._evolve( + query=delta_ds._query.save( + name=name, + version=version, + feature_schema=schema, + dependencies=dependencies, + **kwargs, + ) + ) + + if not has_changes: + # sources have not been changed so new version of resulting dataset + # would be the same as previous one. To avoid duplicating exact + # datasets, we won't create new version of it and we will return + # current latest version instead. + from .datasets import read_dataset + + return read_dataset(name, **kwargs) + return self._evolve( query=self._query.save( name=name, @@ -615,6 +683,7 @@ def gen( signal_schema=udf_obj.output, ) + @delta_disabled def agg( self, func: Optional[Callable] = None, @@ -768,6 +837,7 @@ def order_by(self, *args, descending: bool = False) -> "Self": return self._evolve(query=self._query.order_by(*args)) + @delta_disabled def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override] """Removes duplicate rows based on uniqueness of some input column(s) i.e if rows are found with the same value of input column(s), only one @@ -802,6 +872,7 @@ def select_except(self, *args: str) -> "Self": query=self._query.select(*columns), signal_schema=new_schema ) + @delta_disabled # type: ignore[arg-type] def group_by( self, *, @@ -1160,6 +1231,7 @@ def remove_file_signals(self) -> "Self": schema = self.signals_schema.clone_without_file_signals() return self.select(*schema.values.keys()) + @delta_disabled def merge( self, right_ds: "DataChain", @@ -1268,6 +1340,7 @@ def _resolve( return ds + @delta_disabled def union(self, other: "Self") -> "Self": """Return the set union of the two datasets. diff --git a/src/datachain/lib/dc/datasets.py b/src/datachain/lib/dc/datasets.py index c7c093435..d4a82e513 100644 --- a/src/datachain/lib/dc/datasets.py +++ b/src/datachain/lib/dc/datasets.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import TYPE_CHECKING, Optional, Union, get_origin, get_type_hints from datachain.error import DatasetVersionNotFoundError @@ -27,6 +28,10 @@ def read_dataset( session: Optional[Session] = None, settings: Optional[dict] = None, fallback_to_studio: bool = True, + delta: Optional[bool] = False, + delta_on: Optional[Union[str, Sequence[str]]] = None, + delta_result_on: Optional[Union[str, Sequence[str]]] = None, + delta_compare: Optional[Union[str, Sequence[str]]] = None, ) -> "DataChain": """Get data from a saved Dataset. It returns the chain itself. If dataset or version is not found locally, it will try to pull it from Studio. @@ -38,6 +43,36 @@ def read_dataset( settings : Settings to use for the chain. fallback_to_studio : Try to pull dataset from Studio if not found locally. Default is True. + delta: If set to True, we optimize the creation of new dataset versions by + calculating the diff between the latest version of this storage and the + version used to create the most recent version of the resulting chain + dataset (the one specified in `.save()`). We then run the "diff" chain + using only the diff data, rather than the entire storage data, and merge + that diff chain with the latest version of the resulting dataset to create + a new version. This approach avoids applying modifications to all records + from storage every time, which can be an expensive operation. + The diff is calculated using the `DataChain.compare()` method, which + compares the `delta_on` fields to find matches and checks the compare + fields to determine if a record has changed. Note that this process only + considers added and modified records in storage; deleted records are not + removed from the new dataset version. + This calculation is based on the difference between the current version + of the source and the version used to create the dataset. + delta_on: A list of fields that uniquely identify rows in the source. + If two rows have the same values, they are considered the same (e.g., they + could be different versions of the same row in a versioned source). + This is used in the delta update to calculate the diff. + delta_result_on: A list of fields in the resulting dataset that correspond + to the `delta_on` fields from the source. + This is needed to identify rows that have changed in the source but are + already present in the current version of the resulting dataset, in order + to avoid including outdated versions of those rows in the new dataset. + We retain only the latest versions of rows to prevent duplication. + There is no need to define this if the `delta_on` fields are present in + the final dataset and have not been renamed. + delta_compare: A list of fields used to check if the same row has been modified + in the new version of the source. + If not defined, all fields except those defined in delta_on will be used. Example: ```py @@ -113,7 +148,12 @@ def read_dataset( signals_schema |= SignalSchema.deserialize(query.feature_schema) else: signals_schema |= SignalSchema.from_column_types(query.column_types or {}) - return DataChain(query, _settings, signals_schema) + chain = DataChain(query, _settings, signals_schema) + if delta: + chain = chain._as_delta( + on=delta_on, right_on=delta_result_on, compare=delta_compare + ) + return chain def datasets( diff --git a/src/datachain/lib/dc/storage.py b/src/datachain/lib/dc/storage.py index 7c172fba2..246bf7093 100644 --- a/src/datachain/lib/dc/storage.py +++ b/src/datachain/lib/dc/storage.py @@ -1,4 +1,6 @@ import os.path +from collections.abc import Sequence +from functools import reduce from typing import ( TYPE_CHECKING, Optional, @@ -32,6 +34,10 @@ def read_storage( column: str = "file", update: bool = False, anon: bool = False, + delta: Optional[bool] = False, + delta_on: Optional[Union[str, Sequence[str]]] = None, + delta_result_on: Optional[Union[str, Sequence[str]]] = None, + delta_compare: Optional[Union[str, Sequence[str]]] = None, client_config: Optional[dict] = None, ) -> "DataChain": """Get data from storage(s) as a list of file with all file attributes. @@ -47,6 +53,36 @@ def read_storage( update : force storage reindexing. Default is False. anon : If True, we will treat cloud bucket as public one client_config : Optional client configuration for the storage client. + delta: If set to True, we optimize the creation of new dataset versions by + calculating the diff between the latest version of this storage and the + version used to create the most recent version of the resulting chain + dataset (the one specified in `.save()`). We then run the "diff" chain + using only the diff data, rather than the entire storage data, and merge + that diff chain with the latest version of the resulting dataset to create + a new version. This approach avoids applying modifications to all records + from storage every time, which can be an expensive operation. + The diff is calculated using the `DataChain.compare()` method, which + compares the `delta_on` fields to find matches and checks the compare + fields to determine if a record has changed. Note that this process only + considers added and modified records in storage; deleted records are not + removed from the new dataset version. + This calculation is based on the difference between the current version + of the source and the version used to create the dataset. + delta_on: A list of fields that uniquely identify rows in the source. + If two rows have the same values, they are considered the same (e.g., they + could be different versions of the same row in a versioned source). + This is used in the delta update to calculate the diff. + delta_result_on: A list of fields in the resulting dataset that correspond + to the `delta_on` fields from the source. + This is needed to identify rows that have changed in the source but are + already present in the current version of the resulting dataset, in order + to avoid including outdated versions of those rows in the new dataset. + We retain only the latest versions of rows to prevent duplication. + There is no need to define this if the `delta_on` fields are present in + the final dataset and have not been renamed. + delta_compare: A list of fields used to check if the same row has been modified + in the new version of the source. + If not defined, all fields except those defined in `delta_on` will be used. Returns: DataChain: A DataChain object containing the file information. @@ -106,7 +142,7 @@ def read_storage( if not uris: raise ValueError("No URIs provided") - storage_chain = None + chains = [] listed_ds_name = set() file_values = [] @@ -151,11 +187,11 @@ def lst_fn(ds_name, lst_uri): lambda ds_name=list_ds_name, lst_uri=list_uri: lst_fn(ds_name, lst_uri) ) - chain = ls(dc, list_path, recursive=recursive, column=column) - - storage_chain = storage_chain.union(chain) if storage_chain else chain + chains.append(ls(dc, list_path, recursive=recursive, column=column)) listed_ds_name.add(list_ds_name) + storage_chain = None if not chains else reduce(lambda x, y: x.union(y), chains) + if file_values: file_chain = read_values( session=session, @@ -170,4 +206,8 @@ def lst_fn(ds_name, lst_uri): assert storage_chain is not None + if delta: + storage_chain = storage_chain._as_delta( + on=delta_on, right_on=delta_result_on, compare=delta_compare + ) return storage_chain diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 7425d17a0..b011949cf 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -461,14 +461,13 @@ def row_to_objs(self, row: Sequence[Any]) -> list[DataValue]: pos += 1 return objs - def contains_file(self) -> bool: - for type_ in self.values.values(): - if (fr := ModelStore.to_pydantic(type_)) is not None and issubclass( + def get_file_signal(self) -> Optional[str]: + for signal_name, signal_type in self.values.items(): + if (fr := ModelStore.to_pydantic(signal_type)) is not None and issubclass( fr, File ): - return True - - return False + return signal_name + return None def slice( self, @@ -705,6 +704,13 @@ def merge( return SignalSchema(self.values | schema_right) + def append(self, right: "SignalSchema") -> "SignalSchema": + missing_schema = { + key: right.values[key] + for key in [k for k in right.values if k not in self.values] + } + return SignalSchema(self.values | missing_schema) + def get_signals(self, target_type: type[DataModel]) -> Iterator[str]: for path, type_, has_subtree, _ in self.get_flat_tree(): if has_subtree and issubclass(type_, target_type): diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index e88566ae6..177b3667d 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -41,7 +41,7 @@ partition_col_names, partition_columns, ) -from datachain.dataset import DATASET_PREFIX, DatasetStatus, RowDict +from datachain.dataset import DATASET_PREFIX, DatasetDependency, DatasetStatus, RowDict from datachain.error import DatasetNotFoundError, QueryScriptCancelError from datachain.func.base import Function from datachain.lib.listing import is_listing_dataset, listing_dataset_expired @@ -166,11 +166,13 @@ def apply( @frozen class QueryStep: + """A query that returns all rows from specific dataset version""" + catalog: "Catalog" dataset_name: str dataset_version: str - def apply(self): + def apply(self) -> "StepResult": def q(*columns): return sqlalchemy.select(*columns) @@ -1127,9 +1129,14 @@ def __init__( self.version = version if is_listing_dataset(name): - # not setting query step yet as listing dataset might not exist at - # this point - self.list_ds_name = name + if version: + # this listing dataset should already be listed as we specify + # exact version + self._set_starting_step(self.catalog.get_dataset(name)) + else: + # not setting query step yet as listing dataset might not exist at + # this point + self.list_ds_name = name elif fallback_to_studio and is_token_set(): self._set_starting_step( self.catalog.get_dataset_with_remote_fallback(name, version) @@ -1205,11 +1212,8 @@ def set_listing_fn(self, fn: Callable) -> None: """Setting listing function to be run if needed""" self.listing_fn = fn - def apply_steps(self) -> QueryGenerator: - """ - Apply the steps in the query and return the resulting - sqlalchemy.SelectBase. - """ + def apply_listing_pre_step(self) -> None: + """Runs listing pre-step if needed""" if self.list_ds_name and not self.starting_step: listing_ds = None try: @@ -1225,6 +1229,13 @@ def apply_steps(self) -> QueryGenerator: # at this point we know what is our starting listing dataset name self._set_starting_step(listing_ds) # type: ignore [arg-type] + def apply_steps(self) -> QueryGenerator: + """ + Apply the steps in the query and return the resulting + sqlalchemy.SelectBase. + """ + self.apply_listing_pre_step() + query = self.clone() index = os.getenv("DATACHAIN_QUERY_CHUNK_INDEX", self._chunk_index) @@ -1687,6 +1698,7 @@ def save( name: Optional[str] = None, version: Optional[str] = None, feature_schema: Optional[dict] = None, + dependencies: Optional[list[DatasetDependency]] = None, description: Optional[str] = None, attrs: Optional[list[str]] = None, update_version: Optional[str] = "patch", @@ -1742,6 +1754,9 @@ def save( ) self.catalog.update_dataset_version_with_warehouse_info(dataset, version) + if dependencies: + # overriding dependencies + self.dependencies = {(dep.name, dep.version) for dep in dependencies} self._add_dependencies(dataset, version) # type: ignore [arg-type] finally: self.cleanup() diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 3bbf9f314..1fd99c358 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -230,15 +230,13 @@ def test_read_storage_dependencies(cloud_test_catalog, cloud_type): ctc = cloud_test_catalog src_uri = ctc.src_uri uri = f"{src_uri}/cats" + dep_name, _, _ = parse_listing_uri(uri, ctc.catalog.client_config) ds_name = "dep" dc.read_storage(uri, session=ctc.session).save(ds_name) dependencies = ctc.session.catalog.get_dataset_dependencies(ds_name, "1.0.0") assert len(dependencies) == 1 assert dependencies[0].type == DatasetDependencyType.STORAGE - if cloud_type == "file": - assert dependencies[0].name == uri - else: - assert dependencies[0].name == src_uri + assert dependencies[0].name == dep_name @pytest.mark.parametrize("use_cache", [True, False]) diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index 6738bc75c..37eecc8d4 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -10,6 +10,7 @@ from datachain.error import ( DatasetVersionNotFoundError, ) +from datachain.lib.listing import parse_listing_uri from datachain.query import C, DatasetQuery, Object, Stream from datachain.sql.functions import path as pathfunc from datachain.sql.types import String @@ -964,6 +965,9 @@ def test_dataset_dependencies_one_storage_as_dependency( ds_name = uuid.uuid4().hex catalog = cloud_test_catalog.catalog listing = catalog.listings()[0] + dep_name, _, _ = parse_listing_uri( + cloud_test_catalog.src_uri, catalog.client_config + ) DatasetQuery(cats_dataset.name, catalog=catalog).save(ds_name) @@ -976,7 +980,7 @@ def test_dataset_dependencies_one_storage_as_dependency( { "id": ANY, "type": DatasetDependencyType.STORAGE, - "name": cloud_test_catalog.src_uri, + "name": dep_name, "version": "1.0.0", "created_at": listing.created_at, "dependencies": [], @@ -992,6 +996,10 @@ def test_dataset_dependencies_one_registered_dataset_as_dependency( catalog = cloud_test_catalog.catalog listing = catalog.listings()[0] + dep_name, _, _ = parse_listing_uri( + cloud_test_catalog.src_uri, catalog.client_config + ) + DatasetQuery(name=dogs_dataset.name, catalog=catalog).save(ds_name) expected = [ @@ -1010,7 +1018,7 @@ def test_dataset_dependencies_one_registered_dataset_as_dependency( { "id": ANY, "type": DatasetDependencyType.STORAGE, - "name": cloud_test_catalog.src_uri, + "name": dep_name, "version": "1.0.0", "created_at": listing.created_at, "dependencies": [], @@ -1036,6 +1044,9 @@ def test_dataset_dependencies_multiple_direct_dataset_dependencies( ds_name = uuid.uuid4().hex catalog = cloud_test_catalog.catalog listing = catalog.listings()[0] + dep_name, _, _ = parse_listing_uri( + cloud_test_catalog.src_uri, catalog.client_config + ) dogs = DatasetQuery(name=dogs_dataset.name, version="1.0.0", catalog=catalog) cats = DatasetQuery(name=cats_dataset.name, version="1.0.0", catalog=catalog) @@ -1048,7 +1059,7 @@ def test_dataset_dependencies_multiple_direct_dataset_dependencies( storage_depenedncy = { "id": ANY, "type": DatasetDependencyType.STORAGE, - "name": cloud_test_catalog.src_uri, + "name": dep_name, "version": "1.0.0", "created_at": listing.created_at, "dependencies": [], @@ -1105,6 +1116,9 @@ def test_dataset_dependencies_multiple_union( ds_name = uuid.uuid4().hex catalog = cloud_test_catalog.catalog listing = catalog.listings()[0] + dep_name, _, _ = parse_listing_uri( + cloud_test_catalog.src_uri, catalog.client_config + ) dogs = DatasetQuery(name=dogs_dataset.name, version="1.0.0", catalog=catalog) cats = DatasetQuery(name=cats_dataset.name, version="1.0.0", catalog=catalog) @@ -1115,7 +1129,7 @@ def test_dataset_dependencies_multiple_union( storage_depenedncy = { "id": ANY, "type": DatasetDependencyType.STORAGE, - "name": cloud_test_catalog.src_uri, + "name": dep_name, "version": "1.0.0", "created_at": listing.created_at, "dependencies": [], diff --git a/tests/func/test_datasets.py b/tests/func/test_datasets.py index 260ad96d5..fcfb18307 100644 --- a/tests/func/test_datasets.py +++ b/tests/func/test_datasets.py @@ -666,6 +666,7 @@ def test_dataset_storage_dependencies(cloud_test_catalog, cloud_type, indirect): session = ctc.session catalog = session.catalog uri = cloud_test_catalog.src_uri + dep_name, _, _ = parse_listing_uri(ctc.src_uri, catalog.client_config) ds_name = "some_ds" dc.read_storage(uri, session=session).save(ds_name) @@ -680,7 +681,7 @@ def test_dataset_storage_dependencies(cloud_test_catalog, cloud_type, indirect): { "id": ANY, "type": DatasetDependencyType.STORAGE, - "name": uri, + "name": dep_name, "version": "1.0.0", "created_at": lst_dataset.get_version("1.0.0").created_at, "dependencies": [], diff --git a/tests/func/test_delta.py b/tests/func/test_delta.py new file mode 100644 index 000000000..9d6525525 --- /dev/null +++ b/tests/func/test_delta.py @@ -0,0 +1,383 @@ +import os + +import pytest +import regex as re +from PIL import Image + +import datachain as dc +from datachain import func +from datachain.error import DatasetVersionNotFoundError +from datachain.lib.dc import C +from datachain.lib.file import File, ImageFile + + +def _get_dependencies(catalog, name, version) -> list[tuple[str, str]]: + return sorted( + [ + (d.name, d.version) + for d in catalog.get_dataset_dependencies(name, version, indirect=False) + ] + ) + + +def test_delta_update_from_dataset(test_session, tmp_dir, tmp_path): + catalog = test_session.catalog + starting_ds_name = "starting_ds" + ds_name = "delta_ds" + + images = [ + {"name": "img1.jpg", "data": Image.new(mode="RGB", size=(64, 64))}, + {"name": "img2.jpg", "data": Image.new(mode="RGB", size=(128, 128))}, + {"name": "img3.jpg", "data": Image.new(mode="RGB", size=(64, 64))}, + {"name": "img4.jpg", "data": Image.new(mode="RGB", size=(128, 128))}, + ] + + def create_image_dataset(ds_name, images): + dc.read_values( + file=[ + ImageFile(path=img["name"], source=f"file://{tmp_path}") + for img in images + ], + session=test_session, + ).save(ds_name) + + def create_delta_dataset(ds_name): + dc.read_dataset( + starting_ds_name, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + delta_result_on=["file.source", "file.path"], + delta_compare=["file.version", "file.etag"], + ).save(ds_name) + + # first version of starting dataset + create_image_dataset(starting_ds_name, images[:2]) + # first version of delta dataset + create_delta_dataset(ds_name) + assert _get_dependencies(catalog, ds_name, "1.0.0") == [(starting_ds_name, "1.0.0")] + # second version of starting dataset + create_image_dataset(starting_ds_name, images[2:]) + # second version of delta dataset + create_delta_dataset(ds_name) + assert _get_dependencies(catalog, ds_name, "1.0.1") == [(starting_ds_name, "1.0.1")] + + assert list( + dc.read_dataset(ds_name, version="1.0.0") + .order_by("file.path") + .collect("file.path") + ) == [ + "img1.jpg", + "img2.jpg", + ] + + assert list( + dc.read_dataset(ds_name, version="1.0.1") + .order_by("file.path") + .collect("file.path") + ) == [ + "img1.jpg", + "img2.jpg", + "img3.jpg", + "img4.jpg", + ] + + create_delta_dataset(ds_name) + + +def test_delta_update_from_storage(test_session, tmp_dir, tmp_path): + ds_name = "delta_ds" + path = tmp_dir.as_uri() + tmp_dir = tmp_dir / "images" + os.mkdir(tmp_dir) + + images = [ + { + "name": f"img{i}.{'jpg' if i % 2 == 0 else 'png'}", + "data": Image.new(mode="RGB", size=((i + 1) * 10, (i + 1) * 10)), + } + for i in range(20) + ] + + # save only half of the images for now + for img in images[:10]: + img["data"].save(tmp_dir / img["name"]) + + def create_delta_dataset(): + def my_embedding(file: File) -> list[float]: + return [0.5, 0.5] + + def get_index(file: File) -> int: + r = r".+\/img(\d+)\.jpg" + return int(re.search(r, file.path).group(1)) # type: ignore[union-attr] + + ( + dc.read_storage( + path, + update=True, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + delta_result_on=["file.source", "file.path"], + delta_compare=["file.version", "file.etag"], + ) + .filter(C("file.path").glob("*.jpg")) + .map(emb=my_embedding) + .mutate(dist=func.cosine_distance("emb", (0.1, 0.2))) + .map(index=get_index) + .filter(C("index") > 3) + .save(ds_name) + ) + + # first version of delta dataset + create_delta_dataset() + + # remember old etags for later comparison to prove modified images are also taken + # into consideration on delta update + etags = { + r[0]: r[1].etag + for r in dc.read_dataset(ds_name, version="1.0.0").collect("index", "file") + } + + # remove last couple of images to simulate modification since we will re-create it + for img in images[5:10]: + os.remove(tmp_dir / img["name"]) + + # save other half of images and the ones that are removed above + for img in images[5:]: + img["data"].save(tmp_dir / img["name"]) + + # remove first 5 images to check that deleted rows are not taken into consideration + for img in images[0:5]: + os.remove(tmp_dir / img["name"]) + + # second version of delta dataset + create_delta_dataset() + + assert list( + dc.read_dataset(ds_name, version="1.0.0") + .order_by("file.path") + .collect("file.path") + ) == [ + "images/img4.jpg", + "images/img6.jpg", + "images/img8.jpg", + ] + + assert list( + dc.read_dataset(ds_name, version="1.0.1") + .order_by("file.path") + .collect("file.path") + ) == [ + "images/img10.jpg", + "images/img12.jpg", + "images/img14.jpg", + "images/img16.jpg", + "images/img18.jpg", + "images/img4.jpg", + "images/img6.jpg", + "images/img8.jpg", + ] + + # check that we have newest versions for modified rows since etags are mtime + # and modified rows etags should be bigger than the old ones + assert ( + next( + dc.read_dataset(ds_name, version="1.0.1") + .filter(C("index") == 6) + .order_by("file.path", "file.etag") + .collect("file.etag") + ) + > etags[6] + ) + + +def test_delta_update_check_num_calls(test_session, tmp_dir, tmp_path, capsys): + ds_name = "delta_ds" + path = tmp_dir.as_uri() + tmp_dir = tmp_dir / "images" + os.mkdir(tmp_dir) + map_print = "In map" + + images = [ + { + "name": f"img{i}.jpg", + "data": Image.new(mode="RGB", size=((i + 1) * 10, (i + 1) * 10)), + } + for i in range(20) + ] + + # save only half of the images for now + for img in images[:10]: + img["data"].save(tmp_dir / img["name"]) + + def create_delta_dataset(): + def get_index(file: File) -> int: + print(map_print) # needed to count number of map calls + r = r".+\/img(\d+)\.jpg" + return int(re.search(r, file.path).group(1)) # type: ignore[union-attr] + + ( + dc.read_storage( + path, + update=True, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + delta_result_on=["file.source", "file.path"], + delta_compare=["file.version", "file.etag"], + ) + .map(index=get_index) + .save(ds_name) + ) + + # first version of delta dataset + create_delta_dataset() + # save other half of images + for img in images[10:]: + img["data"].save(tmp_dir / img["name"]) + # second version of delta dataset + create_delta_dataset() + + captured = capsys.readouterr() + # assert captured.out == "Garbage collecting 2 tables.\n" + assert captured.out == "\n".join([map_print] * 20) + "\n" + + +def test_delta_update_no_diff(test_session, tmp_dir, tmp_path): + ds_name = "delta_ds" + path = tmp_dir.as_uri() + tmp_dir = tmp_dir / "images" + os.mkdir(tmp_dir) + + images = [ + {"name": f"img{i}.jpg", "data": Image.new(mode="RGB", size=(64, 128))} + for i in range(10) + ] + + for img in images: + img["data"].save(tmp_dir / img["name"]) + + def create_delta_dataset(): + def get_index(file: File) -> int: + r = r".+\/img(\d+)\.jpg" + return int(re.search(r, file.path).group(1)) # type: ignore[union-attr] + + ( + dc.read_storage( + path, + update=True, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + delta_compare=["file.version", "file.etag"], + ) + .filter(C("file.path").glob("*.jpg")) + .map(index=get_index) + .filter(C("index") > 5) + .save(ds_name) + ) + + create_delta_dataset() + create_delta_dataset() + + assert list( + dc.read_dataset(ds_name, version="1.0.0") + .order_by("file.path") + .collect("file.path") + ) == [ + "images/img6.jpg", + "images/img7.jpg", + "images/img8.jpg", + "images/img9.jpg", + ] + + with pytest.raises(DatasetVersionNotFoundError) as exc_info: + dc.read_dataset(ds_name, version="1.0.1") + + assert str(exc_info.value) == f"Dataset {ds_name} does not have version 1.0.1" + + +@pytest.fixture +def file_dataset(test_session): + return dc.read_values( + file=[ + File(path="a.jpg", source="s3://bucket"), + File(path="b.jpg", source="s3://bucket"), + ], + session=test_session, + ).save("file_ds") + + +def test_delta_update_union(test_session, file_dataset): + dc.read_values(num=[10, 20], session=test_session).save("numbers") + + with pytest.raises(NotImplementedError) as excinfo: + ( + dc.read_dataset( + file_dataset.name, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + ).union(dc.read_dataset("numbers"), session=test_session) + ) + + assert str(excinfo.value) == "Delta update cannot be used with union" + + +def test_delta_update_merge(test_session, file_dataset): + dc.read_values(num=[10, 20], session=test_session).save("numbers") + + with pytest.raises(NotImplementedError) as excinfo: + ( + dc.read_dataset( + file_dataset.name, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + ).merge(dc.read_dataset("numbers"), on="id", session=test_session) + ) + + assert str(excinfo.value) == "Delta update cannot be used with merge" + + +def test_delta_update_distinct(test_session, file_dataset): + with pytest.raises(NotImplementedError) as excinfo: + ( + dc.read_dataset( + file_dataset.name, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + ).distinct("file.path") + ) + + assert str(excinfo.value) == "Delta update cannot be used with distinct" + + +def test_delta_update_group_by(test_session, file_dataset): + with pytest.raises(NotImplementedError) as excinfo: + ( + dc.read_dataset( + file_dataset.name, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + ).group_by(cnt=func.count(), partition_by="file.path") + ) + + assert str(excinfo.value) == "Delta update cannot be used with group_by" + + +def test_delta_update_agg(test_session, file_dataset): + with pytest.raises(NotImplementedError) as excinfo: + ( + dc.read_dataset( + file_dataset.name, + session=test_session, + delta=True, + delta_on=["file.source", "file.path"], + ).agg(cnt=func.count(), partition_by="file.path") + ) + + assert str(excinfo.value) == "Delta update cannot be used with agg" diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 6b5a46f0d..0a4c62c5a 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -274,6 +274,10 @@ def test_read_record_empty_chain_without_schema(test_session): ) +def test_empty(test_session): + assert dc.read_records([], schema=None, session=test_session).empty is True + + def test_empty_chain_skip_udf_run(test_session): # Test that UDF is not called for empty chain with patch.object(UDFAdapter, "run") as mock_udf_run: diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 4912670eb..2eb7fb769 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -1041,17 +1041,6 @@ def test_get_flatten_hidden_fields(schema, hidden_fields): assert SignalSchema.get_flatten_hidden_fields(schema_serialized) == hidden_fields -@pytest.mark.parametrize( - "schema,result", - [ - ({"name": str, "value": int}, False), - ({"name": str, "age": float, "f": File}, True), - ], -) -def test_contains_file(schema, result): - assert SignalSchema(schema).contains_file() is result - - def test_slice(): schema = {"name": str, "age": float, "address": str} setup_values = {"init": lambda: 37} @@ -1316,3 +1305,14 @@ class Custom(DataModel): "f": "FilePartial1@v1", "custom": "CustomPartial1@v1", } + + +def test_get_file_signal(): + assert SignalSchema({"name": str, "f": File}).get_file_signal() == "f" + assert SignalSchema({"name": str}).get_file_signal() is None + + +def test_append(): + s1 = SignalSchema({"name": str, "f": File}) + s2 = SignalSchema({"name": str, "f": File, "age": int}) + assert s1.append(s2).values == {"name": str, "f": File, "age": int}