Skip to content

Commit edfe8fd

Browse files
committed
DFP-3901: Added FrameworkExceptionHandlerTransformer
1 parent a1257fd commit edfe8fd

File tree

6 files changed

+204
-0
lines changed

6 files changed

+204
-0
lines changed

spark_pipeline_framework/transformers/framework_exception_handler/__init__.py

Whitespace-only changes.

spark_pipeline_framework/transformers/framework_exception_handler/v1/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from typing import Dict, Any, Optional, Type, Union, List, Callable
2+
3+
from spark_pipeline_framework.utilities.capture_parameters import capture_parameters
4+
from pyspark.ml import Transformer
5+
from pyspark.sql.dataframe import DataFrame
6+
from spark_pipeline_framework.logger.yarn_logger import get_logger
7+
from spark_pipeline_framework.progress_logger.progress_logger import ProgressLogger
8+
from spark_pipeline_framework.transformers.framework_transformer.v1.framework_transformer import (
9+
FrameworkTransformer,
10+
)
11+
12+
13+
class FrameworkExceptionHandlerTransformer(FrameworkTransformer):
14+
# noinspection PyUnusedLocal
15+
@capture_parameters
16+
def __init__(
17+
self,
18+
*,
19+
raise_on_exception: Optional[Union[bool, Callable[[DataFrame], bool]]] = True,
20+
error_exception: Type[BaseException] = BaseException,
21+
stages: Union[List[Transformer], Callable[[], List[Transformer]]],
22+
exception_stages: Optional[
23+
Union[List[Transformer], Callable[[], List[Transformer]]]
24+
] = None,
25+
name: Optional[str] = None,
26+
parameters: Optional[Dict[str, Any]] = None,
27+
progress_logger: Optional[ProgressLogger] = None,
28+
):
29+
"""
30+
Executes a sequence of stages (transformers) and, in case of an exception, executes a separate
31+
sequence of exception-handling stages.
32+
33+
:param: raise_on_exception: Determines whether to raise exceptions when errors occur.
34+
:param: error_exception: The exception type to catch.
35+
:param: stages: The primary sequence of transformers to execute.
36+
:param: exception_stages: Stages to execute if an error occurs.
37+
:param: name: Name of the transformer.
38+
:param: parameters: Additional parameters.
39+
:param: progress_logger: Logger instance for tracking execution.
40+
41+
"""
42+
super().__init__(
43+
name=name, parameters=parameters, progress_logger=progress_logger
44+
)
45+
46+
self.logger = get_logger(__name__)
47+
48+
self.raise_on_exception: Optional[Union[bool, Callable[[DataFrame], bool]]] = (
49+
raise_on_exception
50+
)
51+
52+
self.error_exception: Type[BaseException] = error_exception
53+
self.stages: Union[List[Transformer], Callable[[], List[Transformer]]] = stages
54+
self.exception_stages: Union[
55+
List[Transformer], Callable[[], List[Transformer]]
56+
] = (exception_stages or [])
57+
58+
self.loop_id: Optional[str] = None
59+
60+
kwargs = self._input_kwargs
61+
self.setParams(**kwargs)
62+
63+
async def _transform_async(self, df: DataFrame) -> DataFrame:
64+
"""
65+
Executes the transformation pipeline asynchronously.
66+
67+
- Runs `stages` normally.
68+
- If an exception occurs, logs the error and executes `exception_stages` if provided.
69+
- Optionally raises exceptions based on `raise_on_exception`.
70+
"""
71+
progress_logger: Optional[ProgressLogger] = self.getProgressLogger()
72+
stage_name = ""
73+
raise_on_exception = (
74+
self.raise_on_exception
75+
if not callable(self.raise_on_exception)
76+
else self.raise_on_exception(df)
77+
)
78+
79+
async def run_pipeline(
80+
df: DataFrame,
81+
stages: Union[List[Transformer], Callable[[], List[Transformer]]],
82+
progress_logger: Optional[ProgressLogger],
83+
) -> None:
84+
stages = stages if not callable(stages) else stages()
85+
nonlocal stage_name
86+
87+
for stage in stages:
88+
stage_name = (
89+
stage.getName()
90+
if hasattr(stage, "getName")
91+
else stage.__class__.__name__
92+
)
93+
if progress_logger:
94+
progress_logger.start_mlflow_run(
95+
run_name=stage_name, is_nested=True
96+
)
97+
if hasattr(stage, "set_loop_id"):
98+
stage.set_loop_id(self.loop_id)
99+
df = (
100+
await stage.transform_async(df)
101+
if hasattr(stage, "transform_async")
102+
else stage.transform(df)
103+
)
104+
if progress_logger:
105+
progress_logger.end_mlflow_run()
106+
107+
try:
108+
await run_pipeline(df, self.stages, progress_logger)
109+
except Exception as e:
110+
if progress_logger:
111+
progress_logger.write_to_log(
112+
self.getName() or "FrameworkExceptionHandlerTransformer",
113+
f"Failed while running steps in stage: {stage_name}. Run execution steps: {isinstance(e, self.error_exception)}",
114+
)
115+
# Assigning it to new variable as stage_name will be updated when running exception stages
116+
failed_stage_name = stage_name
117+
118+
try:
119+
if isinstance(e, self.error_exception):
120+
await run_pipeline(df, self.exception_stages, progress_logger)
121+
except Exception as err:
122+
err.args = (f"In Exception Stage ({stage_name})", *err.args)
123+
raise err
124+
125+
# Raise error if `raise_on_exception` is True or if an exception other than `self.error_exception` is thrown.
126+
if raise_on_exception or not isinstance(e, self.error_exception):
127+
e.args = (f"In Stage ({failed_stage_name})", *e.args)
128+
raise e
129+
130+
return df
131+
132+
def as_dict(self) -> Dict[str, Any]:
133+
134+
return {
135+
**(super().as_dict()),
136+
"raise_on_exception": self.raise_on_exception,
137+
"stages": (
138+
[s.as_dict() for s in self.stages] # type: ignore
139+
if not callable(self.stages)
140+
else str(self.stages)
141+
),
142+
"exception_stages": (
143+
[s.as_dict() for s in self.exception_stages] # type: ignore
144+
if self.exception_stages and not callable(self.exception_stages)
145+
else str(self.exception_stages)
146+
),
147+
}

spark_pipeline_framework/transformers/framework_exception_handler/v1/test/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
NPI,Practice ID - Scheduling Dept,Last Name,First Name,Primary Speciality,Primary Specialty Group,Protocol,Minimum Age,Maximum Age,Workers Compensation,Auto Accident,New Patients,Video Visits
2+
1111111111,RBB-MRGR,BOB,JAMES,Internal Medicine,Internal Medicine,Primary Care,18,No Limit,Y,Y,Y,Y
3+
1111111111,RBB-CAG,BILL,JAMES,Internal Medicine,Internal Medicine,Primary Care,18,No Limit,Y,Y,Y,Y
4+
1111111112,RBB-CAG,BILL,JAMES,Internal Medicine,Internal Medicine,Primary Care,18,No Limit,Y,Y,Y,Y
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from pathlib import Path
2+
3+
from pyspark.sql import SparkSession, DataFrame
4+
5+
from spark_pipeline_framework.progress_logger.progress_logger import ProgressLogger
6+
from spark_pipeline_framework.transformers.framework_csv_loader.v1.framework_csv_loader import (
7+
FrameworkCsvLoader,
8+
)
9+
from spark_pipeline_framework.transformers.framework_exception_handler.v1.framework_exception_handler import (
10+
FrameworkExceptionHandlerTransformer,
11+
)
12+
13+
from spark_pipeline_framework.utilities.spark_data_frame_helpers import (
14+
create_empty_dataframe,
15+
)
16+
17+
18+
def test_framework_exception_handle(spark_session: SparkSession) -> None:
19+
20+
# create a dataframe with the test data
21+
data_dir: Path = Path(__file__).parent.joinpath("./")
22+
df: DataFrame = create_empty_dataframe(spark_session=spark_session)
23+
invalid_view: str = "invalid_view"
24+
valid_view = "view"
25+
26+
with ProgressLogger() as progress_logger:
27+
FrameworkExceptionHandlerTransformer(
28+
name="Exception Handler Test",
29+
stages=[
30+
# A step that tries to load a non-existent CSV file (should fail)
31+
FrameworkCsvLoader(
32+
view=invalid_view,
33+
file_path=data_dir.joinpath("invalid_location.csv"),
34+
clean_column_names=False,
35+
)
36+
],
37+
exception_stages=[
38+
FrameworkCsvLoader(
39+
view=valid_view,
40+
file_path=data_dir.joinpath("primary_care_protocol.csv"),
41+
clean_column_names=False,
42+
)
43+
],
44+
raise_on_exception=False,
45+
progress_logger=progress_logger,
46+
).transform(df)
47+
result_df: DataFrame = spark_session.table(valid_view)
48+
49+
# Assert that the exception-handling stage has successfully run
50+
assert result_df.count() == 3
51+
52+
# Verify that the invalid view was NOT created, confirming that the original stage failed
53+
assert not spark_session.catalog.tableExists(invalid_view)

0 commit comments

Comments
 (0)