Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions libs/core/llmstudio_core/agents/bedrock/data_models.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
from llmstudio_core.agents.data_models import (
AgentBase,
CreateAgentRequest,
ResultBase,
RunAgentRequest,
RunBase,
)


class BedrockAgent(AgentBase):
agentResourceRoleArn: str
agentStatus: str
agentVersion: str
agentArn: str
agent_resource_role_arn: str
agent_status: str
agent_arn: str
agent_alias_id: str


class BedrockRun(RunBase):
session_id: str
response: dict


class BedrockResult(ResultBase):
session_id: str


class BedrockCreateAgentRequest(CreateAgentRequest):
agent_resourcerole_arn: str
agent_resource_role_arn: str
agent_alias: str
name: str


class BedrockRunAgentRequest(RunAgentRequest):
session_id: str
agent_alias_id: str
269 changes: 253 additions & 16 deletions libs/core/llmstudio_core/agents/bedrock/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,43 @@
import boto3
from llmstudio_core.agents.bedrock.data_models import (
BedrockAgent,
BedrockResult,
BedrockCreateAgentRequest,
BedrockRun,
BedrockRunAgentRequest,
)
from llmstudio_core.agents.data_models import (
Attachment,
ImageFile,
ImageFileContent,
Message,
ResultBase,
RetrieveResultRequest,
TextContent,
)
from llmstudio_core.agents.manager import AgentManager, agent_manager
from llmstudio_core.exceptions import AgentError
from pydantic import ValidationError

SERVICE = "bedrock-agent"
AGENT_SERVICE = "bedrock-agent"
RUNTIME_SERVICE = "bedrock-agent-runtime"


@agent_manager
class BedrockAgentManager(AgentManager):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._client = boto3.client(
SERVICE,
service_name=AGENT_SERVICE,
region_name=self.region if self.region else os.getenv("BEDROCK_REGION"),
aws_access_key_id=self.access_key
if self.access_key
else os.getenv("BEDROCK_ACCESS_KEY"),
aws_secret_access_key=self.secret_key
if self.secret_key
else os.getenv("BEDROCK_SECRET_KEY"),
)
self._runtime_client = boto3.client(
service_name=RUNTIME_SERVICE,
region_name=self.region if self.region else os.getenv("BEDROCK_REGION"),
aws_access_key_id=self.access_key
if self.access_key
Expand All @@ -31,31 +54,245 @@ def _agent_config_name():
return "bedrock"

def _validate_create_request(self, request):
raise NotImplementedError("Agents need to implement the method")
return BedrockCreateAgentRequest(**request)

def _validate_run_request(self, request):
raise NotImplementedError("Agents need to implement the method")
return BedrockRunAgentRequest(**request)

def _validate_create_request(self, request):
raise NotImplementedError("Agents need to implement the method")
def _validate_result_request(self, request):
return RetrieveResultRequest(**request)

def create_agent(self, **kargs) -> BedrockAgent:
def create_agent(self, **kwargs) -> BedrockAgent:
"""
Creates a new instance of the agent.
This method validates the input parameters, creates a new agent using the client,
waits for the agent to reach the 'NOT_PREPARED' status, adds tools to the agent,
prepares the agent for use, creates an alias for the agent, and waits for the alias
to be prepared.

Args:
**kwargs: Agent creation parameters.

Returns:
BedrockAgent: An instance of the created BedrockAgent.

Raises:
AgentError: If there is a validation error or if an unsupported tool type is provided.

"""

raise NotImplementedError("Agents need to implement the 'create' method.")
try:
agent_request = self._validate_create_request(
dict(
**kwargs,
)
)

except ValidationError as e:
raise AgentError(str(e))

bedrock_create = self._client.create_agent(
agentName=agent_request.name,
foundationModel=agent_request.model,
instruction=agent_request.instructions,
agentResourceRoleArn=agent_request.agent_resource_role_arn,
)

agentId = bedrock_create["agent"]["agentId"]

# Wait for agent to reach 'NOT_PREPARED' status
agentStatus = ""
while agentStatus != "NOT_PREPARED":
response = self._client.get_agent(agentId=agentId)
agentStatus = response["agent"]["agentStatus"]

# Add tools to the agent
for tool in agent_request.tools:
if tool.type == "code_interpreter":
response = self._client.create_agent_action_group(
actionGroupName="CodeInterpreterAction",
actionGroupState="ENABLED",
agentId=agentId,
agentVersion="DRAFT",
parentActionGroupSignature="AMAZON.CodeInterpreter",
)

actionGroupId = response["agentActionGroup"]["actionGroupId"]

actionGroupStatus = ""
while actionGroupStatus != "ENABLED":
response = self._client.get_agent_action_group(
agentId=agentId,
actionGroupId=actionGroupId,
agentVersion="DRAFT",
)
actionGroupStatus = response["agentActionGroup"]["actionGroupState"]
else:
raise AgentError(f"Tool {tool.get('type')} not supported")

# Prepare the agent for use
response = self._client.prepare_agent(agentId=agentId)

# Wait for agent to reach 'PREPARED' status
agentStatus = ""
while agentStatus != "PREPARED":
response = self._client.get_agent(agentId=agentId)
agentStatus = response["agent"]["agentStatus"]

# Create an alias for the agent
response = self._client.create_agent_alias(
agentAliasName=agent_request.agent_alias, agentId=agentId
)

agentAliasId = response["agentAlias"]["agentAliasId"]

# Wait for agent alias to be prepared
agentAliasStatus = ""
while agentAliasStatus != "PREPARED":
response = self._client.get_agent_alias(
agentId=agentId, agentAliasId=agentAliasId
)
agentAliasStatus = response["agentAlias"]["agentAliasStatus"]

return BedrockAgent(
id=agentId,
created_at=int(bedrock_create["agent"]["createdAt"].timestamp()),
name=bedrock_create["agent"]["agentName"],
description=bedrock_create.get("agent", {}).get("description", None),
model=agent_request.model,
instructions=bedrock_create["agent"]["instruction"],
tools=agent_request.tools,
agent_arn=bedrock_create["agent"]["agentArn"],
agent_resource_role_arn=bedrock_create["agent"]["agentResourceRoleArn"],
agent_status=bedrock_create["agent"]["agentStatus"],
agent_alias_id=agentAliasId,
)

def run_agent(self, **kwargs) -> BedrockRun:
"""
Runs the agent
Runs the agent with the provided keyword arguments.

This method validates the run request and invokes the agent using the runtime client.
If the validation fails, an AgentError is raised.

Returns:
BedrockRun: An object containing the agent ID, status, session ID, and response of the run.

Raises:
AgentError: If the run request validation fails.
"""
raise NotImplementedError(
"Agents need to implement the 'create_thread_and_run' method."

try:
run_request = self._validate_run_request(
dict(
**kwargs,
)
)
except ValidationError as e:
raise AgentError(str(e))

sessionState = {"files": []}

for attachment in run_request.message.attachments:
if any(tool.type == "code_interpreter" for tool in attachment.tools):
sessionState["files"].append(
{
"name": attachment.file_name,
"source": {
"byteContent": {
"data": attachment.file_content,
"mediaType": attachment.file_type,
},
"sourceType": "BYTE_CONTENT",
},
"useCase": "CODE_INTERPRETER",
}
)

if isinstance(run_request.message.content, str):
input_text = run_request.message.content # Use it directly if it's a string
elif isinstance(run_request.message.content, list):
input_text = " ".join(
item.text
for item in run_request.message.content
if isinstance(item, TextContent)
)
else:
input_text = "" # Default to an empty string if content is not valid

invoke_request = self._runtime_client.invoke_agent(
agentId=run_request.agent_id,
agentAliasId=run_request.agent_alias_id,
sessionId=run_request.session_id,
inputText=input_text,
sessionState=sessionState,
)

return BedrockRun(
agent_id=run_request.agent_id,
status="completed",
session_id=run_request.session_id,
response=invoke_request,
)

def retrieve_result(self, **kwargs) -> BedrockResult:
def retrieve_result(self, **kwargs) -> ResultBase:
"""
Retrieves an existing agent.
Retrieve the result based on the provided keyword arguments.
This method validates the result request and processes the event stream to
extract content and attachments. It constructs a message with the extracted
content and attachments and returns it wrapped in a ResultBase object.

Returns:
ResultBase: An object containing the constructed message with content and attachments.
Raises:
AgentError: If the result request validation fails.
"""
raise NotImplementedError("Agents need to implement the 'retrieve' method.")

try:
result_request = self._validate_result_request(
dict(
**kwargs,
)
)

except ValidationError as e:
raise AgentError(str(e))

content = []
attachments = []
event_stream = result_request.run.response.get("completion")
for event in event_stream:
if "chunk" in event:
chunk = event["chunk"]
if "bytes" in chunk:
content.append(TextContent(text=chunk["bytes"].decode("utf-8")))

if "files" in event:
files = event["files"]["files"]
for file in files:
if file["type"] == "image/png":
content.append(
ImageFileContent(
image_file=ImageFile(
file_name=file["name"],
file_content=file["bytes"],
file_type=file["type"],
)
)
)
else:
attachments.append(
Attachment(
file_name=file["name"],
file_content=file["bytes"],
file_type=file["type"],
)
)

message = Message(
thread_id=result_request.run.session_id,
role="assistant",
content=content,
attachments=attachments,
)

return ResultBase(message=message)
Loading