From 70a1361200d3b32539bc93058813a57b8a7aa73a Mon Sep 17 00:00:00 2001 From: Frank Dai Date: Mon, 4 Nov 2019 21:07:39 -0800 Subject: [PATCH] Add overly-safe synchronization --- ircbot/ircbot.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ircbot/ircbot.py b/ircbot/ircbot.py index 16da42c..9b21e2d 100644 --- a/ircbot/ircbot.py +++ b/ircbot/ircbot.py @@ -54,6 +54,15 @@ MAX_CLIENT_MSG = 435 +def synchronize(method): + """Decorator to wrap a method in a lock-acquiring context manager""" + @functools.wraps(method) + def new_method(self, *args, **kwargs): + with self.lock: + return method(self, *args, **kwargs) + return new_method + + class Listener(NamedTuple): pattern: Pattern fn: FunctionType @@ -135,6 +144,9 @@ def __init__( self.plugins: Dict[str, ModuleType] = {} self.extra_channels: Set[str] = set() # plugins can add stuff here + # As we use threads, we should ensure that we use them safely + self.lock = threading.RLock() + # Register plugins before joining the server. self.register_plugins() @@ -146,6 +158,7 @@ def __init__( connect_factory=factory, ) + @synchronize def register_plugins(self): for importer, mod_name, _ in pkgutil.iter_modules(['ircbot/plugin']): mod = importer.find_module(mod_name).load_module(mod_name) @@ -154,6 +167,7 @@ def register_plugins(self): if register is not None: register(self) + @synchronize def handle_error(self, error_message): # for debugging purposes print(error_message) @@ -162,6 +176,7 @@ def handle_error(self, error_message): if not TESTING: send_problem_report(error_message) + @synchronize def listen( self, pattern, @@ -183,6 +198,7 @@ def listen( ), ) + @synchronize def on_welcome(self, conn, _): conn.privmsg('NickServ', f'identify {self.nickserv_password}') @@ -190,6 +206,7 @@ def on_welcome(self, conn, _): for channel in IRC_CHANNELS_OPER | IRC_CHANNELS_ANNOUNCE | self.extra_channels: conn.join(channel) + @synchronize def on_pubmsg(self, conn, event): if event.target in self.channels: is_oper = False @@ -286,10 +303,12 @@ def respond(raw_text, ping=True): if raw_text[0] != '!': self.recent_messages[event.target].appendleft((user, raw_text)) + @synchronize def on_currenttopic(self, connection, event): channel, topic = event.arguments self.topics[channel] = topic + @synchronize def on_topic(self, connection, event): topic, = event.arguments self.topics[event.target] = topic @@ -299,6 +318,7 @@ def on_invite(self, connection, event): import ircbot.plugin.channels return ircbot.plugin.channels.on_invite(self, connection, event) + @synchronize def add_thread(self, func): def thread_func(): try: @@ -330,6 +350,7 @@ def thread_func(): thread = threading.Thread(target=thread_func, daemon=True) thread.start() + @synchronize def bump_topic(self): for channel, topic in self.topics.items(): def plusone(m): @@ -339,6 +360,7 @@ def plusone(m): if topic != new_topic: self.connection.topic(channel, new_topic=new_topic) + @synchronize def say(self, channel, message): # Find the length of the full message msg_len = len(f'PRIVMSG {channel} :{message}\r\n'.encode('utf-8'))