From cc3d2687b328d03bd0e9e98bb177ac160f6cd82a Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Thu, 11 Jun 2020 12:45:36 -0700 Subject: [PATCH 1/5] Initial - Add DT_HALF support --- ngraph_bridge/ngraph_builder.cc | 4 +++- ngraph_bridge/ngraph_mark_for_clustering.cc | 3 +-- ngraph_bridge/ngraph_utils.cc | 20 ++++++++++++-------- ngraph_bridge/ngraph_utils.h | 4 ++++ 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index dbf6af9dd..dc1dd4373 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -350,7 +350,9 @@ Builder::TF_NGRAPH_CONST_MAP() { {DataType::DT_UINT16, make_pair(MakeConstOp, ng::element::u16)}, {DataType::DT_BOOL, - make_pair(MakeConstOp, ng::element::boolean)}}; + make_pair(MakeConstOp, ng::element::boolean)}, + {DataType::DT_HALF, + make_pair(MakeConstOp, ng::element::f16)}}; return the_map; } diff --git a/ngraph_bridge/ngraph_mark_for_clustering.cc b/ngraph_bridge/ngraph_mark_for_clustering.cc index f55f09a7f..db40c685f 100644 --- a/ngraph_bridge/ngraph_mark_for_clustering.cc +++ b/ngraph_bridge/ngraph_mark_for_clustering.cc @@ -78,9 +78,7 @@ static Status TypeConstraintOk(Node* node, for (const auto& name_and_set : itr->second) { auto& type_attr_name = name_and_set.first; auto& allowed_types = name_and_set.second; - DataType dt; - if (GetNodeAttr(node->attrs(), type_attr_name, &dt) != Status::OK() || std::find(allowed_types.begin(), allowed_types.end(), dt) == allowed_types.end()) { @@ -577,6 +575,7 @@ const TypeConstraintMap& GetTypeConstraintMap() { type_constraint_map["NonMaxSuppressionV4"]["T"] = { DT_FLOAT}; // TF allows half too type_constraint_map["OneHot"]["T"] = NGraphDTypes(); + type_constraint_map["OneHot"]["TI"] = NGraphIndexDTypes(); type_constraint_map["Pack"]["T"] = NGraphDTypes(); type_constraint_map["RandomUniform"]["T"] = NGraphDTypes(); type_constraint_map["Pad"]["T"] = NGraphDTypes(); diff --git a/ngraph_bridge/ngraph_utils.cc b/ngraph_bridge/ngraph_utils.cc index 79fd93d3f..5f8c078cb 100644 --- a/ngraph_bridge/ngraph_utils.cc +++ b/ngraph_bridge/ngraph_utils.cc @@ -276,6 +276,9 @@ Status TFDataTypeToNGraphElementType(DataType tf_dt, case DataType::DT_BFLOAT16: *ng_et = ng::element::bf16; break; + case DataType::DT_HALF: + *ng_et = ng::element::f16; + break; default: return errors::Unimplemented("Unsupported TensorFlow data type: ", DataType_Name(tf_dt)); @@ -325,23 +328,23 @@ void print_node_histogram(const std::unordered_map& histogram, const gtl::ArraySlice& NGraphDTypes() { static gtl::ArraySlice result{ - DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, - DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, - DT_BOOL, DT_QINT8, DT_QUINT8, DT_BFLOAT16}; + DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, + DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, + DT_BOOL, DT_QINT8, DT_QUINT8, DT_BFLOAT16, DT_HALF}; return result; } const gtl::ArraySlice& NGraphNumericDTypes() { static gtl::ArraySlice result{ - DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, - DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_BFLOAT16}; + DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, + DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_BFLOAT16, DT_HALF}; return result; } const gtl::ArraySlice& NGraphNumericAndQuantizedDTypes() { static gtl::ArraySlice result{ - DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, - DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_QINT8, DT_QUINT8}; + DT_FLOAT, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, + DT_UINT16, DT_UINT32, DT_UINT64, DT_QINT8, DT_QUINT8, DT_HALF}; return result; } @@ -362,7 +365,8 @@ const gtl::ArraySlice& NGraphSupportedQuantizedDTypes() { } const gtl::ArraySlice& NGraphRealDTypes() { - static gtl::ArraySlice result{DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}; + static gtl::ArraySlice result{DT_FLOAT, DT_DOUBLE, DT_BFLOAT16, + DT_HALF}; return result; } diff --git a/ngraph_bridge/ngraph_utils.h b/ngraph_bridge/ngraph_utils.h index b1987be60..adfb421fa 100644 --- a/ngraph_bridge/ngraph_utils.h +++ b/ngraph_bridge/ngraph_utils.h @@ -156,6 +156,10 @@ Status ValuesFromConstNode(const NodeDef& node, switch (dt) { // TODO(amprocte/NGRAPH-2502): there are more element types to support // here + case DT_HALF: + val_size = tensor.half_val_size(); + if (val_size > 0) val_i = (T)tensor.half_val()[i]; + break; case DT_INT32: val_size = tensor.int_val_size(); if (val_size > 0) val_i = tensor.int_val()[i]; From e3b742a2abb7d19d5e1e48f78d6176f0de65eb3c Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Thu, 11 Jun 2020 15:00:38 -0700 Subject: [PATCH 2/5] Fix half-precision cast issues --- ngraph_bridge/ngraph_builder.cc | 5 ++++- ngraph_bridge/ngraph_utils.h | 15 ++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index dc1dd4373..b1d2946e4 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -249,6 +249,9 @@ static Status TensorDataToVector(const Tensor& tensor, std::vector* vector) { // Else we have to convert. else { switch (dt) { + case DT_HALF: + ConvertTensorDataToVector(tensor, vector); + break; case DT_FLOAT: ConvertTensorDataToVector(tensor, vector); break; @@ -352,7 +355,7 @@ Builder::TF_NGRAPH_CONST_MAP() { {DataType::DT_BOOL, make_pair(MakeConstOp, ng::element::boolean)}, {DataType::DT_HALF, - make_pair(MakeConstOp, ng::element::f16)}}; + make_pair(MakeConstOp, ng::element::f16)}}; return the_map; } diff --git a/ngraph_bridge/ngraph_utils.h b/ngraph_bridge/ngraph_utils.h index adfb421fa..3e432135f 100644 --- a/ngraph_bridge/ngraph_utils.h +++ b/ngraph_bridge/ngraph_utils.h @@ -102,7 +102,8 @@ Status ValuesFromConstNode(const NodeDef& node, return errors::InvalidArgument("Node not a Const"); } - if (node.attr().at("dtype").type() != DataTypeToEnum::value) { + if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 || + node.attr().at("dtype").type() != DataTypeToEnum::value) { std::stringstream ss; ss << "Invalid data type defined for Const. Defined: " << node.attr().at("dtype").type(); @@ -158,27 +159,27 @@ Status ValuesFromConstNode(const NodeDef& node, // here case DT_HALF: val_size = tensor.half_val_size(); - if (val_size > 0) val_i = (T)tensor.half_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.half_val()[i]); break; case DT_INT32: val_size = tensor.int_val_size(); - if (val_size > 0) val_i = tensor.int_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.int_val()[i]); break; case DT_INT64: val_size = tensor.int64_val_size(); - if (val_size > 0) val_i = tensor.int64_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.int64_val()[i]); break; case DT_FLOAT: val_size = tensor.float_val_size(); - if (val_size > 0) val_i = tensor.float_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.float_val()[i]); break; case DT_BOOL: val_size = tensor.bool_val_size(); - if (val_size > 0) val_i = tensor.bool_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.bool_val()[i]); break; case DT_DOUBLE: val_size = tensor.double_val_size(); - if (val_size > 0) val_i = tensor.double_val()[i]; + if (val_size > 0) val_i = static_cast(tensor.double_val()[i]); break; default: NGRAPH_VLOG(0) From ed72cc063800112fb7c636d1156b80bb6ead7a36 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Thu, 11 Jun 2020 15:01:52 -0700 Subject: [PATCH 3/5] Format --- ngraph_bridge/ngraph_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ngraph_bridge/ngraph_utils.h b/ngraph_bridge/ngraph_utils.h index 3e432135f..24fdd3d45 100644 --- a/ngraph_bridge/ngraph_utils.h +++ b/ngraph_bridge/ngraph_utils.h @@ -103,7 +103,7 @@ Status ValuesFromConstNode(const NodeDef& node, } if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 || - node.attr().at("dtype").type() != DataTypeToEnum::value) { + node.attr().at("dtype").type() != DataTypeToEnum::value) { std::stringstream ss; ss << "Invalid data type defined for Const. Defined: " << node.attr().at("dtype").type(); From 9ba5778a6f3e165e9ae43a09d5a99beb92b95ecf Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Thu, 11 Jun 2020 17:20:00 -0700 Subject: [PATCH 4/5] Treat DT_HALF as a float before converting it to ng::float16 --- ngraph_bridge/ngraph_utils.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ngraph_bridge/ngraph_utils.h b/ngraph_bridge/ngraph_utils.h index 3e432135f..9b2869a3c 100644 --- a/ngraph_bridge/ngraph_utils.h +++ b/ngraph_bridge/ngraph_utils.h @@ -102,8 +102,9 @@ Status ValuesFromConstNode(const NodeDef& node, return errors::InvalidArgument("Node not a Const"); } + auto dt = node.attr().at("dtype").type(); if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 || - node.attr().at("dtype").type() != DataTypeToEnum::value) { + (dt != DT_HALF && dt != DataTypeToEnum::value)) { std::stringstream ss; ss << "Invalid data type defined for Const. Defined: " << node.attr().at("dtype").type(); From fdc5a48f66d80acf21d246f508e88d6202d85695 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Thu, 11 Jun 2020 17:24:20 -0700 Subject: [PATCH 5/5] Code format --- ngraph_bridge/ngraph_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ngraph_bridge/ngraph_utils.h b/ngraph_bridge/ngraph_utils.h index 9b2869a3c..2ca30ad65 100644 --- a/ngraph_bridge/ngraph_utils.h +++ b/ngraph_bridge/ngraph_utils.h @@ -104,7 +104,7 @@ Status ValuesFromConstNode(const NodeDef& node, auto dt = node.attr().at("dtype").type(); if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 || - (dt != DT_HALF && dt != DataTypeToEnum::value)) { + (dt != DT_HALF && dt != DataTypeToEnum::value)) { std::stringstream ss; ss << "Invalid data type defined for Const. Defined: " << node.attr().at("dtype").type();