Skip to content

Commit

Permalink
Create and initialize Device object: the first PR for rewriting DML b…
Browse files Browse the repository at this point in the history
…ackend
  • Loading branch information
mingmingtasd committed Aug 5, 2022
1 parent 294de6b commit 9697ccc
Show file tree
Hide file tree
Showing 10 changed files with 478 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/webnn/native/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,19 @@ source_set("sources") {
}
}

if (webnn_enable_dml) {
sources += [
"dml/BackendDML.cpp",
"dml/BackendDML.h",
"dml/ContextDML.cpp",
"dml/ContextDML.h",
"dml/ExecutionContextDML.cpp",
"dml/ExecutionContextDML.h",
"dml/GraphDML.cpp",
"dml/GraphDML.h",
]
}

if (webnn_enable_dmlx) {
if (webnn_enable_gpu_buffer == false) {
sources += [
Expand Down
46 changes: 46 additions & 0 deletions src/webnn/native/dml/BackendDML.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2019 The Dawn Authors
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "webnn/native/dml/BackendDML.h"

#include "webnn/native/Instance.h"
#include "webnn/native/dml/ContextDML.h"

namespace webnn::native::dml {

Backend::Backend(InstanceBase* instance)
: BackendConnection(instance, wnn::BackendType::DirectML) {
}

MaybeError Backend::Initialize() {
return {};
}

ContextBase* Backend::CreateContext(ContextOptions const* options) {
return new Context(options);
}

BackendConnection* Connect(InstanceBase* instance) {
Backend* backend = new Backend(instance);

if (instance->ConsumedError(backend->Initialize())) {
delete backend;
return nullptr;
}

return backend;
}

} // namespace webnn::native::dml
38 changes: 38 additions & 0 deletions src/webnn/native/dml/BackendDML.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2019 The Dawn Authors
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef WEBNN_NATIVE_DML_BACKEND_DML_H_
#define WEBNN_NATIVE_DML_BACKEND_DML_H_

#include <memory>
#include "webnn/native/BackendConnection.h"
#include "webnn/native/Context.h"
#include "webnn/native/Error.h"

namespace webnn::native::dml {

class Backend : public BackendConnection {
public:
Backend(InstanceBase* instance);

MaybeError Initialize();
ContextBase* CreateContext(ContextOptions const* options = nullptr) override;

private:
};

} // namespace webnn::native::dml

#endif // WEBNN_NATIVE_DML_BACKEND_DML_H_
29 changes: 29 additions & 0 deletions src/webnn/native/dml/ContextDML.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "webnn/native/dml/ContextDML.h"

#include "common/RefCounted.h"
#include "webnn/native/dml/GraphDML.h"

namespace webnn::native::dml {

Context::Context(ContextOptions const* options) : ContextBase(options) {
}

GraphBase* Context::CreateGraphImpl() {
return new Graph(this);
}

} // namespace webnn::native::dml
34 changes: 34 additions & 0 deletions src/webnn/native/dml/ContextDML.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef WEBNN_NATIVE_DML_CONTEXT_DML_H_
#define WEBNN_NATIVE_DML_CONTEXT_DML_H_

#include "webnn/native/Context.h"
#include "webnn/native/Graph.h"

namespace webnn::native::dml {

class Context : public ContextBase {
public:
explicit Context(ContextOptions const* options);
~Context() override = default;

private:
GraphBase* CreateGraphImpl() override;
};

} // namespace webnn::native::dml

#endif // WEBNN_NATIVE_DML_CONTEXT_DML_H_
106 changes: 106 additions & 0 deletions src/webnn/native/dml/ExecutionContextDML.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ExecutionContextDML.h"

namespace webnn::native::dml {

// An adapter called the "Microsoft Basic Render Driver" is always present. This adapter is a
// render-only device that has no display outputs.
inline bool IsSoftwareAdapter(IDXGIAdapter1* pAdapter) {
DXGI_ADAPTER_DESC1 pDesc;
pAdapter->GetDesc1(&pDesc);
// See here for documentation on filtering WARP adapter:
// https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8
return pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE ||
(pDesc.VendorId == 0x1414 && pDesc.DeviceId == 0x8c);
}

HRESULT EnumAdapter(DXGI_GPU_PREFERENCE gpuPreference,
bool useGpu,
ComPtr<IDXGIAdapter1> adapter) {
ComPtr<IDXGIFactory6> dxgiFactory;
RETURN_IF_FAILED(CreateDXGIFactory1(IID_PPV_ARGS(&dxgiFactory)));
if (useGpu) {
UINT adapterIndex = 0;
while (dxgiFactory->EnumAdapterByGpuPreference(adapterIndex++, gpuPreference,
IID_PPV_ARGS(&adapter)) !=
DXGI_ERROR_NOT_FOUND) {
if (!IsSoftwareAdapter(adapter.Get())) {
break;
}
}
} else {
RETURN_IF_FAILED(dxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&adapter)));
}
return S_OK;
}

ExecutionContext::ExecutionContext(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer)
: mAdapter(std::move(adapter)), mUseDebugLayer(useDebugLayer) {
}

// static
std::unique_ptr<ExecutionContext> ExecutionContext::Create(ComPtr<IDXGIAdapter1> adapter,
bool useDebugLayer) {
std::unique_ptr<ExecutionContext> executionContext(
new ExecutionContext(adapter, useDebugLayer));
if (FAILED(executionContext->Initialize())) {
dawn::ErrorLog() << "Failed to initialize Device.";
return nullptr;
}
return executionContext;
}

HRESULT ExecutionContext::Initialize() {
if (mUseDebugLayer) {
ComPtr<ID3D12Debug> debug;
if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug)))) {
debug->EnableDebugLayer();
}
}
RETURN_IF_FAILED(
D3D12CreateDevice(mAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&mD3D12Device)));
D3D12_COMMAND_QUEUE_DESC commandQueueDesc{};
commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
commandQueueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
RETURN_IF_FAILED(
mD3D12Device->CreateCommandQueue(&commandQueueDesc, IID_PPV_ARGS(&mCommandQueue)));
RETURN_IF_FAILED(mD3D12Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT,
IID_PPV_ARGS(&mCommandAllocator)));
RETURN_IF_FAILED(mD3D12Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT,
mCommandAllocator.Get(), nullptr,
IID_PPV_ARGS(&mCommandList)));

// Create the DirectML device.
DML_CREATE_DEVICE_FLAGS dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_NONE;
#if defined(_DEBUG)
dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_DEBUG;
#endif
if (dmlCreateDeviceFlags == DML_CREATE_DEVICE_FLAG_DEBUG) {
if (FAILED(DMLCreateDevice(mD3D12Device.Get(), dmlCreateDeviceFlags,
IID_PPV_ARGS(&mDevice)))) {
dawn::WarningLog() << "Failed to create a DirectML device with debug flag, "
"will fall back to use none flag.";
RETURN_IF_FAILED(DMLCreateDevice(mD3D12Device.Get(), DML_CREATE_DEVICE_FLAG_NONE,
IID_PPV_ARGS(&mDevice)));
}
} else {
RETURN_IF_FAILED(
DMLCreateDevice(mD3D12Device.Get(), dmlCreateDeviceFlags, IID_PPV_ARGS(&mDevice)));
}
return S_OK;
};

} // namespace webnn::native::dml
63 changes: 63 additions & 0 deletions src/webnn/native/dml/ExecutionContextDML.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2022 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_
#define WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_

#include <webnn/webnn_cpp.h>
#include <unordered_map>

#include "common/Log.h"
#include "dml_platform.h"
#include "webnn/native/NamedOutputs.h"
#include "webnn/native/webnn_platform.h"

#define RETURN_IF_FAILED(EXPR) \
do { \
auto HR = EXPR; \
if (FAILED(HR)) { \
dawn::ErrorLog() << "Failed to do " << #EXPR << " Return HRESULT " << std::hex << HR; \
return HR; \
} \
} while (0)

namespace webnn::native::dml {

HRESULT EnumAdapter(DXGI_GPU_PREFERENCE gpuPreference,
bool useGpu,
ComPtr<IDXGIAdapter1> adapter);

class ExecutionContext {
public:
static std::unique_ptr<ExecutionContext> Create(ComPtr<IDXGIAdapter1> adapter,
bool useDebugLayer);

private:
ExecutionContext(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer);
HRESULT Initialize();

ComPtr<IDMLDevice> mDevice;
ComPtr<ID3D12Device> mD3D12Device;
ComPtr<IDMLCommandRecorder> mCommandRecorder;
ComPtr<ID3D12CommandQueue> mCommandQueue;
ComPtr<ID3D12CommandAllocator> mCommandAllocator;
ComPtr<ID3D12GraphicsCommandList> mCommandList;

ComPtr<IDXGIAdapter1> mAdapter;
bool mUseDebugLayer = false;
};

} // namespace webnn::native::dml

#endif // WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_
Loading

0 comments on commit 9697ccc

Please sign in to comment.