Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None

def __class_getitem__(cls, item):
def __class_getitem__(cls, item: Type[T]) -> Type["Flow"]:
class _FlowGeneric(cls):
_initial_state_T = item
_initial_state_T: Type[T] = item

_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
return _FlowGeneric

def __init__(self):
def __init__(self) -> None:
self._methods: Dict[str, Callable] = {}
self._state = self._create_initial_state()
self._state: T = self._create_initial_state()
self._completed_methods: Set[str] = set()
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
Expand Down Expand Up @@ -212,11 +213,11 @@ async def kickoff(self) -> Any:
else:
return None # Or raise an exception if no methods were executed

async def _execute_start_method(self, start_method: str):
async def _execute_start_method(self, start_method: str) -> None:
result = await self._execute_method(self._methods[start_method])
await self._execute_listeners(start_method, result)

async def _execute_method(self, method: Callable, *args, **kwargs):
async def _execute_method(self, method: Callable, *args: Any, **kwargs: Any) -> Any:
result = (
await method(*args, **kwargs)
if asyncio.iscoroutinefunction(method)
Expand All @@ -225,7 +226,7 @@ async def _execute_method(self, method: Callable, *args, **kwargs):
self._method_outputs.append(result) # Store the output
return result

async def _execute_listeners(self, trigger_method: str, result: Any):
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
listener_tasks = []

if trigger_method in self._routers:
Expand Down Expand Up @@ -253,7 +254,7 @@ async def _execute_listeners(self, trigger_method: str, result: Any):
# Run all listener tasks concurrently and wait for them to complete
await asyncio.gather(*listener_tasks)

async def _execute_single_listener(self, listener: str, result: Any):
async def _execute_single_listener(self, listener: str, result: Any) -> None:
try:
method = self._methods[listener]
sig = inspect.signature(method)
Expand All @@ -277,7 +278,7 @@ async def _execute_single_listener(self, listener: str, result: Any):

traceback.print_exc()

def plot(self, filename: str = "crewai_flow"):
def plot(self, filename: str = "crewai_flow") -> None:
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())
)
Expand Down