From b9bb2ed0b7e377a83220d51210f0c0d3143f5796 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Thu, 21 May 2020 07:32:10 -0700 Subject: [PATCH 1/2] Add backward Type inference to main DNN operators Signed-off-by: Serge Panev --- src/operator/contrib/batch_norm_relu.cc | 28 +++++++++++++++++-------- src/operator/nn/batch_norm.cc | 28 +++++++++++++++++-------- src/operator/nn/convolution.cc | 13 +++++++++--- src/operator/nn/deconvolution.cc | 13 +++++++++--- src/operator/softmax_output.cc | 13 +++++++++--- 5 files changed, 68 insertions(+), 27 deletions(-) diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index 14452cc96729..25372ddb983b 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -84,14 +84,30 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); - const int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + const size_t n_out = 4; // For float16 input type beta, gamma, mean, and average are stored in float32. // For other input types, these parameters have the same type as input // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + int dtype = (*in_type)[0]; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + } + } else { + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype); + for (size_t i = 1; i < n_out; ++i) { + out_type->push_back(dtype_param); + } + } std::vector args{"data", "gamma", "beta", "mean", "var"}; CHECK_LE(in_type->size(), args.size()); for (size_t i = 1; i < in_type->size(); ++i) { @@ -101,12 +117,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]); } } - const size_t n_out = 4; - out_type->clear(); - out_type->push_back(dtype); - for (size_t i = 1; i < n_out; ++i) { - out_type->push_back(dtype_param); - } return true; } diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 815288cfe554..4b6bd939ae6b 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -352,14 +352,30 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); - const int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + const size_t n_out = 3; // For float16 input type beta, gamma, mean, and average are stored in float32. // For other input types, these parameters have the same type as input // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + int dtype = (*in_type)[0]; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + } + } else { + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype); + for (size_t i = 1; i < n_out; ++i) { + out_type->push_back(dtype_param); + } + } std::vector args{"data", "gamma", "beta", "mean", "var"}; CHECK_LE(in_type->size(), args.size()); for (size_t i = 1; i < in_type->size(); ++i) { @@ -369,12 +385,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]); } } - const size_t n_out = 3; - out_type->clear(); - out_type->push_back(dtype); - for (size_t i = 1; i < n_out; ++i) { - out_type->push_back(dtype_param); - } return true; } diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 8ff5ea75d5f7..3ebb67ad0aa0 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -285,7 +285,16 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param_ = nnvm::get(attrs.parsed); CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + } + } else { + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -293,8 +302,6 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index cd22aced0d03..0b0fecc1849f 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -332,7 +332,16 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, const DeconvolutionParam& param_ = nnvm::get(attrs.parsed); CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + } + } else { + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -340,8 +349,6 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 13bb647f9d43..da02fcfabbf5 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -66,7 +66,16 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, std::vector *out_type) { CHECK_EQ(in_type->size(), 2U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + } + } else { + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -74,8 +83,6 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } From 231d8b6c2bb677d51c79e0486570d1e5525a7237 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Thu, 4 Jun 2020 00:01:03 -0700 Subject: [PATCH 2/2] Add comments Signed-off-by: Serge Panev --- src/operator/contrib/batch_norm_relu.cc | 12 +++++++++--- src/operator/nn/batch_norm.cc | 5 +++++ src/operator/nn/deconvolution.cc | 5 +++++ src/operator/softmax_output.cc | 5 +++++ 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index 25372ddb983b..e2c28e0b055b 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -90,16 +90,22 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; int dtype = (*in_type)[0]; + if (type_is_none(dtype)) { - if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op return false; - } else { + } else { + // Input type is undefined but output type is: backward inference dtype = (*out_type)[0]; (*in_type)[0] = dtype; MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); - } + } } else { + // Input type is defined but output type is not: forward inference MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); out_type->clear(); diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 4b6bd939ae6b..e3907993399d 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -359,15 +359,20 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, int dtype_param; int dtype = (*in_type)[0]; if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op return false; } else { + // Input type is undefined but output type is: backward inference dtype = (*out_type)[0]; (*in_type)[0] = dtype; MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); } } else { + // Input type is defined but output type is not: forward inference MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); out_type->clear(); diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index 0b0fecc1849f..08d6306730ef 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -333,12 +333,17 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op return false; } else { + // Input type is undefined but output type is: backward inference dtype = (*out_type)[0]; } } else { + // Input type is defined but output type is not: forward inference out_type->clear(); out_type->push_back(dtype); } diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index da02fcfabbf5..d87b78145e9e 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -67,12 +67,17 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_type->size(), 2U); int dtype = (*in_type)[0]; if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op return false; } else { + // Input type is undefined but output type is: backward inference dtype = (*out_type)[0]; } } else { + // Input type is defined but output type is not: forward inference out_type->clear(); out_type->push_back(dtype); }