@@ -210,6 +210,7 @@ def _split(_fields):
210
210
inherited_fields = _get_inherited_fields ()
211
211
# From annotations get field with type
212
212
annotations = getattr (cls_or_func , "__annotations__" , {})
213
+ annotations = _update_io_from_mldesigner (annotations )
213
214
annotation_fields = _get_fields (annotations )
214
215
# Update fields use class field with defaults from class dict or signature(func).paramters
215
216
if not is_func :
@@ -223,6 +224,47 @@ def _split(_fields):
223
224
return all_fields
224
225
225
226
227
+ def _update_io_from_mldesigner (annotations : dict ) -> dict :
228
+ """This function will translate IOBase from mldesigner package to azure.ml.entities.Input/Output.
229
+
230
+ This function depend on `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output
231
+ instance annotations to IO entities.
232
+ This function depend on class names of `mldesigner._input_output` to translate Input/Output class annotations
233
+ to IO entities.
234
+ """
235
+ from azure .ai .ml import Input , Output
236
+ mldesigner_pkg = "mldesigner"
237
+
238
+ def _is_input_or_output_type (io : type , type_str : str ):
239
+ """Return true if type name contains type_str"""
240
+ if isinstance (io , type ) and io .__module__ .startswith (mldesigner_pkg ):
241
+ if type_str in io .__name__ :
242
+ return True
243
+ return False
244
+
245
+ result = {}
246
+ for key , io in annotations .items ():
247
+ if isinstance (io , type ):
248
+ if _is_input_or_output_type (io , "Input" ):
249
+ # mldesigner.Input -> entities.Input
250
+ io = Input
251
+ elif _is_input_or_output_type (io , "Output" ):
252
+ # mldesigner.Output -> entities.Output
253
+ io = Output
254
+ elif hasattr (io , "_to_io_entity_args_dict" ):
255
+ try :
256
+ if _is_input_or_output_type (type (io ), "Input" ):
257
+ # mldesigner.Input() -> entities.Input()
258
+ io = Input (** io ._to_io_entity_args_dict ())
259
+ elif _is_input_or_output_type (type (io ), "Output" ):
260
+ # mldesigner.Output() -> entities.Output()
261
+ io = Output (** io ._to_io_entity_args_dict ())
262
+ except BaseException as e :
263
+ raise UserErrorException (f"Failed to parse { io } to azure-ai-ml Input/Output: { str (e )} " ) from e
264
+ result [key ] = io
265
+ return result
266
+
267
+
226
268
def _remove_empty_values (data , ignore_keys = None ):
227
269
if not isinstance (data , dict ):
228
270
return data
0 commit comments