From 671901810a55cd7de9a3783993a7fab01b42ee9e Mon Sep 17 00:00:00 2001 From: uezo Date: Thu, 3 Sep 2020 21:58:39 +0900 Subject: [PATCH 1/3] Support returning generator from Tagger Add `parse_as_generator` method to MeCabTagger, JanomeTagger and Tagger(base) to save CPU and memory resources. For now, `tagger.parse()` is called automatically in `core.chat()` so this new method is unused. TODO: Enable to switch parsing morph automatically or manually --- minette/tagger/base.py | 18 ++++++++++++++++- minette/tagger/janometagger.py | 28 ++++++++++++++++++++------- minette/tagger/mecabtagger.py | 24 +++++++++++++++++++---- tests/tagger/test_janometagger.py | 32 +++++++++++++++++++++++++++++++ tests/tagger/test_mecabtagger.py | 32 +++++++++++++++++++++++++++++++ tests/tagger/test_tagger_base.py | 6 ++++++ 6 files changed, 128 insertions(+), 12 deletions(-) diff --git a/minette/tagger/base.py b/minette/tagger/base.py index 0b5fef0..befac61 100644 --- a/minette/tagger/base.py +++ b/minette/tagger/base.py @@ -42,7 +42,23 @@ def parse(self, text): Returns ------- - words : list of minette.WordNode + words : list of minette.WordNode (empty) Word nodes """ return [] + + def parse_as_generator(self, text): + """ + Analyze and parse text, returns Generator + + Parameters + ---------- + text : str + Text to analyze + + Returns + ------- + words : Generator of minette.WordNode (empty) + Word nodes + """ + yield from () diff --git a/minette/tagger/janometagger.py b/minette/tagger/janometagger.py index 28d6244..0cfc778 100644 --- a/minette/tagger/janometagger.py +++ b/minette/tagger/janometagger.py @@ -96,9 +96,9 @@ def __init__(self, config=None, timezone=None, logger=None, *, else: self.tokenizer = Tokenizer() - def parse(self, text): + def parse_as_generator(self, text): """ - Parse and annotate using Janome + Parse and annotate using Janome, returns Generator Parameters ---------- @@ -107,17 +107,31 @@ def parse(self, text): Returns ------- - words : list of minette.minette.tagger.janometagger.JanomeNode + words : Generator of minette.minette.tagger.janometagger.JanomeNode Janome nodes """ - ret = [] if not text: - return ret + return try: for token in self.tokenizer.tokenize(text): - ret.append(JanomeNode.create(token.surface, token)) + yield JanomeNode.create(token.surface, token) except Exception as ex: self.logger.error( "Janome parsing error: " + str(ex) + "\n" + traceback.format_exc()) - return ret + + def parse(self, text): + """ + Parse and annotate using Janome + + Parameters + ---------- + text : str + Text to analyze + + Returns + ------- + words : Generator of minette.minette.tagger.janometagger.JanomeNode + Janome nodes + """ + return [jn for jn in self.parse_as_generator(text)] diff --git a/minette/tagger/mecabtagger.py b/minette/tagger/mecabtagger.py index 791c4c5..4d8ca15 100755 --- a/minette/tagger/mecabtagger.py +++ b/minette/tagger/mecabtagger.py @@ -74,9 +74,9 @@ class MeCabTagger(Tagger): Logger """ - def parse(self, text): + def parse_as_generator(self, text): """ - Analyze and parse text + Analyze and parse text using MeCab, returns Generator Parameters ---------- @@ -99,10 +99,26 @@ def parse(self, text): while node: features = node.feature.split(",") if features[0] != "BOS/EOS": - ret.append(MeCabNode.create(node.surface, features)) + # ret.append(MeCabNode.create(node.surface, features)) + yield MeCabNode.create(node.surface, features) node = node.next except Exception as ex: self.logger.error( "MeCab parsing error: " + str(ex) + "\n" + traceback.format_exc()) - return ret + + def parse(self, text): + """ + Analyze and parse text + + Parameters + ---------- + text : str + Text to analyze + + Returns + ------- + words : list of minette.tagger.mecabtagger.MeCabNode + MeCab word nodes + """ + return [mn for mn in self.parse_as_generator(text)] diff --git a/tests/tagger/test_janometagger.py b/tests/tagger/test_janometagger.py index 4d9fcb0..a20db7d 100644 --- a/tests/tagger/test_janometagger.py +++ b/tests/tagger/test_janometagger.py @@ -1,5 +1,6 @@ import pytest from pytz import timezone +from types import GeneratorType try: from minette.tagger.janometagger import JanomeTagger, JanomeNode @@ -38,6 +39,37 @@ def test_parse(): assert words[2].pronunciation == "ヨイ" +def test_parse_as_generator(): + tagger = JanomeTagger() + # 空文字列 + empty_words_gen = tagger.parse_as_generator("") + assert isinstance(empty_words_gen, GeneratorType) + empty_words = [ew for ew in empty_words_gen] + assert empty_words == [] + # センテンスあり + words = tagger.parse_as_generator("今日は良い天気です") + assert isinstance(words, GeneratorType) + i = 0 + for w in words: + if i == 0: + assert w.surface == "今日" + assert w.part == "名詞" + assert w.part_detail1 == "副詞可能" + assert w.word == "今日" + assert w.kana == "キョウ" + assert w.pronunciation == "キョー" + elif i == 2: + assert w.surface == "良い" + assert w.part == "形容詞" + assert w.part_detail1 == "自立" + assert w.stem_type == "形容詞・アウオ段" + assert w.stem_form == "基本形" + assert w.word == "良い" + assert w.kana == "ヨイ" + assert w.pronunciation == "ヨイ" + i += 1 + + def test_error(): tagger = JanomeTagger() assert tagger.parse(object()) == [] diff --git a/tests/tagger/test_mecabtagger.py b/tests/tagger/test_mecabtagger.py index 28f87d8..7c9a974 100644 --- a/tests/tagger/test_mecabtagger.py +++ b/tests/tagger/test_mecabtagger.py @@ -1,5 +1,6 @@ import pytest from pytz import timezone +from types import GeneratorType try: from minette.tagger.mecabtagger import MeCabTagger, MeCabNode @@ -38,6 +39,37 @@ def test_parse(): assert words[2].pronunciation == "ヨイ" +def test_parse_as_generator(): + tagger = MeCabTagger() + # 空文字列 + empty_words_gen = tagger.parse_as_generator("") + assert isinstance(empty_words_gen, GeneratorType) + empty_words = [ew for ew in empty_words_gen] + assert empty_words == [] + # センテンスあり + words = tagger.parse_as_generator("今日は良い天気です") + assert isinstance(words, GeneratorType) + i = 0 + for w in words: + if i == 0: + assert w.surface == "今日" + assert w.part == "名詞" + assert w.part_detail1 == "副詞可能" + assert w.word == "今日" + assert w.kana == "キョウ" + assert w.pronunciation == "キョー" + elif i == 2: + assert w.surface == "良い" + assert w.part == "形容詞" + assert w.part_detail1 == "自立" + assert w.stem_type == "形容詞・アウオ段" + assert w.stem_form == "基本形" + assert w.word == "良い" + assert w.kana == "ヨイ" + assert w.pronunciation == "ヨイ" + i += 1 + + def test_error(): tagger = MeCabTagger() assert tagger.parse(object()) == [] diff --git a/tests/tagger/test_tagger_base.py b/tests/tagger/test_tagger_base.py index e8bba96..dd6c073 100644 --- a/tests/tagger/test_tagger_base.py +++ b/tests/tagger/test_tagger_base.py @@ -1,5 +1,6 @@ import pytest from pytz import timezone +from types import GeneratorType from minette import Tagger @@ -12,3 +13,8 @@ def test_init(): def test_parse(): tagger = Tagger() assert tagger.parse("今日は良い天気です") == [] + + +def test_parse_as_generator(): + tagger = Tagger() + assert isinstance(tagger.parse_as_generator("今日は良い天気です"), GeneratorType) From 2ea36630d2d483ede94d534cee5ebb98cb670dd9 Mon Sep 17 00:00:00 2001 From: uezo Date: Thu, 3 Sep 2020 23:01:26 +0900 Subject: [PATCH 2/3] Enable to switch parsing morph automatically or manually Enable to switch off parsing morph by passing `parse_morph=False` when create Minette object. Also add `tagger` to DialogService and DialogRouter. To parse manually call `self.tagger.parse(request.text)` or `self.tagger.parse_as_generator(request.text)`. --- minette/core.py | 28 +++++++----- minette/dialog/router.py | 10 ++++- minette/dialog/service.py | 7 ++- requirements-dev.txt | 2 +- tests/test_core.py | 90 +++++++++++++++++++++++++++++++++++++-- 5 files changed, 120 insertions(+), 17 deletions(-) diff --git a/minette/core.py b/minette/core.py index 7ebb383..0f9911d 100755 --- a/minette/core.py +++ b/minette/core.py @@ -63,7 +63,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None, user_store=None, user_table=None, messagelog_store=None, messagelog_table=None, default_dialog_service=None, dialog_router=None, - tagger=None, prepare_table=True, **kwargs): + tagger=None, parse_morph=True, prepare_table=True, **kwargs): """ Parameters ---------- @@ -122,6 +122,8 @@ def __init__(self, *, config=None, config_file=None, timezone=None, and return proper DialogService for intent tagger: minette.Tagger or type, default None Morphological analysis engine + parse_morph: bool, default True + Parse morph of request message automatically or not prepare_table: bool, default True Create tables for data stores if they don't exist. """ @@ -133,6 +135,9 @@ def __init__(self, *, config=None, config_file=None, timezone=None, self.timezone = timezone or tz(self.config.get("timezone") or "UTC") self.logger = self._get_logger( logger, log_file=log_file, logger_name=logger_name) + self.tagger = self._get_tagger( + tagger, logger=self.logger, + config=self.config, timezone=self.timezone, **kwargs) self.connection_provider = self._get_connection_provider( connection_provider or ( data_stores.connection_provider if data_stores else None), @@ -143,6 +148,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None, "config": self.config, "timezone": self.timezone, "logger": self.logger, + "tagger": self.tagger, "connection_provider": self.connection_provider, "context_store": context_store or ( data_stores.context_store if data_stores else None), @@ -156,7 +162,6 @@ def __init__(self, *, config=None, config_file=None, timezone=None, "messagelog_table": messagelog_table, "dialog_router": dialog_router, "default_dialog_service": default_dialog_service, - "tagger": tagger, } setter_args.update({k: v for k, v in kwargs.items() if k not in setter_args}) @@ -166,7 +171,6 @@ def __init__(self, *, config=None, config_file=None, timezone=None, self.messagelog_store = self._get_messagelog_store(**setter_args) self.default_dialog_service = default_dialog_service self.dialog_router = self._get_dialog_router(**setter_args) - self.tagger = self._get_tagger(**setter_args) # prepare tables if prepare_table is True: @@ -178,6 +182,9 @@ def __init__(self, *, config=None, config_file=None, timezone=None, if hasattr(connection, "close"): connection.close() + # other runtime members + self.parse_morph = parse_morph + def _get_logger(self, logger, log_file=None, logger_name=None): lg = logger # use passed logger if already setup @@ -207,6 +214,12 @@ def _get_logger(self, logger, log_file=None, logger_name=None): lg.addHandler(file_handler) return lg + def _get_tagger(self, tagger, config, logger, timezone, **kwargs): + tg = tagger or Tagger + if issubclass(tg, Tagger): + tg = tg(config, logger, timezone, **kwargs) + return tg + def _get_connection_provider(self, connection_provider, connection_str=None, **kwargs): cp = connection_provider or SQLiteConnectionProvider @@ -259,12 +272,6 @@ def _get_dialog_router(self, dialog_router, default_dialog_service=None, dr = dr(default_dialog_service=default_dialog_service, **kwargs) return dr - def _get_tagger(self, tagger, **kwargs): - tg = tagger or Tagger - if issubclass(tg, Tagger): - tg = tg(**kwargs) - return tg - def chat(self, request): """ Get response from chatbot @@ -296,7 +303,8 @@ def chat(self, request): connection = self.connection_provider.get_connection() performance.append("connection_provider.get_connection") # tagger - request.words = self.tagger.parse(request.text) + if self.parse_morph: + request.words = self.tagger.parse(request.text) performance.append("tagger.parse") # user request.user = self._get_user(request, connection) diff --git a/minette/dialog/router.py b/minette/dialog/router.py index 9a81bfd..b603930 100644 --- a/minette/dialog/router.py +++ b/minette/dialog/router.py @@ -4,6 +4,7 @@ from logging import Logger, getLogger from ..models import Message, Priority +from ..tagger import Tagger from .service import DialogService, ErrorDialogService @@ -27,7 +28,7 @@ class DialogRouter: Resolver for topic to dialog for successive chatting """ - def __init__(self, config=None, timezone=None, logger=None, + def __init__(self, config=None, timezone=None, logger=None, tagger=None, default_dialog_service=None, intent_resolver=None, **kwargs): """ Parameters @@ -38,12 +39,17 @@ def __init__(self, config=None, timezone=None, logger=None, Timezone logger : logging.Logger, default None Logger + tagger : minette.tagger.Tagger, default None + Tagger default_dialog_service : minette.DialogService or type, default None Dialog service used when intent is not clear. + intent_resolver : dict, default None """ self.config = config self.timezone = timezone self.logger = logger or getLogger(__name__) + self.tagger = tagger or \ + Tagger(config=config, timezone=timezone, logger=logger) self.default_dialog_service = default_dialog_service or DialogService # set up intent_resolver self.intent_resolver = intent_resolver or {} @@ -105,7 +111,7 @@ def execute(self, request, context, connection, performance): if issubclass(dialog_service, DialogService): dialog_service = dialog_service( config=self.config, timezone=self.timezone, - logger=self.logger + logger=self.logger, tagger=self.tagger ) performance.append("dialog_router.route") except Exception as ex: diff --git a/minette/dialog/service.py b/minette/dialog/service.py index cc2ba9a..4eb389e 100644 --- a/minette/dialog/service.py +++ b/minette/dialog/service.py @@ -8,6 +8,7 @@ Context, PerformanceInfo ) +from ..tagger import Tagger class DialogService: @@ -41,7 +42,7 @@ def topic_name(cls): cls_name = cls_name[:-6] return cls_name - def __init__(self, config=None, timezone=None, logger=None): + def __init__(self, config=None, timezone=None, logger=None, tagger=None): """ Parameters ---------- @@ -51,10 +52,14 @@ def __init__(self, config=None, timezone=None, logger=None): Timezone logger : logging.Logger, default None Logger + tagger : minette.tagger.Tagger, default None + Tagger """ self.config = config self.timezone = timezone self.logger = logger or getLogger(__name__) + self.tagger = tagger or \ + Tagger(config=config, timezone=timezone, logger=logger) def execute(self, request, context, connection, performance): """ diff --git a/requirements-dev.txt b/requirements-dev.txt index 9057f74..ecd80f1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ pytz==2020.1 -requests==2.24.0 schedule==0.6.0 pytest==6.0.1 +Janome==0.4.0 diff --git a/tests/test_core.py b/tests/test_core.py index 7f68a92..a9230e7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,6 @@ import pytest import os +from types import GeneratorType from pytz import timezone from logging import Logger, FileHandler, getLogger from datetime import datetime @@ -7,7 +8,12 @@ from minette import ( Minette, DialogService, SQLiteConnectionProvider, SQLiteContextStore, SQLiteUserStore, SQLiteMessageLogStore, - Tagger, Config, DialogRouter, StoreSet, Message, User, Group + Tagger, Config, DialogRouter, StoreSet, Message, User, Group, Payload +) + +from minette.tagger.janometagger import ( + JanomeTagger, + JanomeNode ) from minette.utils import date_to_unixtime @@ -55,6 +61,31 @@ def compose_response(self, request, context, connection): return "res:" + request.text +class TaggerDialog(DialogService): + def compose_response(self, request, context, connection): + return request.to_reply( + text=request.text, + payloads=[Payload(content_type="data", content=request.words)]) + + +class TaggerManuallyParseDialog(DialogService): + def compose_response(self, request, context, connection): + assert request.words == [] + request.words = self.tagger.parse(request.text) + return request.to_reply( + text=request.text, + payloads=[Payload(content_type="data", content=request.words)]) + + +class TaggerManuallyParseGeneratorDialog(DialogService): + def compose_response(self, request, context, connection): + assert request.words == [] + request.words = self.tagger.parse_as_generator(request.text) + return request.to_reply( + text=request.text, + payloads=[Payload(content_type="data", content=request.words)]) + + class MyDialogRouter(DialogRouter): def __init__(self, custom_router_arg=None, **kwargs): super().__init__(**kwargs) @@ -68,12 +99,12 @@ def test_init(): assert bot.timezone == timezone("UTC") assert isinstance(bot.logger, Logger) assert bot.logger.name == "minette" + assert isinstance(bot.tagger, Tagger) assert isinstance(bot.connection_provider, SQLiteConnectionProvider) assert isinstance(bot.context_store, SQLiteContextStore) assert isinstance(bot.user_store, SQLiteUserStore) assert isinstance(bot.messagelog_store, SQLiteMessageLogStore) assert bot.default_dialog_service is None - assert isinstance(bot.tagger, Tagger) def test_init_config(): @@ -125,6 +156,7 @@ def test_init_args(): assert bot.config.get("test_key", section="test_section") == "test_value" assert bot.timezone == timezone("Asia/Tokyo") assert bot.logger.name == "test_core_logger" + assert isinstance(bot.tagger, CustomTagger) assert isinstance(bot.connection_provider, CustomConnectionProvider) assert isinstance(bot.context_store, CustomContextStore) assert isinstance(bot.user_store, CustomUserStore) @@ -132,7 +164,6 @@ def test_init_args(): assert bot.default_dialog_service is MyDialog assert isinstance(bot.dialog_router, MyDialogRouter) assert bot.dialog_router.custom_attr == "router_value" - assert isinstance(bot.tagger, CustomTagger) # create bot with data_stores bot = Minette( @@ -288,6 +319,59 @@ def test_chat(): assert res.messages[0].text == "res:hello" +def test_chat_with_tagger(): + bot = Minette( + default_dialog_service=TaggerDialog, + tagger=JanomeTagger) + res = bot.chat("今日はいい天気です。") + assert res.messages[0].text == "今日はいい天気です。" + words = res.messages[0].payloads[0].content + assert words[0].surface == "今日" + assert words[1].surface == "は" + assert words[2].surface == "いい" + assert words[3].surface == "天気" + assert words[4].surface == "です" + + +def test_chat_with_tagger_no_parse(): + bot = Minette( + default_dialog_service=TaggerDialog, + tagger=JanomeTagger, parse_morph=False) + res = bot.chat("今日はいい天気です。") + assert res.messages[0].text == "今日はいい天気です。" + words = res.messages[0].payloads[0].content + assert words == [] + + +def test_chat_parse_morph_manually(): + bot = Minette( + default_dialog_service=TaggerManuallyParseDialog, + tagger=JanomeTagger, parse_morph=False) + res = bot.chat("今日はいい天気です。") + assert res.messages[0].text == "今日はいい天気です。" + words = res.messages[0].payloads[0].content + assert words[0].surface == "今日" + assert words[1].surface == "は" + assert words[2].surface == "いい" + assert words[3].surface == "天気" + assert words[4].surface == "です" + + +def test_chat_parse_morph_manually_generator(): + bot = Minette( + default_dialog_service=TaggerManuallyParseGeneratorDialog, + tagger=JanomeTagger, parse_morph=False) + res = bot.chat("今日はいい天気です。") + assert res.messages[0].text == "今日はいい天気です。" + assert isinstance(res.messages[0].payloads[0].content, GeneratorType) + words = [w for w in res.messages[0].payloads[0].content] + assert words[0].surface == "今日" + assert words[1].surface == "は" + assert words[2].surface == "いい" + assert words[3].surface == "天気" + assert words[4].surface == "です" + + def test_chat_error(): bot = Minette(default_dialog_service=MyDialog) bot.connection_provider = None From 471023cf50bf049b5df159e34d312f40be1efc71 Mon Sep 17 00:00:00 2001 From: uezo Date: Thu, 3 Sep 2020 23:49:05 +0900 Subject: [PATCH 3/3] Revert "Enable to switch parsing morph automatically or manually" This reverts commit 2ea36630d2d483ede94d534cee5ebb98cb670dd9. --- minette/core.py | 28 +++++------- minette/dialog/router.py | 10 +---- minette/dialog/service.py | 7 +-- requirements-dev.txt | 2 +- tests/test_core.py | 90 ++------------------------------------- 5 files changed, 17 insertions(+), 120 deletions(-) diff --git a/minette/core.py b/minette/core.py index 0f9911d..7ebb383 100755 --- a/minette/core.py +++ b/minette/core.py @@ -63,7 +63,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None, user_store=None, user_table=None, messagelog_store=None, messagelog_table=None, default_dialog_service=None, dialog_router=None, - tagger=None, parse_morph=True, prepare_table=True, **kwargs): + tagger=None, prepare_table=True, **kwargs): """ Parameters ---------- @@ -122,8 +122,6 @@ def __init__(self, *, config=None, config_file=None, timezone=None, and return proper DialogService for intent tagger: minette.Tagger or type, default None Morphological analysis engine - parse_morph: bool, default True - Parse morph of request message automatically or not prepare_table: bool, default True Create tables for data stores if they don't exist. """ @@ -135,9 +133,6 @@ def __init__(self, *, config=None, config_file=None, timezone=None, self.timezone = timezone or tz(self.config.get("timezone") or "UTC") self.logger = self._get_logger( logger, log_file=log_file, logger_name=logger_name) - self.tagger = self._get_tagger( - tagger, logger=self.logger, - config=self.config, timezone=self.timezone, **kwargs) self.connection_provider = self._get_connection_provider( connection_provider or ( data_stores.connection_provider if data_stores else None), @@ -148,7 +143,6 @@ def __init__(self, *, config=None, config_file=None, timezone=None, "config": self.config, "timezone": self.timezone, "logger": self.logger, - "tagger": self.tagger, "connection_provider": self.connection_provider, "context_store": context_store or ( data_stores.context_store if data_stores else None), @@ -162,6 +156,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None, "messagelog_table": messagelog_table, "dialog_router": dialog_router, "default_dialog_service": default_dialog_service, + "tagger": tagger, } setter_args.update({k: v for k, v in kwargs.items() if k not in setter_args}) @@ -171,6 +166,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None, self.messagelog_store = self._get_messagelog_store(**setter_args) self.default_dialog_service = default_dialog_service self.dialog_router = self._get_dialog_router(**setter_args) + self.tagger = self._get_tagger(**setter_args) # prepare tables if prepare_table is True: @@ -182,9 +178,6 @@ def __init__(self, *, config=None, config_file=None, timezone=None, if hasattr(connection, "close"): connection.close() - # other runtime members - self.parse_morph = parse_morph - def _get_logger(self, logger, log_file=None, logger_name=None): lg = logger # use passed logger if already setup @@ -214,12 +207,6 @@ def _get_logger(self, logger, log_file=None, logger_name=None): lg.addHandler(file_handler) return lg - def _get_tagger(self, tagger, config, logger, timezone, **kwargs): - tg = tagger or Tagger - if issubclass(tg, Tagger): - tg = tg(config, logger, timezone, **kwargs) - return tg - def _get_connection_provider(self, connection_provider, connection_str=None, **kwargs): cp = connection_provider or SQLiteConnectionProvider @@ -272,6 +259,12 @@ def _get_dialog_router(self, dialog_router, default_dialog_service=None, dr = dr(default_dialog_service=default_dialog_service, **kwargs) return dr + def _get_tagger(self, tagger, **kwargs): + tg = tagger or Tagger + if issubclass(tg, Tagger): + tg = tg(**kwargs) + return tg + def chat(self, request): """ Get response from chatbot @@ -303,8 +296,7 @@ def chat(self, request): connection = self.connection_provider.get_connection() performance.append("connection_provider.get_connection") # tagger - if self.parse_morph: - request.words = self.tagger.parse(request.text) + request.words = self.tagger.parse(request.text) performance.append("tagger.parse") # user request.user = self._get_user(request, connection) diff --git a/minette/dialog/router.py b/minette/dialog/router.py index b603930..9a81bfd 100644 --- a/minette/dialog/router.py +++ b/minette/dialog/router.py @@ -4,7 +4,6 @@ from logging import Logger, getLogger from ..models import Message, Priority -from ..tagger import Tagger from .service import DialogService, ErrorDialogService @@ -28,7 +27,7 @@ class DialogRouter: Resolver for topic to dialog for successive chatting """ - def __init__(self, config=None, timezone=None, logger=None, tagger=None, + def __init__(self, config=None, timezone=None, logger=None, default_dialog_service=None, intent_resolver=None, **kwargs): """ Parameters @@ -39,17 +38,12 @@ def __init__(self, config=None, timezone=None, logger=None, tagger=None, Timezone logger : logging.Logger, default None Logger - tagger : minette.tagger.Tagger, default None - Tagger default_dialog_service : minette.DialogService or type, default None Dialog service used when intent is not clear. - intent_resolver : dict, default None """ self.config = config self.timezone = timezone self.logger = logger or getLogger(__name__) - self.tagger = tagger or \ - Tagger(config=config, timezone=timezone, logger=logger) self.default_dialog_service = default_dialog_service or DialogService # set up intent_resolver self.intent_resolver = intent_resolver or {} @@ -111,7 +105,7 @@ def execute(self, request, context, connection, performance): if issubclass(dialog_service, DialogService): dialog_service = dialog_service( config=self.config, timezone=self.timezone, - logger=self.logger, tagger=self.tagger + logger=self.logger ) performance.append("dialog_router.route") except Exception as ex: diff --git a/minette/dialog/service.py b/minette/dialog/service.py index 4eb389e..cc2ba9a 100644 --- a/minette/dialog/service.py +++ b/minette/dialog/service.py @@ -8,7 +8,6 @@ Context, PerformanceInfo ) -from ..tagger import Tagger class DialogService: @@ -42,7 +41,7 @@ def topic_name(cls): cls_name = cls_name[:-6] return cls_name - def __init__(self, config=None, timezone=None, logger=None, tagger=None): + def __init__(self, config=None, timezone=None, logger=None): """ Parameters ---------- @@ -52,14 +51,10 @@ def __init__(self, config=None, timezone=None, logger=None, tagger=None): Timezone logger : logging.Logger, default None Logger - tagger : minette.tagger.Tagger, default None - Tagger """ self.config = config self.timezone = timezone self.logger = logger or getLogger(__name__) - self.tagger = tagger or \ - Tagger(config=config, timezone=timezone, logger=logger) def execute(self, request, context, connection, performance): """ diff --git a/requirements-dev.txt b/requirements-dev.txt index ecd80f1..9057f74 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ pytz==2020.1 +requests==2.24.0 schedule==0.6.0 pytest==6.0.1 -Janome==0.4.0 diff --git a/tests/test_core.py b/tests/test_core.py index a9230e7..7f68a92 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,5 @@ import pytest import os -from types import GeneratorType from pytz import timezone from logging import Logger, FileHandler, getLogger from datetime import datetime @@ -8,12 +7,7 @@ from minette import ( Minette, DialogService, SQLiteConnectionProvider, SQLiteContextStore, SQLiteUserStore, SQLiteMessageLogStore, - Tagger, Config, DialogRouter, StoreSet, Message, User, Group, Payload -) - -from minette.tagger.janometagger import ( - JanomeTagger, - JanomeNode + Tagger, Config, DialogRouter, StoreSet, Message, User, Group ) from minette.utils import date_to_unixtime @@ -61,31 +55,6 @@ def compose_response(self, request, context, connection): return "res:" + request.text -class TaggerDialog(DialogService): - def compose_response(self, request, context, connection): - return request.to_reply( - text=request.text, - payloads=[Payload(content_type="data", content=request.words)]) - - -class TaggerManuallyParseDialog(DialogService): - def compose_response(self, request, context, connection): - assert request.words == [] - request.words = self.tagger.parse(request.text) - return request.to_reply( - text=request.text, - payloads=[Payload(content_type="data", content=request.words)]) - - -class TaggerManuallyParseGeneratorDialog(DialogService): - def compose_response(self, request, context, connection): - assert request.words == [] - request.words = self.tagger.parse_as_generator(request.text) - return request.to_reply( - text=request.text, - payloads=[Payload(content_type="data", content=request.words)]) - - class MyDialogRouter(DialogRouter): def __init__(self, custom_router_arg=None, **kwargs): super().__init__(**kwargs) @@ -99,12 +68,12 @@ def test_init(): assert bot.timezone == timezone("UTC") assert isinstance(bot.logger, Logger) assert bot.logger.name == "minette" - assert isinstance(bot.tagger, Tagger) assert isinstance(bot.connection_provider, SQLiteConnectionProvider) assert isinstance(bot.context_store, SQLiteContextStore) assert isinstance(bot.user_store, SQLiteUserStore) assert isinstance(bot.messagelog_store, SQLiteMessageLogStore) assert bot.default_dialog_service is None + assert isinstance(bot.tagger, Tagger) def test_init_config(): @@ -156,7 +125,6 @@ def test_init_args(): assert bot.config.get("test_key", section="test_section") == "test_value" assert bot.timezone == timezone("Asia/Tokyo") assert bot.logger.name == "test_core_logger" - assert isinstance(bot.tagger, CustomTagger) assert isinstance(bot.connection_provider, CustomConnectionProvider) assert isinstance(bot.context_store, CustomContextStore) assert isinstance(bot.user_store, CustomUserStore) @@ -164,6 +132,7 @@ def test_init_args(): assert bot.default_dialog_service is MyDialog assert isinstance(bot.dialog_router, MyDialogRouter) assert bot.dialog_router.custom_attr == "router_value" + assert isinstance(bot.tagger, CustomTagger) # create bot with data_stores bot = Minette( @@ -319,59 +288,6 @@ def test_chat(): assert res.messages[0].text == "res:hello" -def test_chat_with_tagger(): - bot = Minette( - default_dialog_service=TaggerDialog, - tagger=JanomeTagger) - res = bot.chat("今日はいい天気です。") - assert res.messages[0].text == "今日はいい天気です。" - words = res.messages[0].payloads[0].content - assert words[0].surface == "今日" - assert words[1].surface == "は" - assert words[2].surface == "いい" - assert words[3].surface == "天気" - assert words[4].surface == "です" - - -def test_chat_with_tagger_no_parse(): - bot = Minette( - default_dialog_service=TaggerDialog, - tagger=JanomeTagger, parse_morph=False) - res = bot.chat("今日はいい天気です。") - assert res.messages[0].text == "今日はいい天気です。" - words = res.messages[0].payloads[0].content - assert words == [] - - -def test_chat_parse_morph_manually(): - bot = Minette( - default_dialog_service=TaggerManuallyParseDialog, - tagger=JanomeTagger, parse_morph=False) - res = bot.chat("今日はいい天気です。") - assert res.messages[0].text == "今日はいい天気です。" - words = res.messages[0].payloads[0].content - assert words[0].surface == "今日" - assert words[1].surface == "は" - assert words[2].surface == "いい" - assert words[3].surface == "天気" - assert words[4].surface == "です" - - -def test_chat_parse_morph_manually_generator(): - bot = Minette( - default_dialog_service=TaggerManuallyParseGeneratorDialog, - tagger=JanomeTagger, parse_morph=False) - res = bot.chat("今日はいい天気です。") - assert res.messages[0].text == "今日はいい天気です。" - assert isinstance(res.messages[0].payloads[0].content, GeneratorType) - words = [w for w in res.messages[0].payloads[0].content] - assert words[0].surface == "今日" - assert words[1].surface == "は" - assert words[2].surface == "いい" - assert words[3].surface == "天気" - assert words[4].surface == "です" - - def test_chat_error(): bot = Minette(default_dialog_service=MyDialog) bot.connection_provider = None