Skip to content

feat(appsync): add Router to allow large resolver composition #776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from abc import ABC
from typing import Any, Callable, Optional, Type, TypeVar

from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
Expand All @@ -9,7 +10,33 @@
AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent)


class AppSyncResolver:
class BaseRouter(ABC):
current_event: AppSyncResolverEventT # type: ignore[valid-type]
lambda_context: LambdaContext

def __init__(self):
self._resolvers: dict = {}

def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
"""Registers the resolver for field_name

Parameters
----------
type_name : str
Type name
field_name : str
Field name
"""

def register_resolver(func):
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
self._resolvers[f"{type_name}.{field_name}"] = {"func": func}
return func

return register_resolver


class AppSyncResolver(BaseRouter):
"""
AppSync resolver decorator

Expand Down Expand Up @@ -40,29 +67,8 @@ def common_field() -> str:
return str(uuid.uuid4())
"""

current_event: AppSyncResolverEventT # type: ignore[valid-type]
lambda_context: LambdaContext

def __init__(self):
self._resolvers: dict = {}

def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
"""Registers the resolver for field_name

Parameters
----------
type_name : str
Type name
field_name : str
Field name
"""

def register_resolver(func):
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
self._resolvers[f"{type_name}.{field_name}"] = {"func": func}
return func

return register_resolver
super().__init__()

def resolve(
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
Expand Down Expand Up @@ -136,10 +142,10 @@ def lambda_handler(event, context):
ValueError
If we could not find a field resolver
"""
self.current_event = data_model(event)
self.lambda_context = context
resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name)
return resolver(**self.current_event.arguments)
BaseRouter.current_event = data_model(event)
BaseRouter.lambda_context = context
resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name)
return resolver(**BaseRouter.current_event.arguments)

def _get_resolver(self, type_name: str, field_name: str) -> Callable:
"""Get resolver for field_name
Expand Down Expand Up @@ -167,3 +173,18 @@ def __call__(
) -> Any:
"""Implicit lambda handler which internally calls `resolve`"""
return self.resolve(event, context, data_model)

def include_resolver(self, resolver: "Resolver") -> None:
"""Adds all resolvers defined in a resolver

Parameters
----------
resolver : Resolver
A resolver containing a dict of field resolvers
"""
self._resolvers.update(resolver._resolvers)


class Resolver(BaseRouter):
def __init__(self):
super().__init__()
27 changes: 27 additions & 0 deletions tests/functional/event_handler/test_appsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from aws_lambda_powertools.event_handler import AppSyncResolver
from aws_lambda_powertools.event_handler.appsync import Resolver
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing import LambdaContext
from tests.functional.utils import load_event
Expand Down Expand Up @@ -161,3 +162,29 @@ def create_something(id: str): # noqa AA03 VNE003
assert result == "my identifier"

assert app.current_event.country_viewer == "US"


def test_resolver_include_resolver():
# GIVEN
app = AppSyncResolver()
resolver = Resolver()

@resolver.resolver(type_name="Query", field_name="listLocations")
def get_locations(name: str):
return "get_locations#" + name

@app.resolver(field_name="listLocations2")
def get_locations2(name: str):
return "get_locations2#" + name

app.include_resolver(resolver)

# WHEN
mock_event1 = {"typeName": "Query", "fieldName": "listLocations", "arguments": {"name": "value"}}
mock_event2 = {"typeName": "Query", "fieldName": "listLocations2", "arguments": {"name": "value"}}
result1 = app.resolve(mock_event1, LambdaContext())
result2 = app.resolve(mock_event2, LambdaContext())

# THEN
assert result1 == "get_locations#value"
assert result2 == "get_locations2#value"