From a75b6de9b52cac1ceca31406b68f46df2757b591 Mon Sep 17 00:00:00 2001 From: Shiyi Zou Date: Tue, 12 Nov 2024 16:55:54 -0800 Subject: [PATCH] webnn: implement reverse operator This CL adds IDL and mojo definition of reverse operator according to the spec issue [1] and implements it on DirectML backend. [1] https://github.com/webmachinelearning/webnn/issues/773 Bug: 376707210 Change-Id: I0d42b49b87ce243db9d44512e6000c7ee901077b Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac15.arm64-blink-rel, linux-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5979825 Commit-Queue: ningxin hu Auto-Submit: Shiyi Zou Reviewed-by: Austin Sullivan Reviewed-by: ningxin hu Cr-Commit-Position: refs/heads/main@{#1382078} --- webnn/conformance_tests/reverse.https.any.js | 158 +++++++++++++++++++ webnn/validation_tests/reverse.https.any.js | 59 +++++++ 2 files changed, 217 insertions(+) create mode 100644 webnn/conformance_tests/reverse.https.any.js create mode 100644 webnn/validation_tests/reverse.https.any.js diff --git a/webnn/conformance_tests/reverse.https.any.js b/webnn/conformance_tests/reverse.https.any.js new file mode 100644 index 00000000000000..b198497613fac7 --- /dev/null +++ b/webnn/conformance_tests/reverse.https.any.js @@ -0,0 +1,158 @@ +// META: title=test WebNN API reverse operation +// META: global=window,dedicatedworker +// META: variant=?cpu +// META: variant=?gpu +// META: variant=?npu +// META: script=../resources/utils.js +// META: timeout=long + +'use strict'; + +// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-reverse-method +// Reverse the order of the input tensor along specified axes. +// +// dictionary MLReverseOptions : MLOperatorOptions { +// sequence<[EnforceRange] unsigned long> axes; +// }; +// +// MLOperand reverse(MLOperand input, optional MLReverseOptions options = {}); + + +const reverseTests = [ + { + 'name': 'reverse float32 2D input with default options', + 'graph': { + 'inputs': { + 'reverseInput': { + 'data': [ + -30.0561466217041, 99.56941986083984, 88.04620361328125, + -91.87507629394531, -23.7972354888916, -91.28665161132812, + -63.15204620361328, 12.0669527053833, -96.1172866821289, + -44.77365493774414, -80.08650970458984, -64.43756866455078 + ], + 'descriptor': {shape: [3, 4], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'reverse', + 'arguments': [{'input': 'reverseInput'}], + 'outputs': 'reverseOutput' + }], + 'expectedOutputs': { + 'reverseOutput': { + 'data': [ + -64.43756866455078, -80.08650970458984, -44.77365493774414, + -96.1172866821289, 12.0669527053833, -63.15204620361328, + -91.28665161132812, -23.7972354888916, -91.87507629394531, + 88.04620361328125, 99.56941986083984, -30.0561466217041 + ], + 'descriptor': {shape: [3, 4], dataType: 'float32'} + } + } + } + }, + { + 'name': 'reverse float32 3D input options.axes=[1, 2]', + 'graph': { + 'inputs': { + 'reverseInput': { + 'data': [ + -30.0561466217041, 99.56941986083984, 88.04620361328125, + -91.87507629394531, -23.7972354888916, -91.28665161132812, + -63.15204620361328, 12.0669527053833, -96.1172866821289, + -44.77365493774414, -80.08650970458984, -64.43756866455078 + ], + 'descriptor': {shape: [3, 2, 2], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'reverse', + 'arguments': [{'input': 'reverseInput'}, {'options': {'axes': [1, 2]}}], + 'outputs': 'reverseOutput' + }], + 'expectedOutputs': { + 'reverseOutput': { + 'data': [ + -91.87507629394531, 88.04620361328125, 99.56941986083984, + -30.0561466217041, 12.0669527053833, -63.15204620361328, + -91.28665161132812, -23.7972354888916, -64.43756866455078, + -80.08650970458984, -44.77365493774414, -96.1172866821289 + ], + 'descriptor': {shape: [3, 2, 2], dataType: 'float32'} + } + } + } + }, + { + 'name': 'reverse float32 4D input options.axes=[3, 1]', + 'graph': { + 'inputs': { + 'reverseInput': { + 'data': [ + -30.0561466217041, 99.56941986083984, 88.04620361328125, + -91.87507629394531, -23.7972354888916, -91.28665161132812, + -63.15204620361328, 12.0669527053833, -96.1172866821289, + -44.77365493774414, -80.08650970458984, -64.43756866455078 + ], + 'descriptor': {shape: [3, 2, 1, 2], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'reverse', + 'arguments': [{'input': 'reverseInput'}, {'options': {'axes': [3, 1]}}], + 'outputs': 'reverseOutput' + }], + 'expectedOutputs': { + 'reverseOutput': { + 'data': [ + -91.87507629394531, 88.04620361328125, 99.56941986083984, + -30.0561466217041, 12.0669527053833, -63.15204620361328, + -91.28665161132812, -23.7972354888916, -64.43756866455078, + -80.08650970458984, -44.77365493774414, -96.1172866821289 + ], + 'descriptor': {shape: [3, 2, 1, 2], dataType: 'float32'} + } + } + } + }, + { + 'name': 'reverse float32 4D input options.axes=[]', + 'graph': { + 'inputs': { + 'reverseInput': { + 'data': [ + -30.0561466217041, 99.56941986083984, 88.04620361328125, + -91.87507629394531, -23.7972354888916, -91.28665161132812, + -63.15204620361328, 12.0669527053833, -96.1172866821289, + -44.77365493774414, -80.08650970458984, -64.43756866455078 + ], + 'descriptor': {shape: [2, 1, 2, 3], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'reverse', + 'arguments': [{'input': 'reverseInput'}, {'options': {'axes': []}}], + 'outputs': 'reverseOutput' + }], + 'expectedOutputs': { + 'reverseOutput': { + 'data': [ + -30.0561466217041, 99.56941986083984, 88.04620361328125, + -91.87507629394531, -23.7972354888916, -91.28665161132812, + -63.15204620361328, 12.0669527053833, -96.1172866821289, + -44.77365493774414, -80.08650970458984, -64.43756866455078 + ], + 'descriptor': {shape: [2, 1, 2, 3], dataType: 'float32'} + } + } + } + } +]; + +if (navigator.ml) { + reverseTests.forEach((test) => { + webnn_conformance_test(buildGraphAndCompute, getPrecisionTolerance, test); + }); +} else { + test(() => assert_implements(navigator.ml, 'missing navigator.ml')); +} diff --git a/webnn/validation_tests/reverse.https.any.js b/webnn/validation_tests/reverse.https.any.js new file mode 100644 index 00000000000000..bee8f2c63de741 --- /dev/null +++ b/webnn/validation_tests/reverse.https.any.js @@ -0,0 +1,59 @@ +// META: title=validation tests for WebNN API reverse operation +// META: global=window,dedicatedworker +// META: variant=?cpu +// META: variant=?gpu +// META: variant=?npu +// META: script=../resources/utils_validation.js + +'use strict'; + +const tests = [ + { + name: '[reverse] Test reverse with default options', + input: {dataType: 'float32', shape: [3, 3]}, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[reverse] Test reverse with axes = [0, 1]', + input: {dataType: 'int32', shape: [1, 2, 3]}, + axes: [0, 1], + output: {dataType: 'int32', shape: [1, 2, 3]} + }, + { + name: '[reverse] Throw if axes is greater than input rank', + input: {dataType: 'float32', shape: [3, 3]}, + axes: [3] + }, + { + name: '[reverse] Throw if axes is duplicated', + input: {dataType: 'float32', shape: [1, 2, 3, 4]}, + axes: [2, 2, 3] + } +]; + +tests.forEach(test => promise_test(async t => { + const builder = new MLGraphBuilder(context); + const input = builder.input('input', test.input); + const options = {}; + if (test.axes) { + options.axes = test.axes; + } + + if (test.output) { + const output = builder.reverse(input, options); + assert_equals(output.dataType, test.output.dataType); + assert_array_equals(output.shape, test.output.shape); + } else { + const label = 'reverse_1' + options.label = label; + const regexp = new RegExp('\\[' + label + '\\]'); + assert_throws_with_label( + () => builder.reverse(input, options), regexp); + } + }, test.name)); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = + otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]}); + assert_throws_js(TypeError, () => builder.reverse(input)); +}, '[reverse] Throw if input is from another builder');