diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java
index a95110c9a96..10106723fba 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java
@@ -42,15 +42,12 @@ public class SoftmaxCrossEntropyWithLogits {
*
Usage:
*
*
- * Operand<TFloat32> logits =
- * tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
- * Operand<TFloat32> labels =
- * tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
- * Operand<TFloat32> output =
- * tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
- * // output Shape = [2]
- * // dataType = FLOAT (1)
- * // values { 0.169846, 0.824745 }
+ * Operand<TFloat32> logits = tf.constant(new float[][] { { 4.0F, 2.0F, 1.0F }, { 0.0F, 5.0F, 1.0F } });
+ * Operand<TFloat32> labels = tf.constant(new float[][] { { 1.0F, 0.0F, 0.0F }, { 0.0F, 0.8F, 0.2F } });
+ * Operand<TFloat32> output = tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+ * // output Shape = [2]
+ * // dataType = FLOAT (1)
+ * // values { 0.169846, 0.824745 }
*
*
* Backpropagation will happen into both logits
and labels
. To
@@ -157,7 +154,7 @@ public static Operand softmaxCrossEntr
* @return the flattened logits
*/
private static Operand flattenOuterDims(Scope scope, Operand logits) {
- Operand one = Constant.scalarOf(scope, 1L);
+ Operand one = Constant.arrayOf(scope, 1L);
Shape shape = logits.shape();
int ndims = shape.numDimensions();
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java
index 25f5e5a54f1..1be85927d4f 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java
@@ -17,10 +17,14 @@
import static org.junit.jupiter.api.Assertions.assertThrows;
import org.junit.jupiter.api.Test;
+import org.tensorflow.Graph;
import org.tensorflow.Operand;
+import org.tensorflow.Session;
import org.tensorflow.framework.utils.TestSession;
import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.Placeholder;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
@@ -36,16 +40,8 @@ public void testAllCorrectUnweighted() {
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
Ops tf = testSession.getTF();
- long[] trueArray = {
- 1L, 0L, 0L,
- 0L, 1L, 0L,
- 0L, 0L, 1L
- };
- float[] predArray = {
- 1.F, 0.F, 0.F,
- 0.F, 1.F, 0.F,
- 0.F, 0.F, 1.F
- };
+ long[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L};
+ float[] predArray = {1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F};
Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
CategoricalCrossentropy instance = new CategoricalCrossentropy();
@@ -55,11 +51,7 @@ public void testAllCorrectUnweighted() {
testSession.evaluate(expected, loss);
// Test with logits.
- float[] logitsArray = {
- 10.F, 0.F, 0.F,
- 0.F, 10.F, 0.F,
- 0.F, 0.F, 10.F
- };
+ float[] logitsArray = {10.F, 0.F, 0.F, 0.F, 10.F, 0.F, 0.F, 0.F, 10.F};
yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
@@ -85,11 +77,7 @@ public void testInvalidPredictionsRange() {
Ops tf = testSession.getTF();
CategoricalCrossentropy instance = new CategoricalCrossentropy();
- float[] trueArray = {
- 1L, 0L, 0L,
- 0L, 1L, 0L,
- 0L, 0L, 1L
- };
+ float[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L};
float[] predArray = {-1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F};
Operand yTrue =
tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
@@ -111,11 +99,7 @@ public void testUnweighted() {
CategoricalCrossentropy instance = new CategoricalCrossentropy();
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
- float[] predArray = {
- .9F, .05F, .05F,
- .5F, .89F, .6F,
- .05F, .01F, .94F
- };
+ float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand loss = instance.call(tf, yTrue, yPred);
@@ -123,11 +107,7 @@ public void testUnweighted() {
testSession.evaluate(expected, loss);
// Test with logits.
- float[] logitsArray = {
- 8.F, 1.F, 1.F,
- 0.F, 9.F, 1.F,
- 2.F, 3.F, 5.F
- };
+ float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
@@ -145,16 +125,8 @@ public void testScalarWeighted() {
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
Ops tf = testSession.getTF();
- int[] trueArray = {
- 1, 0, 0,
- 0, 1, 0,
- 0, 0, 1
- };
- float[] predArray = {
- .9F, .05F, .05F,
- .5F, .89F, .6F,
- .05F, .01F, .94F
- };
+ int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
+ float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand sampleWeight = tf.constant(2.3F);
@@ -166,11 +138,7 @@ public void testScalarWeighted() {
testSession.evaluate(expected, loss);
// Test with logits.
- float[] logitsArray = {
- 8.F, 1.F, 1.F,
- 0.F, 9.F, 1.F,
- 2.F, 3.F, 5.F
- };
+ float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
@@ -189,16 +157,8 @@ public void testSsampleWeighted() {
CategoricalCrossentropy instance = new CategoricalCrossentropy();
float[] sampeWeightArray = {1.2F, 3.4F, 5.6F};
- int[] trueArray = {
- 1, 0, 0,
- 0, 1, 0,
- 0, 0, 1
- };
- float[] predArray = {
- .9F, .05F, .05F,
- .5F, .89F, .6F,
- .05F, .01F, .94F
- };
+ int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
+ float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand sampleWeight =
@@ -208,11 +168,7 @@ public void testSsampleWeighted() {
testSession.evaluate(expected, loss);
// Test with logits.
- float[] logitsArray = {
- 8.F, 1.F, 1.F,
- 0.F, 9.F, 1.F,
- 2.F, 3.F, 5.F
- };
+ float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
@@ -231,11 +187,7 @@ public void testNoReduction() {
// Test with logits.
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
- float[] logitsArray = {
- 8.F, 1.F, 1.F,
- 0.F, 9.F, 1.F,
- 2.F, 3.F, 5.F
- };
+ float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
@@ -266,4 +218,34 @@ public void testLabelSmoothing() {
testSession.evaluate(expected, loss);
}
}
+
+ @Test
+ public void testCategoricalCrossEntopyWithDynamicBatchSize() {
+ try (Graph graph = new Graph()) {
+ Ops tf = Ops.create(graph);
+ Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3)));
+ Operand yTrue =
+ tf.reshape(tf.constant(new float[] {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}), tf.array(3, 3));
+ CategoricalCrossentropy instance = new CategoricalCrossentropy(true);
+ Operand loss =
+ instance.call(tf, yTrue, yPred); // Throw TFInvalidArgument Exception without fix
+ try (Session session = new Session(graph);
+ TFloat32 result =
+ (TFloat32)
+ session
+ .runner()
+ .feed(
+ yPred,
+ TFloat32.tensorOf(
+ Shape.of(3, 3),
+ DataBuffers.of(
+ new float[] {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f})))
+ .fetch(loss)
+ .run()
+ .get(0)) {
+ if (Math.abs(0.5514477f - result.getFloat()) > 0.01)
+ throw new IllegalStateException("Invalid result :" + result.getFloat());
+ }
+ }
+ }
}