Skip to content

Commit 37a5e2e

Browse files
authored
Fix group_by for nested columns (#927)
+ Allow select from chain if no 'sys' columns exists
1 parent 1b7bcaf commit 37a5e2e

File tree

5 files changed

+615
-16
lines changed

5 files changed

+615
-16
lines changed

src/datachain/lib/dc.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,7 +1057,7 @@ def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
10571057
def select(self, *args: str, _sys: bool = True) -> "Self":
10581058
"""Select only a specified set of signals."""
10591059
new_schema = self.signals_schema.resolve(*args)
1060-
if _sys:
1060+
if self._sys and _sys:
10611061
new_schema = SignalSchema({"sys": Sys}) | new_schema
10621062
columns = new_schema.db_signals()
10631063
return self._evolve(
@@ -1100,17 +1100,21 @@ def group_by(
11001100
partition_by_columns: list[Column] = []
11011101
signal_columns: list[Column] = []
11021102
schema_fields: dict[str, DataType] = {}
1103+
keep_columns: list[str] = []
11031104

11041105
# validate partition_by columns and add them to the schema
11051106
for col in partition_by:
11061107
if isinstance(col, str):
11071108
col_db_name = ColumnMeta.to_db_name(col)
11081109
col_type = self.signals_schema.get_column_type(col_db_name)
11091110
column = Column(col_db_name, python_to_sql(col_type))
1111+
if col not in keep_columns:
1112+
keep_columns.append(col)
11101113
elif isinstance(col, Function):
11111114
column = col.get_column(self.signals_schema)
11121115
col_db_name = column.name
11131116
col_type = column.type.python_type
1117+
schema_fields[col_db_name] = col_type
11141118
else:
11151119
raise DataChainColumnError(
11161120
col,
@@ -1120,7 +1124,6 @@ def group_by(
11201124
),
11211125
)
11221126
partition_by_columns.append(column)
1123-
schema_fields[col_db_name] = col_type
11241127

11251128
# validate signal columns and add them to the schema
11261129
if not kwargs:
@@ -1135,9 +1138,13 @@ def group_by(
11351138
signal_columns.append(column)
11361139
schema_fields[col_name] = func.get_result_type(self.signals_schema)
11371140

1141+
signal_schema = SignalSchema(schema_fields)
1142+
if keep_columns:
1143+
signal_schema |= self.signals_schema.to_partial(*keep_columns)
1144+
11381145
return self._evolve(
11391146
query=self._query.group_by(signal_columns, partition_by_columns),
1140-
signal_schema=SignalSchema(schema_fields),
1147+
signal_schema=signal_schema,
11411148
)
11421149

11431150
def mutate(self, **kwargs) -> "Self":

src/datachain/lib/signal_schema.py

Lines changed: 128 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,16 @@ def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
181181
)
182182
return SignalSchema(signals)
183183

184+
@staticmethod
185+
def _get_bases(fr: type) -> list[tuple[str, str, Optional[str]]]:
186+
bases: list[tuple[str, str, Optional[str]]] = []
187+
for base in fr.__mro__:
188+
model_store_name = (
189+
ModelStore.get_name(base) if issubclass(base, DataModel) else None
190+
)
191+
bases.append((base.__name__, base.__module__, model_store_name))
192+
return bases
193+
184194
@staticmethod
185195
def _serialize_custom_model(
186196
version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
@@ -198,12 +208,7 @@ def _serialize_custom_model(
198208
assert field_type
199209
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
200210

201-
bases: list[tuple[str, str, Optional[str]]] = []
202-
for type_ in fr.__mro__:
203-
model_store_name = (
204-
ModelStore.get_name(type_) if issubclass(type_, DataModel) else None
205-
)
206-
bases.append((type_.__name__, type_.__module__, model_store_name))
211+
bases = SignalSchema._get_bases(fr)
207212

208213
ct = CustomType(
209214
schema_version=2,
@@ -806,3 +811,120 @@ def _build_tree_for_model(
806811
res[name] = (anno, subtree) # type: ignore[assignment]
807812

808813
return res
814+
815+
def to_partial(self, *columns: str) -> "SignalSchema":
816+
"""
817+
Convert the schema to a partial schema with only the specified columns.
818+
819+
E.g. if original schema is:
820+
821+
```
822+
signal: Foo@v1
823+
name: str
824+
value: float
825+
count: int
826+
```
827+
828+
Then `to_partial("signal.name", "count")` will return a partial schema:
829+
830+
```
831+
signal: FooPartial@v1
832+
name: str
833+
count: int
834+
```
835+
836+
Note that partial schema will have a different name for the custom types
837+
(e.g. `FooPartial@v1` instead of `Foo@v1`) to avoid conflicts
838+
with the original schema.
839+
840+
Args:
841+
*columns (str): The columns to include in the partial schema.
842+
843+
Returns:
844+
SignalSchema: The new partial schema.
845+
"""
846+
serialized = self.serialize()
847+
custom_types = serialized.get("_custom_types", {})
848+
849+
schema: dict[str, Any] = {}
850+
schema_custom_types: dict[str, CustomType] = {}
851+
852+
data_model_bases: Optional[list[tuple[str, str, Optional[str]]]] = None
853+
854+
signal_partials: dict[str, str] = {}
855+
partial_versions: dict[str, int] = {}
856+
857+
def _type_name_to_partial(signal_name: str, type_name: str) -> str:
858+
if "@" not in type_name:
859+
return type_name
860+
model_name, _ = ModelStore.parse_name_version(type_name)
861+
862+
if signal_name not in signal_partials:
863+
partial_versions.setdefault(model_name, 0)
864+
partial_versions[model_name] += 1
865+
version = partial_versions[model_name]
866+
signal_partials[signal_name] = f"{model_name}Partial{version}"
867+
868+
return signal_partials[signal_name]
869+
870+
for column in columns:
871+
parent_type, parent_type_partial = "", ""
872+
column_parts = column.split(".")
873+
for i, signal in enumerate(column_parts):
874+
if i == 0:
875+
if signal not in serialized:
876+
raise SignalSchemaError(
877+
f"Column {column} not found in the schema"
878+
)
879+
880+
parent_type = serialized[signal]
881+
parent_type_partial = _type_name_to_partial(signal, parent_type)
882+
883+
schema[signal] = parent_type_partial
884+
continue
885+
886+
if parent_type not in custom_types:
887+
raise SignalSchemaError(
888+
f"Custom type {parent_type} not found in the schema"
889+
)
890+
891+
custom_type = custom_types[parent_type]
892+
signal_type = custom_type["fields"].get(signal)
893+
if not signal_type:
894+
raise SignalSchemaError(
895+
f"Field {signal} not found in custom type {parent_type}"
896+
)
897+
898+
partial_type = _type_name_to_partial(
899+
".".join(column_parts[: i + 1]),
900+
signal_type,
901+
)
902+
903+
if parent_type_partial in schema_custom_types:
904+
schema_custom_types[parent_type_partial].fields[signal] = (
905+
partial_type
906+
)
907+
else:
908+
if data_model_bases is None:
909+
data_model_bases = SignalSchema._get_bases(DataModel)
910+
911+
partial_type_name, _ = ModelStore.parse_name_version(partial_type)
912+
schema_custom_types[parent_type_partial] = CustomType(
913+
schema_version=2,
914+
name=partial_type_name,
915+
fields={signal: partial_type},
916+
bases=[
917+
(partial_type_name, "__main__", partial_type),
918+
*data_model_bases,
919+
],
920+
)
921+
922+
parent_type, parent_type_partial = signal_type, partial_type
923+
924+
if schema_custom_types:
925+
schema["_custom_types"] = {
926+
type_name: ct.model_dump()
927+
for type_name, ct in schema_custom_types.items()
928+
}
929+
930+
return SignalSchema.deserialize(schema)

0 commit comments

Comments
 (0)