From ca0bf415f5da548b61265ec63407efe63f4f2a0b Mon Sep 17 00:00:00 2001 From: Matt Klein Date: Fri, 10 Mar 2017 09:59:20 -0800 Subject: [PATCH 1/4] redis: command splitting This commit adds end to end support for incr, incrby, and mget, including full command splitting for mget and proper hashing for single server commands. There is much more work needed (more logging, stats, timeouts, etc.) but the basic flow is working and I am able to proxy to multiple backend redis servers locally. --- include/envoy/redis/codec.h | 5 + include/envoy/redis/command_splitter.h | 58 +++ include/envoy/redis/conn_pool.h | 24 +- source/common/CMakeLists.txt | 1 + source/common/http/filter/ratelimit.h | 4 +- source/common/json/json_validator.h | 4 +- source/common/redis/codec_impl.cc | 25 ++ source/common/redis/command_splitter_impl.cc | 169 +++++++++ source/common/redis/command_splitter_impl.h | 113 ++++++ source/common/redis/conn_pool_impl.cc | 13 +- source/common/redis/conn_pool_impl.h | 18 +- source/common/redis/proxy_filter.cc | 34 +- source/common/redis/proxy_filter.h | 22 +- source/common/router/config_utility.h | 4 +- source/common/router/router_ratelimit.cc | 2 +- source/common/router/router_ratelimit.h | 2 +- source/server/config/network/redis_proxy.cc | 9 +- source/server/server.cc | 2 - source/server/server.h | 1 - test/CMakeLists.txt | 1 + test/common/redis/codec_impl_test.cc | 6 + .../redis/command_splitter_impl_test.cc | 333 ++++++++++++++++++ test/common/redis/conn_pool_impl_test.cc | 48 +-- test/common/redis/proxy_filter_test.cc | 76 ++-- test/mocks/common.h | 5 +- test/mocks/redis/mocks.cc | 21 +- test/mocks/redis/mocks.h | 54 ++- test/test_common/printers.cc | 12 +- test/test_common/printers.h | 14 +- 29 files changed, 910 insertions(+), 170 deletions(-) create mode 100644 include/envoy/redis/command_splitter.h create mode 100644 source/common/redis/command_splitter_impl.cc create mode 100644 source/common/redis/command_splitter_impl.h create mode 100644 test/common/redis/command_splitter_impl_test.cc diff --git a/include/envoy/redis/codec.h b/include/envoy/redis/codec.h index edfc218f164b5..c4ca0cd3560bd 100644 --- a/include/envoy/redis/codec.h +++ b/include/envoy/redis/codec.h @@ -19,6 +19,11 @@ class RespValue { RespValue() : type_(RespType::Null) {} ~RespValue() { cleanup(); } + /** + * Convert a RESP value to a string for debugging purposes. + */ + std::string toString() const; + /** * The following are getters and setters for the internal value. A RespValue start as null, * and much change type via type() before the following methods can be used. diff --git a/include/envoy/redis/command_splitter.h b/include/envoy/redis/command_splitter.h new file mode 100644 index 0000000000000..8fae618cb3df5 --- /dev/null +++ b/include/envoy/redis/command_splitter.h @@ -0,0 +1,58 @@ +#pragma once + +#include "envoy/common/pure.h" +#include "envoy/redis/codec.h" + +namespace Redis { +namespace CommandSplitter { + +/** + * A handle to a split request. + */ +class SplitRequest { +public: + virtual ~SplitRequest() {} + + /** + * Cancel the request. No further request callbacks will be called. + */ + virtual void cancel() PURE; +}; + +typedef std::unique_ptr SplitRequestPtr; + +/** + * Split request callbacks. + */ +class SplitCallbacks { +public: + virtual ~SplitCallbacks() {} + + /** + * Called when the response is ready. + * @param value supplies the response which is now owned by the callee. + */ + virtual void onResponse(RespValuePtr&& value) PURE; +}; + +/** + * A command splitter that takes incoming redis commands and splits them as appropriate to a + * backend connection pool. + */ +class Instance { +public: + virtual ~Instance() {} + + /** + * Make a split redis request. + * @param request supplies the split request to make. + * @param callbacks supplies the split request completion callbacks. + * @return SplitRequestPtr a handle to the active request or nullptr if the request has already + * been satisfied (via onResponse() being called). The splitter ALWAYS calls + * onResponse() for a given request. + */ + virtual SplitRequestPtr makeRequest(const RespValue& request, SplitCallbacks& callbacks) PURE; +}; + +} // CommandSplitter +} // Redis diff --git a/include/envoy/redis/conn_pool.h b/include/envoy/redis/conn_pool.h index a783a0652b755..29d2cc515c54e 100644 --- a/include/envoy/redis/conn_pool.h +++ b/include/envoy/redis/conn_pool.h @@ -9,9 +9,9 @@ namespace ConnPool { /** * A handle to an outbound request. */ -class ActiveRequest { +class PoolRequest { public: - virtual ~ActiveRequest() {} + virtual ~PoolRequest() {} /** * Cancel the request. No further request callbacks will be called. @@ -22,9 +22,9 @@ class ActiveRequest { /** * Outbound request callbacks. */ -class ActiveRequestCallbacks { +class PoolCallbacks { public: - virtual ~ActiveRequestCallbacks() {} + virtual ~PoolCallbacks() {} /** * Called when a pipelined response is received. @@ -59,10 +59,10 @@ class Client { * Make a pipelined request to the remote redis server. * @param request supplies the RESP request to make. * @param callbacks supplies the request callbacks. - * @return ActiveRequest* a handle to the active request. + * @return PoolRequest* a handle to the active request or nullptr if the request could not be made + * for some reason. */ - virtual ActiveRequest* makeRequest(const RespValue& request, - ActiveRequestCallbacks& callbacks) PURE; + virtual PoolRequest* makeRequest(const RespValue& request, PoolCallbacks& callbacks) PURE; }; typedef std::unique_ptr ClientPtr; @@ -93,12 +93,14 @@ class Instance { * @param hash_key supplies the key to use for consistent hashing. * @param request supplies the request to make. * @param callbacks supplies the request completion callbacks. - * @return ActiveRequest* a handle to the active request or nullptr if the request could not - * be made for some reason. + * @return PoolRequest* a handle to the active request or nullptr if the request could not be made + * for some reason. */ - virtual ActiveRequest* makeRequest(const std::string& hash_key, const RespValue& request, - ActiveRequestCallbacks& callbacks) PURE; + virtual PoolRequest* makeRequest(const std::string& hash_key, const RespValue& request, + PoolCallbacks& callbacks) PURE; }; +typedef std::unique_ptr InstancePtr; + } // ConnPool } // Redis diff --git a/source/common/CMakeLists.txt b/source/common/CMakeLists.txt index efa3d490eae72..ce9ea91c2e291 100644 --- a/source/common/CMakeLists.txt +++ b/source/common/CMakeLists.txt @@ -88,6 +88,7 @@ add_library( profiler/profiler.cc ratelimit/ratelimit_impl.cc redis/codec_impl.cc + redis/command_splitter_impl.cc redis/conn_pool_impl.cc redis/proxy_filter.cc router/config_impl.cc diff --git a/source/common/http/filter/ratelimit.h b/source/common/http/filter/ratelimit.h index 64994fc58da32..437086e7de073 100644 --- a/source/common/http/filter/ratelimit.h +++ b/source/common/http/filter/ratelimit.h @@ -23,11 +23,11 @@ enum class FilterRequestType { Internal, External, Both }; /** * Global configuration for the HTTP rate limit filter. */ -class FilterConfig : Json::JsonValidator { +class FilterConfig : Json::Validator { public: FilterConfig(const Json::Object& config, const LocalInfo::LocalInfo& local_info, Stats::Store& global_store, Runtime::Loader& runtime, Upstream::ClusterManager& cm) - : Json::JsonValidator(config, Json::Schema::RATE_LIMIT_HTTP_FILTER_SCHEMA), + : Json::Validator(config, Json::Schema::RATE_LIMIT_HTTP_FILTER_SCHEMA), domain_(config.getString("domain")), stage_(static_cast(config.getInteger("stage", 0))), request_type_(stringToType(config.getString("request_type", "both"))), diff --git a/source/common/json/json_validator.h b/source/common/json/json_validator.h index c01c972522fde..0e957746ec932 100644 --- a/source/common/json/json_validator.h +++ b/source/common/json/json_validator.h @@ -7,9 +7,9 @@ namespace Json { /** * Base class to inherit from to validate config schema before initializing member variables. */ -class JsonValidator { +class Validator { public: - JsonValidator(const Json::Object& config, const std::string& schema) { + Validator(const Json::Object& config, const std::string& schema) { config.validateSchema(schema); } }; diff --git a/source/common/redis/codec_impl.cc b/source/common/redis/codec_impl.cc index b350b5356b004..802bc8ad4ffc9 100644 --- a/source/common/redis/codec_impl.cc +++ b/source/common/redis/codec_impl.cc @@ -5,6 +5,31 @@ namespace Redis { +std::string RespValue::toString() const { + switch (type_) { + case RespType::Array: { + std::string ret = "["; + for (uint64_t i = 0; i < asArray().size(); i++) { + ret += asArray()[i].toString(); + if (i != asArray().size() - 1) { + ret += ", "; + } + } + return ret + "]"; + } + case RespType::SimpleString: + case RespType::BulkString: + case RespType::Error: + return fmt::format("\"{}\"", asString()); + case RespType::Null: + return "null"; + case RespType::Integer: + return std::to_string(asInteger()); + } + + NOT_REACHED; +} + std::vector& RespValue::asArray() { ASSERT(type_ == RespType::Array); return array_; diff --git a/source/common/redis/command_splitter_impl.cc b/source/common/redis/command_splitter_impl.cc new file mode 100644 index 0000000000000..98860a84d0e77 --- /dev/null +++ b/source/common/redis/command_splitter_impl.cc @@ -0,0 +1,169 @@ +#include "command_splitter_impl.h" + +#include "common/common/assert.h" + +namespace Redis { +namespace CommandSplitter { + +RespValuePtr Utility::makeError(const std::string& error) { + RespValuePtr response(new RespValue()); + response->type(RespType::Error); + response->asString() = error; + return response; +} + +SplitRequestPtr AllParamsToOneServerCommandHandler::startRequest(const RespValue& request, + SplitCallbacks& callbacks) { + std::unique_ptr request_handle(new SplitRequestImpl(callbacks)); + request_handle->handle_ = + conn_pool_.makeRequest(request.asArray()[1].asString(), request, *request_handle); + if (!request_handle->handle_) { + callbacks.onResponse(Utility::makeError("no upstream host")); + return nullptr; + } + + return request_handle; +} + +AllParamsToOneServerCommandHandler::SplitRequestImpl::~SplitRequestImpl() { ASSERT(!handle_); } + +void AllParamsToOneServerCommandHandler::SplitRequestImpl::cancel() { + handle_->cancel(); + handle_ = nullptr; +} + +void AllParamsToOneServerCommandHandler::SplitRequestImpl::onResponse(RespValuePtr&& response) { + handle_ = nullptr; + log_debug("redis: response: '{}'", response->toString()); + callbacks_.onResponse(std::move(response)); +} + +void AllParamsToOneServerCommandHandler::SplitRequestImpl::onFailure() { + handle_ = nullptr; + callbacks_.onResponse(Utility::makeError("upstream failure")); +} + +SplitRequestPtr MGETCommandHandler::startRequest(const RespValue& request, + SplitCallbacks& callbacks) { + std::unique_ptr request_handle( + new SplitRequestImpl(callbacks, request.asArray().size() - 1)); + + // Create the get request that we will use for each split get below. + std::vector values(2); + values[0].type(RespType::BulkString); + values[0].asString() = "get"; + values[1].type(RespType::BulkString); + RespValue single_mget; + single_mget.type(RespType::Array); + single_mget.asArray().swap(values); + + for (uint64_t i = 1; i < request.asArray().size(); i++) { + request_handle->pending_requests_.emplace_back(*request_handle, i - 1); + SplitRequestImpl::PendingRequest& pending_request = request_handle->pending_requests_.back(); + + single_mget.asArray()[1].asString() = request.asArray()[i].asString(); + log_debug("redis: parallel get: '{}'", single_mget.toString()); + pending_request.handle_ = + conn_pool_.makeRequest(request.asArray()[i].asString(), single_mget, pending_request); + if (!pending_request.handle_) { + pending_request.onResponse(Utility::makeError("no upstream host")); + } + } + + return request_handle->pending_responses_ > 0 ? std::move(request_handle) : nullptr; +} + +MGETCommandHandler::SplitRequestImpl::SplitRequestImpl(SplitCallbacks& callbacks, + uint32_t num_responses) + : callbacks_(callbacks), pending_responses_(num_responses) { + pending_response_.reset(new RespValue()); + pending_response_->type(RespType::Array); + std::vector responses(num_responses); + pending_response_->asArray().swap(responses); + pending_requests_.reserve(num_responses); +} + +MGETCommandHandler::SplitRequestImpl::~SplitRequestImpl() { +#ifndef NDEBUG + for (const PendingRequest& request : pending_requests_) { + ASSERT(!request.handle_); + } +#endif +} + +void MGETCommandHandler::SplitRequestImpl::cancel() { + for (PendingRequest& request : pending_requests_) { + if (request.handle_) { + request.handle_->cancel(); + request.handle_ = nullptr; + } + } +} + +void MGETCommandHandler::SplitRequestImpl::onResponse(RespValuePtr&& value, uint32_t index) { + pending_requests_[index].handle_ = nullptr; + + pending_response_->asArray()[index].type(value->type()); + switch (value->type()) { + case RespType::Array: + case RespType::Integer: { + pending_response_->asArray()[index].type(RespType::Error); + pending_response_->asArray()[index].asString() = "upstream protocol error"; + break; + } + case RespType::SimpleString: + case RespType::BulkString: + case RespType::Error: { + pending_response_->asArray()[index].asString().swap(value->asString()); + break; + } + case RespType::Null: + break; + } + + ASSERT(pending_responses_ > 0); + if (--pending_responses_ == 0) { + log_debug("redis: response: '{}'", pending_response_->toString()); + callbacks_.onResponse(std::move(pending_response_)); + } +} + +void MGETCommandHandler::SplitRequestImpl::onFailure(uint32_t index) { + onResponse(Utility::makeError("upstream failure"), index); +} + +InstanceImpl::InstanceImpl(ConnPool::InstancePtr&& conn_pool) + : conn_pool_(std::move(conn_pool)), all_to_one_handler_(*conn_pool_), + mget_handler_(*conn_pool_) { + // TODO(mattklein123) PERF: Make this a trie (like in header_map_impl). + // TODO(mattklein123): Make not case sensitive (like in header_map_impl). + command_map_.emplace("incr", all_to_one_handler_); + command_map_.emplace("incrby", all_to_one_handler_); + command_map_.emplace("mget", mget_handler_); +} + +SplitRequestPtr InstanceImpl::makeRequest(const RespValue& request, SplitCallbacks& callbacks) { + if (request.type() != RespType::Array || request.asArray().size() < 2) { + callbacks.onResponse(Utility::makeError("invalid request")); + return nullptr; + } + + for (const RespValue& value : request.asArray()) { + if (value.type() != RespType::BulkString) { + callbacks.onResponse(Utility::makeError("invalid request")); + return nullptr; + } + } + + auto handler = command_map_.find(request.asArray()[0].asString()); + if (handler == command_map_.end()) { + callbacks.onResponse(Utility::makeError("unsupported command")); + return nullptr; + } + + log_debug("redis: splitting '{}'", request.toString()); + return handler->second.get().startRequest(request, callbacks); +} + +} // CommandSplitter +} // Redis diff --git a/source/common/redis/command_splitter_impl.h b/source/common/redis/command_splitter_impl.h new file mode 100644 index 0000000000000..f787777790b51 --- /dev/null +++ b/source/common/redis/command_splitter_impl.h @@ -0,0 +1,113 @@ +#pragma once + +#include "envoy/redis/command_splitter.h" +#include "envoy/redis/conn_pool.h" + +#include "common/common/logger.h" + +namespace Redis { +namespace CommandSplitter { + +class Utility { +public: + static RespValuePtr makeError(const std::string& error); +}; + +class CommandHandler { +public: + virtual ~CommandHandler() {} + + virtual SplitRequestPtr startRequest(const RespValue& request, SplitCallbacks& callbacks) PURE; +}; + +class CommandHandlerBase { +protected: + CommandHandlerBase(ConnPool::Instance& conn_pool) : conn_pool_(conn_pool) {} + + ConnPool::Instance& conn_pool_; +}; + +class AllParamsToOneServerCommandHandler : public CommandHandler, + CommandHandlerBase, + Logger::Loggable { +public: + AllParamsToOneServerCommandHandler(ConnPool::Instance& conn_pool) + : CommandHandlerBase(conn_pool) {} + + // Redis::CommandSplitter::CommandHandler + SplitRequestPtr startRequest(const RespValue& request, SplitCallbacks& callbacks) override; + +private: + struct SplitRequestImpl : public SplitRequest, public ConnPool::PoolCallbacks { + SplitRequestImpl(SplitCallbacks& callbacks) : callbacks_(callbacks) {} + ~SplitRequestImpl(); + + // Redis::CommandSplitter::SplitRequest + void cancel() override; + + // Redis::ConnPool::PoolCallbacks + void onResponse(RespValuePtr&& value) override; + void onFailure() override; + + SplitCallbacks& callbacks_; + ConnPool::PoolRequest* handle_{}; + }; +}; + +class MGETCommandHandler : public CommandHandler, + CommandHandlerBase, + Logger::Loggable { +public: + MGETCommandHandler(ConnPool::Instance& conn_pool) : CommandHandlerBase(conn_pool) {} + + // Redis::CommandSplitter::CommandHandler + SplitRequestPtr startRequest(const RespValue& request, SplitCallbacks& callbacks) override; + +private: + struct SplitRequestImpl : public SplitRequest { + struct PendingRequest : public ConnPool::PoolCallbacks { + PendingRequest(SplitRequestImpl& parent, uint32_t index) : parent_(parent), index_(index) {} + + // Redis::ConnPool::PoolCallbacks + void onResponse(RespValuePtr&& value) override { + parent_.onResponse(std::move(value), index_); + } + void onFailure() override { parent_.onFailure(index_); } + + SplitRequestImpl& parent_; + const uint32_t index_; + ConnPool::PoolRequest* handle_{}; + }; + + SplitRequestImpl(SplitCallbacks& callbacks, uint32_t num_responses); + ~SplitRequestImpl(); + + void onResponse(RespValuePtr&& value, uint32_t index); + void onFailure(uint32_t index); + + // Redis::CommandSplitter::SplitRequest + void cancel() override; + + SplitCallbacks& callbacks_; + RespValuePtr pending_response_; + std::vector pending_requests_; + uint32_t pending_responses_; + }; +}; + +class InstanceImpl : public Instance, Logger::Loggable { +public: + InstanceImpl(ConnPool::InstancePtr&& conn_pool); + + // Redis::CommandSplitter::Instance + SplitRequestPtr makeRequest(const RespValue& request, SplitCallbacks& callbacks) override; + +private: + ConnPool::InstancePtr conn_pool_; + AllParamsToOneServerCommandHandler all_to_one_handler_; + MGETCommandHandler mget_handler_; + std::unordered_map> command_map_; +}; + +} // CommandSplitter +} // Redis diff --git a/source/common/redis/conn_pool_impl.cc b/source/common/redis/conn_pool_impl.cc index 35c55cedbcdcd..f971ce4c41a56 100644 --- a/source/common/redis/conn_pool_impl.cc +++ b/source/common/redis/conn_pool_impl.cc @@ -24,8 +24,7 @@ ClientImpl::~ClientImpl() { void ClientImpl::close() { connection_->close(Network::ConnectionCloseType::NoFlush); } -ActiveRequest* ClientImpl::makeRequest(const RespValue& request, - ActiveRequestCallbacks& callbacks) { +PoolRequest* ClientImpl::makeRequest(const RespValue& request, PoolCallbacks& callbacks) { ASSERT(connection_->state() == Network::Connection::State::Open); pending_requests_.emplace_back(callbacks); encoder_->encode(request, encoder_buffer_); @@ -86,8 +85,8 @@ InstanceImpl::InstanceImpl(const std::string& cluster_name, Upstream::ClusterMan }); } -ActiveRequest* InstanceImpl::makeRequest(const std::string& hash_key, const RespValue& value, - ActiveRequestCallbacks& callbacks) { +PoolRequest* InstanceImpl::makeRequest(const std::string& hash_key, const RespValue& value, + PoolCallbacks& callbacks) { return tls_.getTyped(tls_slot_).makeRequest(hash_key, value, callbacks); } @@ -113,9 +112,9 @@ void InstanceImpl::ThreadLocalPool::onHostsRemoved( } } -ActiveRequest* InstanceImpl::ThreadLocalPool::makeRequest(const std::string& hash_key, - const RespValue& request, - ActiveRequestCallbacks& callbacks) { +PoolRequest* InstanceImpl::ThreadLocalPool::makeRequest(const std::string& hash_key, + const RespValue& request, + PoolCallbacks& callbacks) { LbContextImpl lb_context(hash_key); Upstream::HostConstSharedPtr host = cluster_->loadBalancer().chooseHost(&lb_context); if (!host) { diff --git a/source/common/redis/conn_pool_impl.h b/source/common/redis/conn_pool_impl.h index e22b92b7c88ff..34eb896e99cdd 100644 --- a/source/common/redis/conn_pool_impl.h +++ b/source/common/redis/conn_pool_impl.h @@ -28,7 +28,7 @@ class ClientImpl : public Client, public DecoderCallbacks, public Network::Conne connection_->addConnectionCallbacks(callbacks); } void close() override; - ActiveRequest* makeRequest(const RespValue& request, ActiveRequestCallbacks& callbacks) override; + PoolRequest* makeRequest(const RespValue& request, PoolCallbacks& callbacks) override; private: struct UpstreamReadFilter : public Network::ReadFilterBaseImpl { @@ -43,13 +43,13 @@ class ClientImpl : public Client, public DecoderCallbacks, public Network::Conne ClientImpl& parent_; }; - struct PendingRequest : public ActiveRequest { - PendingRequest(ActiveRequestCallbacks& callbacks) : callbacks_(callbacks) {} + struct PendingRequest : public PoolRequest { + PendingRequest(PoolCallbacks& callbacks) : callbacks_(callbacks) {} - // Redis::ConnPool::ActiveRequest + // Redis::ConnPool::PoolRequest void cancel() override; - ActiveRequestCallbacks& callbacks_; + PoolCallbacks& callbacks_; bool canceled_{}; }; @@ -88,8 +88,8 @@ class InstanceImpl : public Instance { ClientFactory& client_factory, ThreadLocal::Instance& tls); // Redis::ConnPool::Instance - ActiveRequest* makeRequest(const std::string& hash_key, const RespValue& request, - ActiveRequestCallbacks& callbacks) override; + PoolRequest* makeRequest(const std::string& hash_key, const RespValue& request, + PoolCallbacks& callbacks) override; private: struct ThreadLocalPool; @@ -112,8 +112,8 @@ class InstanceImpl : public Instance { ThreadLocalPool(InstanceImpl& parent, Event::Dispatcher& dispatcher, const std::string& cluster_name); - ActiveRequest* makeRequest(const std::string& hash_key, const RespValue& request, - ActiveRequestCallbacks& callbacks); + PoolRequest* makeRequest(const std::string& hash_key, const RespValue& request, + PoolCallbacks& callbacks); void onHostsRemoved(const std::vector& hosts_removed); // ThreadLocal::ThreadLocalObject diff --git a/source/common/redis/proxy_filter.cc b/source/common/redis/proxy_filter.cc index dd830694344f1..95f39512dbd04 100644 --- a/source/common/redis/proxy_filter.cc +++ b/source/common/redis/proxy_filter.cc @@ -6,10 +6,8 @@ namespace Redis { ProxyFilterConfig::ProxyFilterConfig(const Json::Object& config, Upstream::ClusterManager& cm) - : cluster_name_{config.getString("cluster_name")} { - - config.validateSchema(Json::Schema::REDIS_PROXY_NETWORK_FILTER_SCHEMA); - + : Json::Validator(config, Json::Schema::REDIS_PROXY_NETWORK_FILTER_SCHEMA), + cluster_name_{config.getString("cluster_name")} { if (!cm.get(cluster_name_)) { throw EnvoyException( fmt::format("redis filter config: unknown cluster name '{}'", cluster_name_)); @@ -21,11 +19,8 @@ ProxyFilter::~ProxyFilter() { ASSERT(pending_requests_.empty()); } void ProxyFilter::onRespValue(RespValuePtr&& value) { pending_requests_.emplace_back(*this); PendingRequest& request = pending_requests_.back(); - request.request_handle_ = conn_pool_.makeRequest("", *value, request); - if (!request.request_handle_) { - respondWithFailure("no healthy upstream"); - pending_requests_.pop_back(); - } + request.request_handle_ = splitter_.makeRequest(*value, request); + // The splitter can immediately respond. } void ProxyFilter::onEvent(uint32_t events) { @@ -55,30 +50,19 @@ void ProxyFilter::onResponse(PendingRequest& request, RespValuePtr&& value) { } } -void ProxyFilter::onFailure(PendingRequest& request) { - RespValuePtr error(new RespValue()); - error->type(RespType::Error); - error->asString() = "upstream connection error"; - onResponse(request, std::move(error)); -} - Network::FilterStatus ProxyFilter::onData(Buffer::Instance& data) { try { decoder_->decode(data); return Network::FilterStatus::Continue; } catch (ProtocolError&) { - respondWithFailure("downstream protocol error"); + RespValue error; + error.type(RespType::Error); + error.asString() = "downstream protocol error"; + encoder_->encode(error, encoder_buffer_); + callbacks_->connection().write(encoder_buffer_); callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); return Network::FilterStatus::StopIteration; } } -void ProxyFilter::respondWithFailure(const std::string& message) { - RespValue error; - error.type(RespType::Error); - error.asString() = message; - encoder_->encode(error, encoder_buffer_); - callbacks_->connection().write(encoder_buffer_); -} - } // Redis diff --git a/source/common/redis/proxy_filter.h b/source/common/redis/proxy_filter.h index 1048d0d12277d..4c48b60d6c4ac 100644 --- a/source/common/redis/proxy_filter.h +++ b/source/common/redis/proxy_filter.h @@ -2,20 +2,21 @@ #include "envoy/network/filter.h" #include "envoy/redis/codec.h" -#include "envoy/redis/conn_pool.h" +#include "envoy/redis/command_splitter.h" +#include "envoy/upstream/cluster_manager.h" #include "common/buffer/buffer_impl.h" #include "common/json/json_loader.h" +#include "common/json/json_validator.h" namespace Redis { // TODO(mattklein123): Stats -// TODO(mattklein123): Actual multiplexing, command verification, and splitting /** * Configuration for the redis proxy filter. */ -class ProxyFilterConfig { +class ProxyFilterConfig : Json::Validator { public: ProxyFilterConfig(const Json::Object& config, Upstream::ClusterManager& cm); @@ -33,8 +34,8 @@ class ProxyFilter : public Network::ReadFilter, public DecoderCallbacks, public Network::ConnectionCallbacks { public: - ProxyFilter(DecoderFactory& factory, EncoderPtr&& encoder, ConnPool::Instance& conn_pool) - : decoder_(factory.create(*this)), encoder_(std::move(encoder)), conn_pool_(conn_pool) {} + ProxyFilter(DecoderFactory& factory, EncoderPtr&& encoder, CommandSplitter::Instance& splitter) + : decoder_(factory.create(*this)), encoder_(std::move(encoder)), splitter_(splitter) {} ~ProxyFilter(); @@ -53,25 +54,22 @@ class ProxyFilter : public Network::ReadFilter, void onRespValue(RespValuePtr&& value) override; private: - struct PendingRequest : public ConnPool::ActiveRequestCallbacks { + struct PendingRequest : public CommandSplitter::SplitCallbacks { PendingRequest(ProxyFilter& parent) : parent_(parent) {} - // Redis::ConnPool::ActiveRequestCallbacks + // Redis::CommandSplitter::SplitCallbacks void onResponse(RespValuePtr&& value) override { parent_.onResponse(*this, std::move(value)); } - void onFailure() override { parent_.onFailure(*this); } ProxyFilter& parent_; RespValuePtr pending_response_; - ConnPool::ActiveRequest* request_handle_; + CommandSplitter::SplitRequestPtr request_handle_; }; void onResponse(PendingRequest& request, RespValuePtr&& value); - void onFailure(PendingRequest& request); - void respondWithFailure(const std::string& message); DecoderPtr decoder_; EncoderPtr encoder_; - ConnPool::Instance& conn_pool_; + CommandSplitter::Instance& splitter_; Buffer::OwnedImpl encoder_buffer_; Network::ReadFilterCallbacks* callbacks_{}; std::list pending_requests_; diff --git a/source/common/router/config_utility.h b/source/common/router/config_utility.h index e60f5a6c2afaf..54ddfd9ab77f9 100644 --- a/source/common/router/config_utility.h +++ b/source/common/router/config_utility.h @@ -15,12 +15,12 @@ namespace Router { */ class ConfigUtility { public: - struct HeaderData : Json::JsonValidator { + struct HeaderData : Json::Validator { // An empty header value allows for matching to be only based on header presence. // Regex is an opt-in. Unless explicitly mentioned, the header values will be used for // exact string matching. HeaderData(const Json::Object& config) - : Json::JsonValidator(config, Json::Schema::HEADER_DATA_CONFIGURATION_SCHEMA), + : Json::Validator(config, Json::Schema::HEADER_DATA_CONFIGURATION_SCHEMA), name_(config.getString("name")), value_(config.getString("value", EMPTY_STRING)), regex_pattern_(value_, std::regex::optimize), is_regex_(config.getBoolean("regex", false)) {} diff --git a/source/common/router/router_ratelimit.cc b/source/common/router/router_ratelimit.cc index 3c1020c0587cf..c10b49bc2c0f0 100644 --- a/source/common/router/router_ratelimit.cc +++ b/source/common/router/router_ratelimit.cc @@ -69,7 +69,7 @@ void HeaderValueMatchAction::populateDescriptor(const Router::RouteEntry&, } RateLimitPolicyEntryImpl::RateLimitPolicyEntryImpl(const Json::Object& config) - : Json::JsonValidator(config, Json::Schema::HTTP_RATE_LIMITS_CONFIGURATION_SCHEMA), + : Json::Validator(config, Json::Schema::HTTP_RATE_LIMITS_CONFIGURATION_SCHEMA), disable_key_(config.getString("disable_key", "")), stage_(static_cast(config.getInteger("stage", 0))) { for (const Json::ObjectPtr& action : config.getObjectArray("actions")) { diff --git a/source/common/router/router_ratelimit.h b/source/common/router/router_ratelimit.h index ec3f0cb26e1f0..87a700ee0a4c9 100644 --- a/source/common/router/router_ratelimit.h +++ b/source/common/router/router_ratelimit.h @@ -98,7 +98,7 @@ class HeaderValueMatchAction : public RateLimitAction { /* * Implementation of RateLimitPolicyEntry that holds the action for the configuration. */ -class RateLimitPolicyEntryImpl : public RateLimitPolicyEntry, Json::JsonValidator { +class RateLimitPolicyEntryImpl : public RateLimitPolicyEntry, Json::Validator { public: RateLimitPolicyEntryImpl(const Json::Object& config); diff --git a/source/server/config/network/redis_proxy.cc b/source/server/config/network/redis_proxy.cc index 65579019006e2..10930da400e6f 100644 --- a/source/server/config/network/redis_proxy.cc +++ b/source/server/config/network/redis_proxy.cc @@ -1,6 +1,7 @@ #include "redis_proxy.h" #include "common/redis/codec_impl.h" +#include "common/redis/command_splitter_impl.h" #include "common/redis/conn_pool_impl.h" #include "common/redis/proxy_filter.h" @@ -15,13 +16,15 @@ NetworkFilterFactoryCb RedisProxyFilterConfigFactory::tryCreateFilterFactory( } Redis::ProxyFilterConfig filter_config(config, server.clusterManager()); - std::shared_ptr conn_pool(new Redis::ConnPool::InstanceImpl( + Redis::ConnPool::InstancePtr conn_pool(new Redis::ConnPool::InstanceImpl( filter_config.clusterName(), server.clusterManager(), Redis::ConnPool::ClientFactoryImpl::instance_, server.threadLocal())); - return [conn_pool](Network::FilterManager& filter_manager) -> void { + std::shared_ptr splitter( + new Redis::CommandSplitter::InstanceImpl(std::move(conn_pool))); + return [splitter](Network::FilterManager& filter_manager) -> void { Redis::DecoderFactoryImpl factory; filter_manager.addReadFilter(Network::ReadFilterSharedPtr{ - new Redis::ProxyFilter(factory, Redis::EncoderPtr{new Redis::EncoderImpl()}, *conn_pool)}); + new Redis::ProxyFilter(factory, Redis::EncoderPtr{new Redis::EncoderImpl()}, *splitter)}); }; } diff --git a/source/server/server.cc b/source/server/server.cc index cc9cd2a9deb69..f25e179095189 100644 --- a/source/server/server.cc +++ b/source/server/server.cc @@ -361,8 +361,6 @@ void InstanceImpl::run() { Runtime::Loader& InstanceImpl::runtime() { return *runtime_loader_; } -InstanceImpl::~InstanceImpl() {} - void InstanceImpl::shutdown() { log().warn("shutdown invoked. sending SIGTERM to self"); kill(getpid(), SIGTERM); diff --git a/source/server/server.h b/source/server/server.h index d2d49c1560ac6..d9972fdbd1180 100644 --- a/source/server/server.h +++ b/source/server/server.h @@ -94,7 +94,6 @@ class InstanceImpl : Logger::Loggable, public Instance { InstanceImpl(Options& options, TestHooks& hooks, HotRestart& restarter, Stats::StoreRoot& store, Thread::BasicLockable& access_log_lock, ComponentFactory& component_factory, const LocalInfo::LocalInfo& local_info); - ~InstanceImpl(); void run(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 966f39067f293..7484d813a5a49 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -82,6 +82,7 @@ add_executable(envoy-test common/network/utility_test.cc common/ratelimit/ratelimit_impl_test.cc common/redis/codec_impl_test.cc + common/redis/command_splitter_impl_test.cc common/redis/conn_pool_impl_test.cc common/redis/proxy_filter_test.cc common/router/config_impl_test.cc diff --git a/test/common/redis/codec_impl_test.cc b/test/common/redis/codec_impl_test.cc index 939c8700aae00..0e835aeaac609 100644 --- a/test/common/redis/codec_impl_test.cc +++ b/test/common/redis/codec_impl_test.cc @@ -25,6 +25,7 @@ class RedisEncoderDecoderImplTest : public testing::Test, public DecoderCallback TEST_F(RedisEncoderDecoderImplTest, Null) { RespValue value; + EXPECT_EQ("null", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("$-1\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -36,6 +37,7 @@ TEST_F(RedisEncoderDecoderImplTest, Error) { RespValue value; value.type(RespType::Error); value.asString() = "error"; + EXPECT_EQ("\"error\"", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("-error\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -47,6 +49,7 @@ TEST_F(RedisEncoderDecoderImplTest, SimpleString) { RespValue value; value.type(RespType::SimpleString); value.asString() = "simple string"; + EXPECT_EQ("\"simple string\"", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("+simple string\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -58,6 +61,7 @@ TEST_F(RedisEncoderDecoderImplTest, Integer) { RespValue value; value.type(RespType::Integer); value.asInteger() = std::numeric_limits::max(); + EXPECT_EQ("9223372036854775807", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ(":9223372036854775807\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -79,6 +83,7 @@ TEST_F(RedisEncoderDecoderImplTest, NegativeInteger) { TEST_F(RedisEncoderDecoderImplTest, EmptyArray) { RespValue value; value.type(RespType::Array); + EXPECT_EQ("[]", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("*0\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); @@ -96,6 +101,7 @@ TEST_F(RedisEncoderDecoderImplTest, Array) { RespValue value; value.type(RespType::Array); value.asArray().swap(values); + EXPECT_EQ("[\"hello\", -5]", value.toString()); encoder_.encode(value, buffer_); EXPECT_EQ("*2\r\n$5\r\nhello\r\n:-5\r\n", TestUtility::bufferToString(buffer_)); decoder_.decode(buffer_); diff --git a/test/common/redis/command_splitter_impl_test.cc b/test/common/redis/command_splitter_impl_test.cc new file mode 100644 index 0000000000000..5490e299da074 --- /dev/null +++ b/test/common/redis/command_splitter_impl_test.cc @@ -0,0 +1,333 @@ +#include "common/redis/command_splitter_impl.h" + +#include "test/mocks/common.h" +#include "test/mocks/redis/mocks.h" + +using testing::_; +using testing::ByRef; +using testing::DoAll; +using testing::Eq; +using testing::InSequence; +using testing::Ref; +using testing::Return; +using testing::WithArg; + +namespace Redis { +namespace CommandSplitter { + +class RedisCommandSplitterImplTest : public testing::Test { +public: + void makeBulkStringArray(RespValue& value, const std::vector& strings) { + std::vector values(strings.size()); + for (uint64_t i = 0; i < strings.size(); i++) { + values[i].type(RespType::BulkString); + values[i].asString() = strings[i]; + } + + value.type(RespType::Array); + value.asArray().swap(values); + } + + ConnPool::MockInstance* conn_pool_{new ConnPool::MockInstance()}; + InstanceImpl splitter_{ConnPool::InstancePtr{conn_pool_}}; + MockSplitCallbacks callbacks_; + SplitRequestPtr handle_; +}; + +TEST_F(RedisCommandSplitterImplTest, InvalidRequestNotArray) { + RespValue response; + response.type(RespType::Error); + response.asString() = "invalid request"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + RespValue request; + EXPECT_EQ(nullptr, splitter_.makeRequest(request, callbacks_)); +} + +TEST_F(RedisCommandSplitterImplTest, InvalidRequestArrayTooSmall) { + RespValue response; + response.type(RespType::Error); + response.asString() = "invalid request"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + RespValue request; + makeBulkStringArray(request, {"incr"}); + EXPECT_EQ(nullptr, splitter_.makeRequest(request, callbacks_)); +} + +TEST_F(RedisCommandSplitterImplTest, InvalidRequestArrayNotStrings) { + RespValue response; + response.type(RespType::Error); + response.asString() = "invalid request"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + RespValue request; + makeBulkStringArray(request, {"incr", ""}); + request.asArray()[1].type(RespType::Null); + EXPECT_EQ(nullptr, splitter_.makeRequest(request, callbacks_)); +} + +TEST_F(RedisCommandSplitterImplTest, UnsupportedCommand) { + RespValue response; + response.type(RespType::Error); + response.asString() = "unsupported command"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + RespValue request; + makeBulkStringArray(request, {"newcommand", "hello"}); + EXPECT_EQ(nullptr, splitter_.makeRequest(request, callbacks_)); +} + +class RedisAllParamsToOneServerCommandHandlerTest : public RedisCommandSplitterImplTest { +public: + void makeRequest(const std::string& hash_key, const RespValue& request) { + EXPECT_CALL(*conn_pool_, makeRequest(hash_key, Ref(request), _)) + .WillOnce(DoAll(WithArg<2>(SaveArgAddress(&pool_callbacks_)), Return(&pool_request_))); + handle_ = splitter_.makeRequest(request, callbacks_); + } + + void fail() { + RespValue response; + response.type(RespType::Error); + response.asString() = "upstream failure"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + pool_callbacks_->onFailure(); + } + + void respond() { + RespValuePtr response1(new RespValue()); + RespValue* response1_ptr = response1.get(); + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(response1_ptr))); + pool_callbacks_->onResponse(std::move(response1)); + } + + ConnPool::PoolCallbacks* pool_callbacks_; + ConnPool::MockPoolRequest pool_request_; +}; + +TEST_F(RedisAllParamsToOneServerCommandHandlerTest, IncrSuccess) { + InSequence s; + + RespValue request; + makeBulkStringArray(request, {"incr", "hello"}); + makeRequest("hello", request); + EXPECT_NE(nullptr, handle_); + + respond(); +}; + +TEST_F(RedisAllParamsToOneServerCommandHandlerTest, IncrFail) { + InSequence s; + + RespValue request; + makeBulkStringArray(request, {"incr", "hello"}); + makeRequest("hello", request); + EXPECT_NE(nullptr, handle_); + + fail(); +}; + +TEST_F(RedisAllParamsToOneServerCommandHandlerTest, IncrCancel) { + InSequence s; + + RespValue request; + makeBulkStringArray(request, {"incr", "hello"}); + makeRequest("hello", request); + EXPECT_NE(nullptr, handle_); + + EXPECT_CALL(pool_request_, cancel()); + handle_->cancel(); +}; + +TEST_F(RedisAllParamsToOneServerCommandHandlerTest, IncrNoUpstream) { + InSequence s; + + RespValue request; + makeBulkStringArray(request, {"incr", "hello"}); + EXPECT_CALL(*conn_pool_, makeRequest("hello", Ref(request), _)).WillOnce(Return(nullptr)); + RespValue response; + response.type(RespType::Error); + response.asString() = "no upstream host"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); + handle_ = splitter_.makeRequest(request, callbacks_); + EXPECT_EQ(nullptr, handle_); +}; + +class RedisMGETCommandHandlerTest : public RedisCommandSplitterImplTest { +public: + void setup(uint32_t num_gets, const std::list& null_handle_indexes) { + std::vector request_strings = {"mget"}; + for (uint32_t i = 0; i < num_gets; i++) { + request_strings.push_back(std::to_string(i)); + } + + RespValue request; + makeBulkStringArray(request, request_strings); + + std::vector tmp_expected_requests(num_gets); + expected_requests_.swap(tmp_expected_requests); + pool_callbacks_.resize(num_gets); + std::vector tmp_pool_requests(num_gets); + pool_requests_.swap(tmp_pool_requests); + for (uint32_t i = 0; i < num_gets; i++) { + makeBulkStringArray(expected_requests_[i], {"get", std::to_string(i)}); + ConnPool::PoolRequest* request_to_use = nullptr; + if (std::find(null_handle_indexes.begin(), null_handle_indexes.end(), i) == + null_handle_indexes.end()) { + request_to_use = &pool_requests_[i]; + } + EXPECT_CALL(*conn_pool_, makeRequest(std::to_string(i), Eq(ByRef(expected_requests_[i])), _)) + .WillOnce(DoAll(WithArg<2>(SaveArgAddress(&pool_callbacks_[i])), Return(request_to_use))); + } + + handle_ = splitter_.makeRequest(request, callbacks_); + } + + std::vector expected_requests_; + std::vector pool_callbacks_; + std::vector pool_requests_; +}; + +TEST_F(RedisMGETCommandHandlerTest, Normal) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::BulkString); + elements[0].asString() = "response"; + elements[1].type(RespType::BulkString); + elements[1].asString() = "5"; + expected_response.asArray().swap(elements); + + RespValuePtr response2(new RespValue()); + response2->type(RespType::BulkString); + response2->asString() = "5"; + pool_callbacks_[1]->onResponse(std::move(response2)); + + RespValuePtr response1(new RespValue()); + response1->type(RespType::BulkString); + response1->asString() = "response"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(std::move(response1)); +}; + +TEST_F(RedisMGETCommandHandlerTest, NormalWithNull) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::BulkString); + elements[0].asString() = "response"; + expected_response.asArray().swap(elements); + + RespValuePtr response2(new RespValue()); + pool_callbacks_[1]->onResponse(std::move(response2)); + + RespValuePtr response1(new RespValue()); + response1->type(RespType::BulkString); + response1->asString() = "response"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(std::move(response1)); +}; + +TEST_F(RedisMGETCommandHandlerTest, NoUpstreamHostForAll) { + // No InSequence to avoid making setup() more complicated. + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::Error); + elements[0].asString() = "no upstream host"; + elements[1].type(RespType::Error); + elements[1].asString() = "no upstream host"; + expected_response.asArray().swap(elements); + + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + setup(2, {0, 1}); + EXPECT_EQ(nullptr, handle_); +}; + +TEST_F(RedisMGETCommandHandlerTest, NoUpstreamHostForOne) { + InSequence s; + + setup(2, {0}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::Error); + elements[0].asString() = "no upstream host"; + elements[1].type(RespType::Error); + elements[1].asString() = "upstream failure"; + expected_response.asArray().swap(elements); + + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[1]->onFailure(); +}; + +TEST_F(RedisMGETCommandHandlerTest, Failure) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::BulkString); + elements[0].asString() = "response"; + elements[1].type(RespType::Error); + elements[1].asString() = "upstream failure"; + expected_response.asArray().swap(elements); + + pool_callbacks_[1]->onFailure(); + + RespValuePtr response1(new RespValue()); + response1->type(RespType::BulkString); + response1->asString() = "response"; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(std::move(response1)); +}; + +TEST_F(RedisMGETCommandHandlerTest, InvalidUpstreamResponse) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + RespValue expected_response; + expected_response.type(RespType::Array); + std::vector elements(2); + elements[0].type(RespType::Error); + elements[0].asString() = "upstream protocol error"; + elements[1].type(RespType::Error); + elements[1].asString() = "upstream failure"; + expected_response.asArray().swap(elements); + + pool_callbacks_[1]->onFailure(); + + RespValuePtr response1(new RespValue()); + response1->type(RespType::Integer); + response1->asInteger() = 5; + EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&expected_response))); + pool_callbacks_[0]->onResponse(std::move(response1)); +}; + +TEST_F(RedisMGETCommandHandlerTest, Cancel) { + InSequence s; + + setup(2, {}); + EXPECT_NE(nullptr, handle_); + + EXPECT_CALL(pool_requests_[0], cancel()); + EXPECT_CALL(pool_requests_[1], cancel()); + handle_->cancel(); +}; + +} // CommandSplitter +} // Redis diff --git a/test/common/redis/conn_pool_impl_test.cc b/test/common/redis/conn_pool_impl_test.cc index 9a0351b996f5e..ad16d12a95226 100644 --- a/test/common/redis/conn_pool_impl_test.cc +++ b/test/common/redis/conn_pool_impl_test.cc @@ -53,15 +53,15 @@ TEST_F(RedisClientImplTest, Basic) { setup(); RespValue request1; - MockActiveRequestCallbacks callbacks1; + MockPoolCallbacks callbacks1; EXPECT_CALL(*encoder_, encode(Ref(request1), _)); - ActiveRequest* handle1 = client_->makeRequest(request1, callbacks1); + PoolRequest* handle1 = client_->makeRequest(request1, callbacks1); EXPECT_NE(nullptr, handle1); RespValue request2; - MockActiveRequestCallbacks callbacks2; + MockPoolCallbacks callbacks2; EXPECT_CALL(*encoder_, encode(Ref(request2), _)); - ActiveRequest* handle2 = client_->makeRequest(request2, callbacks2); + PoolRequest* handle2 = client_->makeRequest(request2, callbacks2); EXPECT_NE(nullptr, handle2); Buffer::OwnedImpl fake_data; @@ -86,15 +86,15 @@ TEST_F(RedisClientImplTest, Cancel) { setup(); RespValue request1; - MockActiveRequestCallbacks callbacks1; + MockPoolCallbacks callbacks1; EXPECT_CALL(*encoder_, encode(Ref(request1), _)); - ActiveRequest* handle1 = client_->makeRequest(request1, callbacks1); + PoolRequest* handle1 = client_->makeRequest(request1, callbacks1); EXPECT_NE(nullptr, handle1); RespValue request2; - MockActiveRequestCallbacks callbacks2; + MockPoolCallbacks callbacks2; EXPECT_CALL(*encoder_, encode(Ref(request2), _)); - ActiveRequest* handle2 = client_->makeRequest(request2, callbacks2); + PoolRequest* handle2 = client_->makeRequest(request2, callbacks2); EXPECT_NE(nullptr, handle2); handle1->cancel(); @@ -124,9 +124,9 @@ TEST_F(RedisClientImplTest, FailAll) { client_->addConnectionCallbacks(connection_callbacks); RespValue request1; - MockActiveRequestCallbacks callbacks1; + MockPoolCallbacks callbacks1; EXPECT_CALL(*encoder_, encode(Ref(request1), _)); - ActiveRequest* handle1 = client_->makeRequest(request1, callbacks1); + PoolRequest* handle1 = client_->makeRequest(request1, callbacks1); EXPECT_NE(nullptr, handle1); EXPECT_CALL(connection_callbacks, onEvent(Network::ConnectionEvent::RemoteClose)); @@ -138,9 +138,9 @@ TEST_F(RedisClientImplTest, ProtocolError) { setup(); RespValue request1; - MockActiveRequestCallbacks callbacks1; + MockPoolCallbacks callbacks1; EXPECT_CALL(*encoder_, encode(Ref(request1), _)); - ActiveRequest* handle1 = client_->makeRequest(request1, callbacks1); + PoolRequest* handle1 = client_->makeRequest(request1, callbacks1); EXPECT_NE(nullptr, handle1); Buffer::OwnedImpl fake_data; @@ -181,8 +181,8 @@ TEST_F(RedisConnPoolImplTest, Basic) { InSequence s; RespValue value; - MockActiveRequest active_request; - MockActiveRequestCallbacks callbacks; + MockPoolRequest active_request; + MockPoolCallbacks callbacks; MockClient* client = new NiceMock(); EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)) @@ -193,7 +193,7 @@ TEST_F(RedisConnPoolImplTest, Basic) { })); EXPECT_CALL(*this, create_(_, _)).WillOnce(Return(client)); EXPECT_CALL(*client, makeRequest(Ref(value), Ref(callbacks))).WillOnce(Return(&active_request)); - ActiveRequest* request = conn_pool_.makeRequest("foo", value, callbacks); + PoolRequest* request = conn_pool_.makeRequest("foo", value, callbacks); EXPECT_EQ(&active_request, request); EXPECT_CALL(*client, close()); @@ -202,7 +202,7 @@ TEST_F(RedisConnPoolImplTest, Basic) { TEST_F(RedisConnPoolImplTest, HostRemove) { InSequence s; - MockActiveRequestCallbacks callbacks; + MockPoolCallbacks callbacks; RespValue value; std::shared_ptr host1(new Upstream::MockHost()); @@ -213,17 +213,17 @@ TEST_F(RedisConnPoolImplTest, HostRemove) { EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)).WillOnce(Return(host1)); EXPECT_CALL(*this, create_(Eq(host1), _)).WillOnce(Return(client1)); - MockActiveRequest active_request1; + MockPoolRequest active_request1; EXPECT_CALL(*client1, makeRequest(Ref(value), Ref(callbacks))).WillOnce(Return(&active_request1)); - ActiveRequest* request1 = conn_pool_.makeRequest("foo", value, callbacks); + PoolRequest* request1 = conn_pool_.makeRequest("foo", value, callbacks); EXPECT_EQ(&active_request1, request1); EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)).WillOnce(Return(host2)); EXPECT_CALL(*this, create_(Eq(host2), _)).WillOnce(Return(client2)); - MockActiveRequest active_request2; + MockPoolRequest active_request2; EXPECT_CALL(*client2, makeRequest(Ref(value), Ref(callbacks))).WillOnce(Return(&active_request2)); - ActiveRequest* request2 = conn_pool_.makeRequest("bar", value, callbacks); + PoolRequest* request2 = conn_pool_.makeRequest("bar", value, callbacks); EXPECT_EQ(&active_request2, request2); EXPECT_CALL(*client2, close()); @@ -237,9 +237,9 @@ TEST_F(RedisConnPoolImplTest, NoHost) { InSequence s; RespValue value; - MockActiveRequestCallbacks callbacks; + MockPoolCallbacks callbacks; EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)).WillOnce(Return(nullptr)); - ActiveRequest* request = conn_pool_.makeRequest("foo", value, callbacks); + PoolRequest* request = conn_pool_.makeRequest("foo", value, callbacks); EXPECT_EQ(nullptr, request); tls_.shutdownThread(); @@ -249,8 +249,8 @@ TEST_F(RedisConnPoolImplTest, RemoteClose) { InSequence s; RespValue value; - MockActiveRequest active_request; - MockActiveRequestCallbacks callbacks; + MockPoolRequest active_request; + MockPoolCallbacks callbacks; MockClient* client = new NiceMock(); EXPECT_CALL(cm_.thread_local_cluster_.lb_, chooseHost(_)); diff --git a/test/common/redis/proxy_filter_test.cc b/test/common/redis/proxy_filter_test.cc index c59c69f4375df..875caae6550f2 100644 --- a/test/common/redis/proxy_filter_test.cc +++ b/test/common/redis/proxy_filter_test.cc @@ -73,8 +73,8 @@ class RedisProxyFilterTest : public testing::Test, public DecoderFactory { MockEncoder* encoder_{new MockEncoder()}; MockDecoder* decoder_{new MockDecoder()}; DecoderCallbacks* decoder_callbacks_{}; - ConnPool::MockInstance conn_pool_; - ProxyFilter filter_{*this, EncoderPtr{encoder_}, conn_pool_}; + CommandSplitter::MockInstance splitter_; + ProxyFilter filter_{*this, EncoderPtr{encoder_}, splitter_}; NiceMock filter_callbacks_; }; @@ -82,22 +82,22 @@ TEST_F(RedisProxyFilterTest, OutOfOrderResponse) { InSequence s; Buffer::OwnedImpl fake_data; - ConnPool::MockActiveRequest request_handle1; - ConnPool::ActiveRequestCallbacks* request_callbacks1; - ConnPool::MockActiveRequest request_handle2; - ConnPool::ActiveRequestCallbacks* request_callbacks2; + CommandSplitter::MockSplitRequest* request_handle1 = new CommandSplitter::MockSplitRequest(); + CommandSplitter::SplitCallbacks* request_callbacks1; + CommandSplitter::MockSplitRequest* request_handle2 = new CommandSplitter::MockSplitRequest(); + CommandSplitter::SplitCallbacks* request_callbacks2; EXPECT_CALL(*decoder_, decode(Ref(fake_data))) .WillOnce(Invoke([&](Buffer::Instance&) -> void { RespValuePtr request1(new RespValue()); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request1), _)) + EXPECT_CALL(splitter_, makeRequest_(Ref(*request1), _)) .WillOnce( - DoAll(WithArg<2>(SaveArgAddress(&request_callbacks1)), Return(&request_handle1))); + DoAll(WithArg<1>(SaveArgAddress(&request_callbacks1)), Return(request_handle1))); decoder_callbacks_->onRespValue(std::move(request1)); RespValuePtr request2(new RespValue()); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request2), _)) + EXPECT_CALL(splitter_, makeRequest_(Ref(*request2), _)) .WillOnce( - DoAll(WithArg<2>(SaveArgAddress(&request_callbacks2)), Return(&request_handle2))); + DoAll(WithArg<1>(SaveArgAddress(&request_callbacks2)), Return(request_handle2))); decoder_callbacks_->onRespValue(std::move(request2)); })); EXPECT_EQ(Network::FilterStatus::Continue, filter_.onData(fake_data)); @@ -115,53 +115,27 @@ TEST_F(RedisProxyFilterTest, OutOfOrderResponse) { filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); } -TEST_F(RedisProxyFilterTest, UpstreamFailure) { - InSequence s; - - Buffer::OwnedImpl fake_data; - ConnPool::MockActiveRequest request_handle1; - ConnPool::ActiveRequestCallbacks* request_callbacks1; - EXPECT_CALL(*decoder_, decode(Ref(fake_data))) - .WillOnce(Invoke([&](Buffer::Instance&) -> void { - RespValuePtr request1(new RespValue()); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request1), _)) - .WillOnce( - DoAll(WithArg<2>(SaveArgAddress(&request_callbacks1)), Return(&request_handle1))); - decoder_callbacks_->onRespValue(std::move(request1)); - })); - EXPECT_EQ(Network::FilterStatus::Continue, filter_.onData(fake_data)); - - RespValue error; - error.type(RespType::Error); - error.asString() = "upstream connection error"; - EXPECT_CALL(*encoder_, encode(Eq(ByRef(error)), _)); - EXPECT_CALL(filter_callbacks_.connection_, write(_)); - request_callbacks1->onFailure(); - - filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::LocalClose); -} - TEST_F(RedisProxyFilterTest, DownstreamDisconnectWithActive) { InSequence s; Buffer::OwnedImpl fake_data; - ConnPool::MockActiveRequest request_handle1; - ConnPool::ActiveRequestCallbacks* request_callbacks1; + CommandSplitter::MockSplitRequest* request_handle1 = new CommandSplitter::MockSplitRequest(); + CommandSplitter::SplitCallbacks* request_callbacks1; EXPECT_CALL(*decoder_, decode(Ref(fake_data))) .WillOnce(Invoke([&](Buffer::Instance&) -> void { RespValuePtr request1(new RespValue()); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request1), _)) + EXPECT_CALL(splitter_, makeRequest_(Ref(*request1), _)) .WillOnce( - DoAll(WithArg<2>(SaveArgAddress(&request_callbacks1)), Return(&request_handle1))); + DoAll(WithArg<1>(SaveArgAddress(&request_callbacks1)), Return(request_handle1))); decoder_callbacks_->onRespValue(std::move(request1)); })); EXPECT_EQ(Network::FilterStatus::Continue, filter_.onData(fake_data)); - EXPECT_CALL(request_handle1, cancel()); + EXPECT_CALL(*request_handle1, cancel()); filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); } -TEST_F(RedisProxyFilterTest, NoClient) { +TEST_F(RedisProxyFilterTest, ImmediateResponse) { InSequence s; Buffer::OwnedImpl fake_data; @@ -169,15 +143,19 @@ TEST_F(RedisProxyFilterTest, NoClient) { EXPECT_CALL(*decoder_, decode(Ref(fake_data))) .WillOnce(Invoke([&](Buffer::Instance&) -> void { decoder_callbacks_->onRespValue(std::move(request1)); })); - EXPECT_CALL(conn_pool_, makeRequest("", Ref(*request1), _)).WillOnce(Return(nullptr)); + EXPECT_CALL(splitter_, makeRequest_(Ref(*request1), _)) + .WillOnce(Invoke([&](const RespValue&, CommandSplitter::SplitCallbacks& callbacks) + -> CommandSplitter::SplitRequest* { + RespValuePtr error(new RespValue()); + error->type(RespType::Error); + error->asString() = "no healthy upstream"; + EXPECT_CALL(*encoder_, encode(Eq(ByRef(*error)), _)); + EXPECT_CALL(filter_callbacks_.connection_, write(_)); + callbacks.onResponse(std::move(error)); + return nullptr; + })); - RespValue error; - error.type(RespType::Error); - error.asString() = "no healthy upstream"; - EXPECT_CALL(*encoder_, encode(Eq(ByRef(error)), _)); - EXPECT_CALL(filter_callbacks_.connection_, write(_)); EXPECT_EQ(Network::FilterStatus::Continue, filter_.onData(fake_data)); - filter_callbacks_.connection_.raiseEvents(Network::ConnectionEvent::RemoteClose); } diff --git a/test/mocks/common.h b/test/mocks/common.h index 27a59c92dcde5..8076be7290859 100644 --- a/test/mocks/common.h +++ b/test/mocks/common.h @@ -12,7 +12,10 @@ ACTION_P(SaveArgAddress, target) { *target = &arg0; } /** * Matcher that matches on whether the pointee of both lhs and rhs are equal. */ -MATCHER_P(PointeesEq, rhs, "") { return *arg == *rhs; } +MATCHER_P(PointeesEq, rhs, "") { + *result_listener << testing::PrintToString(*arg) + " != " + testing::PrintToString(*rhs); + return *arg == *rhs; +} /** * Simple mock that just lets us make sure a method gets called or not called form a lambda. diff --git a/test/mocks/redis/mocks.cc b/test/mocks/redis/mocks.cc index 6ffb7072be7db..582dd5013ab16 100644 --- a/test/mocks/redis/mocks.cc +++ b/test/mocks/redis/mocks.cc @@ -65,14 +65,27 @@ MockClient::MockClient() { MockClient::~MockClient() {} -MockActiveRequest::MockActiveRequest() {} -MockActiveRequest::~MockActiveRequest() {} +MockPoolRequest::MockPoolRequest() {} +MockPoolRequest::~MockPoolRequest() {} -MockActiveRequestCallbacks::MockActiveRequestCallbacks() {} -MockActiveRequestCallbacks::~MockActiveRequestCallbacks() {} +MockPoolCallbacks::MockPoolCallbacks() {} +MockPoolCallbacks::~MockPoolCallbacks() {} MockInstance::MockInstance() {} MockInstance::~MockInstance() {} } // ConnPool + +namespace CommandSplitter { + +MockSplitRequest::MockSplitRequest() {} +MockSplitRequest::~MockSplitRequest() {} + +MockSplitCallbacks::MockSplitCallbacks() {} +MockSplitCallbacks::~MockSplitCallbacks() {} + +MockInstance::MockInstance() {} +MockInstance::~MockInstance() {} + +} // CommandSplitter } // Redis diff --git a/test/mocks/redis/mocks.h b/test/mocks/redis/mocks.h index bf0cd3e5559fa..e0da90d1c93dc 100644 --- a/test/mocks/redis/mocks.h +++ b/test/mocks/redis/mocks.h @@ -1,5 +1,6 @@ #pragma once +#include "envoy/redis/command_splitter.h" #include "envoy/redis/conn_pool.h" #include "common/redis/codec_impl.h" @@ -42,24 +43,23 @@ class MockClient : public Client { MOCK_METHOD1(addConnectionCallbacks, void(Network::ConnectionCallbacks& callbacks)); MOCK_METHOD0(close, void()); - MOCK_METHOD2(makeRequest, - ActiveRequest*(const RespValue& request, ActiveRequestCallbacks& callbacks)); + MOCK_METHOD2(makeRequest, PoolRequest*(const RespValue& request, PoolCallbacks& callbacks)); std::list callbacks_; }; -class MockActiveRequest : public ActiveRequest { +class MockPoolRequest : public PoolRequest { public: - MockActiveRequest(); - ~MockActiveRequest(); + MockPoolRequest(); + ~MockPoolRequest(); MOCK_METHOD0(cancel, void()); }; -class MockActiveRequestCallbacks : public ActiveRequestCallbacks { +class MockPoolCallbacks : public PoolCallbacks { public: - MockActiveRequestCallbacks(); - ~MockActiveRequestCallbacks(); + MockPoolCallbacks(); + ~MockPoolCallbacks(); void onResponse(RespValuePtr&& value) override { onResponse_(value); } @@ -72,9 +72,43 @@ class MockInstance : public Instance { MockInstance(); ~MockInstance(); - MOCK_METHOD3(makeRequest, ActiveRequest*(const std::string& hash_key, const RespValue& request, - ActiveRequestCallbacks& callbacks)); + MOCK_METHOD3(makeRequest, PoolRequest*(const std::string& hash_key, const RespValue& request, + PoolCallbacks& callbacks)); }; } // ConnPool + +namespace CommandSplitter { + +class MockSplitRequest : public SplitRequest { +public: + MockSplitRequest(); + ~MockSplitRequest(); + + MOCK_METHOD0(cancel, void()); +}; + +class MockSplitCallbacks : public SplitCallbacks { +public: + MockSplitCallbacks(); + ~MockSplitCallbacks(); + + void onResponse(RespValuePtr&& value) override { onResponse_(value); } + + MOCK_METHOD1(onResponse_, void(RespValuePtr& value)); +}; + +class MockInstance : public Instance { +public: + MockInstance(); + ~MockInstance(); + + SplitRequestPtr makeRequest(const RespValue& request, SplitCallbacks& callbacks) override { + return SplitRequestPtr{makeRequest_(request, callbacks)}; + } + + MOCK_METHOD2(makeRequest_, SplitRequest*(const RespValue& request, SplitCallbacks& callbacks)); +}; + +} // CommandSplitter } // Redis diff --git a/test/test_common/printers.cc b/test/test_common/printers.cc index b0ec10c69832b..ef0308efbe196 100644 --- a/test/test_common/printers.cc +++ b/test/test_common/printers.cc @@ -1,5 +1,7 @@ #include "printers.h" +#include "envoy/redis/codec.h" + #include "common/buffer/buffer_impl.h" #include "common/http/header_map_impl.h" @@ -18,7 +20,7 @@ void PrintTo(const HeaderMapPtr& headers, std::ostream* os) { void PrintTo(const HeaderMap& headers, std::ostream* os) { PrintTo(*dynamic_cast(&headers), os); } -} +} // Http namespace Buffer { void PrintTo(const Instance& buffer, std::ostream* os) { @@ -28,4 +30,10 @@ void PrintTo(const Instance& buffer, std::ostream* os) { void PrintTo(const Buffer::OwnedImpl& buffer, std::ostream* os) { PrintTo(dynamic_cast(buffer), os); } -} +} // Buffer + +namespace Redis { +void PrintTo(const RespValue& value, std::ostream* os) { *os << value.toString(); } + +void PrintTo(const RespValuePtr& value, std::ostream* os) { *os << value->toString(); } +} // Redis diff --git a/test/test_common/printers.h b/test/test_common/printers.h index a56db2b1e4e93..ba62cd170d340 100644 --- a/test/test_common/printers.h +++ b/test/test_common/printers.h @@ -14,7 +14,7 @@ class HeaderMap; typedef std::unique_ptr HeaderMapPtr; void PrintTo(const HeaderMap& headers, std::ostream* os); void PrintTo(const HeaderMapPtr& headers, std::ostream* os); -} +} // Http namespace Buffer { /** @@ -28,4 +28,14 @@ void PrintTo(const Instance& buffer, std::ostream* os); */ class OwnedImpl; void PrintTo(const OwnedImpl& buffer, std::ostream* os); -} +} // Buffer + +namespace Redis { +/** + * Pretty print const RespValue& value + */ +class RespValue; +typedef std::unique_ptr RespValuePtr; +void PrintTo(const RespValue& value, std::ostream* os); +void PrintTo(const RespValuePtr& value, std::ostream* os); +} // Redis From d7ba1c343fd371b08457fe94eaf09ea243811a0d Mon Sep 17 00:00:00 2001 From: Matt Klein Date: Thu, 23 Mar 2017 16:41:11 -0700 Subject: [PATCH 2/4] fix --- source/common/redis/command_splitter_impl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/common/redis/command_splitter_impl.cc b/source/common/redis/command_splitter_impl.cc index 98860a84d0e77..414dcd8e47c1d 100644 --- a/source/common/redis/command_splitter_impl.cc +++ b/source/common/redis/command_splitter_impl.cc @@ -22,7 +22,7 @@ SplitRequestPtr AllParamsToOneServerCommandHandler::startRequest(const RespValue return nullptr; } - return request_handle; + return std::move(request_handle); } AllParamsToOneServerCommandHandler::SplitRequestImpl::~SplitRequestImpl() { ASSERT(!handle_); } From f30581eb893d4bfd225ce4c7ddf20d03f3a54857 Mon Sep 17 00:00:00 2001 From: Matt Klein Date: Thu, 23 Mar 2017 19:40:22 -0700 Subject: [PATCH 3/4] fix --- source/common/redis/proxy_filter.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/source/common/redis/proxy_filter.cc b/source/common/redis/proxy_filter.cc index 95f39512dbd04..f68f7522ac6c8 100644 --- a/source/common/redis/proxy_filter.cc +++ b/source/common/redis/proxy_filter.cc @@ -19,8 +19,12 @@ ProxyFilter::~ProxyFilter() { ASSERT(pending_requests_.empty()); } void ProxyFilter::onRespValue(RespValuePtr&& value) { pending_requests_.emplace_back(*this); PendingRequest& request = pending_requests_.back(); - request.request_handle_ = splitter_.makeRequest(*value, request); - // The splitter can immediately respond. + CommandSplitter::SplitRequestPtr split = splitter_.makeRequest(*value, request); + if (split) { + // The splitter can immediately respond and destroy the pending request. Only store the handle + // if the request is still alive. + request.request_handle_ = std::move(split); + } } void ProxyFilter::onEvent(uint32_t events) { From 2176c065efffdf0fe1cb363804d12e444312600f Mon Sep 17 00:00:00 2001 From: Matt Klein Date: Mon, 27 Mar 2017 14:13:07 -0700 Subject: [PATCH 4/4] comment --- source/common/redis/command_splitter_impl.cc | 3 ++- test/common/redis/command_splitter_impl_test.cc | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/source/common/redis/command_splitter_impl.cc b/source/common/redis/command_splitter_impl.cc index 414dcd8e47c1d..72902c2fed424 100644 --- a/source/common/redis/command_splitter_impl.cc +++ b/source/common/redis/command_splitter_impl.cc @@ -157,7 +157,8 @@ SplitRequestPtr InstanceImpl::makeRequest(const RespValue& request, SplitCallbac auto handler = command_map_.find(request.asArray()[0].asString()); if (handler == command_map_.end()) { - callbacks.onResponse(Utility::makeError("unsupported command")); + callbacks.onResponse(Utility::makeError( + fmt::format("unsupported command '{}'", request.asArray()[0].asString()))); return nullptr; } diff --git a/test/common/redis/command_splitter_impl_test.cc b/test/common/redis/command_splitter_impl_test.cc index 5490e299da074..74841023921e7 100644 --- a/test/common/redis/command_splitter_impl_test.cc +++ b/test/common/redis/command_splitter_impl_test.cc @@ -67,7 +67,7 @@ TEST_F(RedisCommandSplitterImplTest, InvalidRequestArrayNotStrings) { TEST_F(RedisCommandSplitterImplTest, UnsupportedCommand) { RespValue response; response.type(RespType::Error); - response.asString() = "unsupported command"; + response.asString() = "unsupported command 'newcommand'"; EXPECT_CALL(callbacks_, onResponse_(PointeesEq(&response))); RespValue request; makeBulkStringArray(request, {"newcommand", "hello"});