diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 2c4614b1dade..1fd73eb4033e 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -9,6 +9,7 @@ from . import symbol as sym from . import optimizer as opt from . import metric +from . import kvstore from .context import Context, cpu from .initializer import Xavier @@ -74,12 +75,54 @@ def _check_arguments(symbol): return (data_index, label_index) -def _train(symbol, ctx, input_shape, - arg_params, aux_params, - begin_round, end_round, optimizer, - train_data, eval_data=None, eval_metric=None, - iter_end_callback=None, logger=None): - """Inernal training function. +def _split_input_slice(input_shape, num_split): + """Get input slice from the input shape. + + Parameters + ---------- + input_shape : tuple + The input shape of the net. + + num_split : int + The number of split we want to have. + + Returns + ------- + slices : list of slice + The split slices to get a specific slice. + + shapes : list of tuples + The shape of each split slice. + + Raises + ------ + ValueError + If there are two many splits such that some slice can be empty. + """ + batch_size = input_shape[0] + step = (batch_size + num_split - 1) / num_split + slices = [] + shapes = [] + for k in range(num_split): + begin = min(k * step, batch_size) + end = min((k+1) * step, batch_size) + if begin == end: + raise ValueError('Too many slices such that some splits are empty') + slices.append(slice(begin, end)) + s = list(input_shape) + s[0] = end - begin + shapes.append(tuple(s)) + return (slices, shapes) + + +def _train_multi_device(symbol, ctx, input_shape, + arg_params, aux_params, + begin_round, end_round, optimizer, + train_data, eval_data=None, eval_metric=None, + iter_end_callback=None, logger=None): + """Internal training function on multiple devices. + + This function will also work for single device as well. Parameters ---------- @@ -127,80 +170,121 @@ def _train(symbol, ctx, input_shape, ----- This function will inplace update the NDArrays in arg_parans and aux_states. """ - assert(len(ctx) == 1) if logger is None: logger = logging - # bind the symbol - train_exec = symbol.simple_bind(ctx[0], data=input_shape, grad_req='write') + # preparation + num_device = len(ctx) + logging.info('Start training with %d devices', num_device) + + slices, shapes = _split_input_slice(input_shape, num_device) + train_execs = [symbol.simple_bind(ctx=c, data=s, grad_req='write') + for c, s in zip(ctx, shapes)] arg_names = symbol.list_arguments() aux_names = symbol.list_auxiliary_states() - arg_arrays = train_exec.arg_arrays - grad_arrays = train_exec.grad_arrays - aux_arrays = train_exec.aux_arrays - # copy initialized parameters to executor parameters - for key, weight in zip(arg_names, arg_arrays): - if key in arg_params: - arg_params[key].copyto(weight) - for key, weight in zip(aux_names, aux_arrays): - if key in aux_params: - aux_params[key].copyto(weight) - # setup helper data structures + # data structure + arg_blocks = [ + [x.arg_arrays[index] for x in train_execs] + for index in range(len(train_execs[0].arg_arrays))] + grad_blocks = [ + [x.grad_arrays[index] for x in train_execs] + for index in range(len(train_execs[0].grad_arrays))] + aux_blocks = [ + [x.aux_arrays[index] for x in train_execs] + for index in range(len(train_execs[0].aux_arrays))] + for name, block in zip(arg_names, arg_blocks): + if name in arg_params: + for w in block: + arg_params[name].copyto(w) + for name, block in zip(aux_names, aux_blocks): + if name in aux_params: + for w in block: + aux_params[name].copyto(w) + # ky value store + kv = kvstore.create() if num_device != 1 else None + # If there are multiple devices, initialize the weights. + for index, pair in enumerate(zip(arg_blocks, grad_blocks)): + arg, grad = pair + if kv and grad[0] is not None: + kv.init(index, arg[0]) + # Input and output data structure data_index, label_index = _check_arguments(symbol) - data_array, label_array = arg_arrays[data_index], arg_arrays[label_index] - out_array = train_exec.outputs[0] - out_cpu_array = nd.zeros(out_array.shape) - arg_blocks = list(zip(arg_arrays, grad_arrays)) - - for i in range(begin_round, end_round): - # training phase + merged_shape = list(train_execs[0].outputs[0].shape) + merged_shape[0] = input_shape[0] + merged_shape = tuple(merged_shape) + out_cpu_array = nd.zeros(merged_shape, cpu()) + + # Now start training + for iteration in range(begin_round, end_round): + # Training phase tic = time.time() train_data.reset() - optimizer.begin_round(i) + optimizer.begin_round(iteration) eval_metric.reset() - + # Iterate over training data. for data, label in train_data: - label.copyto(label_array) - data.copyto(data_array) - train_exec.forward() - out_array.copyto(out_cpu_array) - train_exec.backward() + # Copy data into the target + for target, islice in zip(arg_blocks[label_index], slices): + label[islice].copyto(target) + for target, islice in zip(arg_blocks[data_index], slices): + data[islice].copyto(target) + # forward backward pass + for texec, islice in zip(train_execs, slices): + texec.forward() + texec.outputs[0].copyto(out_cpu_array[islice]) + for texec in train_execs: + texec.backward() # update the parameters - for index, block in enumerate(arg_blocks): - weight, grad = block - if grad is not None: - optimizer.update(index, weight, grad) + for index, pair in enumerate(zip(arg_blocks, grad_blocks)): + arg_list, grad_list = pair + if grad_list[0] is None: + continue + # Gradient synchronization + if kv: + # push gradient + kv.push(index, grad_list) + # pull back the sum, to the same locations. + kv.pull(index, grad_list) + # optimize + for w, g in zip(arg_list, grad_list): + optimizer.update(index, w, g) # evaluate at end, so out_cpu_array can lazy copy eval_metric.update(out_cpu_array, label) name, value = eval_metric.get() - logger.info('Iteration[%d] Train-%s=%f', i, name, value) + logger.info('Iteration[%d] Train-%s=%f', iteration, name, value) toc = time.time() - logger.info('Iteration[%d] Time cost=%.3f', i, (toc - tic)) - - # evaluation phase - if eval_data is not None: + logger.info('Iteration[%d] Time cost=%.3f', iteration, (toc - tic)) + # evaluation + if eval_data: eval_metric.reset() eval_data.reset() for data, label in eval_data: - data.copyto(data_array) - # TODO(bing): add is_train=False - train_exec.forward(is_train=False) - out_array.copyto(out_cpu_array) - eval_metric.update(out_array, label) - + # Copy data into the target + for target, islice in zip(arg_blocks[label_index], slices): + label[islice].copyto(target) + for target, islice in zip(arg_blocks[data_index], slices): + data[islice].copyto(target) + # forward pass + for texec, islice in zip(train_execs, slices): + texec.forward(is_train=False) + texec.outputs[0].copyto(out_cpu_array[islice]) + eval_metric.update(out_cpu_array, label) name, value = eval_metric.get() - logger.info('Iteration[%d] Validation-%s=%f', i, name, value) + logger.info('Iteration[%d] Validation-%s=%f', iteration, name, value) - if iter_end_callback or i + 1 == end_round: + if iter_end_callback or iteration + 1 == end_round: # copy data back to cpu - for key, weight in zip(arg_names, arg_arrays): - if key in arg_params: - weight.copyto(arg_params[key]) - for key, arr in zip(aux_names, aux_arrays): - arr.copyto(aux_params[key]) + for name, block in zip(arg_names, arg_blocks): + if name in arg_params: + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(arg_params[name]) + for name, block in zip(aux_names, aux_blocks): + if name in aux_params: + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(aux_params[name]) if iter_end_callback: - iter_end_callback(i, symbol, arg_params, aux_params) - # end of the function + iter_end_callback(iteration, symbol, arg_params, aux_params) + # end of all iterations return @@ -332,6 +416,8 @@ def __init__(self, symbol, ctx=None, num_round=None, optimizer='sgd', initializer=Xavier(), arg_params=None, aux_params=None, **kwargs): + # check if symbol contain duplicated names. + _check_arguments(symbol) # basic configuration self.symbol = symbol if ctx is None: @@ -467,14 +553,14 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', batch_size = input_shape[0] optimizer = opt.create(optimizer, rescale_grad=(1.0/batch_size), **(self.kwargs)) # do training - _train(self.symbol, self.ctx, input_shape, - self.arg_params, self.aux_params, - begin_round=0, end_round=self.num_round, - optimizer=optimizer, - train_data=X, eval_data=eval_data, - eval_metric=eval_metric, - iter_end_callback=iter_end_callback, - logger=logger) + _train_multi_device(self.symbol, self.ctx, input_shape, + self.arg_params, self.aux_params, + begin_round=0, end_round=self.num_round, + optimizer=optimizer, + train_data=X, eval_data=eval_data, + eval_metric=eval_metric, + iter_end_callback=iter_end_callback, + logger=logger) def save(self, prefix, iteration=None): """Checkpoint the model checkpoint into file.