diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 69ebbfc271b..59ab4d05f73 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -128,6 +128,7 @@ import org.tensorflow.op.core.NextIteration; import org.tensorflow.op.core.NoOp; import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Ones; import org.tensorflow.op.core.OnesLike; import org.tensorflow.op.core.OrderedMapClear; import org.tensorflow.op.core.OrderedMapIncompleteSize; @@ -3426,6 +3427,19 @@ public OneHot oneHot(Operand indices, return OneHot.create(scope, indices, depth, onValue, offValue, options); } + /** + * Creates a one valued tensor given its type and shape. + * + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype. Can not be TString. + * @return a constant tensor initialized with ones + * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones. + */ + public Ones ones(Operand dims, DataType type) { + return Ones.create(scope, dims, type); + } + /** * Returns a tensor of ones with the same shape and type as x. * @@ -7726,7 +7740,7 @@ public Ops withName(String opName) { } /** - * Returns an API that uses the provided DeviceSpec for an op. + * Returns an API that places the created operations on the device(s) matching the provided spec. * * @see {@link Scope#withDevice(DeviceSpec)} */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java new file mode 100644 index 00000000000..3af0846b441 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Ones.java @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.types.TString; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +/** + * An operator creating a constant initialized with ones of the shape given by `dims`. + * + *

For example, the following expression + *

{@code tf.ones(tf.constant(shape), TFloat32.DTYPE)}
+ * is the equivalent of + *
{@code tf.fill(tf.constant(shape), tf.constant(1.0f))}
+ * + * @param constant type + */ +@Operator +public final class Ones implements Op, Operand { + + /** + * Creates a one valued tensor given its type and shape. + * + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype. Can not be TString. + * @return a constant tensor initialized with ones + * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones. + */ + @Endpoint + public static Ones create(Scope scope, Operand dims, DataType type) { + Scope onesScope = scope.withSubScope("Ones"); + if (type == TString.DTYPE) { + throw new IllegalArgumentException("Can't create Ones of String DataType"); + } + Operand one = Cast.create(onesScope.withName("One"), Constant.scalarOf(onesScope, 1), type); + return new Ones<>(Fill.create(onesScope.withName("Fill"), dims, one)); + } + + @Override + public Operation op() { + return fill.op(); + } + + @Override + public Output asOutput() { + return fill.asOutput(); + } + + private final Fill fill; + + private Ones(Fill fill) { + this.fill = fill; + } +}