diff --git a/minette/__init__.py b/minette/__init__.py index 98d4e21..1f9a3c0 100755 --- a/minette/__init__.py +++ b/minette/__init__.py @@ -18,7 +18,8 @@ DialogService, EchoDialogService, ErrorDialogService, - DialogRouter + DialogRouter, + DependencyContainer ) from .models import * from .tagger import Tagger diff --git a/minette/core.py b/minette/core.py index 7ebb383..67d441d 100755 --- a/minette/core.py +++ b/minette/core.py @@ -22,7 +22,11 @@ SQLiteMessageLogStore ) from .config import Config -from .dialog import DialogService, DialogRouter +from .dialog import ( + DialogService, + DialogRouter, + DependencyContainer +) from .tagger import Tagger @@ -369,3 +373,38 @@ def _save_context(self, context, connection): context.reset(self.config.get("keep_context_data", False)) self.context_store.save(context, connection) return context_for_log + + def dialog_uses(self, dependency_rules=None, **defaults): + """ + Set dependency components for DialogServices/Router + + Examples + -------- + >>> bot = Minette(defautl_dialog_service=EchoDialogService) + >>> bot.dialog_uses(apiclient=apiclient, tagger=bot.tagger) + + You can use `apiclient` and `tagger` like below in your code of DialogService/Router. + + >>> self.dependencies.apiclient.get_profile(user.id) + >>> self.dependencies.tagger.parse(request.text) + + Or, you can set dialog specific dependencies. + + >>> bot = Minette(defautl_dialog_service=EchoDialogService) + >>> bot.dialog_uses({EchoDialogService: {"echo_engine": echo_engine}}, apiclient=apiclient, tagger=bot.tagger) + + Then, `echo_engine` can be used only in `EchoDialogService`, `apiclient` and `tagger` can be used any dialogs/router. + + Parameters + ---------- + dependency_rules : dict + Rules that defines on which components each DialogService/Router depends. + Key is DialogService/Router class, value is dict of dependencies (name: value). + + defaults : dict + Dependencies for all DialogServices/Router (name: value) + """ + self.dialog_router.dependency_rules = dependency_rules + self.dialog_router.default_dependencies = defaults + self.dialog_router.dependencies = DependencyContainer( + self.dialog_router, dependency_rules, **defaults) diff --git a/minette/dialog/dependency.py b/minette/dialog/dependency.py new file mode 100644 index 0000000..2477245 --- /dev/null +++ b/minette/dialog/dependency.py @@ -0,0 +1,14 @@ +""" Container class for components that DialogRouter/DialogServices depend """ + + +class DependencyContainer: + def __init__(self, dialog, dependency_rules=None, **defaults): + # set default dependencies + for k, v in defaults.items(): + setattr(self, k, v) + # set dialog specific dependencies + if dependency_rules: + dialog_dependencies = dependency_rules.get(type(dialog)) + if dialog_dependencies: + for k, v in dialog_dependencies.items(): + setattr(self, k, v) diff --git a/minette/dialog/router.py b/minette/dialog/router.py index 9a81bfd..d0b9c6c 100644 --- a/minette/dialog/router.py +++ b/minette/dialog/router.py @@ -5,6 +5,7 @@ from ..models import Message, Priority from .service import DialogService, ErrorDialogService +from .dependency import DependencyContainer class DialogRouter: @@ -21,6 +22,12 @@ class DialogRouter: Logger default_dialog_service : DialogService Dialog service used when intent is not clear + dependency_rules : dict + Rules that defines on which components each DialogService/Router depends + default_dependencies : dict + Dependency components for all DialogServices/Router + dependencies : DependencyContainer + Container to attach objects DialogRouter depends intent_resolver : dict Resolver for intent to dialog topic_resolver : dict @@ -45,6 +52,9 @@ def __init__(self, config=None, timezone=None, logger=None, self.timezone = timezone self.logger = logger or getLogger(__name__) self.default_dialog_service = default_dialog_service or DialogService + self.dependency_rules = None + self.default_dependencies = None or {} # empty dict is required to unpack + self.dependencies = None # set up intent_resolver self.intent_resolver = intent_resolver or {} self.register_intents() @@ -107,6 +117,10 @@ def execute(self, request, context, connection, performance): config=self.config, timezone=self.timezone, logger=self.logger ) + dialog_service.dependencies = DependencyContainer( + dialog_service, + self.dependency_rules, + **self.default_dependencies) performance.append("dialog_router.route") except Exception as ex: self.logger.error( diff --git a/minette/dialog/service.py b/minette/dialog/service.py index cc2ba9a..454376a 100644 --- a/minette/dialog/service.py +++ b/minette/dialog/service.py @@ -22,6 +22,8 @@ class DialogService: Timezone logger : logging.Logger Logger + dependencies : DependencyContainer + Container to attach objects DialogRouter depends """ @classmethod @@ -55,6 +57,7 @@ def __init__(self, config=None, timezone=None, logger=None): self.config = config self.timezone = timezone self.logger = logger or getLogger(__name__) + self.dependencies = None def execute(self, request, context, connection, performance): """ diff --git a/tests/dialog/test_dependency.py b/tests/dialog/test_dependency.py new file mode 100644 index 0000000..05c1bc1 --- /dev/null +++ b/tests/dialog/test_dependency.py @@ -0,0 +1,66 @@ +import sys +import os +sys.path.append(os.pardir) + +from minette.dialog import ( + DialogService, + DialogRouter, + DependencyContainer +) + + +class SobaDialogService(DialogService): + pass + + +class UdonDialogService(DialogService): + pass + + +class RamenDialogService(DialogService): + pass + + +class MenDialogRouter(DialogRouter): + pass + + +def test_dependency(): + # dependencies + d1 = 1 + d2 = 2 + d3 = 3 + d4 = 4 + d5 = 5 + d6 = 6 + d7 = 7 + + # define rules + dependency_rules = { + SobaDialogService: {"d1": d1, "d2": d2}, + UdonDialogService: {"d2": d2, "d3": d3}, + RamenDialogService: {"d3": d3, "d4": d4}, + MenDialogRouter: {"d4": d4, "d5": d5} + } + + # dialog service + soba_dep = DependencyContainer(SobaDialogService(), dependency_rules, d6=d6, d7=d7) + # dependencies for soba + assert soba_dep.d1 == 1 + assert soba_dep.d2 == 2 + # dependencies for all + assert soba_dep.d6 == 6 + assert soba_dep.d7 == 7 + # dependencies not for soba + assert hasattr(soba_dep, "d3") is False + + # dialog router + men_dep = DependencyContainer(MenDialogRouter(), dependency_rules, d6=d6, d7=d7) + # dependencies for men + assert men_dep.d4 == 4 + assert men_dep.d5 == 5 + # dependencies for all + assert men_dep.d6 == 6 + assert men_dep.d7 == 7 + # dependencies not for men + assert hasattr(men_dep, "d1") is False diff --git a/tests/test_core.py b/tests/test_core.py index 7f68a92..1f44319 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,7 @@ -import pytest +import sys import os +sys.path.append(os.pardir) +import pytest from pytz import timezone from logging import Logger, FileHandler, getLogger from datetime import datetime @@ -7,9 +9,9 @@ from minette import ( Minette, DialogService, SQLiteConnectionProvider, SQLiteContextStore, SQLiteUserStore, SQLiteMessageLogStore, - Tagger, Config, DialogRouter, StoreSet, Message, User, Group + Tagger, Config, DialogRouter, StoreSet, Message, User, Group, + DependencyContainer ) - from minette.utils import date_to_unixtime now = datetime.now() @@ -313,3 +315,39 @@ def test_chat_timezone(): res = bot.chat("hello") # bot.timezone itself is +9:19 assert res.messages[0].timestamp.tzinfo == datetime.now(tz=bot.timezone).tzinfo + + +def test_dialog_uses(): + class HighCostToCreate: + pass + + class OnlyForFooDS: + pass + + class FooFialog(DialogService): + pass + + # run once when create bot + hctc = HighCostToCreate() + offds = OnlyForFooDS() + + # create bot + bot = Minette() + + # set dependencies to dialogs + bot.dialog_uses( + { + FooFialog: {"api": offds} + }, + highcost=hctc + ) + + assert bot.dialog_router.dependencies.highcost == hctc + assert hasattr(bot.dialog_router.dependencies, "api") is False + assert bot.dialog_router.dependency_rules[FooFialog]["api"] == offds + + # create bot and not set dialog dependencies + bot_no_dd = Minette() + assert bot_no_dd.dialog_router.dependencies is None + bot_no_dd.dialog_uses() + assert isinstance(bot_no_dd.dialog_router.dependencies, DependencyContainer)