diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py index 244aca97c3cb..6e1def679c85 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py @@ -210,6 +210,7 @@ def _split(_fields): inherited_fields = _get_inherited_fields() # From annotations get field with type annotations = getattr(cls_or_func, "__annotations__", {}) + annotations = _update_io_from_mldesigner(annotations) annotation_fields = _get_fields(annotations) # Update fields use class field with defaults from class dict or signature(func).paramters if not is_func: @@ -223,6 +224,47 @@ def _split(_fields): return all_fields +def _update_io_from_mldesigner(annotations: dict) -> dict: + """This function will translate IOBase from mldesigner package to azure.ml.entities.Input/Output. + + This function depend on `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output + instance annotations to IO entities. + This function depend on class names of `mldesigner._input_output` to translate Input/Output class annotations + to IO entities. + """ + from azure.ai.ml import Input, Output + mldesigner_pkg = "mldesigner" + + def _is_input_or_output_type(io: type, type_str: str): + """Return true if type name contains type_str""" + if isinstance(io, type) and io.__module__.startswith(mldesigner_pkg): + if type_str in io.__name__: + return True + return False + + result = {} + for key, io in annotations.items(): + if isinstance(io, type): + if _is_input_or_output_type(io, "Input"): + # mldesigner.Input -> entities.Input + io = Input + elif _is_input_or_output_type(io, "Output"): + # mldesigner.Output -> entities.Output + io = Output + elif hasattr(io, "_to_io_entity_args_dict"): + try: + if _is_input_or_output_type(type(io), "Input"): + # mldesigner.Input() -> entities.Input() + io = Input(**io._to_io_entity_args_dict()) + elif _is_input_or_output_type(type(io), "Output"): + # mldesigner.Output() -> entities.Output() + io = Output(**io._to_io_entity_args_dict()) + except BaseException as e: + raise UserErrorException(f"Failed to parse {io} to azure-ai-ml Input/Output: {str(e)}") from e + result[key] = io + return result + + def _remove_empty_values(data, ignore_keys=None): if not isinstance(data, dict): return data