Skip to content

Commit a353180

Browse files
committed
DFP-3901: Added FrameworkExceptionHandlerTransformer
1 parent a1257fd commit a353180

File tree

6 files changed

+204
-0
lines changed

6 files changed

+204
-0
lines changed

spark_pipeline_framework/transformers/framework_exception_handler/__init__.py

Whitespace-only changes.

spark_pipeline_framework/transformers/framework_exception_handler/v1/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from typing import Dict, Any, Optional, Union, List, Callable
2+
3+
from spark_pipeline_framework.logger.log_level import LogLevel
4+
from spark_pipeline_framework.utilities.capture_parameters import capture_parameters
5+
from pyspark.ml import Transformer
6+
from pyspark.sql.dataframe import DataFrame
7+
from spark_pipeline_framework.logger.yarn_logger import get_logger
8+
from spark_pipeline_framework.progress_logger.progress_logger import ProgressLogger
9+
from spark_pipeline_framework.transformers.framework_transformer.v1.framework_transformer import (
10+
FrameworkTransformer,
11+
)
12+
13+
14+
class FrameworkExceptionHandlerTransformer(FrameworkTransformer):
15+
# noinspection PyUnusedLocal
16+
@capture_parameters
17+
def __init__(
18+
self,
19+
*,
20+
raise_on_exception: Optional[Union[bool, Callable[[DataFrame], bool]]] = True,
21+
error_exception: Optional[Exception] = Exception,
22+
stages: Union[List[Transformer], Callable[[], List[Transformer]]],
23+
exception_stages: Optional[
24+
Union[List[Transformer], Callable[[], List[Transformer]]]
25+
] = None,
26+
name: Optional[str] = None,
27+
parameters: Optional[Dict[str, Any]] = None,
28+
progress_logger: Optional[ProgressLogger] = None,
29+
):
30+
"""
31+
Executes a sequence of stages (transformers) and, in case of an exception, executes a separate
32+
sequence of exception-handling stages.
33+
34+
:param: raise_on_exception: Determines whether to raise exceptions when errors occur.
35+
:param: error_exception: The exception type to catch.
36+
:param: stages: The primary sequence of transformers to execute.
37+
:param: exception_stages: Stages to execute if an error occurs.
38+
:param: name: Name of the transformer.
39+
:param: parameters: Additional parameters.
40+
:param: progress_logger: Logger instance for tracking execution.
41+
42+
"""
43+
super().__init__(
44+
name=name, parameters=parameters, progress_logger=progress_logger
45+
)
46+
47+
self.logger = get_logger(__name__)
48+
49+
self.raise_on_exception: Optional[Union[bool, Callable[[DataFrame], bool]]] = (
50+
raise_on_exception
51+
)
52+
53+
self.error_exception: Optional[Exception] = error_exception
54+
self.stages: Union[List[Transformer], Callable[[], List[Transformer]]] = stages
55+
self.exception_stages: Optional[
56+
Union[List[Transformer], Callable[[], List[Transformer]]]
57+
] = (exception_stages or [])
58+
59+
self.loop_id: Optional[str] = None
60+
61+
kwargs = self._input_kwargs
62+
self.setParams(**kwargs)
63+
64+
async def _transform_async(self, df):
65+
"""
66+
Executes the transformation pipeline asynchronously.
67+
68+
- Runs `stages` normally.
69+
- If an exception occurs, logs the error and executes `exception_stages` if provided.
70+
- Optionally raises exceptions based on `raise_on_exception`.
71+
"""
72+
progress_logger: Optional[ProgressLogger] = self.getProgressLogger()
73+
stage_name = ""
74+
raise_on_exception = (
75+
self.raise_on_exception
76+
if not callable(self.raise_on_exception)
77+
else self.raise_on_exception(df)
78+
)
79+
80+
async def run_pipeline(
81+
df: DataFrame,
82+
stages: Union[List[Transformer], Callable[[], List[Transformer]]],
83+
progress_logger: Optional[ProgressLogger],
84+
):
85+
stages: List[Transformer] = stages if not callable(stages) else stages()
86+
nonlocal stage_name
87+
88+
for stage in stages:
89+
stage_name = (
90+
stage.getName()
91+
if hasattr(stage, "getName")
92+
else stage.__class__.__name__
93+
)
94+
if progress_logger:
95+
progress_logger.start_mlflow_run(
96+
run_name=stage_name, is_nested=True
97+
)
98+
if hasattr(stage, "set_loop_id"):
99+
stage.set_loop_id(self.loop_id)
100+
df = (
101+
await stage.transform_async(df)
102+
if hasattr(stage, "transform_async")
103+
else stage.transform(df)
104+
)
105+
if progress_logger:
106+
progress_logger.end_mlflow_run()
107+
108+
try:
109+
await run_pipeline(df, self.stages, progress_logger)
110+
except Exception as e:
111+
if progress_logger:
112+
progress_logger.write_to_log(
113+
self.getName() or "FrameworkExceptionHandlerTransformer",
114+
f"Failed while running steps with error: {e}. Run execution steps: {isinstance(e, self.error_exception)}",
115+
log_level=LogLevel.INFO,
116+
)
117+
118+
try:
119+
if isinstance(e, self.error_exception):
120+
await run_pipeline(df, self.exception_stages, progress_logger)
121+
except Exception as err:
122+
err.args = (f"In Exception Stage ({stage_name})", *err.args)
123+
raise err
124+
125+
# Raise error if `raise_on_exception` is True or if an exception other than `self.error_exception` is thrown.
126+
if raise_on_exception or not isinstance(e, self.error_exception):
127+
e.args = (f"In Stage ({stage_name})", *e.args)
128+
raise e
129+
130+
return df
131+
132+
def as_dict(self) -> Dict[str, Any]:
133+
134+
return {
135+
**(super().as_dict()),
136+
"raise_on_exception": self.raise_on_exception,
137+
"stages": (
138+
[s.as_dict() for s in self.stages] # type: ignore
139+
if not callable(self.stages)
140+
else str(self.stages)
141+
),
142+
"exception_stages": (
143+
[s.as_dict() for s in self.else_stages] # type: ignore
144+
if self.else_stages and not callable(self.else_stages)
145+
else str(self.else_stages)
146+
),
147+
}

spark_pipeline_framework/transformers/framework_exception_handler/v1/test/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
NPI,Practice ID - Scheduling Dept,Last Name,First Name,Primary Speciality,Primary Specialty Group,Protocol,Minimum Age,Maximum Age,Workers Compensation,Auto Accident,New Patients,Video Visits
2+
1111111111,RBB-MRGR,BOB,JAMES,Internal Medicine,Internal Medicine,Primary Care,18,No Limit,Y,Y,Y,Y
3+
1111111111,RBB-CAG,BILL,JAMES,Internal Medicine,Internal Medicine,Primary Care,18,No Limit,Y,Y,Y,Y
4+
1111111112,RBB-CAG,BILL,JAMES,Internal Medicine,Internal Medicine,Primary Care,18,No Limit,Y,Y,Y,Y
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from pathlib import Path
2+
3+
from pyspark.sql import SparkSession, DataFrame
4+
5+
from spark_pipeline_framework.progress_logger.progress_logger import ProgressLogger
6+
from spark_pipeline_framework.transformers.framework_csv_loader.v1.framework_csv_loader import (
7+
FrameworkCsvLoader,
8+
)
9+
from spark_pipeline_framework.transformers.framework_exception_handler.v1.framework_exception_handler import (
10+
FrameworkExceptionHandlerTransformer,
11+
)
12+
13+
from spark_pipeline_framework.utilities.spark_data_frame_helpers import (
14+
create_empty_dataframe,
15+
)
16+
17+
18+
def test_framework_exception_handle(spark_session: SparkSession) -> None:
19+
20+
# create a dataframe with the test data
21+
data_dir: Path = Path(__file__).parent.joinpath("./")
22+
df: DataFrame = create_empty_dataframe(spark_session=spark_session)
23+
invalid_view: str = "invalid_view"
24+
valid_view = "view"
25+
26+
with ProgressLogger() as progress_logger:
27+
FrameworkExceptionHandlerTransformer(
28+
name="Exception Handler Test",
29+
stages=[
30+
# A step that tries to load a non-existent CSV file (should fail)
31+
FrameworkCsvLoader(
32+
view=invalid_view,
33+
file_path=data_dir.joinpath("invalid_location.csv"),
34+
clean_column_names=False,
35+
)
36+
],
37+
exception_stages=[
38+
FrameworkCsvLoader(
39+
view=valid_view,
40+
file_path=data_dir.joinpath("primary_care_protocol.csv"),
41+
clean_column_names=False,
42+
)
43+
],
44+
raise_on_exception=False,
45+
progress_logger=progress_logger
46+
).transform(df)
47+
result_df: DataFrame = spark_session.table(valid_view)
48+
49+
# Assert that the exception-handling stage has successfully run
50+
assert result_df.count() == 3
51+
52+
# Verify that the invalid view was NOT created, confirming that the original stage failed
53+
assert not spark_session.catalog.tableExists(invalid_view)

0 commit comments

Comments
 (0)