From 70bb8d6d93b90b5fae3fbd2f7a27fe7c244b79b8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 10 May 2022 01:41:28 -0400 Subject: [PATCH] fix grappler compilation error with TF 1.15 ~ 2.6 --- source/op/optimizer/parallel.cc | 13 ++++++++++--- source/op/optimizer/parallel.h | 5 +++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/source/op/optimizer/parallel.cc b/source/op/optimizer/parallel.cc index 55e120e989..317de4b11f 100644 --- a/source/op/optimizer/parallel.cc +++ b/source/op/optimizer/parallel.cc @@ -2,6 +2,13 @@ #include "tensorflow/core/public/version.h" #if TF_MAJOR_VERSION >= 2 || (TF_MAJOR_VERSION == 1 && TF_MINOR_VERSION >= 15) +#if TF_MAJOR_VERSION >= 2 && TF_MINOR_VERSION >= 7 +// breaking change in tf 2.7: Renaming of tensorflow::int64 to int_64_t +#define TF_INT64 int64_t +#else +#define TF_INT64 tensorflow::int64 +#endif + #include "parallel.h" #include "tensorflow/core/grappler/devices.h" @@ -34,10 +41,10 @@ bool FindProdForce(RemapperContext *ctx, int node_index) { return IsProdForce(*node_def); } -int64_t GetNThreads() { +TF_INT64 GetNThreads() { // the number of threads is based on the session... // For convenience, we use environment variable directly - int64_t tot = 1; + TF_INT64 tot = 1; Status status = ReadInt64FromEnvVar("TF_INTER_OP_PARALLELISM_THREADS", 1, &tot); if (!status.ok()) { @@ -55,7 +62,7 @@ Status ParallelProdForce(RemapperContext *ctx, int node_index, const NodeDef *ori_node = ctx->graph_view.GetNode(node_index)->node(); auto &src_attr = ori_node->attr(); - int64_t tot = GetNThreads(); + TF_INT64 tot = GetNThreads(); if (tot <= 1) return Status::OK(); diff --git a/source/op/optimizer/parallel.h b/source/op/optimizer/parallel.h index 7de9f0b7ea..efedf65da8 100644 --- a/source/op/optimizer/parallel.h +++ b/source/op/optimizer/parallel.h @@ -16,6 +16,11 @@ class DPParallel : public CustomGraphOptimizer { bool UsesFunctionLibrary() const override { return false; } Status Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) override; +#if TF_MAJOR_VERSION >= 2 && TF_MINOR_VERSION < 6 +// TF 3457a2b122e50b4d44ceaaed5a663d635e5c22df + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} +#endif }; #endif // DP_REMAPPER_H_ \ No newline at end of file