Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 212 additions & 1 deletion python/mxnet/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-locals
# pylint: disable=invalid-name, protected-access, too-many-locals, too-many-arguments
"""Symbolic Executor component of MXNet."""
from __future__ import absolute_import

Expand All @@ -8,6 +8,9 @@
from .base import mx_uint, NDArrayHandle, ExecutorHandle
from .base import check_call, c_array, py_str
from .ndarray import NDArray
from . import ndarray as nd
from .context import cpu
import logging

class Executor(object):
""" Executor is the actual executing object of MXNet."""
Expand Down Expand Up @@ -216,3 +219,211 @@ def debug_str(self):
check_call(_LIB.MXExecutorPrint(
self.handle, ctypes.byref(debug_str)))
return py_str(debug_str.value)

def _split_input_slice(batch_size, work_load_list):
"""Get input slice from the input shape.
Parameters
----------
batch_size : int
The number of samples in a mini-batch.
work_load_list : list of float or int, optional
The list of work load for different devices,
in the same order as ctx
Returns
-------
slices : list of slice
The split slices to get a specific slice.
Raises
------
ValueError
If there are two many splits such that some slice can be empty.
"""
total_work_load = sum(work_load_list)
batch_num_list = [round(work_load * batch_size / total_work_load)
for work_load in work_load_list]
batch_num_sum = sum(batch_num_list)
if batch_num_sum < batch_size:
batch_num_list[-1] += batch_size - batch_num_sum
slices = []
end = 0
for batch_num in batch_num_list:
begin = int(min((end, batch_size)))
end = int(min((begin + batch_num, batch_size)))
if begin >= end:
raise ValueError('Too many slices such that some splits are empty')
slices.append(slice(begin, end))
return slices

def _check_arguments(symbol):
"""Check the argument names of symbol.
This function checks the duplication of arguments in Symbol.
The check is done for feedforward net for now.
Parameters
----------
symbol : Symbol
The network configuration
"""
arg_set = set()
arg_names = symbol.list_arguments()
for name in arg_names:
if name in arg_set:
raise ValueError(('Find duplicated argument name \"%s\", ' +
'please make the weight name non-duplicated(using name arguments), ' +
'arguments are %s') % (name, str(arg_names)))
arg_set.add(name)

aux_set = set()
aux_names = symbol.list_auxiliary_states()
for name in aux_names:
if name in aux_set:
raise ValueError(
('Find duplicated auxiliary param name \"%s\", ' +
'please make the weight name non-duplicated(using name arguments), ' +
'arguments are %s, auxiliary params are %s'
) % (name, str(arg_names), str(aux_names)))
aux_set.add(name)

def _load_general(data, targets):
"""Load a list of arrays into a list of arrays specified by slices"""
for d_src, d_targets in zip(data, targets):
if isinstance(d_targets, nd.NDArray):
d_src.copyto(d_targets)
else:
for slice_idx, d_dst in d_targets:
d_src[slice_idx].copyto(d_dst)

def _load_data(batch, targets):
"""Load data into sliced arrays"""
_load_general(batch.data, targets)

def _load_label(batch, targets):
"""Load label into sliced arrays"""
_load_general(batch.label, targets)

class DataParallelExecutorManager(object):
""" Helper class to manage multiple executors for data parallelism.
Parameters
----------
symbol : Symbol
output symbol
ctx : list of Context
devices to run on
param_names: list of str
Name of all trainable parameters of the network.
arg_names: list of str
Name of all arguments of the network.
aux_names: list of str
Name of all auxiliary states of the network.
train_data : DataIter
Training data iterator.
work_load_list : list of float or int, optional
The list of work load for different devices,
in the same order as ctx
logger : logging logger
When not specified, default logger will be used.
"""
def __init__(self, symbol, ctx, train_data,
param_names, arg_names, aux_names,
work_load_list=None, logger=None):
if logger is None:
logger = logging
# preparation
num_device = len(ctx)
logger.info('Start training with %s', str(ctx))

# make sure the architecture is valid
_check_arguments(symbol)

if work_load_list is None:
work_load_list = [1] * num_device
assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \
"Invalid settings for work load. "

slices = _split_input_slice(train_data.batch_size, work_load_list)
self.slices = slices

self.train_execs = []
for i in range(len(ctx)):
data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:]))
for k, v in train_data.provide_data}
train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes)
self.train_execs.append(train_exec)

# data structure
self.data_names = [x[0] for x in train_data.provide_data]
self.label_names = [x[0] for x in train_data.provide_label]
self.aux_names = aux_names

self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)]
for name in self.data_names]
self.label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)]
for name in self.label_names]

self.param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names]
self.param_names = [arg_names[i] for i in self.param_idx]
self.param_arrays = [[e.arg_arrays[i] for e in self.train_execs]
for i in self.param_idx]
self.grad_arrays = [[e.grad_arrays[i] for e in self.train_execs]
for i in self.param_idx]

self.aux_arrays = [[e.aux_arrays[i] for e in self.train_execs]
for i in range(len(aux_names))]

batch_size = train_data.batch_size

output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in self.train_execs[0].outputs]
self.cpu_output_arrays = [nd.zeros(s) for s in output_shapes]

def install_monitor(self, monitor):
""" Install monitor on all executors """
for train_exec in self.train_execs:
monitor.install(train_exec)

def set_params(self, arg_params, aux_params):
""" set parameter and aux values
Parameters
----------
arg_params : list of NDArray
source parameter arrays
aux_params : list of NDArray
source aux arrays
"""

for texec in self.train_execs:
texec.copy_params_from(arg_params, aux_params)

def copy_to(self, arg_params, aux_params):
""" Copy data from each executor to `arg_params` and `aux_params`
Parameters
----------
arg_params : list of NDArray
target parameter arrays
aux_params : list of NDArray
target aux arrays
Notes
-----
- This function will inplace update the NDArrays in arg_params and aux_params.
"""
for name, block in zip(self.param_names, self.param_arrays):
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(arg_params[name])
for name, block in zip(self.aux_names, self.aux_arrays):
weight = sum(w.copyto(cpu()) for w in block) / len(block)
weight.copyto(aux_params[name])

def load_data_batch(self, data_batch):
""" load data and labels into arrays """
_load_data(data_batch, self.data_arrays)
_load_label(data_batch, self.label_arrays)

def forward(self, is_train=False):
""" Perform a forward pass on each executor """
for texec, islice in zip(self.train_execs, self.slices):
texec.forward(is_train=is_train)
for cpu_out, dev_out in zip(self.cpu_output_arrays, texec.outputs):
dev_out.copyto(cpu_out[islice])

def backward(self):
""" Perform a backward pass on each executor """
for texec in self.train_execs:
texec.backward()
Loading