From dcae2bdd2f8b0b509593dce76bdc66d0a0b635c2 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 4 Dec 2020 17:56:55 -0800 Subject: [PATCH 1/4] Add ones op Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 16 +++- .../java/org/tensorflow/op/core/Ones.java | 77 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 69ebbfc271b..0f685fe899d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -128,6 +128,7 @@ import org.tensorflow.op.core.NextIteration; import org.tensorflow.op.core.NoOp; import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Ones; import org.tensorflow.op.core.OnesLike; import org.tensorflow.op.core.OrderedMapClear; import org.tensorflow.op.core.OrderedMapIncompleteSize; @@ -3426,6 +3427,19 @@ public OneHot oneHot(Operand indices, return OneHot.create(scope, indices, depth, onValue, offValue, options); } + /** + * Creates a one valued tensor given its type and shape. + * + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype. Can not be TString. + * @return a constant tensor initialized with zeros + * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones. + */ + public Ones ones(Operand dims, DataType type) { + return Ones.create(scope, dims, type); + } + /** * Returns a tensor of ones with the same shape and type as x. * @@ -7726,7 +7740,7 @@ public Ops withName(String opName) { } /** - * Returns an API that uses the provided DeviceSpec for an op. + * Returns an API that places the created operations on the device(s) matching the provided spec. * * @see {@link Scope#withDevice(DeviceSpec)} */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java new file mode 100644 index 00000000000..3b7f5634e43 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.types.TString; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +/** + * An operator creating a constant initialized with ones of the shape given by `dims`. + * + *

For example, the following expression + *

{@code tf.ones(tf.constant(shape), TFloat32.DTYPE)
+ * is the equivalent of + *
{@code tf.fill(tf.constant(shape), tf.constant(1.0f))
+ * + * @param constant type + */ +@Operator +public final class Ones implements Op, Operand { + + /** + * Creates a one valued tensor given its type and shape. + * + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype. Can not be TString. + * @return a constant tensor initialized with zeros + * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones. + */ + @Endpoint + public static Ones create(Scope scope, Operand dims, DataType type) { + Scope zerosScope = scope.withSubScope("Ones"); + if (type == TString.DTYPE) { + throw new IllegalArgumentException("Can't create Ones of String DataType"); + } + Operand one = Cast.create(zerosScope.withName("One"), Constant.scalarOf(zerosScope, 1), type); + return new Ones<>(Fill.create(zerosScope.withName("Fill"), dims, one)); + } + + @Override + public Operation op() { + return fill.op(); + } + + @Override + public Output asOutput() { + return fill.asOutput(); + } + + private final Fill fill; + + private Ones(Fill fill) { + this.fill = fill; + } +} From 612fed86d89da9527ee33e1baf61d97687c9eeb5 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 4 Dec 2020 17:59:35 -0800 Subject: [PATCH 2/4] missed a zeroes Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/op/core/Ones.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java index 3b7f5634e43..75f41fb4efc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java @@ -46,17 +46,17 @@ public final class Ones implements Op, Operand { * @param scope is a scope used to add the underlying operation * @param dims a 1-D operand that represents the shape of the output tensor * @param type the output tensor datatype. Can not be TString. - * @return a constant tensor initialized with zeros + * @return a constant tensor initialized with ones * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones. */ @Endpoint public static Ones create(Scope scope, Operand dims, DataType type) { - Scope zerosScope = scope.withSubScope("Ones"); + Scope onesScope = scope.withSubScope("Ones"); if (type == TString.DTYPE) { throw new IllegalArgumentException("Can't create Ones of String DataType"); } - Operand one = Cast.create(zerosScope.withName("One"), Constant.scalarOf(zerosScope, 1), type); - return new Ones<>(Fill.create(zerosScope.withName("Fill"), dims, one)); + Operand one = Cast.create(onesScope.withName("One"), Constant.scalarOf(onesScope, 1), type); + return new Ones<>(Fill.create(onesScope.withName("Fill"), dims, one)); } @Override From 099d7e87e856f74bcae4808a4f569d2e22d32d53 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 4 Dec 2020 18:01:11 -0800 Subject: [PATCH 3/4] rerun codegen Signed-off-by: Ryan Nett --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 0f685fe899d..59ab4d05f73 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -3433,7 +3433,7 @@ public OneHot oneHot(Operand indices, * @param scope is a scope used to add the underlying operation * @param dims a 1-D operand that represents the shape of the output tensor * @param type the output tensor datatype. Can not be TString. - * @return a constant tensor initialized with zeros + * @return a constant tensor initialized with ones * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones. */ public Ones ones(Operand dims, DataType type) { From ecbf7d5f0a9c8b1e0153d3e4df4eaaf56ff8fb3b Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 5 Dec 2020 14:28:52 -0800 Subject: [PATCH 4/4] Fix license and javadoc Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/op/core/Ones.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java index 75f41fb4efc..3af0846b441 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,9 +31,9 @@ * An operator creating a constant initialized with ones of the shape given by `dims`. * *

For example, the following expression - *

{@code tf.ones(tf.constant(shape), TFloat32.DTYPE)
+ *
{@code tf.ones(tf.constant(shape), TFloat32.DTYPE)}
* is the equivalent of - *
{@code tf.fill(tf.constant(shape), tf.constant(1.0f))
+ *
{@code tf.fill(tf.constant(shape), tf.constant(1.0f))}
* * @param constant type */