diff --git a/dmlc-core b/dmlc-core index 27013a86f8b8..a9b3320d2c6b 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 27013a86f8b8fd8bb9ebf2253928436e0eb38e13 +Subproject commit a9b3320d2c6b29506139784b877142c9ee78caaf diff --git a/example/fcn-xs/README.md b/example/fcn-xs/README.md new file mode 100644 index 000000000000..a902fcdee7ac --- /dev/null +++ b/example/fcn-xs/README.md @@ -0,0 +1,66 @@ +FCN-xs EXAMPLES +--------------- +This folder contains the examples of image segmentation in MXNet. + +## Sample results +![fcn-xs pasval_voc result](https://github.com/dmlc/web-data/blob/master/mxnet/image/fcnxs-example-result.jpg) + +we have trained a simple fcn-xs model, the parameter is below: + +| model | lr (fixed) | epoch | +| ---- | ----: | ---------: | +| fcn-32s | 1e-10 | 31 | +| fcn-16s | 1e-12 | 27 | +| fcn-8s | 1e-14 | 19 | + +the training image number is only : 2027, and the Validation image number is: 462 + +## How to train fcn-xs in mxnet +#### step1: download the vgg16fc model and experiment data +* vgg16fc model : you can download the ```VGG_FC_ILSVRC_16_layers-symbol.json``` and ```VGG_FC_ILSVRC_16_layers-0074.params``` from [yun.baidu](http://pan.baidu.com/s/1bgz4PC). +this is the fully convolution style of the origin +[VGG_ILSVRC_16_layers.caffemodel](http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel), and the corresponding [VGG_ILSVRC_16_layers_deploy.prototxt](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt), the vgg16 model has [license](http://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only. +* experiment data : you can download the ```VOC2012.rar``` from [yun.baidu](http://pan.baidu.com/s/1bgz4PC), and Extract it. the file/folder will be like: +```JPEGImages folder```, ```SegmentationClass folder```, ```train.lst```, ```val.lst```, ```test.lst``` + +#### step2: train fcn-xs model +* if you want to train the fcn-8s model, it's better for you trained the fcn-32s and fcn-16s model firstly. +when training the fcn-32s model, run in shell ```./run_fcnxs.sh```, the script in it is: +```shell +python -u fcn_xs.py --model=fcn32s --prefix=VGG_FC_ILSVRC_16_layers --epoch=74 --init-type=vgg16 +``` +* in the fcn_xs.py, you may need to change the directory ```root_dir```, ```flist_name```, ``fcnxs_model_prefix``` for your own data. +* when you train fcn-16s or fcn-8s model, you should change the code in ```run_fcnxs.sh``` corresponding, such as when train fcn-16s, comment out the fcn32s script, then it will like this: +```shell + python -u fcn_xs.py --model=fcn16s --prefix=FCN32s_VGG16 --epoch=31 --init-type=fcnxs +``` +* the output log may like this(when training fcn-8s): +```c++ +INFO:root:Start training with gpu(3) +INFO:root:Epoch[0] Batch [50] Speed: 1.16 samples/sec Train-accuracy=0.894318 +INFO:root:Epoch[0] Batch [100] Speed: 1.11 samples/sec Train-accuracy=0.904681 +INFO:root:Epoch[0] Batch [150] Speed: 1.13 samples/sec Train-accuracy=0.908053 +INFO:root:Epoch[0] Batch [200] Speed: 1.12 samples/sec Train-accuracy=0.912219 +INFO:root:Epoch[0] Batch [250] Speed: 1.13 samples/sec Train-accuracy=0.914238 +INFO:root:Epoch[0] Batch [300] Speed: 1.13 samples/sec Train-accuracy=0.912170 +INFO:root:Epoch[0] Batch [350] Speed: 1.12 samples/sec Train-accuracy=0.912080 +``` + +## Using the pre-trained model for image segmentation +* similarly, you should firstly download the pre-trained model from [yun.baidu](http://pan.baidu.com/s/1bgz4PC), the symbol and model file is ```FCN8s_VGG16-symbol.json```, ```FCN8s_VGG16-0019.params``` +* then put the image in your directory for segmentation, and change the ```img = YOUR_IMAGE_NAME``` in ```image_segmentaion.py``` +* lastly, use ```image_segmentaion.py``` to segmentation one image by run in shell ```python image_segmentaion.py```, then you will get the segmentation image like the sample result above. + +## Tips +* this is the whole image size training, that is to say, we do not need resize/crop the image to the same size, so the batch_size during training is set to 1. +* the fcn-xs model is baed on vgg16 model, with some crop, deconv, element-sum layer added, so the model is some big, moreover, the example is using whole image size training, if the input image is some large(such as 700*500), then it may very memory consumption, so I suggest you using the GPU with 12G memory. +* if you don't have GPU with 12G memory, maybe you shoud change the ```cut_off_size``` to be a small value when you construct your FileIter, like this: +```python +train_dataiter = FileIter( + root_dir = "./VOC2012", + flist_name = "train.lst", + cut_off_size = 400, + rgb_mean = (123.68, 116.779, 103.939), + ) +``` +* we are looking forward you to make this example more powerful, thanks. diff --git a/example/fcn-xs/data.py b/example/fcn-xs/data.py new file mode 100644 index 000000000000..dcc958ea481a --- /dev/null +++ b/example/fcn-xs/data.py @@ -0,0 +1,122 @@ +# pylint: skip-file +""" file iterator for pasval voc 2012""" +import mxnet as mx +import numpy as np +import sys, os +from mxnet.io import DataIter +from PIL import Image + +class FileIter(DataIter): + """FileIter object in fcn-xs example. Taking a file list file to get dataiter. + in this example, we use the whole image training for fcn-xs, that is to say + we do not need resize/crop the image to the same size, so the batch_size is + set to 1 here + Parameters + ---------- + root_dir : string + the root dir of image/label lie in + flist_name : string + the list file of iamge and label, every line owns the form: + index \t image_data_path \t image_label_path + cut_off_size : int + if the maximal size of one image is larger than cut_off_size, then it will + crop the image with the minimal size of that image + data_name : string + the data name used in symbol data(default data name) + label_name : string + the label name used in symbol softmax_label(default label name) + """ + def __init__(self, root_dir, flist_name, + rgb_mean = (117, 117, 117), + cut_off_size = None, + data_name = "data", + label_name = "softmax_label"): + super(FileIter, self).__init__() + self.root_dir = root_dir + self.flist_name = os.path.join(self.root_dir, flist_name) + self.mean = np.array(rgb_mean) # (R, G, B) + self.cut_off_size = cut_off_size + self.data_name = data_name + self.label_name = label_name + + self.num_data = len(open(self.flist_name, 'r').readlines()) + self.f = open(self.flist_name, 'r') + self.data, self.label = self._read() + self.cursor = -1 + + def _read(self): + """get two list, each list contains two elements: name and nd.array value""" + _, data_img_name, label_img_name = self.f.readline().strip('\n').split("\t") + data = {} + label = {} + data[self.data_name], label[self.label_name] = self._read_img(data_img_name, label_img_name) + return list(data.items()), list(label.items()) + + def _read_img(self, img_name, label_name): + img = Image.open(os.path.join(self.root_dir, img_name)) + label = Image.open(os.path.join(self.root_dir, label_name)) + assert img.size == label.size + img = np.array(img, dtype=np.float32) # (h, w, c) + label = np.array(label) # (h, w) + if self.cut_off_size is not None: + max_hw = max(img.shape[0], img.shape[1]) + min_hw = min(img.shape[0], img.shape[1]) + if min_hw > self.cut_off_size: + rand_start_max = round(np.random.uniform(0, max_hw - self.cut_off_size - 1)) + rand_start_min = round(np.random.uniform(0, min_hw - self.cut_off_size - 1)) + if img.shape[0] == max_hw : + img = img[rand_start_max : rand_start_max + self.cut_off_size, rand_start_min : rand_start_min + self.cut_off_size] + label = label[rand_start_max : rand_start_max + self.cut_off_size, rand_start_min : rand_start_min + self.cut_off_size] + else : + img = img[rand_start_min : rand_start_min + self.cut_off_size, rand_start_max : rand_start_max + self.cut_off_size] + label = label[rand_start_min : rand_start_min + self.cut_off_size, rand_start_max : rand_start_max + self.cut_off_size] + elif max_hw > self.cut_off_size: + rand_start = round(np.random.uniform(0, max_hw - min_hw - 1)) + if img.shape[0] == max_hw : + img = img[rand_start : rand_start + min_hw, :] + label = label[rand_start : rand_start + min_hw, :] + else : + img = img[:, rand_start : rand_start + min_hw] + label = label[:, rand_start : rand_start + min_hw] + reshaped_mean = self.mean.reshape(1, 1, 3) + img = img - reshaped_mean + img = np.swapaxes(img, 0, 2) + img = np.swapaxes(img, 1, 2) # (c, h, w) + img = np.expand_dims(img, axis=0) # (1, c, h, w) + label = np.array(label) # (h, w) + label = np.expand_dims(label, axis=0) # (1, h, w) + return (img, label) + + @property + def provide_data(self): + """The name and shape of data provided by this iterator""" + return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.data] + + @property + def provide_label(self): + """The name and shape of label provided by this iterator""" + return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.label] + + def get_batch_size(self): + return 1 + + def reset(self): + self.cursor = -1 + self.f.close() + self.f = open(self.flist_name, 'r') + + def iter_next(self): + self.cursor += 1 + if(self.cursor < self.num_data-1): + return True + else: + return False + + def next(self): + """return one dict which contains "data" and "label" """ + if self.iter_next(): + self.data, self.label = self._read() + return {self.data_name : self.data[0][1], + self.label_name : self.label[0][1]} + else: + raise StopIteration diff --git a/example/fcn-xs/fcn_xs.py b/example/fcn-xs/fcn_xs.py new file mode 100644 index 000000000000..01344a7b123f --- /dev/null +++ b/example/fcn-xs/fcn_xs.py @@ -0,0 +1,79 @@ +# pylint: skip-file +import sys, os +import argparse +# mxnet_train = "/home/work/wuwei/tools/mxnet/lib/python2.7/site-packages/mxnet-0.5.0-py2.7.egg" +mxnet_train = "/home/work/wuwei/.local/lib/python2.7/site-packages/mxnet-0.5.0-py2.7.egg" +if mxnet_train in sys.path: + sys.path.remove(mxnet_train) +sys.path.insert(0, mxnet_train) +import mxnet as mx +import numpy as np +import logging +import symbol_fcnxs +import init_fcnxs +from data import FileIter +from solver import Solver + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +ctx = mx.gpu(0) + +def main(): + fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=21, workspace_default=1536) + fcnxs_model_prefix = "model_pascal/FCN32s_VGG16" + if args.model == "fcn16s": + fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=21, workspace_default=1536) + fcnxs_model_prefix = "model_pascal/FCN16s_VGG16" + elif args.model == "fcn8s": + fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=21, workspace_default=1536) + fcnxs_model_prefix = "model_pascal/FCN8s_VGG16" + arg_names = fcnxs.list_arguments() + _, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch) + if not args.retrain: + if args.init_type == "vgg16": + fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs) + elif args.init_type == "fcnxs": + fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs) + train_dataiter = FileIter( + root_dir = "./VOC2012", + flist_name = "train.lst", + # cut_off_size = 400, + rgb_mean = (123.68, 116.779, 103.939), + ) + val_dataiter = FileIter( + root_dir = "./VOC2012", + flist_name = "val.lst", + # cut_off_size = 400, + rgb_mean = (123.68, 116.779, 103.939), + ) + model = Solver( + ctx = ctx, + symbol = fcnxs, + begin_epoch = 0, + num_epoch = 50, + arg_params = fcnxs_args, + aux_params = fcnxs_auxs, + learning_rate = 1e-10, + momentum = 0.99, + wd = 0.0005) + model.fit( + train_data = train_dataiter, + eval_data = val_dataiter, + batch_end_callback = mx.callback.Speedometer(1, 10), + epoch_end_callback = mx.callback.do_checkpoint(fcnxs_model_prefix)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Convert vgg16 model to vgg16fc model.') + parser.add_argument('--model', default='fcnxs', + help='The type of fcn-xs model, e.g. fcnxs, fcn16s, fcn8s.') + parser.add_argument('--prefix', default='VGG_FC_ILSVRC_16_layers', + help='The prefix(include path) of vgg16 model with mxnet format.') + parser.add_argument('--epoch', type=int, default=74, + help='The epoch number of vgg16 model.') + parser.add_argument('--init-type', default="vgg16", + help='the init type of fcn-xs model, e.g. vgg16, fcnxs') + parser.add_argument('--retrain', action='store_true', default=False, + help='true means continue training.') + args = parser.parse_args() + logging.info(args) + main() diff --git a/example/fcn-xs/image_segmentaion.py b/example/fcn-xs/image_segmentaion.py new file mode 100644 index 000000000000..56c7482fcb81 --- /dev/null +++ b/example/fcn-xs/image_segmentaion.py @@ -0,0 +1,60 @@ +# pylint: skip-file +import numpy as np +import mxnet as mx +from PIL import Image + +pallete = [ 0,0,0, + 128,0,0, + 0,128,0, + 128,128,0, + 0,0,128, + 128,0,128, + 0,128,128, + 128,128,128, + 64,0,0, + 192,0,0, + 64,128,0, + 192,128,0, + 64,0,128, + 192,0,128, + 64,128,128, + 192,128,128, + 0,64,0, + 128,64,0, + 0,192,0, + 128,192,0, + 0,64,128 ] +img = "./person_bicycle.jpg" +seg = img.replace("jpg", "png") +model_previx = "FCN8s_VGG16" +epoch = 19 +ctx = mx.gpu(0) + +def get_data(img_path): + """get the (1, 3, h, w) np.array data for the img_path""" + mean = np.array([123.68, 116.779, 103.939]) # (R,G,B) + img = Image.open(img_path) + img = np.array(img, dtype=np.float32) + reshaped_mean = mean.reshape(1, 1, 3) + img = img - reshaped_mean + img = np.swapaxes(img, 0, 2) + img = np.swapaxes(img, 1, 2) + img = np.expand_dims(img, axis=0) + return img + +def main(): + fcnxs, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(model_previx, epoch) + fcnxs_args["data"] = mx.nd.array(get_data(img), ctx) + data_shape = fcnxs_args["data"].shape + label_shape = (1, data_shape[2]*data_shape[3]) + fcnxs_args["softmax_label"] = mx.nd.empty(label_shape, ctx) + exector = fcnxs.bind(ctx, fcnxs_args ,args_grad=None, grad_req="null", aux_states=fcnxs_args) + exector.forward(is_train=False) + output = exector.outputs[0] + out_img = np.uint8(np.squeeze(output.asnumpy().argmax(axis=1))) + out_img = Image.fromarray(out_img) + out_img.putpalette(pallete) + out_img.save(seg) + +if __name__ == "__main__": + main() diff --git a/example/fcn-xs/init_fcnxs.py b/example/fcn-xs/init_fcnxs.py new file mode 100644 index 000000000000..69295ce6be68 --- /dev/null +++ b/example/fcn-xs/init_fcnxs.py @@ -0,0 +1,89 @@ +# pylint: skip-file +import mxnet as mx +import numpy as np +import sys +import logging + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +# make a bilinear interpolation kernel, return a numpy.ndarray +def upsample_filt(size): + factor = (size + 1) // 2 + if size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = np.ogrid[:size, :size] + return (1 - abs(og[0] - center) / factor) * \ + (1 - abs(og[1] - center) / factor) + +def init_from_vgg16(ctx, fcnxs_symbol, vgg16fc_args, vgg16fc_auxs): + fcnxs_args = vgg16fc_args.copy() + fcnxs_auxs = vgg16fc_auxs.copy() + for k,v in fcnxs_args.items(): + if(v.context != ctx): + fcnxs_args[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcnxs_args[k]) + for k,v in fcnxs_auxs.items(): + if(v.context != ctx): + fcnxs_auxs[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcnxs_auxs[k]) + data_shape=(1,3,500,500) + arg_names = fcnxs_symbol.list_arguments() + arg_shapes, _, _ = fcnxs_symbol.infer_shape(data=data_shape) + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_weight', 'score_bias', 'score_pool4_weight', 'score_pool4_bias', \ + 'score_pool3_weight', 'score_pool3_bias']]) + fcnxs_args.update(rest_params) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) + if x[0] in ["bigscore_weight", 'score2_weight', 'score4_weight']]) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcnxs_args[k] = mx.nd.array(initw, ctx) + return fcnxs_args, fcnxs_auxs + +def init_from_fcnxs(ctx, fcnxs_symbol, fcnxs_args_from, fcnxs_auxs_from): + """ use zero initialization for better convergence, because it tends to oputut 0, + and the label 0 stands for background, which may occupy most size of one image. + """ + fcnxs_args = fcnxs_args_from.copy() + fcnxs_auxs = fcnxs_auxs_from.copy() + for k,v in fcnxs_args.items(): + if(v.context != ctx): + fcnxs_args[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcnxs_args[k]) + for k,v in fcnxs_auxs.items(): + if(v.context != ctx): + fcnxs_auxs[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcnxs_auxs[k]) + data_shape=(1,3,500,500) + arg_names = fcnxs_symbol.list_arguments() + arg_shapes, _, _ = fcnxs_symbol.infer_shape(data=data_shape) + rest_params = {} + deconv_params = {} + # this is fcn8s init from fcn16s + if 'score_pool3_weight' in arg_names: + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_pool3_bias', 'score_pool3_weight']]) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ["bigscore_weight", 'score4_weight']]) + # this is fcn16s init from fcn32s + elif 'score_pool4_weight' in arg_names: + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_pool4_weight', 'score_pool4_bias']]) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ["bigscore_weight", 'score2_weight']]) + # this is fcn32s init + else: + logging.error("you are init the fcn32s model, so you should use init_from_vgg16()") + sys.exit() + fcnxs_args.update(rest_params) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcnxs_args[k] = mx.nd.array(initw, ctx) + return fcnxs_args, fcnxs_auxs diff --git a/example/fcn-xs/run_fcnxs.sh b/example/fcn-xs/run_fcnxs.sh new file mode 100755 index 000000000000..926f3f840415 --- /dev/null +++ b/example/fcn-xs/run_fcnxs.sh @@ -0,0 +1,11 @@ +# train fcn-32s model +python -u fcn_xs.py --model=fcn32s --prefix=VGG_FC_ILSVRC_16_layers \ + --epoch=74 --init-type=vgg16 + +## train fcn-16s model +#python -u fcn_xs.py --model=fcn16s --prefix=FCN32s_VGG16 \ + #--epoch=31 --init-type=fcnxs + +# train fcn-8s model +#python -u fcn_xs.py --model=fcn8s --prefix=FCN16s_VGG16 \ + #--epoch=27 --init-type=fcnxs diff --git a/example/fcn-xs/solver.py b/example/fcn-xs/solver.py new file mode 100644 index 000000000000..953e0a986fd2 --- /dev/null +++ b/example/fcn-xs/solver.py @@ -0,0 +1,126 @@ +# pylint: skip-file +import numpy as np +import mxnet as mx +import time +import logging +from collections import namedtuple +from mxnet import optimizer as opt +from mxnet.optimizer import get_updater +from mxnet import metric + +# Parameter to pass to batch_end_callback +BatchEndParam = namedtuple('BatchEndParams', ['epoch', 'nbatch', 'eval_metric']) +class Solver(object): + def __init__(self, symbol, ctx=None, + begin_epoch=0, num_epoch=None, + arg_params=None, aux_params=None, + optimizer='sgd', **kwargs): + self.symbol = symbol + if ctx is None: + ctx = mx.cpu() + self.ctx = ctx + self.begin_epoch = begin_epoch + self.num_epoch = num_epoch + self.arg_params = arg_params + self.aux_params = aux_params + self.optimizer = optimizer + self.kwargs = kwargs.copy() + + def fit(self, train_data, eval_data=None, + eval_metric='acc', + grad_req='write', + epoch_end_callback=None, + batch_end_callback=None, + kvstore='local', + logger=None): + if logger is None: + logger = logging + logging.info('Start training with %s', str(self.ctx)) + arg_shapes, out_shapes, aux_shapes = self.symbol.infer_shape(data=train_data.provide_data[0][1]) + arg_names = self.symbol.list_arguments() + if grad_req != 'null': + self.grad_params = {} + for name, shape in zip(arg_names, arg_shapes): + if not (name.endswith('data') or name.endswith('label')): + self.grad_params[name] = mx.nd.zeros(shape, self.ctx) + else: + self.grad_params = None + aux_names = self.symbol.list_auxiliary_states() + self.aux_params = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)} + data_name = train_data.data_name + label_name = train_data.label_name + input_names = [data_name, label_name] + self.optimizer = opt.create(self.optimizer, rescale_grad=(1.0/train_data.get_batch_size()), **(self.kwargs)) + self.updater = get_updater(self.optimizer) + eval_metric = metric.create(eval_metric) + # begin training + for epoch in range(self.begin_epoch, self.num_epoch): + nbatch = 0 + train_data.reset() + eval_metric.reset() + for data in train_data: + nbatch += 1 + label_shape = data[label_name].shape + self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) + self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \ + label_shape[1]*label_shape[2]), self.ctx) + output_names = self.symbol.list_outputs() + self.exector = self.symbol.bind(self.ctx, self.arg_params, + args_grad=self.grad_params, + grad_req=grad_req, + aux_states=self.aux_params) + assert len(self.symbol.list_arguments()) == len(self.exector.grad_arrays) + update_dict = {name: nd for name, nd in zip(self.symbol.list_arguments(), \ + self.exector.grad_arrays) if nd} + output_dict = {} + output_buff = {} + for key, arr in zip(self.symbol.list_outputs(), self.exector.outputs): + output_dict[key] = arr + output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) + self.exector.forward(is_train=True) + for key in output_dict: + output_dict[key].copyto(output_buff[key]) + self.exector.backward() + for key, arr in update_dict.items(): + if key != "bigscore_weight": + self.updater(key, arr, self.arg_params[key]) + pred_shape = self.exector.outputs[0].shape + label = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2])) + pred = mx.nd.array(output_buff["softmax_output"].asnumpy().reshape(pred_shape[0], \ + pred_shape[1], pred_shape[2]*pred_shape[3])) + eval_metric.update([label], [pred]) + self.exector.outputs[0].wait_to_read() + batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric) + batch_end_callback(batch_end_params) + if epoch_end_callback != None: + epoch_end_callback(epoch, self.symbol, self.arg_params, self.aux_params) + name, value = eval_metric.get() + logger.info(" --->Epoch[%d] Train-%s=%f", epoch, name, value) + # evaluation + if eval_data: + logger.info(" in eval process...") + nbatch = 0 + eval_data.reset() + eval_metric.reset() + for data in eval_data: + nbatch += 1 + label_shape = data[label_name].shape + self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) + self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \ + label_shape[1]*label_shape[2]), self.ctx) + exector = self.symbol.bind(self.ctx, self.arg_params, + args_grad=self.grad_params, + grad_req=grad_req, + aux_states=self.aux_params) + cpu_output_array = mx.nd.zeros(exector.outputs[0].shape) + exector.forward(is_train=False) + exector.outputs[0].copyto(cpu_output_array) + pred_shape = cpu_output_array.shape + label = mx.nd.array(data[label_name].reshape(label_shape[0], \ + label_shape[1]*label_shape[2])) + pred = mx.nd.array(cpu_output_array.asnumpy().reshape(pred_shape[0], \ + pred_shape[1], pred_shape[2]*pred_shape[3])) + eval_metric.update([label], [pred]) + exector.outputs[0].wait_to_read() + name, value = eval_metric.get() + logger.info('batch[%d] Validation-%s=%f', nbatch, name, value) diff --git a/example/fcn-xs/symbol_fcnxs.py b/example/fcn-xs/symbol_fcnxs.py new file mode 100644 index 000000000000..ab283fa13f50 --- /dev/null +++ b/example/fcn-xs/symbol_fcnxs.py @@ -0,0 +1,189 @@ +# pylint: skip-file +import mxnet as mx + +def filter_map(kernel=1, stride=1, pad=0): + # why not return (stride, (kernel-stride)/2-pad)?? + return (stride, (kernel-1)/2-pad) + +def compose_fp(fp_first, fp_second): + return (fp_first[0]*fp_second[0], fp_first[0]*fp_second[1]+fp_first[1]) + +def compose_fp_list(fp_list): + fp_out = (1.0, 0.0) + for fp in fp_list: + fp_out = compose_fp(fp_out, fp) + return fp_out + +def inv_fp(fp_in): + return (1.0/fp_in[0], -1.0*fp_in[1]/fp_in[0]) + +def offset(): + conv1_1_fp = filter_map(kernel=3, pad=100) + conv1_2_fp = conv2_1_fp = conv2_2_fp = conv3_1_fp = conv3_2_fp = conv3_3_fp \ + = conv4_1_fp = conv4_2_fp = conv4_3_fp = conv5_1_fp = conv5_2_fp \ + = conv5_3_fp = filter_map(kernel=3, pad=1) + pool1_fp = pool2_fp = pool3_fp = pool4_fp = pool5_fp = filter_map(kernel=2, stride=2) + fc6_fp = filter_map(kernel=7) + fc7_fp = score_fp = score_pool4_fp = score_pool3_fp = filter_map() + # for fcn-32s + fcn32s_upscore_fp = inv_fp(filter_map(kernel=64, stride=32)) + fcn32s_upscore_list = [conv1_1_fp, conv1_2_fp, pool1_fp, conv2_1_fp, conv2_2_fp, + pool2_fp, conv3_1_fp, conv3_2_fp, conv3_3_fp, pool3_fp, + conv4_1_fp, conv4_2_fp, conv4_3_fp, pool4_fp, conv5_1_fp, + conv5_2_fp, conv5_3_fp, pool5_fp, fc6_fp, fc7_fp, score_fp, + fcn32s_upscore_fp] + crop = {} + crop["fcn32s_upscore"] = (-int(round(compose_fp_list(fcn32s_upscore_list)[1])), + -int(round(compose_fp_list(fcn32s_upscore_list)[1]))) + # for fcn-16s + score2_fp = inv_fp(filter_map(kernel=4, stride=2)) + fcn16s_upscore_fp = inv_fp(filter_map(kernel=32, stride=16)) + score_pool4c_fp_list = [inv_fp(score2_fp), inv_fp(score_fp), inv_fp(fc7_fp), inv_fp(fc6_fp), + inv_fp(pool5_fp), inv_fp(conv5_3_fp), inv_fp(conv5_2_fp), + inv_fp(conv5_1_fp), score_pool4_fp] + crop["score_pool4c"] = (-int(round(compose_fp_list(score_pool4c_fp_list)[1])), + -int(round(compose_fp_list(score_pool4c_fp_list)[1]))) + fcn16s_upscore_list = [conv1_1_fp, conv1_2_fp, pool1_fp, conv2_1_fp, conv2_2_fp, + pool2_fp, conv3_1_fp, conv3_2_fp, conv3_3_fp, pool3_fp, + conv4_1_fp, conv4_2_fp, conv4_3_fp, pool4_fp, score_pool4_fp, + inv_fp((1, -crop["score_pool4c"][0])), fcn16s_upscore_fp] + crop["fcn16s_upscore"] = (-int(round(compose_fp_list(fcn16s_upscore_list)[1])), + -int(round(compose_fp_list(fcn16s_upscore_list)[1]))) + # for fcn-8s + score4_fp = inv_fp(filter_map(kernel=4, stride=2)) + fcn8s_upscore_fp = inv_fp(filter_map(kernel=16, stride=8)) + score_pool3c_fp_list = [inv_fp(score4_fp), (1, -crop["score_pool4c"][0]), inv_fp(score_pool4_fp), + inv_fp(pool4_fp), inv_fp(conv4_3_fp), inv_fp(conv4_2_fp), + inv_fp(conv4_1_fp), score_pool3_fp, score_pool3_fp] + crop["score_pool3c"] = (-int(round(compose_fp_list(score_pool3c_fp_list)[1])), + -int(round(compose_fp_list(score_pool3c_fp_list)[1]))) + fcn8s_upscore_list = [conv1_1_fp, conv1_2_fp, pool1_fp, conv2_1_fp, conv2_2_fp, pool2_fp, + conv3_1_fp, conv3_2_fp, conv3_3_fp, pool3_fp, score_pool3_fp, + inv_fp((1, -crop["score_pool3c"][0])), fcn8s_upscore_fp] + crop["fcn8s_upscore"] = (-int(round(compose_fp_list(fcn8s_upscore_list)[1])), + -int(round(compose_fp_list(fcn8s_upscore_list)[1]))) + return crop + +def vgg16_pool3(input, workspace_default=1024): + # group 1 + conv1_1 = mx.symbol.Convolution(data=input, kernel=(3, 3), pad=(100, 100), num_filter=64, + workspace=workspace_default, name="conv1_1") + relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + conv1_2 = mx.symbol.Convolution(data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, + workspace=workspace_default, name="conv1_2") + relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") + pool1 = mx.symbol.Pooling(data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") + # group 2 + conv2_1 = mx.symbol.Convolution(data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, + workspace=workspace_default, name="conv2_1") + relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + conv2_2 = mx.symbol.Convolution(data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, + workspace=workspace_default, name="conv2_2") + relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") + pool2 = mx.symbol.Pooling(data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") + # group 3 + conv3_1 = mx.symbol.Convolution(data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, + workspace=workspace_default, name="conv3_1") + relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = mx.symbol.Convolution(data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, + workspace=workspace_default, name="conv3_2") + relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + conv3_3 = mx.symbol.Convolution(data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, + workspace=workspace_default, name="conv3_3") + relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") + pool3 = mx.symbol.Pooling(data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") + return pool3 + +def vgg16_pool4(input, workspace_default=1024): + # group 4 + conv4_1 = mx.symbol.Convolution(data=input, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv4_1") + relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = mx.symbol.Convolution(data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv4_2") + relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + conv4_3 = mx.symbol.Convolution(data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv4_3") + relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") + pool4 = mx.symbol.Pooling(data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") + return pool4 + +def vgg16_score(input, numclass, workspace_default=1024): + # group 5 + conv5_1 = mx.symbol.Convolution(data=input, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv5_1") + relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = mx.symbol.Convolution(data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv5_2") + relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + conv5_3 = mx.symbol.Convolution(data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv5_3") + relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") + pool5 = mx.symbol.Pooling(data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") + # group 6 + fc6 = mx.symbol.Convolution(data=pool5, kernel=(7, 7), num_filter=4096, + workspace=workspace_default, name="fc6") + relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + # group 7 + fc7 = mx.symbol.Convolution(data=drop6, kernel=(1, 1), num_filter=4096, + workspace=workspace_default, name="fc7") + relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + # group 8 + score = mx.symbol.Convolution(data=drop7, kernel=(1, 1), num_filter=numclass, + workspace=workspace_default, name="score") + return score + +def fcnxs_score(input, crop, offset, kernel=(64,64), stride=(32,32), numclass=21, workspace_default=1024): + # score out + bigscore = mx.symbol.Deconvolution(data=input, kernel=kernel, stride=stride, num_filter=numclass, + workspace=workspace_default, name="bigscore") + upscore = mx.symbol.Crop(*[bigscore, crop], offset=offset, name="upscore") + softmax = mx.symbol.SoftmaxOutput(data=upscore, multi_output=True, use_ignore=True, ignore_label=255, name="softmax") + return softmax + +def get_fcn32s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + pool3 = vgg16_pool3(data, workspace_default) + pool4 = vgg16_pool4(pool3, workspace_default) + score = vgg16_score(pool4, numclass, workspace_default) + softmax = fcnxs_score(score, data, offset()["fcn32s_upscore"], (64,64), (32,32), numclass, workspace_default) + return softmax + +def get_fcn16s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + pool3 = vgg16_pool3(data, workspace_default) + pool4 = vgg16_pool4(pool3, workspace_default) + score = vgg16_score(pool4, numclass, workspace_default) + # score 2X + score2 = mx.symbol.Deconvolution(data=score, kernel=(4, 4), stride=(2, 2), num_filter=numclass, + workspace=workspace_default, name="score2") # 2X + score_pool4 = mx.symbol.Convolution(data=pool4, kernel=(1, 1), num_filter=numclass, + workspace=workspace_default, name="score_pool4") + score_pool4c = mx.symbol.Crop(*[score_pool4, score2], offset=offset()["score_pool4c"], name="score_pool4c") + score_fused = score2 + score_pool4c + softmax = fcnxs_score(score_fused, data, offset()["fcn16s_upscore"], (32, 32), (16, 16), numclass, workspace_default) + return softmax + +def get_fcn8s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + pool3 = vgg16_pool3(data, workspace_default) + pool4 = vgg16_pool4(pool3, workspace_default) + score = vgg16_score(pool4, numclass, workspace_default) + # score 2X + score2 = mx.symbol.Deconvolution(data=score, kernel=(4, 4), stride=(2, 2),num_filter=numclass, + workspace=workspace_default, name="score2") # 2X + score_pool4 = mx.symbol.Convolution(data=pool4, kernel=(1, 1), num_filter=numclass, + workspace=workspace_default, name="score_pool4") + score_pool4c = mx.symbol.Crop(*[score_pool4, score2], offset=offset()["score_pool4c"], name="score_pool4c") + score_fused = score2 + score_pool4c + # score 4X + score4 = mx.symbol.Deconvolution(data=score_fused, kernel=(4, 4), stride=(2, 2),num_filter=numclass, + workspace=workspace_default, name="score4") # 4X + score_pool3 = mx.symbol.Convolution(data=pool3, kernel=(1, 1), num_filter=numclass, + workspace=workspace_default, name="score_pool3") + score_pool3c = mx.symbol.Crop(*[score_pool3, score4], offset=offset()["score_pool3c"], name="score_pool3c") + score_final = score4 + score_pool3c + softmax = fcnxs_score(score_final, data, offset()["fcn8s_upscore"], (16, 16), (8, 8), numclass, workspace_default) + return softmax diff --git a/ps-lite b/ps-lite index b1da4b6e0f9e..d175ec2393c6 160000 --- a/ps-lite +++ b/ps-lite @@ -1 +1 @@ -Subproject commit b1da4b6e0f9e387ee30d2d02a063944986ff0cbd +Subproject commit d175ec2393c6ab00d5d0a143b42ee6dc6efb7038 diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index 8d08e40ba7d3..c6f466b22269 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -76,8 +76,13 @@ def __call__(self, param): if self.init: if count % self.frequent == 0: speed = self.frequent * self.batch_size / (time.time() - self.tic) - logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", - param.epoch, count, speed) + if param.eval_metric is not None: + name, value = param.eval_metric.get() + logging.info("Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f", + param.epoch, count, speed, name, value) + else: + logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", + param.epoch, count, speed) self.tic = time.time() else: self.init = True diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py index c008f058ab2a..e40e146a0af8 100644 --- a/python/mxnet/lr_scheduler.py +++ b/python/mxnet/lr_scheduler.py @@ -71,6 +71,6 @@ def __call__(self, num_update): if num_update > self.count + self.step: self.count += self.step self.base_lr *= self.factor - logging.info("Update[%d]: Change learning rate to %.5f", + logging.info("Update[%d]: Change learning rate to %0.5e", num_update, self.base_lr) return self.base_lr diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 8e3efe511c0c..4cb807e7232c 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -56,7 +56,7 @@ def update(self, labels, preds): if label.shape[0] < pred_label.shape[0]: raise Exception("Predict label is more than data label? ") self.sum_metric += numpy.sum(pred_label == label[:pred_label.shape[0]]) - num_inst = pred_label.shape[0] + num_inst = pred_label.size self.num_inst += num_inst class MAE(EvalMetric): diff --git a/src/operator/crop-inl.h b/src/operator/crop-inl.h new file mode 100644 index 000000000000..98a081fed5b3 --- /dev/null +++ b/src/operator/crop-inl.h @@ -0,0 +1,210 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file crop-inl.h + * \brief + * \author Wei Wu +*/ +#ifndef MXNET_OPERATOR_CROP_INL_H_ +#define MXNET_OPERATOR_CROP_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +namespace crop_enum { +enum CropOpInputs {kData, kCropLike}; +enum CropOpOutputs {kOut}; +} // namespace crop_enum + +struct CropParam : public dmlc::Parameter { + int num_args; + TShape offset; + TShape h_w; + bool center_crop; + DMLC_DECLARE_PARAMETER(CropParam) { + DMLC_DECLARE_FIELD(num_args).set_range(1, 3) + .describe("Number of inputs for crop, if equals one, then we will use the h_w" + "for crop heihgt and width, else if equals two, then we will use the height" + "and width of the second input symbol, we name crop_like here"); + int shape[] = {0, 0}; + DMLC_DECLARE_FIELD(offset).set_default(TShape(shape, shape + 2)) + .describe("corp offset coordinate: (y, x)"); + DMLC_DECLARE_FIELD(h_w).set_default(TShape(shape, shape + 2)) + .describe("corp height and weight: (h, w)"); + DMLC_DECLARE_FIELD(center_crop).set_default(false) + .describe("If set to true, then it will use be the center_crop," + "or it will crop using the shape of crop_like"); + } +}; // struct CropParam + +template +class CropOp : public Operator { + public: + explicit CropOp(CropParam param) { + this->param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(static_cast(in_data.size()), 2); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req[crop_enum::kOut], kWriteTo); + Stream *s = ctx.get_stream(); + Tensor data = in_data[crop_enum::kData].get(s); + Tensor out = out_data[crop_enum::kOut].get(s); + offset_hw_ = InferCropOfferset(data.shape_, out.shape_); + out = crop(data, Shape2(out.size(2), out.size(3)), offset_hw_[0], offset_hw_[1]); + } + + // because the crop_like input is only used with it's shape, so we should be + // careful setting its backwrd grad value to zeros, so that it will not hurt + // the connection of crop_like. + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_grad.size(), 2) << in_grad.size(); + CHECK_EQ(out_grad.size(), 1) << out_grad.size(); + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[crop_enum::kOut].get(s); + Tensor gdata = in_grad[crop_enum::kData].get(s); + Tensor gcrop_like = in_grad[crop_enum::kCropLike].get(s); + gcrop_like = (real_t)0.0f; + offset_hw_ = InferCropOfferset(gdata.shape_, grad.shape_); + gdata = (real_t)0.0f; + slice<3>(slice<2>(gdata, offset_hw_[0], offset_hw_[0]+grad.size(2)), + offset_hw_[1], offset_hw_[1]+grad.size(3)) = grad; + } + + private: + CropParam param_; + std::vector offset_hw_; + std::vector InferCropOfferset(const mshadow::Shape<4> &data_shape, + const mshadow::Shape<4> &out_shape) { + std::vector offset_hw; + CHECK_GE(data_shape[2], out_shape[2]) << + "data_shape'height should be larger than that of out_shape"; + CHECK_GE(data_shape[3], out_shape[3]) << + "data_shape'weight should be larger than that of out_shape"; + if (param_.center_crop) { + offset_hw.push_back(static_cast((data_shape[2]-out_shape[2])/2)); + offset_hw.push_back(static_cast((data_shape[3]-out_shape[3])/2)); + } else { + CHECK_GE(static_cast(param_.offset[0]), 0) << + "offset[0] should be larger than 0"; + CHECK_LE(static_cast(param_.offset[0]), data_shape[2]-out_shape[2]) << + "offset[0] should be less than the residual space of height"; + CHECK_GE(static_cast(param_.offset[1]), 0) << + "offset[1] should be larger than 0"; + CHECK_LE(static_cast(param_.offset[1]), data_shape[3]-out_shape[3]) << + "offset[1] should be less than the residual space of width"; + offset_hw.push_back(static_cast(param_.offset[0])); + offset_hw.push_back(static_cast(param_.offset[1])); + } + return offset_hw; + } +}; // class CropOp + +template +Operator *CreateOp(CropParam param); + +#if DMLC_USE_CXX11 +class CropProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + std::vector ListArguments() const override { + // return {"data", "crop_like"}; + std::vector ret; + for (int i = 0; i < param_.num_args; ++i) { + ret.push_back(std::string("arg") + static_cast('0' + i)); + } + return ret; + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); + TShape data_shape = in_shape->at(crop_enum::kData); + if (data_shape.ndim() == 0) return false; + CHECK_EQ(data_shape.ndim(), 4) << \ + "Input data should be 4D in batch-num_filter-y-x"; + std::vector crop_shape; + if (param_.num_args == 1) { + CHECK_GE(static_cast(param_.h_w[0]), 1) << + "the crop height(h_w[0]) should be larger than 1"; + CHECK_LE(static_cast(param_.h_w[0]), static_cast(data_shape[2])) << + "the crop height(h_w[0]) should be less than the input data's height"; + CHECK_GE(static_cast(param_.h_w[1]), 1) << + "the crop width(h_w[1]) should be larger than 1"; + CHECK_LE(static_cast(param_.h_w[1]), static_cast(data_shape[3])) << + "the crop width(h_w[1]) should be less than the input data's width"; + crop_shape.push_back(param_.h_w[0]); + crop_shape.push_back(param_.h_w[1]); + } else if (param_.num_args == 2) { + TShape crop_like_shape = in_shape->at(crop_enum::kCropLike); + crop_shape.push_back(crop_like_shape[2]); + crop_shape.push_back(crop_like_shape[3]); + } + if (crop_shape.size() == 0) return false; + CHECK_EQ(crop_shape.size(), 2) << \ + "Input crop_like should be 2D in height-width"; + out_shape->clear(); + data_shape[2] = crop_shape[0]; + data_shape[3] = crop_shape[1]; + out_shape->push_back(data_shape); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new CropProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "Crop"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return out_grad; + } + + Operator* CreateOperator(Context ctx) const override; + + private: + CropParam param_; +}; // class CropProp +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CROP_INL_H_ diff --git a/src/operator/crop.cc b/src/operator/crop.cc new file mode 100644 index 000000000000..2d46a64df78e --- /dev/null +++ b/src/operator/crop.cc @@ -0,0 +1,29 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat.cc + * \brief + * \author Wei Wu +*/ + +#include "./crop-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(CropParam param) { + return new CropOp(param); +} + +Operator* CropProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(CropParam); + +MXNET_REGISTER_OP_PROPERTY(Crop, CropProp) +.describe("Crop the 2th and 3th dim of input data, with the corresponding size of w_h or" +"with widht and height of the second input symbol") +.add_arguments(CropParam::__FIELDS__()) +.set_key_var_num_args("num_args"); +} // namespace op +} // namespace mxnet diff --git a/src/operator/crop.cu b/src/operator/crop.cu new file mode 100644 index 000000000000..64f8cb219f30 --- /dev/null +++ b/src/operator/crop.cu @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat.cu + * \brief + * \author Wei Wu +*/ + +#include "./crop-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(CropParam param) { + return new CropOp(param); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 60877a6b0c3c..fb026df72e55 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -2,7 +2,7 @@ * Copyright (c) 2015 by Contributors * \file softmax_output-inl.h * \brief - * \author Junyuan Xie + * \author Bing Xu */ #ifndef MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ #define MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ @@ -27,14 +27,22 @@ enum SoftmaxOutputOpOutputs {kOut}; struct SoftmaxOutputParam : public dmlc::Parameter { float grad_scale; + float ignore_label; bool multi_output; + bool use_ignore; DMLC_DECLARE_PARAMETER(SoftmaxOutputParam) { DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f) .describe("Scale the gradient by a float factor"); + DMLC_DECLARE_FIELD(ignore_label).set_default(-1.0f) + .describe("the ignore_label will not work in backward, and this only" + "be used when multi_output=true"); DMLC_DECLARE_FIELD(multi_output).set_default(false) .describe("If set to true, for a (n,k,x_1,..,x_n) dimensional" "input tensor, softmax will generate n*x_1*...*x_n output, each" "has k classes"); + DMLC_DECLARE_FIELD(use_ignore).set_default(false) + .describe("If set to true, the ignore_label value will not contributor" + "to the backward gradient"); }; }; @@ -88,8 +96,12 @@ class SoftmaxOutputOp : public Operator { Tensor label = in_data[softmaxout_enum::kLabel].FlatTo2D(s); Tensor out = out_data[softmaxout_enum::kOut].get_with_shape(s3, s); Tensor grad = in_grad[softmaxout_enum::kData].get_with_shape(s3, s); - SoftmaxGrad(grad, out, label); - grad *= param_.grad_scale/s3[2]; + if (param_.use_ignore) { + SoftmaxGrad(grad, out, label, static_cast(param_.ignore_label)); + } else { + SoftmaxGrad(grad, out, label); + } + grad *= param_.grad_scale; } else { Tensor label = in_data[softmaxout_enum::kLabel].get(s); Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s);