diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 8805850e3145..428d71960bb9 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -759,6 +759,7 @@ integrationtest_ubuntu_cpu_dist_kvstore() { ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu --no-multiprecision + ../../tools/launch.py -n 3 --launcher local python test_server_profiling.py } integrationtest_ubuntu_gpu_scala() { diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py index 67cda78172b6..b3b13053addf 100755 --- a/example/image-classification/common/fit.py +++ b/example/image-classification/common/fit.py @@ -135,6 +135,12 @@ def add_fit_args(parser): help='the epochs to ramp-up lr to scaled large-batch value') train.add_argument('--warmup-strategy', type=str, default='linear', help='the ramping-up strategy for large batch sgd') + train.add_argument('--profile-worker-suffix', type=str, default='', + help='profile workers actions into this file. During distributed training\ + filename saved will be rank1_ followed by this suffix') + train.add_argument('--profile-server-suffix', type=str, default='', + help='profile server actions into a file with name like rank1_ followed by this suffix \ + during distributed training') return train @@ -150,6 +156,17 @@ def fit(args, network, data_loader, **kwargs): if args.gc_type != 'none': kv.set_gradient_compression({'type': args.gc_type, 'threshold': args.gc_threshold}) + if args.profile_server_suffix: + mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server') + mx.profiler.set_state(state='run', profile_process='server') + + if args.profile_worker_suffix: + if kv.num_workers > 1: + filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix + else: + filename = args.profile_worker_suffix + mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker') + mx.profiler.set_state(state='run', profile_process='worker') # logging head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' @@ -180,7 +197,6 @@ def fit(args, network, data_loader, **kwargs): logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i, args.disp_batches * args.batch_size / (time.time() - tic)) tic = time.time() - return # load model @@ -314,3 +330,8 @@ def fit(args, network, data_loader, **kwargs): epoch_end_callback=checkpoint, allow_missing=True, monitor=monitor) + + if args.profile_server_suffix: + mx.profiler.set_state(state='run', profile_process='server') + if args.profile_worker_suffix: + mx.profiler.set_state(state='run', profile_process='worker') \ No newline at end of file diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 75147cfd706d..6bbe9dfe8f0a 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -230,7 +230,19 @@ MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id); MXNET_DLL int MXNotifyShutdown(); /*! - * \brief Set up configuration of profiler + * \brief Set up configuration of profiler for the process passed as profile_process in keys + * \param num_params Number of parameters + * \param keys array of parameter keys + * \param vals array of parameter values + * \param kvstoreHandle handle to kvstore + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys, + const char* const* vals, + KVStoreHandle kvstoreHandle); + +/*! + * \brief Set up configuration of profiler for worker/current process * \param num_params Number of parameters * \param keys array of parameter keys * \param vals array of parameter values @@ -239,7 +251,21 @@ MXNET_DLL int MXNotifyShutdown(); MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals); /*! - * \brief Set up state of profiler + * \brief Set up state of profiler for either worker or server process + * \param state indicate the working state of profiler, + * profiler not running when state == 0, + * profiler running when state == 1 + * \param profile_process an int, + * when 0 command is for worker/current process, + * when 1 command is for server process + * \param kvstoreHandle handle to kvstore, needed for server process profiling + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process, + KVStoreHandle kvStoreHandle); + +/*! + * \brief Set up state of profiler for current process * \param state indicate the working state of profiler, * profiler not running when state == 0, * profiler running when state == 1 @@ -250,11 +276,22 @@ MXNET_DLL int MXSetProfilerState(int state); /*! * \brief Save profile and stop profiler * \param finished true if stat output should stop after this point + * \param profile_process an int, + * when 0 command is for worker/current process, + * when 1 command is for server process + * \param kvstoreHandle handle to kvstore * \return 0 when success, -1 when failure happens. */ -MXNET_DLL int MXDumpProfile(int finished); +MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle); +/*! + * \brief Save profile and stop profiler for worker/current process + * \param finished true if stat output should stop after this point + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXDumpProfile(int finished); + /*! * \brief Print aggregate stats to the a string * \param out_str Will receive a pointer to the output string @@ -267,6 +304,16 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset); /*! * \brief Pause profiler tuning collection * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues + * \param profile_process integer which denotes whether to process worker or server process + * \param kvstoreHandle handle to kvstore + * \return 0 when success, -1 when failure happens. + * \note pausing and resuming is global and not recursive + */ +MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle); + +/*! + * \brief Pause profiler tuning collection for worker/current process + * \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues * \return 0 when success, -1 when failure happens. * \note pausing and resuming is global and not recursive */ @@ -2145,8 +2192,7 @@ typedef void (MXKVStoreServerController)(int head, void *controller_handle); /** - * \return Run as server (or scheduler) - * + * \brief Run as server (or scheduler) * \param handle handle to the KVStore * \param controller the user-defined server controller * \param controller_handle helper handle for implementing controller @@ -2157,8 +2203,7 @@ MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle, void *controller_handle); /** - * \return Send a command to all server nodes - * + * \brief Send a command to all server nodes * \param handle handle to the KVStore * \param cmd_id the head of the command * \param cmd_body the body of the command diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index e10bd213aa26..a73d96356132 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -38,6 +38,18 @@ #endif // MXNET_USE_DIST_KVSTORE namespace mxnet { + +/*! + * \brief enum to denote types of commands kvstore sends to server regarding profiler + * kSetConfig sets profiler configs. Similar to mx.profiler.set_config() + * kState allows changing state of profiler to stop or run + * kPause allows pausing and resuming of profiler + * kDump asks profiler to dump output + */ +enum class KVStoreServerProfilerCommand { + kSetConfig, kState, kPause, kDump +}; + /*! * \brief distributed key-value store * @@ -364,6 +376,20 @@ class KVStore { */ virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { } + /** + * \brief Sends server profiler commands to all server nodes + * Only the worker with rank=0 sends the command which will be received by all servers + * \param type ProfilerCommand type + * \param params parameters for that command in the form of a string + */ + virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type, + const std::string& params) { + LOG(INFO) << "Unable to pass server the profiler command. If you are using " + << "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1." + << "If you are training on single machine, then there is no server process" + << "to profile. Please profile the worker process instead."; + } + /** * \brief the prototype of a server controller */ diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 609733659753..a54817501391 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -28,6 +28,7 @@ from .base import check_call, string_types, mx_uint, py_str from .base import NDArrayHandle, KVStoreHandle from . import optimizer as opt +from .profiler import set_kvstore_handle def _ctype_key_value(keys, vals): """ @@ -88,7 +89,8 @@ def _get_kvstore_server_command_type(command): 'kSetMultiPrecision': 1, 'kStopServer': 2, 'kSyncMode': 3, - 'kSetGradientCompression': 4} + 'kSetGradientCompression': 4, + 'kSetProfilerParams': 5} assert (command in command_types), "Unknown command type to send to server" return command_types[command] @@ -670,4 +672,6 @@ def create(name='local'): handle = KVStoreHandle() check_call(_LIB.MXKVStoreCreate(c_str(name), ctypes.byref(handle))) - return KVStore(handle) + kv = KVStore(handle) + set_kvstore_handle(kv.handle) + return kv diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py index 0e7a31c687ef..0b5e85b1eb54 100644 --- a/python/mxnet/profiler.py +++ b/python/mxnet/profiler.py @@ -22,8 +22,13 @@ from __future__ import absolute_import import ctypes import warnings -from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str +from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle +profiler_kvstore_handle = KVStoreHandle() + +def set_kvstore_handle(handle): + global profiler_kvstore_handle + profiler_kvstore_handle = handle def set_config(**kwargs): """Set up the configure of profiler (only accepts keyword arguments). @@ -49,12 +54,17 @@ def set_config(**kwargs): aggregate_stats : boolean, whether to maintain aggregate stats in memory for console dump. Has some negative performance impact. + profile_process : string + whether to profile kvstore `server` or `worker`. + server can only be profiled when kvstore is of type dist. + if this is not passed, defaults to `worker` """ kk = kwargs.keys() vv = kwargs.values() - check_call(_LIB.MXSetProfilerConfig(len(kwargs), - c_str_array([key for key in kk]), - c_str_array([str(val) for val in vv]))) + check_call(_LIB.MXSetProcessProfilerConfig(len(kwargs), + c_str_array([key for key in kk]), + c_str_array([str(val) for val in vv]), + profiler_kvstore_handle)) def profiler_set_config(mode='symbolic', filename='profile.json'): @@ -73,10 +83,10 @@ def profiler_set_config(mode='symbolic', filename='profile.json'): keys = c_str_array([key for key in ["profile_" + mode, "filename"]]) values = c_str_array([str(val) for val in [True, filename]]) assert len(keys) == len(values) - check_call(_LIB.MXSetProfilerConfig(len(keys), keys, values)) + check_call(_LIB.MXSetProcessProfilerConfig(len(keys), keys, values, profiler_kvstore_handle)) -def set_state(state='stop'): +def set_state(state='stop', profile_process='worker'): """Set up the profiler state to 'run' or 'stop'. Parameters @@ -84,9 +94,16 @@ def set_state(state='stop'): state : string, optional Indicates whether to run the profiler, can be 'stop' or 'run'. Default is `stop`. + profile_process : string + whether to profile kvstore `server` or `worker`. + server can only be profiled when kvstore is of type dist. + if this is not passed, defaults to `worker` """ state2int = {'stop': 0, 'run': 1} - check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state]))) + profile_process2int = {'worker': 0, 'server': 1} + check_call(_LIB.MXSetProcessProfilerState(ctypes.c_int(state2int[state]), + profile_process2int[profile_process], + profiler_kvstore_handle)) def profiler_set_state(state='stop'): @@ -102,7 +119,7 @@ def profiler_set_state(state='stop'): 'Please use profiler.set_state() instead') set_state(state) -def dump(finished=True): +def dump(finished=True, profile_process='worker'): """Dump profile and stop profiler. Use this to save profile in advance in case your program cannot exit normally. @@ -111,9 +128,16 @@ def dump(finished=True): finished : boolean Indicates whether to stop statistic output (dumping) after this dump. Default is True + profile_process : string + whether to profile kvstore `server` or `worker`. + server can only be profiled when kvstore is of type dist. + if this is not passed, defaults to `worker` """ - fin = 1 if finished is True else False - check_call(_LIB.MXDumpProfile(fin)) + fin = 1 if finished is True else 0 + profile_process2int = {'worker': 0, 'server': 1} + check_call(_LIB.MXDumpProcessProfile(fin, + profile_process2int[profile_process], + profiler_kvstore_handle)) def dump_profile(): @@ -138,14 +162,37 @@ def dumps(reset=False): return py_str(debug_str.value) -def pause(): - """Pause profiling.""" - check_call(_LIB.MXProfilePause(int(1))) +def pause(profile_process='worker'): + """Pause profiling. + + Parameters + ---------- + profile_process : string + whether to profile kvstore `server` or `worker`. + server can only be profiled when kvstore is of type dist. + if this is not passed, defaults to `worker` + """ + profile_process2int = {'worker': 0, 'server': 1} + check_call(_LIB.MXProcessProfilePause(int(1), + profile_process2int[profile_process], + profiler_kvstore_handle)) + +def resume(profile_process='worker'): + """ + Resume paused profiling. -def resume(): - """Resume paused profiling.""" - check_call(_LIB.MXProfilePause(int(0))) + Parameters + ---------- + profile_process : string + whether to profile kvstore `server` or `worker`. + server can only be profiled when kvstore is of type dist. + if this is not passed, defaults to `worker` + """ + profile_process2int = {'worker': 0, 'server': 1} + check_call(_LIB.MXProcessProfilePause(int(0), + profile_process2int[profile_process], + profiler_kvstore_handle)) class Domain(object): diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc index c5841775794d..9c03b339e3ca 100644 --- a/src/c_api/c_api_profile.cc +++ b/src/c_api/c_api_profile.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include "./c_api_common.h" #include "../profiler/profiler.h" @@ -197,6 +198,10 @@ struct PythonProfileObjects { }; static PythonProfileObjects python_profile_objects; +enum class ProfileProcess { + kWorker, kServer +}; + struct ProfileConfigParam : public dmlc::Parameter { bool profile_all; bool profile_symbolic; @@ -207,6 +212,7 @@ struct ProfileConfigParam : public dmlc::Parameter { bool continuous_dump; float dump_period; bool aggregate_stats; + int profile_process; DMLC_DECLARE_PARAMETER(ProfileConfigParam) { DMLC_DECLARE_FIELD(profile_all).set_default(false) .describe("Profile all."); @@ -228,6 +234,13 @@ struct ProfileConfigParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(aggregate_stats).set_default(false) .describe("Maintain aggregate stats, required for MXDumpAggregateStats. Note that " "this can have anegative performance impact."); + DMLC_DECLARE_FIELD(profile_process) + .add_enum("worker", static_cast(ProfileProcess::kWorker)) + .add_enum("server", static_cast(ProfileProcess::kServer)) + .set_default(static_cast(ProfileProcess::kWorker)) + .describe("Specifies which process to profile: " + "worker: this is default. for single node training it should always be worker." + "server: for distributed training, this profiles server process"); } }; @@ -248,7 +261,8 @@ struct ProfileMarkerScopeParam : public dmlc::Parameter DMLC_REGISTER_PARAMETER(ProfileMarkerScopeParam); -int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) { +int MXSetProcessProfilerConfig(int num_params, const char* const* keys, const char* const* vals, + KVStoreHandle kvstoreHandle) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); std::vector> kwargs; @@ -260,19 +274,37 @@ int MXSetProfilerConfig(int num_params, const char* const* keys, const char* con } ProfileConfigParam param; param.Init(kwargs); - int mode = 0; - if (param.profile_api || param.profile_all) { mode |= profiler::Profiler::kAPI; } - if (param.profile_symbolic || param.profile_all) { mode |= profiler::Profiler::kSymbolic; } - if (param.profile_imperative || param.profile_all) { mode |= profiler::Profiler::kImperative; } - if (param.profile_memory || param.profile_all) { mode |= profiler::Profiler::kMemory; } - profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode), - std::string(param.filename), - param.continuous_dump, - param.dump_period, - param.aggregate_stats); + if (static_cast(param.profile_process) == ProfileProcess::kServer) { + std::ostringstream os; + for (int i = 0; i < num_params; ++i) { + // this will be sent to the server now, those configs shouldn't have profile server again + if (strcmp(keys[i], "profile_process") == 0) continue; + os << keys[i] << ":" << vals[i]; + if (i != num_params - 1) os << ","; + } + CHECK(kvstoreHandle) << "KVStoreHandle passed to profiler is null"; + static_cast(kvstoreHandle)->SetServerProfilerCommand( + mxnet::KVStoreServerProfilerCommand::kSetConfig, os.str()); + } else { + int mode = 0; + if (param.profile_api || param.profile_all) { mode |= profiler::Profiler::kAPI; } + if (param.profile_symbolic || param.profile_all) { mode |= profiler::Profiler::kSymbolic; } + if (param.profile_imperative || + param.profile_all) { mode |= profiler::Profiler::kImperative; } + if (param.profile_memory || param.profile_all) { mode |= profiler::Profiler::kMemory; } + profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode), + std::string(param.filename), + param.continuous_dump, + param.dump_period, + param.aggregate_stats); + } API_END(); } +int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals) { + return MXSetProcessProfilerConfig(num_params, keys, vals, nullptr); +} + int MXAggregateProfileStatsPrint(const char **out_str, int reset) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); @@ -293,19 +325,40 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) { } int MXDumpProfile(int finished) { + return MXDumpProcessProfile(finished, static_cast(ProfileProcess::kWorker), nullptr); +} + +int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); + if (static_cast(profile_process) == ProfileProcess::kServer) { + CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null"; + static_cast(kvStoreHandle)->SetServerProfilerCommand( + mxnet::KVStoreServerProfilerCommand::kDump, + std::to_string(finished)); + } else { profiler::Profiler *profiler = profiler::Profiler::Get(); CHECK(profiler->IsEnableOutput()) << "Profiler hasn't been run. Config and start profiler first"; profiler->DumpProfile(finished != 0); + } API_END() } int MXSetProfilerState(int state) { + return MXSetProcessProfilerState(state, static_cast(ProfileProcess::kWorker), nullptr); +} + +int MXSetProcessProfilerState(int state, int profile_process, KVStoreHandle kvStoreHandle) { mxnet::IgnoreProfileCallScope ignore; // state, kNotRunning: 0, kRunning: 1 API_BEGIN(); + if (static_cast(profile_process) == ProfileProcess::kServer) { + CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null"; + static_cast(kvStoreHandle)->SetServerProfilerCommand( + mxnet::KVStoreServerProfilerCommand::kState, + std::to_string(state)); + } else { switch (state) { case profiler::Profiler::kNotRunning: profiler::vtune::vtune_pause(); @@ -315,6 +368,7 @@ int MXSetProfilerState(int state) { break; } profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(state)); + } API_END(); } @@ -450,8 +504,18 @@ int MXProfileDurationStop(ProfileHandle duration_handle) { } int MXProfilePause(int paused) { + return MXProcessProfilePause(paused, static_cast(ProfileProcess::kWorker), nullptr); +} + +int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); + if (static_cast(profile_process) == ProfileProcess::kServer) { + CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null"; + static_cast(kvStoreHandle)->SetServerProfilerCommand( + mxnet::KVStoreServerProfilerCommand::kPause, + std::to_string(paused)); + } else { if (paused) { profiler::vtune::vtune_pause(); profiler::Profiler::Get()->set_paused(true); @@ -459,6 +523,7 @@ int MXProfilePause(int paused) { profiler::Profiler::Get()->set_paused(false); profiler::vtune::vtune_resume(); } + } API_END(); } diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc index e94a0570d1f4..e4a06fa9a1f2 100644 --- a/src/kvstore/gradient_compression.cc +++ b/src/kvstore/gradient_compression.cc @@ -23,31 +23,14 @@ * \author Rahul Huilgol */ -#include #include +#include "kvstore_local.h" #include "gradient_compression.h" #include "gradient_compression-inl.h" namespace mxnet { namespace kvstore { -/*! - * \brief Splits a string into smaller strings using char as delimiter - * Example: "a,b,c,,d" is split into ["a","b","c","","d"] - * \param s string to split - * \param delim char to split string around - * \param result container for tokens extracted after splitting - */ -template -void split(const std::string &s, const char delim, Out result) { - std::stringstream ss; - ss.str(s); - std::string item; - while (std::getline(ss, item, delim)) { - *(result++) = item; - } -} - DMLC_REGISTER_PARAMETER(GradientCompressionParam); GradientCompression::GradientCompression() { @@ -90,7 +73,7 @@ std::string GradientCompression::EncodeParams() { void GradientCompression::DecodeParams(const std::string &s) { std::vector elems; - split(s, ',', std::back_inserter(elems)); + mxnet::kvstore::split(s, ',', std::back_inserter(elems)); type_ = static_cast(stoi(elems[0])); if (elems.size() > 1) { if (!elems[1].empty()) { diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 7e2f5cb5faa9..23fbf67474ee 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -93,6 +93,15 @@ class KVStoreDist : public KVStoreLocal { } } + void SetServerProfilerCommand(const KVStoreServerProfilerCommand type, + const std::string& params) override { + if (get_rank() == 0) { + SendCommandToServers(static_cast(CommandType::kSetProfilerParams), + params + std::to_string(static_cast(type))); + } + } + + void Barrier() override { ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup); } diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 451fb78a6229..372b58dbbf3d 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -24,6 +24,9 @@ */ #ifndef MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_ #define MXNET_KVSTORE_KVSTORE_DIST_SERVER_H_ +#include +#include +#include #include #include #include @@ -32,8 +35,7 @@ #include #include #include -#include "ps/ps.h" -#include "mxnet/kvstore.h" +#include "../profiler/profiler.h" #include "../operator/tensor/elemwise_binary_op-inl.h" #include "../operator/tensor/init_op.h" @@ -42,7 +44,8 @@ namespace kvstore { // maintain same order in frontend. enum class CommandType { - kController, kSetMultiPrecision, kStopServer, kSyncMode, kSetGradientCompression, + kController, kSetMultiPrecision, kStopServer, kSyncMode, + kSetGradientCompression, kSetProfilerParams }; enum class RequestType { @@ -164,6 +167,7 @@ class KVStoreDistServer { } ~KVStoreDistServer() { + profiler::Profiler::Get()->SetState(profiler::Profiler::ProfilerState(0)); delete ps_server_; } @@ -194,27 +198,37 @@ class KVStoreDistServer { void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) { CommandType recved_type = static_cast(recved.head); - if (recved_type == CommandType::kStopServer) { - exec_.Stop(); - } else if (recved_type == CommandType::kSyncMode) { - sync_mode_ = true; - } else if (recved_type == CommandType::kSetGradientCompression) { - gradient_compression_->DecodeParams(recved.body); - } else if (recved_type == CommandType::kSetMultiPrecision) { - // uses value 1 for message id from frontend - if (!multi_precision_) { - multi_precision_ = true; - CreateMultiPrecisionCopies(); - } - } else if (recved_type == CommandType::kController) { - // value of 0 - // let the main thread to execute ctrl, which is necessary for python - exec_.Exec([this, recved]() { - CHECK(controller_); - controller_(recved.head, recved.body); - }); - } else { - LOG(FATAL) << "Unknown command type received " << recved.head; + switch (recved_type) { + case CommandType::kStopServer: + exec_.Stop(); + break; + case CommandType::kSyncMode: + sync_mode_ = true; + break; + case CommandType::kSetGradientCompression: + gradient_compression_->DecodeParams(recved.body); + break; + case CommandType::kSetProfilerParams: + // last char is the type of profiler command + ProcessServerProfilerCommands(static_cast + (recved.body.back() - '0'), + recved.body); + break; + case CommandType::kSetMultiPrecision: + // uses value 1 for message id from frontend + if (!multi_precision_) { + multi_precision_ = true; + CreateMultiPrecisionCopies(); + } + break; + case CommandType::kController: + // this uses value 0 for message id from frontend + // let the main thread to execute ctrl, which is necessary for python + exec_.Exec([this, recved]() { + CHECK(controller_); + controller_(recved.head, recved.body); + }); + break; } app->Response(recved); } @@ -225,11 +239,11 @@ class KVStoreDistServer { * some keys are initialized before optimizer is set. */ void CreateMultiPrecisionCopies() { - for (auto const& stored_entry : store_) { + for (auto const &stored_entry : store_) { const int key = stored_entry.first; - const NDArray& stored = stored_entry.second; + const NDArray &stored = stored_entry.second; if (stored.dtype() != mshadow::kFloat32) { - auto& stored_realt = store_realt_[key]; + auto &stored_realt = store_realt_[key]; if (stored.storage_type() == kRowSparseStorage) { stored_realt = NDArray(kRowSparseStorage, stored.shape(), stored.ctx(), true, mshadow::kFloat32); @@ -237,7 +251,7 @@ class KVStoreDistServer { stored_realt = NDArray(stored.shape(), stored.ctx(), false, mshadow::kFloat32); } - auto& update = update_buf_[key]; + auto &update = update_buf_[key]; if (!update.merged.is_none()) { if (update.merged.storage_type() == kRowSparseStorage) { update.merged = NDArray(kRowSparseStorage, update.merged.shape(), update.merged.ctx(), @@ -254,11 +268,60 @@ class KVStoreDistServer { CopyFromTo(stored, stored_realt); } } - for (auto const& stored_realt_entry : store_realt_) { + for (auto const &stored_realt_entry : store_realt_) { stored_realt_entry.second.WaitToRead(); } } + void ProcessServerProfilerCommands(KVStoreServerProfilerCommand type, const std::string& body) { + switch (type) { + case KVStoreServerProfilerCommand::kSetConfig: + SetProfilerConfig(body.substr(0, body.size() - 1)); + break; + case KVStoreServerProfilerCommand::kState: + MXSetProfilerState(static_cast(body.front() - '0')); + break; + case KVStoreServerProfilerCommand::kPause: + MXProfilePause(static_cast(body.front() - '0')); + break; + case KVStoreServerProfilerCommand::kDump: + MXDumpProfile(static_cast(body.front() - '0')); + break; + } + } + + void SetProfilerConfig(std::string params_str) { + std::vector elems; + mxnet::kvstore::split(params_str, ',', std::back_inserter(elems)); + std::vector ckeys; + std::vector cvals; + ckeys.reserve(elems.size()); + cvals.reserve(elems.size()); + + for (size_t i=0; i < elems.size(); i++) { + std::vector parts; + mxnet::kvstore::split(elems[i], ':', std::back_inserter(parts)); + CHECK_EQ(parts.size(), 2) << "Improper profiler config passed from worker"; + CHECK(!parts[0].empty()) << "ProfilerConfig parameter is empty"; + CHECK(!parts[1].empty()) << "ProfilerConfig value is empty for parameter "<< parts[0]; + if (parts[0] == "filename") { + parts[1] = "rank" + std::to_string(ps::MyRank()) + "_" + parts[1]; + } + char* ckey = new char[parts[0].length() + 1]; + std::snprintf(ckey, parts[0].length() + 1, "%s", parts[0].c_str()); + ckeys.push_back(ckey); + + char* cval = new char[parts[1].length() + 1]; + std::snprintf(cval, parts[1].length() + 1, "%s", parts[1].c_str()); + cvals.push_back(cval); + } + MXSetProfilerConfig(elems.size(), &ckeys[0], &cvals[0]); + for (size_t i=0; i < ckeys.size(); i++) { + delete[] ckeys[i]; + delete[] cvals[i]; + } + } + void DataHandleEx(const ps::KVMeta& req_meta, const ps::KVPairs& req_data, ps::KVServer* server) { diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 324bc2c9558a..4e004a3a3008 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -40,6 +40,22 @@ namespace mxnet { namespace kvstore { +/*! + * \brief Splits a string into smaller strings using char as delimiter + * Example: "a,b,c,,d" is split into ["a","b","c","","d"] + * \param s string to split + * \param delim char to split string around + * \param result container for tokens extracted after splitting + */ +template +void split(const std::string &s, const char delim, Out result) { + std::stringstream ss; + ss.str(s); + std::string item; + while (std::getline(ss, item, delim)) { + *(result++) = item; + } +} enum KeyType { kUndefinedKey = -1, diff --git a/tests/nightly/test_server_profiling.py b/tests/nightly/test_server_profiling.py new file mode 100644 index 000000000000..7d157a3e4189 --- /dev/null +++ b/tests/nightly/test_server_profiling.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import mxnet as mx +import json + +key = '99' +shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND +kv = mx.kv.create('dist_sync') + +def init_kv(): + # init kv dns keys + kv.init(key, mx.nd.ones(shape)) + kv.set_optimizer(mx.optimizer.create('sgd')) + return kv, kv.rank, kv.num_workers + +def test_sync_push_pull(): + kv, my_rank, nworker = init_kv() + def check_default_keys(kv, my_rank): + nrepeat = 10 + # checks pull after push in loop, because behavior during + # consecutive pushes doesn't offer any guarantees + for i in range(nrepeat): + kv.push(key, mx.nd.ones(shape, dtype='float32') * (my_rank+1)) + val = mx.nd.zeros(shape, dtype='float32') + kv.pull(key, out=val) + mx.nd.waitall() + check_default_keys(kv, my_rank) + +if __name__ == "__main__": + server_filename_suffix = 'test_profile_server.json' + worker_filename_suffix = 'test_profile_worker.json' + mx.profiler.set_config(filename=server_filename_suffix, profile_all=True, profile_process='server') + mx.profiler.set_config(filename='rank' + str(kv.rank) + '_' + worker_filename_suffix, profile_all=True, profile_process='worker') + mx.profiler.set_state(state='run', profile_process='server') + mx.profiler.set_state(state='run', profile_process='worker') + test_sync_push_pull() + mx.profiler.set_state(state='stop', profile_process='server') + mx.profiler.set_state(state='stop', profile_process='worker') + + import glob, os + + # will only work when launcher mode is local, as used for integration test + if kv.rank == 0: + for rank in range(kv.num_workers): + for suffix in [worker_filename_suffix, server_filename_suffix]: + # throws value error if file is not proper json + filename = 'rank' + str(rank) + '_' + suffix + print(glob.glob('*'), os.getcwd()) + with open(filename, 'r') as f: + j = json.load(f) + + +