diff --git a/aws_lambda_powertools/utilities/batch/__init__.py b/aws_lambda_powertools/utilities/batch/__init__.py index 584342e5fd0..463f6f7fbff 100644 --- a/aws_lambda_powertools/utilities/batch/__init__.py +++ b/aws_lambda_powertools/utilities/batch/__init__.py @@ -8,11 +8,11 @@ BasePartialProcessor, BatchProcessor, EventType, - ExceptionInfo, FailureResponse, SuccessResponse, batch_processor, ) +from aws_lambda_powertools.utilities.batch.exceptions import ExceptionInfo from aws_lambda_powertools.utilities.batch.sqs import PartialSQSProcessor, sqs_batch_processor __all__ = ( diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 02eb00ffaed..d8fdc2d85f2 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -8,11 +8,10 @@ import sys from abc import ABC, abstractmethod from enum import Enum -from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, overload from aws_lambda_powertools.middleware_factory import lambda_handler_decorator -from aws_lambda_powertools.utilities.batch.exceptions import BatchProcessingError +from aws_lambda_powertools.utilities.batch.exceptions import BatchProcessingError, ExceptionInfo from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord @@ -30,8 +29,6 @@ class EventType(Enum): # type specifics # has_pydantic = "pydantic" in sys.modules -ExceptionInfo = Tuple[Type[BaseException], BaseException, TracebackType] -OptExcInfo = Union[ExceptionInfo, Tuple[None, None, None]] # For IntelliSense and Mypy to work, we need to account for possible SQS, Kinesis and DynamoDB subclasses # We need them as subclasses as we must access their message ID or sequence number metadata via dot notation @@ -61,7 +58,7 @@ class BasePartialProcessor(ABC): def __init__(self): self.success_messages: List[BatchEventTypes] = [] self.fail_messages: List[BatchEventTypes] = [] - self.exceptions: List = [] + self.exceptions: List[ExceptionInfo] = [] @abstractmethod def _prepare(self): @@ -132,7 +129,7 @@ def success_handler(self, record, result: Any) -> SuccessResponse: self.success_messages.append(record) return entry - def failure_handler(self, record, exception: OptExcInfo) -> FailureResponse: + def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: """ Keeps track of batch records that failed processing @@ -140,7 +137,7 @@ def failure_handler(self, record, exception: OptExcInfo) -> FailureResponse: ---------- record: Any record that failed processing - exception: OptExcInfo + exception: ExceptionInfo Exception information containing type, value, and traceback (sys.exc_info()) Returns @@ -411,32 +408,28 @@ def _get_messages_to_report(self) -> Dict[str, str]: def _collect_sqs_failures(self): if self.model: return {"itemIdentifier": msg.messageId for msg in self.fail_messages} - else: - return {"itemIdentifier": msg.message_id for msg in self.fail_messages} + return {"itemIdentifier": msg.message_id for msg in self.fail_messages} def _collect_kinesis_failures(self): if self.model: # Pydantic model uses int but Lambda poller expects str return {"itemIdentifier": msg.kinesis.sequenceNumber for msg in self.fail_messages} - else: - return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages} + return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages} def _collect_dynamodb_failures(self): if self.model: return {"itemIdentifier": msg.dynamodb.SequenceNumber for msg in self.fail_messages} - else: - return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages} + return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages} @overload def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels": - ... + ... # pragma: no cover @overload def _to_batch_type(self, record: dict, event_type: EventType) -> EventSourceDataClassTypes: - ... + ... # pragma: no cover def _to_batch_type(self, record: dict, event_type: EventType, model: Optional["BatchTypeModels"] = None): if model is not None: return model.parse_obj(record) - else: - return self._DATA_CLASS_MAPPING[event_type](record) + return self._DATA_CLASS_MAPPING[event_type](record) diff --git a/aws_lambda_powertools/utilities/batch/exceptions.py b/aws_lambda_powertools/utilities/batch/exceptions.py index fe51433a5d6..dc4ca300c7c 100644 --- a/aws_lambda_powertools/utilities/batch/exceptions.py +++ b/aws_lambda_powertools/utilities/batch/exceptions.py @@ -2,11 +2,14 @@ Batch processing exceptions """ import traceback -from typing import Optional, Tuple +from types import TracebackType +from typing import List, Optional, Tuple, Type + +ExceptionInfo = Tuple[Type[BaseException], BaseException, TracebackType] class BaseBatchProcessingError(Exception): - def __init__(self, msg="", child_exceptions=()): + def __init__(self, msg="", child_exceptions: Optional[List[ExceptionInfo]] = None): super().__init__(msg) self.msg = msg self.child_exceptions = child_exceptions @@ -24,7 +27,7 @@ def format_exceptions(self, parent_exception_str): class SQSBatchProcessingError(BaseBatchProcessingError): """When at least one message within a batch could not be processed""" - def __init__(self, msg="", child_exceptions: Optional[Tuple[Exception]] = None): + def __init__(self, msg="", child_exceptions: Optional[List[ExceptionInfo]] = None): super().__init__(msg, child_exceptions) # Overriding this method so we can output all child exception tracebacks when we raise this exception to prevent @@ -37,7 +40,7 @@ def __str__(self): class BatchProcessingError(BaseBatchProcessingError): """When all batch records failed to be processed""" - def __init__(self, msg="", child_exceptions: Optional[Tuple[Exception]] = None): + def __init__(self, msg="", child_exceptions: Optional[List[ExceptionInfo]] = None): super().__init__(msg, child_exceptions) def __str__(self): diff --git a/aws_lambda_powertools/utilities/parser/parser.py b/aws_lambda_powertools/utilities/parser/parser.py index a12f163f8a6..ef939cd11f7 100644 --- a/aws_lambda_powertools/utilities/parser/parser.py +++ b/aws_lambda_powertools/utilities/parser/parser.py @@ -86,12 +86,12 @@ def handler(event: Order, context: LambdaContext): @overload def parse(event: Dict[str, Any], model: Type[Model]) -> Model: - ... + ... # pragma: no cover @overload def parse(event: Dict[str, Any], model: Type[Model], envelope: Type[Envelope]) -> EnvelopeModel: - ... + ... # pragma: no cover def parse(event: Dict[str, Any], model: Type[Model], envelope: Optional[Type[Envelope]] = None): diff --git a/pyproject.toml b/pyproject.toml index 532d78d0051..c3b87d4f093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,9 @@ exclude_lines = [ # Don't complain if non-runnable code isn't run: "if 0:", "if __name__ == .__main__.:", + + # Ignore type function overload + "@overload", ] [tool.isort] diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index a8cf652d8a0..51e142bfa55 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -1057,3 +1057,15 @@ def two(data): assert one(data=mock_event) == "one" assert two(data=mock_event) == "two" assert len(persistence_store.table.method_calls) == 4 + + +def test_invalid_dynamodb_persistence_layer(): + # Scenario constructing a DynamoDBPersistenceLayer with a key_attr matching sort_key_attr should fail + with pytest.raises(ValueError) as ve: + DynamoDBPersistenceLayer( + table_name="Foo", + key_attr="id", + sort_key_attr="id", + ) + # and raise a ValueError + assert str(ve.value) == "key_attr [id] and sort_key_attr [id] cannot be the same!" diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index cd6fc67ea15..3728af3111d 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -832,5 +832,7 @@ def lambda_handler(event, context): return processor.response() # WHEN/THEN - with pytest.raises(BatchProcessingError): + with pytest.raises(BatchProcessingError) as e: lambda_handler(event, {}) + ret = str(e) + assert ret is not None