Skip to content
This repository was archived by the owner on Dec 18, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions minette/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None,
user_store=None, user_table=None,
messagelog_store=None, messagelog_table=None,
default_dialog_service=None, dialog_router=None,
tagger=None, prepare_table=True, **kwargs):
tagger=None, tagger_max_length=None, prepare_table=True, **kwargs):
"""
Parameters
----------
Expand Down Expand Up @@ -126,6 +126,8 @@ def __init__(self, *, config=None, config_file=None, timezone=None,
and return proper DialogService for intent
tagger: minette.Tagger or type, default None
Morphological analysis engine
tagger_max_length: Max length of the text to parse morph, default None
Morphological analysis engine
prepare_table: bool, default True
Create tables for data stores if they don't exist.
"""
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(self, *, config=None, config_file=None, timezone=None,
"dialog_router": dialog_router,
"default_dialog_service": default_dialog_service,
"tagger": tagger,
"tagger_max_length": tagger_max_length,
}
setter_args.update({k: v for k, v in kwargs.items() if k not in setter_args})

Expand Down Expand Up @@ -263,10 +266,13 @@ def _get_dialog_router(self, dialog_router, default_dialog_service=None,
dr = dr(default_dialog_service=default_dialog_service, **kwargs)
return dr

def _get_tagger(self, tagger, **kwargs):
def _get_tagger(self, tagger, tagger_max_length, **kwargs):
tg = tagger or Tagger
if issubclass(tg, Tagger):
tg = tg(**kwargs)
if tagger_max_length is not None:
tg = tg(max_length=tagger_max_length, **kwargs)
else:
tg = tg(**kwargs)
return tg

def chat(self, request):
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytz==2020.1
requests==2.24.0
schedule==0.6.0
pytest==6.0.1
Janome==0.4.0
85 changes: 84 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from pytz import timezone
from logging import Logger, FileHandler, getLogger
from datetime import datetime
from types import GeneratorType

from minette import (
Minette, DialogService, SQLiteConnectionProvider,
SQLiteContextStore, SQLiteUserStore, SQLiteMessageLogStore,
Tagger, Config, DialogRouter, StoreSet, Message, User, Group,
DependencyContainer
DependencyContainer, Payload
)
from minette.utils import date_to_unixtime
from minette.tagger.janometagger import JanomeTagger

now = datetime.now()
user_id = "user_id" + str(date_to_unixtime(now))
Expand Down Expand Up @@ -63,6 +65,31 @@ def __init__(self, custom_router_arg=None, **kwargs):
self.custom_attr = custom_router_arg


class TaggerDialog(DialogService):
def compose_response(self, request, context, connection):
return request.to_reply(
text=request.text,
payloads=[Payload(content_type="data", content=request.words)])


class TaggerManuallyParseDialog(DialogService):
def compose_response(self, request, context, connection):
assert request.words == []
request.words = self.dependencies.tagger.parse(request.text, max_length=10)
return request.to_reply(
text=request.text,
payloads=[Payload(content_type="data", content=request.words)])


class TaggerManuallyParseGeneratorDialog(DialogService):
def compose_response(self, request, context, connection):
assert request.words == []
request.words = self.dependencies.tagger.parse_as_generator(request.text, max_length=10)
return request.to_reply(
text=request.text,
payloads=[Payload(content_type="data", content=request.words)])


def test_init():
# without config
bot = Minette()
Expand Down Expand Up @@ -317,6 +344,62 @@ def test_chat_timezone():
assert res.messages[0].timestamp.tzinfo == datetime.now(tz=bot.timezone).tzinfo


def test_chat_with_tagger():
bot = Minette(
default_dialog_service=TaggerDialog,
tagger=JanomeTagger)
res = bot.chat("今日はいい天気です。")
assert res.messages[0].text == "今日はいい天気です。"
words = res.messages[0].payloads[0].content
assert words[0].surface == "今日"
assert words[1].surface == "は"
assert words[2].surface == "いい"
assert words[3].surface == "天気"
assert words[4].surface == "です"


def test_chat_with_tagger_no_parse():
bot = Minette(
default_dialog_service=TaggerDialog,
tagger=JanomeTagger, tagger_max_length=0)
assert bot.tagger.max_length == 0
res = bot.chat("今日はいい天気です。")
assert res.messages[0].text == "今日はいい天気です。"
words = res.messages[0].payloads[0].content
assert words == []


def test_chat_parse_morph_manually():
bot = Minette(
default_dialog_service=TaggerManuallyParseDialog,
tagger=JanomeTagger, tagger_max_length=0)
bot.dialog_uses(tagger=bot.tagger)
res = bot.chat("今日はいい天気です。")
assert res.messages[0].text == "今日はいい天気です。"
words = res.messages[0].payloads[0].content
assert words[0].surface == "今日"
assert words[1].surface == "は"
assert words[2].surface == "いい"
assert words[3].surface == "天気"
assert words[4].surface == "です"


def test_chat_parse_morph_manually_generator():
bot = Minette(
default_dialog_service=TaggerManuallyParseGeneratorDialog,
tagger=JanomeTagger, tagger_max_length=0)
bot.dialog_uses(tagger=bot.tagger)
res = bot.chat("今日はいい天気です。")
assert res.messages[0].text == "今日はいい天気です。"
assert isinstance(res.messages[0].payloads[0].content, GeneratorType)
words = [w for w in res.messages[0].payloads[0].content]
assert words[0].surface == "今日"
assert words[1].surface == "は"
assert words[2].surface == "いい"
assert words[3].surface == "天気"
assert words[4].surface == "です"


def test_dialog_uses():
class HighCostToCreate:
pass
Expand Down