Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 20 additions & 30 deletions src/crewai/agents/crew_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from typing import Any, Dict, List, Union

from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import CrewAgentParser
from crewai.agents.parser import (
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE,
AgentAction,
AgentFinish,
CrewAgentParser,
OutputParserException,
)
from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N, Printer
Expand All @@ -13,12 +19,6 @@
)
from crewai.utilities.logger import Logger
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.agents.parser import (
AgentAction,
AgentFinish,
OutputParserException,
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE,
)


class CrewAgentExecutor(CrewAgentExecutorMixin):
Expand Down Expand Up @@ -307,34 +307,24 @@ def _handle_crew_training_output(
) -> None:
"""Function to handle the process of the training data."""
agent_id = str(self.agent.id)
if (
CrewTrainingHandler(TRAINING_DATA_FILE).load()
and not self.ask_for_human_input
):
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
if training_data.get(agent_id):
training_data[agent_id][self.crew._train_iteration][
"improved_output"
] = result.output
CrewTrainingHandler(TRAINING_DATA_FILE).save(training_data)

if self.ask_for_human_input and human_feedback is not None:
training_data = {
"initial_output": result.output,
"human_feedback": human_feedback,
"agent": agent_id,
"agent_role": self.agent.role,
}

# Load training data
training_handler = CrewTrainingHandler(TRAINING_DATA_FILE)
training_data = training_handler.load()

# Check if training data exists, human input is not requested, and self.crew is valid
if training_data and not self.ask_for_human_input:
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
train_iteration = self.crew._train_iteration
if isinstance(train_iteration, int):
CrewTrainingHandler(TRAINING_DATA_FILE).append(
train_iteration, agent_id, training_data
)
if agent_id in training_data and isinstance(train_iteration, int):
training_data[agent_id][train_iteration][
"improved_output"
] = result.output
training_handler.save(training_data)
else:
self._logger.log(
"error",
"Invalid train iteration type. Expected int.",
"Invalid train iteration type or agent_id not in training data.",
color="red",
)
else:
Expand Down
3 changes: 3 additions & 0 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import uuid
import warnings
from concurrent.futures import Future
from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -57,6 +58,8 @@
if TYPE_CHECKING:
from crewai.pipeline.pipeline import Pipeline

warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")


class Crew(BaseModel):
"""
Expand Down
6 changes: 3 additions & 3 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None

def __class_getitem__(cls, item: Type[T]) -> Type["Flow"]:
class _FlowGeneric(cls):
_initial_state_T: Type[T] = item
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore

_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
return _FlowGeneric
Expand Down
82 changes: 37 additions & 45 deletions src/crewai/project/annotations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import wraps
from typing import Callable

from crewai.project.utils import memoize
from crewai import Crew
from crewai.project.utils import memoize


def task(func):
if not hasattr(task, "registration_order"):
task.registration_order = []
func.is_task = True

@wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -15,9 +15,6 @@ def wrapper(*args, **kwargs):
result.name = func.__name__
return result

setattr(wrapper, "is_task", True)
task.registration_order.append(func.__name__)

return memoize(wrapper)


Expand Down Expand Up @@ -73,50 +70,45 @@ def pipeline(func):
return memoize(func)


def crew(func) -> "Crew":
def wrapper(self, *args, **kwargs):
def crew(func) -> Callable[..., Crew]:
def wrapper(self, *args, **kwargs) -> Crew:
instantiated_tasks = []
instantiated_agents = []

agent_roles = set()
all_functions = {
name: getattr(self, name)
for name in dir(self)
if callable(getattr(self, name))
}
tasks = {
name: func
for name, func in all_functions.items()
if hasattr(func, "is_task")
}
agents = {
name: func
for name, func in all_functions.items()
if hasattr(func, "is_agent")
}

# Sort tasks by their registration order
sorted_task_names = sorted(
tasks, key=lambda name: task.registration_order.index(name)
)

# Instantiate tasks in the order they were defined
for task_name in sorted_task_names:
task_instance = tasks[task_name]()

# Collect methods from crew in order
all_functions = [
(name, getattr(self, name))
for name, attr in self.__class__.__dict__.items()
if callable(attr)
]
tasks = [
(name, method)
for name, method in all_functions
if hasattr(method, "is_task")
]

agents = [
(name, method)
for name, method in all_functions
if hasattr(method, "is_agent")
]

# Instantiate tasks in order
for task_name, task_method in tasks:
task_instance = task_method()
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
if agent_instance is not None:
agent_instance = task_instance.agent
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
agent_roles.add(agent_instance.role)

# Instantiate any additional agents not already included by tasks
for agent_name in agents:
temp_agent_instance = agents[agent_name]()
if temp_agent_instance.role not in agent_roles:
instantiated_agents.append(temp_agent_instance)
agent_roles.add(temp_agent_instance.role)
if agent_instance and agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
agent_roles.add(agent_instance.role)

# Instantiate agents not included by tasks
for agent_name, agent_method in agents:
agent_instance = agent_method()
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
agent_roles.add(agent_instance.role)

self.agents = instantiated_agents
self.tasks = instantiated_tasks
Expand Down
8 changes: 4 additions & 4 deletions src/crewai/project/crew_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import inspect
from pathlib import Path
from typing import Any, Callable, Dict, Type, TypeVar
from typing import Any, Callable, Dict, TypeVar, cast

import yaml
from dotenv import load_dotenv

load_dotenv()

T = TypeVar("T", bound=Type[Any])
T = TypeVar("T", bound=type)


def CrewBase(cls: T) -> T:
class WrappedClass(cls):
class WrappedClass(cls): # type: ignore
is_crew_class: bool = True # type: ignore

# Get the directory of the class being decorated
Expand Down Expand Up @@ -180,4 +180,4 @@ def _map_task_variables(
callback_functions[callback]() for callback in callbacks
]

return WrappedClass
return cast(T, WrappedClass)