Skip to content

feat(bedrock_agent): add new Amazon Bedrock Agents Functions Resolver #6564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws_lambda_powertools/event_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver, BedrockResponse
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
from aws_lambda_powertools.event_handler.lambda_function_url import (
LambdaFunctionUrlResolver,
Expand All @@ -26,9 +27,11 @@
"ALBResolver",
"ApiGatewayResolver",
"BedrockAgentResolver",
"BedrockAgentFunctionResolver",
"CORSConfig",
"LambdaFunctionUrlResolver",
"Response",
"BedrockResponse",
"VPCLatticeResolver",
"VPCLatticeV2Resolver",
]
204 changes: 204 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable

from enum import Enum

from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent


class ResponseState(Enum):
FAILURE = "FAILURE"
REPROMPT = "REPROMPT"


class BedrockResponse:
"""Response class for Bedrock Agent Functions

Parameters
----------
body : Any, optional
Response body
session_attributes : dict[str, str] | None
Session attributes to include in the response
prompt_session_attributes : dict[str, str] | None
Prompt session attributes to include in the response
status_code : int
Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE)

Examples
--------
```python
@app.tool(description="Function that uses session attributes")
def test_function():
return BedrockResponse(
body="Hello",
session_attributes={"userId": "123"},
prompt_session_attributes={"lastAction": "login"}
)
```
"""

def __init__(
self,
body: Any = None,
session_attributes: dict[str, str] | None = None,
prompt_session_attributes: dict[str, str] | None = None,
knowledge_bases: list[dict[str, Any]] | None = None,
status_code: int = 200,
) -> None:
self.body = body
self.session_attributes = session_attributes
self.prompt_session_attributes = prompt_session_attributes
self.knowledge_bases = knowledge_bases
self.status_code = status_code


class BedrockFunctionsResponseBuilder:
"""
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda
when using Bedrock Agent Functions.

Since the payload format is different from the standard API Gateway Proxy event,
we override the build method.
"""

def __init__(self, result: BedrockResponse | Any, status_code: int = 200) -> None:
self.result = result
self.status_code = status_code if not isinstance(result, BedrockResponse) else result.status_code

def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
"""Build the full response dict to be returned by the lambda"""
if isinstance(self.result, BedrockResponse):
body = self.result.body
session_attributes = self.result.session_attributes
prompt_session_attributes = self.result.prompt_session_attributes
knowledge_bases = self.result.knowledge_bases
else:
body = self.result
session_attributes = None
prompt_session_attributes = None
knowledge_bases = None

response: dict[str, Any] = {
"messageVersion": "1.0",
"response": {
"actionGroup": event.action_group,
"function": event.function,
"functionResponse": {"responseBody": {"TEXT": {"body": str(body if body is not None else "")}}},
},
}

# Add responseState if it's an error
if self.status_code >= 400:
response["response"]["functionResponse"]["responseState"] = (
ResponseState.REPROMPT.value if self.status_code == 400 else ResponseState.FAILURE.value
)

# Add session attributes if provided in response or maintain from input
response.update(
{
"sessionAttributes": session_attributes or event.session_attributes or {},
"promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {},
},
)

# Add knowledge bases configuration if provided
if knowledge_bases:
response["knowledgeBasesConfiguration"] = knowledge_bases

return response


class BedrockAgentFunctionResolver:
"""Bedrock Agent Function resolver that handles function definitions

Examples
--------
```python
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver

app = BedrockAgentFunctionResolver()

@app.tool(description="Gets the current UTC time")
def get_current_time():
from datetime import datetime
return datetime.utcnow().isoformat()

def lambda_handler(event, context):
return app.resolve(event, context)
```
"""

def __init__(self) -> None:
self._tools: dict[str, dict[str, Any]] = {}
self.current_event: BedrockAgentFunctionEvent | None = None
self._response_builder_class = BedrockFunctionsResponseBuilder

def tool(
self,
description: str | None = None,
name: str | None = None,
) -> Callable:
"""Decorator to register a tool function

Parameters
----------
description : str | None
Description of what the tool does
name : str | None
Custom name for the tool. If not provided, uses the function name
"""

def decorator(func: Callable) -> Callable:
if not description:
raise ValueError("Tool description is required")

function_name = name or func.__name__
if function_name in self._tools:
raise ValueError(f"Tool '{function_name}' already registered")

self._tools[function_name] = {
"function": func,
"description": description,
}
return func

return decorator

def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]:
"""Resolves the function call from Bedrock Agent event"""
try:
self.current_event = BedrockAgentFunctionEvent(event)
return self._resolve()
except KeyError as e:
raise ValueError(f"Missing required field: {str(e)}")

Check warning on line 178 in aws_lambda_powertools/event_handler/bedrock_agent_function.py

View check run for this annotation

Codecov / codecov/patch

aws_lambda_powertools/event_handler/bedrock_agent_function.py#L178

Added line #L178 was not covered by tests

def _resolve(self) -> dict[str, Any]:
"""Internal resolution logic"""
if self.current_event is None:
raise ValueError("No event to process")

Check warning on line 183 in aws_lambda_powertools/event_handler/bedrock_agent_function.py

View check run for this annotation

Codecov / codecov/patch

aws_lambda_powertools/event_handler/bedrock_agent_function.py#L183

Added line #L183 was not covered by tests

function_name = self.current_event.function

if function_name not in self._tools:
return BedrockFunctionsResponseBuilder(
BedrockResponse(
body=f"Function not found: {function_name}",
status_code=400, # Using 400 to trigger REPROMPT
),
).build(self.current_event)

try:
result = self._tools[function_name]["function"]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A critical part of this agent's workflow is passing parameters to functions. These parameters contain values ​​that will be useful for the function in question. In this discussion, we talk about this:

Support for injecting parameters into the function signature.

Can you pls add this support?

return BedrockFunctionsResponseBuilder(result).build(self.current_event)
except Exception as e:
return BedrockFunctionsResponseBuilder(
BedrockResponse(
body=f"Error: {str(e)}",
status_code=500, # Using 500 to trigger FAILURE
),
).build(self.current_event)
2 changes: 2 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .appsync_resolver_events_event import AppSyncResolverEventsEvent
from .aws_config_rule_event import AWSConfigRuleEvent
from .bedrock_agent_event import BedrockAgentEvent
from .bedrock_agent_function_event import BedrockAgentFunctionEvent
from .cloud_watch_alarm_event import (
CloudWatchAlarmConfiguration,
CloudWatchAlarmData,
Expand Down Expand Up @@ -59,6 +60,7 @@
"AppSyncResolverEventsEvent",
"ALBEvent",
"BedrockAgentEvent",
"BedrockAgentFunctionEvent",
"CloudWatchAlarmData",
"CloudWatchAlarmEvent",
"CloudWatchAlarmMetric",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

from typing import Any

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper


class BedrockAgentInfo(DictWrapper):
@property
def name(self) -> str:
return self["name"]

@property
def id(self) -> str: # noqa: A003
return self["id"]

@property
def alias(self) -> str:
return self["alias"]

@property
def version(self) -> str:
return self["version"]


class BedrockAgentFunctionParameter(DictWrapper):
@property
def name(self) -> str:
return self["name"]

@property
def type(self) -> str: # noqa: A003
return self["type"]

@property
def value(self) -> str:
return self["value"]


class BedrockAgentFunctionEvent(DictWrapper):
"""
Bedrock Agent Function input event

Documentation:
https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html
"""

@classmethod
def validate_required_fields(cls, data: dict[str, Any]) -> None:
required_fields = {
"messageVersion": str,
"agent": dict,
"inputText": str,
"sessionId": str,
"actionGroup": str,
"function": str,
}

for field, field_type in required_fields.items():
if field not in data:
raise ValueError(f"Missing required field: {field}")
if not isinstance(data[field], field_type):
raise TypeError(f"Field {field} must be of type {field_type}")

# Validate agent structure
required_agent_fields = {"name", "id", "alias", "version"}
if not all(field in data["agent"] for field in required_agent_fields):
raise ValueError("Agent object missing required fields")

def __init__(self, data: dict[str, Any]) -> None:
super().__init__(data)
self.validate_required_fields(data)

@property
def message_version(self) -> str:
return self["messageVersion"]

@property
def input_text(self) -> str:
return self["inputText"]

@property
def session_id(self) -> str:
return self["sessionId"]

@property
def action_group(self) -> str:
return self["actionGroup"]

@property
def function(self) -> str:
return self["function"]

@property
def parameters(self) -> list[BedrockAgentFunctionParameter]:
parameters = self.get("parameters") or []
return [BedrockAgentFunctionParameter(x) for x in parameters]

@property
def agent(self) -> BedrockAgentInfo:
return BedrockAgentInfo(self["agent"])

@property
def session_attributes(self) -> dict[str, str]:
return self.get("sessionAttributes", {}) or {}

@property
def prompt_session_attributes(self) -> dict[str, str]:
return self.get("promptSessionAttributes", {}) or {}
32 changes: 32 additions & 0 deletions tests/events/bedrockAgentFunctionEvent.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"messageVersion": "1.0",
"agent": {
"alias": "PROD",
"name": "hr-assistant-function-def",
"version": "1",
"id": "1234abcd"
},
"sessionId": "123456789123458",
"sessionAttributes": {
"employeeId": "EMP123"
},
"promptSessionAttributes": {
"lastInteraction": "2024-02-01T15:30:00Z",
"requestType": "vacation"
},
"inputText": "I want to request vacation from March 15 to March 20",
"actionGroup": "VacationsActionGroup",
"function": "submitVacationRequest",
"parameters": [
{
"name": "startDate",
"type": "string",
"value": "2024-03-15"
},
{
"name": "endDate",
"type": "string",
"value": "2024-03-20"
}
]
}
Loading
Loading