-
Notifications
You must be signed in to change notification settings - Fork 429
feat(bedrock_agent): add new Amazon Bedrock Agents Functions Resolver #6564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 6 commits
41bc401
bed8f3f
a3765f0
44d80f8
abbc100
e42ceff
86c7ab7
34948d7
24978cb
45f85f6
b420a90
84bb6b0
20bbe9f
d463304
b4ab6b9
39e0d36
54a7edf
fdde207
c8b1b2f
db7d6b9
266ebcb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable | ||
|
||
from enum import Enum | ||
|
||
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent | ||
|
||
|
||
class ResponseState(Enum): | ||
FAILURE = "FAILURE" | ||
REPROMPT = "REPROMPT" | ||
|
||
|
||
class BedrockResponse: | ||
"""Response class for Bedrock Agent Functions | ||
|
||
Parameters | ||
---------- | ||
body : Any, optional | ||
Response body | ||
session_attributes : dict[str, str] | None | ||
Session attributes to include in the response | ||
prompt_session_attributes : dict[str, str] | None | ||
Prompt session attributes to include in the response | ||
status_code : int | ||
Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE) | ||
|
||
Examples | ||
-------- | ||
```python | ||
@app.tool(description="Function that uses session attributes") | ||
def test_function(): | ||
return BedrockResponse( | ||
body="Hello", | ||
session_attributes={"userId": "123"}, | ||
prompt_session_attributes={"lastAction": "login"} | ||
) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
body: Any = None, | ||
session_attributes: dict[str, str] | None = None, | ||
prompt_session_attributes: dict[str, str] | None = None, | ||
knowledge_bases: list[dict[str, Any]] | None = None, | ||
status_code: int = 200, | ||
) -> None: | ||
self.body = body | ||
self.session_attributes = session_attributes | ||
self.prompt_session_attributes = prompt_session_attributes | ||
self.knowledge_bases = knowledge_bases | ||
self.status_code = status_code | ||
|
||
|
||
class BedrockFunctionsResponseBuilder: | ||
""" | ||
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda | ||
when using Bedrock Agent Functions. | ||
|
||
Since the payload format is different from the standard API Gateway Proxy event, | ||
we override the build method. | ||
""" | ||
|
||
def __init__(self, result: BedrockResponse | Any, status_code: int = 200) -> None: | ||
self.result = result | ||
self.status_code = status_code if not isinstance(result, BedrockResponse) else result.status_code | ||
|
||
def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]: | ||
"""Build the full response dict to be returned by the lambda""" | ||
if isinstance(self.result, BedrockResponse): | ||
body = self.result.body | ||
session_attributes = self.result.session_attributes | ||
prompt_session_attributes = self.result.prompt_session_attributes | ||
knowledge_bases = self.result.knowledge_bases | ||
else: | ||
body = self.result | ||
session_attributes = None | ||
prompt_session_attributes = None | ||
knowledge_bases = None | ||
|
||
response: dict[str, Any] = { | ||
"messageVersion": "1.0", | ||
"response": { | ||
"actionGroup": event.action_group, | ||
"function": event.function, | ||
"functionResponse": {"responseBody": {"TEXT": {"body": str(body if body is not None else "")}}}, | ||
}, | ||
} | ||
|
||
# Add responseState if it's an error | ||
if self.status_code >= 400: | ||
response["response"]["functionResponse"]["responseState"] = ( | ||
ResponseState.REPROMPT.value if self.status_code == 400 else ResponseState.FAILURE.value | ||
) | ||
anafalcao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Add session attributes if provided in response or maintain from input | ||
response.update( | ||
{ | ||
"sessionAttributes": session_attributes or event.session_attributes or {}, | ||
"promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {}, | ||
}, | ||
) | ||
|
||
# Add knowledge bases configuration if provided | ||
if knowledge_bases: | ||
response["knowledgeBasesConfiguration"] = knowledge_bases | ||
|
||
return response | ||
|
||
|
||
class BedrockAgentFunctionResolver: | ||
"""Bedrock Agent Function resolver that handles function definitions | ||
|
||
Examples | ||
-------- | ||
```python | ||
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver | ||
|
||
app = BedrockAgentFunctionResolver() | ||
|
||
@app.tool(description="Gets the current UTC time") | ||
def get_current_time(): | ||
from datetime import datetime | ||
return datetime.utcnow().isoformat() | ||
|
||
def lambda_handler(event, context): | ||
return app.resolve(event, context) | ||
``` | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._tools: dict[str, dict[str, Any]] = {} | ||
self.current_event: BedrockAgentFunctionEvent | None = None | ||
self._response_builder_class = BedrockFunctionsResponseBuilder | ||
|
||
def tool( | ||
self, | ||
description: str | None = None, | ||
name: str | None = None, | ||
) -> Callable: | ||
"""Decorator to register a tool function | ||
|
||
Parameters | ||
---------- | ||
description : str | None | ||
Description of what the tool does | ||
name : str | None | ||
Custom name for the tool. If not provided, uses the function name | ||
""" | ||
|
||
def decorator(func: Callable) -> Callable: | ||
if not description: | ||
raise ValueError("Tool description is required") | ||
anafalcao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function_name = name or func.__name__ | ||
if function_name in self._tools: | ||
raise ValueError(f"Tool '{function_name}' already registered") | ||
anafalcao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._tools[function_name] = { | ||
"function": func, | ||
"description": description, | ||
} | ||
return func | ||
|
||
return decorator | ||
|
||
def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]: | ||
"""Resolves the function call from Bedrock Agent event""" | ||
try: | ||
self.current_event = BedrockAgentFunctionEvent(event) | ||
return self._resolve() | ||
except KeyError as e: | ||
raise ValueError(f"Missing required field: {str(e)}") | ||
|
||
def _resolve(self) -> dict[str, Any]: | ||
"""Internal resolution logic""" | ||
if self.current_event is None: | ||
raise ValueError("No event to process") | ||
|
||
function_name = self.current_event.function | ||
|
||
if function_name not in self._tools: | ||
return BedrockFunctionsResponseBuilder( | ||
BedrockResponse( | ||
body=f"Function not found: {function_name}", | ||
status_code=400, # Using 400 to trigger REPROMPT | ||
), | ||
).build(self.current_event) | ||
|
||
try: | ||
result = self._tools[function_name]["function"]() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A critical part of this agent's workflow is passing parameters to functions. These parameters contain values that will be useful for the function in question. In this discussion, we talk about this: Support for injecting parameters into the function signature. Can you pls add this support? |
||
return BedrockFunctionsResponseBuilder(result).build(self.current_event) | ||
except Exception as e: | ||
return BedrockFunctionsResponseBuilder( | ||
BedrockResponse( | ||
body=f"Error: {str(e)}", | ||
status_code=500, # Using 500 to trigger FAILURE | ||
), | ||
).build(self.current_event) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper | ||
|
||
|
||
class BedrockAgentInfo(DictWrapper): | ||
@property | ||
def name(self) -> str: | ||
return self["name"] | ||
|
||
@property | ||
def id(self) -> str: # noqa: A003 | ||
return self["id"] | ||
|
||
@property | ||
def alias(self) -> str: | ||
return self["alias"] | ||
|
||
@property | ||
def version(self) -> str: | ||
return self["version"] | ||
|
||
|
||
class BedrockAgentFunctionParameter(DictWrapper): | ||
@property | ||
def name(self) -> str: | ||
return self["name"] | ||
|
||
@property | ||
def type(self) -> str: # noqa: A003 | ||
return self["type"] | ||
|
||
@property | ||
def value(self) -> str: | ||
return self["value"] | ||
|
||
|
||
class BedrockAgentFunctionEvent(DictWrapper): | ||
""" | ||
Bedrock Agent Function input event | ||
|
||
Documentation: | ||
https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html | ||
""" | ||
|
||
@classmethod | ||
def validate_required_fields(cls, data: dict[str, Any]) -> None: | ||
required_fields = { | ||
"messageVersion": str, | ||
"agent": dict, | ||
"inputText": str, | ||
"sessionId": str, | ||
"actionGroup": str, | ||
"function": str, | ||
} | ||
|
||
for field, field_type in required_fields.items(): | ||
if field not in data: | ||
raise ValueError(f"Missing required field: {field}") | ||
if not isinstance(data[field], field_type): | ||
raise TypeError(f"Field {field} must be of type {field_type}") | ||
|
||
# Validate agent structure | ||
required_agent_fields = {"name", "id", "alias", "version"} | ||
if not all(field in data["agent"] for field in required_agent_fields): | ||
raise ValueError("Agent object missing required fields") | ||
|
||
def __init__(self, data: dict[str, Any]) -> None: | ||
super().__init__(data) | ||
self.validate_required_fields(data) | ||
|
||
@property | ||
def message_version(self) -> str: | ||
return self["messageVersion"] | ||
|
||
@property | ||
def input_text(self) -> str: | ||
return self["inputText"] | ||
|
||
@property | ||
def session_id(self) -> str: | ||
return self["sessionId"] | ||
|
||
@property | ||
def action_group(self) -> str: | ||
return self["actionGroup"] | ||
|
||
@property | ||
def function(self) -> str: | ||
return self["function"] | ||
|
||
@property | ||
def parameters(self) -> list[BedrockAgentFunctionParameter]: | ||
parameters = self.get("parameters") or [] | ||
return [BedrockAgentFunctionParameter(x) for x in parameters] | ||
|
||
@property | ||
def agent(self) -> BedrockAgentInfo: | ||
return BedrockAgentInfo(self["agent"]) | ||
|
||
@property | ||
def session_attributes(self) -> dict[str, str]: | ||
return self.get("sessionAttributes", {}) or {} | ||
|
||
@property | ||
def prompt_session_attributes(self) -> dict[str, str]: | ||
return self.get("promptSessionAttributes", {}) or {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
{ | ||
"messageVersion": "1.0", | ||
"agent": { | ||
"alias": "PROD", | ||
"name": "hr-assistant-function-def", | ||
"version": "1", | ||
"id": "1234abcd" | ||
}, | ||
"sessionId": "123456789123458", | ||
"sessionAttributes": { | ||
"employeeId": "EMP123" | ||
}, | ||
"promptSessionAttributes": { | ||
"lastInteraction": "2024-02-01T15:30:00Z", | ||
"requestType": "vacation" | ||
}, | ||
"inputText": "I want to request vacation from March 15 to March 20", | ||
"actionGroup": "VacationsActionGroup", | ||
"function": "submitVacationRequest", | ||
"parameters": [ | ||
{ | ||
"name": "startDate", | ||
"type": "string", | ||
"value": "2024-03-15" | ||
}, | ||
{ | ||
"name": "endDate", | ||
"type": "string", | ||
"value": "2024-03-20" | ||
} | ||
] | ||
} |
Uh oh!
There was an error while loading. Please reload this page.