Skip to content

Commit dd6c0d8

Browse files
authored
Move back annotation conversion from mldesigner (#26843)
Signed-off-by: Brynn Yin <[email protected]> Signed-off-by: Brynn Yin <[email protected]>
1 parent 0e104a6 commit dd6c0d8

File tree

1 file changed

+42
-0
lines changed
  • sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs

1 file changed

+42
-0
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py

+42
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def _split(_fields):
210210
inherited_fields = _get_inherited_fields()
211211
# From annotations get field with type
212212
annotations = getattr(cls_or_func, "__annotations__", {})
213+
annotations = _update_io_from_mldesigner(annotations)
213214
annotation_fields = _get_fields(annotations)
214215
# Update fields use class field with defaults from class dict or signature(func).paramters
215216
if not is_func:
@@ -223,6 +224,47 @@ def _split(_fields):
223224
return all_fields
224225

225226

227+
def _update_io_from_mldesigner(annotations: dict) -> dict:
228+
"""This function will translate IOBase from mldesigner package to azure.ml.entities.Input/Output.
229+
230+
This function depend on `mldesigner._input_output._IOBase._to_io_entity_args_dict` to translate Input/Output
231+
instance annotations to IO entities.
232+
This function depend on class names of `mldesigner._input_output` to translate Input/Output class annotations
233+
to IO entities.
234+
"""
235+
from azure.ai.ml import Input, Output
236+
mldesigner_pkg = "mldesigner"
237+
238+
def _is_input_or_output_type(io: type, type_str: str):
239+
"""Return true if type name contains type_str"""
240+
if isinstance(io, type) and io.__module__.startswith(mldesigner_pkg):
241+
if type_str in io.__name__:
242+
return True
243+
return False
244+
245+
result = {}
246+
for key, io in annotations.items():
247+
if isinstance(io, type):
248+
if _is_input_or_output_type(io, "Input"):
249+
# mldesigner.Input -> entities.Input
250+
io = Input
251+
elif _is_input_or_output_type(io, "Output"):
252+
# mldesigner.Output -> entities.Output
253+
io = Output
254+
elif hasattr(io, "_to_io_entity_args_dict"):
255+
try:
256+
if _is_input_or_output_type(type(io), "Input"):
257+
# mldesigner.Input() -> entities.Input()
258+
io = Input(**io._to_io_entity_args_dict())
259+
elif _is_input_or_output_type(type(io), "Output"):
260+
# mldesigner.Output() -> entities.Output()
261+
io = Output(**io._to_io_entity_args_dict())
262+
except BaseException as e:
263+
raise UserErrorException(f"Failed to parse {io} to azure-ai-ml Input/Output: {str(e)}") from e
264+
result[key] = io
265+
return result
266+
267+
226268
def _remove_empty_values(data, ignore_keys=None):
227269
if not isinstance(data, dict):
228270
return data

0 commit comments

Comments
 (0)