From 7a81613879b39ffa9651d1b3eba33b92092ff81b Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 13 Sep 2018 19:01:26 -0400 Subject: [PATCH] Stop treating ReadyForQuery as a universal result indicator ReadyForQuery is special in auth and simple query flows, where it _does_ signal the final confirmation of the result, but in all other flows other, more specific messages do that. Now, asyncpg will use the rules of a particular subprotocol when determining the timing of the result waiter wakeup. These changes also make most cases of Sync emission unnecessary, although that cleanup will be addressed in subsequent commits. This consolidation also results in a nice reduction of duplicated code. --- asyncpg/protocol/coreproto.pxd | 4 +- asyncpg/protocol/coreproto.pyx | 215 ++++++++++----------------------- asyncpg/protocol/protocol.pyx | 14 ++- 3 files changed, 78 insertions(+), 155 deletions(-) diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index d35334e5..c1fa6567 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -19,7 +19,7 @@ cdef enum ProtocolState: PROTOCOL_CANCELLED = 3 PROTOCOL_AUTH = 10 - PROTOCOL_PREPARE = 11 + PROTOCOL_PARSE_DESCRIBE = 11 PROTOCOL_BIND_EXECUTE = 12 PROTOCOL_BIND_EXECUTE_MANY = 13 PROTOCOL_CLOSE_STMT_PORTAL = 14 @@ -105,7 +105,7 @@ cdef class CoreProtocol: bint result_execute_completed cdef _process__auth(self, char mtype) - cdef _process__prepare(self, char mtype) + cdef _process__parse_describe(self, char mtype) cdef _process__bind_execute(self, char mtype) cdef _process__bind_execute_many(self, char mtype) cdef _process__close_stmt_portal(self, char mtype) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 3ac317bb..21498e7d 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -44,21 +44,56 @@ cdef class CoreProtocol: if mtype == b'S': # ParameterStatus self._parse_msg_parameter_status() - continue + elif mtype == b'A': # NotificationResponse self._parse_msg_notification() - continue + elif mtype == b'N': # 'N' - NoticeResponse self._on_notice(self._parse_msg_error_response(False)) - continue - if state == PROTOCOL_AUTH: + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + # In all cases, except Auth, ErrorResponse will + # be followed by a ReadyForQuery, which is when + # _push_result() will be called. + if state == PROTOCOL_AUTH: + self._push_result() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + + if state != PROTOCOL_BIND_EXECUTE_MANY: + self._push_result() + + else: + if self.result_type == RESULT_FAILED: + self._push_result() + else: + try: + buf = next(self._execute_iter) + except StopIteration: + self._push_result() + except Exception as e: + self.result_type = RESULT_FAILED + self.result = e + self._push_result() + else: + # Next iteration over the executemany() + # arg sequence. + self._send_bind_message( + self._execute_portal_name, + self._execute_stmt_name, + buf, 0) + + elif state == PROTOCOL_AUTH: self._process__auth(mtype) - elif state == PROTOCOL_PREPARE: - self._process__prepare(mtype) + elif state == PROTOCOL_PARSE_DESCRIBE: + self._process__parse_describe(mtype) elif state == PROTOCOL_BIND_EXECUTE: self._process__bind_execute(mtype) @@ -93,42 +128,26 @@ cdef class CoreProtocol: elif state == PROTOCOL_CANCELLED: # discard all messages until the sync message - if mtype == b'E': - self._parse_msg_error_response(True) - elif mtype == b'Z': - self._parse_msg_ready_for_query() - self._push_result() - else: - self.buffer.consume_message() + self.buffer.consume_message() elif state == PROTOCOL_ERROR_CONSUME: # Error in protocol (on asyncpg side); # discard all messages until sync message - - if mtype == b'Z': - # Sync point, self to push the result - if self.result_type != RESULT_FAILED: - self.result_type = RESULT_FAILED - self.result = apg_exc.InternalClientError( - 'unknown error in protocol implementation') - - self._push_result() - - else: - self.buffer.consume_message() + self.buffer.consume_message() else: raise apg_exc.InternalClientError( 'protocol is in an unknown state {}'.format(state)) except Exception as ex: + self.state = PROTOCOL_ERROR_CONSUME self.result_type = RESULT_FAILED self.result = ex if mtype == b'Z': + # This should only happen if _parse_msg_ready_for_query() + # has failed. self._push_result() - else: - self.state = PROTOCOL_ERROR_CONSUME finally: if self._skip_discard: @@ -153,43 +172,27 @@ cdef class CoreProtocol: # BackendKeyData self._parse_msg_backend_key_data() - elif mtype == b'E': - # ErrorResponse - self.con_status = CONNECTION_BAD - self._parse_msg_error_response(True) - self._push_result() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self.con_status = CONNECTION_OK - self._push_result() - - cdef _process__prepare(self, char mtype): - if mtype == b't': - # Parameters description - self.result_param_desc = self.buffer.consume_message().as_bytes() + # push_result() will be initiated by handling + # ReadyForQuery or ErrorResponse in the main loop. - elif mtype == b'1': + cdef _process__parse_describe(self, char mtype): + if mtype == b'1': # ParseComplete self.buffer.consume_message() + elif mtype == b't': + # ParameterDescription + self.result_param_desc = self.buffer.consume_message().as_bytes() + elif mtype == b'T': - # Row description + # RowDescription self.result_row_desc = self.buffer.consume_message().as_bytes() - - elif mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() elif mtype == b'n': # NoData self.buffer.consume_message() + self._push_result() cdef _process__bind_execute(self, char mtype): if mtype == b'D': @@ -199,28 +202,22 @@ cdef class CoreProtocol: elif mtype == b's': # PortalSuspended self.buffer.consume_message() + self._push_result() elif mtype == b'C': # CommandComplete self.result_execute_completed = True self._parse_msg_command_complete() - - elif mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) + self._push_result() elif mtype == b'2': # BindComplete self.buffer.consume_message() - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self._push_result() - elif mtype == b'I': # EmptyQueryResponse self.buffer.consume_message() + self._push_result() cdef _process__bind_execute_many(self, char mtype): cdef WriteBuffer buf @@ -237,64 +234,24 @@ cdef class CoreProtocol: # CommandComplete self._parse_msg_command_complete() - elif mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - elif mtype == b'2': # BindComplete self.buffer.consume_message() - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - if self.result_type == RESULT_FAILED: - self._push_result() - else: - try: - buf = next(self._execute_iter) - except StopIteration: - self._push_result() - except Exception as e: - self.result_type = RESULT_FAILED - self.result = e - self._push_result() - else: - # Next iteration over the executemany() arg sequence - self._send_bind_message( - self._execute_portal_name, self._execute_stmt_name, - buf, 0) - elif mtype == b'I': # EmptyQueryResponse self.buffer.consume_message() cdef _process__bind(self, char mtype): - if mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - - elif mtype == b'2': + if mtype == b'2': # BindComplete self.buffer.consume_message() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() cdef _process__close_stmt_portal(self, char mtype): - if mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - - elif mtype == b'3': + if mtype == b'3': # CloseComplete self.buffer.consume_message() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() cdef _process__simple_query(self, char mtype): @@ -304,42 +261,21 @@ cdef class CoreProtocol: # 'T' - RowDescription self.buffer.consume_message() - elif mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self._push_result() - elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() - else: # We don't really care about COPY IN etc self.buffer.consume_message() cdef _process__copy_out(self, char mtype): - if mtype == b'E': - self._parse_msg_error_response(True) - - elif mtype == b'H': + if mtype == b'H': # CopyOutResponse self._set_state(PROTOCOL_COPY_OUT_DATA) self.buffer.consume_message() - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self._push_result() - cdef _process__copy_out_data(self, char mtype): - if mtype == b'E': - self._parse_msg_error_response(True) - - elif mtype == b'd': + if mtype == b'd': # CopyData self._parse_copy_data_msgs() @@ -351,37 +287,18 @@ cdef class CoreProtocol: elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() cdef _process__copy_in(self, char mtype): - if mtype == b'E': - self._parse_msg_error_response(True) - - elif mtype == b'G': + if mtype == b'G': # CopyInResponse self._set_state(PROTOCOL_COPY_IN_DATA) self.buffer.consume_message() - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self._push_result() - cdef _process__copy_in_data(self, char mtype): - if mtype == b'E': - self._parse_msg_error_response(True) - - elif mtype == b'C': + if mtype == b'C': # CommandComplete self._parse_msg_command_complete() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() cdef _parse_msg_command_complete(self): @@ -739,7 +656,7 @@ cdef class CoreProtocol: WriteBuffer buf self._ensure_connected() - self._set_state(PROTOCOL_PREPARE) + self._set_state(PROTOCOL_PARSE_DESCRIBE) buf = WriteBuffer.new_message(b'P') buf.write_str(stmt_name, self.encoding) diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index e137c74b..3c222db8 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -713,6 +713,7 @@ cdef class BaseProtocol(CoreProtocol): return self.waiter cdef _on_result__connect(self, object waiter): + self.con_status = CONNECTION_OK waiter.set_result(True) cdef _on_result__prepare(self, object waiter): @@ -790,6 +791,10 @@ cdef class BaseProtocol(CoreProtocol): self.result, query=self.last_query) else: exc = self.result + + if self.state == PROTOCOL_AUTH: + self.con_status = CONNECTION_BAD + waiter.set_exception(exc) return @@ -797,7 +802,7 @@ cdef class BaseProtocol(CoreProtocol): if self.state == PROTOCOL_AUTH: self._on_result__connect(waiter) - elif self.state == PROTOCOL_PREPARE: + elif self.state == PROTOCOL_PARSE_DESCRIBE: self._on_result__prepare(waiter) elif self.state == PROTOCOL_BIND_EXECUTE: @@ -847,11 +852,12 @@ cdef class BaseProtocol(CoreProtocol): self.cancel_waiter = None if self.waiter is not None and self.waiter.done(): self.waiter = None - if self.waiter is None: - return try: - self._dispatch_result() + if self.waiter is not None: + # _on_result() might be called several times in the + # process, or the waiter might have been cancelled. + self._dispatch_result() finally: self.statement = None self.last_query = None