diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java index a95110c9a96..10106723fba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -42,15 +42,12 @@ public class SoftmaxCrossEntropyWithLogits { *

Usage: * *

-   *   Operand<TFloat32> logits =
-   *       tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
-   *   Operand<TFloat32> labels =
-   *       tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
-   *   Operand<TFloat32> output =
-   *       tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
-   *   // output Shape = [2]
-   *   // dataType = FLOAT (1)
-   *   // values { 0.169846, 0.824745 }
+   * Operand<TFloat32> logits = tf.constant(new float[][] { { 4.0F, 2.0F, 1.0F }, { 0.0F, 5.0F, 1.0F } });
+   * Operand<TFloat32> labels = tf.constant(new float[][] { { 1.0F, 0.0F, 0.0F }, { 0.0F, 0.8F, 0.2F } });
+   * Operand<TFloat32> output = tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+   * // output Shape = [2]
+   * // dataType = FLOAT (1)
+   * // values { 0.169846, 0.824745 }
    * 
* *

Backpropagation will happen into both logits and labels. To @@ -157,7 +154,7 @@ public static Operand softmaxCrossEntr * @return the flattened logits */ private static Operand flattenOuterDims(Scope scope, Operand logits) { - Operand one = Constant.scalarOf(scope, 1L); + Operand one = Constant.arrayOf(scope, 1L); Shape shape = logits.shape(); int ndims = shape.numDimensions(); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index 25f5e5a54f1..1be85927d4f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -17,10 +17,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; +import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; @@ -36,16 +40,8 @@ public void testAllCorrectUnweighted() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - long[] trueArray = { - 1L, 0L, 0L, - 0L, 1L, 0L, - 0L, 0L, 1L - }; - float[] predArray = { - 1.F, 0.F, 0.F, - 0.F, 1.F, 0.F, - 0.F, 0.F, 1.F - }; + long[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L}; + float[] predArray = {1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); CategoricalCrossentropy instance = new CategoricalCrossentropy(); @@ -55,11 +51,7 @@ public void testAllCorrectUnweighted() { testSession.evaluate(expected, loss); // Test with logits. - float[] logitsArray = { - 10.F, 0.F, 0.F, - 0.F, 10.F, 0.F, - 0.F, 0.F, 10.F - }; + float[] logitsArray = {10.F, 0.F, 0.F, 0.F, 10.F, 0.F, 0.F, 0.F, 10.F}; yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); @@ -85,11 +77,7 @@ public void testInvalidPredictionsRange() { Ops tf = testSession.getTF(); CategoricalCrossentropy instance = new CategoricalCrossentropy(); - float[] trueArray = { - 1L, 0L, 0L, - 0L, 1L, 0L, - 0L, 0L, 1L - }; + float[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L}; float[] predArray = {-1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); @@ -111,11 +99,7 @@ public void testUnweighted() { CategoricalCrossentropy instance = new CategoricalCrossentropy(); int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; - float[] predArray = { - .9F, .05F, .05F, - .5F, .89F, .6F, - .05F, .01F, .94F - }; + float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand loss = instance.call(tf, yTrue, yPred); @@ -123,11 +107,7 @@ public void testUnweighted() { testSession.evaluate(expected, loss); // Test with logits. - float[] logitsArray = { - 8.F, 1.F, 1.F, - 0.F, 9.F, 1.F, - 2.F, 3.F, 5.F - }; + float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F}; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(true); @@ -145,16 +125,8 @@ public void testScalarWeighted() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - int[] trueArray = { - 1, 0, 0, - 0, 1, 0, - 0, 0, 1 - }; - float[] predArray = { - .9F, .05F, .05F, - .5F, .89F, .6F, - .05F, .01F, .94F - }; + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; + float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.constant(2.3F); @@ -166,11 +138,7 @@ public void testScalarWeighted() { testSession.evaluate(expected, loss); // Test with logits. - float[] logitsArray = { - 8.F, 1.F, 1.F, - 0.F, 9.F, 1.F, - 2.F, 3.F, 5.F - }; + float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F}; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(true); @@ -189,16 +157,8 @@ public void testSsampleWeighted() { CategoricalCrossentropy instance = new CategoricalCrossentropy(); float[] sampeWeightArray = {1.2F, 3.4F, 5.6F}; - int[] trueArray = { - 1, 0, 0, - 0, 1, 0, - 0, 0, 1 - }; - float[] predArray = { - .9F, .05F, .05F, - .5F, .89F, .6F, - .05F, .01F, .94F - }; + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; + float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = @@ -208,11 +168,7 @@ public void testSsampleWeighted() { testSession.evaluate(expected, loss); // Test with logits. - float[] logitsArray = { - 8.F, 1.F, 1.F, - 0.F, 9.F, 1.F, - 2.F, 3.F, 5.F - }; + float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F}; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(true); @@ -231,11 +187,7 @@ public void testNoReduction() { // Test with logits. int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; - float[] logitsArray = { - 8.F, 1.F, 1.F, - 0.F, 9.F, 1.F, - 2.F, 3.F, 5.F - }; + float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); @@ -266,4 +218,34 @@ public void testLabelSmoothing() { testSession.evaluate(expected, loss); } } + + @Test + public void testCategoricalCrossEntopyWithDynamicBatchSize() { + try (Graph graph = new Graph()) { + Ops tf = Ops.create(graph); + Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3))); + Operand yTrue = + tf.reshape(tf.constant(new float[] {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}), tf.array(3, 3)); + CategoricalCrossentropy instance = new CategoricalCrossentropy(true); + Operand loss = + instance.call(tf, yTrue, yPred); // Throw TFInvalidArgument Exception without fix + try (Session session = new Session(graph); + TFloat32 result = + (TFloat32) + session + .runner() + .feed( + yPred, + TFloat32.tensorOf( + Shape.of(3, 3), + DataBuffers.of( + new float[] {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}))) + .fetch(loss) + .run() + .get(0)) { + if (Math.abs(0.5514477f - result.getFloat()) > 0.01) + throw new IllegalStateException("Invalid result :" + result.getFloat()); + } + } + } }