Skip to content

Commit

Permalink
[WebNN-native/Node] support context-based compute graph
Browse files Browse the repository at this point in the history
  • Loading branch information
miaobin committed Jul 18, 2022
1 parent 31d48a4 commit 38bb77b
Show file tree
Hide file tree
Showing 76 changed files with 539 additions and 493 deletions.
4 changes: 1 addition & 3 deletions examples/LeNet/LeNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class LeNet {
LeNet();
~LeNet() = default;

wnn::Graph Build(const std::string& weigthsPath);

private:
wnn::Context mContext;
wnn::Graph Build(const std::string& weigthsPath);
};
2 changes: 1 addition & 1 deletion examples/LeNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ int main(int argc, const char* argv[]) {
for (int i = 0; i < nIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
utils::Compute(graph, {{"input", input}}, {{"output", result}});
utils::Compute(lenet.mContext, graph, {{"input", input}}, {{"output", result}});
executionTimeVector.push_back(std::chrono::high_resolution_clock::now() -
executionStartTime);
}
Expand Down
4 changes: 2 additions & 2 deletions examples/MobileNetV2/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(mobilevetv2.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (mobilevetv2.mNIter > 1) {
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < mobilevetv2.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

Expand Down
4 changes: 2 additions & 2 deletions examples/ResNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(resnet.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (resnet.mNIter > 1) {
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < resnet.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

Expand Down
5 changes: 3 additions & 2 deletions examples/SampleUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,11 @@ namespace utils {
return builder.Build(namedOperands);
}

void Compute(const wnn::Graph& graph,
void Compute(const wnn::Context& context,
const wnn::Graph& graph,
const std::vector<NamedInput<float>>& inputs,
const std::vector<NamedOutput<float>>& outputs) {
return Compute<float>(graph, inputs, outputs);
return Compute<float>(context, graph, inputs, outputs);
}

std::vector<std::string> ReadTopKLabel(const std::vector<size_t>& topKIndex,
Expand Down
8 changes: 5 additions & 3 deletions examples/SampleUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ namespace utils {
};

template <typename T>
void Compute(const wnn::Graph& graph,
void Compute(const wnn::Context& context,
const wnn::Graph& graph,
const std::vector<NamedInput<T>>& inputs,
const std::vector<NamedOutput<T>>& outputs) {
if (graph.GetHandle() == nullptr) {
Expand Down Expand Up @@ -274,11 +275,12 @@ namespace utils {
mlOutputs.push_back(resource);
namedOutputs.Set(output.name.c_str(), &mlOutputs.back());
}
graph.Compute(namedInputs, namedOutputs);
context.ComputeSync(graph, namedInputs, namedOutputs);
DoFlush();
}

void Compute(const wnn::Graph& graph,
void Compute(const wnn::Context& context,
const wnn::Graph& graph,
const std::vector<NamedInput<float>>& inputs,
const std::vector<NamedOutput<float>>& outputs);

Expand Down
4 changes: 2 additions & 2 deletions examples/SqueezeNet/Main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ int main(int argc, const char* argv[]) {
std::vector<float> result(utils::SizeOfShape(squeezenet.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (squeezenet.mNIter > 1) {
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
}

std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < squeezenet.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
utils::Compute(context, graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}

Expand Down
31 changes: 30 additions & 1 deletion node/src/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include <napi.h>
#include <iostream>

#include "Graph.h"
#include "ML.h"
#include "Utils.h"

Napi::FunctionReference node::Context::constructor;

Expand Down Expand Up @@ -90,10 +92,37 @@ namespace node {

Napi::Object Context::Initialize(Napi::Env env, Napi::Object exports) {
Napi::HandleScope scope(env);
Napi::Function func = DefineClass(env, "MLContext", {});
Napi::Function func = DefineClass(
env, "MLContext", {InstanceMethod("compute", &Context::Compute, napi_enumerable)});
constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
exports.Set("MLContext", func);
return exports;
}

Napi::Value Context::Compute(const Napi::CallbackInfo& info) {
// status compute(NamedInputs inputs, NamedOutputs outputs);
WEBNN_NODE_ASSERT(info.Length() == 3, "The number of arguments is invalid.");
Napi::Object object = info[0].As<Napi::Object>();
node::Graph* jsGraph = Napi::ObjectWrap<node::Graph>::Unwrap(object);

std::map<std::string, Input> inputs;
WEBNN_NODE_ASSERT(GetNamedInputs(info[1], inputs), "The inputs parameter is invalid.");

std::map<std::string, wnn::Resource> outputs;
WEBNN_NODE_ASSERT(GetNamedOutputs(info[2], outputs), "The outputs parameter is invalid.");

wnn::NamedInputs namedInputs = wnn::CreateNamedInputs();
for (auto& input : inputs) {
namedInputs.Set(input.first.data(), input.second.AsPtr());
}
wnn::NamedOutputs namedOutputs = wnn::CreateNamedOutputs();
for (auto& output : outputs) {
namedOutputs.Set(output.first.data(), &output.second);
}
mImpl.ComputeSync(jsGraph->GetImpl(), namedInputs, namedOutputs);

return Napi::Number::New(info.Env(), 0);
}

} // namespace node
1 change: 1 addition & 0 deletions node/src/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace node {
wnn::Context GetImpl();

private:
Napi::Value Compute(const Napi::CallbackInfo& info);
wnn::Context mImpl;
};

Expand Down
121 changes: 4 additions & 117 deletions node/src/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,135 +13,22 @@
// limitations under the License.

#include "Graph.h"

#include <iostream>
#include <map>

#include "Utils.h"

namespace node {

struct Input {
public:
wnn::ArrayBufferView bufferView;
std::vector<int32_t> dimensions;

const wnn::Input* AsPtr() {
mInput.resource.arrayBufferView = bufferView;
mInput.resource.gpuBufferView = {};
if (!dimensions.empty()) {
mInput.dimensions = dimensions.data();
mInput.dimensionsCount = dimensions.size();
}
return &mInput;
}

private:
wnn::Input mInput;
};

bool GetNamedInputs(const Napi::Value& jsValue, std::map<std::string, Input>& namedInputs) {
if (!jsValue.IsObject()) {
return false;
}
Napi::Object jsNamedInputs = jsValue.As<Napi::Object>();
Napi::Array names = jsNamedInputs.GetPropertyNames();
if (names.Length() == 0) {
return false;
}
// typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource;
// dictionary MLInput {
// required MLResource resource;
// required sequence<long> dimensions;
// };
// typedef record<DOMString, (MLResource or MLInput)> MLNamedInputs;
for (size_t i = 0; i < names.Length(); ++i) {
Input input = {};
std::string name = names.Get(i).As<Napi::String>().Utf8Value();
// FIXME: validate the type of typed array.
Napi::TypedArray jsTypedArray;
if (jsNamedInputs.Get(name).IsTypedArray()) {
jsTypedArray = jsNamedInputs.Get(name).As<Napi::TypedArray>();
} else {
Napi::Object jsInput = jsNamedInputs.Get(name).As<Napi::Object>();
if (!jsInput.Has("resource") || !jsInput.Has("dimensions")) {
// Input resource and dimensions are required.
return false;
}
if (!jsInput.Get("resource").IsTypedArray()) {
return false;
}
jsTypedArray = jsInput.Get("resource").As<Napi::TypedArray>();

if (!GetArray(jsInput.Get("dimensions"), input.dimensions)) {
return false;
}
if (SizeOfShape(input.dimensions) != jsTypedArray.ElementSize()) {
return false;
}
}
if (!GetArrayBufferView(jsTypedArray, input.bufferView)) {
return false;
}
namedInputs[name] = input;
}
return true;
Graph::Graph(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Graph>(info) {
}

bool GetNamedOutputs(const Napi::Value& jsValue,
std::map<std::string, wnn::Resource>& namedOutputs) {
if (!jsValue.IsObject()) {
return false;
}
Napi::Object jsNamedOutputs = jsValue.As<Napi::Object>();
Napi::Array names = jsNamedOutputs.GetPropertyNames();
if (names.Length() == 0) {
return false;
}
// typedef (MLBufferView or WebGLTexture or GPUTexture) MLResource;
// typedef record<DOMString, MLResource> MLNamedOutputs;
for (size_t i = 0; i < names.Length(); ++i) {
wnn::ArrayBufferView arrayBuffer = {};
std::string name = names.Get(i).As<Napi::String>().Utf8Value();
if (!GetArrayBufferView(jsNamedOutputs.Get(name), arrayBuffer)) {
return false;
}
namedOutputs[name] = {arrayBuffer, {}};
}
return true;
wnn::Graph Graph::GetImpl() {
return mImpl;
}

Napi::FunctionReference Graph::constructor;

Graph::Graph(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Graph>(info) {
}

Napi::Value Graph::Compute(const Napi::CallbackInfo& info) {
// status compute(NamedInputs inputs, NamedOutputs outputs);
WEBNN_NODE_ASSERT(info.Length() == 2, "The number of arguments is invalid.");
std::map<std::string, Input> inputs;
WEBNN_NODE_ASSERT(GetNamedInputs(info[0], inputs), "The inputs parameter is invalid.");

std::map<std::string, wnn::Resource> outputs;
WEBNN_NODE_ASSERT(GetNamedOutputs(info[1], outputs), "The outputs parameter is invalid.");

wnn::NamedInputs namedInputs = wnn::CreateNamedInputs();
for (auto& input : inputs) {
namedInputs.Set(input.first.data(), input.second.AsPtr());
}
wnn::NamedOutputs namedOutputs = wnn::CreateNamedOutputs();
for (auto& output : outputs) {
namedOutputs.Set(output.first.data(), &output.second);
}
mImpl.Compute(namedInputs, namedOutputs);

return Napi::Number::New(info.Env(), 0);
}

Napi::Object Graph::Initialize(Napi::Env env, Napi::Object exports) {
Napi::HandleScope scope(env);
Napi::Function func = DefineClass(
env, "MLGraph", {InstanceMethod("compute", &Graph::Compute, napi_enumerable)});
Napi::Function func = DefineClass(env, "MLGraph", {});
constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
exports.Set("MLGraph", func);
Expand Down
5 changes: 2 additions & 3 deletions node/src/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ namespace node {
explicit Graph(const Napi::CallbackInfo& info);
~Graph() = default;

wnn::Graph GetImpl();

private:
friend BuildGraphWorker;
friend GraphBuilder;

Napi::Value Compute(const Napi::CallbackInfo& info);

wnn::Graph mImpl;
std::vector<std::string> mOutputNames;
};

} // namespace node
Expand Down
1 change: 0 additions & 1 deletion node/src/GraphBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ namespace node {
Napi::Object object = node::Graph::constructor.New({});
node::Graph* jsGraph = Napi::ObjectWrap<node::Graph>::Unwrap(object);
jsGraph->mImpl = graph;
jsGraph->mOutputNames = names;
return object;
}

Expand Down
Loading

0 comments on commit 38bb77b

Please sign in to comment.