From ccc7820d54f875530f29c0f32fb853633da499e3 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 15 Jan 2021 10:01:35 -0500 Subject: [PATCH 01/14] Initial Checkin --- .../tensorflow/framework/regularizers/L1.java | 48 +++++ .../framework/regularizers/L1L2.java | 129 ++++++++++++ .../framework/regularizers/L1_L2.java | 60 ++++++ .../tensorflow/framework/regularizers/L2.java | 48 +++++ .../framework/regularizers/Regularizer.java | 85 ++++++++ .../regularizers/RegularizerLoss.java | 68 +++++++ .../framework/regularizers/CommonTest.java | 63 ++++++ .../framework/regularizers/L1L2Test.java | 110 +++++++++++ .../framework/regularizers/L1Test.java | 77 ++++++++ .../framework/regularizers/L1_L2Test.java | 116 +++++++++++ .../framework/regularizers/L2Test.java | 79 ++++++++ .../org/tensorflow/framework/utils/ND.java | 183 +++++++++++++++++- 12 files changed, 1062 insertions(+), 4 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java new file mode 100644 index 00000000000..8d3469dca98 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies a L1 regularization penalty. + * + *

The L1 regularization penalty is computed as: loss = l1 * reduceSum(abs(x)) + * + * @param the data type for the weights + */ +public class L1 extends L1L2 { + + /** + * Create a regularizer that applies an L1 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L1(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Create a regularizer that applies an L1 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l1 the L1 regularization penalty + */ + public L1(Ops tf, float l1, Class type) { + super(tf, l1, null, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java new file mode 100644 index 00000000000..565384f1fa5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies both L1 and L2 regularization penalties. + * + *

The L1 regularization penalty is computed as: + * + *

loss = l1 * reduceSum(abs(x))
+ * + *

The L2 regularization penalty is computed as + * + *

loss = l2 * reduceSum(square(x))
+ * + * @param the data type for the weights + */ +public class L1L2 extends Regularizer { + + private final Float l1; + private final Float l2; + + /** + * Creates an L1L2 regularizer with no l1 or l2 penalty with default penal + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights + */ + public L1L2(Ops tf, Class type) { + this(tf, null, null, type); + } + + /** + * Creates an L1L2 regularizer + * + * @param tf the TensorFlow Ops + * @param l1 L1 regularization factor, if null it is set to 0. + * @param l2 L2 regularization factor, if null it is set to 0. + * @param type the data type for the weights + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link }NaN or is + * infinite. + */ + public L1L2(Ops tf, Float l1, Float l2, Class type) { + super(tf, type); + if (l1 != null) { + if (l1.isNaN() || l1.isInfinite()) { + throw new IllegalArgumentException( + String.format( + "L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l1)); + } + this.l1 = l1; + } else { + this.l1 = 0f; + } + if (l2 != null) { + if (l2.isNaN() || l2.isInfinite()) { + throw new IllegalArgumentException( + String.format( + "L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l2)); + } + this.l2 = l2; + } else { + this.l2 = 0f; + } + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + Ops tf = getTF(); + if (this.getL1() == null && this.getL2() == null) { + return tf.dtypes.cast(tf.constant(0), input.type()); + } + Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); + + if (this.getL1() != null && this.getL1() != 0.f) { + Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); + Operand abs = tf.math.abs(input); + Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); + regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); + } + + if (this.getL2() != null && this.getL2() != 0.f) { + Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); + Operand sqr = tf.math.abs(input); + Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); + regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); + } + + return regularization; + } + + /** + * Gets the L1 regularization factor + * + * @return the L1 regularization factor + */ + public Float getL1() { + return l1; + } + + /** + * Gets the L2 regularization factor + * + * @return the L2 regularization factor + */ + public Float getL2() { + return l2; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java new file mode 100644 index 00000000000..0089395acc0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies both L1 and L2 regularization penalties. + * + *

The L1 regularization penalty is computed as: + * + *

loss = l1 * reduceSum(abs(x))
+ * + *

The L2 regularization penalty is computed as + * + *

loss = l2 * reduceSum(square(x))
+ * + *

The difference between this class and the {@link L1L2} is use of the default regularization + * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. + * + * @param the data type for the weights + */ +public class L1_L2 extends L1L2 { + + /** + * Create a regularizer that applies an L1 and l2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L1_L2(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Create a regularizer that applies an L1 and l2 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l1 the L1 regularization penalty + */ + public L1_L2(Ops tf, Float l1, Float l2, Class type) { + super(tf, + l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, + l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2, + type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java new file mode 100644 index 00000000000..3ce0f84581c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies a L2 regularization penalty. + * + *

The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x)) + * + * @param the data type for the weights + */ +public class L2 extends L1L2 { + + /** + * Create a regularizer that applies an L2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L2(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Create a regularizer that applies an L1 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l2 the L2 regularization penalty + */ + public L2(Ops tf, float l2, Class type) { + super(tf, null, l2, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java new file mode 100644 index 00000000000..aeacee28025 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** @param the data type of the result */ +public abstract class Regularizer { + + public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; + + private final Ops tf; + private final String name; + protected Class type; + + /** + * Create a Regularizer + * + * @param tf the TensorFlow ops. + */ + protected Regularizer(Ops tf, Class type) { + this(tf, null, type); + } + /** + * Create a Regularizer + * + * @param tf the TensorFlow ops. + */ + protected Regularizer(Ops tf, String name, Class type) { + this.tf = tf; + this.type = type; + this.name = name == null ? this.getClass().getSimpleName() : name; + } + + /** + * Returns this Regularizer as a Loss This is a convenience tp regularize a loss. Only + * sampleWeights are applied to the regularizer. + * + * @return this Regularizer as a Loss + */ + public Loss asLoss() { + return new RegularizerLoss(this.tf, this); + } + + /** + * Computes a regularization penalty from an input. + * + * @param input teh weighted input + * @return the result of computing the regularization penalty + */ + public abstract Operand call(Operand input); + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the name for this regularizer + * + * @return the name for this regularizer + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java new file mode 100644 index 00000000000..bcde00fe959 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * A Regularizer call wrapped as a Loss instance + * + *

This class facilitates using a regularizing as a loss, only sampleWeights are + * regularized. + * + * @param the datatype for the weights type + */ +class RegularizerLoss extends Loss { + + private final Regularizer regularizer; + private final Class type; + /** + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT} + * + * @param tf the TensorFlow Ops + */ + public RegularizerLoss(Ops tf, Regularizer regularizer) { + this(tf, null, regularizer); + } + + /** + * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * + * @param tf the TensorFlow Ops + * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + */ + public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { + super(tf, name); + this.regularizer = regularizer; + this.type = regularizer.type; + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + if (sampleWeights == null) { + throw new IllegalArgumentException("sampleWeights cannot be null"); + } + Operand result = regularizer.call(cast(getTF(), sampleWeights, type)); + return cast(tf, result, sampleWeights.type()); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java new file mode 100644 index 00000000000..63ecc155fd1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.framework.utils.ND; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.StdArrays; + +public class CommonTest { + + protected float regularizeL1L2(float[][] w, float l1, float l2) { + return regularizeL1(w, l1) + regularizeL2(w, l2); + } + + protected float regularizeL1(float[][] w, float l1) { + FloatNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.abs(fa); + FloatNdArray sum = ND.sum(fa); + FloatNdArray mul = ND.mul(sum, l1); + return mul.getFloat(); + } + + protected float regularizeL2(float[][] w, float l2) { + FloatNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.square(fa); + FloatNdArray sum = ND.sum(fa); + FloatNdArray mul = ND.mul(sum, l2); + return mul.getFloat(); + } + + protected double regularizeL1L2(double[][] w, float l1, float l2) { + return regularizeL1(w, l1) + regularizeL2(w, l2); + } + + protected double regularizeL1(double[][] w, float l1) { + DoubleNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.abs(fa); + DoubleNdArray sum = ND.sum(fa); + DoubleNdArray mul = ND.mul(sum, l1); + return mul.getDouble(); + } + + protected double regularizeL2(double[][] w, float l2) { + DoubleNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.square(fa); + DoubleNdArray sum = ND.sum(fa); + DoubleNdArray mul = ND.mul(sum, l2); + return mul.getDouble(); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java new file mode 100644 index 00000000000..178468f42e8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -0,0 +1,110 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.2f, 0.3f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.3f, instance.getL2()); + + instance = new L1L2<>(tf, null, null, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1L2<>(tf, 0.5f, null, TFloat32.class); + assertEquals(0.5f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1L2<>(tf, null, 0.5f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.5f, instance.getL2()); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCall() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, TFloat32.class); + Operand result = instance.call(tf.constant(555f)); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallNO() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, TFloat32.class); + Operand weights = + tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1L2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.09f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, null, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, 0.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, null, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java new file mode 100644 index 00000000000..f2031a57834 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -0,0 +1,77 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.*; + +class L1Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.2f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1<>(tf, 0f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(0.f, instance.getL2()); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallNO() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1_2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1(w, 0.02f); + session.evaluate(expected, result); + } + } +} \ No newline at end of file diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java new file mode 100644 index 00000000000..d35179cdd62 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java @@ -0,0 +1,116 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1_L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.2f, 0.3f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.3f, instance.getL2()); + + instance = new L1_L2<>(tf, 0.5f, 0f, TFloat32.class); + assertEquals(0.5f, instance.getL1()); + assertEquals(0f, instance.getL2()); + + instance = new L1_L2<>(tf, 0f, 0.5f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.5f, instance.getL2()); + + instance = new L1_L2<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallZero() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0f, 0f, TFloat32.class); + Operand result = instance.call(tf.constant(555f)); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallNO() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = + regularizeL1L2( + w, + Regularizer.DEFAULT_REGULARIZATION_PENALTY, + Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallL1L2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.01f, 0.02f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallL1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.01f, 0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, 0.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallL2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0f, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java new file mode 100644 index 00000000000..cbb019796f1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -0,0 +1,79 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.2f, TFloat32.class); + assertEquals(0.2f, instance.getL2()); + assertEquals(0.f, instance.getL1()); + + instance = new L2<>(tf, 0f, TFloat32.class); + assertEquals(0.f, instance.getL2()); + assertEquals(0.f, instance.getL1()); + + instance = new L2<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + assertEquals(0.f, instance.getL1()); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallNO() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL2(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1_2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 0503a41dfc2..694287d4970 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,10 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.*; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; @@ -120,6 +117,23 @@ public static FloatNdArray square(FloatNdArray a) { return result; } + /** + * Gets the square of an array. + * + * @param a the array + * @return the square of the array. + */ + public static DoubleNdArray square(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + result.setDouble(v.getDouble() * v.getDouble(), idx); + }); + return result; + } + /** * Adds two arrays * @@ -284,6 +298,64 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) { return mul(a, scalar); } + /** + * Multiply 2 arrays + * + * @param a the first array + * @param b the second array + * @return the resulting array from the muliply operation + */ + public static DoubleNdArray mul(DoubleNdArray a, DoubleNdArray b) { + if (!a.shape().equals(b.shape())) + throw new IllegalArgumentException( + String.format( + "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); + boolean sameSize = a.shape().size() == b.shape().size(); + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + if (sameSize) { + result.setDouble(v.getDouble() * b.getDouble(idx), idx); + } else { + double value = v.getDouble() * b.getDouble(idx[0], 0L); + result.setDouble(value, idx); + } + }); + return result; + } + + /** + * Multiply an array with a scalar value + * + * @param a the array + * @param scalar the scalar value + * @return the resulting array from the Multiply operation + */ + public static DoubleNdArray mul(DoubleNdArray a, float scalar) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + if (a.shape().isScalar()) { + a.scalars().forEach(f -> result.setDouble(f.getDouble() * scalar)); + } else { + a.scalars().forEachIndexed((idx, f) -> result.setDouble(f.getDouble() * scalar, idx)); + } + + return result; + } + + /** + * Multiply a scalar value with an array + * + * @param scalar the scalar value + * @param a the array + * @return the resulting array from the Multiply operation + */ + public static DoubleNdArray mul(float scalar, DoubleNdArray a) { + return mul(a, scalar); + } + /** * Divide two arrays * @@ -556,6 +628,18 @@ public static FloatNdArray abs(FloatNdArray a) { return result; } + /** + * Get the absolute value of each member of the array + * + * @param a the array + * @return the array with the absolute value of each item. + */ + public static DoubleNdArray abs(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + a.scalars().forEachIndexed((idx, f) -> result.setDouble( Math.abs(f.getDouble()), idx)); + return result; + } + /** * Sum all elements of an array * @@ -647,6 +731,97 @@ public static FloatNdArray sum(FloatNdArray a, Integer[] axes, boolean keepDims) } } + /** + * Sum all elements of an array + * + * @param a the array + * @return an a array with one element containing the sum. + */ + public static DoubleNdArray sum(DoubleNdArray a) { + AtomicReference sum = new AtomicReference<>(0.); + a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + return NdArrays.scalarOf(sum.get()); + } + + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @return an a array the sum over the axis less the diemsnion + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis) { + return sum(a, axis, false); + } + + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { + Shape shape = a.shape(); + int nDims = shape.numDimensions(); + int xis = nDims - 1 - axis; + long totalSize = shape.size(); + long axisSize = shape.size(xis); + final double[] sums = new double[(int) axisSize]; + + a.scalars() + .forEachIndexed( + (idx, f) -> { + sums[(int) idx[xis]] += f.getDouble(); + }); + + if (keepDims) { + long[] newDims = shape.asArray(); + newDims[axis] = 1; + final AtomicInteger counter = new AtomicInteger(); + DoubleNdArray arrayK = NdArrays.ofDoubles(Shape.of(newDims)); + arrayK + .elements(newDims.length - 1) + .forEachIndexed( + (idx, v) -> { + v.setDouble(sums[counter.getAndAdd(1)]); + }); + return arrayK; + } else { + return NdArrays.vectorOf(sums); + } + } + + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axes the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static DoubleNdArray sum(DoubleNdArray a, Integer[] axes, boolean keepDims) { + Shape shape = a.shape(); + if (axes == null) { + DoubleNdArray result = sum(a); + if (keepDims) { + double scalar = result.getDouble(0); + long[] dims = {1, 1}; + Shape bShape = Shape.of(dims); + DoubleNdArray resultK = NdArrays.ofDoubles(bShape); + resultK.setDouble(scalar, 0, 0); + return resultK; + } + return result; + } else if (axes.length == 1) { + return sum(a, axes[0], keepDims); + } else { + // TODO + throw new UnsupportedOperationException("Multi Axis Not implemented Yet"); + } + } + /** * Calculate the l2 norm of the array * From 05ec6e8833526d4c15a778663e7aa39b8b5d68e9 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 26 Jan 2021 10:18:48 -0500 Subject: [PATCH 02/14] Clean up JavaDoc Make sure Tests test TFloat32 1nd TFloat64 --- .../tensorflow/framework/regularizers/L1.java | 4 +- .../framework/regularizers/L1L2.java | 7 +- .../framework/regularizers/L1_L2.java | 6 +- .../tensorflow/framework/regularizers/L2.java | 5 +- .../framework/regularizers/Regularizer.java | 19 ++- .../regularizers/RegularizerLoss.java | 2 +- .../framework/regularizers/L1L2Test.java | 32 +++-- .../framework/regularizers/L1Test.java | 117 +++++++++--------- .../framework/regularizers/L1_L2Test.java | 30 +++-- .../framework/regularizers/L2Test.java | 15 +-- .../regularizers/RegularizerLossTest.java | 7 ++ 11 files changed, 142 insertions(+), 102 deletions(-) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 8d3469dca98..740338350e3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -18,7 +18,8 @@ import org.tensorflow.types.family.TNumber; /** - * A regularizer that applies a L1 regularization penalty. + * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) Regression, + * regularization penalty. * *

The L1 regularization penalty is computed as: loss = l1 * reduceSum(abs(x)) * @@ -41,6 +42,7 @@ public L1(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param l1 the L1 regularization penalty + * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ public L1(Ops tf, float l1, Class type) { super(tf, l1, null, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 565384f1fa5..2908387b9d4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -30,6 +30,9 @@ * *

loss = l2 * reduceSum(square(x))
* + *

The difference between this class and the {@link L1_L2} is use of the default regularization + * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. + * * @param the data type for the weights */ public class L1L2 extends Regularizer { @@ -54,8 +57,8 @@ public L1L2(Ops tf, Class type) { * @param l1 L1 regularization factor, if null it is set to 0. * @param l2 L2 regularization factor, if null it is set to 0. * @param type the data type for the weights - * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link }NaN or is - * infinite. + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} + * of {@link Float#isInfinite} */ public L1L2(Ops tf, Float l1, Float l2, Class type) { super(tf, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java index 0089395acc0..95eecc2dd5f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java @@ -36,7 +36,7 @@ public class L1_L2 extends L1L2 { /** - * Create a regularizer that applies an L1 and l2 regularization penalty of {@link + * Creates a regularizer that applies an L1 and l2 regularization penalty of {@link * #DEFAULT_REGULARIZATION_PENALTY} * * @param tf the TensorFlow Ops @@ -46,10 +46,12 @@ public L1_L2(Ops tf, Class type) { } /** - * Create a regularizer that applies an L1 and l2 regularization penalty + * Creates a regularizer that applies an L1 and l2 regularization penalty * * @param tf the TensorFlow Ops * @param l1 the L1 regularization penalty + * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite. */ public L1_L2(Ops tf, Float l1, Float l2, Class type) { super(tf, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index 3ce0f84581c..8298cd4aba5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -18,11 +18,11 @@ import org.tensorflow.types.family.TNumber; /** - * A regularizer that applies a L2 regularization penalty. + * A regularizer that applies a L2 (Ridge Regression) regularization penalty. * *

The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x)) * - * @param the data type for the weights + * @param the data type for the operands and result */ public class L2 extends L1L2 { @@ -41,6 +41,7 @@ public L2(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ public L2(Ops tf, float l2, Class type) { super(tf, null, l2, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index aeacee28025..906efee7f3d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -19,7 +19,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** @param the data type of the result */ +/** + * Base class for Regularizers + * + *

Regularizers allow you to apply penalties on layer parameters or layer activity during + * optimization. These penalties are summed into the loss function that the network optimizes. + * + * @param the data type of the operands and result + */ public abstract class Regularizer { public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; @@ -29,7 +36,7 @@ public abstract class Regularizer { protected Class type; /** - * Create a Regularizer + * Creates a Regularizer * * @param tf the TensorFlow ops. */ @@ -37,7 +44,7 @@ protected Regularizer(Ops tf, Class type) { this(tf, null, type); } /** - * Create a Regularizer + * Creates a Regularizer * * @param tf the TensorFlow ops. */ @@ -48,19 +55,19 @@ protected Regularizer(Ops tf, String name, Class type) { } /** - * Returns this Regularizer as a Loss This is a convenience tp regularize a loss. Only + * Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only * sampleWeights are applied to the regularizer. * * @return this Regularizer as a Loss */ public Loss asLoss() { - return new RegularizerLoss(this.tf, this); + return new RegularizerLoss<>(this.tf, this); } /** * Computes a regularization penalty from an input. * - * @param input teh weighted input + * @param input the weighted input * @return the result of computing the regularization penalty */ public abstract Operand call(Operand input); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java index bcde00fe959..04414285d77 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -24,7 +24,7 @@ /** * A Regularizer call wrapped as a Loss instance * - *

This class facilitates using a regularizing as a loss, only sampleWeights are + *

This class facilitates using a regularizer as a loss, only sampleWeights are * regularized. * * @param the datatype for the weights type diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 178468f42e8..0f3213ed6eb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -35,21 +35,19 @@ public void testCreate() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCall() { + public void testCallDefaultsConstant() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); L1L2 instance = new L1L2<>(tf, TFloat32.class); Operand result = instance.call(tf.constant(555f)); - session.evaluate(0, result); + session.evaluate(0f, result); } } - /** Test of call method, of class L1L2. */ @Test - public void testCallNO() { + public void testCallL1L20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -61,9 +59,8 @@ public void testCallNO() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1L2() { + public void testCallL1L2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -77,9 +74,23 @@ public void testCallL1L2() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1() { + public void testCallL1L2TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.09f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL2Null() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -92,9 +103,8 @@ public void testCallL1() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL2() { + public void testCallL1Null() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java index f2031a57834..6d67bb44d3c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -7,71 +7,68 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; class L1Test extends CommonTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - @Test - public void testCreate() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.2f, TFloat32.class); - assertEquals(0.2f, instance.getL1()); - assertEquals(0.f, instance.getL2()); + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.2f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.f, instance.getL2()); - instance = new L1<>(tf, 0f, TFloat32.class); - assertEquals(0.f, instance.getL1()); - assertEquals(0.f, instance.getL2()); + instance = new L1<>(tf, 0f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); - instance = new L1<>(tf, TFloat32.class); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); - assertEquals(0.f, instance.getL2()); - } - } + instance = new L1<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(0.f, instance.getL2()); + } + } - /** Test of call method, of class L1L2. */ - @Test - public void testCallNO() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.0f, TFloat32.class); - float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - session.evaluate(0, result); - } - } + @Test + public void testCallL10() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0f, result); + } + } - /** Test of call method, of class L1L2. */ - @Test - public void testCallL1() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1 instance = new L1<>(tf, TFloat32.class); - float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); - session.evaluate(expected, result); - } - } + @Test + public void testCallL1TFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.evaluate(expected, result); + } + } - /** Test of call method, of class L1L2. */ - @Test - public void testCallL1_2() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.02f, TFloat64.class); - double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - double expected = regularizeL1(w, 0.02f); - session.evaluate(expected, result); - } - } -} \ No newline at end of file + @Test + public void testCallL1TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1(w, 0.02f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java index d35179cdd62..5aeb5a5d9ad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java @@ -35,7 +35,6 @@ public void testCreate() { } } - /** Test of call method, of class L1_L2<>. */ @Test public void testCallZero() { for (TestSession.Mode tfMode : tfModes) @@ -47,9 +46,8 @@ public void testCallZero() { } } - /** Test of call method, of class L1_L2<>. */ @Test - public void testCallNO() { + public void testCallDefaultTFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -67,7 +65,25 @@ public void testCallNO() { } } - /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallDefaultTFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, TFloat64.class); + double[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = + regularizeL1L2( + w, + Regularizer.DEFAULT_REGULARIZATION_PENALTY, + Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + @Test public void testCallL1L2() { for (TestSession.Mode tfMode : tfModes) @@ -83,9 +99,8 @@ public void testCallL1L2() { } } - /** Test of call method, of class L1_L2<>. */ @Test - public void testCallL1() { + public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -98,9 +113,8 @@ public void testCallL1() { } } - /** Test of call method, of class L1_L2<>. */ @Test - public void testCallL2() { + public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java index cbb019796f1..7f593a2dd14 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -25,15 +25,14 @@ public void testCreate() { assertEquals(0.f, instance.getL2()); assertEquals(0.f, instance.getL1()); - instance = new L2<>(tf, TFloat32.class); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); - assertEquals(0.f, instance.getL1()); + L2 instance64 = new L2<>(tf, TFloat64.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); + assertEquals(0.f, instance64.getL1()); } } - /** Test of call method, of class L1L2. */ @Test - public void testCallNO() { + public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -45,9 +44,8 @@ public void testCallNO() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1() { + public void testCallL2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -61,9 +59,8 @@ public void testCallL1() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1_2() { + public void testCallL2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java new file mode 100644 index 00000000000..e694d9409a0 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -0,0 +1,7 @@ +package org.tensorflow.framework.regularizers; + +import static org.junit.jupiter.api.Assertions.*; + +class RegularizerLossTest { + +} \ No newline at end of file From b446618fd29ae796575e78f115013a66621e9955 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 13 Feb 2021 08:03:03 -0500 Subject: [PATCH 03/14] Fix to match the lates version of losses.Loss --- .../tensorflow/framework/regularizers/RegularizerLoss.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java index 04414285d77..0dc137ed747 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -57,8 +57,8 @@ public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { if (sampleWeights == null) { throw new IllegalArgumentException("sampleWeights cannot be null"); } From b5c7c78648053635aa464de029a7dee05c3f9dcd Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 13 Feb 2021 13:45:37 -0500 Subject: [PATCH 04/14] Updates based on comments from PR. Removed generic from Regularizer class and changed the call method to define the generic return based on the weights parameter. Added static method l1_l2() to L1L2 class. Fixed JavaDoc comments. --- .../tensorflow/framework/regularizers/L1.java | 17 +++-- .../framework/regularizers/L1L2.java | 38 ++++++------ .../framework/regularizers/L1_L2.java | 25 ++++---- .../tensorflow/framework/regularizers/L2.java | 13 ++-- .../framework/regularizers/Regularizer.java | 20 +++--- .../regularizers/RegularizerLoss.java | 20 +++--- .../framework/regularizers/L1L2Test.java | 20 +++--- .../framework/regularizers/L1Test.java | 12 ++-- .../framework/regularizers/L1_L2Test.java | 20 +++--- .../framework/regularizers/L2Test.java | 12 ++-- .../regularizers/RegularizerLossTest.java | 24 ++++++- .../org/tensorflow/framework/utils/ND.java | 62 +++++++++---------- 12 files changed, 145 insertions(+), 138 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 740338350e3..074e881c1cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -15,17 +15,14 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; /** - * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) Regression, - * regularization penalty. + * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) + * Regression, regularization penalty. * *

The L1 regularization penalty is computed as: loss = l1 * reduceSum(abs(x)) - * - * @param the data type for the weights */ -public class L1 extends L1L2 { +public class L1 extends L1L2 { /** * Create a regularizer that applies an L1 regularization penalty of {@link @@ -33,8 +30,8 @@ public class L1 extends L1L2 { * * @param tf the TensorFlow Ops */ - public L1(Ops tf, Class type) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + public L1(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY); } /** @@ -44,7 +41,7 @@ public L1(Ops tf, Class type) { * @param l1 the L1 regularization penalty * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ - public L1(Ops tf, float l1, Class type) { - super(tf, l1, null, type); + public L1(Ops tf, float l1) { + super(tf, l1, null); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 2908387b9d4..89b407e0940 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -32,22 +32,19 @@ * *

The difference between this class and the {@link L1_L2} is use of the default regularization * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. - * - * @param the data type for the weights */ -public class L1L2 extends Regularizer { +public class L1L2 extends Regularizer { - private final Float l1; - private final Float l2; + private final float l1; + private final float l2; /** - * Creates an L1L2 regularizer with no l1 or l2 penalty with default penal + * Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty * * @param tf the TensorFlow Ops - * @param type the data type for the weights */ - public L1L2(Ops tf, Class type) { - this(tf, null, null, type); + public L1L2(Ops tf) { + this(tf, null, null); } /** @@ -56,12 +53,11 @@ public L1L2(Ops tf, Class type) { * @param tf the TensorFlow Ops * @param l1 L1 regularization factor, if null it is set to 0. * @param l2 L2 regularization factor, if null it is set to 0. - * @param type the data type for the weights * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} * of {@link Float#isInfinite} */ - public L1L2(Ops tf, Float l1, Float l2, Class type) { - super(tf, type); + public L1L2(Ops tf, Float l1, Float l2) { + super(tf); if (l1 != null) { if (l1.isNaN() || l1.isInfinite()) { throw new IllegalArgumentException( @@ -86,25 +82,29 @@ public L1L2(Ops tf, Float l1, Float l2, Class type) { } } + public static L1L2 l1_l2(Ops tf) { + return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + } + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { Ops tf = getTF(); - if (this.getL1() == null && this.getL2() == null) { + if (this.getL1() == 0f && this.getL2() == 0f) { return tf.dtypes.cast(tf.constant(0), input.type()); } Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); - if (this.getL1() != null && this.getL1() != 0.f) { + if (this.getL1() != 0.f) { Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); Operand abs = tf.math.abs(input); Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); } - if (this.getL2() != null && this.getL2() != 0.f) { + if (this.getL2() != 0.f) { Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); - Operand sqr = tf.math.abs(input); + Operand sqr = tf.math.square(input); Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); } @@ -117,7 +117,7 @@ public Operand call(Operand input) { * * @return the L1 regularization factor */ - public Float getL1() { + public float getL1() { return l1; } @@ -126,7 +126,7 @@ public Float getL1() { * * @return the L2 regularization factor */ - public Float getL2() { + public float getL2() { return l2; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java index 95eecc2dd5f..44e04ad4d94 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; /** * A regularizer that applies both L1 and L2 regularization penalties. @@ -30,10 +29,8 @@ * *

The difference between this class and the {@link L1L2} is use of the default regularization * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. - * - * @param the data type for the weights */ -public class L1_L2 extends L1L2 { +public class L1_L2 extends L1L2 { /** * Creates a regularizer that applies an L1 and l2 regularization penalty of {@link @@ -41,22 +38,24 @@ public class L1_L2 extends L1L2 { * * @param tf the TensorFlow Ops */ - public L1_L2(Ops tf, Class type) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY, type); + public L1_L2(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); } /** * Creates a regularizer that applies an L1 and l2 regularization penalty * * @param tf the TensorFlow Ops - * @param l1 the L1 regularization penalty - * @param l2 the L2 regularization penalty + * @param l1 the L1 regularization penalty. If null, then l1 will be set to {@link + * #DEFAULT_REGULARIZATION_PENALTY}. + * @param l2 the L2 regularization penalty. If null, then l2 will be set to {@link + * #DEFAULT_REGULARIZATION_PENALTY}. * @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite. */ - public L1_L2(Ops tf, Float l1, Float l2, Class type) { - super(tf, - l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, - l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2, - type); + public L1_L2(Ops tf, Float l1, Float l2) { + super( + tf, + l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, + l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index 8298cd4aba5..b09b93a76d9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -15,16 +15,13 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; /** * A regularizer that applies a L2 (Ridge Regression) regularization penalty. * *

The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x)) - * - * @param the data type for the operands and result */ -public class L2 extends L1L2 { +public class L2 extends L1L2 { /** * Create a regularizer that applies an L2 regularization penalty of {@link @@ -32,8 +29,8 @@ public class L2 extends L1L2 { * * @param tf the TensorFlow Ops */ - public L2(Ops tf, Class type) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + public L2(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY); } /** @@ -43,7 +40,7 @@ public L2(Ops tf, Class type) { * @param l2 the L2 regularization penalty * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ - public L2(Ops tf, float l2, Class type) { - super(tf, null, l2, type); + public L2(Ops tf, float l2) { + super(tf, null, l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index 906efee7f3d..d1c17d4fc8c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -24,33 +24,31 @@ * *

Regularizers allow you to apply penalties on layer parameters or layer activity during * optimization. These penalties are summed into the loss function that the network optimizes. - * - * @param the data type of the operands and result */ -public abstract class Regularizer { +public abstract class Regularizer { public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; private final Ops tf; private final String name; - protected Class type; /** - * Creates a Regularizer + * Creates a Regularizer, using {@link Class#getSimpleName()} for the name * * @param tf the TensorFlow ops. */ - protected Regularizer(Ops tf, Class type) { - this(tf, null, type); + protected Regularizer(Ops tf) { + this(tf, null); } /** * Creates a Regularizer * * @param tf the TensorFlow ops. + * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the + * name. */ - protected Regularizer(Ops tf, String name, Class type) { + protected Regularizer(Ops tf, String name) { this.tf = tf; - this.type = type; this.name = name == null ? this.getClass().getSimpleName() : name; } @@ -61,7 +59,7 @@ protected Regularizer(Ops tf, String name, Class type) { * @return this Regularizer as a Loss */ public Loss asLoss() { - return new RegularizerLoss<>(this.tf, this); + return new RegularizerLoss(this.tf, this); } /** @@ -70,7 +68,7 @@ public Loss asLoss() { * @param input the weighted input * @return the result of computing the regularization penalty */ - public abstract Operand call(Operand input); + public abstract Operand call(Operand input); /** * Gets the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java index 0dc137ed747..582cd038f8f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -19,27 +19,24 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Regularizer call wrapped as a Loss instance * *

This class facilitates using a regularizer as a loss, only sampleWeights are * regularized. - * - * @param the datatype for the weights type */ -class RegularizerLoss extends Loss { +class RegularizerLoss extends Loss { + + private final Regularizer regularizer; - private final Regularizer regularizer; - private final Class type; /** * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, Regularizer regularizer) { + public RegularizerLoss(Ops tf, Regularizer regularizer) { this(tf, null, regularizer); } @@ -48,11 +45,11 @@ public RegularizerLoss(Ops tf, Regularizer regularizer) { * * @param tf the TensorFlow Ops * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { + public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { super(tf, name); this.regularizer = regularizer; - this.type = regularizer.type; } /** {@inheritDoc} */ @@ -62,7 +59,6 @@ public Operand call( if (sampleWeights == null) { throw new IllegalArgumentException("sampleWeights cannot be null"); } - Operand result = regularizer.call(cast(getTF(), sampleWeights, type)); - return cast(tf, result, sampleWeights.type()); + return regularizer.call(sampleWeights); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 0f3213ed6eb..3c6dd83731b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -17,19 +17,19 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, 0.2f, 0.3f, TFloat32.class); + L1L2 instance = new L1L2(tf, 0.2f, 0.3f); assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1L2<>(tf, null, null, TFloat32.class); + instance = new L1L2(tf, null, null); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2<>(tf, 0.5f, null, TFloat32.class); + instance = new L1L2(tf, 0.5f, null); assertEquals(0.5f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2<>(tf, null, 0.5f, TFloat32.class); + instance = new L1L2(tf, null, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); } @@ -40,7 +40,7 @@ public void testCallDefaultsConstant() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, TFloat32.class); + L1L2 instance = new L1L2(tf); Operand result = instance.call(tf.constant(555f)); session.evaluate(0f, result); } @@ -51,7 +51,7 @@ public void testCallL1L20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, TFloat32.class); + L1L2 instance = new L1L2(tf); Operand weights = tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); Operand result = instance.call(weights); @@ -64,7 +64,7 @@ public void testCallL1L2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat32.class); + L1L2 instance = new L1L2(tf, 0.01f, 0.02f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -79,7 +79,7 @@ public void testCallL1L2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat64.class); + L1L2 instance = new L1L2(tf, 0.01f, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -94,7 +94,7 @@ public void testCallL2Null() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, 0.01f, null, TFloat32.class); + L1L2 instance = new L1L2(tf, 0.01f, null); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -108,7 +108,7 @@ public void testCallL1Null() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, null, 0.02f, TFloat64.class); + L1L2 instance = new L1L2(tf, null, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java index 6d67bb44d3c..0e42a257816 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -17,15 +17,15 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.2f, TFloat32.class); + L1 instance = new L1(tf, 0.2f); assertEquals(0.2f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1<>(tf, 0f, TFloat32.class); + instance = new L1(tf, 0f); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1<>(tf, TFloat32.class); + instance = new L1(tf); assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); assertEquals(0.f, instance.getL2()); } @@ -36,7 +36,7 @@ public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.0f, TFloat32.class); + L1 instance = new L1(tf, 0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -49,7 +49,7 @@ public void testCallL1TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1<>(tf, TFloat32.class); + L1 instance = new L1(tf); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -63,7 +63,7 @@ public void testCallL1TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.02f, TFloat64.class); + L1 instance = new L1(tf, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java index 5aeb5a5d9ad..e4b4e7cc7a3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java @@ -17,19 +17,19 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0.2f, 0.3f, TFloat32.class); + L1_L2 instance = new L1_L2(tf, 0.2f, 0.3f); assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1_L2<>(tf, 0.5f, 0f, TFloat32.class); + instance = new L1_L2(tf, 0.5f, 0f); assertEquals(0.5f, instance.getL1()); assertEquals(0f, instance.getL2()); - instance = new L1_L2<>(tf, 0f, 0.5f, TFloat32.class); + instance = new L1_L2(tf, 0f, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); - instance = new L1_L2<>(tf, TFloat32.class); + instance = new L1_L2(tf); assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); } @@ -40,7 +40,7 @@ public void testCallZero() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0f, 0f, TFloat32.class); + L1_L2 instance = new L1_L2(tf, 0f, 0f); Operand result = instance.call(tf.constant(555f)); session.evaluate(0, result); } @@ -51,7 +51,7 @@ public void testCallDefaultTFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, TFloat32.class); + L1_L2 instance = new L1_L2(tf); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -70,7 +70,7 @@ public void testCallDefaultTFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, TFloat64.class); + L1_L2 instance = new L1_L2(tf); double[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -89,7 +89,7 @@ public void testCallL1L2() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0.01f, 0.02f, TFloat32.class); + L1_L2 instance = new L1_L2(tf, 0.01f, 0.02f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -104,7 +104,7 @@ public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0.01f, 0f, TFloat32.class); + L1_L2 instance = new L1_L2(tf, 0.01f, 0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -118,7 +118,7 @@ public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0f, 0.02f, TFloat64.class); + L1_L2 instance = new L1_L2(tf, 0f, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java index 7f593a2dd14..aba036ee306 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -17,15 +17,15 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2<>(tf, 0.2f, TFloat32.class); + L2 instance = new L2(tf, 0.2f); assertEquals(0.2f, instance.getL2()); assertEquals(0.f, instance.getL1()); - instance = new L2<>(tf, 0f, TFloat32.class); + instance = new L2(tf, 0f); assertEquals(0.f, instance.getL2()); assertEquals(0.f, instance.getL1()); - L2 instance64 = new L2<>(tf, TFloat64.class); + L2 instance64 = new L2(tf); assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); assertEquals(0.f, instance64.getL1()); } @@ -36,7 +36,7 @@ public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2<>(tf, 0.0f, TFloat32.class); + L2 instance = new L2(tf, 0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -49,7 +49,7 @@ public void testCallL2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2<>(tf, TFloat32.class); + L2 instance = new L2(tf); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -64,7 +64,7 @@ public void testCallL2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2<>(tf, 0.02f, TFloat64.class); + L2 instance = new L2(tf, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java index e694d9409a0..836503af1fa 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -1,7 +1,27 @@ package org.tensorflow.framework.regularizers; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; class RegularizerLossTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; -} \ No newline at end of file + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 regularizer = new L1L2(tf, 0.01f, null); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand regularizerResult = regularizer.call(weights); + RegularizerLoss lossInstance = new RegularizerLoss(tf, regularizer); + + Operand loss = lossInstance.call(null, null, weights); + session.evaluate(regularizerResult, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 694287d4970..c0c0f12fbf9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -127,10 +127,10 @@ public static DoubleNdArray square(DoubleNdArray a) { DoubleNdArray result = NdArrays.ofDoubles(a.shape()); int nDims = a.shape().numDimensions(); a.elements(nDims - 1) - .forEachIndexed( - (idx, v) -> { - result.setDouble(v.getDouble() * v.getDouble(), idx); - }); + .forEachIndexed( + (idx, v) -> { + result.setDouble(v.getDouble() * v.getDouble(), idx); + }); return result; } @@ -308,22 +308,22 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) { public static DoubleNdArray mul(DoubleNdArray a, DoubleNdArray b) { if (!a.shape().equals(b.shape())) throw new IllegalArgumentException( - String.format( - "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); + String.format( + "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); boolean sameSize = a.shape().size() == b.shape().size(); DoubleNdArray result = NdArrays.ofDoubles(a.shape()); int nDims = a.shape().numDimensions(); a.elements(nDims - 1) - .forEachIndexed( - (idx, v) -> { - if (sameSize) { - result.setDouble(v.getDouble() * b.getDouble(idx), idx); - } else { - double value = v.getDouble() * b.getDouble(idx[0], 0L); - result.setDouble(value, idx); - } - }); + .forEachIndexed( + (idx, v) -> { + if (sameSize) { + result.setDouble(v.getDouble() * b.getDouble(idx), idx); + } else { + double value = v.getDouble() * b.getDouble(idx[0], 0L); + result.setDouble(value, idx); + } + }); return result; } @@ -528,7 +528,7 @@ public static FloatNdArray max(FloatNdArray a, FloatNdArray b) { a.elements(nDims - 1) .forEachIndexed( (idx, v) -> { - result.setFloat((float) Math.max(v.getFloat(), b.getFloat(idx)), idx); + result.setFloat(Math.max(v.getFloat(), b.getFloat(idx)), idx); }); return result; } @@ -547,7 +547,7 @@ public static FloatNdArray max(FloatNdArray a, float scalar) { a.elements(nDims - 1) .forEachIndexed( (idx, v) -> { - result.setFloat((float) Math.max(v.getFloat(), scalar), idx); + result.setFloat(Math.max(v.getFloat(), scalar), idx); }); return result; } @@ -580,7 +580,7 @@ public static FloatNdArray min(FloatNdArray a, FloatNdArray b) { a.elements(nDims - 1) .forEachIndexed( (idx, v) -> { - result.setFloat((float) Math.min(v.getFloat(), b.getFloat(idx)), idx); + result.setFloat(Math.min(v.getFloat(), b.getFloat(idx)), idx); }); return result; } @@ -599,7 +599,7 @@ public static FloatNdArray min(FloatNdArray a, float scalar) { a.elements(nDims - 1) .forEachIndexed( (idx, v) -> { - result.setFloat((float) Math.min(v.getFloat(), scalar), idx); + result.setFloat(Math.min(v.getFloat(), scalar), idx); }); return result; } @@ -624,7 +624,7 @@ public static FloatNdArray min(float scalar, FloatNdArray a) { */ public static FloatNdArray abs(FloatNdArray a) { FloatNdArray result = NdArrays.ofFloats(a.shape()); - a.scalars().forEachIndexed((idx, f) -> result.setFloat((float) Math.abs(f.getFloat()), idx)); + a.scalars().forEachIndexed((idx, f) -> result.setFloat(Math.abs(f.getFloat()), idx)); return result; } @@ -636,7 +636,7 @@ public static FloatNdArray abs(FloatNdArray a) { */ public static DoubleNdArray abs(DoubleNdArray a) { DoubleNdArray result = NdArrays.ofDoubles(a.shape()); - a.scalars().forEachIndexed((idx, f) -> result.setDouble( Math.abs(f.getDouble()), idx)); + a.scalars().forEachIndexed((idx, f) -> result.setDouble(Math.abs(f.getDouble()), idx)); return result; } @@ -755,7 +755,7 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis) { } /** - * Sum all elements of an array based on the specified axis + * Sum all elements of an array over on the specified axis * * @param a the array * @param axis the axis to sum @@ -771,10 +771,10 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { final double[] sums = new double[(int) axisSize]; a.scalars() - .forEachIndexed( - (idx, f) -> { - sums[(int) idx[xis]] += f.getDouble(); - }); + .forEachIndexed( + (idx, f) -> { + sums[(int) idx[xis]] += f.getDouble(); + }); if (keepDims) { long[] newDims = shape.asArray(); @@ -782,11 +782,11 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { final AtomicInteger counter = new AtomicInteger(); DoubleNdArray arrayK = NdArrays.ofDoubles(Shape.of(newDims)); arrayK - .elements(newDims.length - 1) - .forEachIndexed( - (idx, v) -> { - v.setDouble(sums[counter.getAndAdd(1)]); - }); + .elements(newDims.length - 1) + .forEachIndexed( + (idx, v) -> { + v.setDouble(sums[counter.getAndAdd(1)]); + }); return arrayK; } else { return NdArrays.vectorOf(sums); From a3ccf617de1ebe32580294ea15b7ab52483a3621 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 14 Feb 2021 10:49:12 -0500 Subject: [PATCH 05/14] Add JavDoc to new method l1_l2 --- .../java/org/tensorflow/framework/regularizers/L1L2.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 89b407e0940..2670de799e1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -82,6 +82,13 @@ public L1L2(Ops tf, Float l1, Float l2) { } } + /** + * Creates an L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 + * values. + * + * @param tf the TensorFlow Ops + * @return a L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 values. + */ public static L1L2 l1_l2(Ops tf) { return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); } From 8c792147da6c0cc062f50f7987ec68d5e565f17a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 14 Feb 2021 11:04:00 -0500 Subject: [PATCH 06/14] change l1_l2 to create. --- .../main/java/org/tensorflow/framework/regularizers/L1L2.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 2670de799e1..b2a7edbb187 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -89,7 +89,7 @@ public L1L2(Ops tf, Float l1, Float l2) { * @param tf the TensorFlow Ops * @return a L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 values. */ - public static L1L2 l1_l2(Ops tf) { + public static L1L2 create(Ops tf) { return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); } From 1af455298f82bfab8f39a7fcfeb4c7089a52ee04 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 14 Feb 2021 12:06:08 -0500 Subject: [PATCH 07/14] delete class L1_L2 modified Float to float for l1 and l2 parameters Change ctor L1L2(Ops tf) to use DEFAULT_REGULARIZATION_PENALTY for l1/l2 parameters Fix JavaDoc --- .../tensorflow/framework/regularizers/L1.java | 2 +- .../framework/regularizers/L1L2.java | 49 ++----- .../framework/regularizers/L1_L2.java | 61 -------- .../tensorflow/framework/regularizers/L2.java | 2 +- .../framework/regularizers/Regularizer.java | 1 + .../framework/regularizers/L1L2Test.java | 24 ++-- .../framework/regularizers/L1_L2Test.java | 130 ------------------ .../regularizers/RegularizerLossTest.java | 2 +- 8 files changed, 33 insertions(+), 238 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 074e881c1cd..7c8c2a1360a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -42,6 +42,6 @@ public L1(Ops tf) { * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ public L1(Ops tf, float l1) { - super(tf, l1, null); + super(tf, l1, 0f); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index b2a7edbb187..29e411f9897 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -30,8 +30,6 @@ * *

loss = l2 * reduceSum(square(x))
* - *

The difference between this class and the {@link L1_L2} is use of the default regularization - * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. */ public class L1L2 extends Regularizer { @@ -44,7 +42,7 @@ public class L1L2 extends Regularizer { * @param tf the TensorFlow Ops */ public L1L2(Ops tf) { - this(tf, null, null); + this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); } /** @@ -56,42 +54,25 @@ public L1L2(Ops tf) { * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} * of {@link Float#isInfinite} */ - public L1L2(Ops tf, Float l1, Float l2) { + public L1L2(Ops tf, float l1, float l2) { super(tf); - if (l1 != null) { - if (l1.isNaN() || l1.isInfinite()) { - throw new IllegalArgumentException( - String.format( - "L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", - l1)); - } - this.l1 = l1; - } else { - this.l1 = 0f; + if (Float.isNaN(l1) || Float.isInfinite(l1)) { + throw new IllegalArgumentException( + String.format( + "L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l1)); } - if (l2 != null) { - if (l2.isNaN() || l2.isInfinite()) { - throw new IllegalArgumentException( - String.format( - "L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", - l2)); - } - this.l2 = l2; - } else { - this.l2 = 0f; + this.l1 = l1; + + if (Float.isNaN(l2) || Float.isInfinite(l2)) { + throw new IllegalArgumentException( + String.format( + "L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l2)); } + this.l2 = l2; } - /** - * Creates an L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 - * values. - * - * @param tf the TensorFlow Ops - * @return a L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 values. - */ - public static L1L2 create(Ops tf) { - return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); - } /** {@inheritDoc} */ @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java deleted file mode 100644 index 44e04ad4d94..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.regularizers; - -import org.tensorflow.op.Ops; - -/** - * A regularizer that applies both L1 and L2 regularization penalties. - * - *

The L1 regularization penalty is computed as: - * - *

loss = l1 * reduceSum(abs(x))
- * - *

The L2 regularization penalty is computed as - * - *

loss = l2 * reduceSum(square(x))
- * - *

The difference between this class and the {@link L1L2} is use of the default regularization - * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. - */ -public class L1_L2 extends L1L2 { - - /** - * Creates a regularizer that applies an L1 and l2 regularization penalty of {@link - * #DEFAULT_REGULARIZATION_PENALTY} - * - * @param tf the TensorFlow Ops - */ - public L1_L2(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); - } - - /** - * Creates a regularizer that applies an L1 and l2 regularization penalty - * - * @param tf the TensorFlow Ops - * @param l1 the L1 regularization penalty. If null, then l1 will be set to {@link - * #DEFAULT_REGULARIZATION_PENALTY}. - * @param l2 the L2 regularization penalty. If null, then l2 will be set to {@link - * #DEFAULT_REGULARIZATION_PENALTY}. - * @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite. - */ - public L1_L2(Ops tf, Float l1, Float l2) { - super( - tf, - l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, - l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2); - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index b09b93a76d9..7b8f5b28a70 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -41,6 +41,6 @@ public L2(Ops tf) { * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ public L2(Ops tf, float l2) { - super(tf, null, l2); + super(tf, 0f, l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index d1c17d4fc8c..5d9ff0e3e10 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -67,6 +67,7 @@ public Loss asLoss() { * * @param input the weighted input * @return the result of computing the regularization penalty + * @param the data type of the input and result */ public abstract Operand call(Operand input); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 3c6dd83731b..181ae367f07 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -21,17 +21,21 @@ public void testCreate() { assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1L2(tf, null, null); + instance = new L1L2(tf, 0, 0); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, 0.5f, null); + instance = new L1L2(tf, 0.5f, 0); assertEquals(0.5f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, null, 0.5f); + instance = new L1L2(tf, 0, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); + + instance = new L1L2(tf); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); } } @@ -42,16 +46,16 @@ public void testCallDefaultsConstant() { Ops tf = session.getTF(); L1L2 instance = new L1L2(tf); Operand result = instance.call(tf.constant(555f)); - session.evaluate(0f, result); + session.evaluate(3085.8f, result); } } @Test - public void testCallL1L20() { + public void testCallL1L2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf); + L1L2 instance = new L1L2(tf, 0, 0); Operand weights = tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); Operand result = instance.call(weights); @@ -90,11 +94,11 @@ public void testCallL1L2TFloat64() { } @Test - public void testCallL2Null() { + public void testCallL2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, null); + L1L2 instance = new L1L2(tf, 0.01f, 0); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -104,11 +108,11 @@ public void testCallL2Null() { } @Test - public void testCallL1Null() { + public void testCallL1_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, null, 0.02f); + L1L2 instance = new L1L2(tf, 0, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java deleted file mode 100644 index e4b4e7cc7a3..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java +++ /dev/null @@ -1,130 +0,0 @@ -package org.tensorflow.framework.regularizers; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -class L1_L2Test extends CommonTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - - @Test - public void testCreate() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0.2f, 0.3f); - assertEquals(0.2f, instance.getL1()); - assertEquals(0.3f, instance.getL2()); - - instance = new L1_L2(tf, 0.5f, 0f); - assertEquals(0.5f, instance.getL1()); - assertEquals(0f, instance.getL2()); - - instance = new L1_L2(tf, 0f, 0.5f); - assertEquals(0.f, instance.getL1()); - assertEquals(0.5f, instance.getL2()); - - instance = new L1_L2(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); - } - } - - @Test - public void testCallZero() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0f, 0f); - Operand result = instance.call(tf.constant(555f)); - session.evaluate(0, result); - } - } - - @Test - public void testCallDefaultTFloat32() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf); - float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = - regularizeL1L2( - w, - Regularizer.DEFAULT_REGULARIZATION_PENALTY, - Regularizer.DEFAULT_REGULARIZATION_PENALTY); - session.setEpsilon(.01f); - session.evaluate(expected, result); - } - } - - @Test - public void testCallDefaultTFloat64() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf); - double[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - double expected = - regularizeL1L2( - w, - Regularizer.DEFAULT_REGULARIZATION_PENALTY, - Regularizer.DEFAULT_REGULARIZATION_PENALTY); - session.setEpsilon(.01f); - session.evaluate(expected, result); - } - } - - @Test - public void testCallL1L2() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0.01f, 0.02f); - float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL1L2(w, 0.01f, 0.02f); - session.setEpsilon(.01f); - session.evaluate(expected, result); - } - } - - @Test - public void testCallL20() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0.01f, 0f); - float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL1(w, 0.01f); - session.evaluate(expected, result); - } - } - - @Test - public void testCallL10() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0f, 0.02f); - double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - double expected = regularizeL2(w, 0.02f); - session.setEpsilon(.01f); - session.evaluate(expected, result); - } - } -} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java index 836503af1fa..fe2624cec3d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -14,7 +14,7 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 regularizer = new L1L2(tf, 0.01f, null); + L1L2 regularizer = new L1L2(tf, 0.01f, 0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand regularizerResult = regularizer.call(weights); From 54f18021acbf53772774729581984625f229a08e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 1 May 2021 18:13:25 -0400 Subject: [PATCH 08/14] Rebase with tensorflow Master --- .../tensorflow/framework/regularizers/L1.java | 48 +++++ .../framework/regularizers/L1L2.java | 129 +++++++++++++ .../framework/regularizers/L1_L2.java | 61 ++++++ .../tensorflow/framework/regularizers/L2.java | 48 +++++ .../framework/regularizers/Regularizer.java | 92 +++++++++ .../regularizers/RegularizerLoss.java | 69 +++++++ .../framework/regularizers/CommonTest.java | 63 +++++++ .../framework/regularizers/L1L2Test.java | 110 +++++++++++ .../framework/regularizers/L1Test.java | 77 ++++++++ .../framework/regularizers/L1_L2Test.java | 116 ++++++++++++ .../framework/regularizers/L2Test.java | 79 ++++++++ .../org/tensorflow/framework/utils/ND.java | 175 ++++++++++++++---- 12 files changed, 1033 insertions(+), 34 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java new file mode 100644 index 00000000000..8d3469dca98 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies a L1 regularization penalty. + * + *

The L1 regularization penalty is computed as: loss = l1 * reduceSum(abs(x)) + * + * @param the data type for the weights + */ +public class L1 extends L1L2 { + + /** + * Create a regularizer that applies an L1 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L1(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Create a regularizer that applies an L1 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l1 the L1 regularization penalty + */ + public L1(Ops tf, float l1, Class type) { + super(tf, l1, null, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java new file mode 100644 index 00000000000..565384f1fa5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies both L1 and L2 regularization penalties. + * + *

The L1 regularization penalty is computed as: + * + *

loss = l1 * reduceSum(abs(x))
+ * + *

The L2 regularization penalty is computed as + * + *

loss = l2 * reduceSum(square(x))
+ * + * @param the data type for the weights + */ +public class L1L2 extends Regularizer { + + private final Float l1; + private final Float l2; + + /** + * Creates an L1L2 regularizer with no l1 or l2 penalty with default penal + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights + */ + public L1L2(Ops tf, Class type) { + this(tf, null, null, type); + } + + /** + * Creates an L1L2 regularizer + * + * @param tf the TensorFlow Ops + * @param l1 L1 regularization factor, if null it is set to 0. + * @param l2 L2 regularization factor, if null it is set to 0. + * @param type the data type for the weights + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link }NaN or is + * infinite. + */ + public L1L2(Ops tf, Float l1, Float l2, Class type) { + super(tf, type); + if (l1 != null) { + if (l1.isNaN() || l1.isInfinite()) { + throw new IllegalArgumentException( + String.format( + "L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l1)); + } + this.l1 = l1; + } else { + this.l1 = 0f; + } + if (l2 != null) { + if (l2.isNaN() || l2.isInfinite()) { + throw new IllegalArgumentException( + String.format( + "L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l2)); + } + this.l2 = l2; + } else { + this.l2 = 0f; + } + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + Ops tf = getTF(); + if (this.getL1() == null && this.getL2() == null) { + return tf.dtypes.cast(tf.constant(0), input.type()); + } + Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); + + if (this.getL1() != null && this.getL1() != 0.f) { + Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); + Operand abs = tf.math.abs(input); + Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); + regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); + } + + if (this.getL2() != null && this.getL2() != 0.f) { + Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); + Operand sqr = tf.math.abs(input); + Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); + regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); + } + + return regularization; + } + + /** + * Gets the L1 regularization factor + * + * @return the L1 regularization factor + */ + public Float getL1() { + return l1; + } + + /** + * Gets the L2 regularization factor + * + * @return the L2 regularization factor + */ + public Float getL2() { + return l2; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java new file mode 100644 index 00000000000..fe763236cea --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies both L1 and L2 regularization penalties. + * + *

The L1 regularization penalty is computed as: + * + *

loss = l1 * reduceSum(abs(x))
+ * + *

The L2 regularization penalty is computed as + * + *

loss = l2 * reduceSum(square(x))
+ * + *

The difference between this class and the {@link L1L2} is use of the default regularization + * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. + * + * @param the data type for the weights + */ +public class L1_L2 extends L1L2 { + + /** + * Create a regularizer that applies an L1 and l2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L1_L2(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Create a regularizer that applies an L1 and l2 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l1 the L1 regularization penalty + */ + public L1_L2(Ops tf, Float l1, Float l2, Class type) { + super( + tf, + l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, + l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2, + type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java new file mode 100644 index 00000000000..3ce0f84581c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies a L2 regularization penalty. + * + *

The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x)) + * + * @param the data type for the weights + */ +public class L2 extends L1L2 { + + /** + * Create a regularizer that applies an L2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L2(Ops tf, Class type) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + } + + /** + * Create a regularizer that applies an L1 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l2 the L2 regularization penalty + */ + public L2(Ops tf, float l2, Class type) { + super(tf, null, l2, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java new file mode 100644 index 00000000000..20701fb6348 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Regularizer base class. + * + *

Regularizers allow you to apply penalties on layer parameters or layer activity during + * optimization. These penalties are summed into the loss function that the network optimizes. + * + * @param the data type of the result + */ +public abstract class Regularizer { + + public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; + + private final Ops tf; + private final String name; + protected Class type; + + /** + * Create a Regularizer + * + * @param tf the TensorFlow ops. + */ + protected Regularizer(Ops tf, Class type) { + this(tf, null, type); + } + /** + * Create a Regularizer + * + * @param tf the TensorFlow ops. + */ + protected Regularizer(Ops tf, String name, Class type) { + this.tf = tf; + this.type = type; + this.name = name == null ? this.getClass().getSimpleName() : name; + } + + /** + * Returns this Regularizer as a Loss This is a convenience tp regularize a loss. Only + * sampleWeights are applied to the regularizer. + * + * @return this Regularizer as a Loss + */ + public Loss asLoss() { + return new RegularizerLoss(this.tf, this); + } + + /** + * Computes a regularization penalty from an input. + * + * @param input teh weighted input + * @return the result of computing the regularization penalty + */ + public abstract Operand call(Operand input); + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the name for this regularizer + * + * @return the name for this regularizer + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java new file mode 100644 index 00000000000..e750eabebab --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * A Regularizer call wrapped as a Loss instance + * + *

This class facilitates using a regularizing as a loss, only sampleWeights are + * regularized. + * + * @param the datatype for the weights type + */ +class RegularizerLoss extends Loss { + + private final Regularizer regularizer; + private final Class type; + /** + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT} + * + * @param tf the TensorFlow Ops + */ + public RegularizerLoss(Ops tf, Regularizer regularizer) { + this(tf, null, regularizer); + } + + /** + * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * + * @param tf the TensorFlow Ops + * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + */ + public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { + super(tf, name); + this.regularizer = regularizer; + this.type = regularizer.type; + } + + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + if (sampleWeights == null) { + throw new IllegalArgumentException("sampleWeights cannot be null"); + } + Operand result = regularizer.call(cast(getTF(), sampleWeights, type)); + return cast(tf, result, sampleWeights.type()); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java new file mode 100644 index 00000000000..63ecc155fd1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.framework.utils.ND; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.StdArrays; + +public class CommonTest { + + protected float regularizeL1L2(float[][] w, float l1, float l2) { + return regularizeL1(w, l1) + regularizeL2(w, l2); + } + + protected float regularizeL1(float[][] w, float l1) { + FloatNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.abs(fa); + FloatNdArray sum = ND.sum(fa); + FloatNdArray mul = ND.mul(sum, l1); + return mul.getFloat(); + } + + protected float regularizeL2(float[][] w, float l2) { + FloatNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.square(fa); + FloatNdArray sum = ND.sum(fa); + FloatNdArray mul = ND.mul(sum, l2); + return mul.getFloat(); + } + + protected double regularizeL1L2(double[][] w, float l1, float l2) { + return regularizeL1(w, l1) + regularizeL2(w, l2); + } + + protected double regularizeL1(double[][] w, float l1) { + DoubleNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.abs(fa); + DoubleNdArray sum = ND.sum(fa); + DoubleNdArray mul = ND.mul(sum, l1); + return mul.getDouble(); + } + + protected double regularizeL2(double[][] w, float l2) { + DoubleNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.square(fa); + DoubleNdArray sum = ND.sum(fa); + DoubleNdArray mul = ND.mul(sum, l2); + return mul.getDouble(); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java new file mode 100644 index 00000000000..178468f42e8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -0,0 +1,110 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.2f, 0.3f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.3f, instance.getL2()); + + instance = new L1L2<>(tf, null, null, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1L2<>(tf, 0.5f, null, TFloat32.class); + assertEquals(0.5f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1L2<>(tf, null, 0.5f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.5f, instance.getL2()); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCall() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, TFloat32.class); + Operand result = instance.call(tf.constant(555f)); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallNO() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, TFloat32.class); + Operand weights = + tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1L2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.09f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, null, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, 0.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, null, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java new file mode 100644 index 00000000000..1a4416b3302 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -0,0 +1,77 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.2f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1<>(tf, 0f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(0.f, instance.getL2()); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallNO() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1_2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1<>(tf, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1(w, 0.02f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java new file mode 100644 index 00000000000..d35179cdd62 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java @@ -0,0 +1,116 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1_L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.2f, 0.3f, TFloat32.class); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.3f, instance.getL2()); + + instance = new L1_L2<>(tf, 0.5f, 0f, TFloat32.class); + assertEquals(0.5f, instance.getL1()); + assertEquals(0f, instance.getL2()); + + instance = new L1_L2<>(tf, 0f, 0.5f, TFloat32.class); + assertEquals(0.f, instance.getL1()); + assertEquals(0.5f, instance.getL2()); + + instance = new L1_L2<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallZero() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0f, 0f, TFloat32.class); + Operand result = instance.call(tf.constant(555f)); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallNO() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = + regularizeL1L2( + w, + Regularizer.DEFAULT_REGULARIZATION_PENALTY, + Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallL1L2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.01f, 0.02f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallL1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0.01f, 0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, 0.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallL2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, 0f, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java new file mode 100644 index 00000000000..cbb019796f1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -0,0 +1,79 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.2f, TFloat32.class); + assertEquals(0.2f, instance.getL2()); + assertEquals(0.f, instance.getL1()); + + instance = new L2<>(tf, 0f, TFloat32.class); + assertEquals(0.f, instance.getL2()); + assertEquals(0.f, instance.getL1()); + + instance = new L2<>(tf, TFloat32.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + assertEquals(0.f, instance.getL1()); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallNO() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.0f, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, TFloat32.class); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL2(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + /** Test of call method, of class L1L2. */ + @Test + public void testCallL1_2() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2<>(tf, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index ef8bb71d724..7a1f3c685ed 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,7 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.*; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; @@ -315,6 +319,64 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) { return mul(a, scalar); } + /** + * Multiply 2 arrays + * + * @param a the first array + * @param b the second array + * @return the resulting array from the muliply operation + */ + public static DoubleNdArray mul(DoubleNdArray a, DoubleNdArray b) { + if (!a.shape().equals(b.shape())) + throw new IllegalArgumentException( + String.format( + "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); + boolean sameSize = a.shape().size() == b.shape().size(); + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + if (sameSize) { + result.setDouble(v.getDouble() * b.getDouble(idx), idx); + } else { + double value = v.getDouble() * b.getDouble(idx[0], 0L); + result.setDouble(value, idx); + } + }); + return result; + } + + /** + * Multiply an array with a scalar value + * + * @param a the array + * @param scalar the scalar value + * @return the resulting array from the Multiply operation + */ + public static DoubleNdArray mul(DoubleNdArray a, float scalar) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + if (a.shape().isScalar()) { + a.scalars().forEach(f -> result.setDouble(f.getDouble() * scalar)); + } else { + a.scalars().forEachIndexed((idx, f) -> result.setDouble(f.getDouble() * scalar, idx)); + } + + return result; + } + + /** + * Multiply a scalar value with an array + * + * @param scalar the scalar value + * @param a the array + * @return the resulting array from the Multiply operation + */ + public static DoubleNdArray mul(float scalar, DoubleNdArray a) { + return mul(a, scalar); + } + /** * Divide two arrays * @@ -588,15 +650,15 @@ public static FloatNdArray abs(FloatNdArray a) { } /** - * Sum all elements of an array + * Get the absolute value of each member of the array * * @param a the array - * @return an a array with one element containing the sum. + * @return the array with the absolute value of each item. */ - public static FloatNdArray sum(FloatNdArray a) { - AtomicReference sum = new AtomicReference<>(0.f); - a.scalars().forEach(f -> sum.set(sum.get() + f.getFloat())); - return NdArrays.scalarOf(sum.get()); + public static DoubleNdArray abs(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + a.scalars().forEachIndexed((idx, f) -> result.setDouble( Math.abs(f.getDouble()), idx)); + return result; } /** @@ -605,12 +667,14 @@ public static FloatNdArray sum(FloatNdArray a) { * @param a the array * @return an a array with one element containing the sum. */ - public static DoubleNdArray sum(DoubleNdArray a) { - AtomicReference sum = new AtomicReference(0D); - a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + public static FloatNdArray sum(FloatNdArray a) { + AtomicReference sum = new AtomicReference<>(0.f); + a.scalars().forEach(f -> sum.set(sum.get() + f.getFloat())); return NdArrays.scalarOf(sum.get()); } + + /** * Sum all elements of an array based on the specified axis * @@ -622,16 +686,7 @@ public static FloatNdArray sum(FloatNdArray a, int axis) { return sum(a, axis, false); } - /** - * Sum all elements of an array based on the specified axis - * - * @param a the array - * @param axis the axis to sum - * @return an a array the sum over the axis less the diemsnion - */ - public static DoubleNdArray sum(DoubleNdArray a, int axis) { - return sum(a, axis, false); - } + /** * Sum all elements of an array based on the specified axis @@ -672,6 +727,58 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { } } + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axes the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static FloatNdArray sum(FloatNdArray a, Integer[] axes, boolean keepDims) { + Shape shape = a.shape(); + if (axes == null) { + FloatNdArray result = sum(a); + if (keepDims) { + float scalar = result.getFloat(0); + long[] dims = {1, 1}; + Shape bShape = Shape.of(dims); + FloatNdArray resultK = NdArrays.ofFloats(bShape); + resultK.setFloat(scalar, 0, 0); + return resultK; + } + return result; + } else if (axes.length == 1) { + return sum(a, axes[0], keepDims); + } else { + // TODO + throw new UnsupportedOperationException("Multi Axis Not implemented Yet"); + } + } + + /** + * Sum all elements of an array + * + * @param a the array + * @return an a array with one element containing the sum. + */ + public static DoubleNdArray sum(DoubleNdArray a) { + AtomicReference sum = new AtomicReference<>(0.); + a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + return NdArrays.scalarOf(sum.get()); + } + + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @return an a array the sum over the axis less the diemsnion + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis) { + return sum(a, axis, false); + } + /** * Sum all elements of an array based on the specified axis * @@ -689,10 +796,10 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { final double[] sums = new double[(int) axisSize]; a.scalars() - .forEachIndexed( - (idx, f) -> { - sums[(int) idx[xis]] += f.getDouble(); - }); + .forEachIndexed( + (idx, f) -> { + sums[(int) idx[xis]] += f.getDouble(); + }); if (keepDims) { long[] newDims = shape.asArray(); @@ -700,11 +807,11 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { final AtomicInteger counter = new AtomicInteger(); DoubleNdArray arrayK = NdArrays.ofDoubles(Shape.of(newDims)); arrayK - .elements(newDims.length - 1) - .forEachIndexed( - (idx, v) -> { - v.setDouble(sums[counter.getAndAdd(1)]); - }); + .elements(newDims.length - 1) + .forEachIndexed( + (idx, v) -> { + v.setDouble(sums[counter.getAndAdd(1)]); + }); return arrayK; } else { return NdArrays.vectorOf(sums); @@ -719,16 +826,16 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { * @param keepDims indicates whether the dimensions over the sum should be kept or not. * @return an a array the sum over the axis */ - public static FloatNdArray sum(FloatNdArray a, Integer[] axes, boolean keepDims) { + public static DoubleNdArray sum(DoubleNdArray a, Integer[] axes, boolean keepDims) { Shape shape = a.shape(); if (axes == null) { - FloatNdArray result = sum(a); + DoubleNdArray result = sum(a); if (keepDims) { - float scalar = result.getFloat(0); + double scalar = result.getDouble(0); long[] dims = {1, 1}; Shape bShape = Shape.of(dims); - FloatNdArray resultK = NdArrays.ofFloats(bShape); - resultK.setFloat(scalar, 0, 0); + DoubleNdArray resultK = NdArrays.ofDoubles(bShape); + resultK.setDouble(scalar, 0, 0); return resultK; } return result; From bbd3bc31caa34175b93ebd34bedb09b506ac4240 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 08:33:34 -0400 Subject: [PATCH 09/14] Updating fixed local copy to repair broken remote copy --- .../org/tensorflow/framework/utils/ND.java | 110 +++++------------- 1 file changed, 29 insertions(+), 81 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 7a1f3c685ed..d0cbae56628 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -98,9 +98,7 @@ public static FloatNdArray sqrt(FloatNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat((float) Math.sqrt(v.getFloat()), idx); - }); + (idx, v) -> result.setFloat((float) Math.sqrt(v.getFloat()), idx)); return result; } @@ -115,9 +113,7 @@ public static DoubleNdArray sqrt(DoubleNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setDouble(Math.sqrt(v.getDouble()), idx); - }); + (idx, v) -> result.setDouble(Math.sqrt(v.getDouble()), idx)); return result; } @@ -132,9 +128,7 @@ public static FloatNdArray square(FloatNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat(v.getFloat() * v.getFloat(), idx); - }); + (idx, v) -> result.setFloat(v.getFloat() * v.getFloat(), idx)); return result; } @@ -149,9 +143,7 @@ public static DoubleNdArray square(DoubleNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setDouble(v.getDouble() * v.getDouble(), idx); - }); + (idx, v) -> result.setDouble(v.getDouble() * v.getDouble(), idx)); return result; } @@ -169,9 +161,7 @@ public static FloatNdArray add(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat(v.getFloat() + b.getFloat(idx), idx); - }); + (idx, v) -> result.setFloat(v.getFloat() + b.getFloat(idx), idx)); return result; } @@ -188,9 +178,7 @@ public static FloatNdArray add(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat(v.getFloat() + scalar, idx); - }); + (idx, v) -> result.setFloat(v.getFloat() + scalar, idx)); return result; } @@ -219,9 +207,7 @@ public static FloatNdArray sub(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat(v.getFloat() - b.getFloat(idx), idx); - }); + (idx, v) -> result.setFloat(v.getFloat() - b.getFloat(idx), idx)); return result; } @@ -237,9 +223,7 @@ public static FloatNdArray sub(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat(v.getFloat() - scalar, idx); - }); + (idx, v) -> result.setFloat(v.getFloat() - scalar, idx)); return result; } @@ -255,9 +239,7 @@ public static FloatNdArray sub(float scalar, FloatNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat(scalar - v.getFloat(), idx); - }); + (idx, v) -> result.setFloat(scalar - v.getFloat(), idx)); return result; } @@ -391,9 +373,7 @@ public static FloatNdArray div(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat(v.getFloat() / b.getFloat(idx), idx); - }); + (idx, v) -> result.setFloat(v.getFloat() / b.getFloat(idx), idx)); return result; } @@ -410,9 +390,7 @@ public static FloatNdArray div(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat(v.getFloat() / scalar, idx); - }); + (idx, v) -> result.setFloat(v.getFloat() / scalar, idx)); return result; } @@ -449,9 +427,7 @@ public static FloatNdArray pow(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat((float) Math.pow(v.getFloat(), b.getFloat(idx)), idx); - }); + (idx, v) -> result.setFloat((float) Math.pow(v.getFloat(), b.getFloat(idx)), idx)); return result; } @@ -467,9 +443,7 @@ public static FloatNdArray pow(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat((float) Math.pow(v.getFloat(), scalar), idx); - }); + (idx, v) -> result.setFloat((float) Math.pow(v.getFloat(), scalar), idx)); return result; } @@ -485,9 +459,7 @@ public static FloatNdArray pow(float scalar, FloatNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat((float) Math.pow(scalar, v.getFloat()), idx); - }); + (idx, v) -> result.setFloat((float) Math.pow(scalar, v.getFloat()), idx)); return result; } @@ -503,9 +475,7 @@ public static float[] flatten(FloatNdArray a) { AtomicInteger counter = new AtomicInteger(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result[counter.getAndAdd(1)] = v.getFloat(); - }); + (idx, v) -> result[counter.getAndAdd(1)] = v.getFloat()); return result; } @@ -537,7 +507,7 @@ public static float min(FloatNdArray a) { * Get the maximum value of comparing the arrays * * @param a the first array - * @param a the second array + * @param b the second array * @return the resulting array with the maximum values between each element of the arrays. * @throws AssertionError if the two arrays are not the same size. */ @@ -548,9 +518,7 @@ public static FloatNdArray max(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat((float) Math.max(v.getFloat(), b.getFloat(idx)), idx); - }); + (idx, v) -> result.setFloat(Math.max(v.getFloat(), b.getFloat(idx)), idx)); return result; } @@ -567,9 +535,7 @@ public static FloatNdArray max(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat((float) Math.max(v.getFloat(), scalar), idx); - }); + (idx, v) -> result.setFloat(Math.max(v.getFloat(), scalar), idx)); return result; } @@ -589,7 +555,7 @@ public static FloatNdArray max(float scalar, FloatNdArray a) { * Get the minimum value of comparing the arrays * * @param a the first array - * @param a the second array + * @param b the second array * @return the resulting array with the minimum values between each element of the arrays. * @throws AssertionError if the two arrays are not the same size. */ @@ -600,9 +566,7 @@ public static FloatNdArray min(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat((float) Math.min(v.getFloat(), b.getFloat(idx)), idx); - }); + (idx, v) -> result.setFloat(Math.min(v.getFloat(), b.getFloat(idx)), idx)); return result; } @@ -619,9 +583,7 @@ public static FloatNdArray min(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> { - result.setFloat((float) Math.min(v.getFloat(), scalar), idx); - }); + (idx, v) -> result.setFloat( Math.min(v.getFloat(), scalar), idx)); return result; } @@ -645,7 +607,7 @@ public static FloatNdArray min(float scalar, FloatNdArray a) { */ public static FloatNdArray abs(FloatNdArray a) { FloatNdArray result = NdArrays.ofFloats(a.shape()); - a.scalars().forEachIndexed((idx, f) -> result.setFloat((float) Math.abs(f.getFloat()), idx)); + a.scalars().forEachIndexed((idx, f) -> result.setFloat(Math.abs(f.getFloat()), idx)); return result; } @@ -686,8 +648,6 @@ public static FloatNdArray sum(FloatNdArray a, int axis) { return sum(a, axis, false); } - - /** * Sum all elements of an array based on the specified axis * @@ -706,9 +666,7 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { a.scalars() .forEachIndexed( - (idx, f) -> { - sums[(int) idx[xis]] += f.getFloat(); - }); + (idx, f) -> sums[(int) idx[xis]] += f.getFloat()); if (keepDims) { long[] newDims = shape.asArray(); @@ -718,9 +676,7 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { arrayK .elements(newDims.length - 1) .forEachIndexed( - (idx, v) -> { - v.setFloat(sums[counter.getAndAdd(1)]); - }); + (idx, v) -> v.setFloat(sums[counter.getAndAdd(1)])); return arrayK; } else { return NdArrays.vectorOf(sums); @@ -780,7 +736,7 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis) { } /** - * Sum all elements of an array based on the specified axis + * Sum all elements of an array over the specified axis * * @param a the array * @param axis the axis to sum @@ -797,9 +753,7 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { a.scalars() .forEachIndexed( - (idx, f) -> { - sums[(int) idx[xis]] += f.getDouble(); - }); + (idx, f) -> sums[(int) idx[xis]] += f.getDouble()); if (keepDims) { long[] newDims = shape.asArray(); @@ -809,9 +763,7 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { arrayK .elements(newDims.length - 1) .forEachIndexed( - (idx, v) -> { - v.setDouble(sums[counter.getAndAdd(1)]); - }); + (idx, v) -> v.setDouble(sums[counter.getAndAdd(1)])); return arrayK; } else { return NdArrays.vectorOf(sums); @@ -883,9 +835,7 @@ public static void print(FloatNdArray a) { } else { a.elements(a.shape().numDimensions() - 1) .forEachIndexed( - (idx, v) -> { - System.out.printf("%s == %f\n", Arrays.toString(idx), v.getFloat()); - }); + (idx, v) -> System.out.printf("%s == %f\n", Arrays.toString(idx), v.getFloat())); } System.out.println(); } @@ -916,9 +866,7 @@ public static FloatNdArray create(float[] y, Shape shape) { result .elements(shape.numDimensions() - 1) .forEachIndexed( - (idx, v) -> { - v.setFloat(y[index.getAndAdd(1)]); - }); + (idx, v) -> v.setFloat(y[index.getAndAdd(1)])); return result; } } From 6c481312fac779a4aab02dcac6d3b3ea495f59c4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 26 Jan 2021 10:18:48 -0500 Subject: [PATCH 10/14] Clean up JavaDoc Make sure Tests test TFloat32 1nd TFloat64 --- .../tensorflow/framework/regularizers/L1.java | 4 ++- .../framework/regularizers/L1L2.java | 7 ++-- .../framework/regularizers/L1_L2.java | 6 ++-- .../tensorflow/framework/regularizers/L2.java | 5 +-- .../framework/regularizers/Regularizer.java | 14 ++++---- .../regularizers/RegularizerLoss.java | 2 +- .../framework/regularizers/L1L2Test.java | 32 ++++++++++++------- .../framework/regularizers/L1Test.java | 11 +++---- .../framework/regularizers/L1_L2Test.java | 30 ++++++++++++----- .../framework/regularizers/L2Test.java | 15 ++++----- .../regularizers/RegularizerLossTest.java | 7 ++++ 11 files changed, 83 insertions(+), 50 deletions(-) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 8d3469dca98..740338350e3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -18,7 +18,8 @@ import org.tensorflow.types.family.TNumber; /** - * A regularizer that applies a L1 regularization penalty. + * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) Regression, + * regularization penalty. * *

The L1 regularization penalty is computed as: loss = l1 * reduceSum(abs(x)) * @@ -41,6 +42,7 @@ public L1(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param l1 the L1 regularization penalty + * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ public L1(Ops tf, float l1, Class type) { super(tf, l1, null, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 565384f1fa5..2908387b9d4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -30,6 +30,9 @@ * *

loss = l2 * reduceSum(square(x))
* + *

The difference between this class and the {@link L1_L2} is use of the default regularization + * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. + * * @param the data type for the weights */ public class L1L2 extends Regularizer { @@ -54,8 +57,8 @@ public L1L2(Ops tf, Class type) { * @param l1 L1 regularization factor, if null it is set to 0. * @param l2 L2 regularization factor, if null it is set to 0. * @param type the data type for the weights - * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link }NaN or is - * infinite. + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} + * of {@link Float#isInfinite} */ public L1L2(Ops tf, Float l1, Float l2, Class type) { super(tf, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java index fe763236cea..3e21c136a9d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java @@ -36,7 +36,7 @@ public class L1_L2 extends L1L2 { /** - * Create a regularizer that applies an L1 and l2 regularization penalty of {@link + * Creates a regularizer that applies an L1 and l2 regularization penalty of {@link * #DEFAULT_REGULARIZATION_PENALTY} * * @param tf the TensorFlow Ops @@ -46,10 +46,12 @@ public L1_L2(Ops tf, Class type) { } /** - * Create a regularizer that applies an L1 and l2 regularization penalty + * Creates a regularizer that applies an L1 and l2 regularization penalty * * @param tf the TensorFlow Ops * @param l1 the L1 regularization penalty + * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite. */ public L1_L2(Ops tf, Float l1, Float l2, Class type) { super( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index 3ce0f84581c..8298cd4aba5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -18,11 +18,11 @@ import org.tensorflow.types.family.TNumber; /** - * A regularizer that applies a L2 regularization penalty. + * A regularizer that applies a L2 (Ridge Regression) regularization penalty. * *

The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x)) * - * @param the data type for the weights + * @param the data type for the operands and result */ public class L2 extends L1L2 { @@ -41,6 +41,7 @@ public L2(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ public L2(Ops tf, float l2, Class type) { super(tf, null, l2, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index 20701fb6348..906efee7f3d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -20,12 +20,12 @@ import org.tensorflow.types.family.TNumber; /** - * Regularizer base class. + * Base class for Regularizers * *

Regularizers allow you to apply penalties on layer parameters or layer activity during * optimization. These penalties are summed into the loss function that the network optimizes. * - * @param the data type of the result + * @param the data type of the operands and result */ public abstract class Regularizer { @@ -36,7 +36,7 @@ public abstract class Regularizer { protected Class type; /** - * Create a Regularizer + * Creates a Regularizer * * @param tf the TensorFlow ops. */ @@ -44,7 +44,7 @@ protected Regularizer(Ops tf, Class type) { this(tf, null, type); } /** - * Create a Regularizer + * Creates a Regularizer * * @param tf the TensorFlow ops. */ @@ -55,19 +55,19 @@ protected Regularizer(Ops tf, String name, Class type) { } /** - * Returns this Regularizer as a Loss This is a convenience tp regularize a loss. Only + * Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only * sampleWeights are applied to the regularizer. * * @return this Regularizer as a Loss */ public Loss asLoss() { - return new RegularizerLoss(this.tf, this); + return new RegularizerLoss<>(this.tf, this); } /** * Computes a regularization penalty from an input. * - * @param input teh weighted input + * @param input the weighted input * @return the result of computing the regularization penalty */ public abstract Operand call(Operand input); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java index e750eabebab..bdd7cfcf1cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -24,7 +24,7 @@ /** * A Regularizer call wrapped as a Loss instance * - *

This class facilitates using a regularizing as a loss, only sampleWeights are + *

This class facilitates using a regularizer as a loss, only sampleWeights are * regularized. * * @param the datatype for the weights type diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 178468f42e8..0f3213ed6eb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -35,21 +35,19 @@ public void testCreate() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCall() { + public void testCallDefaultsConstant() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); L1L2 instance = new L1L2<>(tf, TFloat32.class); Operand result = instance.call(tf.constant(555f)); - session.evaluate(0, result); + session.evaluate(0f, result); } } - /** Test of call method, of class L1L2. */ @Test - public void testCallNO() { + public void testCallL1L20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -61,9 +59,8 @@ public void testCallNO() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1L2() { + public void testCallL1L2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -77,9 +74,23 @@ public void testCallL1L2() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1() { + public void testCallL1L2TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat64.class); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.09f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL2Null() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -92,9 +103,8 @@ public void testCallL1() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL2() { + public void testCallL1Null() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java index 1a4416b3302..6d67bb44d3c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -31,9 +31,8 @@ public void testCreate() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallNO() { + public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -41,13 +40,12 @@ public void testCallNO() { float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); - session.evaluate(0, result); + session.evaluate(0f, result); } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1() { + public void testCallL1TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -60,9 +58,8 @@ public void testCallL1() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1_2() { + public void testCallL1TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java index d35179cdd62..5aeb5a5d9ad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java @@ -35,7 +35,6 @@ public void testCreate() { } } - /** Test of call method, of class L1_L2<>. */ @Test public void testCallZero() { for (TestSession.Mode tfMode : tfModes) @@ -47,9 +46,8 @@ public void testCallZero() { } } - /** Test of call method, of class L1_L2<>. */ @Test - public void testCallNO() { + public void testCallDefaultTFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -67,7 +65,25 @@ public void testCallNO() { } } - /** Test of call method, of class L1_L2<>. */ + @Test + public void testCallDefaultTFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1_L2 instance = new L1_L2<>(tf, TFloat64.class); + double[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = + regularizeL1L2( + w, + Regularizer.DEFAULT_REGULARIZATION_PENALTY, + Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + @Test public void testCallL1L2() { for (TestSession.Mode tfMode : tfModes) @@ -83,9 +99,8 @@ public void testCallL1L2() { } } - /** Test of call method, of class L1_L2<>. */ @Test - public void testCallL1() { + public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -98,9 +113,8 @@ public void testCallL1() { } } - /** Test of call method, of class L1_L2<>. */ @Test - public void testCallL2() { + public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java index cbb019796f1..7f593a2dd14 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -25,15 +25,14 @@ public void testCreate() { assertEquals(0.f, instance.getL2()); assertEquals(0.f, instance.getL1()); - instance = new L2<>(tf, TFloat32.class); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); - assertEquals(0.f, instance.getL1()); + L2 instance64 = new L2<>(tf, TFloat64.class); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); + assertEquals(0.f, instance64.getL1()); } } - /** Test of call method, of class L1L2. */ @Test - public void testCallNO() { + public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -45,9 +44,8 @@ public void testCallNO() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1() { + public void testCallL2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -61,9 +59,8 @@ public void testCallL1() { } } - /** Test of call method, of class L1L2. */ @Test - public void testCallL1_2() { + public void testCallL2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java new file mode 100644 index 00000000000..e694d9409a0 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -0,0 +1,7 @@ +package org.tensorflow.framework.regularizers; + +import static org.junit.jupiter.api.Assertions.*; + +class RegularizerLossTest { + +} \ No newline at end of file From 3c45a87aeda72a5f5f755dc7aac97f9311ff2a4e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 13 Feb 2021 13:45:37 -0500 Subject: [PATCH 11/14] Updates based on comments from PR. Removed generic from Regularizer class and changed the call method to define the generic return based on the weights parameter. Added static method l1_l2() to L1L2 class. Fixed JavaDoc comments. --- .../tensorflow/framework/regularizers/L1.java | 17 +- .../framework/regularizers/L1L2.java | 38 ++--- .../framework/regularizers/L1_L2.java | 20 +-- .../tensorflow/framework/regularizers/L2.java | 13 +- .../framework/regularizers/Regularizer.java | 20 +-- .../regularizers/RegularizerLoss.java | 20 +-- .../framework/regularizers/L1L2Test.java | 20 +-- .../framework/regularizers/L1Test.java | 12 +- .../framework/regularizers/L1_L2Test.java | 20 +-- .../framework/regularizers/L2Test.java | 12 +- .../regularizers/RegularizerLossTest.java | 24 ++- .../org/tensorflow/framework/utils/ND.java | 155 ++++++++++-------- 12 files changed, 202 insertions(+), 169 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 740338350e3..074e881c1cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -15,17 +15,14 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; /** - * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) Regression, - * regularization penalty. + * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) + * Regression, regularization penalty. * *

The L1 regularization penalty is computed as: loss = l1 * reduceSum(abs(x)) - * - * @param the data type for the weights */ -public class L1 extends L1L2 { +public class L1 extends L1L2 { /** * Create a regularizer that applies an L1 regularization penalty of {@link @@ -33,8 +30,8 @@ public class L1 extends L1L2 { * * @param tf the TensorFlow Ops */ - public L1(Ops tf, Class type) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + public L1(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY); } /** @@ -44,7 +41,7 @@ public L1(Ops tf, Class type) { * @param l1 the L1 regularization penalty * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ - public L1(Ops tf, float l1, Class type) { - super(tf, l1, null, type); + public L1(Ops tf, float l1) { + super(tf, l1, null); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 2908387b9d4..89b407e0940 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -32,22 +32,19 @@ * *

The difference between this class and the {@link L1_L2} is use of the default regularization * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. - * - * @param the data type for the weights */ -public class L1L2 extends Regularizer { +public class L1L2 extends Regularizer { - private final Float l1; - private final Float l2; + private final float l1; + private final float l2; /** - * Creates an L1L2 regularizer with no l1 or l2 penalty with default penal + * Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty * * @param tf the TensorFlow Ops - * @param type the data type for the weights */ - public L1L2(Ops tf, Class type) { - this(tf, null, null, type); + public L1L2(Ops tf) { + this(tf, null, null); } /** @@ -56,12 +53,11 @@ public L1L2(Ops tf, Class type) { * @param tf the TensorFlow Ops * @param l1 L1 regularization factor, if null it is set to 0. * @param l2 L2 regularization factor, if null it is set to 0. - * @param type the data type for the weights * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} * of {@link Float#isInfinite} */ - public L1L2(Ops tf, Float l1, Float l2, Class type) { - super(tf, type); + public L1L2(Ops tf, Float l1, Float l2) { + super(tf); if (l1 != null) { if (l1.isNaN() || l1.isInfinite()) { throw new IllegalArgumentException( @@ -86,25 +82,29 @@ public L1L2(Ops tf, Float l1, Float l2, Class type) { } } + public static L1L2 l1_l2(Ops tf) { + return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + } + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { Ops tf = getTF(); - if (this.getL1() == null && this.getL2() == null) { + if (this.getL1() == 0f && this.getL2() == 0f) { return tf.dtypes.cast(tf.constant(0), input.type()); } Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); - if (this.getL1() != null && this.getL1() != 0.f) { + if (this.getL1() != 0.f) { Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); Operand abs = tf.math.abs(input); Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); } - if (this.getL2() != null && this.getL2() != 0.f) { + if (this.getL2() != 0.f) { Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); - Operand sqr = tf.math.abs(input); + Operand sqr = tf.math.square(input); Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); } @@ -117,7 +117,7 @@ public Operand call(Operand input) { * * @return the L1 regularization factor */ - public Float getL1() { + public float getL1() { return l1; } @@ -126,7 +126,7 @@ public Float getL1() { * * @return the L2 regularization factor */ - public Float getL2() { + public float getL2() { return l2; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java index 3e21c136a9d..44e04ad4d94 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; /** * A regularizer that applies both L1 and L2 regularization penalties. @@ -30,10 +29,8 @@ * *

The difference between this class and the {@link L1L2} is use of the default regularization * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. - * - * @param the data type for the weights */ -public class L1_L2 extends L1L2 { +public class L1_L2 extends L1L2 { /** * Creates a regularizer that applies an L1 and l2 regularization penalty of {@link @@ -41,23 +38,24 @@ public class L1_L2 extends L1L2 { * * @param tf the TensorFlow Ops */ - public L1_L2(Ops tf, Class type) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY, type); + public L1_L2(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); } /** * Creates a regularizer that applies an L1 and l2 regularization penalty * * @param tf the TensorFlow Ops - * @param l1 the L1 regularization penalty - * @param l2 the L2 regularization penalty + * @param l1 the L1 regularization penalty. If null, then l1 will be set to {@link + * #DEFAULT_REGULARIZATION_PENALTY}. + * @param l2 the L2 regularization penalty. If null, then l2 will be set to {@link + * #DEFAULT_REGULARIZATION_PENALTY}. * @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite. */ - public L1_L2(Ops tf, Float l1, Float l2, Class type) { + public L1_L2(Ops tf, Float l1, Float l2) { super( tf, l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, - l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2, - type); + l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index 8298cd4aba5..b09b93a76d9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -15,16 +15,13 @@ package org.tensorflow.framework.regularizers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TNumber; /** * A regularizer that applies a L2 (Ridge Regression) regularization penalty. * *

The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x)) - * - * @param the data type for the operands and result */ -public class L2 extends L1L2 { +public class L2 extends L1L2 { /** * Create a regularizer that applies an L2 regularization penalty of {@link @@ -32,8 +29,8 @@ public class L2 extends L1L2 { * * @param tf the TensorFlow Ops */ - public L2(Ops tf, Class type) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, type); + public L2(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY); } /** @@ -43,7 +40,7 @@ public L2(Ops tf, Class type) { * @param l2 the L2 regularization penalty * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ - public L2(Ops tf, float l2, Class type) { - super(tf, null, l2, type); + public L2(Ops tf, float l2) { + super(tf, null, l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index 906efee7f3d..d1c17d4fc8c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -24,33 +24,31 @@ * *

Regularizers allow you to apply penalties on layer parameters or layer activity during * optimization. These penalties are summed into the loss function that the network optimizes. - * - * @param the data type of the operands and result */ -public abstract class Regularizer { +public abstract class Regularizer { public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; private final Ops tf; private final String name; - protected Class type; /** - * Creates a Regularizer + * Creates a Regularizer, using {@link Class#getSimpleName()} for the name * * @param tf the TensorFlow ops. */ - protected Regularizer(Ops tf, Class type) { - this(tf, null, type); + protected Regularizer(Ops tf) { + this(tf, null); } /** * Creates a Regularizer * * @param tf the TensorFlow ops. + * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the + * name. */ - protected Regularizer(Ops tf, String name, Class type) { + protected Regularizer(Ops tf, String name) { this.tf = tf; - this.type = type; this.name = name == null ? this.getClass().getSimpleName() : name; } @@ -61,7 +59,7 @@ protected Regularizer(Ops tf, String name, Class type) { * @return this Regularizer as a Loss */ public Loss asLoss() { - return new RegularizerLoss<>(this.tf, this); + return new RegularizerLoss(this.tf, this); } /** @@ -70,7 +68,7 @@ public Loss asLoss() { * @param input the weighted input * @return the result of computing the regularization penalty */ - public abstract Operand call(Operand input); + public abstract Operand call(Operand input); /** * Gets the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java index bdd7cfcf1cd..d3adfbf68ae 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -19,27 +19,24 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Regularizer call wrapped as a Loss instance * *

This class facilitates using a regularizer as a loss, only sampleWeights are * regularized. - * - * @param the datatype for the weights type */ -class RegularizerLoss extends Loss { +class RegularizerLoss extends Loss { + + private final Regularizer regularizer; - private final Regularizer regularizer; - private final Class type; /** * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, Regularizer regularizer) { + public RegularizerLoss(Ops tf, Regularizer regularizer) { this(tf, null, regularizer); } @@ -48,11 +45,11 @@ public RegularizerLoss(Ops tf, Regularizer regularizer) { * * @param tf the TensorFlow Ops * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + * @param regularizer the regularizer used to calculate the loss */ - public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { + public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { super(tf, name); this.regularizer = regularizer; - this.type = regularizer.type; } @@ -63,7 +60,6 @@ public Operand call( if (sampleWeights == null) { throw new IllegalArgumentException("sampleWeights cannot be null"); } - Operand result = regularizer.call(cast(getTF(), sampleWeights, type)); - return cast(tf, result, sampleWeights.type()); + return regularizer.call(sampleWeights); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 0f3213ed6eb..3c6dd83731b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -17,19 +17,19 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, 0.2f, 0.3f, TFloat32.class); + L1L2 instance = new L1L2(tf, 0.2f, 0.3f); assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1L2<>(tf, null, null, TFloat32.class); + instance = new L1L2(tf, null, null); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2<>(tf, 0.5f, null, TFloat32.class); + instance = new L1L2(tf, 0.5f, null); assertEquals(0.5f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2<>(tf, null, 0.5f, TFloat32.class); + instance = new L1L2(tf, null, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); } @@ -40,7 +40,7 @@ public void testCallDefaultsConstant() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, TFloat32.class); + L1L2 instance = new L1L2(tf); Operand result = instance.call(tf.constant(555f)); session.evaluate(0f, result); } @@ -51,7 +51,7 @@ public void testCallL1L20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, TFloat32.class); + L1L2 instance = new L1L2(tf); Operand weights = tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); Operand result = instance.call(weights); @@ -64,7 +64,7 @@ public void testCallL1L2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat32.class); + L1L2 instance = new L1L2(tf, 0.01f, 0.02f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -79,7 +79,7 @@ public void testCallL1L2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, 0.01f, 0.02f, TFloat64.class); + L1L2 instance = new L1L2(tf, 0.01f, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -94,7 +94,7 @@ public void testCallL2Null() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, 0.01f, null, TFloat32.class); + L1L2 instance = new L1L2(tf, 0.01f, null); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -108,7 +108,7 @@ public void testCallL1Null() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2<>(tf, null, 0.02f, TFloat64.class); + L1L2 instance = new L1L2(tf, null, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java index 6d67bb44d3c..0e42a257816 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -17,15 +17,15 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.2f, TFloat32.class); + L1 instance = new L1(tf, 0.2f); assertEquals(0.2f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1<>(tf, 0f, TFloat32.class); + instance = new L1(tf, 0f); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1<>(tf, TFloat32.class); + instance = new L1(tf); assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); assertEquals(0.f, instance.getL2()); } @@ -36,7 +36,7 @@ public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.0f, TFloat32.class); + L1 instance = new L1(tf, 0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -49,7 +49,7 @@ public void testCallL1TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1<>(tf, TFloat32.class); + L1 instance = new L1(tf); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -63,7 +63,7 @@ public void testCallL1TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1 instance = new L1<>(tf, 0.02f, TFloat64.class); + L1 instance = new L1(tf, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java index 5aeb5a5d9ad..e4b4e7cc7a3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java @@ -17,19 +17,19 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0.2f, 0.3f, TFloat32.class); + L1_L2 instance = new L1_L2(tf, 0.2f, 0.3f); assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1_L2<>(tf, 0.5f, 0f, TFloat32.class); + instance = new L1_L2(tf, 0.5f, 0f); assertEquals(0.5f, instance.getL1()); assertEquals(0f, instance.getL2()); - instance = new L1_L2<>(tf, 0f, 0.5f, TFloat32.class); + instance = new L1_L2(tf, 0f, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); - instance = new L1_L2<>(tf, TFloat32.class); + instance = new L1_L2(tf); assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); } @@ -40,7 +40,7 @@ public void testCallZero() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0f, 0f, TFloat32.class); + L1_L2 instance = new L1_L2(tf, 0f, 0f); Operand result = instance.call(tf.constant(555f)); session.evaluate(0, result); } @@ -51,7 +51,7 @@ public void testCallDefaultTFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, TFloat32.class); + L1_L2 instance = new L1_L2(tf); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -70,7 +70,7 @@ public void testCallDefaultTFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, TFloat64.class); + L1_L2 instance = new L1_L2(tf); double[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -89,7 +89,7 @@ public void testCallL1L2() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0.01f, 0.02f, TFloat32.class); + L1_L2 instance = new L1_L2(tf, 0.01f, 0.02f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -104,7 +104,7 @@ public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0.01f, 0f, TFloat32.class); + L1_L2 instance = new L1_L2(tf, 0.01f, 0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -118,7 +118,7 @@ public void testCallL10() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1_L2 instance = new L1_L2<>(tf, 0f, 0.02f, TFloat64.class); + L1_L2 instance = new L1_L2(tf, 0f, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java index 7f593a2dd14..aba036ee306 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -17,15 +17,15 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2<>(tf, 0.2f, TFloat32.class); + L2 instance = new L2(tf, 0.2f); assertEquals(0.2f, instance.getL2()); assertEquals(0.f, instance.getL1()); - instance = new L2<>(tf, 0f, TFloat32.class); + instance = new L2(tf, 0f); assertEquals(0.f, instance.getL2()); assertEquals(0.f, instance.getL1()); - L2 instance64 = new L2<>(tf, TFloat64.class); + L2 instance64 = new L2(tf); assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); assertEquals(0.f, instance64.getL1()); } @@ -36,7 +36,7 @@ public void testCallL20() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2<>(tf, 0.0f, TFloat32.class); + L2 instance = new L2(tf, 0.0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -49,7 +49,7 @@ public void testCallL2TFloat32() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2<>(tf, TFloat32.class); + L2 instance = new L2(tf); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -64,7 +64,7 @@ public void testCallL2TFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L2 instance = new L2<>(tf, 0.02f, TFloat64.class); + L2 instance = new L2(tf, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java index e694d9409a0..836503af1fa 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -1,7 +1,27 @@ package org.tensorflow.framework.regularizers; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; class RegularizerLossTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; -} \ No newline at end of file + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 regularizer = new L1L2(tf, 0.01f, null); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand regularizerResult = regularizer.call(weights); + RegularizerLoss lossInstance = new RegularizerLoss(tf, regularizer); + + Operand loss = lossInstance.call(null, null, weights); + session.evaluate(regularizerResult, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index d0cbae56628..c0c0f12fbf9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,11 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.*; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; @@ -98,22 +94,9 @@ public static FloatNdArray sqrt(FloatNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat((float) Math.sqrt(v.getFloat()), idx)); - return result; - } - - /** - * Gets the square root of an array. - * - * @param a the array - * @return the square root of the array. - */ - public static DoubleNdArray sqrt(DoubleNdArray a) { - DoubleNdArray result = NdArrays.ofDoubles(a.shape()); - int nDims = a.shape().numDimensions(); - a.elements(nDims - 1) - .forEachIndexed( - (idx, v) -> result.setDouble(Math.sqrt(v.getDouble()), idx)); + (idx, v) -> { + result.setFloat((float) Math.sqrt(v.getFloat()), idx); + }); return result; } @@ -128,7 +111,9 @@ public static FloatNdArray square(FloatNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(v.getFloat() * v.getFloat(), idx)); + (idx, v) -> { + result.setFloat(v.getFloat() * v.getFloat(), idx); + }); return result; } @@ -143,7 +128,9 @@ public static DoubleNdArray square(DoubleNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setDouble(v.getDouble() * v.getDouble(), idx)); + (idx, v) -> { + result.setDouble(v.getDouble() * v.getDouble(), idx); + }); return result; } @@ -161,7 +148,9 @@ public static FloatNdArray add(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(v.getFloat() + b.getFloat(idx), idx)); + (idx, v) -> { + result.setFloat(v.getFloat() + b.getFloat(idx), idx); + }); return result; } @@ -178,7 +167,9 @@ public static FloatNdArray add(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(v.getFloat() + scalar, idx)); + (idx, v) -> { + result.setFloat(v.getFloat() + scalar, idx); + }); return result; } @@ -207,7 +198,9 @@ public static FloatNdArray sub(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(v.getFloat() - b.getFloat(idx), idx)); + (idx, v) -> { + result.setFloat(v.getFloat() - b.getFloat(idx), idx); + }); return result; } @@ -223,7 +216,9 @@ public static FloatNdArray sub(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(v.getFloat() - scalar, idx)); + (idx, v) -> { + result.setFloat(v.getFloat() - scalar, idx); + }); return result; } @@ -239,7 +234,9 @@ public static FloatNdArray sub(float scalar, FloatNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(scalar - v.getFloat(), idx)); + (idx, v) -> { + result.setFloat(scalar - v.getFloat(), idx); + }); return result; } @@ -311,22 +308,22 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) { public static DoubleNdArray mul(DoubleNdArray a, DoubleNdArray b) { if (!a.shape().equals(b.shape())) throw new IllegalArgumentException( - String.format( - "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); + String.format( + "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); boolean sameSize = a.shape().size() == b.shape().size(); DoubleNdArray result = NdArrays.ofDoubles(a.shape()); int nDims = a.shape().numDimensions(); a.elements(nDims - 1) - .forEachIndexed( - (idx, v) -> { - if (sameSize) { - result.setDouble(v.getDouble() * b.getDouble(idx), idx); - } else { - double value = v.getDouble() * b.getDouble(idx[0], 0L); - result.setDouble(value, idx); - } - }); + .forEachIndexed( + (idx, v) -> { + if (sameSize) { + result.setDouble(v.getDouble() * b.getDouble(idx), idx); + } else { + double value = v.getDouble() * b.getDouble(idx[0], 0L); + result.setDouble(value, idx); + } + }); return result; } @@ -373,7 +370,9 @@ public static FloatNdArray div(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(v.getFloat() / b.getFloat(idx), idx)); + (idx, v) -> { + result.setFloat(v.getFloat() / b.getFloat(idx), idx); + }); return result; } @@ -390,7 +389,9 @@ public static FloatNdArray div(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(v.getFloat() / scalar, idx)); + (idx, v) -> { + result.setFloat(v.getFloat() / scalar, idx); + }); return result; } @@ -427,7 +428,9 @@ public static FloatNdArray pow(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat((float) Math.pow(v.getFloat(), b.getFloat(idx)), idx)); + (idx, v) -> { + result.setFloat((float) Math.pow(v.getFloat(), b.getFloat(idx)), idx); + }); return result; } @@ -443,7 +446,9 @@ public static FloatNdArray pow(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat((float) Math.pow(v.getFloat(), scalar), idx)); + (idx, v) -> { + result.setFloat((float) Math.pow(v.getFloat(), scalar), idx); + }); return result; } @@ -459,7 +464,9 @@ public static FloatNdArray pow(float scalar, FloatNdArray a) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat((float) Math.pow(scalar, v.getFloat()), idx)); + (idx, v) -> { + result.setFloat((float) Math.pow(scalar, v.getFloat()), idx); + }); return result; } @@ -475,7 +482,9 @@ public static float[] flatten(FloatNdArray a) { AtomicInteger counter = new AtomicInteger(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result[counter.getAndAdd(1)] = v.getFloat()); + (idx, v) -> { + result[counter.getAndAdd(1)] = v.getFloat(); + }); return result; } @@ -507,7 +516,7 @@ public static float min(FloatNdArray a) { * Get the maximum value of comparing the arrays * * @param a the first array - * @param b the second array + * @param a the second array * @return the resulting array with the maximum values between each element of the arrays. * @throws AssertionError if the two arrays are not the same size. */ @@ -518,7 +527,9 @@ public static FloatNdArray max(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(Math.max(v.getFloat(), b.getFloat(idx)), idx)); + (idx, v) -> { + result.setFloat(Math.max(v.getFloat(), b.getFloat(idx)), idx); + }); return result; } @@ -535,7 +546,9 @@ public static FloatNdArray max(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(Math.max(v.getFloat(), scalar), idx)); + (idx, v) -> { + result.setFloat(Math.max(v.getFloat(), scalar), idx); + }); return result; } @@ -555,7 +568,7 @@ public static FloatNdArray max(float scalar, FloatNdArray a) { * Get the minimum value of comparing the arrays * * @param a the first array - * @param b the second array + * @param a the second array * @return the resulting array with the minimum values between each element of the arrays. * @throws AssertionError if the two arrays are not the same size. */ @@ -566,7 +579,9 @@ public static FloatNdArray min(FloatNdArray a, FloatNdArray b) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat(Math.min(v.getFloat(), b.getFloat(idx)), idx)); + (idx, v) -> { + result.setFloat(Math.min(v.getFloat(), b.getFloat(idx)), idx); + }); return result; } @@ -583,7 +598,9 @@ public static FloatNdArray min(FloatNdArray a, float scalar) { int nDims = a.shape().numDimensions(); a.elements(nDims - 1) .forEachIndexed( - (idx, v) -> result.setFloat( Math.min(v.getFloat(), scalar), idx)); + (idx, v) -> { + result.setFloat(Math.min(v.getFloat(), scalar), idx); + }); return result; } @@ -619,7 +636,7 @@ public static FloatNdArray abs(FloatNdArray a) { */ public static DoubleNdArray abs(DoubleNdArray a) { DoubleNdArray result = NdArrays.ofDoubles(a.shape()); - a.scalars().forEachIndexed((idx, f) -> result.setDouble( Math.abs(f.getDouble()), idx)); + a.scalars().forEachIndexed((idx, f) -> result.setDouble(Math.abs(f.getDouble()), idx)); return result; } @@ -635,8 +652,6 @@ public static FloatNdArray sum(FloatNdArray a) { return NdArrays.scalarOf(sum.get()); } - - /** * Sum all elements of an array based on the specified axis * @@ -666,7 +681,9 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { a.scalars() .forEachIndexed( - (idx, f) -> sums[(int) idx[xis]] += f.getFloat()); + (idx, f) -> { + sums[(int) idx[xis]] += f.getFloat(); + }); if (keepDims) { long[] newDims = shape.asArray(); @@ -676,7 +693,9 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { arrayK .elements(newDims.length - 1) .forEachIndexed( - (idx, v) -> v.setFloat(sums[counter.getAndAdd(1)])); + (idx, v) -> { + v.setFloat(sums[counter.getAndAdd(1)]); + }); return arrayK; } else { return NdArrays.vectorOf(sums); @@ -736,7 +755,7 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis) { } /** - * Sum all elements of an array over the specified axis + * Sum all elements of an array over on the specified axis * * @param a the array * @param axis the axis to sum @@ -752,8 +771,10 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { final double[] sums = new double[(int) axisSize]; a.scalars() - .forEachIndexed( - (idx, f) -> sums[(int) idx[xis]] += f.getDouble()); + .forEachIndexed( + (idx, f) -> { + sums[(int) idx[xis]] += f.getDouble(); + }); if (keepDims) { long[] newDims = shape.asArray(); @@ -761,9 +782,11 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { final AtomicInteger counter = new AtomicInteger(); DoubleNdArray arrayK = NdArrays.ofDoubles(Shape.of(newDims)); arrayK - .elements(newDims.length - 1) - .forEachIndexed( - (idx, v) -> v.setDouble(sums[counter.getAndAdd(1)])); + .elements(newDims.length - 1) + .forEachIndexed( + (idx, v) -> { + v.setDouble(sums[counter.getAndAdd(1)]); + }); return arrayK; } else { return NdArrays.vectorOf(sums); @@ -835,7 +858,9 @@ public static void print(FloatNdArray a) { } else { a.elements(a.shape().numDimensions() - 1) .forEachIndexed( - (idx, v) -> System.out.printf("%s == %f\n", Arrays.toString(idx), v.getFloat())); + (idx, v) -> { + System.out.printf("%s == %f\n", Arrays.toString(idx), v.getFloat()); + }); } System.out.println(); } @@ -866,7 +891,9 @@ public static FloatNdArray create(float[] y, Shape shape) { result .elements(shape.numDimensions() - 1) .forEachIndexed( - (idx, v) -> v.setFloat(y[index.getAndAdd(1)])); + (idx, v) -> { + v.setFloat(y[index.getAndAdd(1)]); + }); return result; } } From 2bd80b32f55d5f8fe07519caeebb22b6fa2a5c00 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 14 Feb 2021 10:49:12 -0500 Subject: [PATCH 12/14] Add JavDoc to new method l1_l2 --- .../java/org/tensorflow/framework/regularizers/L1L2.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 89b407e0940..2670de799e1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -82,6 +82,13 @@ public L1L2(Ops tf, Float l1, Float l2) { } } + /** + * Creates an L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 + * values. + * + * @param tf the TensorFlow Ops + * @return a L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 values. + */ public static L1L2 l1_l2(Ops tf) { return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); } From da7a10b152eeb370d4f83423a9a7def0aab91086 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 14 Feb 2021 11:04:00 -0500 Subject: [PATCH 13/14] change l1_l2 to create. --- .../main/java/org/tensorflow/framework/regularizers/L1L2.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index 2670de799e1..b2a7edbb187 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -89,7 +89,7 @@ public L1L2(Ops tf, Float l1, Float l2) { * @param tf the TensorFlow Ops * @return a L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 values. */ - public static L1L2 l1_l2(Ops tf) { + public static L1L2 create(Ops tf) { return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); } From 9ea1d9af4f008e169174fecde90923caf6784996 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 14 Feb 2021 12:06:08 -0500 Subject: [PATCH 14/14] delete class L1_L2 modified Float to float for l1 and l2 parameters Change ctor L1L2(Ops tf) to use DEFAULT_REGULARIZATION_PENALTY for l1/l2 parameters Fix JavaDoc --- .../tensorflow/framework/regularizers/L1.java | 2 +- .../framework/regularizers/L1L2.java | 49 ++----- .../framework/regularizers/L1_L2.java | 61 -------- .../tensorflow/framework/regularizers/L2.java | 2 +- .../framework/regularizers/Regularizer.java | 1 + .../framework/regularizers/L1L2Test.java | 24 ++-- .../framework/regularizers/L1_L2Test.java | 130 ------------------ .../regularizers/RegularizerLossTest.java | 2 +- 8 files changed, 33 insertions(+), 238 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java index 074e881c1cd..7c8c2a1360a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -42,6 +42,6 @@ public L1(Ops tf) { * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. */ public L1(Ops tf, float l1) { - super(tf, l1, null); + super(tf, l1, 0f); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java index b2a7edbb187..29e411f9897 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -30,8 +30,6 @@ * *

loss = l2 * reduceSum(square(x))
* - *

The difference between this class and the {@link L1_L2} is use of the default regularization - * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. */ public class L1L2 extends Regularizer { @@ -44,7 +42,7 @@ public class L1L2 extends Regularizer { * @param tf the TensorFlow Ops */ public L1L2(Ops tf) { - this(tf, null, null); + this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); } /** @@ -56,42 +54,25 @@ public L1L2(Ops tf) { * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} * of {@link Float#isInfinite} */ - public L1L2(Ops tf, Float l1, Float l2) { + public L1L2(Ops tf, float l1, float l2) { super(tf); - if (l1 != null) { - if (l1.isNaN() || l1.isInfinite()) { - throw new IllegalArgumentException( - String.format( - "L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", - l1)); - } - this.l1 = l1; - } else { - this.l1 = 0f; + if (Float.isNaN(l1) || Float.isInfinite(l1)) { + throw new IllegalArgumentException( + String.format( + "L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l1)); } - if (l2 != null) { - if (l2.isNaN() || l2.isInfinite()) { - throw new IllegalArgumentException( - String.format( - "L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", - l2)); - } - this.l2 = l2; - } else { - this.l2 = 0f; + this.l1 = l1; + + if (Float.isNaN(l2) || Float.isInfinite(l2)) { + throw new IllegalArgumentException( + String.format( + "L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l2)); } + this.l2 = l2; } - /** - * Creates an L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 - * values. - * - * @param tf the TensorFlow Ops - * @return a L1L2 instance using {@link #DEFAULT_REGULARIZATION_PENALTY} for the l1 and l2 values. - */ - public static L1L2 create(Ops tf) { - return new L1L2(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); - } /** {@inheritDoc} */ @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java deleted file mode 100644 index 44e04ad4d94..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1_L2.java +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.regularizers; - -import org.tensorflow.op.Ops; - -/** - * A regularizer that applies both L1 and L2 regularization penalties. - * - *

The L1 regularization penalty is computed as: - * - *

loss = l1 * reduceSum(abs(x))
- * - *

The L2 regularization penalty is computed as - * - *

loss = l2 * reduceSum(square(x))
- * - *

The difference between this class and the {@link L1L2} is use of the default regularization - * penalty {@link #DEFAULT_REGULARIZATION_PENALTY}, whereas {@link L1L2} defaults to 0. - */ -public class L1_L2 extends L1L2 { - - /** - * Creates a regularizer that applies an L1 and l2 regularization penalty of {@link - * #DEFAULT_REGULARIZATION_PENALTY} - * - * @param tf the TensorFlow Ops - */ - public L1_L2(Ops tf) { - this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); - } - - /** - * Creates a regularizer that applies an L1 and l2 regularization penalty - * - * @param tf the TensorFlow Ops - * @param l1 the L1 regularization penalty. If null, then l1 will be set to {@link - * #DEFAULT_REGULARIZATION_PENALTY}. - * @param l2 the L2 regularization penalty. If null, then l2 will be set to {@link - * #DEFAULT_REGULARIZATION_PENALTY}. - * @throws IllegalArgumentException if the l1 or l2 regularization factor is NaN or is infinite. - */ - public L1_L2(Ops tf, Float l1, Float l2) { - super( - tf, - l1 == null ? DEFAULT_REGULARIZATION_PENALTY : l1, - l2 == null ? DEFAULT_REGULARIZATION_PENALTY : l2); - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java index b09b93a76d9..7b8f5b28a70 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -41,6 +41,6 @@ public L2(Ops tf) { * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. */ public L2(Ops tf, float l2) { - super(tf, null, l2); + super(tf, 0f, l2); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java index d1c17d4fc8c..5d9ff0e3e10 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -67,6 +67,7 @@ public Loss asLoss() { * * @param input the weighted input * @return the result of computing the regularization penalty + * @param the data type of the input and result */ public abstract Operand call(Operand input); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java index 3c6dd83731b..181ae367f07 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -21,17 +21,21 @@ public void testCreate() { assertEquals(0.2f, instance.getL1()); assertEquals(0.3f, instance.getL2()); - instance = new L1L2(tf, null, null); + instance = new L1L2(tf, 0, 0); assertEquals(0.f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, 0.5f, null); + instance = new L1L2(tf, 0.5f, 0); assertEquals(0.5f, instance.getL1()); assertEquals(0.f, instance.getL2()); - instance = new L1L2(tf, null, 0.5f); + instance = new L1L2(tf, 0, 0.5f); assertEquals(0.f, instance.getL1()); assertEquals(0.5f, instance.getL2()); + + instance = new L1L2(tf); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); } } @@ -42,16 +46,16 @@ public void testCallDefaultsConstant() { Ops tf = session.getTF(); L1L2 instance = new L1L2(tf); Operand result = instance.call(tf.constant(555f)); - session.evaluate(0f, result); + session.evaluate(3085.8f, result); } } @Test - public void testCallL1L20() { + public void testCallL1L2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf); + L1L2 instance = new L1L2(tf, 0, 0); Operand weights = tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); Operand result = instance.call(weights); @@ -90,11 +94,11 @@ public void testCallL1L2TFloat64() { } @Test - public void testCallL2Null() { + public void testCallL2_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, 0.01f, null); + L1L2 instance = new L1L2(tf, 0.01f, 0); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); @@ -104,11 +108,11 @@ public void testCallL2Null() { } @Test - public void testCallL1Null() { + public void testCallL1_0() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 instance = new L1L2(tf, null, 0.02f); + L1L2 instance = new L1L2(tf, 0, 0.02f); double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; Operand weights = tf.constant(w); Operand result = instance.call(weights); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java deleted file mode 100644 index e4b4e7cc7a3..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1_L2Test.java +++ /dev/null @@ -1,130 +0,0 @@ -package org.tensorflow.framework.regularizers; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -class L1_L2Test extends CommonTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - - @Test - public void testCreate() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0.2f, 0.3f); - assertEquals(0.2f, instance.getL1()); - assertEquals(0.3f, instance.getL2()); - - instance = new L1_L2(tf, 0.5f, 0f); - assertEquals(0.5f, instance.getL1()); - assertEquals(0f, instance.getL2()); - - instance = new L1_L2(tf, 0f, 0.5f); - assertEquals(0.f, instance.getL1()); - assertEquals(0.5f, instance.getL2()); - - instance = new L1_L2(tf); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); - assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); - } - } - - @Test - public void testCallZero() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0f, 0f); - Operand result = instance.call(tf.constant(555f)); - session.evaluate(0, result); - } - } - - @Test - public void testCallDefaultTFloat32() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf); - float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = - regularizeL1L2( - w, - Regularizer.DEFAULT_REGULARIZATION_PENALTY, - Regularizer.DEFAULT_REGULARIZATION_PENALTY); - session.setEpsilon(.01f); - session.evaluate(expected, result); - } - } - - @Test - public void testCallDefaultTFloat64() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf); - double[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - double expected = - regularizeL1L2( - w, - Regularizer.DEFAULT_REGULARIZATION_PENALTY, - Regularizer.DEFAULT_REGULARIZATION_PENALTY); - session.setEpsilon(.01f); - session.evaluate(expected, result); - } - } - - @Test - public void testCallL1L2() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0.01f, 0.02f); - float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL1L2(w, 0.01f, 0.02f); - session.setEpsilon(.01f); - session.evaluate(expected, result); - } - } - - @Test - public void testCallL20() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0.01f, 0f); - float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - float expected = regularizeL1(w, 0.01f); - session.evaluate(expected, result); - } - } - - @Test - public void testCallL10() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - L1_L2 instance = new L1_L2(tf, 0f, 0.02f); - double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; - Operand weights = tf.constant(w); - Operand result = instance.call(weights); - double expected = regularizeL2(w, 0.02f); - session.setEpsilon(.01f); - session.evaluate(expected, result); - } - } -} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java index 836503af1fa..fe2624cec3d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -14,7 +14,7 @@ public void testCreate() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - L1L2 regularizer = new L1L2(tf, 0.01f, null); + L1L2 regularizer = new L1L2(tf, 0.01f, 0f); float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; Operand weights = tf.constant(w); Operand regularizerResult = regularizer.call(weights);