Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 152 additions & 66 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import symbol as sym
from . import optimizer as opt
from . import metric
from . import kvstore
from .context import Context, cpu
from .initializer import Xavier

Expand Down Expand Up @@ -74,12 +75,54 @@ def _check_arguments(symbol):
return (data_index, label_index)


def _train(symbol, ctx, input_shape,
arg_params, aux_params,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, logger=None):
"""Inernal training function.
def _split_input_slice(input_shape, num_split):
"""Get input slice from the input shape.

Parameters
----------
input_shape : tuple
The input shape of the net.

num_split : int
The number of split we want to have.

Returns
-------
slices : list of slice
The split slices to get a specific slice.

shapes : list of tuples
The shape of each split slice.

Raises
------
ValueError
If there are two many splits such that some slice can be empty.
"""
batch_size = input_shape[0]
step = (batch_size + num_split - 1) / num_split
slices = []
shapes = []
for k in range(num_split):
begin = min(k * step, batch_size)
end = min((k+1) * step, batch_size)
if begin == end:
raise ValueError('Too many slices such that some splits are empty')
slices.append(slice(begin, end))
s = list(input_shape)
s[0] = end - begin
shapes.append(tuple(s))
return (slices, shapes)


def _train_multi_device(symbol, ctx, input_shape,
arg_params, aux_params,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, logger=None):
"""Internal training function on multiple devices.

This function will also work for single device as well.

Parameters
----------
Expand Down Expand Up @@ -127,80 +170,121 @@ def _train(symbol, ctx, input_shape,
-----
This function will inplace update the NDArrays in arg_parans and aux_states.
"""
assert(len(ctx) == 1)
if logger is None:
logger = logging
# bind the symbol
train_exec = symbol.simple_bind(ctx[0], data=input_shape, grad_req='write')
# preparation
num_device = len(ctx)
logging.info('Start training with %d devices', num_device)

slices, shapes = _split_input_slice(input_shape, num_device)
train_execs = [symbol.simple_bind(ctx=c, data=s, grad_req='write')
for c, s in zip(ctx, shapes)]
arg_names = symbol.list_arguments()
aux_names = symbol.list_auxiliary_states()
arg_arrays = train_exec.arg_arrays
grad_arrays = train_exec.grad_arrays
aux_arrays = train_exec.aux_arrays
# copy initialized parameters to executor parameters
for key, weight in zip(arg_names, arg_arrays):
if key in arg_params:
arg_params[key].copyto(weight)
for key, weight in zip(aux_names, aux_arrays):
if key in aux_params:
aux_params[key].copyto(weight)
# setup helper data structures
# data structure
arg_blocks = [
[x.arg_arrays[index] for x in train_execs]
for index in range(len(train_execs[0].arg_arrays))]
grad_blocks = [
[x.grad_arrays[index] for x in train_execs]
for index in range(len(train_execs[0].grad_arrays))]
aux_blocks = [
[x.aux_arrays[index] for x in train_execs]
for index in range(len(train_execs[0].aux_arrays))]
for name, block in zip(arg_names, arg_blocks):
if name in arg_params:
for w in block:
arg_params[name].copyto(w)
for name, block in zip(aux_names, aux_blocks):
if name in aux_params:
for w in block:
aux_params[name].copyto(w)
# ky value store
kv = kvstore.create() if num_device != 1 else None
# If there are multiple devices, initialize the weights.
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
arg, grad = pair
if kv and grad[0] is not None:
kv.init(index, arg[0])
# Input and output data structure
data_index, label_index = _check_arguments(symbol)
data_array, label_array = arg_arrays[data_index], arg_arrays[label_index]
out_array = train_exec.outputs[0]
out_cpu_array = nd.zeros(out_array.shape)
arg_blocks = list(zip(arg_arrays, grad_arrays))

for i in range(begin_round, end_round):
# training phase
merged_shape = list(train_execs[0].outputs[0].shape)
merged_shape[0] = input_shape[0]
merged_shape = tuple(merged_shape)
out_cpu_array = nd.zeros(merged_shape, cpu())

# Now start training
for iteration in range(begin_round, end_round):
# Training phase
tic = time.time()
train_data.reset()
optimizer.begin_round(i)
optimizer.begin_round(iteration)
eval_metric.reset()

# Iterate over training data.
for data, label in train_data:
label.copyto(label_array)
data.copyto(data_array)
train_exec.forward()
out_array.copyto(out_cpu_array)
train_exec.backward()
# Copy data into the target
for target, islice in zip(arg_blocks[label_index], slices):
label[islice].copyto(target)
for target, islice in zip(arg_blocks[data_index], slices):
data[islice].copyto(target)
# forward backward pass
for texec, islice in zip(train_execs, slices):
texec.forward()
texec.outputs[0].copyto(out_cpu_array[islice])
for texec in train_execs:
texec.backward()
# update the parameters
for index, block in enumerate(arg_blocks):
weight, grad = block
if grad is not None:
optimizer.update(index, weight, grad)
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
# Gradient synchronization
if kv:
# push gradient
kv.push(index, grad_list)
# pull back the sum, to the same locations.
kv.pull(index, grad_list)
# optimize
for w, g in zip(arg_list, grad_list):
optimizer.update(index, w, g)
# evaluate at end, so out_cpu_array can lazy copy
eval_metric.update(out_cpu_array, label)

name, value = eval_metric.get()
logger.info('Iteration[%d] Train-%s=%f', i, name, value)
logger.info('Iteration[%d] Train-%s=%f', iteration, name, value)
toc = time.time()
logger.info('Iteration[%d] Time cost=%.3f', i, (toc - tic))

# evaluation phase
if eval_data is not None:
logger.info('Iteration[%d] Time cost=%.3f', iteration, (toc - tic))
# evaluation
if eval_data:
eval_metric.reset()
eval_data.reset()
for data, label in eval_data:
data.copyto(data_array)
# TODO(bing): add is_train=False
train_exec.forward(is_train=False)
out_array.copyto(out_cpu_array)
eval_metric.update(out_array, label)

# Copy data into the target
for target, islice in zip(arg_blocks[label_index], slices):
label[islice].copyto(target)
for target, islice in zip(arg_blocks[data_index], slices):
data[islice].copyto(target)
# forward pass
for texec, islice in zip(train_execs, slices):
texec.forward(is_train=False)
texec.outputs[0].copyto(out_cpu_array[islice])
eval_metric.update(out_cpu_array, label)
name, value = eval_metric.get()
logger.info('Iteration[%d] Validation-%s=%f', i, name, value)
logger.info('Iteration[%d] Validation-%s=%f', iteration, name, value)

if iter_end_callback or i + 1 == end_round:
if iter_end_callback or iteration + 1 == end_round:
# copy data back to cpu
for key, weight in zip(arg_names, arg_arrays):
if key in arg_params:
weight.copyto(arg_params[key])
for key, arr in zip(aux_names, aux_arrays):
arr.copyto(aux_params[key])
for name, block in zip(arg_names, arg_blocks):
if name in arg_params:
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(arg_params[name])
for name, block in zip(aux_names, aux_blocks):
if name in aux_params:
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(aux_params[name])
if iter_end_callback:
iter_end_callback(i, symbol, arg_params, aux_params)
# end of the function
iter_end_callback(iteration, symbol, arg_params, aux_params)
# end of all iterations
return


Expand Down Expand Up @@ -332,6 +416,8 @@ def __init__(self, symbol, ctx=None,
num_round=None, optimizer='sgd', initializer=Xavier(),
arg_params=None, aux_params=None,
**kwargs):
# check if symbol contain duplicated names.
_check_arguments(symbol)
# basic configuration
self.symbol = symbol
if ctx is None:
Expand Down Expand Up @@ -467,14 +553,14 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
batch_size = input_shape[0]
optimizer = opt.create(optimizer, rescale_grad=(1.0/batch_size), **(self.kwargs))
# do training
_train(self.symbol, self.ctx, input_shape,
self.arg_params, self.aux_params,
begin_round=0, end_round=self.num_round,
optimizer=optimizer,
train_data=X, eval_data=eval_data,
eval_metric=eval_metric,
iter_end_callback=iter_end_callback,
logger=logger)
_train_multi_device(self.symbol, self.ctx, input_shape,
self.arg_params, self.aux_params,
begin_round=0, end_round=self.num_round,
optimizer=optimizer,
train_data=X, eval_data=eval_data,
eval_metric=eval_metric,
iter_end_callback=iter_end_callback,
logger=logger)

def save(self, prefix, iteration=None):
"""Checkpoint the model checkpoint into file.
Expand Down