This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Seg #940
Closed
Closed
Seg #940
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
e150a86
add vgg16_fcn model change, and rename fcn-xs branch to seg branch
tornadomeet 88d23b8
update convenience for me when debug
tornadomeet 7ba1c62
Merge branch 'master' of https://github.com/dmlc/mxnet into seg
tornadomeet 9a9403a
Merge branch 'master' of https://github.com/dmlc/mxnet into seg
tornadomeet 4678892
update for segmentation
tornadomeet 1a1fbb5
update for seg
tornadomeet 4fdd318
Merge branch 'master' of https://github.com/dmlc/mxnet into seg
tornadomeet 5ef42f7
Merge branch 'master' of https://github.com/dmlc/mxnet into seg
tornadomeet 26f95c6
add the dependent source code for image segmentation
tornadomeet b7e8bc2
add the fcn-xs example
tornadomeet 119bb69
add fcn-xs result image
tornadomeet 2ffdb05
add result image
tornadomeet File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Submodule dmlc-core
updated
6 files
| +4 −49 | include/dmlc/data.h | |
| +2 −2 | include/dmlc/registry.h | |
| +1 −1 | scripts/lint.py | |
| +1 −1 | scripts/lint3.py | |
| +14 −32 | src/data.cc | |
| +1 −1 | tracker/tracker.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| FCN-xs EXAMPLES | ||
| --------------- | ||
| This folder contains the examples of image segmentation in MXNet. | ||
|
|
||
| ## Sample results | ||
|  | ||
|
|
||
| ## How to train fcn-xs in mxnet | ||
| #### step1: get the fully convulutional style of vgg16 model | ||
| * dwonload the vgg16 caffe-model from [VGG_ILSVRC_16_layers.caffemodel](http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel),and the corresponding [VGG_ILSVRC_16_layers_deploy.prototxt](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt), the vgg16 model has [license](http://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only. | ||
| * use convert_model.py to convet the caffe model to mxnet model, like(shell): | ||
| ``` | ||
| vgg16_deploy=VGG_ILSVRC_16_layers_deploy.prototxt | ||
| vgg16_caffemodel=VGG_ILSVRC_16_layers.caffemodel | ||
| model_prefix=VGG_ILSVRC_16_layers | ||
| cmd=../../tools/caffe_converter/convert_model.py | ||
| python $cmd $vgg16_deploy $vgg16_caffemodel VGG_ILSVRC_16_layers | ||
| mv VGG_ILSVRC_16_layers-0001.params VGG_ILSVRC_16_layers-0074.params | ||
| ``` | ||
| * convet conv+fully-connect style to fully convolutional style, like(shell): | ||
| ``` | ||
| python create_vgg16fc_model.py VGG_ILSVRC_16_layers 74 VGG_FC_ILSVRC_16_layers | ||
| ``` | ||
| you can use vgg16fc model now, or you can download it directly from [yun.baidu](http://pan.baidu.com/s/1jGlOvno). | ||
|
|
||
| **`be careful: if you put one (very) large image to the vgg16fc model, you should change the 'workspace_default' value larger(Related to your field) in create_vgg16fc_model.py.`** | ||
| * or you can directly download the ```VGG_FC_ILSVRC_16_layers-symbol.json``` and ```VGG_FC_ILSVRC_16_layers-0074.params``` from [yun.baidu](http://pan.baidu.com/s/1jGlOvno) | ||
|
|
||
| #### step2: prepare your training Data | ||
| in the example here, the training image list owns the form: | ||
| ```index \t image_data_path \t image_label_path``` | ||
| the labels for one image in image segmentation field is also one image, with the same shape of input image. | ||
| * or you can directly download the ```VOC2012.rar``` from [yun.baidu](http://pan.baidu.com/s/1jGlOvno), and Extract it. the file/folder will be: | ||
| ```JPEGImages folder```, ```SegmentationClass folder```, ```train.lst```, ```val.lst```, ```test.lst``` | ||
|
|
||
| #### step3: begin training fcn-xs | ||
| if you want to train the fcn-8s model, it's better for you trained the fcn-32s and fcn-16s model firstly. | ||
| when training the fcn-32s model, run in shell ```./run_fcn32s.sh``` | ||
| * in the fcn_xs.py(e.g. fcn_32s.py, fcn_16s.py, fcn_8s.py), you may need to change the directory ```img_dir```, ```train_lst```, ```val_lst```, ```fcnxs_model_prefix``` | ||
| * the output log may like this(when training fcn-8s): | ||
| ```c++ | ||
| INFO:root:Start training with gpu(3) | ||
| INFO:root:Epoch[0] Batch [50] Speed: 1.16 samples/sec Train-accuracy=0.894318 | ||
| INFO:root:Epoch[0] Batch [100] Speed: 1.11 samples/sec Train-accuracy=0.904681 | ||
| INFO:root:Epoch[0] Batch [150] Speed: 1.13 samples/sec Train-accuracy=0.908053 | ||
| INFO:root:Epoch[0] Batch [200] Speed: 1.12 samples/sec Train-accuracy=0.912219 | ||
| INFO:root:Epoch[0] Batch [250] Speed: 1.13 samples/sec Train-accuracy=0.914238 | ||
| INFO:root:Epoch[0] Batch [300] Speed: 1.13 samples/sec Train-accuracy=0.912170 | ||
| INFO:root:Epoch[0] Batch [350] Speed: 1.12 samples/sec Train-accuracy=0.912080 | ||
| ``` | ||
|
|
||
| ## TODO | ||
| * add the example of using pretrained model | ||
| * add the crop_offset function in symbol(both c++ and python side of mxnet) | ||
| * make the example more cleaner(the code is some dirty here) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,204 @@ | ||
| # pylint: skip-file | ||
| import sys, os | ||
| import argparse | ||
| import mxnet as mx | ||
| import numpy as np | ||
| import logging | ||
|
|
||
| logger = logging.getLogger() | ||
| logger.setLevel(logging.DEBUG) | ||
|
|
||
| workspace_default = 1024 | ||
|
|
||
| ## define vgg16 | ||
| def get_vgg16_symbol(): | ||
| data = mx.symbol.Variable(name="data") | ||
| # group 1 | ||
| conv1_1 = mx.symbol.Convolution( | ||
| data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1") | ||
| relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") | ||
| conv1_2 = mx.symbol.Convolution( | ||
| data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_2") | ||
| relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") | ||
| pool1 = mx.symbol.Pooling( | ||
| data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") | ||
| # group 2 | ||
| conv2_1 = mx.symbol.Convolution( | ||
| data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1") | ||
| relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") | ||
| conv2_2 = mx.symbol.Convolution( | ||
| data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_2") | ||
| relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") | ||
| pool2 = mx.symbol.Pooling( | ||
| data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") | ||
| # group 3 | ||
| conv3_1 = mx.symbol.Convolution( | ||
| data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1") | ||
| relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") | ||
| conv3_2 = mx.symbol.Convolution( | ||
| data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2") | ||
| relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") | ||
| conv3_3 = mx.symbol.Convolution( | ||
| data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_3") | ||
| relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") | ||
| pool3 = mx.symbol.Pooling( | ||
| data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") | ||
| # group 4 | ||
| conv4_1 = mx.symbol.Convolution( | ||
| data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1") | ||
| relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") | ||
| conv4_2 = mx.symbol.Convolution( | ||
| data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2") | ||
| relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") | ||
| conv4_3 = mx.symbol.Convolution( | ||
| data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_3") | ||
| relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") | ||
| pool4 = mx.symbol.Pooling( | ||
| data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") | ||
| # group 5 | ||
| conv5_1 = mx.symbol.Convolution( | ||
| data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1") | ||
| relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") | ||
| conv5_2 = mx.symbol.Convolution( | ||
| data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2") | ||
| relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") | ||
| conv5_3 = mx.symbol.Convolution( | ||
| data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_3") | ||
| relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") | ||
| pool5 = mx.symbol.Pooling( | ||
| data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") | ||
| # group 6 | ||
| flatten = mx.symbol.Flatten(data=pool5, name="flatten") | ||
| fc6 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096, name="fc6") | ||
| relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") | ||
| drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") | ||
| # group 7 | ||
| fc7 = mx.symbol.FullyConnected(data=drop6, num_hidden=4096, name="fc7") | ||
| relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") | ||
| drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") | ||
| # output | ||
| fc8 = mx.symbol.FullyConnected(data=drop7, num_hidden=1000, name="fc8") | ||
| softmax = mx.symbol.SoftmaxOutput(data=fc8, name="prob") | ||
| return softmax | ||
|
|
||
| ## define vgg16 | ||
| def get_vgg16fc_symbol(): | ||
| data = mx.symbol.Variable(name="data") | ||
| # group 1 | ||
| conv1_1 = mx.symbol.Convolution( | ||
| data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1", workspace=workspace_default) | ||
| relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") | ||
| conv1_2 = mx.symbol.Convolution( | ||
| data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_2", workspace=workspace_default) | ||
| relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") | ||
| pool1 = mx.symbol.Pooling( | ||
| data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") | ||
| # group 2 | ||
| conv2_1 = mx.symbol.Convolution( | ||
| data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1", workspace=workspace_default) | ||
| relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") | ||
| conv2_2 = mx.symbol.Convolution( | ||
| data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_2", workspace=workspace_default) | ||
| relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") | ||
| pool2 = mx.symbol.Pooling( | ||
| data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") | ||
| # group 3 | ||
| conv3_1 = mx.symbol.Convolution( | ||
| data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1", workspace=workspace_default) | ||
| relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") | ||
| conv3_2 = mx.symbol.Convolution( | ||
| data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2", workspace=workspace_default) | ||
| relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") | ||
| conv3_3 = mx.symbol.Convolution( | ||
| data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_3", workspace=workspace_default) | ||
| relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") | ||
| pool3 = mx.symbol.Pooling( | ||
| data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") | ||
| # group 4 | ||
| conv4_1 = mx.symbol.Convolution( | ||
| data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1", workspace=workspace_default) | ||
| relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") | ||
| conv4_2 = mx.symbol.Convolution( | ||
| data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2", workspace=workspace_default) | ||
| relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") | ||
| conv4_3 = mx.symbol.Convolution( | ||
| data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_3", workspace=workspace_default) | ||
| relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") | ||
| pool4 = mx.symbol.Pooling( | ||
| data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") | ||
| # group 5 | ||
| conv5_1 = mx.symbol.Convolution( | ||
| data=pool4, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_1", workspace=workspace_default) | ||
| relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") | ||
| conv5_2 = mx.symbol.Convolution( | ||
| data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2", workspace=workspace_default) | ||
| relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") | ||
| conv5_3 = mx.symbol.Convolution( | ||
| data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_3", workspace=workspace_default) | ||
| relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") | ||
| pool5 = mx.symbol.Pooling( | ||
| data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") | ||
| # group 6 | ||
| fc6 = mx.symbol.Convolution( | ||
| data=pool5, kernel=(7, 7), num_filter=4096, name="fc6", workspace=workspace_default) | ||
| relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") | ||
| drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") | ||
| # group 7 | ||
| fc7 = mx.symbol.Convolution( | ||
| data=drop6, kernel=(1, 1), num_filter=4096, name="fc7", workspace=workspace_default) | ||
| relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") | ||
| drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") | ||
| # group 8 | ||
| fc8 = mx.symbol.Convolution( | ||
| data=drop7, kernel=(1, 1), num_filter=1000, name="fc8", workspace=workspace_default) | ||
| # output | ||
| softmax = mx.symbol.SoftmaxOutput(data=fc8, multi_output=True, name="prob") | ||
| return softmax | ||
|
|
||
| def get_vgg16fc_arg_param(vgg16_arg_params): | ||
| vgg16fc_arg_param = vgg16_arg_params | ||
| vgg16fc_arg_param["fc6_weight"] = \ | ||
| mx.nd.array(vgg16_arg_params["fc6_weight"].asnumpy().reshape(4096,512,7,7)) | ||
| vgg16fc_arg_param["fc7_weight"] = \ | ||
| mx.nd.array(vgg16_arg_params["fc7_weight"].asnumpy().reshape(4096,4096,1,1)) | ||
| vgg16fc_arg_param["fc8_weight"] = \ | ||
| mx.nd.array(vgg16_arg_params["fc8_weight"].asnumpy().reshape(1000,4096,1,1)) | ||
| return vgg16fc_arg_param | ||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description='Convert vgg16 model to vgg16fc model.') | ||
| parser.add_argument('prefix', default='VGG_ILSVRC_16_layers', | ||
| help='The prefix(include path) of vgg16 model with mxnet format.') | ||
| parser.add_argument('epoch', type=int, default=74, | ||
| help='The epoch number of vgg16 model.') | ||
| parser.add_argument('prefix_fc', default='VGG_FC_ILSVRC_16_layers', | ||
| help='The prefix(include path) of vgg16fc model which your want to save.') | ||
| args = parser.parse_args() | ||
|
|
||
| vgg16_symbol, vgg16_arg_params, vgg16_aux_params = \ | ||
| mx.model.load_checkpoint(args.prefix, args.epoch) | ||
|
|
||
| ## when your get original vgg16 symbol, here is two way: | ||
| # way 1: use get_vgg16_symbol() to get symbol, and it's good in | ||
| # transfer learning. | ||
| # softmax = get_vgg16_symbol() | ||
| # vgg_model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, | ||
| # arg_params=vgg16_arg_params, aux_params=vgg16_aux_params) | ||
|
|
||
| # way 2: use caffe converter get the symbol directly | ||
| # softmax = get_vgg16_symbol() | ||
| # vgg_model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, | ||
| # arg_params=vgg16_symbol, aux_params=vgg16_aux_params) | ||
|
|
||
| # vgg16fc_mxnet | ||
| softmax = get_vgg16fc_symbol() | ||
| vgg16fc_arg_params = get_vgg16fc_arg_param(vgg16_arg_params) | ||
| vgg_model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, | ||
| arg_params=vgg16fc_arg_params, aux_params=vgg16_aux_params) | ||
|
|
||
| # vgg_model.save(prefix=args.prefix_fc, epoch=1) | ||
| vgg_model.save(prefix=args.prefix_fc, epoch=args.epoch) | ||
| return "" | ||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| # pylint: skip-file | ||
| """ data iterator for pasval voc 2012""" | ||
| import mxnet as mx | ||
| import numpy as np | ||
| import sys | ||
| import os | ||
| from mxnet.io import DataIter | ||
| from skimage import io | ||
| from PIL import Image | ||
|
|
||
| class FileIter(DataIter): | ||
| """FileIter object in mxnet. Taking NDArray or numpy array to get dataiter. | ||
| Parameters | ||
| ---------- | ||
| data_list or data, label: a list of, or two separate NDArray or numpy.ndarray | ||
| list of NDArray for data. The last one is treated as label. | ||
| batch_size: int | ||
| Batch Size | ||
| shuffle: bool | ||
| Whether to shuffle the data | ||
| data_pad_value: float, optional | ||
| Padding value for data | ||
| label_pad_value: float, optionl | ||
| Padding value for label | ||
| last_batch_handle: 'pad', 'discard' or 'roll_over' | ||
| How to handle the last batch | ||
| Note | ||
| ---- | ||
| This iterator will pad, discard or roll over the last batch if | ||
| the size of data does not match batch_size. Roll over is intended | ||
| for training and can cause problems if used for prediction. | ||
| """ | ||
| def __init__(self, root_dir, flist_name, data_name="data", label_name="softmax_label"): | ||
| super(FileIter, self).__init__() | ||
| self.root_dir = root_dir | ||
| self.data_name = data_name | ||
| self.label_name = label_name | ||
| self.flist_name = os.path.join(self.root_dir, flist_name) | ||
| self.num_data = len(open(self.flist_name, 'r').readlines()) | ||
| self.img_path = "img_path" | ||
| self.f = open(self.flist_name, 'r') | ||
| self.mean = np.array([123.68, 116.779, 103.939]) # (R, G, B) | ||
| self.data, self.label, self.img_name = self._read(self.f) | ||
| self.cursor = -1 | ||
|
|
||
| def _read(self, f): | ||
| _, data_img_name, label_img_name = f.readline().strip('\n').split("\t") | ||
| data = {} | ||
| label = {} | ||
| data[self.data_name] = self._read_img(data_img_name) | ||
| label[self.label_name] = self._read_img(label_img_name, True) | ||
| return list(data.items()), list(label.items()), data_img_name | ||
|
|
||
| def _read_img(self, img_name, is_label_img=False): | ||
| if not is_label_img: | ||
| img = Image.open(os.path.join(self.root_dir, img_name)) | ||
| img = np.array(img, dtype=np.float32) # (h, w, c) | ||
| reshaped_mean = self.mean.reshape(1, 1, 3) | ||
| img = img - reshaped_mean | ||
| img = np.swapaxes(img, 0, 2) | ||
| img = np.swapaxes(img, 1, 2) # (c, h, w) | ||
| else: | ||
| img = Image.open(os.path.join(self.root_dir, img_name)) | ||
| img = np.array(img) # (h, w) | ||
| # img[img==255] = 0 # change the value of 255 to 0 | ||
| img = np.expand_dims(img, axis=0) # (1, c, h, w) or (1, h, w) | ||
| return img | ||
|
|
||
| @property | ||
| def provide_data(self): | ||
| """The name and shape of data provided by this iterator""" | ||
| return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.data] | ||
|
|
||
| @property | ||
| def provide_label(self): | ||
| """The name and shape of label provided by this iterator""" | ||
| return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.label] | ||
|
|
||
| @property | ||
| def batch_size(self): | ||
| return 1 | ||
|
|
||
| def reset(self): | ||
| self.cursor = -1 | ||
| self.f.close() | ||
| self.f = open(self.flist_name, 'r') | ||
|
|
||
| def iter_next(self): | ||
| self.cursor += 1 | ||
| if(self.cursor < self.num_data-1): | ||
| return True | ||
| else: | ||
| return False | ||
|
|
||
| def next(self): | ||
| if self.iter_next(): | ||
| self.data, self.label, self.img_name = self._read(self.f) | ||
| return {self.data_name:self.getdata(), | ||
| self.label_name:self.getlabel(), | ||
| self.img_path:self.img_name} | ||
| else: | ||
| raise StopIteration | ||
|
|
||
| def getdata(self): | ||
| return self.data[0][1] | ||
|
|
||
| def getlabel(self): | ||
| return self.label[0][1] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This url should point to the image included in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can commit image to
dmlc/web-datato reduce repo size