Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #526 #527

Merged
merged 6 commits into from
Mar 8, 2024
Merged

fix #526 #527

merged 6 commits into from
Mar 8, 2024

Conversation

nfeybesse
Copy link
Contributor

No description provided.

@Craigacp
Copy link
Collaborator

Craigacp commented Mar 4, 2024

Can you add a test which triggers this bug to the cross entropy tests - https://github.com/tensorflow/java/blob/master/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java. I think this used to work so I worry it's due to a TF upgrade and we didn't catch it with tests.

@nfeybesse
Copy link
Contributor Author

No, the problem is older, and it is probably the dynamic batch size which triggers the problems. I will try to do a test case

@nfeybesse
Copy link
Contributor Author

@test
public void testCategoricalCrossEntopyWithDynamicBatchSize() {
try (Graph graph = new Graph()) {
Ops tf = Ops.create(graph);
Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3)));
Operand yTrue = tf.reshape(tf.constant(new float[] { 1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f }), tf.array(3, 3));
CategoricalCrossentropy instance = new CategoricalCrossentropy(true);
Operand loss = instance.call(tf, yTrue, yPred);// Throw TFInvalidArgument Exception without fix
try (Session session = new Session(graph); TFloat32 result = (TFloat32) session.runner().feed(yPred, TFloat32.tensorOf(Shape.of(3, 3), DataBuffers.of(new float[] { 1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f }))).fetch(loss).run().get(0)) {
if (Math.abs(0.5514477f - result.getFloat()) > 0.01)
throw new IllegalStateException("Invalid result :" + result.getFloat());
}
}
}

@nfeybesse
Copy link
Contributor Author

I am confused because I have the impression that there has not yet been a test carried out by feeding a model with batches of dynamic size. I know from experience that it is largely possible, but that you have to track down a few small bugs. How would you integrate my test so that it would be suitable?

@Craigacp
Copy link
Collaborator

Craigacp commented Mar 6, 2024

Add it next to the other tests for that loss. If there are more issues then let's fix them.

More of the framework was in flight a couple of years ago, but we didn't get all of it merged, so I assume that some of those things were tested in the original codebase before it was broken up into smaller PRs.

@Craigacp Craigacp merged commit 2b6d83f into tensorflow:master Mar 8, 2024
9 checks passed
@Craigacp
Copy link
Collaborator

Craigacp commented Mar 8, 2024

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants