@@ -181,6 +181,16 @@ def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
181
181
)
182
182
return SignalSchema (signals )
183
183
184
+ @staticmethod
185
+ def _get_bases (fr : type ) -> list [tuple [str , str , Optional [str ]]]:
186
+ bases : list [tuple [str , str , Optional [str ]]] = []
187
+ for base in fr .__mro__ :
188
+ model_store_name = (
189
+ ModelStore .get_name (base ) if issubclass (base , DataModel ) else None
190
+ )
191
+ bases .append ((base .__name__ , base .__module__ , model_store_name ))
192
+ return bases
193
+
184
194
@staticmethod
185
195
def _serialize_custom_model (
186
196
version_name : str , fr : type [BaseModel ], custom_types : dict [str , Any ]
@@ -198,12 +208,7 @@ def _serialize_custom_model(
198
208
assert field_type
199
209
fields [field_name ] = SignalSchema ._serialize_type (field_type , custom_types )
200
210
201
- bases : list [tuple [str , str , Optional [str ]]] = []
202
- for type_ in fr .__mro__ :
203
- model_store_name = (
204
- ModelStore .get_name (type_ ) if issubclass (type_ , DataModel ) else None
205
- )
206
- bases .append ((type_ .__name__ , type_ .__module__ , model_store_name ))
211
+ bases = SignalSchema ._get_bases (fr )
207
212
208
213
ct = CustomType (
209
214
schema_version = 2 ,
@@ -806,3 +811,120 @@ def _build_tree_for_model(
806
811
res [name ] = (anno , subtree ) # type: ignore[assignment]
807
812
808
813
return res
814
+
815
+ def to_partial (self , * columns : str ) -> "SignalSchema" :
816
+ """
817
+ Convert the schema to a partial schema with only the specified columns.
818
+
819
+ E.g. if original schema is:
820
+
821
+ ```
822
+ signal: Foo@v1
823
+ name: str
824
+ value: float
825
+ count: int
826
+ ```
827
+
828
+ Then `to_partial("signal.name", "count")` will return a partial schema:
829
+
830
+ ```
831
+ signal: FooPartial@v1
832
+ name: str
833
+ count: int
834
+ ```
835
+
836
+ Note that partial schema will have a different name for the custom types
837
+ (e.g. `FooPartial@v1` instead of `Foo@v1`) to avoid conflicts
838
+ with the original schema.
839
+
840
+ Args:
841
+ *columns (str): The columns to include in the partial schema.
842
+
843
+ Returns:
844
+ SignalSchema: The new partial schema.
845
+ """
846
+ serialized = self .serialize ()
847
+ custom_types = serialized .get ("_custom_types" , {})
848
+
849
+ schema : dict [str , Any ] = {}
850
+ schema_custom_types : dict [str , CustomType ] = {}
851
+
852
+ data_model_bases : Optional [list [tuple [str , str , Optional [str ]]]] = None
853
+
854
+ signal_partials : dict [str , str ] = {}
855
+ partial_versions : dict [str , int ] = {}
856
+
857
+ def _type_name_to_partial (signal_name : str , type_name : str ) -> str :
858
+ if "@" not in type_name :
859
+ return type_name
860
+ model_name , _ = ModelStore .parse_name_version (type_name )
861
+
862
+ if signal_name not in signal_partials :
863
+ partial_versions .setdefault (model_name , 0 )
864
+ partial_versions [model_name ] += 1
865
+ version = partial_versions [model_name ]
866
+ signal_partials [signal_name ] = f"{ model_name } Partial{ version } "
867
+
868
+ return signal_partials [signal_name ]
869
+
870
+ for column in columns :
871
+ parent_type , parent_type_partial = "" , ""
872
+ column_parts = column .split ("." )
873
+ for i , signal in enumerate (column_parts ):
874
+ if i == 0 :
875
+ if signal not in serialized :
876
+ raise SignalSchemaError (
877
+ f"Column { column } not found in the schema"
878
+ )
879
+
880
+ parent_type = serialized [signal ]
881
+ parent_type_partial = _type_name_to_partial (signal , parent_type )
882
+
883
+ schema [signal ] = parent_type_partial
884
+ continue
885
+
886
+ if parent_type not in custom_types :
887
+ raise SignalSchemaError (
888
+ f"Custom type { parent_type } not found in the schema"
889
+ )
890
+
891
+ custom_type = custom_types [parent_type ]
892
+ signal_type = custom_type ["fields" ].get (signal )
893
+ if not signal_type :
894
+ raise SignalSchemaError (
895
+ f"Field { signal } not found in custom type { parent_type } "
896
+ )
897
+
898
+ partial_type = _type_name_to_partial (
899
+ "." .join (column_parts [: i + 1 ]),
900
+ signal_type ,
901
+ )
902
+
903
+ if parent_type_partial in schema_custom_types :
904
+ schema_custom_types [parent_type_partial ].fields [signal ] = (
905
+ partial_type
906
+ )
907
+ else :
908
+ if data_model_bases is None :
909
+ data_model_bases = SignalSchema ._get_bases (DataModel )
910
+
911
+ partial_type_name , _ = ModelStore .parse_name_version (partial_type )
912
+ schema_custom_types [parent_type_partial ] = CustomType (
913
+ schema_version = 2 ,
914
+ name = partial_type_name ,
915
+ fields = {signal : partial_type },
916
+ bases = [
917
+ (partial_type_name , "__main__" , partial_type ),
918
+ * data_model_bases ,
919
+ ],
920
+ )
921
+
922
+ parent_type , parent_type_partial = signal_type , partial_type
923
+
924
+ if schema_custom_types :
925
+ schema ["_custom_types" ] = {
926
+ type_name : ct .model_dump ()
927
+ for type_name , ct in schema_custom_types .items ()
928
+ }
929
+
930
+ return SignalSchema .deserialize (schema )
0 commit comments