Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #526 #527

Merged
merged 6 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,12 @@ public class SoftmaxCrossEntropyWithLogits {
* <p>Usage:
*
* <pre>
* Operand&lt;TFloat32&gt; logits =
* tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
* Operand&lt;TFloat32&gt; labels =
* tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
* Operand&lt;TFloat32&gt; output =
* tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
* // output Shape = [2]
* // dataType = FLOAT (1)
* // values { 0.169846, 0.824745 }
* Operand&lt;TFloat32&gt; logits = tf.constant(new float[][] { { 4.0F, 2.0F, 1.0F }, { 0.0F, 5.0F, 1.0F } });
* Operand&lt;TFloat32&gt; labels = tf.constant(new float[][] { { 1.0F, 0.0F, 0.0F }, { 0.0F, 0.8F, 0.2F } });
* Operand&lt;TFloat32&gt; output = tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
* // output Shape = [2]
* // dataType = FLOAT (1)
* // values { 0.169846, 0.824745 }
* </pre>
*
* <p>Backpropagation will happen into both <code>logits</code> and <code>labels</code>. To
Expand Down Expand Up @@ -157,7 +154,7 @@ public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntr
* @return the flattened logits
*/
private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) {
Operand<TInt64> one = Constant.scalarOf(scope, 1L);
Operand<TInt64> one = Constant.arrayOf(scope, 1L);

Shape shape = logits.shape();
int ndims = shape.numDimensions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.framework.utils.TestSession;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
Expand All @@ -36,16 +40,8 @@ public void testAllCorrectUnweighted() {
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
Ops tf = testSession.getTF();

long[] trueArray = {
1L, 0L, 0L,
0L, 1L, 0L,
0L, 0L, 1L
};
float[] predArray = {
1.F, 0.F, 0.F,
0.F, 1.F, 0.F,
0.F, 0.F, 1.F
};
long[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L};
float[] predArray = {1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F};
Operand<TInt64> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
CategoricalCrossentropy instance = new CategoricalCrossentropy();
Expand All @@ -55,11 +51,7 @@ public void testAllCorrectUnweighted() {
testSession.evaluate(expected, loss);

// Test with logits.
float[] logitsArray = {
10.F, 0.F, 0.F,
0.F, 10.F, 0.F,
0.F, 0.F, 10.F
};
float[] logitsArray = {10.F, 0.F, 0.F, 0.F, 10.F, 0.F, 0.F, 0.F, 10.F};
yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
Expand All @@ -85,11 +77,7 @@ public void testInvalidPredictionsRange() {
Ops tf = testSession.getTF();
CategoricalCrossentropy instance = new CategoricalCrossentropy();

float[] trueArray = {
1L, 0L, 0L,
0L, 1L, 0L,
0L, 0L, 1L
};
float[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L};
float[] predArray = {-1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F};
Operand<TFloat32> yTrue =
tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Expand All @@ -111,23 +99,15 @@ public void testUnweighted() {
CategoricalCrossentropy instance = new CategoricalCrossentropy();

int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
float[] predArray = {
.9F, .05F, .05F,
.5F, .89F, .6F,
.05F, .01F, .94F
};
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> loss = instance.call(tf, yTrue, yPred);
float expected = 0.32396814F;
testSession.evaluate(expected, loss);

// Test with logits.
float[] logitsArray = {
8.F, 1.F, 1.F,
0.F, 9.F, 1.F,
2.F, 3.F, 5.F
};
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
Expand All @@ -145,16 +125,8 @@ public void testScalarWeighted() {
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
Ops tf = testSession.getTF();

int[] trueArray = {
1, 0, 0,
0, 1, 0,
0, 0, 1
};
float[] predArray = {
.9F, .05F, .05F,
.5F, .89F, .6F,
.05F, .01F, .94F
};
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> sampleWeight = tf.constant(2.3F);
Expand All @@ -166,11 +138,7 @@ public void testScalarWeighted() {
testSession.evaluate(expected, loss);

// Test with logits.
float[] logitsArray = {
8.F, 1.F, 1.F,
0.F, 9.F, 1.F,
2.F, 3.F, 5.F
};
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
Expand All @@ -189,16 +157,8 @@ public void testSsampleWeighted() {
CategoricalCrossentropy instance = new CategoricalCrossentropy();

float[] sampeWeightArray = {1.2F, 3.4F, 5.6F};
int[] trueArray = {
1, 0, 0,
0, 1, 0,
0, 0, 1
};
float[] predArray = {
.9F, .05F, .05F,
.5F, .89F, .6F,
.05F, .01F, .94F
};
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F};
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> sampleWeight =
Expand All @@ -208,11 +168,7 @@ public void testSsampleWeighted() {
testSession.evaluate(expected, loss);

// Test with logits.
float[] logitsArray = {
8.F, 1.F, 1.F,
0.F, 9.F, 1.F,
2.F, 3.F, 5.F
};
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
instance = new CategoricalCrossentropy(true);
Expand All @@ -231,11 +187,7 @@ public void testNoReduction() {

// Test with logits.
int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1};
float[] logitsArray = {
8.F, 1.F, 1.F,
0.F, 9.F, 1.F,
2.F, 3.F, 5.F
};
float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F};
Operand<TInt32> yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3)));
Operand<TFloat32> logits =
tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3)));
Expand Down Expand Up @@ -266,4 +218,34 @@ public void testLabelSmoothing() {
testSession.evaluate(expected, loss);
}
}

@Test
public void testCategoricalCrossEntopyWithDynamicBatchSize() {
try (Graph graph = new Graph()) {
Ops tf = Ops.create(graph);
Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3)));
Operand yTrue =
tf.reshape(tf.constant(new float[] {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}), tf.array(3, 3));
CategoricalCrossentropy instance = new CategoricalCrossentropy(true);
Operand loss =
instance.call(tf, yTrue, yPred); // Throw TFInvalidArgument Exception without fix
try (Session session = new Session(graph);
TFloat32 result =
(TFloat32)
session
.runner()
.feed(
yPred,
TFloat32.tensorOf(
Shape.of(3, 3),
DataBuffers.of(
new float[] {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f})))
.fetch(loss)
.run()
.get(0)) {
if (Math.abs(0.5514477f - result.getFloat()) > 0.01)
throw new IllegalStateException("Invalid result :" + result.getFloat());
}
}
}
}
Loading