Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f36c352

Browse files
committed
Fix GPU build
1 parent 4e65b41 commit f36c352

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

python/mxnet/amp/lists/symbol_fp16.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,6 @@
641641
'_mod',
642642
'_not_equal',
643643
'_npi_column_stack',
644-
'_npi_concatenate',
645644
'_npi_copysign',
646645
'_npi_cross',
647646
'_npi_dot',

src/operator/subgraph/tensorrt/nnvm_to_onnx.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,13 +669,14 @@ void ConvertConcatenate(GraphProto* graph_proto,
669669
NodeProto* node_proto = graph_proto->add_node();
670670
node_proto->set_name(node_name);
671671
const auto& _param = nnvm::get<ConcatParam>(attrs.parsed);
672+
const int param_dim = _param.dim.has_value() ? _param.dim.value() : 0;
672673
node_proto->set_op_type("Concat");
673674
node_proto->set_name(attrs.name);
674675
// axis
675676
AttributeProto* const axis = node_proto->add_attribute();
676677
axis->set_name("axis");
677678
axis->set_type(AttributeProto::INT);
678-
axis->set_i(static_cast<int64_t>(_param.dim));
679+
axis->set_i(static_cast<int64_t>(param_dim));
679680
DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
680681
}
681682

src/operator/subgraph/tensorrt/tensorrt-inl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ class TensorrtSelector : public SubgraphSelector {
193193

194194
if (op_name == "Concat") {
195195
const auto& param = nnvm::get<ConcatParam>(n.attrs.parsed);
196-
return (param.dim != 0);
196+
const int param_dim = param.dim.has_value() ? param.dim.value() : 0;
197+
return (param_dim != 0);
197198
}
198199

199200
if (op_name == "Dropout") {

0 commit comments

Comments
 (0)