Skip to content

Commit bfcac95

Browse files
authored
[ML][Pipelines] avoid changing base_path in validation (#26851)
* feat: add base path to path field validation error message * refactor: avoid creating temp folder when not necessary * feat: skip path validation if with additional includes * fix: pylint * fix: fix bandit * fix: resolve internal environment based on local code
1 parent 5d05ccb commit bfcac95

File tree

14 files changed

+222
-83
lines changed

14 files changed

+222
-83
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/_additional_includes.py

+34-15
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,38 @@
1212
from azure.ai.ml.entities._validation import MutableValidationResult, _ValidationResultBuilder
1313

1414
ADDITIONAL_INCLUDES_SUFFIX = "additional_includes"
15+
PLACEHOLDER_FILE_NAME = "_placeholder_spec.yaml"
1516

1617

1718
class _AdditionalIncludes:
1819
def __init__(self, code_path: Union[None, str], yaml_path: str):
19-
self.__yaml_path = Path(yaml_path)
20+
self.__yaml_path = yaml_path
2021
self.__code_path = code_path
2122

2223
self._tmp_code_path = None
23-
self._includes = None
24-
if self._additional_includes_file_path.is_file():
24+
self.__includes = None
25+
26+
@property
27+
def _includes(self):
28+
if not self._additional_includes_file_path.is_file():
29+
return []
30+
if self.__includes is None:
2531
with open(self._additional_includes_file_path, "r") as f:
2632
lines = f.readlines()
27-
self._includes = [line.strip() for line in lines if len(line.strip()) > 0]
33+
self.__includes = [line.strip() for line in lines if len(line.strip()) > 0]
34+
return self.__includes
35+
36+
@property
37+
def with_includes(self):
38+
return len(self._includes) != 0
2839

2940
@property
3041
def _yaml_path(self) -> Path:
31-
return self.__yaml_path
42+
if self.__yaml_path is None:
43+
# if yaml path is not specified, use a not created
44+
# temp file name
45+
return Path.cwd() / PLACEHOLDER_FILE_NAME
46+
return Path(self.__yaml_path)
3247

3348
@property
3449
def _code_path(self) -> Path:
@@ -56,25 +71,28 @@ def _copy(src: Path, dst: Path) -> None:
5671
shutil.copytree(src, dst)
5772

5873
def _validate(self) -> MutableValidationResult:
59-
# pylint: disable=too-many-return-statements
60-
if self._includes is None:
61-
return _ValidationResultBuilder.success()
74+
validation_result = _ValidationResultBuilder.success()
75+
if not self.with_includes:
76+
return validation_result
6277
for additional_include in self._includes:
6378
include_path = self._additional_includes_file_path.parent / additional_include
6479
# if additional include has not supported characters, resolve will fail and raise OSError
6580
try:
6681
src_path = include_path.resolve()
6782
except OSError:
6883
error_msg = f"Failed to resolve additional include {additional_include} for {self._yaml_name}."
69-
return _ValidationResultBuilder.from_single_message(error_msg)
84+
validation_result.append_error(message=error_msg)
85+
continue
7086

7187
if not src_path.exists():
7288
error_msg = f"Unable to find additional include {additional_include} for {self._yaml_name}."
73-
return _ValidationResultBuilder.from_single_message(error_msg)
89+
validation_result.append_error(message=error_msg)
90+
continue
7491

7592
if len(src_path.parents) == 0:
7693
error_msg = f"Root directory is not supported for additional includes for {self._yaml_name}."
77-
return _ValidationResultBuilder.from_single_message(error_msg)
94+
validation_result.append_error(message=error_msg)
95+
continue
7896

7997
dst_path = Path(self._code_path) / src_path.name
8098
if dst_path.is_symlink():
@@ -84,11 +102,12 @@ def _validate(self) -> MutableValidationResult:
84102
f"A symbolic link already exists for additional include {additional_include} "
85103
f"for {self._yaml_name}."
86104
)
87-
return _ValidationResultBuilder.from_single_message(error_msg)
105+
validation_result.append_error(message=error_msg)
106+
continue
88107
elif dst_path.exists():
89108
error_msg = f"A file already exists for additional include {additional_include} for {self._yaml_name}."
90-
return _ValidationResultBuilder.from_single_message(error_msg)
91-
return _ValidationResultBuilder.success()
109+
validation_result.append_error(message=error_msg)
110+
return validation_result
92111

93112
def resolve(self) -> None:
94113
"""Resolve code and potential additional includes.
@@ -97,7 +116,7 @@ def resolve(self) -> None:
97116
original real code path; otherwise, create a tmp folder and copy
98117
all files under real code path and additional includes to it.
99118
"""
100-
if self._includes is None:
119+
if not self.with_includes:
101120
return
102121
tmp_folder_path = Path(tempfile.mkdtemp())
103122
# code can be either file or folder, as additional includes exists, need to copy to temporary folder

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py

+23-30
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# disable redefined-builtin to use id/type as argument name
66
from contextlib import contextmanager
77
from typing import Dict, Union
8-
import os
98

109
from marshmallow import INCLUDE, Schema
1110

@@ -177,30 +176,24 @@ def _additional_includes(self):
177176
def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
178177
return InternalBaseComponentSchema(context=context)
179178

180-
def _validate(self, raise_error=False) -> MutableValidationResult:
181-
if self._additional_includes is not None and self._additional_includes._validate().passed:
182-
# update source path in case dependency file is in additional_includes
183-
with self._resolve_local_code() as tmp_base_path:
184-
origin_base_path, origin_source_path = self._base_path, self._source_path
185-
186-
try:
187-
self._base_path, self._source_path = \
188-
tmp_base_path, tmp_base_path / os.path.basename(self._source_path)
189-
return super()._validate(raise_error=raise_error)
190-
finally:
191-
self._base_path, self._source_path = origin_base_path, origin_source_path
192-
193-
return super()._validate(raise_error=raise_error)
194-
195179
def _customized_validate(self) -> MutableValidationResult:
196180
validation_result = super(InternalComponent, self)._customized_validate()
181+
if self._additional_includes.with_includes:
182+
validation_result.merge_with(self._additional_includes._validate())
183+
# resolving additional includes & update self._base_path can be dangerous,
184+
# so we just skip path validation if additional_includes is used
185+
# note that there will still be runtime error in submission or execution
186+
skip_path_validation = True
187+
else:
188+
skip_path_validation = False
197189
if isinstance(self.environment, InternalEnvironment):
198190
validation_result.merge_with(
199-
self.environment._validate(self._source_path),
191+
self.environment._validate(
192+
self._base_path,
193+
skip_path_validation=skip_path_validation
194+
),
200195
field_name="environment",
201196
)
202-
if self._additional_includes is not None:
203-
validation_result.merge_with(self._additional_includes._validate())
204197
return validation_result
205198

206199
@classmethod
@@ -215,8 +208,6 @@ def _load_from_rest(cls, obj: ComponentVersionData) -> "InternalComponent":
215208
)
216209

217210
def _to_rest_object(self) -> ComponentVersionData:
218-
if isinstance(self.environment, InternalEnvironment):
219-
self.environment.resolve(self._source_path)
220211
component = convert_ordered_dict_to_dict(self._to_dict())
221212

222213
properties = ComponentVersionDetails(
@@ -232,15 +223,17 @@ def _to_rest_object(self) -> ComponentVersionData:
232223

233224
@contextmanager
234225
def _resolve_local_code(self):
235-
# if `self._source_path` is None, component is not loaded from local yaml and
236-
# no need to resolve
237-
if self._source_path is None:
238-
yield self.code
239-
else:
240-
self._additional_includes.resolve()
241-
# use absolute path in case temp folder & work dir are in different drive
242-
yield self._additional_includes.code.absolute()
243-
self._additional_includes.cleanup()
226+
self._additional_includes.resolve()
227+
228+
# file dependency in code will be read during internal environment resolution
229+
# for example, docker file of the environment may be in additional includes
230+
# and it will be read then insert to the environment object during resolution
231+
# so we need to resolve environment based on the temporary code path
232+
if isinstance(self.environment, InternalEnvironment):
233+
self.environment.resolve(self._additional_includes.code)
234+
# use absolute path in case temp folder & work dir are in different drive
235+
yield self._additional_includes.code.absolute()
236+
self._additional_includes.cleanup()
244237

245238
def __call__(self, *args, **kwargs) -> InternalBaseNode: # pylint: disable=useless-super-delegation
246239
return super(InternalComponent, self).__call__(*args, **kwargs)

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/environment.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# ---------------------------------------------------------
44

55
from pathlib import Path
6-
from typing import Dict
6+
from typing import Dict, Union
77

88
from azure.ai.ml._utils.utils import load_yaml
99
from azure.ai.ml.constants._common import FILE_PREFIX
@@ -40,7 +40,7 @@ def __init__(
4040
def _parse_file_path(value: str) -> str:
4141
return value[len(FILE_PREFIX) :] if value.startswith(FILE_PREFIX) else value
4242

43-
def _validate_conda_section(self, source_path: str) -> MutableValidationResult:
43+
def _validate_conda_section(self, base_path: str, skip_path_validation: bool) -> MutableValidationResult:
4444
validation_result = _ValidationResultBuilder.success()
4545
if not self.conda:
4646
return validation_result
@@ -53,56 +53,56 @@ def _validate_conda_section(self, source_path: str) -> MutableValidationResult:
5353
)
5454
if self.conda.get(self.CONDA_DEPENDENCIES_FILE):
5555
conda_dependencies_file = self.conda[self.CONDA_DEPENDENCIES_FILE]
56-
if not (Path(source_path).parent / conda_dependencies_file).is_file():
56+
if not skip_path_validation and not (Path(base_path) / conda_dependencies_file).is_file():
5757
validation_result.append_error(
5858
yaml_path=f"conda.{self.CONDA_DEPENDENCIES_FILE}",
5959
message=f"Cannot find conda dependencies file: {conda_dependencies_file!r}",
6060
)
6161
if self.conda.get(self.PIP_REQUIREMENTS_FILE):
6262
pip_requirements_file = self.conda[self.PIP_REQUIREMENTS_FILE]
63-
if not (Path(source_path).parent / pip_requirements_file).is_file():
63+
if not skip_path_validation and not (Path(base_path) / pip_requirements_file).is_file():
6464
validation_result.append_error(
6565
yaml_path=f"conda.{self.PIP_REQUIREMENTS_FILE}",
6666
message=f"Cannot find pip requirements file: {pip_requirements_file!r}",
6767
)
6868
return validation_result
6969

70-
def _validate_docker_section(self, source_path: str) -> MutableValidationResult:
70+
def _validate_docker_section(self, base_path: str, skip_path_validation: bool) -> MutableValidationResult:
7171
validation_result = _ValidationResultBuilder.success()
7272
if not self.docker:
7373
return validation_result
7474
if not self.docker.get(self.BUILD) or not self.docker[self.BUILD].get(self.DOCKERFILE):
7575
return validation_result
7676
dockerfile_file = self.docker[self.BUILD][self.DOCKERFILE]
7777
dockerfile_file = self._parse_file_path(dockerfile_file)
78-
if not (Path(source_path).parent / dockerfile_file).is_file():
78+
if not skip_path_validation and not (Path(base_path) / dockerfile_file).is_file():
7979
validation_result.append_error(
8080
yaml_path=f"docker.{self.BUILD}.{self.DOCKERFILE}",
8181
message=f"Dockerfile not exists: {dockerfile_file}",
8282
)
8383
return validation_result
8484

85-
def _validate(self, source_path: str) -> MutableValidationResult:
85+
def _validate(self, base_path: str, skip_path_validation: bool = False) -> MutableValidationResult:
8686
validation_result = _ValidationResultBuilder.success()
8787
if self.os is not None and self.os not in {"Linux", "Windows"}:
8888
validation_result.append_error(
8989
yaml_path="os",
9090
message=f"Only support 'Linux' and 'Windows', but got {self.os!r}",
9191
)
92-
validation_result.merge_with(self._validate_conda_section(source_path))
93-
validation_result.merge_with(self._validate_docker_section(source_path))
92+
validation_result.merge_with(self._validate_conda_section(base_path, skip_path_validation))
93+
validation_result.merge_with(self._validate_docker_section(base_path, skip_path_validation))
9494
return validation_result
9595

96-
def _resolve_conda_section(self, source_path: str) -> None:
96+
def _resolve_conda_section(self, base_path: Union[Path, str]) -> None:
9797
if not self.conda:
9898
return
9999
if self.conda.get(self.CONDA_DEPENDENCIES_FILE):
100100
conda_dependencies_file = self.conda.pop(self.CONDA_DEPENDENCIES_FILE)
101-
self.conda[self.CONDA_DEPENDENCIES] = load_yaml(Path(source_path).parent / conda_dependencies_file)
101+
self.conda[self.CONDA_DEPENDENCIES] = load_yaml(Path(base_path) / conda_dependencies_file)
102102
return
103103
if self.conda.get(self.PIP_REQUIREMENTS_FILE):
104104
pip_requirements_file = self.conda.pop(self.PIP_REQUIREMENTS_FILE)
105-
with open(Path(source_path).parent / pip_requirements_file) as f:
105+
with open(Path(base_path) / pip_requirements_file) as f:
106106
pip_requirements = f.read().splitlines()
107107
self.conda = {
108108
self.CONDA_DEPENDENCIES: {
@@ -117,7 +117,7 @@ def _resolve_conda_section(self, source_path: str) -> None:
117117
}
118118
return
119119

120-
def _resolve_docker_section(self, source_path: str) -> None:
120+
def _resolve_docker_section(self, base_path: Union[Path, str]) -> None:
121121
if not self.docker:
122122
return
123123
if not self.docker.get(self.BUILD) or not self.docker[self.BUILD].get(self.DOCKERFILE):
@@ -126,10 +126,10 @@ def _resolve_docker_section(self, source_path: str) -> None:
126126
if not dockerfile_file.startswith(FILE_PREFIX):
127127
return
128128
dockerfile_file = self._parse_file_path(dockerfile_file)
129-
with open(Path(source_path).parent / dockerfile_file, "r") as f:
129+
with open(Path(base_path) / dockerfile_file, "r") as f:
130130
self.docker[self.BUILD][self.DOCKERFILE] = f.read()
131131
return
132132

133-
def resolve(self, source_path: str) -> None:
134-
self._resolve_conda_section(source_path)
135-
self._resolve_docker_section(source_path)
133+
def resolve(self, base_path: Union[Path, str]) -> None:
134+
self._resolve_conda_section(base_path)
135+
self._resolve_docker_section(base_path)

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,26 @@ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
112112
return super(LocalPathField, self)._serialize(value, attr, obj, **kwargs)
113113

114114
def _validate(self, value):
115+
base_path_err_msg = ""
115116
try:
116117
path = Path(value)
117118
base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
118119
if not path.is_absolute():
119120
path = base_path / path
120121
path.resolve()
122+
base_path_err_msg = f" Resolved absolute path: {path.absolute()}"
121123
if (self._allow_dir and path.is_dir()) or (self._allow_file and path.is_file()):
122124
return super(LocalPathField, self)._validate(value)
123125
except OSError:
124126
pass
125127
if self._allow_dir and self._allow_file:
126-
raise ValidationError(f"{value} is not a valid path")
127-
if self._allow_dir:
128-
raise ValidationError(f"{value} is not a valid directory")
129-
raise ValidationError(f"{value} is not a valid file")
128+
allow_type = "directory or file"
129+
elif self._allow_dir:
130+
allow_type = "directory"
131+
else:
132+
allow_type = "file"
133+
raise ValidationError(f"Value {value!r} passed is not a valid "
134+
f"{allow_type} path.{base_path_err_msg}")
130135

131136

132137
class SerializeValidatedUrl(fields.Url):

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/component.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -449,20 +449,23 @@ def __call__(self, *args, **kwargs) -> [..., Union["Command", "Parallel"]]:
449449
@contextmanager
450450
def _resolve_local_code(self):
451451
"""Resolve working directory path for the component."""
452-
with tempfile.TemporaryDirectory() as tmp_dir:
453-
if hasattr(self, "code"):
454-
code = getattr(self, "code")
455-
# Hack: when code not specified, we generated a file which contains
456-
# COMPONENT_PLACEHOLDER as code
457-
# This hack was introduced because job does not allow running component without a
458-
# code, and we need to make sure when component updated some field(eg: description),
459-
# the code remains the same. Benefit of using a constant code for all components
460-
# without code is this will generate same code for anonymous components which
461-
# enables component reuse
462-
if code is None:
452+
if hasattr(self, "code"):
453+
code = getattr(self, "code")
454+
# Hack: when code not specified, we generated a file which contains
455+
# COMPONENT_PLACEHOLDER as code
456+
# This hack was introduced because job does not allow running component without a
457+
# code, and we need to make sure when component updated some field(eg: description),
458+
# the code remains the same. Benefit of using a constant code for all components
459+
# without code is this will generate same code for anonymous components which
460+
# enables component reuse
461+
if code is None:
462+
with tempfile.TemporaryDirectory() as tmp_dir:
463463
code = Path(tmp_dir) / COMPONENT_PLACEHOLDER
464464
with open(code, "w") as f:
465465
f.write(COMPONENT_CODE_PLACEHOLDER)
466-
yield code
466+
yield code
467467
else:
468+
yield code
469+
else:
470+
with tempfile.TemporaryDirectory() as tmp_dir:
468471
yield tmp_dir

sdk/ml/azure-ai-ml/tests/component/unittests/test_component_schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def test_serialize_deserialize_default_code(self, mock_machinelearning_client: M
194194
component_entity = load_component_entity_from_yaml(test_path, mock_machinelearning_client)
195195
# make sure default code has generated with name and version as content
196196
assert component_entity.code
197-
assert COMPONENT_CODE_PLACEHOLDER == component_entity.code
197+
assert component_entity.code == COMPONENT_CODE_PLACEHOLDER
198198

199199
def test_serialize_deserialize_input_output_path(self, mock_machinelearning_client: MLClient):
200200
expected_value_dict = {

0 commit comments

Comments
 (0)