From 74fc76284bcb36946f540d2b14da069cd333b048 Mon Sep 17 00:00:00 2001 From: Ningxin Hu Date: Mon, 14 Oct 2024 10:18:11 -0700 Subject: [PATCH] WebNN: Implement `scatterElements` operator in DirectML backend The `scatterElements` operator is proposed by WebML WG [1] for supporting popular transformer-based models. This CL adds the IDL and mojo definitions of scatterElements, and implements it in the DirectML backend by mapping to `DML_OPERATOR_SCATTER` [2]. This CL also adds the `scatterElements` validation and conformance tests into WPT. [1]: https://github.com/webmachinelearning/webnn/issues/375#issuecomment-2292466613 [2]: https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_scatter_operator_desc Bug: 370536101,370538328 Change-Id: Ifb73bed5eb05cb919b106b4aaea5127ec099edb2 Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5921136 Reviewed-by: Alex Gough Reviewed-by: Weizhong Xia Auto-Submit: ningxin hu Commit-Queue: ningxin hu Commit-Queue: Weizhong Xia Reviewed-by: Rafael Cintron Reviewed-by: Austin Sullivan Cr-Commit-Position: refs/heads/main@{#1368312} --- .../scatterElements.https.any.js | 91 +++++++++++ .../scatterElements.https.any.js | 150 ++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 webnn/conformance_tests/scatterElements.https.any.js create mode 100644 webnn/validation_tests/scatterElements.https.any.js diff --git a/webnn/conformance_tests/scatterElements.https.any.js b/webnn/conformance_tests/scatterElements.https.any.js new file mode 100644 index 00000000000000..561260d47ecf66 --- /dev/null +++ b/webnn/conformance_tests/scatterElements.https.any.js @@ -0,0 +1,91 @@ +// META: title=test WebNN API scatterElements operation +// META: global=window,dedicatedworker +// META: variant=?cpu +// META: variant=?gpu +// META: variant=?npu +// META: script=../resources/utils.js +// META: timeout=long + +'use strict'; + +const getScatterElementsPrecisionTolerance = () => { + return {metricType: 'ULP', value: 0}; +}; + +const scatterElementsTests = [ + { + 'name': 'Scatter elements along axis 0', + 'graph': { + 'inputs': { + 'input': { + 'data': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + 'descriptor': {shape: [3, 3], dataType: 'float32'} + }, + 'indices': { + 'data': [1, 0, 2, 0, 2, 1], + 'descriptor': {shape: [2, 3], dataType: 'int32'}, + }, + 'updates': { + 'data': [1.0, 1.1, 1.2, 2.0, 2.1, 2.2], + 'descriptor': {shape: [2, 3], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'scatterElements', + 'arguments': [ + {'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, + {'options': {'axis': 0}} + ], + 'outputs': 'output' + }], + 'expectedOutputs': { + 'output': { + 'data': [2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2], + 'descriptor': {shape: [3, 3], dataType: 'float32'} + } + } + } + }, + { + 'name': 'Scatter elements along axis 1', + 'graph': { + 'inputs': { + 'input': { + 'data': [1.0, 2.0, 3.0, 4.0, 5.0], + 'descriptor': {shape: [1, 5], dataType: 'float32'} + }, + 'indices': { + 'data': [1, 3], + 'descriptor': {shape: [1, 2], dataType: 'int32'}, + }, + 'updates': { + 'data': [1.1, 2.1], + 'descriptor': {shape: [1, 2], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'scatterElements', + 'arguments': [ + {'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, + {'options': {'axis': 1}} + ], + 'outputs': 'output' + }], + 'expectedOutputs': { + 'output': { + 'data': [1.0, 1.1, 3.0, 2.1, 5.0], + 'descriptor': {shape: [1, 5], dataType: 'float32'} + } + } + } + } +]; + +if (navigator.ml) { + scatterElementsTests.forEach((test) => { + webnn_conformance_test( + buildGraphAndCompute, getScatterElementsPrecisionTolerance, test); + }); +} else { + test(() => assert_implements(navigator.ml, 'missing navigator.ml')); +} diff --git a/webnn/validation_tests/scatterElements.https.any.js b/webnn/validation_tests/scatterElements.https.any.js new file mode 100644 index 00000000000000..15551b2bbe5b48 --- /dev/null +++ b/webnn/validation_tests/scatterElements.https.any.js @@ -0,0 +1,150 @@ +// META: title=validation tests for WebNN API scatterElements operation +// META: global=window,dedicatedworker +// META: variant=?cpu +// META: variant=?gpu +// META: variant=?npu +// META: script=../resources/utils_validation.js + +'use strict'; + +const tests = [ + { + name: '[scatterElements] Test scatterElements with default options', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Test scatterElements with axis = 0', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + axis: 0, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Test scatterElements with axis = 1', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [3, 2]}, + updates: {dataType: 'float32', shape: [3, 2]}, + axis: 1, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Throw if axis is greater than input rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + axis: 2 + }, + { + name: + '[scatterElements] Throw if updates tensor data type is not the same as input data type', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float16', shape: [2, 3]}, + }, + { + name: '[scatterElements] Throw if input, indices and updates are scalar', + input: {dataType: 'float32', shape: []}, + indices: {dataType: 'int32', shape: []}, + updates: {dataType: 'float32', shape: []}, + }, + { + name: + '[scatterElements] Throw if indices rank is not the same as input rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3, 3]}, + updates: {dataType: 'float32', shape: [2, 3, 3]}, + }, + { + name: + '[scatterElements] Throw if indices size is not the same as input size along axis 1', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 4]}, + updates: {dataType: 'float32', shape: [2, 4]}, + axis: 0 + }, + { + name: + '[scatterElements] Throw if indices size is not the same as input size along axis 0', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 2]}, + updates: {dataType: 'float32', shape: [2, 2]}, + axis: 1 + }, + { + name: + '[scatterElements] Throw if indices rank is not the same as updates rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3, 3]}, + }, + { + name: + '[scatterElements] Throw if indices shape is not the same as updates shape', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 4]}, + } +]; + +tests.forEach( + test => promise_test(async t => { + const builder = new MLGraphBuilder(context); + const input = builder.input('input', test.input); + const indices = builder.input('indices', test.indices); + const updates = builder.input('updates', test.updates); + + const options = {}; + if (test.axis) { + options.axis = test.axis; + } + + if (test.output) { + const output = + builder.scatterElements(input, indices, updates, options); + assert_equals(output.dataType(), test.output.dataType); + assert_array_equals(output.shape(), test.output.shape); + } else { + const label = 'a_scatter_elements' + options.label = label; + const regexp = new RegExp('\\[' + label + '\\]'); + assert_throws_with_label( + () => builder.scatterElements(input, indices, updates, options), + regexp); + } + }, test.name)); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = + otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + builder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if input is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = + otherBuilder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + builder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if indices is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + otherBuilder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if updates is from another builder');