From 81eba24a60aac5dbe1f6c4c45475214bc1b6b1c8 Mon Sep 17 00:00:00 2001 From: miaobin Date: Mon, 18 Jul 2022 15:38:49 +0800 Subject: [PATCH] [WebNN-native/Node] support context-based compute graph --- examples/LeNet/LeNet.h | 4 +- examples/LeNet/Main.cpp | 2 +- examples/MobileNetV2/Main.cpp | 4 +- examples/ResNet/Main.cpp | 4 +- examples/SampleUtils.cpp | 5 +- examples/SampleUtils.h | 8 +- examples/SqueezeNet/Main.cpp | 4 +- node/src/Context.cpp | 31 ++++- node/src/Context.h | 1 + node/src/Graph.cpp | 121 +----------------- node/src/Graph.h | 5 +- node/src/GraphBuilder.cpp | 1 - node/src/Utils.h | 92 +++++++++++++ src/webnn/native/Context.cpp | 31 +++++ src/webnn/native/Context.h | 11 ++ src/webnn/native/Graph.cpp | 24 +--- src/webnn/native/Graph.h | 9 +- src/webnn/native/dmlx/GraphDMLX.h | 3 +- src/webnn/native/mlas/GraphMLAS.h | 3 +- src/webnn/native/nnapi/ContextNnapi.h | 1 + src/webnn/native/nnapi/GraphNnapi.h | 5 +- src/webnn/native/null/ContextNull.h | 3 +- src/webnn/native/onednn/GraphDNNL.h | 3 +- src/webnn/native/openvino/GraphIE.h | 3 +- src/webnn/native/xnnpack/GraphXNN.h | 3 +- src/webnn/tests/end2end/AddTests.cpp | 6 +- src/webnn/tests/end2end/BatchNormTests.cpp | 2 +- src/webnn/tests/end2end/ClampTests.cpp | 2 +- src/webnn/tests/end2end/ConcatTests.cpp | 2 +- src/webnn/tests/end2end/Conv2dTests.cpp | 2 +- .../tests/end2end/ConvTranspose2dTests.cpp | 2 +- src/webnn/tests/end2end/DivTests.cpp | 4 +- .../tests/end2end/ElementWiseUnaryTests.cpp | 2 +- src/webnn/tests/end2end/GemmTests.cpp | 4 +- src/webnn/tests/end2end/GruTests.cpp | 6 +- src/webnn/tests/end2end/HardSwishTests.cpp | 2 +- src/webnn/tests/end2end/InstanceNormTests.cpp | 2 +- src/webnn/tests/end2end/LeakyReluTests.cpp | 2 +- src/webnn/tests/end2end/MatMulTests.cpp | 18 +-- src/webnn/tests/end2end/MaxTests.cpp | 6 +- src/webnn/tests/end2end/MinTests.cpp | 6 +- src/webnn/tests/end2end/MulTests.cpp | 6 +- src/webnn/tests/end2end/PadTests.cpp | 2 +- src/webnn/tests/end2end/Pool2dTests.cpp | 90 ++++++------- src/webnn/tests/end2end/PowTests.cpp | 10 +- src/webnn/tests/end2end/ReduceTests.cpp | 2 +- src/webnn/tests/end2end/ReluTests.cpp | 2 +- src/webnn/tests/end2end/Resample2dTests.cpp | 2 +- src/webnn/tests/end2end/ReshapeTests.cpp | 2 +- src/webnn/tests/end2end/SigmoidTests.cpp | 4 +- src/webnn/tests/end2end/SliceTests.cpp | 2 +- src/webnn/tests/end2end/SoftmaxTests.cpp | 2 +- src/webnn/tests/end2end/SplitTests.cpp | 2 +- src/webnn/tests/end2end/SqueezeTests.cpp | 2 +- src/webnn/tests/end2end/SubTests.cpp | 4 +- src/webnn/tests/end2end/TanhTests.cpp | 2 +- src/webnn/tests/end2end/TransposeTests.cpp | 2 +- .../models/MobileNetV2BatchNormNchw.cpp | 2 +- .../tests/end2end/models/MobileNetV2Nchw.cpp | 2 +- .../tests/end2end/models/MobileNetV2Nhwc.cpp | 2 +- src/webnn/tests/end2end/models/ResNetNchw.cpp | 2 +- src/webnn/tests/end2end/models/ResNetNhwc.cpp | 2 +- .../tests/end2end/models/SqueezeNetNchw.cpp | 2 +- .../tests/end2end/models/SqueezeNetNhwc.cpp | 2 +- .../end2end/models/SuperResolutionNchw.cpp | 2 +- .../tests/unittests/native/mocks/GraphMock.h | 4 - src/webnn/wire/client/ClientDoers.cpp | 22 ++-- src/webnn/wire/client/Context.cpp | 58 +++++++++ src/webnn/wire/client/Context.h | 16 +++ src/webnn/wire/client/Graph.cpp | 53 -------- src/webnn/wire/client/Graph.h | 15 --- src/webnn/wire/server/Server.h | 8 +- src/webnn/wire/server/ServerContext.cpp | 89 +++++++++++++ src/webnn/wire/server/ServerGraph.cpp | 80 ------------ webnn.json | 42 +++--- webnn_wire.json | 44 ++++--- 76 files changed, 539 insertions(+), 493 deletions(-) diff --git a/examples/LeNet/LeNet.h b/examples/LeNet/LeNet.h index 740e8f6e6..1ad74d9a4 100644 --- a/examples/LeNet/LeNet.h +++ b/examples/LeNet/LeNet.h @@ -25,8 +25,6 @@ class LeNet { LeNet(); ~LeNet() = default; - wnn::Graph Build(const std::string& weigthsPath); - - private: wnn::Context mContext; + wnn::Graph Build(const std::string& weigthsPath); }; diff --git a/examples/LeNet/Main.cpp b/examples/LeNet/Main.cpp index 5e5f69129..20aba35df 100644 --- a/examples/LeNet/Main.cpp +++ b/examples/LeNet/Main.cpp @@ -87,7 +87,7 @@ int main(int argc, const char* argv[]) { for (int i = 0; i < nIter; ++i) { std::chrono::time_point executionStartTime = std::chrono::high_resolution_clock::now(); - utils::Compute(graph, {{"input", input}}, {{"output", result}}); + utils::Compute(lenet.mContext, graph, {{"input", input}}, {{"output", result}}); executionTimeVector.push_back(std::chrono::high_resolution_clock::now() - executionStartTime); } diff --git a/examples/MobileNetV2/Main.cpp b/examples/MobileNetV2/Main.cpp index 340d865e3..09256f504 100644 --- a/examples/MobileNetV2/Main.cpp +++ b/examples/MobileNetV2/Main.cpp @@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) { std::vector result(utils::SizeOfShape(mobilevetv2.mOutputShape)); // Do the first inference for warming up if nIter > 1. if (mobilevetv2.mNIter > 1) { - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); + utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}}); } std::vector executionTime; for (int i = 0; i < mobilevetv2.mNIter; ++i) { std::chrono::time_point executionStartTime = std::chrono::high_resolution_clock::now(); - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); + utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}}); executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime); } diff --git a/examples/ResNet/Main.cpp b/examples/ResNet/Main.cpp index 64773cfd3..d4e4a1355 100644 --- a/examples/ResNet/Main.cpp +++ b/examples/ResNet/Main.cpp @@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) { std::vector result(utils::SizeOfShape(resnet.mOutputShape)); // Do the first inference for warming up if nIter > 1. if (resnet.mNIter > 1) { - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); + utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}}); } std::vector executionTime; for (int i = 0; i < resnet.mNIter; ++i) { std::chrono::time_point executionStartTime = std::chrono::high_resolution_clock::now(); - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); + utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}}); executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime); } diff --git a/examples/SampleUtils.cpp b/examples/SampleUtils.cpp index 3e2d922aa..c8f3b6a73 100644 --- a/examples/SampleUtils.cpp +++ b/examples/SampleUtils.cpp @@ -304,10 +304,11 @@ namespace utils { return builder.Build(namedOperands); } - void Compute(const wnn::Graph& graph, + void Compute(const wnn::Context& context, + const wnn::Graph& graph, const std::vector>& inputs, const std::vector>& outputs) { - return Compute(graph, inputs, outputs); + return Compute(context, graph, inputs, outputs); } std::vector ReadTopKLabel(const std::vector& topKIndex, diff --git a/examples/SampleUtils.h b/examples/SampleUtils.h index fea8c6c7f..71ac91edb 100644 --- a/examples/SampleUtils.h +++ b/examples/SampleUtils.h @@ -247,7 +247,8 @@ namespace utils { }; template - void Compute(const wnn::Graph& graph, + void Compute(const wnn::Context& context, + const wnn::Graph& graph, const std::vector>& inputs, const std::vector>& outputs) { if (graph.GetHandle() == nullptr) { @@ -277,11 +278,12 @@ namespace utils { mlOutputs.push_back(resource); namedOutputs.Set(output.name.c_str(), &mlOutputs.back()); } - graph.Compute(namedInputs, namedOutputs); + context.ComputeSync(graph, namedInputs, namedOutputs); DoFlush(); } - void Compute(const wnn::Graph& graph, + void Compute(const wnn::Context& context, + const wnn::Graph& graph, const std::vector>& inputs, const std::vector>& outputs); diff --git a/examples/SqueezeNet/Main.cpp b/examples/SqueezeNet/Main.cpp index feba69e18..b227a1075 100644 --- a/examples/SqueezeNet/Main.cpp +++ b/examples/SqueezeNet/Main.cpp @@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) { std::vector result(utils::SizeOfShape(squeezenet.mOutputShape)); // Do the first inference for warming up if nIter > 1. if (squeezenet.mNIter > 1) { - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); + utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}}); } std::vector executionTime; for (int i = 0; i < squeezenet.mNIter; ++i) { std::chrono::time_point executionStartTime = std::chrono::high_resolution_clock::now(); - utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}}); + utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}}); executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime); } diff --git a/node/src/Context.cpp b/node/src/Context.cpp index 0371515a3..9c4605ed0 100644 --- a/node/src/Context.cpp +++ b/node/src/Context.cpp @@ -17,7 +17,9 @@ #include #include +#include "Graph.h" #include "ML.h" +#include "Utils.h" Napi::FunctionReference node::Context::constructor; @@ -90,10 +92,37 @@ namespace node { Napi::Object Context::Initialize(Napi::Env env, Napi::Object exports) { Napi::HandleScope scope(env); - Napi::Function func = DefineClass(env, "MLContext", {}); + Napi::Function func = DefineClass( + env, "MLContext", {InstanceMethod("compute", &Context::Compute, napi_enumerable)}); constructor = Napi::Persistent(func); constructor.SuppressDestruct(); exports.Set("MLContext", func); return exports; } + + Napi::Value Context::Compute(const Napi::CallbackInfo& info) { + // status compute(NamedInputs inputs, NamedOutputs outputs); + WEBNN_NODE_ASSERT(info.Length() == 3, "The number of arguments is invalid."); + Napi::Object object = info[0].As(); + node::Graph* jsGraph = Napi::ObjectWrap::Unwrap(object); + + std::map inputs; + WEBNN_NODE_ASSERT(GetNamedInputs(info[1], inputs), "The inputs parameter is invalid."); + + std::map outputs; + WEBNN_NODE_ASSERT(GetNamedOutputs(info[2], outputs), "The outputs parameter is invalid."); + + wnn::NamedInputs namedInputs = wnn::CreateNamedInputs(); + for (auto& input : inputs) { + namedInputs.Set(input.first.data(), input.second.AsPtr()); + } + wnn::NamedOutputs namedOutputs = wnn::CreateNamedOutputs(); + for (auto& output : outputs) { + namedOutputs.Set(output.first.data(), &output.second); + } + mImpl.ComputeSync(jsGraph->GetImpl(), namedInputs, namedOutputs); + + return Napi::Number::New(info.Env(), 0); + } + } // namespace node diff --git a/node/src/Context.h b/node/src/Context.h index 99d77c6fd..01bc60f6c 100644 --- a/node/src/Context.h +++ b/node/src/Context.h @@ -31,6 +31,7 @@ namespace node { wnn::Context GetImpl(); private: + Napi::Value Compute(const Napi::CallbackInfo& info); wnn::Context mImpl; }; diff --git a/node/src/Graph.cpp b/node/src/Graph.cpp index af4f09003..65640d225 100644 --- a/node/src/Graph.cpp +++ b/node/src/Graph.cpp @@ -13,135 +13,22 @@ // limitations under the License. #include "Graph.h" - -#include -#include - #include "Utils.h" namespace node { - struct Input { - public: - wnn::ArrayBufferView bufferView; - std::vector dimensions; - - const wnn::Input* AsPtr() { - mInput.resource.arrayBufferView = bufferView; - mInput.resource.gpuBufferView = {}; - if (!dimensions.empty()) { - mInput.dimensions = dimensions.data(); - mInput.dimensionsCount = dimensions.size(); - } - return &mInput; - } - - private: - wnn::Input mInput; - }; - - bool GetNamedInputs(const Napi::Value& jsValue, std::map& namedInputs) { - if (!jsValue.IsObject()) { - return false; - } - Napi::Object jsNamedInputs = jsValue.As(); - Napi::Array names = jsNamedInputs.GetPropertyNames(); - if (names.Length() == 0) { - return false; - } - // typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource; - // dictionary MLInput { - // required MLResource resource; - // required sequence dimensions; - // }; - // typedef record MLNamedInputs; - for (size_t i = 0; i < names.Length(); ++i) { - Input input = {}; - std::string name = names.Get(i).As().Utf8Value(); - // FIXME: validate the type of typed array. - Napi::TypedArray jsTypedArray; - if (jsNamedInputs.Get(name).IsTypedArray()) { - jsTypedArray = jsNamedInputs.Get(name).As(); - } else { - Napi::Object jsInput = jsNamedInputs.Get(name).As(); - if (!jsInput.Has("resource") || !jsInput.Has("dimensions")) { - // Input resource and dimensions are required. - return false; - } - if (!jsInput.Get("resource").IsTypedArray()) { - return false; - } - jsTypedArray = jsInput.Get("resource").As(); - - if (!GetArray(jsInput.Get("dimensions"), input.dimensions)) { - return false; - } - if (SizeOfShape(input.dimensions) != jsTypedArray.ElementSize()) { - return false; - } - } - if (!GetArrayBufferView(jsTypedArray, input.bufferView)) { - return false; - } - namedInputs[name] = input; - } - return true; + Graph::Graph(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { } - bool GetNamedOutputs(const Napi::Value& jsValue, - std::map& namedOutputs) { - if (!jsValue.IsObject()) { - return false; - } - Napi::Object jsNamedOutputs = jsValue.As(); - Napi::Array names = jsNamedOutputs.GetPropertyNames(); - if (names.Length() == 0) { - return false; - } - // typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource; - // typedef record MLNamedOutputs; - for (size_t i = 0; i < names.Length(); ++i) { - wnn::ArrayBufferView arrayBuffer = {}; - std::string name = names.Get(i).As().Utf8Value(); - if (!GetArrayBufferView(jsNamedOutputs.Get(name), arrayBuffer)) { - return false; - } - namedOutputs[name] = {arrayBuffer, {}}; - } - return true; + wnn::Graph Graph::GetImpl() { + return mImpl; } Napi::FunctionReference Graph::constructor; - Graph::Graph(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { - } - - Napi::Value Graph::Compute(const Napi::CallbackInfo& info) { - // status compute(NamedInputs inputs, NamedOutputs outputs); - WEBNN_NODE_ASSERT(info.Length() == 2, "The number of arguments is invalid."); - std::map inputs; - WEBNN_NODE_ASSERT(GetNamedInputs(info[0], inputs), "The inputs parameter is invalid."); - - std::map outputs; - WEBNN_NODE_ASSERT(GetNamedOutputs(info[1], outputs), "The outputs parameter is invalid."); - - wnn::NamedInputs namedInputs = wnn::CreateNamedInputs(); - for (auto& input : inputs) { - namedInputs.Set(input.first.data(), input.second.AsPtr()); - } - wnn::NamedOutputs namedOutputs = wnn::CreateNamedOutputs(); - for (auto& output : outputs) { - namedOutputs.Set(output.first.data(), &output.second); - } - mImpl.Compute(namedInputs, namedOutputs); - - return Napi::Number::New(info.Env(), 0); - } - Napi::Object Graph::Initialize(Napi::Env env, Napi::Object exports) { Napi::HandleScope scope(env); - Napi::Function func = DefineClass( - env, "MLGraph", {InstanceMethod("compute", &Graph::Compute, napi_enumerable)}); + Napi::Function func = DefineClass(env, "MLGraph", {}); constructor = Napi::Persistent(func); constructor.SuppressDestruct(); exports.Set("MLGraph", func); diff --git a/node/src/Graph.h b/node/src/Graph.h index a00d83883..c7f8d0dc6 100644 --- a/node/src/Graph.h +++ b/node/src/Graph.h @@ -32,14 +32,13 @@ namespace node { explicit Graph(const Napi::CallbackInfo& info); ~Graph() = default; + wnn::Graph GetImpl(); + private: friend BuildGraphWorker; friend GraphBuilder; - Napi::Value Compute(const Napi::CallbackInfo& info); - wnn::Graph mImpl; - std::vector mOutputNames; }; } // namespace node diff --git a/node/src/GraphBuilder.cpp b/node/src/GraphBuilder.cpp index d04c657bb..80cbfcbf5 100644 --- a/node/src/GraphBuilder.cpp +++ b/node/src/GraphBuilder.cpp @@ -306,7 +306,6 @@ namespace node { Napi::Object object = node::Graph::constructor.New({}); node::Graph* jsGraph = Napi::ObjectWrap::Unwrap(object); jsGraph->mImpl = graph; - jsGraph->mOutputNames = names; return object; } diff --git a/node/src/Utils.h b/node/src/Utils.h index c9c47b56c..6657c04f4 100644 --- a/node/src/Utils.h +++ b/node/src/Utils.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "Operand.h" @@ -516,6 +517,97 @@ namespace node { return jsOptions.Has(name) && !jsOptions.Get(name).IsUndefined(); } + struct Input { + public: + wnn::ArrayBufferView bufferView; + std::vector dimensions; + + const wnn::Input* AsPtr() { + mInput.resource.arrayBufferView = bufferView; + mInput.resource.gpuBufferView = {}; + if (!dimensions.empty()) { + mInput.dimensions = dimensions.data(); + mInput.dimensionsCount = dimensions.size(); + } + return &mInput; + } + + private: + wnn::Input mInput; + }; + + inline bool GetNamedInputs(const Napi::Value& jsValue, + std::map& namedInputs) { + if (!jsValue.IsObject()) { + return false; + } + Napi::Object jsNamedInputs = jsValue.As(); + Napi::Array names = jsNamedInputs.GetPropertyNames(); + if (names.Length() == 0) { + return false; + } + // typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource; + // dictionary MLInput { + // required MLResource resource; + // required sequence dimensions; + // }; + // typedef record MLNamedInputs; + for (size_t i = 0; i < names.Length(); ++i) { + Input input = {}; + std::string name = names.Get(i).As().Utf8Value(); + // FIXME: validate the type of typed array. + Napi::TypedArray jsTypedArray; + if (jsNamedInputs.Get(name).IsTypedArray()) { + jsTypedArray = jsNamedInputs.Get(name).As(); + } else { + Napi::Object jsInput = jsNamedInputs.Get(name).As(); + if (!jsInput.Has("resource") || !jsInput.Has("dimensions")) { + // Input resource and dimensions are required. + return false; + } + if (!jsInput.Get("resource").IsTypedArray()) { + return false; + } + jsTypedArray = jsInput.Get("resource").As(); + + if (!GetArray(jsInput.Get("dimensions"), input.dimensions)) { + return false; + } + if (SizeOfShape(input.dimensions) != jsTypedArray.ElementSize()) { + return false; + } + } + if (!GetArrayBufferView(jsTypedArray, input.bufferView)) { + return false; + } + namedInputs[name] = input; + } + return true; + } + + inline bool GetNamedOutputs(const Napi::Value& jsValue, + std::map& namedOutputs) { + if (!jsValue.IsObject()) { + return false; + } + Napi::Object jsNamedOutputs = jsValue.As(); + Napi::Array names = jsNamedOutputs.GetPropertyNames(); + if (names.Length() == 0) { + return false; + } + // typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource; + // typedef record MLNamedOutputs; + for (size_t i = 0; i < names.Length(); ++i) { + wnn::ArrayBufferView arrayBuffer = {}; + std::string name = names.Get(i).As().Utf8Value(); + if (!GetArrayBufferView(jsNamedOutputs.Get(name), arrayBuffer)) { + return false; + } + namedOutputs[name] = {arrayBuffer, {}}; + } + return true; + } + } // namespace node #endif // NODE_UTILS_H_ diff --git a/src/webnn/native/Context.cpp b/src/webnn/native/Context.cpp index dc7152502..875e0aeb6 100644 --- a/src/webnn/native/Context.cpp +++ b/src/webnn/native/Context.cpp @@ -14,6 +14,7 @@ #include "webnn/native/Context.h" +#include "webnn/native/Graph.h" #include "webnn/native/ValidationUtils_autogen.h" #include "webnn/native/webnn_platform.h" @@ -65,6 +66,36 @@ namespace webnn::native { } #endif + void ContextBase::Compute(GraphBase* graph, + NamedInputsBase* inputs, + NamedOutputsBase* outputs, + WNNComputeAsyncCallback callback, + void* userdata) { + if (inputs == nullptr || outputs == nullptr) { + callback(WNNErrorType_Validation, "named inputs or outputs is empty.", userdata); + } + MaybeError maybeError = ComputeImpl(graph, inputs, outputs); + if (maybeError.IsError()) { + std::unique_ptr errorData = maybeError.AcquireError(); + callback(static_cast(ToWNNErrorType(errorData->GetType())), + const_cast(errorData->GetMessage().c_str()), userdata); + } else { + callback(WNNErrorType_NoError, "", userdata); + } + } + + void ContextBase::ComputeSync(GraphBase* graph, + NamedInputsBase* inputs, + NamedOutputsBase* outputs) { + this->ConsumedError(ComputeImpl(graph, inputs, outputs)); + } + + MaybeError ContextBase::ComputeImpl(GraphBase* graph, + NamedInputsBase* inputs, + NamedOutputsBase* outputs) { + return graph->ComputeImpl(inputs, outputs); + } + void ContextBase::InjectError(wnn::ErrorType type, const char* message) { if (ConsumedError(ValidateErrorType(type))) { return; diff --git a/src/webnn/native/Context.h b/src/webnn/native/Context.h index d01cc16c8..5541e9c54 100644 --- a/src/webnn/native/Context.h +++ b/src/webnn/native/Context.h @@ -67,9 +67,20 @@ namespace webnn::native { return mContextOptions; } + // WebNN API + void Compute(GraphBase* graph, + NamedInputsBase* inputs, + NamedOutputsBase* outputs, + WNNComputeAsyncCallback callback, + void* userdata); + void ComputeSync(GraphBase* graph, NamedInputsBase* inputs, NamedOutputsBase* outputs); + private: // Create concrete model. virtual GraphBase* CreateGraphImpl() = 0; + MaybeError ComputeImpl(GraphBase* graph, + NamedInputsBase* inputs, + NamedOutputsBase* outputs); void HandleError(std::unique_ptr error); diff --git a/src/webnn/native/Graph.cpp b/src/webnn/native/Graph.cpp index d8359d042..060441a8e 100644 --- a/src/webnn/native/Graph.cpp +++ b/src/webnn/native/Graph.cpp @@ -32,9 +32,6 @@ namespace webnn::native { MaybeError CompileImpl() override { UNREACHABLE(); } - MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override { - return DAWN_INTERNAL_ERROR("fail to build graph!"); - } }; } // namespace @@ -137,25 +134,8 @@ namespace webnn::native { return CompileImpl(); } - void GraphBase::Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs) { - GetContext()->ConsumedError(ComputeImpl(inputs, outputs)); - } - - void GraphBase::ComputeAsync(NamedInputsBase* inputs, - NamedOutputsBase* outputs, - WNNComputeAsyncCallback callback, - void* userdata) { - if (inputs == nullptr || outputs == nullptr) { - callback(WNNErrorType_Validation, "named inputs or outputs is empty.", userdata); - } - MaybeError maybeError = ComputeImpl(inputs, outputs); - if (maybeError.IsError()) { - std::unique_ptr errorData = maybeError.AcquireError(); - callback(static_cast(ToWNNErrorType(errorData->GetType())), - const_cast(errorData->GetMessage().c_str()), userdata); - } else { - callback(WNNErrorType_NoError, "", userdata); - } + MaybeError GraphBase::ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) { + return DAWN_INTERNAL_ERROR("fail to build graph!"); } GraphBase::GraphBase(ContextBase* context, ObjectBase::ErrorTag tag) diff --git a/src/webnn/native/Graph.h b/src/webnn/native/Graph.h index 7709a3855..f79d3527c 100644 --- a/src/webnn/native/Graph.h +++ b/src/webnn/native/Graph.h @@ -80,20 +80,13 @@ namespace webnn::native { virtual MaybeError AddInstanceNorm(const op::InstanceNorm* instanceNorm); virtual MaybeError Finish(); virtual MaybeError Compile(); - - // Webnn API - void Compute(NamedInputsBase* inputs, NamedOutputsBase* outputs); - void ComputeAsync(NamedInputsBase* inputs, - NamedOutputsBase* outputs, - WNNComputeAsyncCallback callback, - void* userdata); + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs); GraphBase(ContextBase* context, ObjectBase::ErrorTag tag); static GraphBase* MakeError(ContextBase* context); private: virtual MaybeError CompileImpl() = 0; - virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) = 0; }; } // namespace webnn::native diff --git a/src/webnn/native/dmlx/GraphDMLX.h b/src/webnn/native/dmlx/GraphDMLX.h index c5dad3bb2..d49cad1b1 100644 --- a/src/webnn/native/dmlx/GraphDMLX.h +++ b/src/webnn/native/dmlx/GraphDMLX.h @@ -98,9 +98,10 @@ namespace webnn::native::dmlx { virtual MaybeError AddInstanceNorm(const op::InstanceNorm* instanceNorm) override; virtual MaybeError Finish() override; + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; + private: MaybeError CompileImpl() override; - MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; ::dml::Expression BindingConstant(DML_TENSOR_DATA_TYPE dmlTensorType, ::dml::TensorDimensions dmlTensorDims, diff --git a/src/webnn/native/mlas/GraphMLAS.h b/src/webnn/native/mlas/GraphMLAS.h index a8caa8ada..eac4a61fc 100644 --- a/src/webnn/native/mlas/GraphMLAS.h +++ b/src/webnn/native/mlas/GraphMLAS.h @@ -53,9 +53,10 @@ namespace webnn::native::mlas { virtual MaybeError AddUnary(const op::Unary* unary) override; virtual MaybeError Finish() override; + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; + private: MaybeError CompileImpl() override; - MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; std::unordered_map> mInputs; std::unordered_map> mOutputs; diff --git a/src/webnn/native/nnapi/ContextNnapi.h b/src/webnn/native/nnapi/ContextNnapi.h index 7e48f09cb..a852a9929 100644 --- a/src/webnn/native/nnapi/ContextNnapi.h +++ b/src/webnn/native/nnapi/ContextNnapi.h @@ -16,6 +16,7 @@ #define WEBNN_NATIVE_NNAPI_CONTEXT_NN_H_ #include "webnn/native/Context.h" + #include "webnn/native/Graph.h" namespace webnn::native::nnapi { diff --git a/src/webnn/native/nnapi/GraphNnapi.h b/src/webnn/native/nnapi/GraphNnapi.h index c053281b3..f4285b498 100755 --- a/src/webnn/native/nnapi/GraphNnapi.h +++ b/src/webnn/native/nnapi/GraphNnapi.h @@ -85,6 +85,8 @@ namespace webnn::native::nnapi { MaybeError AddSoftMax(const std::shared_ptr& input0Node, std::shared_ptr outputNode); + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; + private: uint32_t getOperandIdx() { return mOperandCount++; @@ -151,7 +153,7 @@ namespace webnn::native::nnapi { node->dimensions.push_back(static_cast(desc->dimensions[i])); } } - + MaybeError error; if (buffer) { error = mNnapiMgr->CreateOperandAndSetMemory(name, node, buffer); @@ -212,7 +214,6 @@ namespace webnn::native::nnapi { } MaybeError CompileImpl() override; - MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; // Map the input name to NNAPI internal input number. std::map> mInputNameMap; // Map the output name to NNAPI internal original output name that will be updated after diff --git a/src/webnn/native/null/ContextNull.h b/src/webnn/native/null/ContextNull.h index 641e9f635..b83224e83 100644 --- a/src/webnn/native/null/ContextNull.h +++ b/src/webnn/native/null/ContextNull.h @@ -74,9 +74,10 @@ namespace webnn::native::null { virtual MaybeError AddInstanceNorm(const op::InstanceNorm* instanceNorm) override; virtual MaybeError Finish() override; + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; + private: MaybeError CompileImpl() override; - MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; }; } // namespace webnn::native::null diff --git a/src/webnn/native/onednn/GraphDNNL.h b/src/webnn/native/onednn/GraphDNNL.h index 45d0053ba..f768bcae6 100644 --- a/src/webnn/native/onednn/GraphDNNL.h +++ b/src/webnn/native/onednn/GraphDNNL.h @@ -50,6 +50,8 @@ namespace webnn::native::onednn { virtual MaybeError AddClamp(const op::Clamp* clamp) override; virtual MaybeError Finish() override; + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; + private: dnnl_status_t AddConv2dImpl(const op::Conv2d* conv2d, const op::Binary* add = nullptr, @@ -62,7 +64,6 @@ namespace webnn::native::onednn { dnnl_status_t BuildPrimitives(); MaybeError CompileImpl() override; - MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; dnnl_engine_t GetEngine(); dnnl_status_t GetMemoryDesc(dnnl_memory_t memory, const dnnl_memory_desc_t** desc); dnnl_status_t ReorderIfNeeded(const dnnl_memory_desc_t* srcDesc, diff --git a/src/webnn/native/openvino/GraphIE.h b/src/webnn/native/openvino/GraphIE.h index d9b6a221b..62c732627 100644 --- a/src/webnn/native/openvino/GraphIE.h +++ b/src/webnn/native/openvino/GraphIE.h @@ -78,9 +78,10 @@ namespace webnn::native::ie { virtual MaybeError AddInstanceNorm(const op::InstanceNorm* InstanceNorm) override; virtual MaybeError Finish() override; + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; + private: MaybeError CompileImpl() override; - MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; // Map the input name to IE internal input number. std::map mInputIdMap; diff --git a/src/webnn/native/xnnpack/GraphXNN.h b/src/webnn/native/xnnpack/GraphXNN.h index d151bd2d8..7397521c9 100644 --- a/src/webnn/native/xnnpack/GraphXNN.h +++ b/src/webnn/native/xnnpack/GraphXNN.h @@ -61,9 +61,10 @@ namespace webnn::native::xnnpack { virtual MaybeError AddUnary(const op::Unary* unary) override; virtual MaybeError Finish() override; + virtual MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; + private: MaybeError CompileImpl() override; - MaybeError ComputeImpl(NamedInputsBase* inputs, NamedOutputsBase* outputs) override; pthreadpool_t GetThreadpool(); diff --git a/src/webnn/tests/end2end/AddTests.cpp b/src/webnn/tests/end2end/AddTests.cpp index 5bedd8848..833d06e68 100644 --- a/src/webnn/tests/end2end/AddTests.cpp +++ b/src/webnn/tests/end2end/AddTests.cpp @@ -45,7 +45,7 @@ TEST_F(AddTests, AddConstantAndInput) { -0.55656946, -0.735903, 0.22050636, -0.5008282, -1.3132697, 1.6642882, -0.48397836, 0.20099205, -0.28786168, 1.3315053, -0.41619393}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); const std::vector expectedValue( {-0.48879138, -2.0812354, 0.6382897, 0.07346585, -0.93846387, 2.9300475, 0.84765005, 1.2585825, -1.7465117, 2.0591164, 2.3115096, 1.2746171, -0.16182771, 0.29538065, @@ -87,7 +87,7 @@ TEST_F(AddTests, AddTwoInputs) { 0.3222213, 1.0590587, -1.7948701, -1.7195907, -0.9120889, -0.9391962, -0.2566791, -0.5464537, 1.4351872, 0.5705938, -0.30327085}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {-0.48879138, -2.0812354, 0.6382897, 0.07346585, -0.93846387, 2.9300475, 0.84765005, 1.2585825, -1.7465117, 2.0591164, 2.3115096, 1.2746171, -0.16182771, 0.29538065, @@ -122,7 +122,7 @@ TEST_F(AddTests, AddBroadcast) { 0.6338172, 1.630534, -1.3819867, -1.0427561, 1.058136, }; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {0.5484205, 1.7485408, -2.6178582, -0.7418642, 0.32369673, 2.123247, 1.7987677, -3.585476, 0.0313431, 0.7035562, 1.2490666, 2.0926871, -0.7827864, -1.8532355, diff --git a/src/webnn/tests/end2end/BatchNormTests.cpp b/src/webnn/tests/end2end/BatchNormTests.cpp index 95728c327..d2527f6ee 100644 --- a/src/webnn/tests/end2end/BatchNormTests.cpp +++ b/src/webnn/tests/end2end/BatchNormTests.cpp @@ -53,7 +53,7 @@ class BatchNormTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"output", output}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(input.shape)); - utils::Compute(graph, {{"input", input.value}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", input.value}}, {{"output", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } diff --git a/src/webnn/tests/end2end/ClampTests.cpp b/src/webnn/tests/end2end/ClampTests.cpp index 6df0b8128..3b96d47f7 100644 --- a/src/webnn/tests/end2end/ClampTests.cpp +++ b/src/webnn/tests/end2end/ClampTests.cpp @@ -26,7 +26,7 @@ class ClampTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"b", b}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(inputShape)); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; diff --git a/src/webnn/tests/end2end/ConcatTests.cpp b/src/webnn/tests/end2end/ConcatTests.cpp index 848285e1a..73f506138 100644 --- a/src/webnn/tests/end2end/ConcatTests.cpp +++ b/src/webnn/tests/end2end/ConcatTests.cpp @@ -47,7 +47,7 @@ class ConcatTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{outputName, output}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expectedShape)); - utils::Compute(graph, namedInputs, {{outputName, result}}); + utils::Compute(GetContext(), graph, namedInputs, {{outputName, result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; diff --git a/src/webnn/tests/end2end/Conv2dTests.cpp b/src/webnn/tests/end2end/Conv2dTests.cpp index 296637ab0..628b087a9 100644 --- a/src/webnn/tests/end2end/Conv2dTests.cpp +++ b/src/webnn/tests/end2end/Conv2dTests.cpp @@ -70,7 +70,7 @@ class Conv2dTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"output", y}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expected.shape)); - utils::Compute(graph, {{"input", input.value}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", input.value}}, {{"output", result}}); EXPECT_TRUE(utils::CheckValue(result, expected.value)); } diff --git a/src/webnn/tests/end2end/ConvTranspose2dTests.cpp b/src/webnn/tests/end2end/ConvTranspose2dTests.cpp index 022e1b70a..cf6d733cb 100644 --- a/src/webnn/tests/end2end/ConvTranspose2dTests.cpp +++ b/src/webnn/tests/end2end/ConvTranspose2dTests.cpp @@ -70,7 +70,7 @@ class ConvTranspose2dTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"output", y}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expected.shape)); - utils::Compute(graph, {{"input", input.value}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", input.value}}, {{"output", result}}); EXPECT_TRUE(utils::CheckValue(result, expected.value)); } diff --git a/src/webnn/tests/end2end/DivTests.cpp b/src/webnn/tests/end2end/DivTests.cpp index 83c22b804..36a77eaa8 100644 --- a/src/webnn/tests/end2end/DivTests.cpp +++ b/src/webnn/tests/end2end/DivTests.cpp @@ -44,7 +44,7 @@ TEST_F(DivTests, Div) { 2.4229836, 1.3960866, 0.40859735, 2.1244192, 1.7553957, 1.8674074, 0.34353632, -1.8345544, 3.116791, -0.61087835, 0.9642319}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {5.2773511e-01, 1.4511688e+00, -2.0733004e+00, 2.5239782e-02, 1.2193620e+00, 1.0799783e+00, -2.5929454e-01, -9.8252831e+00, -8.9286619e-01, -7.1994968e-02, @@ -80,7 +80,7 @@ TEST_F(DivTests, DivBroadcast) { 0.55578697, 0.01034931, 0.72003376, -1.8242567}; const std::vector dataB = {1.3041736, 1.5910654, 1.9217191, 1.8052639, 1.7239413}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {1.825482, 0.20777069, 0.49395692, -0.832231, -1.0311644, -0.40846005, 0.68554676, -0.18017694, -0.44017738, 0.11483412, 0.82959443, -0.9081589, -0.62992716, -0.436872, diff --git a/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp b/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp index 00241ee13..d264b091a 100644 --- a/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp +++ b/src/webnn/tests/end2end/ElementWiseUnaryTests.cpp @@ -69,7 +69,7 @@ class ElementWiseUnaryTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"b", b}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(shape)); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; diff --git a/src/webnn/tests/end2end/GemmTests.cpp b/src/webnn/tests/end2end/GemmTests.cpp index 6d2ebfebb..c192f0c2e 100644 --- a/src/webnn/tests/end2end/GemmTests.cpp +++ b/src/webnn/tests/end2end/GemmTests.cpp @@ -59,9 +59,9 @@ class GemmTests : public WebnnTest { ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expectedShape)); if (constantWeight) { - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); } else { - utils::Compute(graph, {{"a", aData}, {"b", bData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}, {"b", bData}}, {{"c", result}}); } EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } diff --git a/src/webnn/tests/end2end/GruTests.cpp b/src/webnn/tests/end2end/GruTests.cpp index 39f8aa024..4c8f12398 100644 --- a/src/webnn/tests/end2end/GruTests.cpp +++ b/src/webnn/tests/end2end/GruTests.cpp @@ -16,7 +16,8 @@ class GruTests : public WebnnTest { void SetUp() override { - builder = wnn::CreateGraphBuilder(GetContext()); + context = GetContext(); + builder = wnn::CreateGraphBuilder(context); } protected: @@ -55,12 +56,13 @@ class GruTests : public WebnnTest { results.push_back(std::vector(utils::SizeOfShape(expected[i].shape))); namedOutputs.push_back({"gru" + std::to_string(i), results[i]}); } - utils::Compute(graph, {{"a", input.value}}, namedOutputs); + utils::Compute(context, graph, {{"a", input.value}}, namedOutputs); for (size_t i = 0; i < outputSize; ++i) { EXPECT_TRUE(utils::CheckValue(namedOutputs[i].resource, expected[i].value)); } } + wnn::Context context; wnn::GraphBuilder builder; }; diff --git a/src/webnn/tests/end2end/HardSwishTests.cpp b/src/webnn/tests/end2end/HardSwishTests.cpp index 2f54a23d8..5597a1abd 100644 --- a/src/webnn/tests/end2end/HardSwishTests.cpp +++ b/src/webnn/tests/end2end/HardSwishTests.cpp @@ -27,7 +27,7 @@ class HardSwishTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"y", y}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(inputShape)); - utils::Compute(graph, {{"x", inputBuffer}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", inputBuffer}}, {{"y", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedBuffer)); } }; diff --git a/src/webnn/tests/end2end/InstanceNormTests.cpp b/src/webnn/tests/end2end/InstanceNormTests.cpp index c6d5dc257..d374870d1 100644 --- a/src/webnn/tests/end2end/InstanceNormTests.cpp +++ b/src/webnn/tests/end2end/InstanceNormTests.cpp @@ -29,7 +29,7 @@ class InstanceNormTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"b", b}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(inputShape)); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } wnn::GraphBuilder builder; diff --git a/src/webnn/tests/end2end/LeakyReluTests.cpp b/src/webnn/tests/end2end/LeakyReluTests.cpp index 303b51ac1..7efe5899d 100644 --- a/src/webnn/tests/end2end/LeakyReluTests.cpp +++ b/src/webnn/tests/end2end/LeakyReluTests.cpp @@ -28,7 +28,7 @@ class LeakyReluTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"b", b}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(inputShape)); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; diff --git a/src/webnn/tests/end2end/MatMulTests.cpp b/src/webnn/tests/end2end/MatMulTests.cpp index 10f94f618..482d3454a 100644 --- a/src/webnn/tests/end2end/MatMulTests.cpp +++ b/src/webnn/tests/end2end/MatMulTests.cpp @@ -27,7 +27,7 @@ TEST_F(MatMulTests, MatMul1d) { ASSERT_TRUE(graph); const std::vector aData = {0.9025404, 0.89538723, 0.16789329, 0.7440875}; std::vector result(utils::SizeOfShape({1})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = {1.1453342}; EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -46,7 +46,7 @@ TEST_F(MatMulTests, MatMul1dx2d) { ASSERT_TRUE(graph); const std::vector aData = {0.1309212, 0.9090703, 0.62183434, 0.9195683}; std::vector result(utils::SizeOfShape({1, 3})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = {0.6616409, -0.80990994, 0.8797145}; EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -65,7 +65,7 @@ TEST_F(MatMulTests, MatMul2dx1d) { 0.4122941, 0.6787481, 0.15072346, 0.2820577, 0.67296237, 0.3856028, }; std::vector result(utils::SizeOfShape({3, 1})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = {0.8839391, 0.9928265, 0.5955407}; EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -85,7 +85,7 @@ TEST_F(MatMulTests, MatMul2d) { 0.40872088, 0.18995902, 0.69355214, -0.37210146, 0.18104352, 3.270753, -0.803097, -0.7268995}; std::vector result(utils::SizeOfShape({3, 3})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = {1.5347629, -0.3981255, 2.6510081, -0.14295794, 0.6647107, -0.70315295, 1.3096018, 3.9256358, 3.873897}; @@ -112,7 +112,7 @@ TEST_F(MatMulTests, MatMul3d) { 0.7644939, -0.8567167, 0.3942727, -0.772506, -0.06412488, -0.9848109, }; std::vector result(utils::SizeOfShape({2, 3, 3})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = { -0.10833447, -0.13393278, 0.8061598, -1.3357227, 2.449343, -2.801163, 0.31218773, -2.7866507, 1.7064441, 1.5293882, -0.02957799, 0.5971595, @@ -138,7 +138,7 @@ TEST_F(MatMulTests, MatMul3dx2d) { 0.09809388, 0.5084747, 0.76594603, 0.8050488, -0.03979152, 2.4019558, -0.54937273, -0.1696853, -1.223669, 1.0791223, -0.61921734, 2.1074235}; std::vector result(utils::SizeOfShape({2, 3, 3})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = { -0.8885305, 1.0170201, 1.8490261, 1.8789318, -2.3183105, -2.9326258, 0.775168, 0.2069526, 2.7969716, 0.9437693, -0.5050435, 0.15727985, @@ -161,7 +161,7 @@ TEST_F(MatMulTests, MatMul3dx2dGet3d) { 0.9875369, 1.3744136, 0.61079186, 0.74018836, -0.56111795, -0.16432828, 1.3176169, -0.249416}; std::vector result(utils::SizeOfShape({1, 3, 3})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = { 0.6533069, -1.4796758, -2.6561086, -1.607665, -0.04264185, 2.1811159, -0.13444155, 2.297084, 2.762841, @@ -190,7 +190,7 @@ TEST_F(MatMulTests, MatMul4d) { -0.07425678, -1.2638812, -1.1002079, -1.5324054, -1.1643038, -0.05644368, }; std::vector result(utils::SizeOfShape({1, 2, 3, 3})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = { 1.2216457, -1.0545375, 1.2706597, -2.2521434, -0.4334606, 2.1588962, -0.1886742, 0.66638416, -1.1427099, 0.47668338, 1.464142, -0.84385866, @@ -218,7 +218,7 @@ TEST_F(MatMulTests, MatMul4dx2d) { -0.9099179, -0.6195976, 0.38710263, 0.5102191, -0.03610202, 1.2280966, }; std::vector result(utils::SizeOfShape({1, 2, 3, 3})); - utils::Compute(graph, {{"a", aData}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", aData}}, {{"c", result}}); const std::vector expectedValue = { 3.2632291, 0.19901966, 0.5334567, -1.3227482, -3.223286, -2.628851, 1.118986, 2.7767603, -0.25850934, -2.185273, 0.3517071, 2.061255, diff --git a/src/webnn/tests/end2end/MaxTests.cpp b/src/webnn/tests/end2end/MaxTests.cpp index 9bdce0581..d3e3cd31e 100644 --- a/src/webnn/tests/end2end/MaxTests.cpp +++ b/src/webnn/tests/end2end/MaxTests.cpp @@ -45,7 +45,7 @@ TEST_F(MaxTests, MaxConstantAndInput) { -0.92654175, -0.507083, -1.8776977, 0.57921827, 1.460351, 1.4930215, -0.757663, 1.0773797, -1.1858964, -0.5337765, 0.27636543}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); const std::vector expectedValue( {0.54270846, 0.3356357, 0.17466596, 1.6710619, 1.3720452, 1.4024457, -0.5183214, -0.26632488, 0.16786452, -0.2980101, 0.12268824, 1.8612522, 0.2960607, 0.85281086, @@ -87,7 +87,7 @@ TEST_F(MaxTests, MaxTwoInputs) { 0.5907742, -1.0454807, -0.8065648, 2.0162134, -0.30215183, 0.67375183, 1.6682644, -2.916385, 0.43166366, -0.7290503, 0.11509943}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {0.54270846, 0.3356357, 0.17466596, 1.6710619, 1.3720452, 1.4024457, -0.5183214, -0.26632488, 0.16786452, -0.2980101, 0.12268824, 1.8612522, 0.2960607, 0.85281086, @@ -120,7 +120,7 @@ TEST_F(MaxTests, MaxBroadcast) { -0.13389224, -0.5757679, 0.38655168, -0.39935285}; std::vector dataB = {0.67538136, 0.3535401, 1.0303422, -0.50294054, -0.25600532}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {0.67538136, 0.3535401, 1.0303422, -0.24858657, 0.36215156, 0.67538136, 1.540389, 1.9143543, 0.4806893, 0.0123093, 1.2142435, 0.3535401, 1.0303422, 1.1247561, diff --git a/src/webnn/tests/end2end/MinTests.cpp b/src/webnn/tests/end2end/MinTests.cpp index 77f29a99a..fc600e7a7 100644 --- a/src/webnn/tests/end2end/MinTests.cpp +++ b/src/webnn/tests/end2end/MinTests.cpp @@ -45,7 +45,7 @@ TEST_F(MinTests, MinConstantAndInput) { -0.01093488, -0.3274254, 0.73195547, -0.5514492, -0.7521337, -1.0613606, 0.6751333, 0.9138903, 1.7775172, 0.5034791, 0.00691956}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); const std::vector expectedValue( {-0.3013072, -0.09710764, 0.11072686, 0.57673335, -0.9459303, -0.4660466, -0.51731133, -1.1046865, -0.7237214, -2.4551184, 0.05005725, -0.505013, -0.93030375, -0.46502006, @@ -87,7 +87,7 @@ TEST_F(MinTests, MinTwoInputs) { -0.8959271, 1.2020742, -0.24440259, 0.18198308, -1.3384086, -0.5169678, -0.6608337, 0.30539933, -1.529869, -0.70533603, -2.1911235}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {-0.3013072, -0.09710764, 0.11072686, 0.57673335, -0.9459303, -0.4660466, -0.51731133, -1.1046865, -0.7237214, -2.4551184, 0.05005725, -0.505013, -0.93030375, -0.46502006, @@ -120,7 +120,7 @@ TEST_F(MinTests, MinBroadcast) { 0.04377401, -0.26201916, -1.6016098, -0.74272215}; std::vector dataB = {0.6450575, -1.302236, 0.27485028, 1.8353013, -0.83993983}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {0.09259097, -1.302236, 0.27485028, 0.83395857, -0.83993983, -0.10002025, -1.302236, 0.27485028, 0.7070375, -0.83993983, -1.1588863, -1.302236, -0.27449006, 1.3718864, diff --git a/src/webnn/tests/end2end/MulTests.cpp b/src/webnn/tests/end2end/MulTests.cpp index cfd1a0a21..df3838584 100644 --- a/src/webnn/tests/end2end/MulTests.cpp +++ b/src/webnn/tests/end2end/MulTests.cpp @@ -48,7 +48,7 @@ TEST_F(MulTests, MulInputAndConstant) { 5.3133094e-01, 2.3897937e-01, -1.3832775e+00, 6.3414145e-01, 1.0691971e+00, 5.7040757e-01, 3.0711100e-01, 8.8405716e-01, -2.1583509e+00, 4.3243581e-01}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); std::vector expectedData = { 1.1491189e+00, 9.4631165e-03, 1.6490275e+00, -2.4890469e-02, 8.1811851e-01, 1.6337387e-01, -7.8853898e-02, -1.2602202e+00, -5.3575772e-01, -4.1527072e-01, @@ -96,7 +96,7 @@ TEST_F(MulTests, MulTwoInputs) { -1.3840157, 1.9665064, 0.35833818, -0.87076694, -0.76727265, 0.6157508, -0.5558823, 0.18417479, -0.93904793, -0.00859687, 0.5034271}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); std::vector expectedData = { 1.1491189e+00, 9.4631165e-03, 1.6490275e+00, -2.4890469e-02, 8.1811851e-01, 1.6337387e-01, -7.8853898e-02, -1.2602202e+00, -5.3575772e-01, -4.1527072e-01, @@ -135,7 +135,7 @@ TEST_F(MulTests, MulBroadcast) { 1.167636, 0.03020451, 0.91373825, 1.0675793, }; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); std::vector expectedData = { -0.05412592, 0.192414, 1.707958, -0.31375682, -0.7771366, 0.9440262, 0.2743106, 3.045193, -1.1200235, -0.37519363, 0.3899556, 0.7535562, -0.82808685, 0.8451324, diff --git a/src/webnn/tests/end2end/PadTests.cpp b/src/webnn/tests/end2end/PadTests.cpp index fab73be21..d9d689cee 100644 --- a/src/webnn/tests/end2end/PadTests.cpp +++ b/src/webnn/tests/end2end/PadTests.cpp @@ -34,7 +34,7 @@ class PadTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"y", y}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expectedShape)); - utils::Compute(graph, {{"x", inputData}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", inputData}}, {{"y", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; diff --git a/src/webnn/tests/end2end/Pool2dTests.cpp b/src/webnn/tests/end2end/Pool2dTests.cpp index 39d8f6dc9..c78fd97f7 100644 --- a/src/webnn/tests/end2end/Pool2dTests.cpp +++ b/src/webnn/tests/end2end/Pool2dTests.cpp @@ -26,7 +26,7 @@ TEST_F(Pool2dTests, MaxPool2dDefault) { ASSERT_TRUE(graph); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; std::vector result(utils::SizeOfShape({1, 1, 2, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({11, 12, 15, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -42,7 +42,7 @@ TEST_F(Pool2dTests, MaxPool2dNhwc) { ASSERT_TRUE(graph); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; std::vector result(utils::SizeOfShape({1, 2, 2, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({11, 12, 15, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -58,7 +58,7 @@ TEST_F(Pool2dTests, MaxPool2dDilationsDefault) { ASSERT_TRUE(graph); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; std::vector result(utils::SizeOfShape({1, 1, 2, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({11, 12, 15, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -75,7 +75,7 @@ TEST_F(Pool2dTests, MaxPool2dDilationsNhwc) { ASSERT_TRUE(graph); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; std::vector result(utils::SizeOfShape({1, 2, 2, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({11, 12, 15, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -92,7 +92,7 @@ TEST_F(Pool2dTests, MaxPool2dPadsDefault) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 1, 5, 5})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -111,7 +111,7 @@ TEST_F(Pool2dTests, MaxPool2dPadsNhwc) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 5, 5, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -129,7 +129,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperDefault) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 1, 5, 5})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -152,7 +152,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitNhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 4, 4, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {9, 11, 13, 14, 23, 25, 27, 28, 37, 39, 41, 42, 44, 46, 48, 49}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -176,7 +176,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes3x3Nhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 3, 3, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({17, 19, 21, 31, 33, 35, 45, 47, 49}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -199,7 +199,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitOutputSizes4x4Nhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 4, 4, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {17, 19, 21, 21, 31, 33, 35, 35, 45, 47, 49, 49, 45, 47, 49, 49}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -223,7 +223,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeFloorNhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 3, 3, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({17, 19, 21, 31, 33, 35, 45, 47, 49}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -246,7 +246,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadExplicitRoundingTypeCeilNhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 4, 4, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {17, 19, 21, 21, 31, 33, 35, 35, 45, 47, 49, 49, 45, 47, 49, 49}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -268,7 +268,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameLowerNhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 4, 4, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {9, 11, 13, 14, 23, 25, 27, 28, 37, 39, 41, 42, 44, 46, 48, 49}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -287,7 +287,7 @@ TEST_F(Pool2dTests, MaxPool2dAutoPadSameUpperNhwc) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 5, 5, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -305,7 +305,7 @@ TEST_F(Pool2dTests, MaxPool2dStridesDefault) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 1, 2, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({7, 9, 17, 19}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -323,7 +323,7 @@ TEST_F(Pool2dTests, MaxPool2dStridesNhwc) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 2, 2, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({7, 9, 17, 19}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -338,7 +338,7 @@ TEST_F(Pool2dTests, AveragePool2dDefault) { ASSERT_TRUE(graph); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; std::vector result(utils::SizeOfShape({1, 1, 2, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({6, 7, 10, 11}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -354,7 +354,7 @@ TEST_F(Pool2dTests, AveragePool2dNhwc) { ASSERT_TRUE(graph); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; std::vector result(utils::SizeOfShape({1, 2, 2, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({6, 7, 10, 11}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -371,7 +371,7 @@ TEST_F(Pool2dTests, AveragePool2dPadsDefault) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 1, 5, 5})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5, 19}); @@ -391,7 +391,7 @@ TEST_F(Pool2dTests, AveragePool2dPadsNhwc) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 5, 5, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5, 19}); @@ -410,7 +410,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperDefault) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 1, 5, 5})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5, 19}); @@ -430,7 +430,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameUpperNhwc) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 5, 5, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5, 19}); @@ -454,7 +454,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitNhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 4, 4, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {5, 6, 8, 9.5, 12, 13, 15, 16.5, 26, 27, 29, 30.5, 36.5, 37.5, 39.5, 41}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -478,7 +478,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes3x3Nhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 3, 3, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({9, 10.5, 12.5, 19.5, 21, 23, 33.5, 35, 37}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -501,7 +501,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitOutputSizes4x4Nhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 4, 4, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {9, 10.5, 12.5, 13.5, 19.5, 21, 23, 24, 33.5, 35, 37, 38, 40.5, 42, 44, 45}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -525,7 +525,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeFloorNhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 3, 3, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({9, 10.5, 12.5, 19.5, 21, 23, 33.5, 35, 37}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -548,7 +548,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadExplicitRoundingTypeCeilNhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 4, 4, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {9, 10.5, 12.5, 13.5, 19.5, 21, 23, 24, 33.5, 35, 37, 38, 40.5, 42, 44, 45}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -570,7 +570,7 @@ TEST_F(Pool2dTests, AveragePool2dAutoPadSameLowerNhwc) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 4, 4, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {5, 6, 8, 9.5, 12, 13, 15, 16.5, 26, 27, 29, 30.5, 36.5, 37.5, 39.5, 41}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); @@ -588,7 +588,7 @@ TEST_F(Pool2dTests, AveragePool2dStridesDefault) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 1, 2, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({4, 6, 14, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -606,7 +606,7 @@ TEST_F(Pool2dTests, AveragePool2dStridesNhwc) { const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; std::vector result(utils::SizeOfShape({1, 2, 2, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({4, 6, 14, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -630,7 +630,7 @@ TEST_F(Pool2dTests, GlobalAveragePool2dDefault) { 0.20211531, 0.8832168, -0.19886094, -0.61088, 0.682026, -0.5253442, 1.5022339, 1.0256356, 1.0642492, -0.4169051, -0.8740329, 1.1494869}; std::vector result(utils::SizeOfShape({1, 3, 1, 1})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({0.07170041, 0.05194739, 0.07117923}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -656,7 +656,7 @@ TEST_F(Pool2dTests, GlobalAveragePool2dNhwc) { 1.1549494, 0.19907428, 1.0642492, 0.24823844, 0.20298219, -0.4169051, 0.75670505, -0.8399954, -0.8740329, -1.7108902, 1.3583295, 1.1494869}; std::vector result(utils::SizeOfShape({1, 1, 1, 3})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({0.07170041, 0.05194739, 0.07117923}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -671,7 +671,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesDefault) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 3})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 1, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -687,7 +687,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStrides) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -703,7 +703,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dStridesNhwc) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -720,7 +720,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsDefault) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -742,7 +742,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes3x3) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 1, 3, 3})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {10.692676544189453, 12.006942749023438, 13.790093421936035, 21.027759552001953, 22.438806533813477, 24.320772171020508, 34.41172409057617, 35.881752014160156, @@ -767,7 +767,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsOutputSizes4x4) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 1, 4, 4})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {10.692676544189453, 12.006942749023438, 13.790093421936035, 14.668560981750488, 21.027759552001953, 22.438806533813477, 24.320772171020508, 25.248762130737305, @@ -793,7 +793,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeFloor) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 1, 3, 3})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {10.692676544189453, 12.006942749023438, 13.790093421936035, 21.027759552001953, 22.438806533813477, 24.320772171020508, 34.41172409057617, 35.881752014160156, @@ -818,7 +818,7 @@ TEST_F(Pool2dTests, DISABLED_l2Pool2dPadsRoundingTypeCeil) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49}; std::vector result(utils::SizeOfShape({1, 1, 4, 4})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue( {10.692676544189453, 12.006942749023438, 13.790093421936035, 14.668560981750488, 21.027759552001953, 22.438806533813477, 24.320772171020508, 25.248762130737305, @@ -840,7 +840,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dPadsNhwc) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -857,7 +857,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperDefault) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -875,7 +875,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameUpperNhwc) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -892,7 +892,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameLowerDefault) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -910,7 +910,7 @@ TEST_F(Pool2dTests, DISABLED_L2Pool2dSameLowerNhwc) { ASSERT_TRUE(graph); const std::vector dataX = {-1, 2, 0, 3, -2, 0, 0, -4}; std::vector result(utils::SizeOfShape({1, 1, 1, 2})); - utils::Compute(graph, {{"x", dataX}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", dataX}}, {{"y", result}}); const std::vector expectedValue({1.5, 2.5}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } diff --git a/src/webnn/tests/end2end/PowTests.cpp b/src/webnn/tests/end2end/PowTests.cpp index 43a11ad43..09f7a0b11 100644 --- a/src/webnn/tests/end2end/PowTests.cpp +++ b/src/webnn/tests/end2end/PowTests.cpp @@ -27,7 +27,7 @@ TEST_F(PowTests, Sqrt1d) { ASSERT_TRUE(graph); const std::vector dataA = {1, 4, 9}; std::vector result(utils::SizeOfShape({3})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); const std::vector expectedValue({1, 2, 3}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -52,7 +52,7 @@ TEST_F(PowTests, Sqrt3d) { 0.6550839, 0.7919175, 0.21990986, 0.2881369, 0.5660939, 0.54675615, 0.70638055, 0.82219034, 0.6266006, 0.89149487, 0.36557788}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); const std::vector expectedValue( {0.5782331, 0.7559077, 0.1920685, 0.88435894, 0.8785719, 0.4208243, 1.0277354, 1.5064393, 1.0163065, 1.2666107, 1.4384935, 1.3356625, 1.2201996, 0.75858086, @@ -77,7 +77,7 @@ TEST_F(PowTests, Pow1d) { ASSERT_TRUE(graph); const std::vector dataA = {1, 2, 3}; std::vector result(utils::SizeOfShape({3})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); const std::vector expectedValue({1, 4, 9}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -93,7 +93,7 @@ TEST_F(PowTests, PowBroadcastScalar) { ASSERT_TRUE(graph); const std::vector dataA = {1, 2, 3, 4, 5, 6}; std::vector result(utils::SizeOfShape({2, 3})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); const std::vector expectedValue({1, 4, 9, 16, 25, 36}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } @@ -109,7 +109,7 @@ TEST_F(PowTests, PowBroadcast1d) { ASSERT_TRUE(graph); const std::vector dataA = {1, 2, 3, 4, 5, 6}; std::vector result(utils::SizeOfShape({2, 3})); - utils::Compute(graph, {{"a", dataA}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}}, {{"c", result}}); const std::vector expectedValue({1, 4, 27, 4, 25, 216}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } diff --git a/src/webnn/tests/end2end/ReduceTests.cpp b/src/webnn/tests/end2end/ReduceTests.cpp index 65e5c8e87..f0fb3a9a7 100644 --- a/src/webnn/tests/end2end/ReduceTests.cpp +++ b/src/webnn/tests/end2end/ReduceTests.cpp @@ -72,7 +72,7 @@ class ReduceTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"b", b}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expectedShape)); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; diff --git a/src/webnn/tests/end2end/ReluTests.cpp b/src/webnn/tests/end2end/ReluTests.cpp index 4b8afc37e..163bf8113 100644 --- a/src/webnn/tests/end2end/ReluTests.cpp +++ b/src/webnn/tests/end2end/ReluTests.cpp @@ -33,7 +33,7 @@ TEST_F(ReluTests, Relu) { 1.1774449, 0.8999488, -1.1143959, 1.0122099, -0.48604885, -0.06009902, -0.1766853, 1.4515465, -0.7182982, 2.0361354, 0.7899623}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); const std::vector expectedData( {0., 0.6447428, 0., 0., 0.9777725, 0., 0., 0., 1.3725083, 0., 0., 0., 0.5027815, 0., 0., 0.00880813, diff --git a/src/webnn/tests/end2end/Resample2dTests.cpp b/src/webnn/tests/end2end/Resample2dTests.cpp index 42024904b..5dcbbcc6b 100644 --- a/src/webnn/tests/end2end/Resample2dTests.cpp +++ b/src/webnn/tests/end2end/Resample2dTests.cpp @@ -27,7 +27,7 @@ class Resample2dTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"output", output}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expectedShape)); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; diff --git a/src/webnn/tests/end2end/ReshapeTests.cpp b/src/webnn/tests/end2end/ReshapeTests.cpp index 39c133390..d6ecc4cd1 100644 --- a/src/webnn/tests/end2end/ReshapeTests.cpp +++ b/src/webnn/tests/end2end/ReshapeTests.cpp @@ -29,7 +29,7 @@ class ReshapeTests : public WebnnTest { : std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; std::vector result(utils::SizeOfShape(expectedShape)); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); EXPECT_TRUE(utils::CheckValue(result, inputData)); } }; diff --git a/src/webnn/tests/end2end/SigmoidTests.cpp b/src/webnn/tests/end2end/SigmoidTests.cpp index 893904d88..c508e97d3 100644 --- a/src/webnn/tests/end2end/SigmoidTests.cpp +++ b/src/webnn/tests/end2end/SigmoidTests.cpp @@ -24,7 +24,7 @@ TEST_F(SigmoidTests, SigmoidWith1DTensor) { ASSERT_TRUE(graph); const std::vector inputData = {-1, 0, 1}; std::vector result(utils::SizeOfShape({3})); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); const std::vector expectedData = {0.26894143, 0.5, 0.7310586}; EXPECT_TRUE(utils::CheckValue(result, expectedData)); } @@ -47,7 +47,7 @@ TEST_F(SigmoidTests, SigmoidWith3DTensor) { 2.0984166, -1.2020895, 1.5637838, -0.7114222, }; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); const std::vector expectedData = { 0.4541994, 0.61787516, 0.9381, 0.50759846, 0.5104914, 0.23981662, 0.11790343, 0.29200357, 0.43591312, 0.7340212, 0.44518682, 0.53401446, 0.7021274, 0.4601159, diff --git a/src/webnn/tests/end2end/SliceTests.cpp b/src/webnn/tests/end2end/SliceTests.cpp index 759b5c9b7..094c883ae 100644 --- a/src/webnn/tests/end2end/SliceTests.cpp +++ b/src/webnn/tests/end2end/SliceTests.cpp @@ -37,7 +37,7 @@ class SliceTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"output", output}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expected.shape)); - utils::Compute(graph, {{"input", input.value}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", input.value}}, {{"output", result}}); EXPECT_TRUE(utils::CheckValue(result, expected.value)); } diff --git a/src/webnn/tests/end2end/SoftmaxTests.cpp b/src/webnn/tests/end2end/SoftmaxTests.cpp index cb8e3b89b..06dc8ccbe 100644 --- a/src/webnn/tests/end2end/SoftmaxTests.cpp +++ b/src/webnn/tests/end2end/SoftmaxTests.cpp @@ -26,7 +26,7 @@ TEST_F(SoftmaxTests, Softmax) { 0.58390397, 0.1735679, 0.539724, -0.953514, -0.59202826, -0.17344485, 0.14395015, -0.37920907}; std::vector result(utils::SizeOfShape({3, 4})); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); const std::vector expectedData = {0.32165375, 0.36157736, 0.0653337, 0.25143513, 0.35271573, 0.23400122, 0.33747196, 0.07581109, 0.17110129, 0.26004094, 0.35717794, 0.21167983}; diff --git a/src/webnn/tests/end2end/SplitTests.cpp b/src/webnn/tests/end2end/SplitTests.cpp index 69b6b21fd..2c6ec1ac0 100644 --- a/src/webnn/tests/end2end/SplitTests.cpp +++ b/src/webnn/tests/end2end/SplitTests.cpp @@ -47,7 +47,7 @@ class SplitTests : public WebnnTest { results.push_back(std::vector(utils::SizeOfShape(expectedArray[i].shape))); namedOutputs.push_back({"split" + std::to_string(i), results.back()}); } - utils::Compute(graph, {{"input", inputBuffer}}, namedOutputs); + utils::Compute(GetContext(), graph, {{"input", inputBuffer}}, namedOutputs); for (size_t i = 0; i < splittedOperands.Size(); ++i) { EXPECT_TRUE(utils::CheckValue(namedOutputs[i].resource, expectedArray[i].buffer)); } diff --git a/src/webnn/tests/end2end/SqueezeTests.cpp b/src/webnn/tests/end2end/SqueezeTests.cpp index af336a1b0..2da8070da 100644 --- a/src/webnn/tests/end2end/SqueezeTests.cpp +++ b/src/webnn/tests/end2end/SqueezeTests.cpp @@ -36,7 +36,7 @@ class SqueezeTests : public WebnnTest { input = std::rand(); } std::vector result(utils::SizeOfShape(expectedShape)); - utils::Compute(graph, {{"x", inputBuffer}}, {{"y", result}}); + utils::Compute(GetContext(), graph, {{"x", inputBuffer}}, {{"y", result}}); EXPECT_TRUE(utils::CheckValue(result, inputBuffer)); } }; diff --git a/src/webnn/tests/end2end/SubTests.cpp b/src/webnn/tests/end2end/SubTests.cpp index 16beeb7ab..ef0bcad0e 100644 --- a/src/webnn/tests/end2end/SubTests.cpp +++ b/src/webnn/tests/end2end/SubTests.cpp @@ -45,7 +45,7 @@ TEST_F(SubTests, SubTwoInputs) { 1.4805148, 1.867559, 0.90604466, -0.86122566, 1.9100649, -0.26800337, 0.8024564, 0.947252, -0.15501009, 0.61407936, 0.9222067}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue( {2.436513, 0.7597104, 1.7918843, 3.9671757, 1.6901319, -0.5754969, 2.5802867, -0.61413944, 0.80407953, 0.35865313, -0.585047, 1.3252906, -0.378363, 1.3565009, @@ -78,7 +78,7 @@ TEST_F(SubTests, SubBroadcast) { -0.80340964, -0.6895498, -0.4555325, 0.01747916}; const std::vector dataB = {-0.35399392, -1.3749512, -0.6436184, -2.2234032, 0.62523144}; std::vector result(utils::SizeOfShape({3, 4, 5})); - utils::Compute(graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); + utils::Compute(GetContext(), graph, {{"a", dataA}, {"b", dataB}}, {{"c", result}}); const std::vector expectedValue({ 0.73041946, 0.27555048, 0.9418566, 3.549789, -1.3197993, 0.20435938, 0.9397977, 2.4928823, 2.895698, -0.21776962, -0.41592214, 1.9142004, -0.03071427, 2.2552338, diff --git a/src/webnn/tests/end2end/TanhTests.cpp b/src/webnn/tests/end2end/TanhTests.cpp index 62b487ee4..84ce8bce3 100644 --- a/src/webnn/tests/end2end/TanhTests.cpp +++ b/src/webnn/tests/end2end/TanhTests.cpp @@ -25,7 +25,7 @@ class TanhTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"b", b}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(shape)); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedData)); } }; diff --git a/src/webnn/tests/end2end/TransposeTests.cpp b/src/webnn/tests/end2end/TransposeTests.cpp index c93e808cc..23a271d71 100644 --- a/src/webnn/tests/end2end/TransposeTests.cpp +++ b/src/webnn/tests/end2end/TransposeTests.cpp @@ -30,7 +30,7 @@ class TransposeTests : public WebnnTest { const wnn::Graph graph = utils::Build(builder, {{"b", b}}); ASSERT_TRUE(graph); std::vector result(utils::SizeOfShape(expectedShape)); - utils::Compute(graph, {{"a", inputData}}, {{"b", result}}); + utils::Compute(GetContext(), graph, {{"a", inputData}}, {{"b", result}}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } }; diff --git a/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp b/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp index 084c841bd..7b4bb0e8c 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2BatchNormNchw.cpp @@ -32,7 +32,7 @@ class MobileNetV2BatchNormNchwTests : public WebnnTest { const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); std::vector result(utils::SizeOfShape({1, 1000})); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); const cnpy::NpyArray outputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + expectedFile); EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec())); } diff --git a/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp b/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp index 71a66ff5f..8a2e3a128 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2Nchw.cpp @@ -32,7 +32,7 @@ class MobileNetV2NchwTests : public WebnnTest { const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); std::vector result(utils::SizeOfShape({1, 1000})); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); const cnpy::NpyArray outputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + expectedFile); EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec())); } diff --git a/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp b/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp index 47842facc..9bb11fde2 100644 --- a/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp +++ b/src/webnn/tests/end2end/models/MobileNetV2Nhwc.cpp @@ -33,7 +33,7 @@ class MobileNetV2NhwcTests : public WebnnTest { const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); std::vector result(utils::SizeOfShape({1, 1001})); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); const cnpy::NpyArray outputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + expectedFile); EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec())); } diff --git a/src/webnn/tests/end2end/models/ResNetNchw.cpp b/src/webnn/tests/end2end/models/ResNetNchw.cpp index 54e48122a..5a56379d3 100644 --- a/src/webnn/tests/end2end/models/ResNetNchw.cpp +++ b/src/webnn/tests/end2end/models/ResNetNchw.cpp @@ -32,7 +32,7 @@ class ResNetNchwTests : public WebnnTest { const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); std::vector result(utils::SizeOfShape({1, 1000})); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); const cnpy::NpyArray outputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + expectedFile); EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec())); } diff --git a/src/webnn/tests/end2end/models/ResNetNhwc.cpp b/src/webnn/tests/end2end/models/ResNetNhwc.cpp index ad0fcb23c..8bc4aa503 100644 --- a/src/webnn/tests/end2end/models/ResNetNhwc.cpp +++ b/src/webnn/tests/end2end/models/ResNetNhwc.cpp @@ -32,7 +32,7 @@ class ResNetNhwcTests : public WebnnTest { const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); std::vector result(utils::SizeOfShape({1, 1001})); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); const cnpy::NpyArray outputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + expectedFile); EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec())); } diff --git a/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp b/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp index 9430f604f..8682fa223 100644 --- a/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp +++ b/src/webnn/tests/end2end/models/SqueezeNetNchw.cpp @@ -32,7 +32,7 @@ class SqueezeNetNchwTests : public WebnnTest { const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); std::vector result(utils::SizeOfShape({1, 1000})); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); const cnpy::NpyArray outputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + expectedFile); EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec())); } diff --git a/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp b/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp index adfa1d6ff..201f875e1 100644 --- a/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp +++ b/src/webnn/tests/end2end/models/SqueezeNetNhwc.cpp @@ -33,7 +33,7 @@ class SqueezeNetNhwcTests : public WebnnTest { const cnpy::NpyArray inputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); std::vector result(utils::SizeOfShape({1, 1001})); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); const cnpy::NpyArray outputNpy = cnpy::npy_load(nhwcPath + "test_data_set/" + expectedFile); EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec())); } diff --git a/src/webnn/tests/end2end/models/SuperResolutionNchw.cpp b/src/webnn/tests/end2end/models/SuperResolutionNchw.cpp index 6aae73076..f072c701d 100644 --- a/src/webnn/tests/end2end/models/SuperResolutionNchw.cpp +++ b/src/webnn/tests/end2end/models/SuperResolutionNchw.cpp @@ -32,7 +32,7 @@ class SuperResolutionNchwTests : public WebnnTest { const cnpy::NpyArray inputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + inputFile); const std::vector inputData = inputNpy.as_vec(); std::vector result(utils::SizeOfShape({/*TODO: batchSize?*/ 1, 1, 672, 672})); - utils::Compute(graph, {{"input", inputData}}, {{"output", result}}); + utils::Compute(GetContext(), graph, {{"input", inputData}}, {{"output", result}}); const cnpy::NpyArray outputNpy = cnpy::npy_load(nchwPath + "test_data_set/" + expectedFile); EXPECT_TRUE(utils::CheckValue(result, outputNpy.as_vec())); } diff --git a/src/webnn/tests/unittests/native/mocks/GraphMock.h b/src/webnn/tests/unittests/native/mocks/GraphMock.h index ed9bbcc88..d5a5593ca 100644 --- a/src/webnn/tests/unittests/native/mocks/GraphMock.h +++ b/src/webnn/tests/unittests/native/mocks/GraphMock.h @@ -56,10 +56,6 @@ namespace webnn::native { (override)); MOCK_METHOD(MaybeError, Finish, (), (override)); MOCK_METHOD(MaybeError, CompileImpl, (), (override)); - MOCK_METHOD(MaybeError, - ComputeImpl, - (NamedInputsBase * inputs, NamedOutputsBase* outputs), - (override)); }; } // namespace webnn::native diff --git a/src/webnn/wire/client/ClientDoers.cpp b/src/webnn/wire/client/ClientDoers.cpp index e867ee46c..d7801fedb 100644 --- a/src/webnn/wire/client/ClientDoers.cpp +++ b/src/webnn/wire/client/ClientDoers.cpp @@ -28,19 +28,19 @@ namespace webnn::wire::client { return context->OnPopErrorScopeCallback(requestSerial, errorType, message); } - bool Client::DoGraphComputeResult(NamedOutputs* namedOutputs, - char const* name, - uint8_t const* buffer, - size_t byteLength, - size_t byteOffset) { - return namedOutputs->OutputResult(name, buffer, byteLength, byteOffset); + bool Client::DoContextComputeCallback(Context* context, + uint64_t requestSerial, + WNNErrorType type, + const char* message) { + return context->OnComputeAsyncCallback(requestSerial, type, message); } - bool Client::DoGraphComputeAsyncCallback(Graph* graph, - uint64_t requestSerial, - WNNErrorType type, - const char* message) { - return graph->OnComputeAsyncCallback(requestSerial, type, message); + bool Client::DoContextComputeSyncResult(NamedOutputs* namedOutputs, + char const* name, + uint8_t const* buffer, + size_t byteLength, + size_t byteOffset) { + return namedOutputs->OutputResult(name, buffer, byteLength, byteOffset); } } // namespace webnn::wire::client diff --git a/src/webnn/wire/client/Context.cpp b/src/webnn/wire/client/Context.cpp index 7ae2074e6..55de6b48f 100644 --- a/src/webnn/wire/client/Context.cpp +++ b/src/webnn/wire/client/Context.cpp @@ -83,4 +83,62 @@ namespace webnn::wire::client { void Context::SetUncapturedErrorCallback(WNNErrorCallback callback, void* userdata) { } + void Context::Compute(WNNGraph wnnGraph, + WNNNamedInputs inputs, + WNNNamedOutputs outputs, + WNNComputeAsyncCallback callback, + void* userdata) { + if (client->IsDisconnected()) { + callback(WNNErrorType_DeviceLost, "WebNN context disconnected", userdata); + return; + } + + uint64_t serial = mComputeAsyncRequestSerial++; + ASSERT(mComputeAsyncRequests.find(serial) == mComputeAsyncRequests.end()); + + mComputeAsyncRequests[serial] = {callback, userdata}; + + Graph* graph = FromAPI(wnnGraph); + NamedInputs* namedInputs = FromAPI(inputs); + NamedOutputs* namedOutputs = FromAPI(outputs); + + ContextComputeCmd cmd; + cmd.contextId = this->id; + cmd.graphId = graph->id; + cmd.requestSerial = serial; + cmd.inputsId = namedInputs->id; + cmd.outputsId = namedOutputs->id; + + client->SerializeCommand(cmd); + } + + void Context::ComputeSync(WNNGraph wnnGraph, WNNNamedInputs inputs, WNNNamedOutputs outputs) { + Graph* graph = FromAPI(wnnGraph); + NamedInputs* namedInputs = FromAPI(inputs); + NamedOutputs* namedOutputs = FromAPI(outputs); + + ContextComputeSyncCmd cmd; + cmd.contextId = this->id; + cmd.graphId = graph->id; + cmd.inputsId = namedInputs->id; + cmd.outputsId = namedOutputs->id; + + client->SerializeCommand(cmd); + } + + bool Context::OnComputeAsyncCallback(uint64_t requestSerial, + WNNErrorType type, + const char* message) { + auto requestIt = mComputeAsyncRequests.find(requestSerial); + if (requestIt == mComputeAsyncRequests.end()) { + return false; + } + + ComputeAsyncRequest request = std::move(requestIt->second); + + mComputeAsyncRequests.erase(requestIt); + request.callback(type, message, request.userdata); + return true; + } + } // namespace webnn::wire::client diff --git a/src/webnn/wire/client/Context.h b/src/webnn/wire/client/Context.h index e41dd7875..427e462b9 100644 --- a/src/webnn/wire/client/Context.h +++ b/src/webnn/wire/client/Context.h @@ -36,6 +36,15 @@ namespace webnn::wire::client { WNNErrorType type, const char* message); + void Compute(WNNGraph wnnGraph, + WNNNamedInputs inputs, + WNNNamedOutputs outputs, + WNNComputeAsyncCallback callback, + void* userdata); + void ComputeSync(WNNGraph wnnGraph, WNNNamedInputs inputs, WNNNamedOutputs outputs); + + bool OnComputeAsyncCallback(uint64_t requestSerial, WNNErrorType type, const char* message); + private: struct ErrorScopeData { WNNErrorCallback callback = nullptr; @@ -44,6 +53,13 @@ namespace webnn::wire::client { std::map mErrorScopes; uint64_t mErrorScopeRequestSerial = 0; uint64_t mErrorScopeStackSize = 0; + + struct ComputeAsyncRequest { + WNNComputeAsyncCallback callback = nullptr; + void* userdata = nullptr; + }; + std::map mComputeAsyncRequests; + uint64_t mComputeAsyncRequestSerial = 0; }; } // namespace webnn::wire::client diff --git a/src/webnn/wire/client/Graph.cpp b/src/webnn/wire/client/Graph.cpp index f8b658a5c..c28349101 100644 --- a/src/webnn/wire/client/Graph.cpp +++ b/src/webnn/wire/client/Graph.cpp @@ -20,57 +20,4 @@ namespace webnn::wire::client { - void Graph::Compute(WNNNamedInputs inputs, WNNNamedOutputs outputs) { - NamedInputs* namedInputs = FromAPI(inputs); - NamedOutputs* namedOutputs = FromAPI(outputs); - - GraphComputeCmd cmd; - cmd.graphId = this->id; - cmd.inputsId = namedInputs->id; - cmd.outputsId = namedOutputs->id; - - client->SerializeCommand(cmd); - } - - void Graph::ComputeAsync(WNNNamedInputs inputs, - WNNNamedOutputs outputs, - WNNComputeAsyncCallback callback, - void* userdata) { - if (client->IsDisconnected()) { - callback(WNNErrorType_DeviceLost, "WebNN context disconnected", userdata); - return; - } - - uint64_t serial = mComputeAsyncRequestSerial++; - ASSERT(mComputeAsyncRequests.find(serial) == mComputeAsyncRequests.end()); - - mComputeAsyncRequests[serial] = {callback, userdata}; - - NamedInputs* namedInputs = FromAPI(inputs); - NamedOutputs* namedOutputs = FromAPI(outputs); - - GraphComputeAsyncCmd cmd; - cmd.graphId = this->id; - cmd.requestSerial = serial; - cmd.inputsId = namedInputs->id; - cmd.outputsId = namedOutputs->id; - - client->SerializeCommand(cmd); - } - - bool Graph::OnComputeAsyncCallback(uint64_t requestSerial, - WNNErrorType type, - const char* message) { - auto requestIt = mComputeAsyncRequests.find(requestSerial); - if (requestIt == mComputeAsyncRequests.end()) { - return false; - } - - ComputeAsyncRequest request = std::move(requestIt->second); - - mComputeAsyncRequests.erase(requestIt); - request.callback(type, message, request.userdata); - return true; - } - } // namespace webnn::wire::client diff --git a/src/webnn/wire/client/Graph.h b/src/webnn/wire/client/Graph.h index 7fe2d562d..1c3578a9d 100644 --- a/src/webnn/wire/client/Graph.h +++ b/src/webnn/wire/client/Graph.h @@ -27,21 +27,6 @@ namespace webnn::wire::client { class Graph final : public ObjectBase { public: using ObjectBase::ObjectBase; - - void Compute(WNNNamedInputs inputs, WNNNamedOutputs outputs); - void ComputeAsync(WNNNamedInputs inputs, - WNNNamedOutputs outputs, - WNNComputeAsyncCallback callback, - void* userdata); - bool OnComputeAsyncCallback(uint64_t requestSerial, WNNErrorType type, const char* message); - - private: - struct ComputeAsyncRequest { - WNNComputeAsyncCallback callback = nullptr; - void* userdata = nullptr; - }; - std::map mComputeAsyncRequests; - uint64_t mComputeAsyncRequestSerial = 0; }; } // namespace webnn::wire::client diff --git a/src/webnn/wire/server/Server.h b/src/webnn/wire/server/Server.h index 850134c1a..37ef233b7 100644 --- a/src/webnn/wire/server/Server.h +++ b/src/webnn/wire/server/Server.h @@ -101,7 +101,7 @@ namespace webnn::wire::server { struct ComputeAsyncUserdata : CallbackUserdata { using CallbackUserdata::CallbackUserdata; - ObjectHandle graph; + ObjectHandle context; uint64_t requestSerial; ObjectId namedOutputsObjectID; }; @@ -159,9 +159,9 @@ namespace webnn::wire::server { void OnContextPopErrorScope(ErrorScopeUserdata* userdata, WNNErrorType type, const char* message); - void OnGraphComputeAsyncCallback(ComputeAsyncUserdata* userdata, - WNNErrorType type, - const char* message); + void OnContextComputeCallback(ComputeAsyncUserdata* userdata, + WNNErrorType type, + const char* message); #include "webnn/wire/server/ServerPrototypes_autogen.inc" WireDeserializeAllocator mAllocator; diff --git a/src/webnn/wire/server/ServerContext.cpp b/src/webnn/wire/server/ServerContext.cpp index 00030d13f..a2379f0c9 100644 --- a/src/webnn/wire/server/ServerContext.cpp +++ b/src/webnn/wire/server/ServerContext.cpp @@ -48,4 +48,93 @@ namespace webnn::wire::server { SerializeCommand(cmd); } + bool Server::SerializeComputeResult(ObjectId outputsId) { + auto* namedOutputs = NamedOutputsObjects().Get(outputsId); + if (mOutputNamesMap.find(outputsId) == mOutputNamesMap.end()) { + return false; + } + for (auto& name : mOutputNamesMap[outputsId]) { + WNNArrayBufferView arrayBuffer = {}; + mProcs.namedOutputsGet(namedOutputs->handle, name.data(), &arrayBuffer); + if (arrayBuffer.buffer == nullptr) { + return false; + } + + // Return the result. + ReturnContextComputeSyncResultCmd cmd; + cmd.namedOutputs = ObjectHandle{outputsId, namedOutputs->generation}; + cmd.name = name.data(); + cmd.buffer = static_cast(arrayBuffer.buffer); + cmd.byteLength = arrayBuffer.byteLength; + cmd.byteOffset = arrayBuffer.byteOffset; + SerializeCommand(cmd); + } + // Reset the mOutputNamesMap which host in the server. + mOutputNamesMap.erase(outputsId); + return true; + } + + void Server::OnContextComputeCallback(ComputeAsyncUserdata* userdata, + WNNErrorType type, + const char* message) { + if (type == WNNErrorType_NoError) { + SerializeComputeResult(userdata->namedOutputsObjectID); + } + ReturnContextComputeCallbackCmd cmd; + cmd.context = userdata->context; + cmd.requestSerial = userdata->requestSerial; + cmd.type = type; + cmd.message = message; + + SerializeCommand(cmd); + } + + bool Server::DoContextComputeSync(ObjectId contextId, + ObjectId graphId, + ObjectId inputsId, + ObjectId outputsId) { + auto* context = ContextObjects().Get(contextId); + auto* graph = GraphObjects().Get(graphId); + auto* namedInputs = NamedInputsObjects().Get(inputsId); + auto* namedOutputs = NamedOutputsObjects().Get(outputsId); + if (context == nullptr || graph == nullptr || namedInputs == nullptr || + namedOutputs == nullptr) { + return false; + } + + mProcs.contextComputeSync(context->handle, graph->handle, namedInputs->handle, + namedOutputs->handle); + +#if defined(WEBNN_ENABLE_GPU_BUFFER) + return true; +#else + return SerializeComputeResult(outputsId); +#endif + } + + bool Server::DoContextCompute(ObjectId contextId, + ObjectId graphId, + uint64_t requestSerial, + ObjectId inputsId, + ObjectId outputsId) { + auto* context = ContextObjects().Get(contextId); + auto* graph = GraphObjects().Get(graphId); + auto* namedInputs = NamedInputsObjects().Get(inputsId); + auto* namedOutputs = NamedOutputsObjects().Get(outputsId); + if (context == nullptr || graph == nullptr || namedInputs == nullptr || + namedOutputs == nullptr) { + return false; + } + + auto userdata = MakeUserdata(); + userdata->requestSerial = requestSerial; + userdata->context = ObjectHandle{contextId, context->generation}; + userdata->namedOutputsObjectID = outputsId; + + mProcs.contextCompute( + context->handle, graph->handle, namedInputs->handle, namedOutputs->handle, + ForwardToServer<&Server::OnContextComputeCallback>, userdata.release()); + return true; + } + } // namespace webnn::wire::server diff --git a/src/webnn/wire/server/ServerGraph.cpp b/src/webnn/wire/server/ServerGraph.cpp index b80ade192..1418adef9 100644 --- a/src/webnn/wire/server/ServerGraph.cpp +++ b/src/webnn/wire/server/ServerGraph.cpp @@ -18,84 +18,4 @@ namespace webnn::wire::server { - bool Server::SerializeComputeResult(ObjectId outputsId) { - auto* namedOutputs = NamedOutputsObjects().Get(outputsId); - if (mOutputNamesMap.find(outputsId) == mOutputNamesMap.end()) { - return false; - } - for (auto& name : mOutputNamesMap[outputsId]) { - WNNArrayBufferView arrayBuffer = {}; - mProcs.namedOutputsGet(namedOutputs->handle, name.data(), &arrayBuffer); - if (arrayBuffer.buffer == nullptr) { - return false; - } - - // Return the result. - ReturnGraphComputeResultCmd cmd; - cmd.namedOutputs = ObjectHandle{outputsId, namedOutputs->generation}; - cmd.name = name.data(); - cmd.buffer = static_cast(arrayBuffer.buffer); - cmd.byteLength = arrayBuffer.byteLength; - cmd.byteOffset = arrayBuffer.byteOffset; - SerializeCommand(cmd); - } - // Reset the mOutputNamesMap which host in the server. - mOutputNamesMap.erase(outputsId); - return true; - } - - bool Server::DoGraphCompute(ObjectId graphId, ObjectId inputsId, ObjectId outputsId) { - auto* graph = GraphObjects().Get(graphId); - auto* namedInputs = NamedInputsObjects().Get(inputsId); - auto* namedOutputs = NamedOutputsObjects().Get(outputsId); - if (graph == nullptr || namedInputs == nullptr || namedOutputs == nullptr) { - return false; - } - - mProcs.graphCompute(graph->handle, namedInputs->handle, namedOutputs->handle); - -#if defined(WEBNN_ENABLE_GPU_BUFFER) - return true; -#else - return SerializeComputeResult(outputsId); -#endif - } - - bool Server::DoGraphComputeAsync(ObjectId graphId, - uint64_t requestSerial, - ObjectId inputsId, - ObjectId outputsId) { - auto* graph = GraphObjects().Get(graphId); - auto* namedInputs = NamedInputsObjects().Get(inputsId); - auto* namedOutputs = NamedOutputsObjects().Get(outputsId); - if (graph == nullptr || namedInputs == nullptr || namedOutputs == nullptr) { - return false; - } - - auto userdata = MakeUserdata(); - userdata->requestSerial = requestSerial; - userdata->graph = ObjectHandle{graphId, graph->generation}; - userdata->namedOutputsObjectID = outputsId; - - mProcs.graphComputeAsync(graph->handle, namedInputs->handle, namedOutputs->handle, - ForwardToServer<&Server::OnGraphComputeAsyncCallback>, - userdata.release()); - return true; - } - - void Server::OnGraphComputeAsyncCallback(ComputeAsyncUserdata* userdata, - WNNErrorType type, - const char* message) { - if (type == WNNErrorType_NoError) { - SerializeComputeResult(userdata->namedOutputsObjectID); - } - ReturnGraphComputeAsyncCallbackCmd cmd; - cmd.graph = userdata->graph; - cmd.requestSerial = userdata->requestSerial; - cmd.type = type; - cmd.message = message; - - SerializeCommand(cmd); - } - } // namespace webnn::wire::server diff --git a/webnn.json b/webnn.json index b0e40fcb6..56b26f1fe 100644 --- a/webnn.json +++ b/webnn.json @@ -179,6 +179,26 @@ "context": { "category": "object", "methods": [ + { + "name": "compute", + "returns": "void", + "args": [ + {"name": "graph", "type": "graph"}, + {"name": "inputs", "type": "named inputs"}, + {"name": "outputs", "type": "named outputs"}, + {"name": "callback", "type": "compute async callback"}, + {"name": "userdata", "type": "void", "annotation": "*"} + ] + }, + { + "name": "compute sync", + "returns": "void", + "args": [ + {"name": "graph", "type": "graph"}, + {"name": "inputs", "type": "named inputs"}, + {"name": "outputs", "type": "named outputs"} + ] + }, { "name": "inject error", "args": [ @@ -1073,26 +1093,6 @@ ] }, "graph": { - "category": "object", - "methods": [ - { - "name": "compute", - "returns": "void", - "args": [ - {"name": "inputs", "type": "named inputs"}, - {"name": "outputs", "type": "named outputs"} - ] - }, - { - "name": "compute async", - "returns": "void", - "args": [ - {"name": "inputs", "type": "named inputs"}, - {"name": "outputs", "type": "named outputs"}, - {"name": "callback", "type": "compute async callback"}, - {"name": "userdata", "type": "void", "annotation": "*"} - ] - } - ] + "category": "object" } } diff --git a/webnn_wire.json b/webnn_wire.json index c947002ca..058f1c8d4 100644 --- a/webnn_wire.json +++ b/webnn_wire.json @@ -20,6 +20,19 @@ {"name": "context id", "type": "ObjectId"}, {"name": "request serial", "type": "uint64_t"} ], + "context compute": [ + {"name": "context id", "type": "ObjectId"}, + { "name": "graph id", "type": "ObjectId" }, + { "name": "request serial", "type": "uint64_t" }, + {"name": "inputs id", "type": "ObjectId"}, + {"name": "outputs id", "type": "ObjectId"} + ], + "context compute sync": [ + {"name": "context id", "type": "ObjectId"}, + {"name": "graph id", "type": "ObjectId"}, + {"name": "inputs id", "type": "ObjectId"}, + {"name": "outputs id", "type": "ObjectId"} + ], "graph builder constant internal": [ {"name": "graph builder id", "type": "ObjectId"}, {"name": "desc", "type": "operand descriptor", "annotation": "const*"}, @@ -63,17 +76,6 @@ {"name": "generation", "type": "uint32_t", "default": 0}, {"name": "result", "type": "ObjectHandle", "handle_type": "context"} ], - "graph compute": [ - {"name": "graph id", "type": "ObjectId"}, - {"name": "inputs id", "type": "ObjectId"}, - {"name": "outputs id", "type": "ObjectId"} - ], - "graph compute async": [ - { "name": "graph id", "type": "ObjectId" }, - { "name": "request serial", "type": "uint64_t" }, - {"name": "inputs id", "type": "ObjectId"}, - {"name": "outputs id", "type": "ObjectId"} - ], "operand array size": [ {"name": "operand array id", "type": "ObjectId"} ], @@ -115,18 +117,18 @@ {"name": "type", "type": "error type"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"} ], - "graph compute result": [ + "context compute callback": [ + { "name": "context", "type": "ObjectHandle", "handle_type": "context" }, + { "name": "request serial", "type": "uint64_t" }, + { "name": "type", "type": "error type"}, + { "name": "message", "type": "char", "annotation": "const*", "length": "strlen" } + ], + "context compute sync result": [ {"name": "named outputs", "type": "ObjectHandle", "handle_type": "named outputs"}, {"name": "name", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "buffer", "type": "uint8_t", "annotation": "const*", "length": "byte length"}, {"name": "byte length", "type": "size_t"}, {"name": "byte offset", "type": "size_t", "default": 0} - ], - "graph compute async callback": [ - { "name": "graph", "type": "ObjectHandle", "handle_type": "graph" }, - { "name": "request serial", "type": "uint64_t" }, - { "name": "type", "type": "error type"}, - { "name": "message", "type": "char", "annotation": "const*", "length": "strlen" } ] }, "special items": { @@ -140,6 +142,8 @@ "client_side_commands": [ "ContextPopErrorScope", "ContextSetUncapturedErrorCallback", + "ContextCompute", + "ContextComputeSync", "GraphBuilderConstant", "GraphBuilderConstantWithGpuBuffer", "GraphBuilderGru", @@ -149,9 +153,7 @@ "NamedOutputsSet", "NamedOutputsGet", "OperandArraySize", - "OperatorArraySize", - "GraphComputeAsync", - "GraphCompute" + "OperatorArraySize" ], "client_handwritten_commands": [ "ContextPushErrorScope"