-
Notifications
You must be signed in to change notification settings - Fork 429
feat(bedrock_agent): add new Amazon Bedrock Agents Functions Resolver #6564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
anafalcao
wants to merge
21
commits into
develop
Choose a base branch
from
feat/bedrock_functions
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,032
−104
Open
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
41bc401
feat(bedrock_agent): create bedrock agents functions data class
anafalcao bed8f3f
create resolver
anafalcao a3765f0
mypy
anafalcao 44d80f8
add response
anafalcao abbc100
add name param to tool
anafalcao e42ceff
add response optional fields
anafalcao 86c7ab7
bedrockfunctionresponse and response state
anafalcao 34948d7
remove body message
anafalcao 24978cb
add parser
anafalcao 45f85f6
add test for required fields
anafalcao b420a90
Merge branch 'develop' into feat/bedrock_functions
anafalcao 84bb6b0
add more tests for parser and resolver
anafalcao 20bbe9f
Merge branch 'feat/bedrock_functions' of https://github.com/aws-power…
anafalcao d463304
add validation response state
anafalcao b4ab6b9
Merge branch 'develop' into feat/bedrock_functions
leandrodamascena 39e0d36
Merge branch 'develop' into feat/bedrock_functions
leandrodamascena 54a7edf
params injection
anafalcao fdde207
doc event handler, parser and data class
anafalcao c8b1b2f
fix doc typo
anafalcao db7d6b9
fix doc typo
anafalcao 266ebcb
mypy
anafalcao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
208 changes: 208 additions & 0 deletions
208
aws_lambda_powertools/event_handler/bedrock_agent_function.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
import warnings | ||
from typing import TYPE_CHECKING, Any, Literal | ||
|
||
from aws_lambda_powertools.warnings import PowertoolsUserWarning | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable | ||
|
||
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent | ||
|
||
|
||
class BedrockFunctionResponse: | ||
"""Response class for Bedrock Agent Functions | ||
|
||
Parameters | ||
---------- | ||
body : Any, optional | ||
Response body | ||
session_attributes : dict[str, str] | None | ||
Session attributes to include in the response | ||
prompt_session_attributes : dict[str, str] | None | ||
Prompt session attributes to include in the response | ||
response_state : Literal["FAILURE", "REPROMPT"] | None | ||
Response state ("FAILURE" or "REPROMPT") | ||
|
||
Examples | ||
-------- | ||
```python | ||
@app.tool(description="Function that uses session attributes") | ||
def test_function(): | ||
return BedrockFunctionResponse( | ||
body="Hello", | ||
session_attributes={"userId": "123"}, | ||
prompt_session_attributes={"lastAction": "login"} | ||
) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
body: Any = None, | ||
session_attributes: dict[str, str] | None = None, | ||
prompt_session_attributes: dict[str, str] | None = None, | ||
knowledge_bases: list[dict[str, Any]] | None = None, | ||
response_state: Literal["FAILURE", "REPROMPT"] | None = None, | ||
) -> None: | ||
if response_state is not None and response_state not in ["FAILURE", "REPROMPT"]: | ||
raise ValueError("responseState must be 'FAILURE' or 'REPROMPT'") | ||
|
||
self.body = body | ||
self.session_attributes = session_attributes | ||
self.prompt_session_attributes = prompt_session_attributes | ||
self.knowledge_bases = knowledge_bases | ||
self.response_state = response_state | ||
|
||
|
||
class BedrockFunctionsResponseBuilder: | ||
""" | ||
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda | ||
when using Bedrock Agent Functions. | ||
""" | ||
|
||
def __init__(self, result: BedrockFunctionResponse | Any) -> None: | ||
self.result = result | ||
|
||
def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]: | ||
"""Build the full response dict to be returned by the lambda""" | ||
if isinstance(self.result, BedrockFunctionResponse): | ||
body = self.result.body | ||
session_attributes = self.result.session_attributes | ||
prompt_session_attributes = self.result.prompt_session_attributes | ||
knowledge_bases = self.result.knowledge_bases | ||
response_state = self.result.response_state | ||
|
||
else: | ||
body = self.result | ||
session_attributes = None | ||
prompt_session_attributes = None | ||
knowledge_bases = None | ||
response_state = None | ||
|
||
# Per AWS Bedrock documentation, currently only "TEXT" is supported as the responseBody content type | ||
# https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html | ||
response: dict[str, Any] = { | ||
"messageVersion": "1.0", | ||
"response": { | ||
"actionGroup": event.action_group, | ||
"function": event.function, | ||
"functionResponse": {"responseBody": {"TEXT": {"body": str(body if body is not None else "")}}}, | ||
}, | ||
} | ||
|
||
# Add responseState if provided | ||
if response_state: | ||
response["response"]["functionResponse"]["responseState"] = response_state | ||
|
||
# Add session attributes if provided in response or maintain from input | ||
response.update( | ||
{ | ||
"sessionAttributes": session_attributes or event.session_attributes or {}, | ||
"promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {}, | ||
}, | ||
) | ||
|
||
# Add knowledge bases configuration if provided | ||
if knowledge_bases: | ||
response["knowledgeBasesConfiguration"] = knowledge_bases | ||
|
||
return response | ||
|
||
|
||
class BedrockAgentFunctionResolver: | ||
"""Bedrock Agent Function resolver that handles function definitions | ||
|
||
Examples | ||
-------- | ||
```python | ||
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver | ||
|
||
app = BedrockAgentFunctionResolver() | ||
|
||
@app.tool(description="Gets the current UTC time") | ||
def get_current_time(): | ||
from datetime import datetime | ||
return datetime.utcnow().isoformat() | ||
|
||
def lambda_handler(event, context): | ||
return app.resolve(event, context) | ||
``` | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._tools: dict[str, dict[str, Any]] = {} | ||
self.current_event: BedrockAgentFunctionEvent | None = None | ||
self._response_builder_class = BedrockFunctionsResponseBuilder | ||
|
||
def tool( | ||
self, | ||
description: str | None = None, | ||
name: str | None = None, | ||
) -> Callable: | ||
"""Decorator to register a tool function | ||
|
||
Parameters | ||
---------- | ||
description : str | None | ||
Description of what the tool does | ||
name : str | None | ||
Custom name for the tool. If not provided, uses the function name | ||
""" | ||
|
||
def decorator(func: Callable) -> Callable: | ||
function_name = name or func.__name__ | ||
if function_name in self._tools: | ||
warnings.warn( | ||
f"Tool '{function_name}' already registered. Overwriting with new definition.", | ||
PowertoolsUserWarning, | ||
stacklevel=2, | ||
) | ||
|
||
self._tools[function_name] = { | ||
"function": func, | ||
"description": description, | ||
} | ||
return func | ||
|
||
return decorator | ||
|
||
def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]: | ||
"""Resolves the function call from Bedrock Agent event""" | ||
try: | ||
self.current_event = BedrockAgentFunctionEvent(event) | ||
return self._resolve() | ||
except KeyError as e: | ||
raise ValueError(f"Missing required field: {str(e)}") | ||
|
||
def _resolve(self) -> dict[str, Any]: | ||
"""Internal resolution logic""" | ||
if self.current_event is None: | ||
raise ValueError("No event to process") | ||
|
||
function_name = self.current_event.function | ||
|
||
try: | ||
parameters = {} | ||
if hasattr(self.current_event, "parameters"): | ||
for param in self.current_event.parameters: | ||
parameters[param.name] = param.value | ||
|
||
func = self._tools[function_name]["function"] | ||
sig = inspect.signature(func) | ||
|
||
valid_params = {} | ||
for name, value in parameters.items(): | ||
if name in sig.parameters: | ||
valid_params[name] = value | ||
|
||
result = func(**valid_params) | ||
return BedrockFunctionsResponseBuilder(result).build(self.current_event) | ||
except Exception as e: | ||
return BedrockFunctionsResponseBuilder( | ||
BedrockFunctionResponse( | ||
body=f"Error: {str(e)}", | ||
), | ||
).build(self.current_event) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import annotations | ||
|
||
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper | ||
|
||
|
||
class BedrockAgentInfo(DictWrapper): | ||
@property | ||
def name(self) -> str: | ||
return self["name"] | ||
|
||
@property | ||
def id(self) -> str: # noqa: A003 | ||
return self["id"] | ||
|
||
@property | ||
def alias(self) -> str: | ||
return self["alias"] | ||
|
||
@property | ||
def version(self) -> str: | ||
return self["version"] | ||
|
||
|
||
class BedrockAgentFunctionParameter(DictWrapper): | ||
@property | ||
def name(self) -> str: | ||
return self["name"] | ||
|
||
@property | ||
def type(self) -> str: # noqa: A003 | ||
return self["type"] | ||
|
||
@property | ||
def value(self) -> str: | ||
return self["value"] | ||
|
||
|
||
class BedrockAgentFunctionEvent(DictWrapper): | ||
""" | ||
Bedrock Agent Function input event | ||
|
||
Documentation: | ||
https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html | ||
""" | ||
|
||
@property | ||
def message_version(self) -> str: | ||
return self["messageVersion"] | ||
|
||
@property | ||
def input_text(self) -> str: | ||
return self["inputText"] | ||
|
||
@property | ||
def session_id(self) -> str: | ||
return self["sessionId"] | ||
|
||
@property | ||
def action_group(self) -> str: | ||
return self["actionGroup"] | ||
|
||
@property | ||
def function(self) -> str: | ||
return self["function"] | ||
|
||
@property | ||
def parameters(self) -> list[BedrockAgentFunctionParameter]: | ||
parameters = self.get("parameters") or [] | ||
return [BedrockAgentFunctionParameter(x) for x in parameters] | ||
|
||
@property | ||
def agent(self) -> BedrockAgentInfo: | ||
return BedrockAgentInfo(self["agent"]) | ||
|
||
@property | ||
def session_attributes(self) -> dict[str, str]: | ||
return self.get("sessionAttributes", {}) or {} | ||
|
||
@property | ||
def prompt_session_attributes(self) -> dict[str, str]: | ||
return self.get("promptSessionAttributes", {}) or {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.