Skip to content

Commit

Permalink
fix #526 (#527)
Browse files Browse the repository at this point in the history
* Interface should be public for external usage

* Fix #523

* Fix google format

* fix #526

* Add test to CategoricalCrossentropyTest.java
  • Loading branch information
nfeybesse authored Mar 8, 2024
1 parent 3f89f60 commit 2b6d83f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,12 @@ public class SoftmaxCrossEntropyWithLogits {
* <p>Usage:
*
* <pre>
* Operand&lt;TFloat32&gt; logits =
* tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
* Operand&lt;TFloat32&gt; labels =
* tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
* Operand&lt;TFloat32&gt; output =
* tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
* // output Shape = [2]
* // dataType = FLOAT (1)
* // values { 0.169846, 0.824745 }
* Operand&lt;TFloat32&gt; logits = tf.constant(new float[][] { { 4.0F, 2.0F, 1.0F }, { 0.0F, 5.0F, 1.0F } });
* Operand&lt;TFloat32&gt; labels = tf.constant(new float[][] { { 1.0F, 0.0F, 0.0F }, { 0.0F, 0.8F, 0.2F } });
* Operand&lt;TFloat32&gt; output = tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
* // output Shape = [2]
* // dataType = FLOAT (1)
* // values { 0.169846, 0.824745 }
* </pre>
*
* <p>Backpropagation will happen into both <code>logits</code> and <code>labels</code>. To
Expand Down Expand Up @@ -157,7 +154,7 @@ public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntr
* @return the flattened logits
*/
private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) {
Operand<TInt64> one = Constant.scalarOf(scope, 1L);
Operand<TInt64> one = Constant.arrayOf(scope, 1L);

Shape shape = logits.shape();
int ndims = shape.numDimensions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.framework.utils.TestSession;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
Expand All @@ -36,16 +40,8 @@ public void testAllCorrectUnweighted() {
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
Ops tf = testSession.getTF();

long[] trueArray = {
1L, 0L, 0L,
0L, 1L, 0L,
0L, 0L, 1L
};
float[] predArray = {
1.F, 0.F, 0.F,
0.F, 1.F, 0.F,
0.F, 0.F, 1.F
};
long[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L};
float[] predArray = {1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F};
Operand<TInt64> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
CategoricalCrossentropy instance = new CategoricalCrossentropy();
Expand All @@ -55,11 +51,7 @@ public void testAllCorrectUnweighted() {
testSession.evaluate(expected, loss);

// Test with logits.
float[] logitsArray = {
10.F, 0.F, 0.F,
0.F, 10.F, 0.F,
0.F, 0.F, 10.F
};
float[] logitsArray = {10.F, 0.F, 0.F, 0.F, 10.F, 0.F, 0.F, 0.F, 10.F};
yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
Expand All @@ -85,11 +77,7 @@ public void testInvalidPredictionsRange() {
Ops tf = testSession.getTF();
CategoricalCrossentropy instance = new CategoricalCrossentropy();

float[] trueArray = {
1L, 0L, 0L,
0L, 1L, 0L,
0L, 0L, 1L
};
float[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L};
float[] predArray = {-1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F};
Operand<TFloat32> yTrue =
tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Expand All @@ -111,23 +99,15 @@ public void testUnweighted() {
CategoricalCrossentropy instance = new CategoricalCrossentropy();

int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
float[] predArray = {
.9F, .05F, .05F,
.5F, .89F, .6F,
.05F, .01F, .94F
};
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> loss = instance.call(tf, yTrue, yPred);
float expected = 0.32396814F;
testSession.evaluate(expected, loss);

// Test with logits.
float[] logitsArray = {
8.F, 1.F, 1.F,
0.F, 9.F, 1.F,
2.F, 3.F, 5.F
};
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
Expand All @@ -145,16 +125,8 @@ public void testScalarWeighted() {
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
Ops tf = testSession.getTF();

int[] trueArray = {
1, 0, 0,
0, 1, 0,
0, 0, 1
};
float[] predArray = {
.9F, .05F, .05F,
.5F, .89F, .6F,
.05F, .01F, .94F
};
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> sampleWeight = tf.constant(2.3F);
Expand All @@ -166,11 +138,7 @@ public void testScalarWeighted() {
testSession.evaluate(expected, loss);

// Test with logits.
float[] logitsArray = {
8.F, 1.F, 1.F,
0.F, 9.F, 1.F,
2.F, 3.F, 5.F
};
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
Expand All @@ -189,16 +157,8 @@ public void testSsampleWeighted() {
CategoricalCrossentropy instance = new CategoricalCrossentropy();

float[] sampeWeightArray = {1.2F, 3.4F, 5.6F};
int[] trueArray = {
1, 0, 0,
0, 1, 0,
0, 0, 1
};
float[] predArray = {
.9F, .05F, .05F,
.5F, .89F, .6F,
.05F, .01F, .94F
};
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> sampleWeight =
Expand All @@ -208,11 +168,7 @@ public void testSsampleWeighted() {
testSession.evaluate(expected, loss);

// Test with logits.
float[] logitsArray = {
8.F, 1.F, 1.F,
0.F, 9.F, 1.F,
2.F, 3.F, 5.F
};
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
Expand All @@ -231,11 +187,7 @@ public void testNoReduction() {

// Test with logits.
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
float[] logitsArray = {
8.F, 1.F, 1.F,
0.F, 9.F, 1.F,
2.F, 3.F, 5.F
};
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
Expand Down Expand Up @@ -266,4 +218,34 @@ public void testLabelSmoothing() {
testSession.evaluate(expected, loss);
}
}

@Test
public void testCategoricalCrossEntopyWithDynamicBatchSize() {
try (Graph graph = new Graph()) {
Ops tf = Ops.create(graph);
Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3)));
Operand yTrue =
tf.reshape(tf.constant(new float[] {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}), tf.array(3, 3));
CategoricalCrossentropy instance = new CategoricalCrossentropy(true);
Operand loss =
instance.call(tf, yTrue, yPred); // Throw TFInvalidArgument Exception without fix
try (Session session = new Session(graph);
TFloat32 result =
(TFloat32)
session
.runner()
.feed(
yPred,
TFloat32.tensorOf(
Shape.of(3, 3),
DataBuffers.of(
new float[] {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f})))
.fetch(loss)
.run()
.get(0)) {
if (Math.abs(0.5514477f - result.getFloat()) > 0.01)
throw new IllegalStateException("Invalid result :" + result.getFloat());
}
}
}
}

0 comments on commit 2b6d83f

Please sign in to comment.