Skip to content

Commit 4740bd4

Browse files
committed
Rebase from master
1 parent 7897b25 commit 4740bd4

File tree

16 files changed

+220
-116
lines changed

16 files changed

+220
-116
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java

Lines changed: 146 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.tensorflow.op;
1919

2020
import java.util.List;
21-
import org.tensorflow.DataType;
2221
import org.tensorflow.Operand;
2322
import org.tensorflow.op.nn.AvgPool;
2423
import org.tensorflow.op.nn.AvgPool3d;
@@ -84,6 +83,7 @@
8483
import org.tensorflow.op.nn.Relu;
8584
import org.tensorflow.op.nn.Relu6;
8685
import org.tensorflow.op.nn.Selu;
86+
import org.tensorflow.op.nn.SigmoidCrossEntropyWithLogits;
8787
import org.tensorflow.op.nn.Softmax;
8888
import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits;
8989
import org.tensorflow.op.nn.Softsign;
@@ -103,10 +103,13 @@
103103
* @see {@link Ops}
104104
*/
105105
public final class NnOps {
106+
public final NnRawOps raw;
107+
106108
private final Scope scope;
107109

108110
NnOps(Scope scope) {
109111
this.scope = scope;
112+
raw = new NnRawOps(scope);
110113
}
111114

112115
/**
@@ -1342,7 +1345,7 @@ public <T extends TNumber> MaxPool3d<T> maxPool3d(Operand<T> input, List<Long> k
13421345
}
13431346

13441347
/**
1345-
* Computes gradients of max pooling function.
1348+
* Computes gradients of 3D max pooling function.
13461349
*
13471350
* @param <U> data type for {@code output()} output
13481351
* @param origInput The original input tensor.
@@ -1767,6 +1770,56 @@ public <T extends TNumber> Selu<T> selu(Operand<T> features) {
17671770
return Selu.create(scope, features);
17681771
}
17691772

1773+
/**
1774+
* Computes sigmoid cross entropy given <code>logits</code>.
1775+
*
1776+
* <p>Measures the probability error in discrete classification tasks in which each class is
1777+
* independent and not mutually exclusive. For instance, one could perform multilabel
1778+
* classification where a picture can contain both an elephant and a dog at the same time.
1779+
*
1780+
* <p>For brevity, let <code>x = logits</code>, <code>z = labels</code>. The logistic loss in
1781+
* pseudo-code is
1782+
*
1783+
* <pre>
1784+
* z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1785+
* = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
1786+
* = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
1787+
* = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
1788+
* = (1 - z) * x + log(1 + exp(-x))
1789+
* = x - x * z + log(1 + exp(-x))
1790+
* </pre>
1791+
*
1792+
* <p>For <code>x < 0</code>, to avoid overflow in <code>exp(-x)</code>, we reformulate the above
1793+
*
1794+
* <pre>
1795+
* x - x * z + log(1 + exp(-x))
1796+
* = log(exp(x)) - x * z + log(1 + exp(-x))
1797+
* = - x * z + log(1 + exp(x))
1798+
* </pre>
1799+
*
1800+
* <p>Hence, to ensure stability and avoid overflow, the implementation uses this equivalent
1801+
* formulation
1802+
*
1803+
* <pre>
1804+
* max(x, 0) - x * z + log(1 + exp(-abs(x)))
1805+
* </pre>
1806+
*
1807+
* <p></ode>logits</code> and <code>labels</code> must have the same type and shape.
1808+
*
1809+
* <p>
1810+
*
1811+
* @param scope The TensorFlow scope
1812+
* @param labels the labels
1813+
* @param logits the logits of type float32 or float64
1814+
* @param <T> the type of labels and logits
1815+
* @return the component-wise logistic losses.
1816+
* @throws IllegalArgumentException if logits' and labels' do not have the same shape
1817+
*/
1818+
public <T extends TNumber> Operand<T> sigmoidCrossEntropyWithLogits(Operand<T> labels,
1819+
Operand<T> logits) {
1820+
return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits);
1821+
}
1822+
17701823
/**
17711824
* Computes softmax activations.
17721825
* <p>
@@ -1783,20 +1836,54 @@ public <T extends TNumber> Softmax<T> softmax(Operand<T> logits) {
17831836
}
17841837

17851838
/**
1786-
* Computes softmax cross entropy cost and gradients to backpropagate.
1787-
* <p>
1788-
* Inputs are the logits, not probabilities.
1839+
* Computes softmax cross entropy between <code>logits</code> and <code>labels</code>.
17891840
*
1790-
* @param <T> data type for {@code loss()} output
1791-
* @param features batch_size x num_classes matrix
1792-
* @param labels batch_size x num_classes matrix
1793-
* The caller must ensure that each batch of labels represents a valid
1794-
* probability distribution.
1795-
* @return a new instance of SoftmaxCrossEntropyWithLogits
1841+
* <p>Measures the probability error in discrete classification tasks in which the classes are
1842+
* mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is
1843+
* labeled with one and only one label: an image can be a dog or a truck, but not both.
1844+
*
1845+
* <p><b>NOTE:</b>
1846+
*
1847+
* <p>While the classes are mutually exclusive, their probabilities need not be. All that is
1848+
* required is that each row of <code>labels</code> is a valid probability distribution. If they
1849+
* are not, the computation of the gradient will be incorrect.
1850+
*
1851+
* <p>If using exclusive <code>labels</code> (wherein one and only one class is true at a time),
1852+
* see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits}
1853+
*
1854+
* <p>Usage:
1855+
*
1856+
* <pre>
1857+
* Operand&lt;TFloat32&gt; logits =
1858+
* tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
1859+
* Operand&lt;TFloat32&gt; labels =
1860+
* tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
1861+
* Operand&lt;TFloat32&gt; output =
1862+
* tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
1863+
* // output Shape = [2]
1864+
* // dataType = FLOAT (1)
1865+
* // values { 0.169846, 0.824745 }
1866+
* </pre>
1867+
*
1868+
* <p>Backpropagation will happen into both <code>logits</code> and <code>labels</code>. To
1869+
* disallow backpropagation into <code>labels</code>, pass label tensors through <code>
1870+
* tf.stopGradient</code> before feeding it to this function.
1871+
*
1872+
* @param scope current scope
1873+
* @param labels Each vector along the class dimension should hold a valid probability
1874+
* distribution e.g. for the case in which labels are of shape <code>[batch_size, num_classes]
1875+
* </code>, each row of <code>labels[i]</code> must be a valid probability distribution.
1876+
* @param logits Per-label activations, typically a linear output. These activation energies are
1877+
* interpreted as unnormalized log probabilities.
1878+
* @param axis The class dimension. -1 is the last dimension.
1879+
* @param <T> the number type of the operands
1880+
* @return the softmax cross entropy loss. Its type is the same as <code>logits</code> and its
1881+
* shape is the same as <code>labels</code> except that it does not have the last dimension of
1882+
* <code>labels</code>.
17961883
*/
1797-
public <T extends TNumber> SoftmaxCrossEntropyWithLogits<T> softmaxCrossEntropyWithLogits(
1798-
Operand<T> features, Operand<T> labels) {
1799-
return SoftmaxCrossEntropyWithLogits.create(scope, features, labels);
1884+
public <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(
1885+
Operand<U> labels, Operand<T> logits, int axis) {
1886+
return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis);
18001887
}
18011888

18021889
/**
@@ -1988,24 +2075,51 @@ public <T extends TType> SpaceToDepth<T> spaceToDepth(Operand<T> input, Long blo
19882075
}
19892076

19902077
/**
1991-
* Computes softmax cross entropy cost and gradients to backpropagate.
1992-
* <p>
1993-
* Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
1994-
* a matrix of label probabilities, but rather a single label per row
1995-
* of features. This label is considered to have probability 1.0 for the
1996-
* given row.
1997-
* <p>
1998-
* Inputs are the logits, not probabilities.
1999-
*
2000-
* @param <T> data type for {@code loss()} output
2001-
* @param features batch_size x num_classes matrix
2002-
* @param labels batch_size vector with values in [0, num_classes).
2003-
* This is the label for the given minibatch entry.
2004-
* @return a new instance of SparseSoftmaxCrossEntropyWithLogits
2005-
*/
2006-
public <T extends TNumber, U extends TNumber> SparseSoftmaxCrossEntropyWithLogits<T> sparseSoftmaxCrossEntropyWithLogits(
2007-
Operand<T> features, Operand<U> labels) {
2008-
return SparseSoftmaxCrossEntropyWithLogits.create(scope, features, labels);
2078+
* Computes sparse softmax cross entropy between <code>logits</code> and <code>labels</code>.
2079+
*
2080+
* <p>Measures the probability error in discrete classification tasks in which the classes are
2081+
* mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is
2082+
* labeled with one and only one label: an image can be a dog or a truck, but not both.
2083+
*
2084+
* <p><b>NOTE:</b>
2085+
*
2086+
* <p>For this operation, the probability of a given label is considered exclusive. That is, soft
2087+
* classes are not allowed, and the <code>labels</code> vector must provide a single specific
2088+
* index for the true class for each row of <code>logits</code> (each minibatch entry). For soft
2089+
* softmax classification with a probability distribution for each entry, {@link
2090+
* org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}.
2091+
*
2092+
* <p><b>WARNING:</b>
2093+
*
2094+
* <p>This op expects unscaled logits, since it performs a <code>softmax</code> on <code>logits
2095+
* </code> internally for efficiency. Do not call this op with the output of <code>softmax</code>,
2096+
* as it will produce incorrect results.
2097+
*
2098+
* <p>A common use case is to have logits of shape <code>[batchSize, numClasses]</code> and have
2099+
* labels of shape <code>[batchSize]</code>, but higher dimensions are supported, in which case
2100+
* the <code>dim</code>-th dimension is assumed to be of size <code>numClasses</code>. <code>
2101+
* logits</code> must have the <cod>dataType</cod> of <code>TFloat16</code>, <code>TFloat32</code>
2102+
* , or <code>TFloat64</code>, and <code>labels</code> must have the dtype of <code>TInt32</code>
2103+
* or <code>TInt64</code>.
2104+
*
2105+
* @param scope current scope
2106+
* @param labels <code>Tensor</code> of shape <code>[d_0, d_1, ..., d_{r-1}]</code> (where <code>r
2107+
* </code> is rank of <code>labels</code> and result) and the dataType is <code>TInt32</code>
2108+
* or <code>TInt64</code>. Each entry in <code>labels</code> must be an index in <code>[0,
2109+
* numClasses)</code>. Other values will raise an exception when this op is run on CPU, and
2110+
* return <code>NaN</code> for corresponding loss and gradient rows on GPU.
2111+
* @param logits Per-label activations (typically a linear output) of shape <code>[d_0, d_1, ...,
2112+
* d_{r-1}, numClasses]</code> and dataType of <code>TFloat16</code>, <code>TFloat32</code>,
2113+
* or <code>TFloat64</code>. These activation energies are interpreted as unnormalized log
2114+
* probabilities.
2115+
* @return A <code>Tensor</code> of the same shape as <code>labels</code> and of the same type as
2116+
* <code>logits</code> with the softmax cross entropy loss.
2117+
* @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank
2118+
* of the labels is not equal to the rank of the logits minus one.
2119+
*/
2120+
public <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossEntropyWithLogits(
2121+
Operand<T> labels, Operand<U> logits) {
2122+
return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits);
20092123
}
20102124

20112125
/**

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,10 @@ public void createScalar() {
3939

4040
@Test
4141
public void createrScalarLongerThan127() {
42-
Tensor<TString> tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !");
42+
TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !");
4343
assertNotNull(tensor);
44-
45-
TString data = tensor.data();
46-
assertNotNull(data);
47-
assertEquals(Shape.scalar(), data.shape());
48-
assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", data.getObject());
44+
assertEquals(Shape.scalar(), tensor.shape());
45+
assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject());
4946
}
5047

5148

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
=======================================================================*/
1515
package org.tensorflow.framework.activations;
1616

17-
import org.tensorflow.DataType;
1817
import org.tensorflow.Operand;
1918
import org.tensorflow.op.Ops;
2019
import org.tensorflow.types.TBool;
@@ -89,7 +88,7 @@ public Operand<T> call(Operand<T> input) {
8988
Operand<T> result = tf.nn.elu(input);
9089
if (alpha == 1.0) return result;
9190
else {
92-
DataType<T> dataType = input.asOutput().dataType();
91+
Class<T> dataType = input.asOutput().type();
9392
Operand<T> y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), dataType));
9493
Operand<TBool> cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), dataType));
9594
return tf.select(cond, result, y);

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
=======================================================================*/
1515
package org.tensorflow.framework.activations;
1616

17-
import org.tensorflow.DataType;
1817
import org.tensorflow.Operand;
1918
import org.tensorflow.op.Ops;
2019
import org.tensorflow.types.family.TFloating;
@@ -63,7 +62,7 @@ public HardSigmoid(Ops tf) {
6362
*/
6463
@Override
6564
public Operand<T> call(Operand<T> input) {
66-
DataType<T> dataType = input.asOutput().dataType();
65+
Class<T> dataType = input.asOutput().type();
6766
Operand<T> point2 = tf.dtypes.cast(tf.constant(0.2), dataType);
6867
Operand<T> point5 = tf.dtypes.cast(tf.constant(0.5), dataType);
6968

tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
=======================================================================*/
1515
package org.tensorflow.framework.activations;
1616

17-
import org.tensorflow.DataType;
1817
import org.tensorflow.Operand;
1918
import org.tensorflow.op.Ops;
2019
import org.tensorflow.op.math.Greater;
@@ -99,7 +98,7 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) {
9998
@Override
10099
public Operand<T> call(Operand<T> input) {
101100

102-
DataType<T> dataType = input.asOutput().dataType();
101+
Class<T> dataType = input.asOutput().type();
103102

104103
boolean clipMax = !Float.isNaN(maxValue);
105104
Operand<T> negativePart = null;

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ public <T extends TNumber, U extends TNumber> Operand<T> call(
217217
getTF(),
218218
"predictions range check [0-1]",
219219
predictions,
220-
cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()),
221-
cast(getTF(), getTF().constant(1), predictions.asOutput().dataType()));
220+
cast(getTF(), getTF().constant(0), predictions.asOutput().type()),
221+
cast(getTF(), getTF().constant(1), predictions.asOutput().type()));
222222

223223
} else {
224224
lPredictions = predictions;

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ public <T extends TNumber, U extends TNumber> Operand<T> call(
256256
getTF(),
257257
"predictions range check [0-1]",
258258
predictions,
259-
cast(getTF(), getTF().constant(0), predictions.asOutput().dataType()),
260-
cast(getTF(), getTF().constant(1), predictions.asOutput().dataType()));
259+
cast(getTF(), getTF().constant(0), predictions.asOutput().type()),
260+
cast(getTF(), getTF().constant(1), predictions.asOutput().type()));
261261

262262
} else {
263263
lPredictions = predictions;

tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ public Hinge(Ops tf, String name, Reduction reduction) {
124124
public <T extends TNumber, U extends TNumber> Operand<T> call(
125125
Operand<U> labels, Operand<T> predictions, Operand<T> sampleWeights) {
126126
@SuppressWarnings("unchecked")
127-
Operand<T> tLabels = predictions.asOutput().dataType() == labels.asOutput().dataType() ?
127+
Operand<T> tLabels = predictions.asOutput().type() == labels.asOutput().type() ?
128128
(Operand<T>)labels :
129-
cast(tf, labels, predictions.asOutput().dataType());
129+
cast(tf, labels, predictions.asOutput().type());
130130
tLabels = LossesHelper.valueCheck(
131131
getTF(),
132132
"labels value check [-1, 0, 1]",
133133
tLabels,
134134
cast(getTF(), getTF().constant(new int[] { -1, 0, 1}),
135-
predictions.asOutput().dataType()));
135+
predictions.asOutput().type()));
136136

137137
Operand<T> losses = Losses.hinge(getTF(), tLabels, predictions);
138138
return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights);

0 commit comments

Comments
 (0)