diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py index 50625b4b17a4..3a197a416219 100644 --- a/nnvm/python/nnvm/frontend/darknet.py +++ b/nnvm/python/nnvm/frontend/darknet.py @@ -673,6 +673,31 @@ def _handle_darknet_rnn_layers(self, layer_num, sym): self._sym_array[layer_num] = sym processed = True + elif LAYERTYPE.CRNN == layer.type: + attr.update({'n' : layer.n}) + attr.update({'batch' : layer.batch}) + attr.update({'num_hidden' : str(layer.outputs)}) + + state = self._get_rnn_state_buffer(layer) + + for _ in range(layer.steps): + input_layer = layer.input_layer + sym = self._get_darknet_rnn_attrs(input_layer, sym) + + self_layer = layer.self_layer + state = self._get_darknet_rnn_attrs(self_layer, state) + + op_name, new_attrs = 'elemwise_add', {} + new_inputs = _as_list([sym, state]) + state = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs) + self._outs.append(state) + + output_layer = layer.output_layer + sym = self._get_darknet_rnn_attrs(output_layer, state) + + self._sym_array[layer_num] = sym + processed = True + return processed, sym def from_darknet(self): diff --git a/nnvm/tests/python/frontend/darknet/test_forward.py b/nnvm/tests/python/frontend/darknet/test_forward.py index 961da238452c..e68aed085664 100644 --- a/nnvm/tests/python/frontend/darknet/test_forward.py +++ b/nnvm/tests/python/frontend/darknet/test_forward.py @@ -324,6 +324,31 @@ def test_forward_rnn(): test_rnn_forward(net) LIB.free_network(net) +def test_forward_crnn(): + '''test softmax layer''' + net = LIB.make_network(1) + batch = 1 + c = 3 + h = 224 + w = 224 + hidden_filters = c + output_filters = c + steps = 1 + activation = 0 + batch_normalize = 0 + inputs = 256 + outputs = 256 + layer_1 = LIB.make_crnn_layer(batch, h, w, c, hidden_filters, output_filters, + steps, activation, batch_normalize) + net.layers[0] = layer_1 + net.inputs = inputs + net.outputs = output_filters * h * w + net.w = w + net.h = h + LIB.resize_network(net, net.w, net.h) + test_forward(net) + LIB.free_network(net) + def test_forward_activation_logistic(): '''test logistic activation layer''' net = LIB.make_network(1) @@ -369,4 +394,5 @@ def test_forward_activation_logistic(): test_forward_region() test_forward_elu() test_forward_rnn() + test_forward_crnn() test_forward_activation_logistic() \ No newline at end of file