From d64624cd4cfb136c60dd211dd22cdaad022d7379 Mon Sep 17 00:00:00 2001 From: Belem Zhang Date: Mon, 18 Mar 2024 16:36:52 +0800 Subject: [PATCH] Refactor the MobileNet NCHW .js to fetch in parallel (#200) --- image_classification/mobilenet_nchw.js | 80 +++++++++++++------------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/image_classification/mobilenet_nchw.js b/image_classification/mobilenet_nchw.js index e9c6aaa6..c62cb51e 100644 --- a/image_classification/mobilenet_nchw.js +++ b/image_classification/mobilenet_nchw.js @@ -25,54 +25,52 @@ export class MobileNetV2Nchw { async buildConv_(input, name, relu6 = true, options = {}) { const prefix = this.weightsUrl_ + 'conv_' + name; const weightsName = prefix + '_weight.npy'; - const weights = - await buildConstantByNpy(this.builder_, weightsName); + const weights = buildConstantByNpy(this.builder_, weightsName); const biasName = prefix + '_bias.npy'; - const bias = - await buildConstantByNpy(this.builder_, biasName); - options.bias = bias; + const bias = buildConstantByNpy(this.builder_, biasName); + options.bias = await bias; if (relu6) { // TODO: Set clamp activation to options once it's supported in // WebNN DML backend. // Implement `clip` by `clamp` of WebNN API if (this.deviceType_ == 'gpu') { return this.builder_.clamp( - this.builder_.conv2d(input, weights, options), + this.builder_.conv2d(await input, await weights, options), {minValue: 0, maxValue: 6}); } else { options.activation = this.builder_.clamp({minValue: 0, maxValue: 6}); } } - return this.builder_.conv2d(input, weights, options); + return this.builder_.conv2d(await input, await weights, options); } async buildGemm_(input, name) { const prefix = this.weightsUrl_ + 'gemm_' + name; const weightsName = prefix + '_weight.npy'; - const weights = await buildConstantByNpy(this.builder_, weightsName); + const weights = buildConstantByNpy(this.builder_, weightsName); const biasName = prefix + '_bias.npy'; - const bias = await buildConstantByNpy(this.builder_, biasName); - const options = {c: bias, bTranspose: true}; - return this.builder_.gemm(input, weights, options); + const bias = buildConstantByNpy(this.builder_, biasName); + const options = {c: await bias, bTranspose: true}; + return this.builder_.gemm(await input, await weights, options); } async buildLinearBottleneck_( input, convNameArray, group, stride, shortcut = true) { - const conv1x1Relu6 = await this.buildConv_(input, convNameArray[0]); + const conv1x1Relu6 = this.buildConv_(await input, convNameArray[0]); const options = { padding: [1, 1, 1, 1], groups: group, strides: [stride, stride], }; - const dwise3x3Relu6 = await this.buildConv_( - conv1x1Relu6, convNameArray[1], true, options); - const conv1x1Linear = await this.buildConv_( - dwise3x3Relu6, convNameArray[2], false); + const dwise3x3Relu6 = this.buildConv_( + await conv1x1Relu6, convNameArray[1], true, options); + const conv1x1Linear = this.buildConv_( + await dwise3x3Relu6, convNameArray[2], false); if (shortcut) { - return this.builder_.add(input, conv1x1Linear); + return this.builder_.add(await input, await conv1x1Linear); } - return conv1x1Linear; + return await conv1x1Linear; } async load(contextOptions) { @@ -84,49 +82,49 @@ export class MobileNetV2Nchw { dataType: 'float32', dimensions: this.inputOptions.inputDimensions, }); - const conv0 = await this.buildConv_( + const conv0 = this.buildConv_( data, '0', true, {padding: [1, 1, 1, 1], strides: [2, 2]}); - const conv1 = await this.buildConv_( + const conv1 = this.buildConv_( conv0, '2', true, {padding: [1, 1, 1, 1], groups: 32}); - const conv2 = await this.buildConv_(conv1, '4', false); - const bottleneck0 = await this.buildLinearBottleneck_( + const conv2 = this.buildConv_(conv1, '4', false); + const bottleneck0 = this.buildLinearBottleneck_( conv2, ['5', '7', '9'], 96, 2, false); - const bottleneck1 = await this.buildLinearBottleneck_( + const bottleneck1 = this.buildLinearBottleneck_( bottleneck0, ['10', '12', '14'], 144, 1); - const bottleneck2 = await this.buildLinearBottleneck_( + const bottleneck2 = this.buildLinearBottleneck_( bottleneck1, ['16', '18', '20'], 144, 2, false); - const bottleneck3 = await this.buildLinearBottleneck_( + const bottleneck3 = this.buildLinearBottleneck_( bottleneck2, ['21', '23', '25'], 192, 1); - const bottleneck4 = await this.buildLinearBottleneck_( + const bottleneck4 = this.buildLinearBottleneck_( bottleneck3, ['27', '29', '31'], 192, 1); - const bottleneck5 = await this.buildLinearBottleneck_( + const bottleneck5 = this.buildLinearBottleneck_( bottleneck4, ['33', '35', '37'], 192, 2, false); - const bottleneck6 = await this.buildLinearBottleneck_( + const bottleneck6 = this.buildLinearBottleneck_( bottleneck5, ['38', '40', '42'], 384, 1); - const bottleneck7 = await this.buildLinearBottleneck_( + const bottleneck7 = this.buildLinearBottleneck_( bottleneck6, ['44', '46', '48'], 384, 1); - const bottleneck8 = await this.buildLinearBottleneck_( + const bottleneck8 = this.buildLinearBottleneck_( bottleneck7, ['50', '52', '54'], 384, 1); - const bottleneck9 = await this.buildLinearBottleneck_( + const bottleneck9 = this.buildLinearBottleneck_( bottleneck8, ['56', '58', '60'], 384, 1, false); - const bottleneck10 = await this.buildLinearBottleneck_( + const bottleneck10 = this.buildLinearBottleneck_( bottleneck9, ['61', '63', '65'], 576, 1); - const bottleneck11 = await this.buildLinearBottleneck_( + const bottleneck11 = this.buildLinearBottleneck_( bottleneck10, ['67', '69', '71'], 576, 1); - const bottleneck12 = await this.buildLinearBottleneck_( + const bottleneck12 = this.buildLinearBottleneck_( bottleneck11, ['73', '75', '77'], 576, 2, false); - const bottleneck13 = await this.buildLinearBottleneck_( + const bottleneck13 = this.buildLinearBottleneck_( bottleneck12, ['78', '80', '82'], 960, 1); - const bottleneck14 = await this.buildLinearBottleneck_( + const bottleneck14 = this.buildLinearBottleneck_( bottleneck13, ['84', '86', '88'], 960, 1); - const bottleneck15 = await this.buildLinearBottleneck_( + const bottleneck15 = this.buildLinearBottleneck_( bottleneck14, ['90', '92', '94'], 960, 1, false); - const conv3 = await this.buildConv_(bottleneck15, '95', true); - const pool = this.builder_.averagePool2d(conv3); + const conv3 = this.buildConv_(bottleneck15, '95', true); + const pool = this.builder_.averagePool2d(await conv3); const reshape = this.builder_.reshape(pool, [1, 1280]); - const gemm = await this.buildGemm_(reshape, '104'); - return this.builder_.softmax(gemm); + const gemm = this.buildGemm_(reshape, '104'); + return await this.builder_.softmax(await gemm); } async build(outputOperand) {