1
- from typing import Dict , Any , Optional , Union , List , Callable
1
+ from typing import Dict , Any , Optional , Type , Union , List , Callable
2
2
3
- from spark_pipeline_framework .logger .log_level import LogLevel
4
3
from spark_pipeline_framework .utilities .capture_parameters import capture_parameters
5
4
from pyspark .ml import Transformer
6
5
from pyspark .sql .dataframe import DataFrame
@@ -18,7 +17,7 @@ def __init__(
18
17
self ,
19
18
* ,
20
19
raise_on_exception : Optional [Union [bool , Callable [[DataFrame ], bool ]]] = True ,
21
- error_exception : Optional [ Exception ] = Exception ,
20
+ error_exception : Type [ BaseException ] = BaseException ,
22
21
stages : Union [List [Transformer ], Callable [[], List [Transformer ]]],
23
22
exception_stages : Optional [
24
23
Union [List [Transformer ], Callable [[], List [Transformer ]]]
@@ -50,18 +49,18 @@ def __init__(
50
49
raise_on_exception
51
50
)
52
51
53
- self .error_exception : Optional [ Exception ] = error_exception
52
+ self .error_exception : Type [ BaseException ] = error_exception
54
53
self .stages : Union [List [Transformer ], Callable [[], List [Transformer ]]] = stages
55
- self .exception_stages : Optional [
56
- Union [ List [Transformer ], Callable [[], List [Transformer ] ]]
54
+ self .exception_stages : Union [
55
+ List [Transformer ], Callable [[], List [Transformer ]]
57
56
] = (exception_stages or [])
58
57
59
58
self .loop_id : Optional [str ] = None
60
59
61
60
kwargs = self ._input_kwargs
62
61
self .setParams (** kwargs )
63
62
64
- async def _transform_async (self , df ) :
63
+ async def _transform_async (self , df : DataFrame ) -> DataFrame :
65
64
"""
66
65
Executes the transformation pipeline asynchronously.
67
66
@@ -81,8 +80,8 @@ async def run_pipeline(
81
80
df : DataFrame ,
82
81
stages : Union [List [Transformer ], Callable [[], List [Transformer ]]],
83
82
progress_logger : Optional [ProgressLogger ],
84
- ):
85
- stages : List [ Transformer ] = stages if not callable (stages ) else stages ()
83
+ ) -> None :
84
+ stages = stages if not callable (stages ) else stages ()
86
85
nonlocal stage_name
87
86
88
87
for stage in stages :
@@ -112,7 +111,6 @@ async def run_pipeline(
112
111
progress_logger .write_to_log (
113
112
self .getName () or "FrameworkExceptionHandlerTransformer" ,
114
113
f"Failed while running steps with error: { e } . Run execution steps: { isinstance (e , self .error_exception )} " ,
115
- log_level = LogLevel .INFO ,
116
114
)
117
115
118
116
try :
@@ -140,8 +138,8 @@ def as_dict(self) -> Dict[str, Any]:
140
138
else str (self .stages )
141
139
),
142
140
"exception_stages" : (
143
- [s .as_dict () for s in self .else_stages ] # type: ignore
144
- if self .else_stages and not callable (self .else_stages )
145
- else str (self .else_stages )
141
+ [s .as_dict () for s in self .exception_stages ] # type: ignore
142
+ if self .exception_stages and not callable (self .exception_stages )
143
+ else str (self .exception_stages )
146
144
),
147
145
}
0 commit comments