Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove softmax workaround for NPU #237

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion image_classification/.eslintrc.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module.exports = {
globals: {
'MLGraphBuilder': 'readonly',
'tf': 'readonly',
},
};
9 changes: 3 additions & 6 deletions image_classification/efficientnet_fp16_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,9 @@ export class EfficientNetFP16Nchw {
const pool1 = this.builder_.averagePool2d(await conv22);
const reshape = this.builder_.reshape(pool1, [1, 1280]);
const gemm = this.buildGemm_(reshape, '0');
if (contextOptions.deviceType === 'npu') {
return this.builder_.cast(await gemm, 'float32');
} else {
const softmax = this.builder_.softmax(await gemm);
return this.builder_.cast(softmax, 'float32');
}
const softmax = this.builder_.softmax(await gemm);

return this.builder_.cast(softmax, 'float32');
}

async build(outputOperand) {
Expand Down
3 changes: 0 additions & 3 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,6 @@ <h2 class="text-uppercase text-info">No model selected</h2>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/umd/popper.min.js"
integrity="sha384-9/reFTGAW83EW2RDu2S0VKaIzap3H66lZH81PoYlFhbGU+6BZp6G7niu735Sk7lN"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"
integrity="sha256-28ZvjeNGrGNEIj9/2D8YAPE6Vm5JSvvDs+LI4ED31x8="
crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"
integrity="sha384-B4gt1jrGC7Jh4AgTPSdUtOBvfO8shuf57BaghqFfPlYxofvL8/KUEfYiJOMMV+rV"
crossorigin="anonymous"></script>
Expand Down
12 changes: 0 additions & 12 deletions image_classification/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,6 @@ async function renderCamStream() {

// Get top 3 classes of labels from output buffer
function getTopClasses(buffer, labels) {
// Currently we need to fallback softmax to tf.softmax because
// NPU dosen't support softmax.
// TODO: Remove this workaround once NPU supports softmax.
if (deviceType === 'npu') {
// Softmax
buffer = tf.tidy(() => {
const a =
tf.tensor(buffer, netInstance.outputDimensions, 'float32');
const b = tf.softmax(a);
return b.dataSync();
});
}
const probs = Array.from(buffer);
const indexes = probs.map((prob, index) => [prob, index]);
const sorted = indexes.sort((a, b) => {
Expand Down
8 changes: 2 additions & 6 deletions image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,8 @@ export class MobileNetV2Nchw {
{groups: 1280, strides: [7, 7]});
const conv5 = this.buildConv_(await conv4, '104', false);
const reshape = this.builder_.reshape(await conv5, [1, 1000]);
if (contextOptions.deviceType === 'npu') {
return this.builder_.cast(reshape, 'float32');
} else {
const softmax = this.builder_.softmax(reshape);
return this.builder_.cast(softmax, 'float32');
}
const softmax = this.builder_.softmax(reshape);
return this.builder_.cast(softmax, 'float32');
}
}

Expand Down
8 changes: 2 additions & 6 deletions image_classification/resnet50v1_fp16_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,8 @@ export class ResNet50V1FP16Nchw {
const pool2 = this.builder_.averagePool2d(await bottleneck16);
const reshape = this.builder_.reshape(pool2, [1, 2048]);
const gemm = this.buildGemm_(reshape, '0');
if (contextOptions.deviceType === 'npu') {
return this.builder_.cast(await gemm, 'float32');
} else {
const softmax = this.builder_.softmax(await gemm);
return this.builder_.cast(softmax, 'float32');
}
const softmax = this.builder_.softmax(await gemm);
return this.builder_.cast(softmax, 'float32');
}

async build(outputOperand) {
Expand Down