1616package org .tensorflow ;
1717
1818import static org .tensorflow .Graph .resolveOutputs ;
19- import static org .tensorflow .internal .c_api .global .tensorflow .TF_OperationGetAttrType ;
19+ import static org .tensorflow .internal .c_api .global .tensorflow .TF_CloseSession ;
20+ import static org .tensorflow .internal .c_api .global .tensorflow .TF_DeleteSession ;
21+ import static org .tensorflow .internal .c_api .global .tensorflow .TF_NewSession ;
2022import static org .tensorflow .internal .c_api .global .tensorflow .TF_SessionRun ;
2123import static org .tensorflow .internal .c_api .global .tensorflow .TF_SetConfig ;
2224
3638import org .tensorflow .internal .c_api .TF_SessionOptions ;
3739import org .tensorflow .internal .c_api .TF_Status ;
3840import org .tensorflow .internal .c_api .TF_Tensor ;
39- import org .tensorflow .internal .types .registry .TensorTypeRegistry ;
4041import org .tensorflow .op .Op ;
41- import org .tensorflow .op .Ops ;
42- import org .tensorflow .op .core .ReadVariableOp ;
4342import org .tensorflow .proto .framework .ConfigProto ;
44- import org .tensorflow .proto .framework .DataType ;
4543import org .tensorflow .proto .framework .RunMetadata ;
4644import org .tensorflow .proto .framework .RunOptions ;
4745import org .tensorflow .proto .util .SaverDef ;
@@ -194,11 +192,6 @@ public Runner feed(String operation, int index, Tensor t) {
194192 * @return this session runner
195193 */
196194 public Runner feed (Operand <?> operand , Tensor t ) {
197- if (operand .env () != graph ) {
198- throw new IllegalStateException ("Can't feed value for operand " + operand + ", it is from " +
199- (operand .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
200- }
201-
202195 inputs .add (operand .asOutput ());
203196 inputTensors .add (t );
204197 return this ;
@@ -207,8 +200,6 @@ public Runner feed(Operand<?> operand, Tensor t) {
207200 /**
208201 * Make {@link #run()} return the output of {@code operation}.
209202 *
210- * If the output is a resource variable, will fetch the value.
211- *
212203 * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
213204 * fetch(operation, 0)}, or it is a string of the form
214205 * <tt>operation_name:output_index</tt> , in which case this method acts like {@code
@@ -224,8 +215,6 @@ public Runner fetch(String operation) {
224215 /**
225216 * Make {@link #run()} return the {@code index}-th output of {@code operation}.
226217 *
227- * If the output is a resource variable, will fetch the value.
228- *
229218 * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
230219 * one to return.
231220 *
@@ -236,61 +225,24 @@ public Runner fetch(String operation) {
236225 */
237226 public Runner fetch (String operation , int index ) {
238227 Operation op = graph .operationOrThrow (operation );
239- return fetch (op .output (index ));
228+ outputs .add (op .output (index ));
229+ return this ;
240230 }
241231
242232 /**
243233 * Makes {@link #run()} return the Tensor referred to by {@code output}.
244234 *
245- * If {@code output} is a resource variable, will fetch the value.
246- *
247235 * @param output the node to fetch the tensor from
248236 * @return this session runner
249237 */
250238 public Runner fetch (Output <?> output ) {
251- if (output .env () != graph ) {
252- throw new IllegalStateException ("Can't fetch output " + output + ", it is from " +
253- (output .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
254- }
255-
256- if (output .dataType () == DataType .DT_RESOURCE ) {
257- int [] rawDt = new int [1 ];
258-
259- GraphOperation graphOp = (GraphOperation ) output .op ();
260-
261- try (PointerScope scope = new PointerScope ()) {
262- TF_Status status = TF_Status .newStatus ();
263- TF_OperationGetAttrType (graphOp .getUnsafeNativeHandle (), "dtype" , rawDt , status );
264- status .throwExceptionIfNotOK ();
265- }
266-
267- DataType valueDt = DataType .forNumber (rawDt [0 ]);
268-
269- Operand <?> read = null ;
270- for (GraphOperation op : graphOp .consumers ()) {
271- if (op .dtype (0 ) == valueDt && op .type ().equals (ReadVariableOp .OP_NAME )) {
272- read = op .output (0 );
273- break ;
274- }
275- }
276-
277- if (read == null ) {
278- read = Ops .create (graph ).withSubScope ("session_reads" ).withName (output .op ().name () + "_read" )
279- .readVariableOp (output , TensorTypeRegistry .find (valueDt ).type ());
280- }
281-
282- outputs .add (read .asOutput ());
283- } else {
284- outputs .add (output );
285- }
239+ outputs .add (output );
286240 return this ;
287241 }
288242
289243 /**
290244 * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
291245 *
292- * If {@code operand} is a resource variable, will fetch the value.
293- *
294246 * @param operand the node to fetch the tensor from, as an operand
295247 * @return this session runner
296248 */
@@ -306,7 +258,9 @@ public Runner fetch(Operand<?> operand) {
306258 * @throws IllegalArgumentException if no operation exists with the provided name
307259 */
308260 public Runner addTarget (String operation ) {
309- return addTarget (graph .operationOrThrow (operation ));
261+ GraphOperation op = graph .operationOrThrow (operation );
262+ targets .add (op );
263+ return this ;
310264 }
311265
312266 /**
@@ -315,12 +269,13 @@ public Runner addTarget(String operation) {
315269 * @param operation the operation to execute
316270 * @return this session runner
317271 * @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
318- * @throws IllegalStateException if the operation is not from the session's graph.
319272 */
320273 public Runner addTarget (Operation operation ) {
321- if (operation .env () != graph ) {
322- throw new IllegalStateException ("Can't target operation " + operation + ", it is from " +
323- (operation .env ().isEager () ? "an eager session" : "a different graph" ) + "." );
274+ if (!(operation instanceof GraphOperation )) {
275+ throw new IllegalArgumentException (
276+ "Operation of type "
277+ + operation .getClass ().getName ()
278+ + " is not supported in graph sessions" );
324279 }
325280 targets .add ((GraphOperation ) operation );
326281 return this ;
@@ -639,12 +594,12 @@ private static void delete(TF_Session handle) {
639594 *
640595 * @param handle to the C API TF_Session object (Session.nativeHandle)
641596 * @param runOptions A RunOptions protocol buffer, or null
597+ * @param inputOpHandles (see inputOpIndices)
598+ * @param inputOpIndices (see inputTensorHandles)
642599 * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
643600 * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
644601 * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
645602 * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
646- * @param inputOpHandles (see inputOpIndices)
647- * @param inputOpIndices (see inputTensorHandles)
648603 * @param outputOpHandles (see outputOpIndices)
649604 * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
650605 * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==
0 commit comments