diff --git a/README.md b/README.md index f10247c..fc506d7 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ This command binds a particular message type `msgtype` (passed as a string) to a connect() -Makes the bot connect to the specified irc server. +Makes the bot connect to the specified irc server. This function is blocking until the server sends the end of the motd or the connect_timeout is reached. The return value indicates which case you have (`True` means successful). debug(state) diff --git a/ircbotframe.py b/ircbotframe.py index b08230a..46be5ae 100644 --- a/ircbotframe.py +++ b/ircbotframe.py @@ -4,50 +4,41 @@ import ssl import threading import time +import sched +import queue + class ircOutputBuffer: - # Delays consecutive messages by at least 1 second. - # This prevents the bot spamming the IRC server. + # This class provides buffered and unbuffered sending to a socket def __init__(self, irc): - self.waiting = False self.irc = irc - self.queue = [] - self.error = False - - def __pop(self): - if len(self.queue) == 0: - self.waiting = False - else: - self.sendImmediately(self.queue[0]) - self.queue = self.queue[1:] - self.__startPopTimer() - - def __startPopTimer(self): - self.timer = threading.Timer(1, self.__pop) - self.timer.start() + self.queue = queue.Queue() def sendBuffered(self, string): # Sends the given string after the rest of the messages in the buffer. - # There is a 1 second gap between each message. - if self.waiting: - self.queue.append(string) - else: - self.waiting = True - self.sendImmediately(string) - self.__startPopTimer() + self.queue.put_nowait(string) + return True + + def sendFromQueue(self): + # Send the oldest message in the buffer if there is one + try: + string = self.queue.get_nowait() + result = self.sendImmediately(string) + self.queue.task_done() + return result + except queue.Empty: + return True def sendImmediately(self, string): # Sends the given string without buffering. - if not self.error: - try: - self.irc.send((string + "\r\n").encode("utf-8")) - except socket.error as msg: - self.error = True - print("Output error", msg) - print("Was sending \"" + string + "\"") + try: + self.irc.send((string + "\r\n").encode("utf-8")) + return True + except socket.error as msg: + print("Output error", msg) + print("Was sending \"" + string + "\"") + return False - def isInError(self): - return self.error class ircInputBuffer: # Keeps a record of the last line fragment received by the socket which is usually not a complete line. @@ -61,8 +52,6 @@ def __recv(self): # Receives new data from the socket and splits it into lines. try: data = self.buffer + self.irc.recv(4096).decode("utf-8") - except socket.error as msg: - raise socket.error(msg) except UnicodeDecodeError: data = '' self.lines += data.split("\r\n") @@ -71,19 +60,23 @@ def __recv(self): self.lines = self.lines[:-1] def getLine(self): - # Returns the next line of IRC received by the socket. + # Returns the next line of IRC received by the socket or None. # This should already be in the standard string format. - # If no lines are buffered, this blocks until a line is received. + # If no lines are buffered, this blocks until a line is received + # or we reach the socket timeout. When the timeout is + # reached, the function returns None. + while len(self.lines) == 0: try: self.__recv() - except socket.error as msg: - raise socket.error(msg) - time.sleep(1); + except socket.timeout: + return None + line = self.lines[0] self.lines = self.lines[1:] return line + class ircBot(threading.Thread): def __init__(self, network, port, name, description, password=None, ssl=False, ip_ver=None): threading.Thread.__init__(self) @@ -101,6 +94,19 @@ def __init__(self, network, port, name, description, password=None, ssl=False, i self.default_log_length = 200 self.log_own_messages = True self.channel_data = {} + self.irc = None + self.outBuf = None + self.inBuf = None + self.connected = False + self.connect_timeout = 30 + self.reconnect_interval = 30 + self.ping_timeout = 10 + self.ping_interval = 60 + + self.bind("PONG", self.__handlePong) + self.__unansweredPing = False + self.__sched = sched.scheduler() + if ip_ver == 4: self.socket_family = socket.AF_INET elif ip_ver == 6: @@ -112,6 +118,7 @@ def __init__(self, network, port, name, description, password=None, ssl=False, i for family, _, _, _, _ in socket.getaddrinfo(network, port, proto=socket.IPPROTO_TCP): if family == socket.AF_INET6: self.socket_family = socket.AF_INET6 + break else: self.socket_family = socket.AF_INET else: @@ -191,6 +198,8 @@ def __processLine(self, line): else: msgtype = headers[1] self.__debugPrint('[' + msgtype + '] ' + message) + if msgtype == '376': + self.connected = True if msgtype in ['307', '330'] and len(headers) >= 4: self.__identAccept(headers[3]) if msgtype == '318' and len(headers) >= 4: @@ -206,6 +215,60 @@ def __debugPrint(self, s): if self.debug: print(s) + def __periodicSend(self): + if not self.irc: + return + + if not self.outBuf.sendFromQueue(): + self.close() + return + + # Delays consecutive messages by at least 1 second. + # This prevents the bot spamming the IRC server. + self.__sched.enter(1, priority=10, action=self.__periodicSend) + + def __periodicRecv(self): + if not self.irc: + return + + try: + line = self.inBuf.getLine() + except socket.error as msg: + self.__debugPrint("Input error", msg) + self.close() + return + + if line is not None: + if line.startswith("PING"): + if not self.outBuf.sendImmediately("PONG " + line.split()[1]): + self.close() + return + else: + self.__processLine(line) + + # next recv should be directly but with verly low priority + self.__sched.enter(0.01, priority=1, action=self.__periodicRecv) + + def __periodicPing(self): + self.ping() + self.__sched.enter(self.ping_interval, 1, self.__periodicPing) + + def __handlePong(self, sender, headers, message): + self.__unansweredPing = False + + def __handlePingTimeout(self): + if self.__unansweredPing: + self.__debugPrint("Ping timeout reached. Killing the connection.") + self.close() + + def ping(self): + if self.__unansweredPing: + return + + self.outBuf.sendImmediately('PING %s' % self.network) + self.__unansweredPing = True + self.__sched.enter(self.ping_timeout, 1, self.__handlePingTimeout) + def log(self, channel, msgtype, sender, headers, message): if channel in self.channel_data: self.channel_data[channel]['log'].append((msgtype, sender, headers, message)) @@ -216,41 +279,81 @@ def log(self, channel, msgtype, sender, headers, message): def ban(self, banMask, channel, reason): # only bans, no kick. self.__debugPrint("Banning " + banMask + "...") - self.outBuf.sendBuffered("MODE +b " + channel + " " + banMask) + self.send("MODE +b " + channel + " " + banMask) # TODO get nick #self.kick(nick, channel, reason) def bind(self, msgtype, callback): self.binds[msgtype] = callback + def __handleConnectingTimeout(self): + if not self.connected: + self.close() + def connect(self): self.__debugPrint("Connecting...") self.irc = socket.socket(self.socket_family, socket.SOCK_STREAM) + self.irc.settimeout(self.connect_timeout) + if self.ssl: self.irc = ssl.wrap_socket(self.irc) - self.irc.connect((self.network, self.port)) + + try: + self.irc.connect((self.network, self.port)) + except socket.error as msg: + self.__debugPrint("Connection failed: %s" % msg) + self.close() + return False + + self.irc.settimeout(1.0) + self.inBuf = ircInputBuffer(self.irc) self.outBuf = ircOutputBuffer(self.irc) + if self.password is not None: self.outBuf.sendBuffered("PASS " + self.password) + self.outBuf.sendBuffered("NICK " + self.name) self.outBuf.sendBuffered("USER " + self.name + " 0 * :" + self.desc) + self.connected = False + + self.__periodicSend() + self.__periodicRecv() + self.__sched.enter(self.connect_timeout, priority=20, action=self.__handleConnectingTimeout) + + while True: + if self.connected: + self.__debugPrint("Connection was successful!") + return True + + if self.irc is None: + return False + + self.__sched.run(blocking=False) + def debugging(self, state): self.debug = state + def close(self): + self.outBuf = None + self.inBuf = None + self.irc.close() + self.irc = None + self.connected = False + def disconnect(self, qMessage): self.__debugPrint("Disconnecting...") # TODO make the following block until the message is sent - self.outBuf.sendBuffered("QUIT :" + qMessage) - self.irc.close() + self.send("QUIT :" + qMessage) + self.close() def identify(self, nick, approvedFunc, approvedParams, deniedFunc, deniedParams): self.__debugPrint("Verifying " + nick + "...") self.identifyNickCommands += [(nick, approvedFunc, approvedParams, deniedFunc, deniedParams)] # TODO this doesn't seem right if not self.identifyLock: - self.outBuf.sendBuffered("WHOIS " + nick) + self.send("WHOIS " + nick) self.identifyLock = True def joinchan(self, channel): @@ -259,42 +362,49 @@ def joinchan(self, channel): 'log': [], 'log_length': self.default_log_length } - self.outBuf.sendBuffered("JOIN " + channel) + self.send("JOIN " + channel) def kick(self, nick, channel, reason): self.__debugPrint("Kicking " + nick + "...") - self.outBuf.sendBuffered("KICK " + channel + " " + nick + " :" + reason) + self.send("KICK " + channel + " " + nick + " :" + reason) + + def reconnect(self, gracefully=True): + if gracefully: + self.disconnect("Reconnecting") + else: + self.close() - def reconnect(self): - self.disconnect("Reconnecting") self.__debugPrint("Pausing before reconnecting...") - time.sleep(5) + time.sleep(self.reconnect_interval) self.connect() def run(self): self.__debugPrint("Bot is now running.") self.connect() + + self.__periodicPing() + while self.keepGoing: - line = "" - while len(line) == 0: - try: - line = self.inBuf.getLine() - except socket.error as msg: - print("Input error", msg) - self.reconnect() - if line.startswith("PING"): - self.outBuf.sendImmediately("PONG " + line.split()[1]) - else: - self.__processLine(line) - if self.outBuf.isInError(): - self.reconnect() + if self.irc is None: + self.__debugPrint("Pausing before reconnecting...") + time.sleep(self.reconnect_interval) + self.connect() + continue + + self.__sched.run(blocking=False) + + self.disconnect() def say(self, recipient, message): if self.log_own_messages: self.log(recipient, 'PRIVMSG', self.name, [recipient], message) - self.outBuf.sendBuffered("PRIVMSG " + recipient + " :" + message) + self.send("PRIVMSG " + recipient + " :" + message) def send(self, string): + if not self.connected: + self.__debugPrint("WARNING: you are trying to send without being connected - \"", string, "\"") + return + self.outBuf.sendBuffered(string) def stop(self): @@ -305,4 +415,4 @@ def topic(self, channel, message): def unban(self, banMask, channel): self.__debugPrint('Unbanning ' + banMask + '...') - self.outBuf.sendBuffered('MODE -b ' + channel + ' ' + banMask) + self.send('MODE -b ' + channel + ' ' + banMask)