Skip to content

Commit

Permalink
Refactor the MobileNet NCHW .js to fetch in parallel (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
ibelem authored Mar 18, 2024
1 parent 9f9e7b4 commit d64624c
Showing 1 changed file with 39 additions and 41 deletions.
80 changes: 39 additions & 41 deletions image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,52 @@ export class MobileNetV2Nchw {
async buildConv_(input, name, relu6 = true, options = {}) {
const prefix = this.weightsUrl_ + 'conv_' + name;
const weightsName = prefix + '_weight.npy';
const weights =
await buildConstantByNpy(this.builder_, weightsName);
const weights = buildConstantByNpy(this.builder_, weightsName);
const biasName = prefix + '_bias.npy';
const bias =
await buildConstantByNpy(this.builder_, biasName);
options.bias = bias;
const bias = buildConstantByNpy(this.builder_, biasName);
options.bias = await bias;
if (relu6) {
// TODO: Set clamp activation to options once it's supported in
// WebNN DML backend.
// Implement `clip` by `clamp` of WebNN API
if (this.deviceType_ == 'gpu') {
return this.builder_.clamp(
this.builder_.conv2d(input, weights, options),
this.builder_.conv2d(await input, await weights, options),
{minValue: 0, maxValue: 6});
} else {
options.activation = this.builder_.clamp({minValue: 0, maxValue: 6});
}
}
return this.builder_.conv2d(input, weights, options);
return this.builder_.conv2d(await input, await weights, options);
}

async buildGemm_(input, name) {
const prefix = this.weightsUrl_ + 'gemm_' + name;
const weightsName = prefix + '_weight.npy';
const weights = await buildConstantByNpy(this.builder_, weightsName);
const weights = buildConstantByNpy(this.builder_, weightsName);
const biasName = prefix + '_bias.npy';
const bias = await buildConstantByNpy(this.builder_, biasName);
const options = {c: bias, bTranspose: true};
return this.builder_.gemm(input, weights, options);
const bias = buildConstantByNpy(this.builder_, biasName);
const options = {c: await bias, bTranspose: true};
return this.builder_.gemm(await input, await weights, options);
}

async buildLinearBottleneck_(
input, convNameArray, group, stride, shortcut = true) {
const conv1x1Relu6 = await this.buildConv_(input, convNameArray[0]);
const conv1x1Relu6 = this.buildConv_(await input, convNameArray[0]);
const options = {
padding: [1, 1, 1, 1],
groups: group,
strides: [stride, stride],
};
const dwise3x3Relu6 = await this.buildConv_(
conv1x1Relu6, convNameArray[1], true, options);
const conv1x1Linear = await this.buildConv_(
dwise3x3Relu6, convNameArray[2], false);
const dwise3x3Relu6 = this.buildConv_(
await conv1x1Relu6, convNameArray[1], true, options);
const conv1x1Linear = this.buildConv_(
await dwise3x3Relu6, convNameArray[2], false);

if (shortcut) {
return this.builder_.add(input, conv1x1Linear);
return this.builder_.add(await input, await conv1x1Linear);
}
return conv1x1Linear;
return await conv1x1Linear;
}

async load(contextOptions) {
Expand All @@ -84,49 +82,49 @@ export class MobileNetV2Nchw {
dataType: 'float32',
dimensions: this.inputOptions.inputDimensions,
});
const conv0 = await this.buildConv_(
const conv0 = this.buildConv_(
data, '0', true, {padding: [1, 1, 1, 1], strides: [2, 2]});
const conv1 = await this.buildConv_(
const conv1 = this.buildConv_(
conv0, '2', true, {padding: [1, 1, 1, 1], groups: 32});
const conv2 = await this.buildConv_(conv1, '4', false);
const bottleneck0 = await this.buildLinearBottleneck_(
const conv2 = this.buildConv_(conv1, '4', false);
const bottleneck0 = this.buildLinearBottleneck_(
conv2, ['5', '7', '9'], 96, 2, false);
const bottleneck1 = await this.buildLinearBottleneck_(
const bottleneck1 = this.buildLinearBottleneck_(
bottleneck0, ['10', '12', '14'], 144, 1);
const bottleneck2 = await this.buildLinearBottleneck_(
const bottleneck2 = this.buildLinearBottleneck_(
bottleneck1, ['16', '18', '20'], 144, 2, false);
const bottleneck3 = await this.buildLinearBottleneck_(
const bottleneck3 = this.buildLinearBottleneck_(
bottleneck2, ['21', '23', '25'], 192, 1);
const bottleneck4 = await this.buildLinearBottleneck_(
const bottleneck4 = this.buildLinearBottleneck_(
bottleneck3, ['27', '29', '31'], 192, 1);
const bottleneck5 = await this.buildLinearBottleneck_(
const bottleneck5 = this.buildLinearBottleneck_(
bottleneck4, ['33', '35', '37'], 192, 2, false);
const bottleneck6 = await this.buildLinearBottleneck_(
const bottleneck6 = this.buildLinearBottleneck_(
bottleneck5, ['38', '40', '42'], 384, 1);
const bottleneck7 = await this.buildLinearBottleneck_(
const bottleneck7 = this.buildLinearBottleneck_(
bottleneck6, ['44', '46', '48'], 384, 1);
const bottleneck8 = await this.buildLinearBottleneck_(
const bottleneck8 = this.buildLinearBottleneck_(
bottleneck7, ['50', '52', '54'], 384, 1);
const bottleneck9 = await this.buildLinearBottleneck_(
const bottleneck9 = this.buildLinearBottleneck_(
bottleneck8, ['56', '58', '60'], 384, 1, false);
const bottleneck10 = await this.buildLinearBottleneck_(
const bottleneck10 = this.buildLinearBottleneck_(
bottleneck9, ['61', '63', '65'], 576, 1);
const bottleneck11 = await this.buildLinearBottleneck_(
const bottleneck11 = this.buildLinearBottleneck_(
bottleneck10, ['67', '69', '71'], 576, 1);
const bottleneck12 = await this.buildLinearBottleneck_(
const bottleneck12 = this.buildLinearBottleneck_(
bottleneck11, ['73', '75', '77'], 576, 2, false);
const bottleneck13 = await this.buildLinearBottleneck_(
const bottleneck13 = this.buildLinearBottleneck_(
bottleneck12, ['78', '80', '82'], 960, 1);
const bottleneck14 = await this.buildLinearBottleneck_(
const bottleneck14 = this.buildLinearBottleneck_(
bottleneck13, ['84', '86', '88'], 960, 1);
const bottleneck15 = await this.buildLinearBottleneck_(
const bottleneck15 = this.buildLinearBottleneck_(
bottleneck14, ['90', '92', '94'], 960, 1, false);

const conv3 = await this.buildConv_(bottleneck15, '95', true);
const pool = this.builder_.averagePool2d(conv3);
const conv3 = this.buildConv_(bottleneck15, '95', true);
const pool = this.builder_.averagePool2d(await conv3);
const reshape = this.builder_.reshape(pool, [1, 1280]);
const gemm = await this.buildGemm_(reshape, '104');
return this.builder_.softmax(gemm);
const gemm = this.buildGemm_(reshape, '104');
return await this.builder_.softmax(await gemm);
}

async build(outputOperand) {
Expand Down

0 comments on commit d64624c

Please sign in to comment.