diff --git a/include/nighthawk/client/client_worker.h b/include/nighthawk/client/client_worker.h index ecb651dca..ad234aa72 100644 --- a/include/nighthawk/client/client_worker.h +++ b/include/nighthawk/client/client_worker.h @@ -34,6 +34,11 @@ class ClientWorker : virtual public Worker { * @return const Phase& associated to this worker. */ virtual const Phase& phase() const PURE; + + /** + * Requests execution cancellation. + */ + virtual void requestExecutionCancellation() PURE; }; using ClientWorkerPtr = std::unique_ptr; diff --git a/include/nighthawk/client/process.h b/include/nighthawk/client/process.h index 39c214e4c..943517726 100644 --- a/include/nighthawk/client/process.h +++ b/include/nighthawk/client/process.h @@ -23,6 +23,11 @@ class Process { * Shuts down the worker. Mandatory call before destructing. */ virtual void shutdown() PURE; + + /** + * Will request all workers to cancel execution asap. + */ + virtual bool requestExecutionCancellation() PURE; }; using ProcessPtr = std::unique_ptr; diff --git a/source/client/client.cc b/source/client/client.cc index cd3b46743..8c08eda00 100644 --- a/source/client/client.cc +++ b/source/client/client.cc @@ -23,6 +23,7 @@ #include "api/client/service.grpc.pb.h" #include "common/frequency.h" +#include "common/signal_handler.h" #include "common/uri_impl.h" #include "common/utility.h" @@ -73,16 +74,21 @@ bool Main::run() { } OutputFormatterFactoryImpl output_formatter_factory; OutputCollectorImpl output_collector(time_system, *options_); - const bool res = process->run(output_collector); + bool result; + { + auto signal_handler = + std::make_unique([&process]() { process->requestExecutionCancellation(); }); + result = process->run(output_collector); + } auto formatter = output_formatter_factory.create(options_->outputFormat()); std::cout << formatter->formatProto(output_collector.toProto()); process->shutdown(); - if (!res) { + if (!result) { ENVOY_LOG(error, "An error ocurred."); } else { ENVOY_LOG(info, "Done."); } - return res; + return result; } } // namespace Client diff --git a/source/client/client_worker_impl.cc b/source/client/client_worker_impl.cc index e4231bc87..09aaba66a 100644 --- a/source/client/client_worker_impl.cc +++ b/source/client/client_worker_impl.cc @@ -87,6 +87,14 @@ void ClientWorkerImpl::work() { void ClientWorkerImpl::shutdownThread() { benchmark_client_->terminate(); } +void ClientWorkerImpl::requestExecutionCancellation() { + // We just bump a counter, which is watched by a static termination predicate. + // A useful side effect is that this counter will propagate to the output, which leaves + // a note about that execution was subject to cancellation. + dispatcher_->post( + [this]() { worker_number_scope_->counterFromString("graceful_stop_requested").inc(); }); +} + StatisticPtrMap ClientWorkerImpl::statistics() const { StatisticPtrMap statistics; StatisticPtrMap s1 = benchmark_client_->statistics(); diff --git a/source/client/client_worker_impl.h b/source/client/client_worker_impl.h index 05f2fcb35..41b2660bc 100644 --- a/source/client/client_worker_impl.h +++ b/source/client/client_worker_impl.h @@ -46,6 +46,8 @@ class ClientWorkerImpl : public WorkerImpl, virtual public ClientWorker { void shutdownThread() override; + void requestExecutionCancellation() override; + protected: void work() override; diff --git a/source/client/factories_impl.cc b/source/client/factories_impl.cc index 0f63700b8..a2e3eb611 100644 --- a/source/client/factories_impl.cc +++ b/source/client/factories_impl.cc @@ -174,14 +174,18 @@ TerminationPredicateFactoryImpl::TerminationPredicateFactoryImpl(const Options& TerminationPredicatePtr TerminationPredicateFactoryImpl::create(Envoy::TimeSource& time_source, Envoy::Stats::Scope& scope, const Envoy::MonotonicTime scheduled_starting_time) const { - TerminationPredicatePtr root_predicate; - if (options_.noDuration()) { - root_predicate = std::make_unique(); - } else { - root_predicate = std::make_unique( - time_source, options_.duration(), scheduled_starting_time); - } + // We'll always link a predicate which checks for requests to cancel. + TerminationPredicatePtr root_predicate = + std::make_unique( + scope.counterFromString("graceful_stop_requested"), 0, + TerminationPredicate::Status::TERMINATE); + TerminationPredicate* current_predicate = root_predicate.get(); + if (!options_.noDuration()) { + current_predicate = ¤t_predicate->link(std::make_unique( + time_source, options_.duration(), scheduled_starting_time)); + } + current_predicate = linkConfiguredPredicates(*current_predicate, options_.failurePredicates(), TerminationPredicate::Status::FAIL, scope); linkConfiguredPredicates(*current_predicate, options_.terminationPredicates(), diff --git a/source/client/process_impl.cc b/source/client/process_impl.cc index 311ef938f..85799d723 100644 --- a/source/client/process_impl.cc +++ b/source/client/process_impl.cc @@ -126,11 +126,15 @@ void ProcessImpl::shutdown() { // Before we shut down the worker threads, stop threading. tls_.shutdownGlobalThreading(); store_root_.shutdownThreading(); - // Before shutting down the cluster manager, stop the workers. - for (auto& worker : workers_) { - worker->shutdown(); + + { + auto guard = std::make_unique(workers_lock_); + // Before shutting down the cluster manager, stop the workers. + for (auto& worker : workers_) { + worker->shutdown(); + } + workers_.clear(); } - workers_.clear(); if (cluster_manager_ != nullptr) { cluster_manager_->shutdown(); } @@ -138,7 +142,17 @@ void ProcessImpl::shutdown() { shutdown_ = true; } -const std::vector& ProcessImpl::createWorkers(const uint32_t concurrency) { +bool ProcessImpl::requestExecutionCancellation() { + ENVOY_LOG(debug, "Requesting workers to cancel execution"); + auto guard = std::make_unique(workers_lock_); + for (auto& worker : workers_) { + worker->requestExecutionCancellation(); + } + cancelled_ = true; + return true; +} + +void ProcessImpl::createWorkers(const uint32_t concurrency) { // TODO(oschaaf): Expose kMinimalDelay in configuration. const std::chrono::milliseconds kMinimalWorkerDelay = 500ms + (concurrency * 50ms); ASSERT(workers_.empty()); @@ -168,7 +182,6 @@ const std::vector& ProcessImpl::createWorkers(const uint32_t co : ClientWorkerImpl::HardCodedWarmupStyle::OFF)); worker_number++; } - return workers_; } void ProcessImpl::configureComponentLogLevels(spdlog::level::level_enum level) { @@ -381,44 +394,50 @@ void ProcessImpl::addRequestSourceCluster( bool ProcessImpl::runInternal(OutputCollector& collector, const std::vector& uris, const UriPtr& request_source_uri, const UriPtr& tracing_uri) { - int number_of_workers = determineConcurrency(); - shutdown_ = false; - const std::vector& workers = createWorkers(number_of_workers); - tls_.registerThread(*dispatcher_, true); - store_root_.initializeThreading(*dispatcher_, tls_); - runtime_singleton_ = std::make_unique( - Envoy::Runtime::LoaderPtr{new Envoy::Runtime::LoaderImpl( - *dispatcher_, tls_, {}, *local_info_, store_root_, generator_, - Envoy::ProtobufMessage::getStrictValidationVisitor(), *api_)}); - ssl_context_manager_ = - std::make_unique(time_system_); - cluster_manager_factory_ = std::make_unique( - admin_, Envoy::Runtime::LoaderSingleton::get(), store_root_, tls_, generator_, - dispatcher_->createDnsResolver({}, false), *ssl_context_manager_, *dispatcher_, *local_info_, - secret_manager_, validation_context_, *api_, http_context_, grpc_context_, - access_log_manager_, *singleton_manager_); - cluster_manager_factory_->setConnectionReuseStrategy( - options_.h1ConnectionReuseStrategy() == nighthawk::client::H1ConnectionReuseStrategy::LRU - ? Http1PoolImpl::ConnectionReuseStrategy::LRU - : Http1PoolImpl::ConnectionReuseStrategy::MRU); - cluster_manager_factory_->setPrefetchConnections(options_.prefetchConnections()); - envoy::config::bootstrap::v3::Bootstrap bootstrap; - createBootstrapConfiguration(bootstrap, uris, request_source_uri, number_of_workers); - if (tracing_uri != nullptr) { - setupTracingImplementation(bootstrap, *tracing_uri); - addTracingCluster(bootstrap, *tracing_uri); - } - ENVOY_LOG(debug, "Computed configuration: {}", bootstrap.DebugString()); - cluster_manager_ = cluster_manager_factory_->clusterManagerFromProto(bootstrap); - maybeCreateTracingDriver(bootstrap.tracing()); - cluster_manager_->setInitializedCb([this]() -> void { init_manager_.initialize(init_watcher_); }); + { + auto guard = std::make_unique(workers_lock_); + if (cancelled_) { + return true; + } + int number_of_workers = determineConcurrency(); + shutdown_ = false; + createWorkers(number_of_workers); + tls_.registerThread(*dispatcher_, true); + store_root_.initializeThreading(*dispatcher_, tls_); + runtime_singleton_ = std::make_unique( + Envoy::Runtime::LoaderPtr{new Envoy::Runtime::LoaderImpl( + *dispatcher_, tls_, {}, *local_info_, store_root_, generator_, + Envoy::ProtobufMessage::getStrictValidationVisitor(), *api_)}); + ssl_context_manager_ = + std::make_unique(time_system_); + cluster_manager_factory_ = std::make_unique( + admin_, Envoy::Runtime::LoaderSingleton::get(), store_root_, tls_, generator_, + dispatcher_->createDnsResolver({}, false), *ssl_context_manager_, *dispatcher_, + *local_info_, secret_manager_, validation_context_, *api_, http_context_, grpc_context_, + access_log_manager_, *singleton_manager_); + cluster_manager_factory_->setConnectionReuseStrategy( + options_.h1ConnectionReuseStrategy() == nighthawk::client::H1ConnectionReuseStrategy::LRU + ? Http1PoolImpl::ConnectionReuseStrategy::LRU + : Http1PoolImpl::ConnectionReuseStrategy::MRU); + cluster_manager_factory_->setPrefetchConnections(options_.prefetchConnections()); + envoy::config::bootstrap::v3::Bootstrap bootstrap; + createBootstrapConfiguration(bootstrap, uris, request_source_uri, number_of_workers); + if (tracing_uri != nullptr) { + setupTracingImplementation(bootstrap, *tracing_uri); + addTracingCluster(bootstrap, *tracing_uri); + } + ENVOY_LOG(debug, "Computed configuration: {}", bootstrap.DebugString()); + cluster_manager_ = cluster_manager_factory_->clusterManagerFromProto(bootstrap); + maybeCreateTracingDriver(bootstrap.tracing()); + cluster_manager_->setInitializedCb( + [this]() -> void { init_manager_.initialize(init_watcher_); }); - Runtime::LoaderSingleton::get().initialize(*cluster_manager_); + Runtime::LoaderSingleton::get().initialize(*cluster_manager_); - for (auto& w : workers_) { - w->start(); + for (auto& w : workers_) { + w->start(); + } } - for (auto& w : workers_) { w->waitForCompletion(); } @@ -447,7 +466,7 @@ bool ProcessImpl::runInternal(OutputCollector& collector, const std::vector 0; }); StatisticFactoryImpl statistic_factory(options_); - collector.addResult("global", mergeWorkerStatistics(workers), counters, + collector.addResult("global", mergeWorkerStatistics(workers_), counters, total_execution_duration / workers_.size()); return counters.find("sequencer.failed_terminations") == counters.end(); } diff --git a/source/client/process_impl.h b/source/client/process_impl.h index d1f0334cf..11d265fb8 100644 --- a/source/client/process_impl.h +++ b/source/client/process_impl.h @@ -80,6 +80,8 @@ class ProcessImpl : public Process, public Envoy::Logger::Loggable& createWorkers(const uint32_t concurrency); + /** + * Prepare the ProcessImpl instance by creating and configuring the workers it needs for execution + * of the load test. + * + * @param concurrency the amount of workers that should be created. + */ + void createWorkers(const uint32_t concurrency); std::vector vectorizeStatisticPtrMap(const StatisticPtrMap& statistics) const; std::vector mergeWorkerStatistics(const std::vector& workers) const; @@ -145,6 +153,8 @@ class ProcessImpl : public Process, public Envoy::Logger::Loggable duration_ ? TerminationPredicate::Status::TERMINATE : TerminationPredicate::Status::PROCEED; diff --git a/source/common/termination_predicate_impl.h b/source/common/termination_predicate_impl.h index 80e4c67a9..c1c761345 100644 --- a/source/common/termination_predicate_impl.h +++ b/source/common/termination_predicate_impl.h @@ -27,14 +27,6 @@ class TerminationPredicateBaseImpl : public TerminationPredicate { TerminationPredicatePtr linked_child_; }; -/** - * Predicate which always returns TerminationPredicate::Status::PROCEED. - */ -class NullTerminationPredicateImpl : public TerminationPredicateBaseImpl { -public: - TerminationPredicate::Status evaluate() override; -}; - /** * Predicate which indicates termination iff the passed in duration has expired. * time tracking starts at the first call to evaluate(). diff --git a/test/integration/test_integration_basics.py b/test/integration/test_integration_basics.py index df02ef193..d82fdb478 100644 --- a/test/integration/test_integration_basics.py +++ b/test/integration/test_integration_basics.py @@ -1,9 +1,13 @@ #!/usr/bin/env python3 +import json import logging import os +import subprocess import sys import pytest +import time +from threading import Thread from test.integration.common import IpVersion from test.integration.integration_test_fixtures import ( @@ -654,6 +658,35 @@ def test_http_request_release_timing(http_test_server_fixture, qps_parameterizat assertCounterEqual(counters, "benchmark.http_2xx", (total_requests)) +def _send_sigterm(process): + # Sleep for a while, under tsan the client needs a lot of time + # to start up. 10 seconds has been determined to work through + # emperical observation. + time.sleep(10) + process.terminate() + + +def test_cancellation(http_test_server_fixture): + """ + Make sure that we can use signals to cancel execution. + """ + args = [ + http_test_server_fixture.nighthawk_client_path, "--concurrency", "2", + http_test_server_fixture.getTestServerRootUri(), "--duration", "1000", "--output-format", + "json" + ] + client_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + Thread(target=(lambda: _send_sigterm(client_process))).start() + stdout, stderr = client_process.communicate() + client_process.wait() + output = stdout.decode('utf-8') + assertEqual(client_process.returncode, 0) + parsed_json = json.loads(output) + counters = http_test_server_fixture.getNighthawkCounterMapFromJson(parsed_json) + assertCounterEqual(counters, "graceful_stop_requested", 2) + assertCounterGreaterEqual(counters, "benchmark.http_2xx", 1) + + def _run_client_with_args(args): return run_binary_with_args("nighthawk_client", args) diff --git a/test/process_test.cc b/test/process_test.cc index 5072bf245..b2b977175 100644 --- a/test/process_test.cc +++ b/test/process_test.cc @@ -1,3 +1,4 @@ +#include #include #include "nighthawk/common/exception.h" @@ -32,12 +33,48 @@ class ProcessTest : public TestWithParam { : loopback_address_(Envoy::Network::Test::getLoopbackAddressUrlString(GetParam())), options_(TestUtility::createOptionsImpl( fmt::format("foo --duration 1 -v error --rps 10 https://{}/", loopback_address_))){}; - void runProcess(RunExpectation expectation) { + + void runProcess(RunExpectation expectation, bool do_cancel = false, + bool terminate_right_away = false) { ProcessPtr process = std::make_unique(*options_, time_system_); OutputCollectorImpl collector(time_system_, *options_); + std::thread cancel_thread; + if (do_cancel) { + cancel_thread = std::thread([&process, terminate_right_away] { + if (!terminate_right_away) { + // We sleep to give the the load test execution in the other thread a change to get + // started before we request cancellation. Five seconds has been determined to work with + // the sanitizer runs in CI through emperical observation. + sleep(5); + } + process->requestExecutionCancellation(); + }); + if (terminate_right_away) { + cancel_thread.join(); + } + } const auto result = process->run(collector) ? RunExpectation::EXPECT_SUCCESS : RunExpectation::EXPECT_FAILURE; EXPECT_EQ(result, expectation); + if (do_cancel) { + if (cancel_thread.joinable()) { + cancel_thread.join(); + } + auto proto = collector.toProto(); + if (terminate_right_away) { + EXPECT_EQ(0, proto.results().size()); + } else { + int graceful_stop_requested = 0; + for (const auto& result : proto.results()) { + for (const auto& counter : result.counters()) { + if (counter.name() == "graceful_stop_requested") { + graceful_stop_requested++; + } + } + } + EXPECT_EQ(3, graceful_stop_requested); // global results + two workers + } + } process->shutdown(); } @@ -64,5 +101,22 @@ TEST_P(ProcessTest, BadTracerSpec) { runProcess(RunExpectation::EXPECT_FAILURE); } +TEST_P(ProcessTest, CancelDuringLoadTest) { + // The failure predicate below is there to wipe out any stock ones. We want this to run for a long + // time, even if the upstream fails (there is no live upstream in this test, we send traffic into + // the void), so we can check cancellation works. + options_ = TestUtility::createOptionsImpl( + fmt::format("foo --duration 300 --failure-predicate foo:0 --concurrency 2 https://{}/", + loopback_address_)); + runProcess(RunExpectation::EXPECT_SUCCESS, true); +} + +TEST_P(ProcessTest, CancelExecutionBeforeBeginLoadTest) { + options_ = TestUtility::createOptionsImpl( + fmt::format("foo --duration 300 --failure-predicate foo:0 --concurrency 2 https://{}/", + loopback_address_)); + runProcess(RunExpectation::EXPECT_SUCCESS, true, true); +} + } // namespace Client } // namespace Nighthawk