diff --git a/aws_lambda_powertools/metrics/metric.py b/aws_lambda_powertools/metrics/metric.py index 76ff4339dea..94b427738a1 100644 --- a/aws_lambda_powertools/metrics/metric.py +++ b/aws_lambda_powertools/metrics/metric.py @@ -1,7 +1,7 @@ import json import logging from contextlib import contextmanager -from typing import Dict, Optional, Union, Generator +from typing import Dict, Generator, Optional, Union from .base import MetricManager, MetricUnit @@ -61,7 +61,9 @@ def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float) -> N @contextmanager -def single_metric(name: str, unit: MetricUnit, value: float, namespace: Optional[str] = None) -> Generator[SingleMetric, None, None]: +def single_metric( + name: str, unit: MetricUnit, value: float, namespace: Optional[str] = None +) -> Generator[SingleMetric, None, None]: """Context manager to simplify creation of a single metric Example diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index b3b907bc18b..b059a3b2483 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -93,6 +93,8 @@ def get( raise GetParameterError(str(exc)) if transform is not None: + if isinstance(value, bytes): + value = value.decode("utf-8") value = transform_value(value, transform) self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) @@ -100,7 +102,7 @@ def get( return value @abstractmethod - def _get(self, name: str, **sdk_options) -> str: + def _get(self, name: str, **sdk_options) -> Union[str, bytes]: """ Retrieve parameter value from the underlying parameter store """ diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 79b8bfb2fd0..47fc5a0e982 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1503,7 +1503,7 @@ def test_appconf_provider_get_configuration_no_transform(mock_name, config): stubber.activate() try: - value = provider.get(mock_name) + value: str = provider.get(mock_name) str_value = value.decode("utf-8") assert str_value == json.dumps(mock_body_json) stubber.assert_no_pending_responses() @@ -1516,11 +1516,12 @@ def test_appconf_get_app_config_no_transform(monkeypatch, mock_name): Test get_app_config() """ mock_body_json = {"myenvvar1": "Black Panther", "myenvvar2": 3} + mock_body_bytes = str.encode(json.dumps(mock_body_json)) class TestProvider(BaseProvider): - def _get(self, name: str, **kwargs) -> str: + def _get(self, name: str, **kwargs) -> bytes: assert name == mock_name - return json.dumps(mock_body_json).encode("utf-8") + return mock_body_bytes def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: raise NotImplementedError() @@ -1532,6 +1533,30 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: value = parameters.get_app_config(mock_name, environment=environment, application=application) str_value = value.decode("utf-8") assert str_value == json.dumps(mock_body_json) + assert value == mock_body_bytes + + +def test_appconf_get_app_config_transform_json(monkeypatch, mock_name): + """ + Test get_app_config() + """ + mock_body_json = {"myenvvar1": "Black Panther", "myenvvar2": 3} + mock_body_bytes = str.encode(json.dumps(mock_body_json)) + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_body_bytes + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "appconfig", TestProvider()) + + environment = "dev" + application = "myapp" + value = parameters.get_app_config(mock_name, environment=environment, application=application, transform="json") + assert value == mock_body_json def test_appconf_get_app_config_new(monkeypatch, mock_name, mock_value):