diff --git a/flo_ai/router/flo_linear.py b/flo_ai/router/flo_linear.py index 6d5f8b82..e890b810 100644 --- a/flo_ai/router/flo_linear.py +++ b/flo_ai/router/flo_linear.py @@ -8,7 +8,12 @@ class FloLinear(FloRouter): - def __init__(self, session: FloSession, name: str, flo_team: FloTeam): + def __init__( + self, + session: FloSession, + name: str, + flo_team: FloTeam, + ): super().__init__( session=session, name=name, @@ -61,7 +66,12 @@ def create(session: FloSession, name: str, team: FloTeam): return FloLinear.Builder(session=session, name=name, flo_team=team).build() class Builder: - def __init__(self, session: FloSession, name: str, flo_team: FloTeam) -> None: + def __init__( + self, + session: FloSession, + name: str, + flo_team: FloTeam, + ) -> None: self.name = name self.session = session self.team = flo_team diff --git a/flo_ai/router/flo_llm_router.py b/flo_ai/router/flo_llm_router.py index e5bc0ed0..38b8a5b1 100644 --- a/flo_ai/router/flo_llm_router.py +++ b/flo_ai/router/flo_llm_router.py @@ -24,7 +24,7 @@ def __init__( executor: Runnable, flo_team: FloTeam, name: str, - model_name: str, + model_name: str = 'default', ) -> None: super().__init__( session=session, @@ -56,12 +56,14 @@ def create( router_prompt: str = None, llm: Union[BaseLanguageModel, None] = None, ): + model_name = 'default' if llm is None else llm.name return FloLLMRouter.Builder( session=session, name=name, flo_team=team, router_prompt=router_prompt, llm=llm, + model_nick_name=model_name, ).build() class Builder: @@ -72,6 +74,7 @@ def __init__( flo_team: FloTeam, router_prompt: str = None, llm: Union[BaseLanguageModel, None] = None, + model_nick_name: str = 'default', ) -> None: self.name = name self.session = session @@ -79,6 +82,7 @@ def __init__( self.flo_team = flo_team self.agents = flo_team.members self.members = [agent.name for agent in flo_team.members] + self.model_name = model_nick_name self.options = self.members + [FLO_FINISH] member_type = ( 'workers' if flo_team.members[0].type == 'agent' else 'team members' @@ -118,4 +122,5 @@ def build(self): flo_team=self.flo_team, name=self.name, session=self.session, + model_name=self.model_name, )