From 59f243f3edf46fde83687de1cb426608e77dea3d Mon Sep 17 00:00:00 2001 From: Shajan Dasan Date: Mon, 6 Jul 2020 22:25:33 -0700 Subject: [PATCH] Draft: Java API to use tf.function available on SavedModel. Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function --- .../java/org/tensorflow/SavedModelBundle.java | 128 ++++++++++++++ .../main/java/org/tensorflow/TfFunction.java | 157 ++++++++++++++++++ 2 files changed, 285 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 8f683a59d89..b9fbbd9dfd9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -20,6 +20,9 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerScope; @@ -32,6 +35,7 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.MetaGraphDef; import org.tensorflow.proto.framework.RunOptions; +import org.tensorflow.proto.framework.SignatureDef; /** * SavedModelBundle represents a model loaded from storage. @@ -94,6 +98,101 @@ private Loader(String exportDir) { private RunOptions runOptions = null; } + /** + * SignatureToNodeName finds the node names in the {@link Graph} corresponding to the + * input / output parameters of a tf.function + */ + public static final class SignatureToNodeName { + + public SignatureToNodeName(SavedModelBundle savedModelBundle) { + loadSignatures(savedModelBundle); + } + + /** + * Given a tf.function signature name, find the node names corresponding + * to the input arguments + * + * @param functionSignatureName tf.function signature name + * @return a map from input arguments to node names in the {@link Graph} + */ + public Map inputNameToNode(String functionSignatureName) { + NameContainer nc = this.functionMap.get(functionSignatureName); + return (nc == null) ? null : nc.inputNameToNode(); + } + + /** + * Given a tf.function signature name, find the node names corresponding + * to the output arguments + * + * @param functionSignatureName tf.function signature name + * @return a map from output arguments to node names in the {@link Graph} + */ + public Map outputNameToNode(String functionSignatureName) { + NameContainer nc = this.functionMap.get(functionSignatureName); + return (nc == null) ? null : nc.outputNameToNode(); + } + + /** + * Given a tf.function signature name, find the method name + */ + public String methodName(String functionSignatureName) { + NameContainer nc = this.functionMap.get(functionSignatureName); + return (nc == null) ? null : nc.methodName(); + } + + private void loadSignatures(SavedModelBundle savedModelBundle) { + MetaGraphDef metaGraph = savedModelBundle.metaGraphDef(); + Map signatureMap = metaGraph.getSignatureDefMap(); + + // A saved model can contain multiple SignatureDef + for (Map.Entry entry : signatureMap.entrySet()) { + NameContainer nc = new NameContainer(entry.getValue()); + this.functionMap.put(entry.getKey(), nc); + } + } + + private Map functionMap = new HashMap<>(); + + private static final class NameContainer { + NameContainer(SignatureDef sd) { + this.inputNameToNodeName = sd.getInputsMap() + .entrySet() + .stream() + .collect(Collectors.toMap( + e -> e.getKey(), + e -> e.getValue().getName() + )); + + this.outputNameToNodeName = sd.getOutputsMap() + .entrySet() + .stream() + .collect(Collectors.toMap( + e -> e.getKey(), + e -> e.getValue().getName() + )); + + this.method = sd.getMethodName(); + } + + public Map inputNameToNode() { + return this.inputNameToNodeName; + } + + public Map outputNameToNode() { + return this.outputNameToNodeName; + } + + public String methodName() { + return this.method; + } + + private Map inputNameToNodeName; + private Map outputNameToNodeName; + private String method; + } + } + /** * Load a saved model from an export directory. The model that is being loaded should be created * using the Saved Model @@ -148,6 +247,34 @@ public Session session() { return session; } + /** + * Returns the {@link SignatureToNodeName} translator for the model. + * + * @return SignatureToNodeName translator + */ + public SignatureToNodeName getSignatureToNodeName() { + if (this.sigToNodeName == null) { + // no need to lock, ok to create multiple instances + this.sigToNodeName = new SignatureToNodeName(this); + } + return this.sigToNodeName; + } + + /** + * Return a {@link TfFunction} corresponding to the function signature. + * + *
{@code
+   * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
+   * Map> outputTensorMap = myFunction.call(inputTensorMap);
+   * }
+ * + * @param functionSignatureName name of the {@code SignatureDef} in the saved model. + * @return TfFunction object that can be used to make calls to the tf.function + */ + public TfFunction function(String functionSignatureName) { + return new TfFunction(functionSignatureName, this.getSignatureToNodeName(), this.session); + } + /** * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model * bundle. @@ -161,6 +288,7 @@ public void close() { private final Graph graph; private final Session session; private final MetaGraphDef metaGraphDef; + private SignatureToNodeName sigToNodeName; private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef) { this.graph = graph; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java new file mode 100644 index 00000000000..5dc5a128898 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TfFunction.java @@ -0,0 +1,157 @@ +/* + * Copyright 2020 The TensorFlow Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow; + +import com.google.protobuf.InvalidProtocolBufferException; + +import java.util.List; +import java.util.ListIterator; +import java.util.HashMap; +import java.util.Map; + +/** + * Invoke
tf.function + * defined in a {@link SavedModelBundle}. + * + *
{@code
+ * TfFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
+ * Map> outputTensorMap = myFunction.call(inputTensorMap);
+ * }
+ * + */ +public class TfFunction { + + public TfFunction( + String functionSignatureName, + SavedModelBundle.SignatureToNodeName nameToNode, Session session) { + this.nameToNode = nameToNode; + this.session = session; + this.functionSignatureName = functionSignatureName; + } + + /** + * Invokes a tf.function. + * Caller is responsible for closing all Tensors. + * + * @param arguments map of input tensors + * @return map of output tensors + */ + public Map> call( + Map> arguments) throws IllegalArgumentException { + + Session.Runner runner = this.session.runner(); + + Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); + + if (inputToNode == null) { + throw new IllegalArgumentException( + String.format("Function [%s] is missing input", this.functionSignatureName)); + } + + // Join arguments.key, inputToNodeName.key + for (Map.Entry entry: inputToNode.entrySet()) { + String argName = entry.getKey(); + Tensor tensor = arguments.get(argName); + + if (tensor == null) { + throw new IllegalArgumentException(String.format("Missing argument [%s]", argName)); + } + + // Node name in the tensorflow graph, corresponding to the tf.function argument + runner = runner.feed(entry.getValue(), tensor); + } + + Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); + if (outputToNode == null) { + throw new IllegalArgumentException( + String.format("Function [%] is missing output", this.functionSignatureName)); + } + + for (String nodeName: outputToNode.values()) { + // Node names corresponding to the return value + runner = runner.fetch(nodeName); + } + + List> resultTensors = runner.run(); + ListIterator> resultTensorIter = resultTensors.listIterator(); + + Map> returnMap = new HashMap>(); + + // Use the output names as present in the signature definition + for (String nodeName: outputToNode.keySet()) { + returnMap.put(nodeName, resultTensorIter.next()); + } + + return returnMap; + } + + /** + * Invokes a tf.function. + * Caller is responsible for closing all Tensors. + * + * Throws IllegalArgumentException if there are multiple input or output parameters defined + * in the tf.function + * + * @param tensor input tensor + * @return output tensor + */ + public Tensor call(Tensor tensor) throws IllegalArgumentException { + Session.Runner runner = this.session.runner(); + + Map inputToNode = this.nameToNode.inputNameToNode(this.functionSignatureName); + + if (inputToNode == null) { + throw new IllegalArgumentException( + String.format("Function [%s] is missing input", this.functionSignatureName)); + } + + if (inputToNode.size() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] requires multiple inputs", this.functionSignatureName)); + } + + // Feed the single argument + for (Map.Entry entry: inputToNode.entrySet()) { + // Node name in the tensorflow graph, corresponding to the tf.function argument + runner = runner.feed(entry.getValue(), tensor); + } + + Map outputToNode = this.nameToNode.outputNameToNode(this.functionSignatureName); + if (outputToNode == null) { + throw new IllegalArgumentException( + String.format("Function [%] is missing output", this.functionSignatureName)); + } + + if (outputToNode.size() != 1) { + throw new IllegalArgumentException( + String.format("Function [%s] has multiple outputs", this.functionSignatureName)); + } + + // Fetch the single return tensor + for (String nodeName: outputToNode.values()) { + // Node names corresponding to the return value + runner = runner.fetch(nodeName); + } + + List> resultTensors = runner.run(); + + return resultTensors.get(0); + } + + private final Session session; + private final SavedModelBundle.SignatureToNodeName nameToNode; + private final String functionSignatureName; +}