-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WebNN: Implement
scatterElements
operator in DirectML backend
The `scatterElements` operator is proposed by WebML WG [1] for supporting popular transformer-based models. This CL adds the IDL and mojo definitions of scatterElements, and implements it in the DirectML backend by mapping to `DML_OPERATOR_SCATTER` [2]. This CL also adds the `scatterElements` validation and conformance tests into WPT. [1]: webmachinelearning/webnn#375 (comment) [2]: https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_scatter_operator_desc Bug: 370536101,370538328 Change-Id: Ifb73bed5eb05cb919b106b4aaea5127ec099edb2 Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5921136 Reviewed-by: Alex Gough <[email protected]> Reviewed-by: Weizhong Xia <[email protected]> Auto-Submit: ningxin hu <[email protected]> Commit-Queue: ningxin hu <[email protected]> Commit-Queue: Weizhong Xia <[email protected]> Reviewed-by: Rafael Cintron <[email protected]> Reviewed-by: Austin Sullivan <[email protected]> Cr-Commit-Position: refs/heads/main@{#1368312}
- Loading branch information
1 parent
a2dc005
commit 74fc762
Showing
2 changed files
with
241 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// META: title=test WebNN API scatterElements operation | ||
// META: global=window,dedicatedworker | ||
// META: variant=?cpu | ||
// META: variant=?gpu | ||
// META: variant=?npu | ||
// META: script=../resources/utils.js | ||
// META: timeout=long | ||
|
||
'use strict'; | ||
|
||
const getScatterElementsPrecisionTolerance = () => { | ||
return {metricType: 'ULP', value: 0}; | ||
}; | ||
|
||
const scatterElementsTests = [ | ||
{ | ||
'name': 'Scatter elements along axis 0', | ||
'graph': { | ||
'inputs': { | ||
'input': { | ||
'data': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | ||
'descriptor': {shape: [3, 3], dataType: 'float32'} | ||
}, | ||
'indices': { | ||
'data': [1, 0, 2, 0, 2, 1], | ||
'descriptor': {shape: [2, 3], dataType: 'int32'}, | ||
}, | ||
'updates': { | ||
'data': [1.0, 1.1, 1.2, 2.0, 2.1, 2.2], | ||
'descriptor': {shape: [2, 3], dataType: 'float32'} | ||
} | ||
}, | ||
'operators': [{ | ||
'name': 'scatterElements', | ||
'arguments': [ | ||
{'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, | ||
{'options': {'axis': 0}} | ||
], | ||
'outputs': 'output' | ||
}], | ||
'expectedOutputs': { | ||
'output': { | ||
'data': [2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2], | ||
'descriptor': {shape: [3, 3], dataType: 'float32'} | ||
} | ||
} | ||
} | ||
}, | ||
{ | ||
'name': 'Scatter elements along axis 1', | ||
'graph': { | ||
'inputs': { | ||
'input': { | ||
'data': [1.0, 2.0, 3.0, 4.0, 5.0], | ||
'descriptor': {shape: [1, 5], dataType: 'float32'} | ||
}, | ||
'indices': { | ||
'data': [1, 3], | ||
'descriptor': {shape: [1, 2], dataType: 'int32'}, | ||
}, | ||
'updates': { | ||
'data': [1.1, 2.1], | ||
'descriptor': {shape: [1, 2], dataType: 'float32'} | ||
} | ||
}, | ||
'operators': [{ | ||
'name': 'scatterElements', | ||
'arguments': [ | ||
{'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, | ||
{'options': {'axis': 1}} | ||
], | ||
'outputs': 'output' | ||
}], | ||
'expectedOutputs': { | ||
'output': { | ||
'data': [1.0, 1.1, 3.0, 2.1, 5.0], | ||
'descriptor': {shape: [1, 5], dataType: 'float32'} | ||
} | ||
} | ||
} | ||
} | ||
]; | ||
|
||
if (navigator.ml) { | ||
scatterElementsTests.forEach((test) => { | ||
webnn_conformance_test( | ||
buildGraphAndCompute, getScatterElementsPrecisionTolerance, test); | ||
}); | ||
} else { | ||
test(() => assert_implements(navigator.ml, 'missing navigator.ml')); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// META: title=validation tests for WebNN API scatterElements operation | ||
// META: global=window,dedicatedworker | ||
// META: variant=?cpu | ||
// META: variant=?gpu | ||
// META: variant=?npu | ||
// META: script=../resources/utils_validation.js | ||
|
||
'use strict'; | ||
|
||
const tests = [ | ||
{ | ||
name: '[scatterElements] Test scatterElements with default options', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3]}, | ||
output: {dataType: 'float32', shape: [3, 3]} | ||
}, | ||
{ | ||
name: '[scatterElements] Test scatterElements with axis = 0', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3]}, | ||
axis: 0, | ||
output: {dataType: 'float32', shape: [3, 3]} | ||
}, | ||
{ | ||
name: '[scatterElements] Test scatterElements with axis = 1', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [3, 2]}, | ||
updates: {dataType: 'float32', shape: [3, 2]}, | ||
axis: 1, | ||
output: {dataType: 'float32', shape: [3, 3]} | ||
}, | ||
{ | ||
name: '[scatterElements] Throw if axis is greater than input rank', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3]}, | ||
axis: 2 | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if updates tensor data type is not the same as input data type', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float16', shape: [2, 3]}, | ||
}, | ||
{ | ||
name: '[scatterElements] Throw if input, indices and updates are scalar', | ||
input: {dataType: 'float32', shape: []}, | ||
indices: {dataType: 'int32', shape: []}, | ||
updates: {dataType: 'float32', shape: []}, | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices rank is not the same as input rank', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3, 3]}, | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices size is not the same as input size along axis 1', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 4]}, | ||
updates: {dataType: 'float32', shape: [2, 4]}, | ||
axis: 0 | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices size is not the same as input size along axis 0', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 2]}, | ||
updates: {dataType: 'float32', shape: [2, 2]}, | ||
axis: 1 | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices rank is not the same as updates rank', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3, 3]}, | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices shape is not the same as updates shape', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 4]}, | ||
} | ||
]; | ||
|
||
tests.forEach( | ||
test => promise_test(async t => { | ||
const builder = new MLGraphBuilder(context); | ||
const input = builder.input('input', test.input); | ||
const indices = builder.input('indices', test.indices); | ||
const updates = builder.input('updates', test.updates); | ||
|
||
const options = {}; | ||
if (test.axis) { | ||
options.axis = test.axis; | ||
} | ||
|
||
if (test.output) { | ||
const output = | ||
builder.scatterElements(input, indices, updates, options); | ||
assert_equals(output.dataType(), test.output.dataType); | ||
assert_array_equals(output.shape(), test.output.shape); | ||
} else { | ||
const label = 'a_scatter_elements' | ||
options.label = label; | ||
const regexp = new RegExp('\\[' + label + '\\]'); | ||
assert_throws_with_label( | ||
() => builder.scatterElements(input, indices, updates, options), | ||
regexp); | ||
} | ||
}, test.name)); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = | ||
otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]}); | ||
const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); | ||
const updates = | ||
builder.input('updates', {dataType: 'float32', shape: [2, 3]}); | ||
|
||
assert_throws_js( | ||
TypeError, () => builder.scatterElements(input, indices, updates)); | ||
}, '[scatterElements] Throw if input is from another builder'); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); | ||
const indices = | ||
otherBuilder.input('indices', {dataType: 'int32', shape: [2, 3]}); | ||
const updates = | ||
builder.input('updates', {dataType: 'float32', shape: [2, 3]}); | ||
|
||
assert_throws_js( | ||
TypeError, () => builder.scatterElements(input, indices, updates)); | ||
}, '[scatterElements] Throw if indices is from another builder'); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); | ||
const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); | ||
const updates = | ||
otherBuilder.input('updates', {dataType: 'float32', shape: [2, 3]}); | ||
|
||
assert_throws_js( | ||
TypeError, () => builder.scatterElements(input, indices, updates)); | ||
}, '[scatterElements] Throw if updates is from another builder'); |