From 11c72074d72b422267beb82031b807475e7a06e6 Mon Sep 17 00:00:00 2001 From: uezo Date: Sat, 5 Sep 2020 08:11:43 +0900 Subject: [PATCH] Add argument `tagger_max_length` to constructor of Minette Given value will be set to `tagger.max_length`. You can skip auto parsing by `tagger_max_length=0` and can parse morph manually by calling parse with value of length long enough. For testing, Janome tagger is required.(Added to requirements-dev.txt) --- minette/core.py | 12 +++++-- requirements-dev.txt | 1 + tests/test_core.py | 85 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 94 insertions(+), 4 deletions(-) diff --git a/minette/core.py b/minette/core.py index 67d441d..86ae75d 100755 --- a/minette/core.py +++ b/minette/core.py @@ -67,7 +67,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None, user_store=None, user_table=None, messagelog_store=None, messagelog_table=None, default_dialog_service=None, dialog_router=None, - tagger=None, prepare_table=True, **kwargs): + tagger=None, tagger_max_length=None, prepare_table=True, **kwargs): """ Parameters ---------- @@ -126,6 +126,8 @@ def __init__(self, *, config=None, config_file=None, timezone=None, and return proper DialogService for intent tagger: minette.Tagger or type, default None Morphological analysis engine + tagger_max_length: Max length of the text to parse morph, default None + Morphological analysis engine prepare_table: bool, default True Create tables for data stores if they don't exist. """ @@ -161,6 +163,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None, "dialog_router": dialog_router, "default_dialog_service": default_dialog_service, "tagger": tagger, + "tagger_max_length": tagger_max_length, } setter_args.update({k: v for k, v in kwargs.items() if k not in setter_args}) @@ -263,10 +266,13 @@ def _get_dialog_router(self, dialog_router, default_dialog_service=None, dr = dr(default_dialog_service=default_dialog_service, **kwargs) return dr - def _get_tagger(self, tagger, **kwargs): + def _get_tagger(self, tagger, tagger_max_length, **kwargs): tg = tagger or Tagger if issubclass(tg, Tagger): - tg = tg(**kwargs) + if tagger_max_length is not None: + tg = tg(max_length=tagger_max_length, **kwargs) + else: + tg = tg(**kwargs) return tg def chat(self, request): diff --git a/requirements-dev.txt b/requirements-dev.txt index 9057f74..b76b71f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,4 @@ pytz==2020.1 requests==2.24.0 schedule==0.6.0 pytest==6.0.1 +Janome==0.4.0 diff --git a/tests/test_core.py b/tests/test_core.py index 1f44319..6b84b5a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,14 +5,16 @@ from pytz import timezone from logging import Logger, FileHandler, getLogger from datetime import datetime +from types import GeneratorType from minette import ( Minette, DialogService, SQLiteConnectionProvider, SQLiteContextStore, SQLiteUserStore, SQLiteMessageLogStore, Tagger, Config, DialogRouter, StoreSet, Message, User, Group, - DependencyContainer + DependencyContainer, Payload ) from minette.utils import date_to_unixtime +from minette.tagger.janometagger import JanomeTagger now = datetime.now() user_id = "user_id" + str(date_to_unixtime(now)) @@ -63,6 +65,31 @@ def __init__(self, custom_router_arg=None, **kwargs): self.custom_attr = custom_router_arg +class TaggerDialog(DialogService): + def compose_response(self, request, context, connection): + return request.to_reply( + text=request.text, + payloads=[Payload(content_type="data", content=request.words)]) + + +class TaggerManuallyParseDialog(DialogService): + def compose_response(self, request, context, connection): + assert request.words == [] + request.words = self.dependencies.tagger.parse(request.text, max_length=10) + return request.to_reply( + text=request.text, + payloads=[Payload(content_type="data", content=request.words)]) + + +class TaggerManuallyParseGeneratorDialog(DialogService): + def compose_response(self, request, context, connection): + assert request.words == [] + request.words = self.dependencies.tagger.parse_as_generator(request.text, max_length=10) + return request.to_reply( + text=request.text, + payloads=[Payload(content_type="data", content=request.words)]) + + def test_init(): # without config bot = Minette() @@ -317,6 +344,62 @@ def test_chat_timezone(): assert res.messages[0].timestamp.tzinfo == datetime.now(tz=bot.timezone).tzinfo +def test_chat_with_tagger(): + bot = Minette( + default_dialog_service=TaggerDialog, + tagger=JanomeTagger) + res = bot.chat("今日はいい天気です。") + assert res.messages[0].text == "今日はいい天気です。" + words = res.messages[0].payloads[0].content + assert words[0].surface == "今日" + assert words[1].surface == "は" + assert words[2].surface == "いい" + assert words[3].surface == "天気" + assert words[4].surface == "です" + + +def test_chat_with_tagger_no_parse(): + bot = Minette( + default_dialog_service=TaggerDialog, + tagger=JanomeTagger, tagger_max_length=0) + assert bot.tagger.max_length == 0 + res = bot.chat("今日はいい天気です。") + assert res.messages[0].text == "今日はいい天気です。" + words = res.messages[0].payloads[0].content + assert words == [] + + +def test_chat_parse_morph_manually(): + bot = Minette( + default_dialog_service=TaggerManuallyParseDialog, + tagger=JanomeTagger, tagger_max_length=0) + bot.dialog_uses(tagger=bot.tagger) + res = bot.chat("今日はいい天気です。") + assert res.messages[0].text == "今日はいい天気です。" + words = res.messages[0].payloads[0].content + assert words[0].surface == "今日" + assert words[1].surface == "は" + assert words[2].surface == "いい" + assert words[3].surface == "天気" + assert words[4].surface == "です" + + +def test_chat_parse_morph_manually_generator(): + bot = Minette( + default_dialog_service=TaggerManuallyParseGeneratorDialog, + tagger=JanomeTagger, tagger_max_length=0) + bot.dialog_uses(tagger=bot.tagger) + res = bot.chat("今日はいい天気です。") + assert res.messages[0].text == "今日はいい天気です。" + assert isinstance(res.messages[0].payloads[0].content, GeneratorType) + words = [w for w in res.messages[0].payloads[0].content] + assert words[0].surface == "今日" + assert words[1].surface == "は" + assert words[2].surface == "いい" + assert words[3].surface == "天気" + assert words[4].surface == "です" + + def test_dialog_uses(): class HighCostToCreate: pass