Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ef0ce67
Initial checkin of Keras Optimzers and helper classes.
JimClarke5 Jul 28, 2020
9c113a7
Added static final NAME to replace hardcoded String in the create met…
JimClarke5 Aug 20, 2020
824d487
Changed of method to use the DataType NAME attribute rather than hard…
JimClarke5 Aug 20, 2020
07a83a5
Added method WriteFieldWithInitializer to output a "final static Stri…
JimClarke5 Aug 20, 2020
3d26831
Added tf.nn.softmaxCrossEntropyWitLogits() and tf.nn.raw.softmaxCross…
JimClarke5 Aug 20, 2020
11cda5f
Moved SoftmaxCrossEntropyWithLogits and SparseSoftmaxCrossEntropyWit…
JimClarke5 Aug 20, 2020
9c7dfaa
Generated classes now have public static final String OP_NAME = "XXXX…
JimClarke5 Aug 20, 2020
84f49db
Generated classes now have public static final String OP_NAME = "XXXX…
JimClarke5 Aug 20, 2020
208b84a
fix dependencies for other Tensorflow Java modules
JimClarke5 Aug 20, 2020
3913161
formatting fix
JimClarke5 Aug 20, 2020
b5a7c0f
Fix ctors with name to properly pass the name to the the super ctor.
JimClarke5 Aug 20, 2020
fcba0a5
change asserts to IllegalArgumentException
JimClarke5 Aug 20, 2020
960cfc3
change asserts to IllegalArgumentException
JimClarke5 Aug 20, 2020
d37298a
Moved back to tests
JimClarke5 Aug 20, 2020
c68812c
Moved SoftmaxCrossEntropyWithLogits.java and SparseSoftmaxCrossEntrop…
JimClarke5 Aug 20, 2020
6b8eb26
Deleted files that are not necessary yet
JimClarke5 Aug 20, 2020
6515c24
Added nn.raw group for softmaxCrossEntropyWithLogits() and sparseSoft…
JimClarke5 Aug 20, 2020
76d0fe5
Added nn.raw group for softmaxCrossEntropyWithLogits() and sparseSoft…
JimClarke5 Aug 20, 2020
d2201df
Merge branch 'master' into master
JimClarke5 Aug 20, 2020
ab379d1
Refactor NN into individual operations under org.tensorflow.op.nn. Fi…
JimClarke5 Sep 3, 2020
889d67e
Refactor NN into individual operations under org.tensorflow.op.nn. Fi…
JimClarke5 Sep 3, 2020
515b799
Reformatted code
JimClarke5 Sep 3, 2020
5a9fe37
Added sub scope
JimClarke5 Sep 3, 2020
8d21dd7
Miscellaneous fixes based on review comments.
JimClarke5 Sep 3, 2020
4c3cc78
Fixed op_generator.cc to remove a spurious new line in the generated …
JimClarke5 Sep 3, 2020
44f530f
Changed back to non-generic Operand until we resolve how to handle ge…
JimClarke5 Sep 3, 2020
b8d3ac2
Regenerated due to creation of SoftmaxCrossEntropyWithLogits.java, S…
JimClarke5 Sep 3, 2020
c32fc5b
change snake case to camel case. format code
JimClarke5 Sep 7, 2020
171cd2f
clean upd warning, format code
JimClarke5 Sep 7, 2020
e9c3134
Added Adamax, Ftrl, and Nadam Optimizers. Added Optimizers enum for e…
JimClarke5 Sep 9, 2020
5c30a72
Removed optimize classes from tensorflow-keras, moved optimizer test …
JimClarke5 Sep 9, 2020
ebefc2e
Fixed generics
JimClarke5 Sep 9, 2020
7915e63
Fixed from Unit test results
JimClarke5 Sep 9, 2020
ec4f679
added @SuppressWarnings("unchecked") on Variable array
JimClarke5 Sep 9, 2020
c86d09b
Merge pull request #1 from tensorflow/master
JimClarke5 Sep 18, 2020
e9cd56a
Add initializers
JimClarke5 Sep 20, 2020
9360efe
Add initializers
JimClarke5 Sep 20, 2020
8e28bb5
Remove @author
JimClarke5 Sep 20, 2020
33530bb
Fix javadoc, change name of squeeze to reduce. Add logic to in reduc…
JimClarke5 Sep 20, 2020
3261888
Update JavaDoc to highlight difference between compatible shapes and …
JimClarke5 Sep 22, 2020
8860c56
Add handling of TUint8 data types, and add a predicate to evaluate TS…
JimClarke5 Sep 22, 2020
c3fa457
Change all ctors to require seed param, when seed is used. Refactor L…
JimClarke5 Sep 22, 2020
ccf7b53
Fix all tests to run in both Eager and Graph Mode.
JimClarke5 Sep 22, 2020
a8e4407
Fix formatting
JimClarke5 Sep 22, 2020
f5d1216
Fix formatting
JimClarke5 Sep 22, 2020
d834827
Fix formatting
JimClarke5 Sep 22, 2020
6fccd59
Fix formatting
JimClarke5 Sep 22, 2020
43f3fb7
Added Reproducible tests to make sure that for each initilaizer insta…
JimClarke5 Sep 22, 2020
dcf82a7
Fixed JavaDoc and default value references
JimClarke5 Sep 25, 2020
1f32de2
Fixed JavaDoc
JimClarke5 Sep 25, 2020
fb52dd4
Change UNTRUNCATED_NORMAL to NORMAL
JimClarke5 Sep 25, 2020
e328cbf
Changed Long seed to long seed.
JimClarke5 Sep 30, 2020
80f2fb0
Remove snake case
JimClarke5 Sep 30, 2020
b0b747e
Change snake case to camel case
JimClarke5 Oct 2, 2020
505d0d6
Change snake case to camel case
JimClarke5 Oct 2, 2020
c4a7bfb
Moved isCompatibleWith to Shape
JimClarke5 Oct 2, 2020
9229789
Moved isCompatibleWith to Shape
JimClarke5 Oct 3, 2020
f0934ea
Merge branch 'master' into Initializers1
JimClarke5 Oct 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 123 additions & 55 deletions ndarray/src/main/java/org/tensorflow/ndarray/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ public static Shape scalar() {
/**
* Create a Shape representing a scalar or an N-dimensional value.
*
* <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1),
* with the provided size for each dimension. A -1 indicates that the size of the corresponding
* dimension is unknown. If no sizes are provided, a Shape representing a scalar is created.
* For example:
* <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1), with
* the provided size for each dimension. A -1 indicates that the size of the corresponding
* dimension is unknown. If no sizes are provided, a Shape representing a scalar is created. For
* example:
*
* <pre>{@code
* // A 2-element vector.
Expand Down Expand Up @@ -88,11 +88,11 @@ public static Shape of(long... dimensionSizes) {
/**
* Returns the total number of elements a Tensor with this Shape would have.
*
* <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true,
* {@link Shape#UNKNOWN_SIZE} is returned.
* <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true, {@link
* Shape#UNKNOWN_SIZE} is returned.
*
* @return The total number of elements a Tensor with this shape would have if it can be
* calculated, else {@link Shape#UNKNOWN_SIZE}.
* calculated, else {@link Shape#UNKNOWN_SIZE}.
*/
public long size() {
if (size == null) {
Expand All @@ -108,12 +108,11 @@ public long size() {
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
*
* @param i the index of the dimension to get the size for. If this Shape has a known number of
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative,
* in which case the position is counted from the end of the shape. E.g.:
* {@code size(-1)} returns the size of the last dimension, {@code size(-2)} the size of
* the second to last dimension etc.
* dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
* case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
* size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
* otherwise.
* otherwise.
*/
public long size(int i) {
if (dimensionSizes == null) {
Expand Down Expand Up @@ -167,8 +166,8 @@ public boolean isUnknown() {
}

/**
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not
* change this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
*/
public long[] asArray() {
if (this.dimensionSizes == null) {
Expand All @@ -186,15 +185,16 @@ public int hashCode() {
/**
* Equals implementation for Shapes. Two Shapes are considered equal iff:
*
* <p>
* <ul>
* <li>the number of dimensions is defined and equal for both
* <li>the size of each dimension is defined and equal for both
* <li>the number of dimensions is defined and equal for both
* <li>the size of each dimension is defined and equal for both
* </ul>
*
* <p>If either Shape has unknown dimensions (even if they are the same in both) or if either
* shape has an unknown number of dimensions (even if both return {@code true} for
* {@link Shape#isUnknown()}), they are not considered equal! However, a shape will always
* equal itself, even if it is unknown or contains unknown dimensions.
* shape has an unknown number of dimensions (even if both return {@code true} for {@link
* Shape#isUnknown()}), they are not considered equal! However, a shape will always equal itself,
* even if it is unknown or contains unknown dimensions.
*/
@Override
public boolean equals(Object obj) {
Expand Down Expand Up @@ -233,17 +233,17 @@ public Shape head() {
}

/**
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions
* of this shape
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
* shape
*
* @param n the number of leading dimensions to get, must be &lt;= than {@link Shape#numDimensions()}
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions
* of this Shape
* @param n the number of leading dimensions to get, must be <= than {@link Shape#numDimensions()}
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
* this Shape
*/
public Shape take(int n) {
if (n > numDimensions()) {
throw new ArrayIndexOutOfBoundsException("Cannot take " + n +
" dimensions, shape has only " + numDimensions() + ".");
throw new ArrayIndexOutOfBoundsException(
"Cannot take " + n + " dimensions, shape has only " + numDimensions() + ".");
}
long[] newDimensions = new long[n];
System.arraycopy(dimensionSizes, 0, newDimensions, 0, n);
Expand All @@ -257,18 +257,18 @@ public Shape tail() {
}

/**
* Returns an n-dimensional Shape with the dimensions matching the last n dimensions
* of this Shape.
* Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this
* Shape.
*
* @param n the number of trailing dimensions to get, must be &lt;= than
* {@link Shape#numDimensions()}
* @param n the number of trailing dimensions to get, must be <= than {@link
* Shape#numDimensions()}
* @return an n-dimensional shape with the dimensions matching the last n dimensions of this
* Shape, never null
* Shape, never null
*/
public Shape takeLast(int n) {
if (n > numDimensions()) {
throw new ArrayIndexOutOfBoundsException("Cannot take last " + n +
" dimensions, shape has only " + numDimensions() + ".");
throw new ArrayIndexOutOfBoundsException(
"Cannot take last " + n + " dimensions, shape has only " + numDimensions() + ".");
}
long[] newDimensions = new long[n];
System.arraycopy(dimensionSizes, numDimensions() - n, newDimensions, 0, n);
Expand All @@ -280,8 +280,8 @@ public Shape takeLast(int n) {
* {@link Shape#isUnknown()} must be {@code false}.
*
* @param firstDimension the dimension to prepend
* @return a new shape with the given dimension first, followed by this Shape's dimensions,
* never null
* @return a new shape with the given dimension first, followed by this Shape's dimensions, never
* null
*/
public Shape prepend(long firstDimension) {
long[] newDimensions = new long[dimensionSizes.length + 1];
Expand All @@ -292,8 +292,8 @@ public Shape prepend(long firstDimension) {
}

/**
* Returns a new Shape, with a new last dimension added. In order for this call to succeed,
* {@link Shape#isUnknown()} must be {@code false}.
* Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link
* Shape#isUnknown()} must be {@code false}.
*
* @param lastDimension the dimension to append
* @return a new Shape with this Shape's dimensions followed by the given dimension, never null
Expand All @@ -307,38 +307,36 @@ public Shape append(long lastDimension) {
}

/**
* Returns a new Shape, with another Shape's dimensions prepended.
* For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
* E.g. {@code Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
* Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the
* other Shape, {@link Shape#isUnknown()} must return false. E.g. {@code
* Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
*
* @param other another Shape, must not be {@code null}, must not be unknown
* @return A new Shape consisting of the given Shapes's dimensions followed by this Shape's
* dimensions, never null
* @return A new Shape consisting of the given Shape's dimensions followed by this Shape's
* dimensions, never null
*/
public Shape prepend(Shape other) {
long[] newDimensions = new long[other.dimensionSizes.length + dimensionSizes.length];
System.arraycopy(other.dimensionSizes, 0,
newDimensions, 0, other.dimensionSizes.length);
System.arraycopy(dimensionSizes, 0,
newDimensions, other.dimensionSizes.length, dimensionSizes.length);
System.arraycopy(other.dimensionSizes, 0, newDimensions, 0, other.dimensionSizes.length);
System.arraycopy(
dimensionSizes, 0, newDimensions, other.dimensionSizes.length, dimensionSizes.length);
return Shape.of(newDimensions);
}

/**
* Returns a new Shape, with another Shapes' dimensions appended.
* For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
* e.g. {@code Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
* Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the
* other Shape, {@link Shape#isUnknown()} must return false. E.g. @code
* Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
*
* @param other another Shape, must not be {@code null}, must not be unknown
* @return A new Shape consisting of this Shapes's dimensions followed by the given Shape's
* dimensions
* @return A new Shape consisting of this Shape's dimensions followed by the given Shape's
* dimensions
*/
public Shape append(Shape other) {
long[] newDimensions = new long[dimensionSizes.length + other.dimensionSizes.length];
System.arraycopy(dimensionSizes, 0,
newDimensions, 0, dimensionSizes.length);
System.arraycopy(other.dimensionSizes, 0,
newDimensions, dimensionSizes.length, other.dimensionSizes.length);
System.arraycopy(dimensionSizes, 0, newDimensions, 0, dimensionSizes.length);
System.arraycopy(
other.dimensionSizes, 0, newDimensions, dimensionSizes.length, other.dimensionSizes.length);
return Shape.of(newDimensions);
}

Expand All @@ -355,4 +353,74 @@ private static long computeSize(long[] dimensionSizes) {
}
return computedSize;
}

/**
* Determines whether another shape is compatible with this one.
*
* <p>
*
* <p>Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape
* that both shapes can represent. Thus, compatibility allows the shape inference code to reason
* about partially-defined shapes. For example:
*
* <ul>
* <li><code>Shape.unknown()</code> is compatible with all shapes.
* <li><code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> is compatible with all two-dimensional
* shapes, such as <code>Shape(32, 784)</code>, and also <code>Shape.unknown()</code>. It is
* not compatible with, for example, <code>Shape(UNKNOWN_SIZE)</code> or <code>
* Shape(UNKNOWN_SIZE, UNKNOWN_SIZE, UNKNOWN_SIZE)</code>.
* <li><code>Shape(32, UNKNOWN_SIZE)</code> is compatible with all two-dimensional shapes with
* size 32 in the 0th dimension, and also <code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and
* <code>Shape.unknown()</code>. It is not compatible with, for example, <code>Shape(32)
* </code>, <code>Shape(32, UNKNOWN_SIZE, 1)</code> or <code>Shape(64, UNKNOWN_SIZE)</code>.
* <li><code>Shape(32, 784)</code> is compatible with itself, and also <code>
* Shape(32, UNKNOWN_SIZE)</code>, <code>Shape(UNKNOWN_SIZE, 784)</code>, <code>
* Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and <code>Shape.unknown()</code>. It is not
* compatible with, for example, <code>Shape(32, 1, 784)</code> or <code>Shape(UNKNOWN_SIZE)
* </code>.
* </ul>
*
* <p>The compatibility relation is reflexive and symmetric, but not transitive. For example,
* <code>Shape(32, 784)</code> is compatible with <code>Shape.unknown()</code>, and <code>
* Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
* </code> is not compatible with <code>Shape(4, 4)</code>.
*
* <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
* of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
* at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
*
* <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
* one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
* is "stretched" with dimensions of 1.
*
* @param shape The other shape
* @return true, if the two shapes are compatible.
*/
public boolean isCompatibleWith(Shape shape) {
if (!this.isUnknown() && !shape.isUnknown()) {
if (numDimensions() != shape.numDimensions()) {
return false;
}
for (int i = 0; i < numDimensions(); i++) {
if (!isCompatible(size(i), shape.size(i))) {
return false;
}
}
}
return true;
}

/**
* Test to see if two shape dimensions are compatible.
*
* <p>The dimensions are compatible if either dimension is <code>Shape.UNKNOWN_SIZE</code> or both
* dimensions are equal
*
* @param dim the first dimension
* @param otherDim the second dimension
* @return true, if both dimensions are compatible
*/
public static boolean isCompatible(long dim, long otherDim) {
return dim == Shape.UNKNOWN_SIZE || otherDim == Shape.UNKNOWN_SIZE || dim == otherDim;
}
}
36 changes: 34 additions & 2 deletions ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
*/
package org.tensorflow.ndarray;

import static org.junit.jupiter.api.Assertions.*;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.*;

public class ShapeTest {

@Test
Expand Down Expand Up @@ -135,4 +135,36 @@ public void testShapeModification() {
internalShape[0] = 42L;
assertEquals(2L, one.size(0));
}

@Test
public void testShapeCompatible() {
Shape a = Shape.unknown();
Shape b = Shape.of(2, 2);
assertTrue(a.isCompatibleWith(b));
assertTrue(b.isCompatibleWith(a));

a = Shape.of(2, 2);
assertTrue(a.isCompatibleWith(b));
assertTrue(b.isCompatibleWith(a));

a = Shape.of(2, -1);
assertTrue(a.isCompatibleWith(b));
assertTrue(b.isCompatibleWith(a));

a = Shape.of(-1, 2);
assertTrue(a.isCompatibleWith(b));
assertTrue(b.isCompatibleWith(a));

a = Shape.of(-1, -1);
assertTrue(a.isCompatibleWith(b));
assertTrue(b.isCompatibleWith(a));

a = Shape.of(1, 2);
assertFalse(a.isCompatibleWith(b));
assertFalse(b.isCompatibleWith(a));

a = Shape.of(1, 2, 3);
assertFalse(a.isCompatibleWith(b));
assertFalse(b.isCompatibleWith(a));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.initializers;

import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TType;

/** Abstract base class for all Initializers */
public abstract class BaseInitializer<T extends TType> implements Initializer<T> {

protected final Ops tf;

/**
* Creates an Initializer
*
* @param tf the TensorFlow Ops
*/
protected BaseInitializer(Ops tf) {
this.tf = tf;
}

/**
* Gets the TensorFlow Ops
*
* @return the TensorFlow Ops
*/
public Ops getTF() {
return tf;
}
}
Loading