diff --git a/source/common/buffer/buffer_impl.cc b/source/common/buffer/buffer_impl.cc index d3a09d2e932ab..c7751bce8eb97 100644 --- a/source/common/buffer/buffer_impl.cc +++ b/source/common/buffer/buffer_impl.cc @@ -6,6 +6,15 @@ namespace Buffer { +// RawSlice is the same structure as evbuffer_iovec. This was put into place to avoid leaking +// libevent into most code since we will likely replace evbuffer with our own implementation at +// some point. However, we can avoid a bunch of copies since the structure is the same. +static_assert(sizeof(RawSlice) == sizeof(evbuffer_iovec), "RawSlice != evbuffer_iovec"); +static_assert(offsetof(RawSlice, mem_) == offsetof(evbuffer_iovec, iov_base), + "RawSlice != evbuffer_iovec"); +static_assert(offsetof(RawSlice, len_) == offsetof(evbuffer_iovec, iov_len), + "RawSlice != evbuffer_iovec"); + void OwnedImpl::add(const void* data, uint64_t size) { evbuffer_add(buffer_.get(), data, size); } void OwnedImpl::add(const std::string& data) { @@ -22,12 +31,8 @@ void OwnedImpl::add(const Instance& data) { } void OwnedImpl::commit(RawSlice* iovecs, uint64_t num_iovecs) { - evbuffer_iovec local_iovecs[num_iovecs]; - for (uint64_t i = 0; i < num_iovecs; i++) { - local_iovecs[i].iov_len = iovecs[i].len_; - local_iovecs[i].iov_base = iovecs[i].mem_; - } - int rc = evbuffer_commit_space(buffer_.get(), local_iovecs, num_iovecs); + int rc = + evbuffer_commit_space(buffer_.get(), reinterpret_cast(iovecs), num_iovecs); ASSERT(rc == 0); UNREFERENCED_PARAMETER(rc); } @@ -40,13 +45,8 @@ void OwnedImpl::drain(uint64_t size) { } uint64_t OwnedImpl::getRawSlices(RawSlice* out, uint64_t out_size) const { - evbuffer_iovec iovecs[out_size]; - uint64_t needed_size = evbuffer_peek(buffer_.get(), -1, nullptr, iovecs, out_size); - for (uint64_t i = 0; i < std::min(out_size, needed_size); i++) { - out[i].mem_ = iovecs[i].iov_base; - out[i].len_ = iovecs[i].iov_len; - } - return needed_size; + return evbuffer_peek(buffer_.get(), -1, nullptr, reinterpret_cast(out), + out_size); } uint64_t OwnedImpl::length() const { return evbuffer_get_length(buffer_.get()); } @@ -79,13 +79,9 @@ int OwnedImpl::read(int fd, uint64_t max_length) { } uint64_t OwnedImpl::reserve(uint64_t length, RawSlice* iovecs, uint64_t num_iovecs) { - evbuffer_iovec local_iovecs[num_iovecs]; - uint64_t ret = evbuffer_reserve_space(buffer_.get(), length, local_iovecs, num_iovecs); + uint64_t ret = evbuffer_reserve_space(buffer_.get(), length, + reinterpret_cast(iovecs), num_iovecs); ASSERT(ret >= 1); - for (uint64_t i = 0; i < ret; i++) { - iovecs[i].len_ = local_iovecs[i].iov_len; - iovecs[i].mem_ = local_iovecs[i].iov_base; - } return ret; } diff --git a/source/common/network/connection_impl.cc b/source/common/network/connection_impl.cc index 81e2b1177352c..642a167083f82 100644 --- a/source/common/network/connection_impl.cc +++ b/source/common/network/connection_impl.cc @@ -231,6 +231,12 @@ void ConnectionImpl::write(Buffer::Instance& data) { if (data.length() > 0) { conn_log_trace("writing {} bytes", *this, data.length()); + // TODO(mattklein123): All data currently gets moved from the source buffer to the write buffer. + // This can lead to inefficient behavior if writing a bunch of small chunks. In this case, it + // would likely be more efficient to copy data below a certain size. VERY IMPORTANT: If this is + // ever changed, read the comment in Ssl::ConnectionImpl::doWriteToSocket() VERY carefully. + // That code assumes that we never change existing write_buffer_ chain elements between calls + // to SSL_write(). That code will have to change if we ever copy here. write_buffer_.move(data); if (!(state_ & InternalState::Connecting)) { file_event_->activate(Event::FileReadyType::Write); diff --git a/source/common/ssl/connection_impl.cc b/source/common/ssl/connection_impl.cc index 3a119b91eaf43..81798fa9ca052 100644 --- a/source/common/ssl/connection_impl.cc +++ b/source/common/ssl/connection_impl.cc @@ -134,47 +134,58 @@ Network::ConnectionImpl::IoResult ConnectionImpl::doWriteToSocket() { } } - if (write_buffer_.length() == 0) { - return {PostIoAction::KeepOpen, 0}; - } + uint64_t total_bytes_written = 0; + bool keep_writing = true; + while ((write_buffer_.length() > 0) && keep_writing) { + // Protect against stack overflow if the buffer has a very large buffer chain. + // TODO(mattklein123): The current evbuffer Buffer::Instance implementation will iterate through + // the entire chain each time this is called to determine how many slices would be needed. In + // this case, we don't care, and only want to fill up to MAX_SLICES. When we swap out evbuffer + // we can change this behavior. + // TODO(mattklein123): As it relates to our fairness efforts, we might want to limit the number + // of iterations of this loop, either by pure iterations, bytes written, etc. + const uint64_t MAX_SLICES = 32; + Buffer::RawSlice slices[MAX_SLICES]; + uint64_t num_slices = std::min(MAX_SLICES, write_buffer_.getRawSlices(slices, MAX_SLICES)); + + uint64_t inner_bytes_written = 0; + for (uint64_t i = 0; i < num_slices; i++) { + // SSL_write() requires that if a previous call returns SSL_ERROR_WANT_WRITE, we need to call + // it again with the same parameters. Most implementations keep track of the last write size. + // In our case we don't need to do that because: a) SSL_write() will not write partial + // buffers. b) We only move() into the write buffer, which means that it's impossible for a + // particular chain to increase in size. So as long as we start writing where we left off we + // are guaranteed to call SSL_write() with the same parameters. + int rc = SSL_write(ssl_.get(), slices[i].mem_, slices[i].len_); + conn_log_trace("ssl write returns: {}", *this, rc); + if (rc > 0) { + inner_bytes_written += rc; + total_bytes_written += rc; + } else { + int err = SSL_get_error(ssl_.get(), rc); + switch (err) { + case SSL_ERROR_WANT_WRITE: + keep_writing = false; + break; + case SSL_ERROR_WANT_READ: + // Renegotiation has started. We don't handle renegotiation so just fall through. + default: + drainErrorQueue(); + return {PostIoAction::Close, total_bytes_written}; + } - uint64_t num_slices = write_buffer_.getRawSlices(nullptr, 0); - Buffer::RawSlice slices[num_slices]; - write_buffer_.getRawSlices(slices, num_slices); - - uint64_t bytes_written = 0; - for (uint64_t i = 0; i < num_slices; i++) { - // SSL_write() requires that if a previous call returns SSL_ERROR_WANT_WRITE, we need to call - // it again with the same parameters. Most implementations keep track of the last write size. - // In our case we don't need to do that because: a) SSL_write() will not write partial buffers. - // b) We only move() into the write buffer, which means that it's impossible for a particular - // chain to increase in size. So as long as we start writing where we left off we are guaranteed - // to call SSL_write() with the same parameters. - int rc = SSL_write(ssl_.get(), slices[i].mem_, slices[i].len_); - conn_log_trace("ssl write returns: {}", *this, rc); - if (rc > 0) { - bytes_written += rc; - } else { - int err = SSL_get_error(ssl_.get(), rc); - switch (err) { - case SSL_ERROR_WANT_WRITE: break; - case SSL_ERROR_WANT_READ: - // Renegotiation has started. We don't handle renegotiation so just fall through. - default: - drainErrorQueue(); - return {PostIoAction::Close, bytes_written}; } - - break; } - } - if (bytes_written > 0) { - write_buffer_.drain(bytes_written); + // Draining must be done within the inner loop, otherwise we will keep getting the same slices + // at the beginning of the buffer. + if (inner_bytes_written > 0) { + write_buffer_.drain(inner_bytes_written); + } } - return {PostIoAction::KeepOpen, bytes_written}; + return {PostIoAction::KeepOpen, total_bytes_written}; } void ConnectionImpl::onConnected() { ASSERT(!handshake_complete_); } diff --git a/test/common/ssl/connection_impl_test.cc b/test/common/ssl/connection_impl_test.cc index e8ff44fe4a5fa..13fa2a85d9aad 100644 --- a/test/common/ssl/connection_impl_test.cc +++ b/test/common/ssl/connection_impl_test.cc @@ -247,9 +247,8 @@ TEST(SslConnectionImplTest, SslError) { class SslReadBufferLimitTest : public testing::Test { public: - void readBufferLimitTest(uint32_t read_buffer_limit, uint32_t expected_chunk_size) { - const uint32_t buffer_size = 256 * 1024; - + void readBufferLimitTest(uint32_t read_buffer_limit, uint32_t expected_chunk_size, + uint32_t write_size, uint32_t num_writes) { Stats::IsolatedStoreImpl stats_store; Event::DispatcherImpl dispatcher; Network::TcpListenSocket socket(uint32_t(10000), true); @@ -306,10 +305,10 @@ class SslReadBufferLimitTest : public testing::Test { EXPECT_CALL(*read_filter, onNewConnection()); EXPECT_CALL(*read_filter, onData(_)) .WillRepeatedly(Invoke([&](Buffer::Instance& data) -> Network::FilterStatus { - EXPECT_EQ(expected_chunk_size, data.length()); + EXPECT_GE(expected_chunk_size, data.length()); filter_seen += data.length(); data.drain(data.length()); - if (filter_seen == buffer_size) { + if (filter_seen == (write_size * num_writes)) { server_connection->close(Network::ConnectionCloseType::FlushWrite); } return Network::FilterStatus::StopIteration; @@ -320,18 +319,27 @@ class SslReadBufferLimitTest : public testing::Test { EXPECT_CALL(client_callbacks, onEvent(Network::ConnectionEvent::Connected)); EXPECT_CALL(client_callbacks, onEvent(Network::ConnectionEvent::RemoteClose)) .WillOnce(Invoke([&](uint32_t) -> void { - EXPECT_EQ(buffer_size, filter_seen); + EXPECT_EQ((write_size * num_writes), filter_seen); dispatcher.exit(); })); - Buffer::OwnedImpl data(std::string(buffer_size, 'a')); - client_connection->write(data); + for (uint32_t i = 0; i < num_writes; i++) { + Buffer::OwnedImpl data(std::string(write_size, 'a')); + client_connection->write(data); + } + dispatcher.run(Event::Dispatcher::RunType::Block); } }; -TEST_F(SslReadBufferLimitTest, NoLimit) { readBufferLimitTest(0, 256 * 1024); } +TEST_F(SslReadBufferLimitTest, NoLimit) { readBufferLimitTest(0, 256 * 1024, 256 * 1024, 1); } -TEST_F(SslReadBufferLimitTest, SomeLimit) { readBufferLimitTest(32 * 1024, 32 * 1024); } +TEST_F(SslReadBufferLimitTest, NoLimitSmallWrites) { + readBufferLimitTest(0, 256 * 1024, 1, 256 * 1024); +} + +TEST_F(SslReadBufferLimitTest, SomeLimit) { + readBufferLimitTest(32 * 1024, 32 * 1024, 256 * 1024, 1); +} } // Ssl