Skip to content

Commit

Permalink
webnn: implement reverse operator
Browse files Browse the repository at this point in the history
This CL adds IDL and mojo definition of reverse operator according to
the spec issue [1] and implements it on DirectML backend.

[1] webmachinelearning/webnn#773

Bug: 376707210
Change-Id: I0d42b49b87ce243db9d44512e6000c7ee901077b
Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel
  • Loading branch information
shiyi9801 authored and chromium-wpt-export-bot committed Nov 8, 2024
1 parent af91552 commit ca4b672
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 0 deletions.
158 changes: 158 additions & 0 deletions webnn/conformance_tests/reverse.https.any.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// META: title=test WebNN API reverse operation
// META: global=window,dedicatedworker
// META: variant=?cpu
// META: variant=?gpu
// META: variant=?npu
// META: script=../resources/utils.js
// META: timeout=long

'use strict';

// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-reverse-method
// Reverse the order of the input tensor along specified axes.
//
// dictionary MLReverseOptions : MLOperatorOptions {
// sequence<[EnforceRange] unsigned long> axes;
// };
//
// MLOperand reverse(MLOperand input, optional MLReverseOptions options = {});


const reverseTests = [
{
'name': 'reverse float32 2D input with default options',
'graph': {
'inputs': {
'reverseInput': {
'data': [
-30.0561466217041, 99.56941986083984, 88.04620361328125,
-91.87507629394531, -23.7972354888916, -91.28665161132812,
-63.15204620361328, 12.0669527053833, -96.1172866821289,
-44.77365493774414, -80.08650970458984, -64.43756866455078
],
'descriptor': {shape: [3, 4], dataType: 'float32'}
}
},
'operators': [{
'name': 'reverse',
'arguments': [{'input': 'reverseInput'}],
'outputs': 'reverseOutput'
}],
'expectedOutputs': {
'reverseOutput': {
'data': [
-64.43756866455078, -80.08650970458984, -44.77365493774414,
-96.1172866821289, 12.0669527053833, -63.15204620361328,
-91.28665161132812, -23.7972354888916, -91.87507629394531,
88.04620361328125, 99.56941986083984, -30.0561466217041
],
'descriptor': {shape: [3, 4], dataType: 'float32'}
}
}
}
},
{
'name': 'reverse float32 3D input options.axes=[1, 2]',
'graph': {
'inputs': {
'reverseInput': {
'data': [
-30.0561466217041, 99.56941986083984, 88.04620361328125,
-91.87507629394531, -23.7972354888916, -91.28665161132812,
-63.15204620361328, 12.0669527053833, -96.1172866821289,
-44.77365493774414, -80.08650970458984, -64.43756866455078
],
'descriptor': {shape: [3, 2, 2], dataType: 'float32'}
}
},
'operators': [{
'name': 'reverse',
'arguments': [{'input': 'reverseInput'}, {'options': {'axes': [1, 2]}}],
'outputs': 'reverseOutput'
}],
'expectedOutputs': {
'reverseOutput': {
'data': [
-91.87507629394531, 88.04620361328125, 99.56941986083984,
-30.0561466217041, 12.0669527053833, -63.15204620361328,
-91.28665161132812, -23.7972354888916, -64.43756866455078,
-80.08650970458984, -44.77365493774414, -96.1172866821289
],
'descriptor': {shape: [3, 2, 2], dataType: 'float32'}
}
}
}
},
{
'name': 'reverse float32 4D input options.axes=[3, 1]',
'graph': {
'inputs': {
'reverseInput': {
'data': [
-30.0561466217041, 99.56941986083984, 88.04620361328125,
-91.87507629394531, -23.7972354888916, -91.28665161132812,
-63.15204620361328, 12.0669527053833, -96.1172866821289,
-44.77365493774414, -80.08650970458984, -64.43756866455078
],
'descriptor': {shape: [3, 2, 1, 2], dataType: 'float32'}
}
},
'operators': [{
'name': 'reverse',
'arguments': [{'input': 'reverseInput'}, {'options': {'axes': [3, 1]}}],
'outputs': 'reverseOutput'
}],
'expectedOutputs': {
'reverseOutput': {
'data': [
-91.87507629394531, 88.04620361328125, 99.56941986083984,
-30.0561466217041, 12.0669527053833, -63.15204620361328,
-91.28665161132812, -23.7972354888916, -64.43756866455078,
-80.08650970458984, -44.77365493774414, -96.1172866821289
],
'descriptor': {shape: [3, 2, 1, 2], dataType: 'float32'}
}
}
}
},
{
'name': 'reverse float32 4D input options.axes=[]',
'graph': {
'inputs': {
'reverseInput': {
'data': [
-30.0561466217041, 99.56941986083984, 88.04620361328125,
-91.87507629394531, -23.7972354888916, -91.28665161132812,
-63.15204620361328, 12.0669527053833, -96.1172866821289,
-44.77365493774414, -80.08650970458984, -64.43756866455078
],
'descriptor': {shape: [2, 1, 2, 3], dataType: 'float32'}
}
},
'operators': [{
'name': 'reverse',
'arguments': [{'input': 'reverseInput'}, {'options': {'axes': []}}],
'outputs': 'reverseOutput'
}],
'expectedOutputs': {
'reverseOutput': {
'data': [
-30.0561466217041, 99.56941986083984, 88.04620361328125,
-91.87507629394531, -23.7972354888916, -91.28665161132812,
-63.15204620361328, 12.0669527053833, -96.1172866821289,
-44.77365493774414, -80.08650970458984, -64.43756866455078
],
'descriptor': {shape: [2, 1, 2, 3], dataType: 'float32'}
}
}
}
}
];

if (navigator.ml) {
reverseTests.forEach((test) => {
webnn_conformance_test(buildGraphAndCompute, getPrecisionTolerance, test);
});
} else {
test(() => assert_implements(navigator.ml, 'missing navigator.ml'));
}
59 changes: 59 additions & 0 deletions webnn/validation_tests/reverse.https.any.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// META: title=validation tests for WebNN API reverse operation
// META: global=window,dedicatedworker
// META: variant=?cpu
// META: variant=?gpu
// META: variant=?npu
// META: script=../resources/utils_validation.js

'use strict';

const tests = [
{
name: '[reverse] Test reverse with default options',
input: {dataType: 'float32', shape: [3, 3]},
output: {dataType: 'float32', shape: [3, 3]}
},
{
name: '[reverse] Test reverse with axes = [0, 1]',
input: {dataType: 'int32', shape: [1, 2, 3]},
axes: [0, 1],
output: {dataType: 'int32', shape: [1, 2, 3]}
},
{
name: '[reverse] Throw if axes is greater than input rank',
input: {dataType: 'float32', shape: [3, 3]},
axes: [3]
},
{
name: '[reverse] Throw if axes is duplicated',
input: {dataType: 'float32', shape: [1, 2, 3, 4]},
axes: [2, 2, 3]
}
];

tests.forEach(test => promise_test(async t => {
const builder = new MLGraphBuilder(context);
const input = builder.input('input', test.input);
const options = {};
if (test.axes) {
options.axes = test.axes;
}

if (test.output) {
const output = builder.reverse(input, options);
assert_equals(output.dataType, test.output.dataType);
assert_array_equals(output.shape, test.output.shape);
} else {
const label = 'reverse_1'
options.label = label;
const regexp = new RegExp('\\[' + label + '\\]');
assert_throws_with_label(
() => builder.reverse(input, options), regexp);
}
}, test.name));

multi_builder_test(async (t, builder, otherBuilder) => {
const input =
otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]});
assert_throws_js(TypeError, () => builder.reverse(input));
}, '[reverse] Throw if input is from another builder');

0 comments on commit ca4b672

Please sign in to comment.