Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ integrationtest_ubuntu_cpu_dist_kvstore() {
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu --no-multiprecision
../../tools/launch.py -n 3 --launcher local python test_server_profiling.py
}

integrationtest_ubuntu_gpu_scala() {
Expand Down
23 changes: 22 additions & 1 deletion example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def add_fit_args(parser):
help='the epochs to ramp-up lr to scaled large-batch value')
train.add_argument('--warmup-strategy', type=str, default='linear',
help='the ramping-up strategy for large batch sgd')
train.add_argument('--profile-worker-suffix', type=str, default='',
help='profile workers actions into this file. During distributed training\
filename saved will be rank1_ followed by this suffix')
train.add_argument('--profile-server-suffix', type=str, default='',
help='profile server actions into a file with name like rank1_ followed by this suffix \
during distributed training')
return train


Expand All @@ -150,6 +156,17 @@ def fit(args, network, data_loader, **kwargs):
if args.gc_type != 'none':
kv.set_gradient_compression({'type': args.gc_type,
'threshold': args.gc_threshold})
if args.profile_server_suffix:
mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server')
mx.profiler.set_state(state='run', profile_process='server')

if args.profile_worker_suffix:
if kv.num_workers > 1:
filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
else:
filename = args.profile_worker_suffix
mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker')
mx.profiler.set_state(state='run', profile_process='worker')

# logging
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
Expand Down Expand Up @@ -180,7 +197,6 @@ def fit(args, network, data_loader, **kwargs):
logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
args.disp_batches * args.batch_size / (time.time() - tic))
tic = time.time()

return

# load model
Expand Down Expand Up @@ -314,3 +330,8 @@ def fit(args, network, data_loader, **kwargs):
epoch_end_callback=checkpoint,
allow_missing=True,
monitor=monitor)

if args.profile_server_suffix:
mx.profiler.set_state(state='run', profile_process='server')
if args.profile_worker_suffix:
mx.profiler.set_state(state='run', profile_process='worker')
59 changes: 52 additions & 7 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,19 @@ MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
MXNET_DLL int MXNotifyShutdown();

/*!
* \brief Set up configuration of profiler
* \brief Set up configuration of profiler for the process passed as profile_process in keys
* \param num_params Number of parameters
* \param keys array of parameter keys
* \param vals array of parameter values
* \param kvstoreHandle handle to kvstore
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys,
const char* const* vals,
KVStoreHandle kvstoreHandle);

/*!
* \brief Set up configuration of profiler for worker/current process
* \param num_params Number of parameters
* \param keys array of parameter keys
* \param vals array of parameter values
Expand All @@ -239,7 +251,21 @@ MXNET_DLL int MXNotifyShutdown();
MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals);

/*!
* \brief Set up state of profiler
* \brief Set up state of profiler for either worker or server process
* \param state indicate the working state of profiler,
* profiler not running when state == 0,
* profiler running when state == 1
* \param profile_process an int,
* when 0 command is for worker/current process,
* when 1 command is for server process
* \param kvstoreHandle handle to kvstore, needed for server process profiling
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process,
KVStoreHandle kvStoreHandle);

/*!
* \brief Set up state of profiler for current process
* \param state indicate the working state of profiler,
* profiler not running when state == 0,
* profiler running when state == 1
Expand All @@ -250,11 +276,22 @@ MXNET_DLL int MXSetProfilerState(int state);
/*!
* \brief Save profile and stop profiler
* \param finished true if stat output should stop after this point
* \param profile_process an int,
* when 0 command is for worker/current process,
* when 1 command is for server process
* \param kvstoreHandle handle to kvstore
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXDumpProfile(int finished);
MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle);


/*!
* \brief Save profile and stop profiler for worker/current process
* \param finished true if stat output should stop after this point
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXDumpProfile(int finished);

/*!
* \brief Print aggregate stats to the a string
* \param out_str Will receive a pointer to the output string
Expand All @@ -267,6 +304,16 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
/*!
* \brief Pause profiler tuning collection
* \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
* \param profile_process integer which denotes whether to process worker or server process
* \param kvstoreHandle handle to kvstore
* \return 0 when success, -1 when failure happens.
* \note pausing and resuming is global and not recursive
*/
MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle);

/*!
* \brief Pause profiler tuning collection for worker/current process
* \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
* \return 0 when success, -1 when failure happens.
* \note pausing and resuming is global and not recursive
*/
Expand Down Expand Up @@ -2145,8 +2192,7 @@ typedef void (MXKVStoreServerController)(int head,
void *controller_handle);

/**
* \return Run as server (or scheduler)
*
* \brief Run as server (or scheduler)
* \param handle handle to the KVStore
* \param controller the user-defined server controller
* \param controller_handle helper handle for implementing controller
Expand All @@ -2157,8 +2203,7 @@ MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
void *controller_handle);

/**
* \return Send a command to all server nodes
*
* \brief Send a command to all server nodes
* \param handle handle to the KVStore
* \param cmd_id the head of the command
* \param cmd_body the body of the command
Expand Down
26 changes: 26 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@
#endif // MXNET_USE_DIST_KVSTORE

namespace mxnet {

/*!
* \brief enum to denote types of commands kvstore sends to server regarding profiler
* kSetConfig sets profiler configs. Similar to mx.profiler.set_config()
* kState allows changing state of profiler to stop or run
* kPause allows pausing and resuming of profiler
* kDump asks profiler to dump output
*/
enum class KVStoreServerProfilerCommand {
kSetConfig, kState, kPause, kDump
};

/*!
* \brief distributed key-value store
*
Expand Down Expand Up @@ -364,6 +376,20 @@ class KVStore {
*/
virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }

/**
* \brief Sends server profiler commands to all server nodes
* Only the worker with rank=0 sends the command which will be received by all servers
* \param type ProfilerCommand type
* \param params parameters for that command in the form of a string
*/
virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
const std::string& params) {
LOG(INFO) << "Unable to pass server the profiler command. If you are using "
<< "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
<< "If you are training on single machine, then there is no server process"
<< "to profile. Please profile the worker process instead.";
}

/**
* \brief the prototype of a server controller
*/
Expand Down
8 changes: 6 additions & 2 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .base import check_call, string_types, mx_uint, py_str
from .base import NDArrayHandle, KVStoreHandle
from . import optimizer as opt
from .profiler import set_kvstore_handle

def _ctype_key_value(keys, vals):
"""
Expand Down Expand Up @@ -88,7 +89,8 @@ def _get_kvstore_server_command_type(command):
'kSetMultiPrecision': 1,
'kStopServer': 2,
'kSyncMode': 3,
'kSetGradientCompression': 4}
'kSetGradientCompression': 4,
'kSetProfilerParams': 5}
assert (command in command_types), "Unknown command type to send to server"
return command_types[command]

Expand Down Expand Up @@ -670,4 +672,6 @@ def create(name='local'):
handle = KVStoreHandle()
check_call(_LIB.MXKVStoreCreate(c_str(name),
ctypes.byref(handle)))
return KVStore(handle)
kv = KVStore(handle)
set_kvstore_handle(kv.handle)
return kv
79 changes: 63 additions & 16 deletions python/mxnet/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
from __future__ import absolute_import
import ctypes
import warnings
from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str
from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle

profiler_kvstore_handle = KVStoreHandle()

def set_kvstore_handle(handle):
global profiler_kvstore_handle
profiler_kvstore_handle = handle

def set_config(**kwargs):
"""Set up the configure of profiler (only accepts keyword arguments).
Expand All @@ -49,12 +54,17 @@ def set_config(**kwargs):
aggregate_stats : boolean,
whether to maintain aggregate stats in memory for console
dump. Has some negative performance impact.
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
kk = kwargs.keys()
vv = kwargs.values()
check_call(_LIB.MXSetProfilerConfig(len(kwargs),
c_str_array([key for key in kk]),
c_str_array([str(val) for val in vv])))
check_call(_LIB.MXSetProcessProfilerConfig(len(kwargs),
c_str_array([key for key in kk]),
c_str_array([str(val) for val in vv]),
profiler_kvstore_handle))


def profiler_set_config(mode='symbolic', filename='profile.json'):
Expand All @@ -73,20 +83,27 @@ def profiler_set_config(mode='symbolic', filename='profile.json'):
keys = c_str_array([key for key in ["profile_" + mode, "filename"]])
values = c_str_array([str(val) for val in [True, filename]])
assert len(keys) == len(values)
check_call(_LIB.MXSetProfilerConfig(len(keys), keys, values))
check_call(_LIB.MXSetProcessProfilerConfig(len(keys), keys, values, profiler_kvstore_handle))


def set_state(state='stop'):
def set_state(state='stop', profile_process='worker'):
"""Set up the profiler state to 'run' or 'stop'.

Parameters
----------
state : string, optional
Indicates whether to run the profiler, can
be 'stop' or 'run'. Default is `stop`.
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
state2int = {'stop': 0, 'run': 1}
check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state])))
profile_process2int = {'worker': 0, 'server': 1}
check_call(_LIB.MXSetProcessProfilerState(ctypes.c_int(state2int[state]),
profile_process2int[profile_process],
profiler_kvstore_handle))


def profiler_set_state(state='stop'):
Expand All @@ -102,7 +119,7 @@ def profiler_set_state(state='stop'):
'Please use profiler.set_state() instead')
set_state(state)

def dump(finished=True):
def dump(finished=True, profile_process='worker'):
"""Dump profile and stop profiler. Use this to save profile
in advance in case your program cannot exit normally.

Expand All @@ -111,9 +128,16 @@ def dump(finished=True):
finished : boolean
Indicates whether to stop statistic output (dumping) after this dump.
Default is True
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
fin = 1 if finished is True else False
check_call(_LIB.MXDumpProfile(fin))
fin = 1 if finished is True else 0
profile_process2int = {'worker': 0, 'server': 1}
check_call(_LIB.MXDumpProcessProfile(fin,
profile_process2int[profile_process],
profiler_kvstore_handle))


def dump_profile():
Expand All @@ -138,14 +162,37 @@ def dumps(reset=False):
return py_str(debug_str.value)


def pause():
"""Pause profiling."""
check_call(_LIB.MXProfilePause(int(1)))
def pause(profile_process='worker'):
"""Pause profiling.

Parameters
----------
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
profile_process2int = {'worker': 0, 'server': 1}
check_call(_LIB.MXProcessProfilePause(int(1),
profile_process2int[profile_process],
profiler_kvstore_handle))


def resume(profile_process='worker'):
"""
Resume paused profiling.

def resume():
"""Resume paused profiling."""
check_call(_LIB.MXProfilePause(int(0)))
Parameters
----------
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
profile_process2int = {'worker': 0, 'server': 1}
check_call(_LIB.MXProcessProfilePause(int(0),
profile_process2int[profile_process],
profiler_kvstore_handle))


class Domain(object):
Expand Down
Loading