diff --git a/dmlc-core b/dmlc-core index a9b3320d2c6b..4b951c037838 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit a9b3320d2c6b29506139784b877142c9ee78caaf +Subproject commit 4b951c0378386b7f4d9eae72be2ecd3b9c816afe diff --git a/example/fcn-xs/README.md b/example/fcn-xs/README.md new file mode 100644 index 000000000000..cc67846a4b6e --- /dev/null +++ b/example/fcn-xs/README.md @@ -0,0 +1,55 @@ +FCN-xs EXAMPLES +--------------- +This folder contains the examples of image segmentation in MXNet. + +## Sample results +![fcn-xs pasval_voc result](https://github.com/tornadomeet/mxnet/blob/seg/example/fcn-xs/fcn-xs_pascal.jpg) + +## How to train fcn-xs in mxnet +#### step1: get the fully convulutional style of vgg16 model +* dwonload the vgg16 caffe-model from [VGG_ILSVRC_16_layers.caffemodel](http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel),and the corresponding [VGG_ILSVRC_16_layers_deploy.prototxt](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt), the vgg16 model has [license](http://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only. +* use convert_model.py to convet the caffe model to mxnet model, like(shell): +``` + vgg16_deploy=VGG_ILSVRC_16_layers_deploy.prototxt + vgg16_caffemodel=VGG_ILSVRC_16_layers.caffemodel + model_prefix=VGG_ILSVRC_16_layers + cmd=../../tools/caffe_converter/convert_model.py + python $cmd $vgg16_deploy $vgg16_caffemodel VGG_ILSVRC_16_layers + mv VGG_ILSVRC_16_layers-0001.params VGG_ILSVRC_16_layers-0074.params +``` +* convet conv+fully-connect style to fully convolutional style, like(shell): +``` + python create_vgg16fc_model.py VGG_ILSVRC_16_layers 74 VGG_FC_ILSVRC_16_layers +``` + you can use vgg16fc model now, or you can download it directly from [yun.baidu](http://pan.baidu.com/s/1jGlOvno). + + **`be careful: if you put one (very) large image to the vgg16fc model, you should change the 'workspace_default' value larger(Related to your field) in create_vgg16fc_model.py.`** +* or you can directly download the ```VGG_FC_ILSVRC_16_layers-symbol.json``` and ```VGG_FC_ILSVRC_16_layers-0074.params``` from [yun.baidu](http://pan.baidu.com/s/1jGlOvno) + +#### step2: prepare your training Data +in the example here, the training image list owns the form: +```index \t image_data_path \t image_label_path``` +the labels for one image in image segmentation field is also one image, with the same shape of input image. +* or you can directly download the ```VOC2012.rar``` from [yun.baidu](http://pan.baidu.com/s/1jGlOvno), and Extract it. the file/folder will be: +```JPEGImages folder```, ```SegmentationClass folder```, ```train.lst```, ```val.lst```, ```test.lst``` + +#### step3: begin training fcn-xs +if you want to train the fcn-8s model, it's better for you trained the fcn-32s and fcn-16s model firstly. +when training the fcn-32s model, run in shell ```./run_fcn32s.sh``` +* in the fcn_xs.py(e.g. fcn_32s.py, fcn_16s.py, fcn_8s.py), you may need to change the directory ```img_dir```, ```train_lst```, ```val_lst```, ```fcnxs_model_prefix``` +* the output log may like this(when training fcn-8s): +```c++ +INFO:root:Start training with gpu(3) +INFO:root:Epoch[0] Batch [50] Speed: 1.16 samples/sec Train-accuracy=0.894318 +INFO:root:Epoch[0] Batch [100] Speed: 1.11 samples/sec Train-accuracy=0.904681 +INFO:root:Epoch[0] Batch [150] Speed: 1.13 samples/sec Train-accuracy=0.908053 +INFO:root:Epoch[0] Batch [200] Speed: 1.12 samples/sec Train-accuracy=0.912219 +INFO:root:Epoch[0] Batch [250] Speed: 1.13 samples/sec Train-accuracy=0.914238 +INFO:root:Epoch[0] Batch [300] Speed: 1.13 samples/sec Train-accuracy=0.912170 +INFO:root:Epoch[0] Batch [350] Speed: 1.12 samples/sec Train-accuracy=0.912080 +``` + +## TODO +* add the example of using pretrained model +* add the crop_offset function in symbol(both c++ and python side of mxnet) +* make the example more cleaner(the code is some dirty here) diff --git a/example/fcn-xs/create_vgg16fc_model.py b/example/fcn-xs/create_vgg16fc_model.py new file mode 100644 index 000000000000..a35a5913339f --- /dev/null +++ b/example/fcn-xs/create_vgg16fc_model.py @@ -0,0 +1,204 @@ +# pylint: skip-file +import sys, os +import argparse +import mxnet as mx +import numpy as np +import logging + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +workspace_default = 1024 + +## define vgg16 +def get_vgg16_symbol(): + data = mx.symbol.Variable(name="data") + # group 1 + conv1_1 = mx.symbol.Convolution( + data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1") + relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + conv1_2 = mx.symbol.Convolution( + data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_2") + relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") + pool1 = mx.symbol.Pooling( + data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") + # group 2 + conv2_1 = mx.symbol.Convolution( + data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1") + relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + conv2_2 = mx.symbol.Convolution( + data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_2") + relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") + pool2 = mx.symbol.Pooling( + data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") + # group 3 + conv3_1 = mx.symbol.Convolution( + data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1") + relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = mx.symbol.Convolution( + data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2") + relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + conv3_3 = mx.symbol.Convolution( + data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_3") + relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") + pool3 = mx.symbol.Pooling( + data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") + # group 4 + conv4_1 = mx.symbol.Convolution( + data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1") + relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = mx.symbol.Convolution( + data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2") + relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + conv4_3 = mx.symbol.Convolution( + data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_3") + relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") + pool4 = mx.symbol.Pooling( + data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") + # group 5 + conv5_1 = mx.symbol.Convolution( + data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1") + relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = mx.symbol.Convolution( + data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2") + relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + conv5_3 = mx.symbol.Convolution( + data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_3") + relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") + pool5 = mx.symbol.Pooling( + data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") + # group 6 + flatten = mx.symbol.Flatten(data=pool5, name="flatten") + fc6 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096, name="fc6") + relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + # group 7 + fc7 = mx.symbol.FullyConnected(data=drop6, num_hidden=4096, name="fc7") + relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + # output + fc8 = mx.symbol.FullyConnected(data=drop7, num_hidden=1000, name="fc8") + softmax = mx.symbol.SoftmaxOutput(data=fc8, name="prob") + return softmax + +## define vgg16 +def get_vgg16fc_symbol(): + data = mx.symbol.Variable(name="data") + # group 1 + conv1_1 = mx.symbol.Convolution( + data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1", workspace=workspace_default) + relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + conv1_2 = mx.symbol.Convolution( + data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_2", workspace=workspace_default) + relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") + pool1 = mx.symbol.Pooling( + data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") + # group 2 + conv2_1 = mx.symbol.Convolution( + data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1", workspace=workspace_default) + relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + conv2_2 = mx.symbol.Convolution( + data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_2", workspace=workspace_default) + relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") + pool2 = mx.symbol.Pooling( + data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") + # group 3 + conv3_1 = mx.symbol.Convolution( + data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1", workspace=workspace_default) + relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = mx.symbol.Convolution( + data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2", workspace=workspace_default) + relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + conv3_3 = mx.symbol.Convolution( + data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_3", workspace=workspace_default) + relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") + pool3 = mx.symbol.Pooling( + data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") + # group 4 + conv4_1 = mx.symbol.Convolution( + data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1", workspace=workspace_default) + relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = mx.symbol.Convolution( + data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2", workspace=workspace_default) + relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + conv4_3 = mx.symbol.Convolution( + data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_3", workspace=workspace_default) + relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") + pool4 = mx.symbol.Pooling( + data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") + # group 5 + conv5_1 = mx.symbol.Convolution( + data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1", workspace=workspace_default) + relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = mx.symbol.Convolution( + data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2", workspace=workspace_default) + relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + conv5_3 = mx.symbol.Convolution( + data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_3", workspace=workspace_default) + relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") + pool5 = mx.symbol.Pooling( + data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") + # group 6 + fc6 = mx.symbol.Convolution( + data=pool5, kernel=(7, 7), num_filter=4096, name="fc6", workspace=workspace_default) + relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + # group 7 + fc7 = mx.symbol.Convolution( + data=drop6, kernel=(1, 1), num_filter=4096, name="fc7", workspace=workspace_default) + relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + # group 8 + fc8 = mx.symbol.Convolution( + data=drop7, kernel=(1, 1), num_filter=1000, name="fc8", workspace=workspace_default) + # output + softmax = mx.symbol.SoftmaxOutput(data=fc8, multi_output=True, name="prob") + return softmax + +def get_vgg16fc_arg_param(vgg16_arg_params): + vgg16fc_arg_param = vgg16_arg_params + vgg16fc_arg_param["fc6_weight"] = \ + mx.nd.array(vgg16_arg_params["fc6_weight"].asnumpy().reshape(4096,512,7,7)) + vgg16fc_arg_param["fc7_weight"] = \ + mx.nd.array(vgg16_arg_params["fc7_weight"].asnumpy().reshape(4096,4096,1,1)) + vgg16fc_arg_param["fc8_weight"] = \ + mx.nd.array(vgg16_arg_params["fc8_weight"].asnumpy().reshape(1000,4096,1,1)) + return vgg16fc_arg_param + +def main(): + parser = argparse.ArgumentParser(description='Convert vgg16 model to vgg16fc model.') + parser.add_argument('prefix', default='VGG_ILSVRC_16_layers', + help='The prefix(include path) of vgg16 model with mxnet format.') + parser.add_argument('epoch', type=int, default=74, + help='The epoch number of vgg16 model.') + parser.add_argument('prefix_fc', default='VGG_FC_ILSVRC_16_layers', + help='The prefix(include path) of vgg16fc model which your want to save.') + args = parser.parse_args() + + vgg16_symbol, vgg16_arg_params, vgg16_aux_params = \ + mx.model.load_checkpoint(args.prefix, args.epoch) + + ## when your get original vgg16 symbol, here is two way: + # way 1: use get_vgg16_symbol() to get symbol, and it's good in + # transfer learning. + # softmax = get_vgg16_symbol() + # vgg_model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, + # arg_params=vgg16_arg_params, aux_params=vgg16_aux_params) + + # way 2: use caffe converter get the symbol directly + # softmax = get_vgg16_symbol() + # vgg_model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, + # arg_params=vgg16_symbol, aux_params=vgg16_aux_params) + + # vgg16fc_mxnet + softmax = get_vgg16fc_symbol() + vgg16fc_arg_params = get_vgg16fc_arg_param(vgg16_arg_params) + vgg_model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, + arg_params=vgg16fc_arg_params, aux_params=vgg16_aux_params) + + # vgg_model.save(prefix=args.prefix_fc, epoch=1) + vgg_model.save(prefix=args.prefix_fc, epoch=args.epoch) + return "" + +if __name__ == "__main__": + main() diff --git a/example/fcn-xs/data.py b/example/fcn-xs/data.py new file mode 100644 index 000000000000..9115d32a45a5 --- /dev/null +++ b/example/fcn-xs/data.py @@ -0,0 +1,108 @@ +# pylint: skip-file +""" data iterator for pasval voc 2012""" +import mxnet as mx +import numpy as np +import sys +import os +from mxnet.io import DataIter +from skimage import io +from PIL import Image + +class FileIter(DataIter): + """FileIter object in mxnet. Taking NDArray or numpy array to get dataiter. + Parameters + ---------- + data_list or data, label: a list of, or two separate NDArray or numpy.ndarray + list of NDArray for data. The last one is treated as label. + batch_size: int + Batch Size + shuffle: bool + Whether to shuffle the data + data_pad_value: float, optional + Padding value for data + label_pad_value: float, optionl + Padding value for label + last_batch_handle: 'pad', 'discard' or 'roll_over' + How to handle the last batch + Note + ---- + This iterator will pad, discard or roll over the last batch if + the size of data does not match batch_size. Roll over is intended + for training and can cause problems if used for prediction. + """ + def __init__(self, root_dir, flist_name, data_name="data", label_name="softmax_label"): + super(FileIter, self).__init__() + self.root_dir = root_dir + self.data_name = data_name + self.label_name = label_name + self.flist_name = os.path.join(self.root_dir, flist_name) + self.num_data = len(open(self.flist_name, 'r').readlines()) + self.img_path = "img_path" + self.f = open(self.flist_name, 'r') + self.mean = np.array([123.68, 116.779, 103.939]) # (R, G, B) + self.data, self.label, self.img_name = self._read(self.f) + self.cursor = -1 + + def _read(self, f): + _, data_img_name, label_img_name = f.readline().strip('\n').split("\t") + data = {} + label = {} + data[self.data_name] = self._read_img(data_img_name) + label[self.label_name] = self._read_img(label_img_name, True) + return list(data.items()), list(label.items()), data_img_name + + def _read_img(self, img_name, is_label_img=False): + if not is_label_img: + img = Image.open(os.path.join(self.root_dir, img_name)) + img = np.array(img, dtype=np.float32) # (h, w, c) + reshaped_mean = self.mean.reshape(1, 1, 3) + img = img - reshaped_mean + img = np.swapaxes(img, 0, 2) + img = np.swapaxes(img, 1, 2) # (c, h, w) + else: + img = Image.open(os.path.join(self.root_dir, img_name)) + img = np.array(img) # (h, w) + # img[img==255] = 0 # change the value of 255 to 0 + img = np.expand_dims(img, axis=0) # (1, c, h, w) or (1, h, w) + return img + + @property + def provide_data(self): + """The name and shape of data provided by this iterator""" + return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.data] + + @property + def provide_label(self): + """The name and shape of label provided by this iterator""" + return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.label] + + @property + def batch_size(self): + return 1 + + def reset(self): + self.cursor = -1 + self.f.close() + self.f = open(self.flist_name, 'r') + + def iter_next(self): + self.cursor += 1 + if(self.cursor < self.num_data-1): + return True + else: + return False + + def next(self): + if self.iter_next(): + self.data, self.label, self.img_name = self._read(self.f) + return {self.data_name:self.getdata(), + self.label_name:self.getlabel(), + self.img_path:self.img_name} + else: + raise StopIteration + + def getdata(self): + return self.data[0][1] + + def getlabel(self): + return self.label[0][1] diff --git a/example/fcn-xs/fcn-xs_pascal.jpg b/example/fcn-xs/fcn-xs_pascal.jpg new file mode 100644 index 000000000000..70b9763db645 Binary files /dev/null and b/example/fcn-xs/fcn-xs_pascal.jpg differ diff --git a/example/fcn-xs/fcn_16s.py b/example/fcn-xs/fcn_16s.py new file mode 100644 index 000000000000..276414d77804 --- /dev/null +++ b/example/fcn-xs/fcn_16s.py @@ -0,0 +1,66 @@ +# pylint: skip-file +import sys, os +import argparse +import mxnet as mx +import numpy as np +import logging +import symbol_fcnxs +import init_fcnxs +from data import FileIter +from solver import Solver + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +np.set_printoptions(threshold=np.nan) + +img_dir = "./VOC2012" +train_lst = "train.lst" +val_lst = "val.lst" +fcn16s_model_prefix = "model_pascal/FCN16s_VGG16" +batch_size = 1 +workspace = 1536 +ctx = mx.gpu(0) +# ctx = mx.cpu() + +def norm_stat(d): + return mx.nd.norm(d)/np.sqrt(d.size) + +def main(): + fcn16s = symbol_fcnxs.get_fcn16s_symbol(21, workspace) + arg_names = fcn16s.list_arguments() + print "arg_names=", arg_names + arg_shapes, out_shapes, _ = fcn16s.infer_shape(data=(1,3,336,500)) + print "out_shapes[0]=", out_shapes[0] + arg_shapes_dict = dict(zip(arg_names, arg_shapes)) + _, vgg16fc_arg_params, vgg16fc_aux_params = mx.model.load_checkpoint(args.prefix, args.epoch) + # fcn16s_arg_params, fcn16s_aux_params = init_fcnxs.init_fcn16s_params(ctx, fcn16s, vgg16fc_arg_params, vgg16fc_aux_params) + fcn16s_arg_params, fcn16s_aux_params = init_fcnxs.init_fcn16s_params_from_fcn32s(ctx, fcn16s, vgg16fc_arg_params, vgg16fc_aux_params) + train_dataiter = FileIter(img_dir, train_lst) + val_dataiter = FileIter(img_dir, val_lst) + mon = mx.mon.Monitor(1, norm_stat) + model = Solver( + ctx = ctx, + symbol = fcn16s, + begin_epoch = 0, + num_epoch = 100, + arg_params = fcn16s_arg_params, + aux_params = fcn16s_aux_params, + learning_rate = 1e-12, + momentum = 0.99, + wd = 0.0005, + snapshot = 1, + monitor = None) + model.fit( + train_data = train_dataiter, + eval_data = val_dataiter, + batch_end_callback = mx.callback.Speedometer(batch_size, 50), + epoch_end_callback = mx.callback.do_checkpoint(fcn16s_model_prefix)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='training pascal voc segmentation using fcn-16s.') + parser.add_argument('prefix', default='FCN16s_VGG16', + help='The prefix(include path) of vgg16 model with mxnet format.') + parser.add_argument('epoch', type=int, default=199, + help='The epoch number of fcn16s model.') + args = parser.parse_args() + main() diff --git a/example/fcn-xs/fcn_32s.py b/example/fcn-xs/fcn_32s.py new file mode 100644 index 000000000000..5b467888c059 --- /dev/null +++ b/example/fcn-xs/fcn_32s.py @@ -0,0 +1,64 @@ +import sys, os +import argparse +import mxnet as mx +import numpy as np +import logging +import symbol_fcnxs +import init_fcnxs +from data import FileIter +from solver import Solver + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +np.set_printoptions(threshold=np.nan) + +img_dir = "./VOC2012" +train_lst = "train.lst" +val_lst = "val.lst" +fcn32s_model_prefix = "model_pascal/FCN32s_VGG16" +batch_size = 1 +workspace = 1536 +ctx = mx.gpu(0) +# ctx = mx.cpu() + +def norm_stat(d): + return mx.nd.norm(d)/np.sqrt(d.size) + +def main(): + fcn32s = symbol_fcnxs.get_fcn32s_symbol(21, workspace) + arg_names = fcn32s.list_arguments() + arg_shapes, out_shapes, _ = fcn32s.infer_shape(data=(1,3,336,500)) + print "out_shapes[0]=", out_shapes[0] + arg_shapes_dict = dict(zip(arg_names, arg_shapes)) + _, vgg16fc_arg_params, vgg16fc_aux_params = \ + mx.model.load_checkpoint(args.prefix, args.epoch) + fcn32s_arg_params, fcn32s_aux_params = \ + init_fcnxs.init_fcn32s_params(ctx, fcn32s, vgg16fc_arg_params, vgg16fc_aux_params) + train_dataiter = FileIter(img_dir, train_lst) + val_dataiter = FileIter(img_dir, val_lst) + mon = mx.mon.Monitor(10, norm_stat) + model = Solver( + ctx = ctx, + symbol = fcn32s, + begin_epoch = 0, + num_epoch = 200, + arg_params = fcn32s_arg_params, + aux_params = fcn32s_aux_params, + learning_rate = 1e-10, + momentum = 0.99, + wd = 0.0005, + monitor = None) + model.fit( + train_data = train_dataiter, + eval_data = val_dataiter, + batch_end_callback = mx.callback.Speedometer(batch_size, 10), + epoch_end_callback = mx.callback.do_checkpoint(fcn32s_model_prefix)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Convert vgg16 model to vgg16fc model.') + parser.add_argument('prefix', default='VGG_FC_ILSVRC_16_layers', + help='The prefix(include path) of vgg16 model with mxnet format.') + parser.add_argument('epoch', type=int, default=74, + help='The epoch number of vgg16 model.') + args = parser.parse_args() + main() diff --git a/example/fcn-xs/fcn_8s.py b/example/fcn-xs/fcn_8s.py new file mode 100644 index 000000000000..fb1d5861b73f --- /dev/null +++ b/example/fcn-xs/fcn_8s.py @@ -0,0 +1,64 @@ +import sys, os +import argparse +import mxnet as mx +import numpy as np +import logging +import symbol_fcnxs +import init_fcnxs +from data import FileIter +from solver import Solver + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +np.set_printoptions(threshold=np.nan) + +img_dir = "./VOC2012" +train_lst = "train.lst" +val_lst = "val.lst" +fcn8s_model_prefix = "model_pascal/FCN8s_VGG16" +batch_size = 1 +workspace = 1536 +ctx = mx.gpu(0) +# ctx = mx.cpu() + +def norm_stat(d): + return mx.nd.norm(d)/np.sqrt(d.size) + +def main(): + fcn8s = symbol_fcnxs.get_fcn8s_symbol(21, workspace) + arg_names = fcn8s.list_arguments() + print "arg_names=", arg_names + arg_shapes, out_shapes, _ = fcn8s.infer_shape(data=(1,3,336,500)) + print "out_shapes[0]=", out_shapes[0] + arg_shapes_dict = dict(zip(arg_names, arg_shapes)) + _, vgg16fc_arg_params, vgg16fc_aux_params = mx.model.load_checkpoint(args.prefix, args.epoch) + fcn8s_arg_params, fcn8s_aux_params = init_fcnxs.init_fcn8s_params_from_fcn16s(ctx, fcn8s, vgg16fc_arg_params, vgg16fc_aux_params) + train_dataiter = FileIter(img_dir, train_lst) + val_dataiter = FileIter(img_dir, val_lst) + mon = mx.mon.Monitor(10, norm_stat) + model = Solver( + ctx = ctx, + symbol = fcn8s, + begin_epoch = 0, + num_epoch = 100, + arg_params = fcn8s_arg_params, + aux_params = fcn8s_aux_params, + learning_rate = 1e-14, + momentum = 0.99, + wd = 0.0005, + snapshot = 1, + monitor = None) + model.fit( + train_data = train_dataiter, + eval_data = val_dataiter, + batch_end_callback = mx.callback.Speedometer(batch_size, 50), + epoch_end_callback = mx.callback.do_checkpoint(fcn8s_model_prefix)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='training pascal voc segmentation using fcn-8s.') + parser.add_argument('prefix', default='FCN32s_VGG16', + help='The prefix(include path) of vgg16 model with mxnet format.') + parser.add_argument('epoch', type=int, default=199, + help='The epoch number of fcn32s model.') + args = parser.parse_args() + main() diff --git a/example/fcn-xs/init_fcnxs.py b/example/fcn-xs/init_fcnxs.py new file mode 100644 index 000000000000..36f6e48b8cef --- /dev/null +++ b/example/fcn-xs/init_fcnxs.py @@ -0,0 +1,165 @@ +# pylint: skip-file +import mxnet as mx +import numpy as np + +# make a bilinear interpolation kernel, return a numpy.ndarray +def upsample_filt(size): + factor = (size + 1) // 2 + if size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = np.ogrid[:size, :size] + return (1 - abs(og[0] - center) / factor) * \ + (1 - abs(og[1] - center) / factor) + +def init_fcn32s_params(ctx, fcn32s_symbol, vgg16fc_arg_params, vgg16fc_aux_params, retrain): + fcn32s_arg_params = vgg16fc_arg_params.copy() + fcn32s_aux_params = vgg16fc_aux_params.copy() + if not retrain: + for k,v in fcn32s_arg_params.items(): + if(v.context != ctx): + fcn32s_arg_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn32s_arg_params[k]) + for k,v in fcn32s_aux_params.items(): + if(v.context != ctx): + fcn32s_aux_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn32s_aux_params[k]) + data_shape=(1,3,500,500) + arg_names = fcn32s_symbol.list_arguments() + arg_shapes, _, _ = fcn32s_symbol.infer_shape(data=data_shape) + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['fc8_weight', 'fc8_bias']]) + fcn32s_arg_params.update(rest_params) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ["deconv8_weight"]]) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcn32s_arg_params[k] = mx.nd.array(initw, ctx) + else: + print "it is retrain, so will use the model weight trained before." + return fcn32s_arg_params, fcn32s_aux_params + +def init_fcn16s_params(ctx, fcn16s_symbol, vgg16fc_arg_params, vgg16fc_aux_params): + fcn16s_arg_params = vgg16fc_arg_params.copy() + fcn16s_aux_params = vgg16fc_aux_params.copy() + del fcn16s_arg_params["fc8_weight"] + del fcn16s_arg_params["fc8_bias"] + for k,v in fcn16s_arg_params.items(): + if(v.context != ctx): + fcn16s_arg_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn8s_arg_params[k]) + for k,v in fcn16s_aux_params.items(): + if(v.context != ctx): + fcn16s_aux_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn8s_aux_params[k]) + data_shape=(1,3,500,500) + arg_names = fcn16s_symbol.list_arguments() + arg_shapes, _, _ = fcn16s_symbol.infer_shape(data=data_shape) + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_weight', 'score_bias', + 'score_pool4_weight', 'score_pool4_bias']]) + fcn16s_arg_params.update(rest_params) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ['score2_weight', 'bigscore_weight']]) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcn16s_arg_params[k] = mx.nd.array(initw, ctx) + + # print "fcn16s_arg_params[conv1_1_weight]=", fcn16s_arg_params["conv1_1_weight"].asnumpy() + # print "fcn16s_arg_params[conv1_2_weight]=", fcn16s_arg_params["conv1_2_weight"].asnumpy() + return fcn16s_arg_params, fcn16s_aux_params + +def init_fcn16s_params_from_fcn32s(ctx, fcn16s_symbol, fcn32s_arg_params, fcn32s_aux_params): + fcn16s_arg_params = fcn32s_arg_params.copy() + fcn16s_aux_params = fcn32s_aux_params.copy() + fcn16s_arg_params["score_weight"] = fcn16s_arg_params["fc8_weight"] + fcn16s_arg_params["score_bias"] = fcn16s_arg_params["fc8_bias"] + del fcn16s_arg_params["deconv8_weight"] + for k,v in fcn16s_arg_params.items(): + if(v.context != ctx): + fcn16s_arg_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn16s_arg_params[k]) + for k,v in fcn16s_aux_params.items(): + if(v.context != ctx): + fcn16s_aux_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn16s_aux_params[k]) + data_shape=(1,3,500,500) + arg_names = fcn16s_symbol.list_arguments() + arg_shapes, _, _ = fcn16s_symbol.infer_shape(data=data_shape) + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_pool4_weight', 'score_pool4_bias']]) + fcn16s_arg_params.update(rest_params) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ['score2_weight', 'bigscore_weight']]) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcn16s_arg_params[k] = mx.nd.array(initw, ctx) + + return fcn16s_arg_params, fcn16s_aux_params + +def init_fcn8s_params(ctx, fcn8s_symbol, vgg16fc_arg_params, vgg16fc_aux_params): + fcn8s_arg_params = vgg16fc_arg_params.copy() + fcn8s_aux_params = vgg16fc_aux_params.copy() + del fcn8s_arg_params["fc8_weight"] + del fcn8s_arg_params["fc8_bias"] + for k,v in fcn8s_arg_params.items(): + if(v.context != ctx): + fcn8s_arg_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn8s_arg_params[k]) + for k,v in fcn8s_aux_params.items(): + if(v.context != ctx): + fcn8s_aux_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn8s_aux_params[k]) + data_shape=(1,3,500,500) + arg_names = fcn8s_symbol.list_arguments() + arg_shapes, _, _ = fcn8s_symbol.infer_shape(data=data_shape) + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_weight', 'score_bias', + 'score_pool4_weight', 'score_pool4_bias', \ + 'score_pool3_bias', 'score_pool3_weight']]) + fcn8s_arg_params.update(rest_params) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ['score2_weight', 'score4_weight', 'bigscore_weight']]) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcn8s_arg_params[k] = mx.nd.array(initw, ctx) + + # print "fcn8s_arg_params[conv1_1_weight]=", fcn8s_arg_params["conv1_1_weight"].asnumpy() + # print "fcn8s_arg_params[conv1_2_weight]=", fcn8s_arg_params["conv1_2_weight"].asnumpy() + return fcn8s_arg_params, fcn8s_aux_params + +def init_fcn8s_params_from_fcn16s(ctx, fcn8s_symbol, fcn16s_arg_params, fcn16s_aux_params): + fcn8s_arg_params = fcn16s_arg_params.copy() + fcn8s_aux_params = fcn16s_aux_params.copy() + for k,v in fcn8s_arg_params.items(): + if(v.context != ctx): + fcn8s_arg_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn8s_arg_params[k]) + for k,v in fcn8s_aux_params.items(): + if(v.context != ctx): + fcn8s_aux_params[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcn8s_aux_params[k]) + data_shape=(1,3,500,500) + arg_names = fcn8s_symbol.list_arguments() + arg_shapes, _, _ = fcn8s_symbol.infer_shape(data=data_shape) + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_pool3_bias', 'score_pool3_weight']]) + fcn8s_arg_params.update(rest_params) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ['score4_weight', 'bigscore_weight']]) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcn8s_arg_params[k] = mx.nd.array(initw, ctx) + + return fcn8s_arg_params, fcn8s_aux_params diff --git a/example/fcn-xs/run_fcn16s.sh b/example/fcn-xs/run_fcn16s.sh new file mode 100755 index 000000000000..f5ce46cb0cae --- /dev/null +++ b/example/fcn-xs/run_fcn16s.sh @@ -0,0 +1,2 @@ +python -u fcn_16s.py FCN32s_VGG16 31 +#python -u fcn_16s.py VGG_FC_ILSVRC_16_layers 74 diff --git a/example/fcn-xs/run_fcn32s.sh b/example/fcn-xs/run_fcn32s.sh new file mode 100755 index 000000000000..744e1fc0f60e --- /dev/null +++ b/example/fcn-xs/run_fcn32s.sh @@ -0,0 +1 @@ +python -u fcn_32s.py VGG_FC_ILSVRC_16_layers 74 diff --git a/example/fcn-xs/run_fcn8s.sh b/example/fcn-xs/run_fcn8s.sh new file mode 100755 index 000000000000..480a1325b66e --- /dev/null +++ b/example/fcn-xs/run_fcn8s.sh @@ -0,0 +1,2 @@ +python -u fcn_8s.py FCN16s_VGG16 27 +# python -u fcn_8s.py VGG_FC_ILSVRC_16_layers 74 diff --git a/example/fcn-xs/solver.py b/example/fcn-xs/solver.py new file mode 100644 index 000000000000..7ca404323190 --- /dev/null +++ b/example/fcn-xs/solver.py @@ -0,0 +1,148 @@ +import numpy as np +import mxnet as mx +import time +import logging +from collections import namedtuple +from mxnet import optimizer as opt +from mxnet.optimizer import get_updater +from mxnet import metric + +# Parameter to pass to batch_end_callback +BatchEndParam = namedtuple('BatchEndParams', + ['epoch', + 'nbatch', + 'eval_metric']) +def norm_stat(d): + return mx.nd.norm(d)/np.sqrt(d.size) + +class Solver(object): + def __init__(self, symbol, ctx=None, + num_epoch=None, epoch_size=None, + optimizer='sgd', + arg_params=None, aux_params=None, + begin_epoch=0, snapshot=1, + monitor=None, + **kwargs): + self.symbol = symbol + if ctx is None: + ctx = mx.cpu() + self.ctx = ctx + self.begin_epoch = begin_epoch + self.num_epoch = num_epoch + self.epoch_size = epoch_size + self.kwargs = kwargs.copy() + self.optimizer = optimizer + self.arg_params = arg_params + self.aux_params = aux_params + self.begin_epoch = begin_epoch + self.snapshot = snapshot + self.monitor = monitor + + def fit(self, + train_data, + eval_data=None, eval_metric='acc', + grad_req='write', + epoch_end_callback=None, + batch_end_callback=None, + kvstore='local', + logger=None): + if logger is None: + logger = logging + logging.info('Start training with %s', str(self.ctx)) + mx.model._check_arguments(self.symbol) + arg_shapes, out_shapes, aux_shapes = self.symbol.infer_shape( + data=train_data.provide_data[0][1]) + arg_names = self.symbol.list_arguments() + if grad_req != 'null': + self.grad_params = {} + for name, shape in zip(arg_names, arg_shapes): + if not (name.endswith('data') or name.endswith('label')): + self.grad_params[name] = mx.nd.zeros(shape, self.ctx) + else: + self.grad_params = None + aux_names = self.symbol.list_auxiliary_states() + self.aux_params = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)} + # names + data_name = train_data.data_name + label_name = train_data.label_name + img_path_name = train_data.img_path + input_names = [data_name, label_name] + self.optimizer = opt.create(self.optimizer, + rescale_grad=(1.0/train_data.batch_size), + **(self.kwargs)) + self.updater = get_updater(self.optimizer) + # evaluation + eval_metric = metric.create(eval_metric) + + for epoch in range(self.begin_epoch, self.num_epoch): + nbatch = 0 + train_data.reset() + eval_metric.reset() + for data in train_data: + nbatch += 1 + label_shape = data[label_name].shape + self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) + self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2]), self.ctx) + output_names = self.symbol.list_outputs() + self.exector = self.symbol.bind(self.ctx, self.arg_params, + args_grad=self.grad_params, + grad_req=grad_req, + aux_states=self.aux_params) + if self.monitor is not None: + self.monitor.install(self.exector) + self.monitor.tic() + assert len(self.symbol.list_arguments()) == len(self.exector.grad_arrays) + update_dict = {name: nd for name, nd in zip(self.symbol.list_arguments(), self.exector.grad_arrays) if nd} + output_dict = {} + output_buff = {} + for key, arr in zip(self.symbol.list_outputs(), self.exector.outputs): + output_dict[key] = arr + output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) + self.exector.forward(is_train=True) + for key in output_dict: + output_dict[key].copyto(output_buff[key]) + self.exector.backward() + for key, arr in update_dict.items(): + if key != "deconv8_weight" and key != "bigscore_weight": + self.updater(key, arr, self.arg_params[key]) + pred_shape = self.exector.outputs[0].shape + label = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2])) + pred = mx.nd.array(output_buff["softmax_output"].asnumpy().reshape(pred_shape[0], pred_shape[1], pred_shape[2]*pred_shape[3])) + eval_metric.update([label], [pred]) + self.exector.outputs[0].wait_to_read() + if self.monitor is not None: + self.monitor.toc_print() + batch_end_params = BatchEndParam(epoch=epoch, + nbatch=nbatch, + eval_metric=eval_metric) + batch_end_callback(batch_end_params) + if epoch_end_callback != None and epoch % self.snapshot == 0: + epoch_end_callback(epoch, self.symbol, self.arg_params, self.aux_params) + name, value = eval_metric.get() + logger.info(" --->Epoch[%d] Train-%s=%f", + epoch, name, value) + # evaluation + if eval_data: + print "in eval process..." + nbatch = 0 + eval_data.reset() + eval_metric.reset() + for data in eval_data: + nbatch += 1 + label_shape = data[label_name].shape + self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) + self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2]), self.ctx) + exector = self.symbol.bind(self.ctx, self.arg_params, + args_grad=self.grad_params, + grad_req=grad_req, + aux_states=self.aux_params) + cpu_output_array = mx.nd.zeros(exector.outputs[0].shape) + exector.forward(is_train=False) + exector.outputs[0].copyto(cpu_output_array) + pred_shape = cpu_output_array.shape + label = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2])) + pred = mx.nd.array(cpu_output_array.asnumpy().reshape(pred_shape[0], pred_shape[1], pred_shape[2]*pred_shape[3])) + eval_metric.update([label], [pred]) + exector.outputs[0].wait_to_read() + name, value = eval_metric.get() + logger.info('batch[%d] Validation-%s=%f', nbatch, name, value) diff --git a/example/fcn-xs/symbol_fcnxs.py b/example/fcn-xs/symbol_fcnxs.py new file mode 100644 index 000000000000..4208cdfaa84e --- /dev/null +++ b/example/fcn-xs/symbol_fcnxs.py @@ -0,0 +1,478 @@ +# pylint: skip-file +import mxnet as mx + +def filter_map(kernel=1, stride=1, pad=0): + # why not return (stride, (kernel-stride)/2-pad)?? + return (stride, (kernel-1)/2-pad) + +def compose_fp(fp_first, fp_second): + return (fp_first[0]*fp_second[0], fp_first[0]*fp_second[1]+fp_first[1]) + +def compose_fp_list(fp_list): + fp_out = (1.0, 0.0) + for fp in fp_list: + fp_out = compose_fp(fp_out, fp) + return fp_out + +def inv_fp(fp_in): + return (1.0/fp_in[0], -1.0*fp_in[1]/fp_in[0]) + +def coord_map_fcn32s(): + conv1_1_fp = filter_map(kernel=3, pad=100) + conv1_2_fp = filter_map(kernel=3, pad=1) + pool1_fp = filter_map(kernel=2, stride=2) + conv2_1_fp = filter_map(kernel=3, pad=1) + conv2_2_fp = filter_map(kernel=3, pad=1) + pool2_fp = filter_map(kernel=2, stride=2) + conv3_1_fp = filter_map(kernel=3, pad=1) + conv3_2_fp = filter_map(kernel=3, pad=1) + conv3_3_fp = filter_map(kernel=3, pad=1) + pool3_fp = filter_map(kernel=2, stride=2) + conv4_1_fp = filter_map(kernel=3, pad=1) + conv4_2_fp = filter_map(kernel=3, pad=1) + conv4_3_fp = filter_map(kernel=3, pad=1) + pool4_fp = filter_map(kernel=2, stride=2) + conv5_1_fp = filter_map(kernel=3, pad=1) + conv5_2_fp = filter_map(kernel=3, pad=1) + conv5_3_fp = filter_map(kernel=3, pad=1) + pool5_fp = filter_map(kernel=2, stride=2) + fc6_fp = filter_map(kernel=7) + fc7_fp = filter_map() + fc8_fp = filter_map() + deconv8_fp = inv_fp(filter_map(kernel=64, stride=32)) + fp_list = [conv1_1_fp, conv1_2_fp, pool1_fp, + conv2_1_fp, conv2_2_fp, pool2_fp, + conv3_1_fp, conv3_2_fp, conv3_3_fp, pool3_fp, + conv4_1_fp, conv4_2_fp, conv4_3_fp, pool4_fp, + conv5_1_fp, conv5_2_fp, conv5_3_fp, pool5_fp, + fc6_fp, fc7_fp, fc8_fp, deconv8_fp] + crop = {} + crop["crop8"] = (-int(round(compose_fp_list(fp_list)[1])), + -int(round(compose_fp_list(fp_list)[1]))) + # print "crop8=", crop["crop8"] + return crop + +def coord_map_fcn16s(): + conv1_1_fp = filter_map(kernel=3, pad=100) + conv1_2_fp = filter_map(kernel=3, pad=1) + pool1_fp = filter_map(kernel=2, stride=2) + conv2_1_fp = filter_map(kernel=3, pad=1) + conv2_2_fp = filter_map(kernel=3, pad=1) + pool2_fp = filter_map(kernel=2, stride=2) + conv3_1_fp = filter_map(kernel=3, pad=1) + conv3_2_fp = filter_map(kernel=3, pad=1) + conv3_3_fp = filter_map(kernel=3, pad=1) + pool3_fp = filter_map(kernel=2, stride=2) + conv4_1_fp = filter_map(kernel=3, pad=1) + conv4_2_fp = filter_map(kernel=3, pad=1) + conv4_3_fp = filter_map(kernel=3, pad=1) + pool4_fp = filter_map(kernel=2, stride=2) + conv5_1_fp = filter_map(kernel=3, pad=1) + conv5_2_fp = filter_map(kernel=3, pad=1) + conv5_3_fp = filter_map(kernel=3, pad=1) + pool5_fp = filter_map(kernel=2, stride=2) + fc6_fp = filter_map(kernel=7) + fc7_fp = filter_map() + score_fp = filter_map() + score2_fp = inv_fp(filter_map(kernel=4, stride=2)) + score_pool4_fp = filter_map() + bigscore_fp = inv_fp(filter_map(kernel=32, stride=16)) + crop = {} + score_pool4c_fp_list = [inv_fp(score2_fp), inv_fp(score_fp), + inv_fp(fc7_fp), inv_fp(fc6_fp), + inv_fp(pool5_fp), inv_fp(conv5_3_fp), + inv_fp(conv5_2_fp), inv_fp(conv5_1_fp), + score_pool4_fp] + crop["score_pool4c"] = (-int(round(compose_fp_list(score_pool4c_fp_list)[1])), + -int(round(compose_fp_list(score_pool4c_fp_list)[1]))) + upscore_fp_list = [conv1_1_fp, conv1_2_fp, pool1_fp, + conv2_1_fp, conv2_2_fp, pool2_fp, + conv3_1_fp, conv3_2_fp, conv3_3_fp, pool3_fp, + conv4_1_fp, conv4_2_fp, conv4_3_fp, pool4_fp, + score_pool4_fp, inv_fp((1, -crop["score_pool4c"][0])), + bigscore_fp] + crop["upscore"] = (-int(round(compose_fp_list(upscore_fp_list)[1])), + -int(round(compose_fp_list(upscore_fp_list)[1]))) + # print "score_pool4c=", crop["score_pool4c"] + # print "upscore=", crop["upscore"] + return crop + +def coord_map_fcn8s(): + conv1_1_fp = filter_map(kernel=3, pad=100) + conv1_2_fp = filter_map(kernel=3, pad=1) + pool1_fp = filter_map(kernel=2, stride=2) + conv2_1_fp = filter_map(kernel=3, pad=1) + conv2_2_fp = filter_map(kernel=3, pad=1) + pool2_fp = filter_map(kernel=2, stride=2) + conv3_1_fp = filter_map(kernel=3, pad=1) + conv3_2_fp = filter_map(kernel=3, pad=1) + conv3_3_fp = filter_map(kernel=3, pad=1) + pool3_fp = filter_map(kernel=2, stride=2) + conv4_1_fp = filter_map(kernel=3, pad=1) + conv4_2_fp = filter_map(kernel=3, pad=1) + conv4_3_fp = filter_map(kernel=3, pad=1) + pool4_fp = filter_map(kernel=2, stride=2) + conv5_1_fp = filter_map(kernel=3, pad=1) + conv5_2_fp = filter_map(kernel=3, pad=1) + conv5_3_fp = filter_map(kernel=3, pad=1) + pool5_fp = filter_map(kernel=2, stride=2) + fc6_fp = filter_map(kernel=7) + fc7_fp = filter_map() + score_fp = filter_map() + score2_fp = inv_fp(filter_map(kernel=4, stride=2)) + score_pool4_fp = filter_map() + score4_fp = inv_fp(filter_map(kernel=4, stride=2)) + score_pool3_fp = filter_map() + bigscore_fp = inv_fp(filter_map(kernel=16, stride=8)) + crop = {} + score_pool4c_fp_list = [inv_fp(score2_fp), inv_fp(score_fp), + inv_fp(fc7_fp), inv_fp(fc6_fp), + inv_fp(pool5_fp), inv_fp(conv5_3_fp), + inv_fp(conv5_2_fp), inv_fp(conv5_1_fp), + score_pool4_fp] + crop["score_pool4c"] = (-int(round(compose_fp_list(score_pool4c_fp_list)[1])), + -int(round(compose_fp_list(score_pool4c_fp_list)[1]))) + score_pool3c_fp_list = [inv_fp(score4_fp), (1, -crop["score_pool4c"][0]), + inv_fp(score_pool4_fp), inv_fp(pool4_fp), + inv_fp(conv4_3_fp), inv_fp(conv4_2_fp), + inv_fp(conv4_1_fp), score_pool3_fp, + score_pool3_fp] + crop["score_pool3c"] = (-int(round(compose_fp_list(score_pool3c_fp_list)[1])), + -int(round(compose_fp_list(score_pool3c_fp_list)[1]))) + upscore_fp_list = [conv1_1_fp, conv1_2_fp, pool1_fp, + conv2_1_fp, conv2_2_fp, pool2_fp, + conv3_1_fp, conv3_2_fp, conv3_3_fp, + pool3_fp, score_pool3_fp, inv_fp((1, -crop["score_pool3c"][0])), + bigscore_fp] + crop["upscore"] = (-int(round(compose_fp_list(upscore_fp_list)[1])), + -int(round(compose_fp_list(upscore_fp_list)[1]))) + # print "score_pool4c=", crop["score_pool4c"] + # print "score_pool3c=", crop["score_pool3c"] + # print "upscore=", crop["upscore"] + return crop + +def get_fcn32s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + # group 1 + conv1_1 = mx.symbol.Convolution( + data=data, kernel=(3, 3), pad=(100, 100), num_filter=64, name="conv1_1", + workspace=workspace_default) # coord_map() + relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + conv1_2 = mx.symbol.Convolution( + data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_2", + workspace=workspace_default) + relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") + pool1 = mx.symbol.Pooling( + data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") + # group 2 + conv2_1 = mx.symbol.Convolution( + data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1", + workspace=workspace_default) + relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + conv2_2 = mx.symbol.Convolution( + data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_2", + workspace=workspace_default) + relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") + pool2 = mx.symbol.Pooling( + data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") + # group 3 + conv3_1 = mx.symbol.Convolution( + data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1", + workspace=workspace_default) + relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = mx.symbol.Convolution( + data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2", + workspace=workspace_default) + relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + conv3_3 = mx.symbol.Convolution( + data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_3", + workspace=workspace_default) + relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") + pool3 = mx.symbol.Pooling( + data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") + # group 4 + conv4_1 = mx.symbol.Convolution( + data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1", + workspace=workspace_default) + relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = mx.symbol.Convolution( + data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2", + workspace=workspace_default) + relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + conv4_3 = mx.symbol.Convolution( + data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_3", + workspace=workspace_default) + relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") + pool4 = mx.symbol.Pooling( + data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") + # group 5 + conv5_1 = mx.symbol.Convolution( + data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1", + workspace=workspace_default) + relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = mx.symbol.Convolution( + data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2", + workspace=workspace_default) + relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + conv5_3 = mx.symbol.Convolution( + data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_3", + workspace=workspace_default) + relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") + pool5 = mx.symbol.Pooling( + data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") + # group 6 + fc6 = mx.symbol.Convolution( + data=pool5, kernel=(7, 7), num_filter=4096, name="fc6", + workspace=workspace_default) + relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + # group 7 + fc7 = mx.symbol.Convolution( + data=drop6, kernel=(1, 1), num_filter=4096, name="fc7", + workspace=workspace_default) + relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + # group 8 + fc8 = mx.symbol.Convolution( + data=drop7, kernel=(1, 1), num_filter=21, name="fc8", + workspace=workspace_default) + deconv8 = mx.symbol.Deconvolution( + data=fc8, kernel=(64, 64), stride=(32, 32), + num_filter=21, name="deconv8", + workspace=workspace_default) + crop8 = mx.symbol.Crop( + data=deconv8, crop_like=data, + offset=coord_map_fcn32s()["crop8"], name="crop8") + softmax = mx.symbol.SoftmaxOutput( + data=crop8, multi_output=True, ignore_label=255, name="softmax") + return softmax + +def get_fcn16s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + # group 1 + conv1_1 = mx.symbol.Convolution( + data=data, kernel=(3, 3), pad=(100, 100), num_filter=64, name="conv1_1", + workspace=workspace_default) # coord_map() + relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + conv1_2 = mx.symbol.Convolution( + data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_2", + workspace=workspace_default) + relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") + pool1 = mx.symbol.Pooling( + data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") + # group 2 + conv2_1 = mx.symbol.Convolution( + data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1", + workspace=workspace_default) + relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + conv2_2 = mx.symbol.Convolution( + data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_2", + workspace=workspace_default) + relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") + pool2 = mx.symbol.Pooling( + data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") + # group 3 + conv3_1 = mx.symbol.Convolution( + data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1", + workspace=workspace_default) + relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = mx.symbol.Convolution( + data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2", + workspace=workspace_default) + relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + conv3_3 = mx.symbol.Convolution( + data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_3", + workspace=workspace_default) + relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") + pool3 = mx.symbol.Pooling( + data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") + # group 4 + conv4_1 = mx.symbol.Convolution( + data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1", + workspace=workspace_default) + relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = mx.symbol.Convolution( + data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2", + workspace=workspace_default) + relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + conv4_3 = mx.symbol.Convolution( + data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_3", + workspace=workspace_default) + relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") + pool4 = mx.symbol.Pooling( + data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") + # group 5 + conv5_1 = mx.symbol.Convolution( + data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1", + workspace=workspace_default) + relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = mx.symbol.Convolution( + data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2", + workspace=workspace_default) + relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + conv5_3 = mx.symbol.Convolution( + data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_3", + workspace=workspace_default) + relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") + pool5 = mx.symbol.Pooling( + data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") + # group 6 + fc6 = mx.symbol.Convolution( + data=pool5, kernel=(7, 7), num_filter=4096, name="fc6", + workspace=workspace_default) + relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + # group 7 + fc7 = mx.symbol.Convolution( + data=drop6, kernel=(1, 1), num_filter=4096, name="fc7", + workspace=workspace_default) + relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + # group 8 + score = mx.symbol.Convolution( + data=drop7, kernel=(1, 1), num_filter=21, name="score", + workspace=workspace_default) + # score 2X + score2 = mx.symbol.Deconvolution( + data=score, kernel=(4, 4), stride=(2, 2), + num_filter=21, name="score2", + workspace=workspace_default) # 2X + score_pool4 = mx.symbol.Convolution( + data=pool4, kernel=(1, 1), num_filter=21, name="score_pool4", + workspace=workspace_default) + score_pool4c = mx.symbol.Crop( + data=score_pool4, crop_like=score2, + offset=coord_map_fcn16s()["score_pool4c"], name="score_pool4c") # TODO + score_fused = mx.symbol.ElementWiseSum(*[score2, score_pool4c], name='score_fused') + # score out + bigscore = mx.symbol.Deconvolution( + data=score_fused, kernel=(32, 32), stride=(16, 16), + num_filter=21, name="bigscore", + workspace=workspace_default) # 16X TODO + upscore = mx.symbol.Crop( + data=bigscore, crop_like=data, + offset=coord_map_fcn16s()["upscore"], name="upscore") # TODO + softmax = mx.symbol.SoftmaxOutput( + data=upscore, multi_output=True, ignore_label=255, name="softmax") + return softmax + +def get_fcn8s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + # group 1 + conv1_1 = mx.symbol.Convolution( + data=data, kernel=(3, 3), pad=(100, 100), num_filter=64, name="conv1_1", + workspace=workspace_default) # coord_map() + relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + conv1_2 = mx.symbol.Convolution( + data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_2", + workspace=workspace_default) + relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") + pool1 = mx.symbol.Pooling( + data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") + # group 2 + conv2_1 = mx.symbol.Convolution( + data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1", + workspace=workspace_default) + relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + conv2_2 = mx.symbol.Convolution( + data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_2", + workspace=workspace_default) + relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") + pool2 = mx.symbol.Pooling( + data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") + # group 3 + conv3_1 = mx.symbol.Convolution( + data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1", + workspace=workspace_default) + relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = mx.symbol.Convolution( + data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2", + workspace=workspace_default) + relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + conv3_3 = mx.symbol.Convolution( + data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_3", + workspace=workspace_default) + relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") + pool3 = mx.symbol.Pooling( + data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") + # group 4 + conv4_1 = mx.symbol.Convolution( + data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1", + workspace=workspace_default) + relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = mx.symbol.Convolution( + data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2", + workspace=workspace_default) + relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + conv4_3 = mx.symbol.Convolution( + data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_3", + workspace=workspace_default) + relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") + pool4 = mx.symbol.Pooling( + data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") + # group 5 + conv5_1 = mx.symbol.Convolution( + data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1", + workspace=workspace_default) + relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = mx.symbol.Convolution( + data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2", + workspace=workspace_default) + relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + conv5_3 = mx.symbol.Convolution( + data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_3", + workspace=workspace_default) + relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") + pool5 = mx.symbol.Pooling( + data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") + # group 6 + fc6 = mx.symbol.Convolution( + data=pool5, kernel=(7, 7), num_filter=4096, name="fc6", + workspace=workspace_default) + relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + # group 7 + fc7 = mx.symbol.Convolution( + data=drop6, kernel=(1, 1), num_filter=4096, name="fc7", + workspace=workspace_default) + relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + # group 8 + score = mx.symbol.Convolution( + data=drop7, kernel=(1, 1), num_filter=21, name="score", + workspace=workspace_default) + # score 2X + score2 = mx.symbol.Deconvolution( + data=score, kernel=(4, 4), stride=(2, 2), + num_filter=21, name="score2", + workspace=workspace_default) # 2X + score_pool4 = mx.symbol.Convolution( + data=pool4, kernel=(1, 1), num_filter=21, name="score_pool4", + workspace=workspace_default) + score_pool4c = mx.symbol.Crop( + data=score_pool4, crop_like=score2, + offset=coord_map_fcn8s()["score_pool4c"], name="score_pool4c") # TODO + score_fused = mx.symbol.ElementWiseSum(*[score2, score_pool4c], name='score_fused') + # score 4X + score4 = mx.symbol.Deconvolution( + data=score_fused, kernel=(4, 4), stride=(2, 2), + num_filter=21, name="score4", + workspace=workspace_default) # 2X + score_pool3 = mx.symbol.Convolution( + data=pool3, kernel=(1, 1), num_filter=21, name="score_pool3", + workspace=workspace_default) + score_pool3c = mx.symbol.Crop( + data=score_pool3, crop_like=score4, + offset=coord_map_fcn8s()["score_pool3c"], name="score_pool3c") # TODO + score_final = mx.symbol.ElementWiseSum(*[score4, score_pool3c], name='score_final') + # score out + bigscore = mx.symbol.Deconvolution( + data=score_final, kernel=(16, 16), stride=(8, 8), + num_filter=21, name="bigscore", + workspace=workspace_default) # 8X + upscore = mx.symbol.Crop( + data=bigscore, crop_like=data, + offset=coord_map_fcn8s()["upscore"], name="upscore") # TODO + softmax = mx.symbol.SoftmaxOutput( + data=upscore, multi_output=True, ignore_label=255, name="softmax") + return softmax diff --git a/mshadow b/mshadow index 00ca771296cb..da390521662f 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 00ca771296cbdfbc026fbe26be01e35596c16ec4 +Subproject commit da390521662f99adcc7963e97141738b57974573 diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index 8d08e40ba7d3..b4060d4af04e 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -76,8 +76,9 @@ def __call__(self, param): if self.init: if count % self.frequent == 0: speed = self.frequent * self.batch_size / (time.time() - self.tic) - logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", - param.epoch, count, speed) + name, value = param.eval_metric.get() + logging.info("Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f", + param.epoch, count, speed, name, value) self.tic = time.time() else: self.init = True diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py index c008f058ab2a..e40e146a0af8 100644 --- a/python/mxnet/lr_scheduler.py +++ b/python/mxnet/lr_scheduler.py @@ -71,6 +71,6 @@ def __call__(self, num_update): if num_update > self.count + self.step: self.count += self.step self.base_lr *= self.factor - logging.info("Update[%d]: Change learning rate to %.5f", + logging.info("Update[%d]: Change learning rate to %0.5e", num_update, self.base_lr) return self.base_lr diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 69b3eea3f840..d8d058ad3ada 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -56,7 +56,7 @@ def update(self, labels, preds): if label.shape[0] < pred_label.shape[0]: raise Exception("Predict label is more than data label? ") self.sum_metric += numpy.sum(pred_label == label[:pred_label.shape[0]]) - num_inst = pred_label.shape[0] + num_inst = pred_label.size self.num_inst += num_inst diff --git a/python/mxnet/monitor.py b/python/mxnet/monitor.py index 4f1e236fb212..853135eb2a2a 100644 --- a/python/mxnet/monitor.py +++ b/python/mxnet/monitor.py @@ -52,7 +52,8 @@ def install(self, exe): the Executor (returned by symbol.bind) to install to. """ exe.set_monitor_callback(self.stat_helper) - self.exes.append(exe) + #self.exes.append(exe) + self.exes = [exe] def tic(self): """start collecting stats for current batch. diff --git a/src/operator/crop-inl.h b/src/operator/crop-inl.h new file mode 100644 index 000000000000..866a067f9d49 --- /dev/null +++ b/src/operator/crop-inl.h @@ -0,0 +1,182 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file crop-inl.h + * \brief + * \author Wei Wu +*/ +#ifndef MXNET_OPERATOR_CROP_INL_H_ +#define MXNET_OPERATOR_CROP_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +namespace crop_enum { +enum CropOpInputs {kData, kCropLike}; +enum CropOpOutputs {kOut}; +} // namespace crop_enum + +struct CropParam : public dmlc::Parameter { + TShape offset; + bool center_crop; + DMLC_DECLARE_PARAMETER(CropParam) { + int shape[] = {0, 0}; + DMLC_DECLARE_FIELD(offset).set_default(TShape(shape, shape + 2)) + .describe("corp offset coordinate: (y, x)"); + DMLC_DECLARE_FIELD(center_crop).set_default(false) + .describe("If set to true, then it will use be the center_crop," + "or it will crop using the shape of crop_like"); + } +}; // struct CropParam + +template +class CropOp : public Operator { + public: + explicit CropOp(CropParam param) { + this->param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(static_cast(in_data.size()), 2); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req[crop_enum::kOut], kWriteTo); + Stream *s = ctx.get_stream(); + Tensor data = in_data[crop_enum::kData].get(s); + Tensor out = out_data[crop_enum::kOut].get(s); + offset_hw_ = InferCropOfferset(data.shape_, out.shape_); + out = crop(data, Shape2(out.size(2), out.size(3)), offset_hw_[0], offset_hw_[1]); + } + + // because the crop_like input is only used with it's shape, so we should be + // careful setting its backwrd grad value to zeros, so that it will not hurt + // the connection of crop_like. + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_grad.size(), 2) << in_grad.size(); + CHECK_EQ(out_grad.size(), 1) << out_grad.size(); + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[crop_enum::kOut].get(s); + Tensor gdata = in_grad[crop_enum::kData].get(s); + Tensor gcrop_like = in_grad[crop_enum::kCropLike].get(s); + gcrop_like = (real_t)0.0f; + offset_hw_ = InferCropOfferset(gdata.shape_, grad.shape_); + gdata = (real_t)0.0f; + slice<3>(slice<2>(gdata, offset_hw_[0], offset_hw_[0]+grad.size(2)), + offset_hw_[1], offset_hw_[1]+grad.size(3)) = grad; + } + + private: + CropParam param_; + std::vector offset_hw_; + std::vector InferCropOfferset(mshadow::Shape<4> &data_shape, + mshadow::Shape<4> &out_shape) { + std::vector offset_hw; + CHECK_GE(data_shape[2], out_shape[2]) << + "data_shape'height should be larger than that of out_shape"; + CHECK_GE(data_shape[3], out_shape[3]) << + "data_shape'weight should be larger than that of out_shape"; + if (param_.center_crop) { + offset_hw.push_back(static_cast((data_shape[2]-out_shape[2])/2)); + offset_hw.push_back(static_cast((data_shape[3]-out_shape[3])/2)); + } else { + CHECK_GE(static_cast(param_.offset[0]), 0) << + "offset[0] should be larger than 0"; + CHECK_LE(static_cast(param_.offset[0]), data_shape[2]-out_shape[2]) << + "offset[0] should be less than the residual space of height"; + CHECK_GE(static_cast(param_.offset[1]), 0) << + "offset[1] should be larger than 0"; + CHECK_LE(static_cast(param_.offset[1]), data_shape[3]-out_shape[3]) << + "offset[1] should be less than the residual space of width"; + offset_hw.push_back(static_cast(param_.offset[0])); + offset_hw.push_back(static_cast(param_.offset[1])); + } + return offset_hw; + } +}; // class CropOp + +template +Operator *CreateOp(CropParam param); + +#if DMLC_USE_CXX11 +class CropProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + std::vector ListArguments() const override { + return {"data", "crop_like"}; + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 2) << "Input:[data, crop_like]"; + TShape data_shape = in_shape->at(crop_enum::kData); + if (data_shape.ndim() == 0) return false; + CHECK_EQ(data_shape.ndim(), 4) << \ + "Input data should be 4D in batch-num_filter-y-x"; + TShape crop_shape = in_shape->at(crop_enum::kCropLike); + if (crop_shape.ndim() == 0) return false; + CHECK_EQ(crop_shape.ndim(), 4) << \ + "Input crop_like should be 4D in batch-num_filter/batch-num_channel-y-x"; + out_shape->clear(); + data_shape[2] = crop_shape[2]; + data_shape[3] = crop_shape[3]; + out_shape->push_back(data_shape); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new CropProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "Crop"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return out_grad; + } + + Operator* CreateOperator(Context ctx) const override; + + private: + CropParam param_; +}; // class CropProp +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONCAT_INL_H_ diff --git a/src/operator/crop.cc b/src/operator/crop.cc new file mode 100644 index 000000000000..5a3315c24d63 --- /dev/null +++ b/src/operator/crop.cc @@ -0,0 +1,29 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat.cc + * \brief + * \author Wei Wu +*/ + +#include "./crop-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(CropParam param) { + return new CropOp(param); +} + +Operator* CropProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(CropParam); + +MXNET_REGISTER_OP_PROPERTY(Crop, CropProp) +.add_argument("data", "Symbol", "Input data to the CropOp.") +.add_argument("crop_like", "Symbol", "crop_like data to the CropOp.") +.add_arguments(CropParam::__FIELDS__()) +.describe("Crop the 2th and 3th dim of input data, with the corresponding size of crop_like."); +} // namespace op +} // namespace mxnet diff --git a/src/operator/crop.cu b/src/operator/crop.cu new file mode 100644 index 000000000000..64f8cb219f30 --- /dev/null +++ b/src/operator/crop.cu @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat.cu + * \brief + * \author Wei Wu +*/ + +#include "./crop-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(CropParam param) { + return new CropOp(param); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 9528ed0a41c6..b75edd21a04b 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -28,6 +28,7 @@ enum SoftmaxOutputOpOutputs {kOut}; struct SoftmaxOutputParam : public dmlc::Parameter { float grad_scale; bool multi_output; + float ignore_label; DMLC_DECLARE_PARAMETER(SoftmaxOutputParam) { DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f) .describe("Scale the gradient by a float factor"); @@ -35,6 +36,9 @@ struct SoftmaxOutputParam : public dmlc::Parameter { .describe("If set to true, for a (n,k,x_1,..,x_n) dimensional" "input tensor, softmax will generate n*x_1*...*x_n output, each" "has k classes"); + DMLC_DECLARE_FIELD(ignore_label).set_default(-1.0f) + .describe("the ignore_label will not work in backward, and this only" + "be used when multi_output=true"); }; }; @@ -56,7 +60,7 @@ class SoftmaxOutputOp : public Operator { if (param_.multi_output) { int n = in_data[softmaxout_enum::kData].size(0); int k = in_data[softmaxout_enum::kData].size(1); - Shape<3> s3 = Shape3(n, k, static_cast(in_data[softmaxout_enum::kData].Size()/n/k)); + Shape<3> s3 = Shape3(n, k, static_cast(in_data[softmaxout_enum::kData].Size()/n/k)); // 即(n, k, c) Tensor data = in_data[softmaxout_enum::kData].get_with_shape(s3, s); Tensor out = out_data[softmaxout_enum::kOut].get_with_shape(s3, s); Softmax(out, data); @@ -88,7 +92,7 @@ class SoftmaxOutputOp : public Operator { Tensor label = in_data[softmaxout_enum::kLabel].FlatTo2D(s); Tensor out = out_data[softmaxout_enum::kOut].get_with_shape(s3, s); Tensor grad = in_grad[softmaxout_enum::kData].get_with_shape(s3, s); - SoftmaxGrad(grad, out, label); + SoftmaxGrad(grad, out, label, static_cast(param_.ignore_label)); if (param_.grad_scale < 1.0) { grad *= param_.grad_scale; }