From aa537d1db67494299858cf48fe43d601b3d280ba Mon Sep 17 00:00:00 2001 From: fujunwei Date: Thu, 15 Apr 2021 09:28:56 +0800 Subject: [PATCH 1/2] Implement new spec api --- examples/LeNet/LeNet.cpp | 83 +++++----- examples/LeNet/LeNet.h | 10 +- examples/LeNet/Main.cpp | 8 +- examples/SampleUtils.cpp | 77 ++++----- examples/SampleUtils.h | 37 ++--- generator/templates/webnn.h | 14 +- generator/templates/webnn_cpp.cpp | 6 +- generator/templates/webnn_cpp.h | 3 +- .../templates/webnn_native/ProcTable.cpp | 17 +- .../webnn_native/ValidationUtils.cpp | 8 +- .../templates/webnn_native/ValidationUtils.h | 2 +- .../templates/webnn_native/webnn_structs.h | 4 +- generator/templates/webnn_proc.c | 10 +- generator/templates/webnn_proc_table.h | 1 + generator/webnn_json_generator.py | 8 +- src/common/BUILD.gn | 2 +- src/include/webnn_native/WebnnNative.h | 2 +- src/tests/BUILD.gn | 4 +- src/tests/WebnnTest.cpp | 12 +- src/tests/WebnnTest.h | 8 +- src/tests/end2end/AddTests.cpp | 51 +++--- src/tests/end2end/Conv2dTests.cpp | 68 ++++---- src/tests/end2end/MatMulTests.cpp | 135 +++++++-------- src/tests/end2end/MulTests.cpp | 49 +++--- src/tests/end2end/Pool2dTests.cpp | 104 ++++++------ src/tests/end2end/ReluTests.cpp | 13 +- src/tests/end2end/ReshapeTests.cpp | 13 +- src/tests/end2end/SoftmaxTests.cpp | 13 +- src/tests/end2end/TransposeTests.cpp | 15 +- src/tests/unittests/ObjectBaseTests.cpp | 2 +- .../validation/BinaryValidationTests.cpp | 19 +-- .../validation/Conv2dValidationTests.cpp | 43 +++-- .../validation/ErrorScopeValidationTests.cpp | 33 ++-- .../validation/GraphValidationTests.cpp | 74 +++++++++ .../validation/ModelValidationTests.cpp | 78 --------- .../validation/PoolValidationTests.cpp | 36 ++-- .../validation/ReshapeValidationTests.cpp | 10 +- .../validation/TransposeValidationTests.cpp | 22 +-- .../validation/UnaryValidationTests.cpp | 22 +-- .../unittests/validation/ValidationTest.cpp | 10 +- .../unittests/validation/ValidationTest.h | 6 +- src/webnn_native/BUILD.gn | 42 ++--- src/webnn_native/Compilation.cpp | 30 ---- src/webnn_native/Compilation.h | 46 ------ .../{NeuralNetworkContext.cpp => Context.cpp} | 23 +-- .../{NeuralNetworkContext.h => Context.h} | 26 +-- src/webnn_native/Error.cpp | 18 +- src/webnn_native/Error.h | 4 +- src/webnn_native/ErrorData.h | 4 +- src/webnn_native/ErrorScope.cpp | 34 ++-- src/webnn_native/ErrorScope.h | 14 +- src/webnn_native/Forward.h | 4 +- src/webnn_native/Graph.cpp | 74 +++++++++ src/webnn_native/{Model.h => Graph.h} | 29 ++-- .../{ModelBuilder.cpp => GraphBuilder.cpp} | 68 +++++--- .../{ModelBuilder.h => GraphBuilder.h} | 13 +- src/webnn_native/Model.cpp | 92 ----------- src/webnn_native/ObjectBase.cpp | 7 +- src/webnn_native/ObjectBase.h | 10 +- src/webnn_native/Operand.cpp | 15 +- src/webnn_native/Operand.h | 14 +- src/webnn_native/WebnnNative.cpp | 17 +- src/webnn_native/dml/CompilationDML.cpp | 107 ------------ src/webnn_native/dml/CompilationDML.h | 48 ------ ...alNetworkContextDML.cpp => ContextDML.cpp} | 24 ++- ...NeuralNetworkContextDML.h => ContextDML.h} | 19 ++- .../dml/{ModelDML.cpp => GraphDML.cpp} | 131 +++++++++++---- .../dml/{ModelDML.h => GraphDML.h} | 16 +- src/webnn_native/dml/ModelBuilderDML.cpp | 30 ---- src/webnn_native/dml/ModelBuilderDML.h | 33 ---- src/webnn_native/null/ContextNull.cpp | 86 ++++++++++ ...uralNetworkContextNull.h => ContextNull.h} | 60 +++---- .../null/NeuralNetworkContextNull.cpp | 102 ------------ src/webnn_native/openvino/CompilationIE.cpp | 123 -------------- src/webnn_native/openvino/CompilationIE.h | 45 ----- ...uralNetworkContextIE.cpp => ContextIE.cpp} | 19 ++- .../{ModelBuilderIE.h => ContextIE.h} | 19 ++- .../openvino/{ModelIE.cpp => GraphIE.cpp} | 155 +++++++++++++----- .../openvino/{ModelIE.h => GraphIE.h} | 22 +-- src/webnn_native/openvino/ModelBuilderIE.cpp | 29 ---- .../openvino/NeuralNetworkContextIE.h | 34 ---- src/webnn_native/ops/Binary.h | 6 +- src/webnn_native/ops/Constant.h | 6 +- src/webnn_native/ops/Conv2d.cpp | 6 +- src/webnn_native/ops/Conv2d.h | 6 +- src/webnn_native/ops/Input.h | 6 +- src/webnn_native/ops/Pool2d.cpp | 6 +- src/webnn_native/ops/Pool2d.h | 6 +- src/webnn_native/ops/Reshape.h | 6 +- src/webnn_native/ops/Transpose.h | 6 +- src/webnn_native/ops/Unary.h | 6 +- webnn.json | 88 +++++----- 92 files changed, 1238 insertions(+), 1728 deletions(-) create mode 100644 src/tests/unittests/validation/GraphValidationTests.cpp delete mode 100644 src/tests/unittests/validation/ModelValidationTests.cpp delete mode 100644 src/webnn_native/Compilation.cpp delete mode 100644 src/webnn_native/Compilation.h rename src/webnn_native/{NeuralNetworkContext.cpp => Context.cpp} (68%) rename src/webnn_native/{NeuralNetworkContext.h => Context.h} (66%) create mode 100644 src/webnn_native/Graph.cpp rename src/webnn_native/{Model.h => Graph.h} (72%) rename src/webnn_native/{ModelBuilder.cpp => GraphBuilder.cpp} (72%) rename src/webnn_native/{ModelBuilder.h => GraphBuilder.h} (86%) delete mode 100644 src/webnn_native/Model.cpp delete mode 100644 src/webnn_native/dml/CompilationDML.cpp delete mode 100644 src/webnn_native/dml/CompilationDML.h rename src/webnn_native/dml/{NeuralNetworkContextDML.cpp => ContextDML.cpp} (63%) rename src/webnn_native/dml/{NeuralNetworkContextDML.h => ContextDML.h} (68%) rename src/webnn_native/dml/{ModelDML.cpp => GraphDML.cpp} (84%) rename src/webnn_native/dml/{ModelDML.h => GraphDML.h} (86%) delete mode 100644 src/webnn_native/dml/ModelBuilderDML.cpp delete mode 100644 src/webnn_native/dml/ModelBuilderDML.h create mode 100644 src/webnn_native/null/ContextNull.cpp rename src/webnn_native/null/{NeuralNetworkContextNull.h => ContextNull.h} (52%) delete mode 100644 src/webnn_native/null/NeuralNetworkContextNull.cpp delete mode 100644 src/webnn_native/openvino/CompilationIE.cpp delete mode 100644 src/webnn_native/openvino/CompilationIE.h rename src/webnn_native/openvino/{NeuralNetworkContextIE.cpp => ContextIE.cpp} (70%) rename src/webnn_native/openvino/{ModelBuilderIE.h => ContextIE.h} (64%) rename src/webnn_native/openvino/{ModelIE.cpp => GraphIE.cpp} (58%) rename src/webnn_native/openvino/{ModelIE.h => GraphIE.h} (83%) delete mode 100644 src/webnn_native/openvino/ModelBuilderIE.cpp delete mode 100644 src/webnn_native/openvino/NeuralNetworkContextIE.h diff --git a/examples/LeNet/LeNet.cpp b/examples/LeNet/LeNet.cpp index 6b9022230..4ebc31c8e 100644 --- a/examples/LeNet/LeNet.cpp +++ b/examples/LeNet/LeNet.cpp @@ -22,10 +22,10 @@ const size_t WEIGHTS_LENGTH = 1724336; LeNet::LeNet() { - mContext = CreateCppNeuralNetworkContext(); + mContext = CreateCppContext(); mContext.SetUncapturedErrorCallback( - [](WebnnErrorType type, char const* message, void* userData) { - if (type != WebnnErrorType_NoError) { + [](MLErrorType type, char const* message, void* userData) { + if (type != MLErrorType_NoError) { dawn::ErrorLog() << "Error type is " << type << ", message is " << message; } }, @@ -48,55 +48,55 @@ bool LeNet::Load(const std::string& weigthsPath) { return false; } - const webnn::ModelBuilder builder = mContext.CreateModelBuilder(); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(mContext); uint32_t byteOffset = 0; - const webnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 28, 28}); + const ml::Operand input = utils::BuildInput(builder, "input", {1, 1, 28, 28}); const std::vector conv2d1FilterShape = {20, 1, 5, 5}; const float* conv2d1FilterData = reinterpret_cast(weightsData.get() + byteOffset); const uint32_t conv2d1FilterDataLength = product(conv2d1FilterShape) * sizeof(float); byteOffset += conv2d1FilterDataLength; - const webnn::Operand conv2d1FilterConstant = utils::BuildConstant( + const ml::Operand conv2d1FilterConstant = utils::BuildConstant( builder, conv2d1FilterShape, conv2d1FilterData, conv2d1FilterDataLength); - const webnn::Operand conv1 = builder.Conv2d(input, conv2d1FilterConstant); + const ml::Operand conv1 = builder.Conv2d(input, conv2d1FilterConstant); const std::vector add1BiasShape = {1, 20, 1, 1}; const float* add1BiasData = reinterpret_cast(weightsData.get() + byteOffset); const uint32_t add1BiasDataLength = product(add1BiasShape) * sizeof(float); byteOffset += add1BiasDataLength; - const webnn::Operand add1BiasConstant = + const ml::Operand add1BiasConstant = utils::BuildConstant(builder, add1BiasShape, add1BiasData, add1BiasDataLength); - const webnn::Operand add1 = builder.Add(conv1, add1BiasConstant); + const ml::Operand add1 = builder.Add(conv1, add1BiasConstant); utils::Pool2dOptions pool1Options; pool1Options.windowDimensions = {2, 2}; pool1Options.strides = {2, 2}; - const webnn::Operand pool1 = builder.MaxPool2d(add1, pool1Options.AsPtr()); + const ml::Operand pool1 = builder.MaxPool2d(add1, pool1Options.AsPtr()); const std::vector conv2d2FilterShape = {50, 20, 5, 5}; const float* conv2d2FilterData = reinterpret_cast(weightsData.get() + byteOffset); const uint32_t conv2d2FilterDataLength = product(conv2d2FilterShape) * sizeof(float); byteOffset += conv2d2FilterDataLength; - const webnn::Operand conv2d2FilterConstant = utils::BuildConstant( + const ml::Operand conv2d2FilterConstant = utils::BuildConstant( builder, conv2d2FilterShape, conv2d2FilterData, conv2d2FilterDataLength); - const webnn::Operand conv2 = builder.Conv2d(pool1, conv2d2FilterConstant); + const ml::Operand conv2 = builder.Conv2d(pool1, conv2d2FilterConstant); const std::vector add2BiasShape = {1, 50, 1, 1}; const float* add2BiasData = reinterpret_cast(weightsData.get() + byteOffset); const uint32_t add2BiasDataLength = product(add2BiasShape) * sizeof(float); byteOffset += add2BiasDataLength; - const webnn::Operand add2BiasConstant = + const ml::Operand add2BiasConstant = utils::BuildConstant(builder, add2BiasShape, add2BiasData, add2BiasDataLength); - const webnn::Operand add2 = builder.Add(conv2, add2BiasConstant); + const ml::Operand add2 = builder.Add(conv2, add2BiasConstant); utils::Pool2dOptions pool2Options; pool2Options.windowDimensions = {2, 2}; pool2Options.strides = {2, 2}; - const webnn::Operand pool2 = builder.MaxPool2d(add2, pool2Options.AsPtr()); + const ml::Operand pool2 = builder.MaxPool2d(add2, pool2Options.AsPtr()); const std::vector newShape = {1, -1}; - const webnn::Operand reshape1 = builder.Reshape(pool2, newShape.data(), newShape.size()); + const ml::Operand reshape1 = builder.Reshape(pool2, newShape.data(), newShape.size()); // skip the new shape, 2 int64 values byteOffset += 2 * 8; @@ -104,56 +104,47 @@ bool LeNet::Load(const std::string& weigthsPath) { const float* matmul1Data = reinterpret_cast(weightsData.get() + byteOffset); const uint32_t matmul1DataLength = product(matmul1Shape) * sizeof(float); byteOffset += matmul1DataLength; - const webnn::Operand matmul1Weights = + const ml::Operand matmul1Weights = utils::BuildConstant(builder, matmul1Shape, matmul1Data, matmul1DataLength); - const webnn::Operand matmul1WeightsTransposed = builder.Transpose(matmul1Weights); - const webnn::Operand matmul1 = builder.Matmul(reshape1, matmul1WeightsTransposed); + const ml::Operand matmul1WeightsTransposed = builder.Transpose(matmul1Weights); + const ml::Operand matmul1 = builder.Matmul(reshape1, matmul1WeightsTransposed); const std::vector add3BiasShape = {1, 500}; const float* add3BiasData = reinterpret_cast(weightsData.get() + byteOffset); const uint32_t add3BiasDataLength = product(add3BiasShape) * sizeof(float); byteOffset += add3BiasDataLength; - const webnn::Operand add3BiasConstant = + const ml::Operand add3BiasConstant = utils::BuildConstant(builder, add3BiasShape, add3BiasData, add3BiasDataLength); - const webnn::Operand add3 = builder.Add(matmul1, add3BiasConstant); + const ml::Operand add3 = builder.Add(matmul1, add3BiasConstant); - const webnn::Operand relu = builder.Relu(add3); + const ml::Operand relu = builder.Relu(add3); const std::vector newShape2 = {1, -1}; - const webnn::Operand reshape2 = builder.Reshape(relu, newShape2.data(), newShape2.size()); + const ml::Operand reshape2 = builder.Reshape(relu, newShape2.data(), newShape2.size()); const std::vector matmul2Shape = {10, 500}; const float* matmul2Data = reinterpret_cast(weightsData.get() + byteOffset); const uint32_t matmul2DataLength = product(matmul2Shape) * sizeof(float); byteOffset += matmul2DataLength; - const webnn::Operand matmul2Weights = + const ml::Operand matmul2Weights = utils::BuildConstant(builder, matmul2Shape, matmul2Data, matmul2DataLength); - const webnn::Operand matmul2WeightsTransposed = builder.Transpose(matmul2Weights); - const webnn::Operand matmul2 = builder.Matmul(reshape2, matmul2WeightsTransposed); + const ml::Operand matmul2WeightsTransposed = builder.Transpose(matmul2Weights); + const ml::Operand matmul2 = builder.Matmul(reshape2, matmul2WeightsTransposed); const std::vector add4BiasShape = {1, 10}; const float* add4BiasData = reinterpret_cast(weightsData.get() + byteOffset); const uint32_t add4BiasDataLength = product(add4BiasShape) * sizeof(float); byteOffset += add4BiasDataLength; - const webnn::Operand add4BiasConstant = + const ml::Operand add4BiasConstant = utils::BuildConstant(builder, add4BiasShape, add4BiasData, add4BiasDataLength); - const webnn::Operand add4 = builder.Add(matmul2, add4BiasConstant); + const ml::Operand add4 = builder.Add(matmul2, add4BiasConstant); - const webnn::Operand softmax = builder.Softmax(add4); + const ml::Operand softmax = builder.Softmax(add4); - mModel = utils::CreateModel(builder, {{"output", softmax}}); - return true; -} - -bool LeNet::Compile(webnn::CompilationOptions const* options) { - if (!mModel) { - dawn::ErrorLog() << "Model is not ready."; - return false; - } const std::chrono::time_point startTime = std::chrono::high_resolution_clock::now(); - mCompilation = utils::AwaitCompile(mModel, options); - if (!mCompilation) { + mGraph = utils::AwaitBuild(builder, {{"output", softmax}}); + if (!mGraph) { return false; } const std::chrono::duration elapsedTime = @@ -162,16 +153,16 @@ bool LeNet::Compile(webnn::CompilationOptions const* options) { return true; } -webnn::Result LeNet::Compute(const void* inputData, size_t inputLength) { - if (!mCompilation) { - dawn::ErrorLog() << "Compilation is not ready."; - return webnn::Result(); +ml::Result LeNet::Compute(const void* inputData, size_t inputLength) { + if (!mGraph) { + dawn::ErrorLog() << "Graph is not ready."; + return ml::Result(); } const std::chrono::time_point startTime = std::chrono::high_resolution_clock::now(); - mResults = utils::AwaitCompute(mCompilation, {{"input", {inputData, inputLength}}}); + mResults = utils::AwaitCompute(mGraph, {{"input", {inputData, inputLength}}}); if (!mResults) { - return webnn::Result(); + return ml::Result(); } const std::chrono::duration elapsedTime = std::chrono::high_resolution_clock::now() - startTime; diff --git a/examples/LeNet/LeNet.h b/examples/LeNet/LeNet.h index 988a044e3..32295bb74 100644 --- a/examples/LeNet/LeNet.h +++ b/examples/LeNet/LeNet.h @@ -26,12 +26,10 @@ class LeNet { ~LeNet() = default; bool Load(const std::string& weigthsPath); - bool Compile(webnn::CompilationOptions const* options = nullptr); - webnn::Result Compute(const void* inputData, size_t inputLength); + ml::Result Compute(const void* inputData, size_t inputLength); private: - webnn::NeuralNetworkContext mContext; - webnn::Model mModel; - webnn::Compilation mCompilation; - webnn::NamedResults mResults; + ml::Context mContext; + ml::Graph mGraph; + ml::NamedResults mResults; }; diff --git a/examples/LeNet/Main.cpp b/examples/LeNet/Main.cpp index 74ce2cd0c..91e4a3cbf 100644 --- a/examples/LeNet/Main.cpp +++ b/examples/LeNet/Main.cpp @@ -44,7 +44,7 @@ void SelectTopKData(std::vector outputData, } } -void PrintResult(webnn::Result output) { +void PrintResult(ml::Result output) { const float* outputBuffer = static_cast(output.Buffer()); std::vector outputData(outputBuffer, outputBuffer + output.BufferSize() / sizeof(float)); std::vector topKIndex(TOP_NUMBER); @@ -116,12 +116,8 @@ int main(int argc, const char* argv[]) { dawn::ErrorLog() << "Failed to load LeNet."; return -1; } - if (!lenet.Compile()) { - dawn::ErrorLog() << "Failed to compile LeNet."; - return -1; - } std::vector input(reader.GetData().get(), reader.GetData().get() + reader.Size()); - webnn::Result result = lenet.Compute(input.data(), input.size() * sizeof(float)); + ml::Result result = lenet.Compute(input.data(), input.size() * sizeof(float)); if (!result) { dawn::ErrorLog() << "Failed to compute LeNet."; return -1; diff --git a/examples/SampleUtils.cpp b/examples/SampleUtils.cpp index 8f99963e8..2fe8b1962 100644 --- a/examples/SampleUtils.cpp +++ b/examples/SampleUtils.cpp @@ -32,14 +32,14 @@ uint32_t product(const std::vector& dims) { return prod; } -webnn::NeuralNetworkContext CreateCppNeuralNetworkContext() { +ml::Context CreateCppContext() { WebnnProcTable backendProcs = webnn_native::GetProcs(); webnnProcSetProcs(&backendProcs); - WebnnNeuralNetworkContext context = webnn_native::CreateNeuralNetworkContext(); + MLContext context = webnn_native::CreateContext(); if (context) { - return webnn::NeuralNetworkContext::Acquire(context); + return ml::Context::Acquire(context); } - return webnn::NeuralNetworkContext(); + return ml::Context(); } void DumpMemoryLeaks() { @@ -63,77 +63,72 @@ bool Expected(float output, float expected) { namespace utils { - webnn::Operand BuildInput(const webnn::ModelBuilder& builder, + ml::Operand BuildInput(const ml::GraphBuilder& builder, std::string name, const std::vector& dimensions, - webnn::OperandType type) { - webnn::OperandDescriptor desc = {type, dimensions.data(), (uint32_t)dimensions.size()}; + ml::OperandType type) { + ml::OperandDescriptor desc = {type, dimensions.data(), (uint32_t)dimensions.size()}; return builder.Input(name.c_str(), &desc); } - webnn::Operand BuildConstant(const webnn::ModelBuilder& builder, + ml::Operand BuildConstant(const ml::GraphBuilder& builder, const std::vector& dimensions, const void* value, size_t size, - webnn::OperandType type) { - webnn::OperandDescriptor desc = {type, dimensions.data(), (uint32_t)dimensions.size()}; + ml::OperandType type) { + ml::OperandDescriptor desc = {type, dimensions.data(), (uint32_t)dimensions.size()}; return builder.Constant(&desc, value, size); } - webnn::Model CreateModel(const webnn::ModelBuilder& builder, + ml::Graph AwaitBuild(const ml::GraphBuilder& builder, const std::vector& outputs) { - webnn::NamedOperands namedOperands = webnn::CreateNamedOperands(); - for (auto& output : outputs) { - namedOperands.Set(output.name.c_str(), output.operand); - } - return builder.CreateModel(namedOperands); - } - - webnn::Compilation AwaitCompile(const webnn::Model& model, - webnn::CompilationOptions const* options) { typedef struct { Async async; - webnn::Compilation compilation; - } CompilationData; + ml::Graph graph; + } BuildData; - CompilationData compilationData; - model.Compile( - [](WebnnCompileStatus status, WebnnCompilation impl, char const* message, + BuildData buildData; + ml::NamedOperands namedOperands = ml::CreateNamedOperands(); + for (auto& output : outputs) { + namedOperands.Set(output.name.c_str(), output.operand); + } + builder.Build(namedOperands, + [](MLBuildStatus status, MLGraph impl, char const* message, void* userData) { - CompilationData* compilationDataPtr = reinterpret_cast(userData); - DAWN_ASSERT(compilationDataPtr); - if (status != WebnnCompileStatus_Success) { - dawn::ErrorLog() << "Compile failed: " << message; + BuildData* buildDataPtr = reinterpret_cast(userData); + DAWN_ASSERT(buildDataPtr); + if (status != MLBuildStatus_Success) { + dawn::ErrorLog() << "Compute failed: " << message; } else { - compilationDataPtr->compilation = compilationDataPtr->compilation.Acquire(impl); + buildDataPtr->graph = buildDataPtr->graph.Acquire(impl); } - compilationDataPtr->async.Finish(); + buildDataPtr->async.Finish(); return; }, - &compilationData, options); - compilationData.async.Wait(); - return compilationData.compilation; + &buildData); + buildData.async.Wait(); + return buildData.graph; } - webnn::NamedResults AwaitCompute(const webnn::Compilation& compilation, + ml::NamedResults AwaitCompute(const ml::Graph& graph, const std::vector& inputs) { typedef struct { Async async; - webnn::NamedResults results; + ml::NamedResults results; } ComputeData; ComputeData computeData; - webnn::NamedInputs namedInputs = webnn::CreateNamedInputs(); + ml::NamedInputs namedInputs = ml::CreateNamedInputs(); for (auto& input : inputs) { namedInputs.Set(input.name.c_str(), &input.input); } - compilation.Compute( + graph.Compute( namedInputs, - [](WebnnComputeStatus status, WebnnNamedResults impl, char const* message, + [](MLComputeStatus status, MLNamedResults impl, char const* message, void* userData) { ComputeData* computeDataPtr = reinterpret_cast(userData); DAWN_ASSERT(computeDataPtr); - if (status != WebnnComputeStatus_Success) { + if (status != MLComputeStatus_Success) { dawn::ErrorLog() << "Compute failed: " << message; } else { computeDataPtr->results = computeDataPtr->results.Acquire(impl); @@ -146,7 +141,7 @@ namespace utils { return computeData.results; } - bool CheckShape(const webnn::Result& result, const std::vector& expectedShape) { + bool CheckShape(const ml::Result& result, const std::vector& expectedShape) { if (expectedShape.size() != result.DimensionsSize()) { dawn::ErrorLog() << "The output rank is expected as " << expectedShape.size() << ", but got " << result.DimensionsSize(); diff --git a/examples/SampleUtils.h b/examples/SampleUtils.h index 6bf240a0f..d74d3db87 100644 --- a/examples/SampleUtils.h +++ b/examples/SampleUtils.h @@ -26,7 +26,7 @@ uint32_t product(const std::vector& dims); -webnn::NeuralNetworkContext CreateCppNeuralNetworkContext(); +ml::Context CreateCppContext(); void DumpMemoryLeaks(); @@ -34,16 +34,16 @@ bool Expected(float output, float expected); namespace utils { - webnn::Operand BuildInput(const webnn::ModelBuilder& builder, + ml::Operand BuildInput(const ml::GraphBuilder& builder, std::string name, const std::vector& dimensions, - webnn::OperandType type = webnn::OperandType::Float32); + ml::OperandType type = ml::OperandType::Float32); - webnn::Operand BuildConstant(const webnn::ModelBuilder& builder, + ml::Operand BuildConstant(const ml::GraphBuilder& builder, const std::vector& dimensions, const void* value, size_t size, - webnn::OperandType type = webnn::OperandType::Float32); + ml::OperandType type = ml::OperandType::Float32); struct Conv2dOptions { public: @@ -51,9 +51,9 @@ namespace utils { std::vector strides; std::vector dilations; int32_t groups = 1; - webnn::OperandLayout layout = webnn::OperandLayout::Nchw; + ml::InputOperandLayout layout = ml::InputOperandLayout::Nchw; - const webnn::Conv2dOptions* AsPtr() { + const ml::Conv2dOptions* AsPtr() { if (!padding.empty()) { mOptions.paddingCount = padding.size(); mOptions.padding = padding.data(); @@ -72,7 +72,7 @@ namespace utils { } private: - webnn::Conv2dOptions mOptions; + ml::Conv2dOptions mOptions; }; struct Pool2dOptions { @@ -81,9 +81,9 @@ namespace utils { std::vector padding; std::vector strides; std::vector dilations; - webnn::OperandLayout layout = webnn::OperandLayout::Nchw; + ml::InputOperandLayout layout = ml::InputOperandLayout::Nchw; - const webnn::Pool2dOptions* AsPtr() { + const ml::Pool2dOptions* AsPtr() { if (!windowDimensions.empty()) { mOptions.windowDimensionsCount = windowDimensions.size(); mOptions.windowDimensions = windowDimensions.data(); @@ -105,32 +105,29 @@ namespace utils { } private: - webnn::Pool2dOptions mOptions; + ml::Pool2dOptions mOptions; }; typedef struct { const std::string& name; - const webnn::Operand& operand; + const ml::Operand& operand; } NamedOutput; - webnn::Model CreateModel(const webnn::ModelBuilder& builder, + ml::Graph AwaitBuild(const ml::GraphBuilder& builder, const std::vector& outputs); - webnn::Compilation AwaitCompile(const webnn::Model& model, - webnn::CompilationOptions const* options = nullptr); - typedef struct { const std::string& name; - const webnn::Input& input; + const ml::Input& input; } NamedInput; - webnn::NamedResults AwaitCompute(const webnn::Compilation& compilation, + ml::NamedResults AwaitCompute(const ml::Graph& compilation, const std::vector& inputs); - bool CheckShape(const webnn::Result& result, const std::vector& expectedShape); + bool CheckShape(const ml::Result& result, const std::vector& expectedShape); template - bool CheckValue(const webnn::Result& result, const std::vector& expectedValue) { + bool CheckValue(const ml::Result& result, const std::vector& expectedValue) { size_t size = result.BufferSize() / sizeof(T); if (size != expectedValue.size()) { dawn::ErrorLog() << "The size of output data is expected as " << expectedValue.size() diff --git a/generator/templates/webnn.h b/generator/templates/webnn.h index d70fc7d04..33bcfa5b1 100644 --- a/generator/templates/webnn.h +++ b/generator/templates/webnn.h @@ -101,9 +101,10 @@ typedef void (*WebnnProc)(void); #if !defined(WEBNN_SKIP_PROCS) -typedef WebnnNamedInputs (*WebnnProcCreateNamedInputs)(); -typedef WebnnNamedOperands (*WebnnProcCreateNamedOperands)(); -typedef WebnnNamedOutputs (*WebnnProcCreateNamedOutputs)(); +typedef MLGraphBuilder (*WebnnProcCreateGraphBuilder)(MLContext context); +typedef MLNamedInputs (*WebnnProcCreateNamedInputs)(); +typedef MLNamedOperands (*WebnnProcCreateNamedOperands)(); +typedef MLNamedOutputs (*WebnnProcCreateNamedOutputs)(); {% for type in by_category["object"] if len(c_methods(type)) > 0 %} // Procs of {{type.name.CamelCase()}} @@ -121,9 +122,10 @@ typedef WebnnNamedOutputs (*WebnnProcCreateNamedOutputs)(); #if !defined(WEBNN_SKIP_DECLARATIONS) -WEBNN_EXPORT WebnnNamedInputs webnnCreateNamedInputs(); -WEBNN_EXPORT WebnnNamedOperands webnnCreateNamedOperands(); -WEBNN_EXPORT WebnnNamedOutputs webnnCreateNamedOutputs(); +WEBNN_EXPORT MLGraphBuilder webnnCreateGraphBuilder(MLContext context); +WEBNN_EXPORT MLNamedInputs webnnCreateNamedInputs(); +WEBNN_EXPORT MLNamedOperands webnnCreateNamedOperands(); +WEBNN_EXPORT MLNamedOutputs webnnCreateNamedOutputs(); {% for type in by_category["object"] if len(c_methods(type)) > 0 %} // Methods of {{type.name.CamelCase()}} diff --git a/generator/templates/webnn_cpp.cpp b/generator/templates/webnn_cpp.cpp index 71c413ad9..dceee7365 100644 --- a/generator/templates/webnn_cpp.cpp +++ b/generator/templates/webnn_cpp.cpp @@ -14,7 +14,7 @@ //* limitations under the License. #include "webnn/webnn_cpp.h" -namespace webnn { +namespace ml { {% for type in by_category["enum"] %} {% set CppType = as_cppType(type.name) %} {% set CType = as_cType(type.name) %} @@ -130,6 +130,10 @@ namespace webnn { } {% endfor %} + GraphBuilder CreateGraphBuilder(Context context) { + return GraphBuilder::Acquire(webnnCreateGraphBuilder(context.GetHandle())); + } + NamedInputs CreateNamedInputs() { return NamedInputs::Acquire(webnnCreateNamedInputs()); } diff --git a/generator/templates/webnn_cpp.h b/generator/templates/webnn_cpp.h index a87413e2f..a55cd6ba0 100644 --- a/generator/templates/webnn_cpp.h +++ b/generator/templates/webnn_cpp.h @@ -18,7 +18,7 @@ #include "webnn/webnn.h" #include "webnn/EnumClassBitmasks.h" -namespace webnn { +namespace ml { {% for type in by_category["enum"] %} enum class {{as_cppType(type.name)}} : uint32_t { @@ -216,6 +216,7 @@ namespace webnn { {% endfor %} + GraphBuilder CreateGraphBuilder(Context context); NamedInputs CreateNamedInputs(); NamedOperands CreateNamedOperands(); NamedOutputs CreateNamedOutputs(); diff --git a/generator/templates/webnn_native/ProcTable.cpp b/generator/templates/webnn_native/ProcTable.cpp index 344fa637e..271fc7579 100644 --- a/generator/templates/webnn_native/ProcTable.cpp +++ b/generator/templates/webnn_native/ProcTable.cpp @@ -94,19 +94,24 @@ namespace webnn_native { return result; } - WebnnNamedInputs NativeCreateNamedInputs() { - return reinterpret_cast(new NamedInputsBase()); + MLGraphBuilder NativeCreateGraphBuilder(MLContext context) { + return reinterpret_cast(new GraphBuilderBase(reinterpret_cast(context))); } - WebnnNamedOperands NativeCreateNamedOperands() { - return reinterpret_cast(new NamedOperandsBase()); + MLNamedInputs NativeCreateNamedInputs() { + return reinterpret_cast(new NamedInputsBase()); } - WebnnNamedOutputs NativeCreateNamedOutputs() { - return reinterpret_cast(new NamedOutputsBase()); + MLNamedOperands NativeCreateNamedOperands() { + return reinterpret_cast(new NamedOperandsBase()); + } + + MLNamedOutputs NativeCreateNamedOutputs() { + return reinterpret_cast(new NamedOutputsBase()); } static WebnnProcTable gProcTable = { + NativeCreateGraphBuilder, NativeCreateNamedInputs, NativeCreateNamedOperands, NativeCreateNamedOutputs, diff --git a/generator/templates/webnn_native/ValidationUtils.cpp b/generator/templates/webnn_native/ValidationUtils.cpp index 748a95aa9..df5132e2e 100644 --- a/generator/templates/webnn_native/ValidationUtils.cpp +++ b/generator/templates/webnn_native/ValidationUtils.cpp @@ -18,10 +18,10 @@ namespace webnn_native { {% for type in by_category["enum"] %} - MaybeError Validate{{type.name.CamelCase()}}(webnn::{{as_cppType(type.name)}} value) { + MaybeError Validate{{type.name.CamelCase()}}(ml::{{as_cppType(type.name)}} value) { switch (value) { {% for value in type.values if value.valid %} - case webnn::{{as_cppType(type.name)}}::{{as_cppEnum(value.name)}}: + case ml::{{as_cppType(type.name)}}::{{as_cppEnum(value.name)}}: return {}; {% endfor %} default: @@ -32,8 +32,8 @@ namespace webnn_native { {% endfor %} {% for type in by_category["bitmask"] %} - MaybeError Validate{{type.name.CamelCase()}}(webnn::{{as_cppType(type.name)}} value) { - if ((value & static_cast(~{{type.full_mask}})) == 0) { + MaybeError Validate{{type.name.CamelCase()}}(ml::{{as_cppType(type.name)}} value) { + if ((value & static_cast(~{{type.full_mask}})) == 0) { return {}; } return DAWN_VALIDATION_ERROR("Invalid value for {{as_cType(type.name)}}"); diff --git a/generator/templates/webnn_native/ValidationUtils.h b/generator/templates/webnn_native/ValidationUtils.h index 21eec25d3..103144a1c 100644 --- a/generator/templates/webnn_native/ValidationUtils.h +++ b/generator/templates/webnn_native/ValidationUtils.h @@ -24,7 +24,7 @@ namespace webnn_native { // Helper functions to check the value of enums and bitmasks {% for type in by_category["enum"] + by_category["bitmask"] %} - MaybeError Validate{{type.name.CamelCase()}}(webnn::{{as_cppType(type.name)}} value); + MaybeError Validate{{type.name.CamelCase()}}(ml::{{as_cppType(type.name)}} value); {% endfor %} } // namespace webnn_native diff --git a/generator/templates/webnn_native/webnn_structs.h b/generator/templates/webnn_native/webnn_structs.h index d3fb86eb3..322041f53 100644 --- a/generator/templates/webnn_native/webnn_structs.h +++ b/generator/templates/webnn_native/webnn_structs.h @@ -25,7 +25,7 @@ namespace webnn_native { {%- if member.annotation in ["*", "const*", "const*const*"] and member.optional -%} {{" "}}= nullptr {%- elif member.type.category in ["enum", "bitmask"] and member.default_value != None -%} - {{" "}}= webnn::{{as_cppType(member.type.name)}}::{{as_cppEnum(Name(member.default_value))}} + {{" "}}= ml::{{as_cppType(member.type.name)}}::{{as_cppEnum(Name(member.default_value))}} {%- elif member.type.category == "native" and member.default_value != None -%} {{" "}}= {{member.default_value}} {%- else -%} @@ -41,7 +41,7 @@ namespace webnn_native { {% if type.chained %} struct {{as_cppType(type.name)}} : ChainedStruct { {{as_cppType(type.name)}}() { - sType = webnn::SType::{{type.name.CamelCase()}}; + sType = ml::SType::{{type.name.CamelCase()}}; } {% else %} struct {{as_cppType(type.name)}} { diff --git a/generator/templates/webnn_proc.c b/generator/templates/webnn_proc.c index 3625cb027..7904bf530 100644 --- a/generator/templates/webnn_proc.c +++ b/generator/templates/webnn_proc.c @@ -27,15 +27,19 @@ void webnnProcSetProcs(const WebnnProcTable* procs_) { } } -WebnnNamedInputs webnnCreateNamedInputs() { +MLGraphBuilder webnnCreateGraphBuilder(MLContext context) { + return procs.createGraphBuilder(context); +} + +MLNamedInputs webnnCreateNamedInputs() { return procs.createNamedInputs(); } -WebnnNamedOperands webnnCreateNamedOperands() { +MLNamedOperands webnnCreateNamedOperands() { return procs.createNamedOperands(); } -WebnnNamedOutputs webnnCreateNamedOutputs() { +MLNamedOutputs webnnCreateNamedOutputs() { return procs.createNamedOutputs(); } diff --git a/generator/templates/webnn_proc_table.h b/generator/templates/webnn_proc_table.h index 32125ab60..6114b58c1 100644 --- a/generator/templates/webnn_proc_table.h +++ b/generator/templates/webnn_proc_table.h @@ -19,6 +19,7 @@ #include "webnn/webnn.h" typedef struct WebnnProcTable { + WebnnProcCreateGraphBuilder createGraphBuilder; WebnnProcCreateNamedInputs createNamedInputs; WebnnProcCreateNamedOperands createNamedOperands; WebnnProcCreateNamedOutputs createNamedOutputs; diff --git a/generator/webnn_json_generator.py b/generator/webnn_json_generator.py index 4ac3db51e..e88355f8a 100644 --- a/generator/webnn_json_generator.py +++ b/generator/webnn_json_generator.py @@ -46,7 +46,7 @@ def as_cType(name): if name.native: return name.concatcase() else: - return 'Webnn' + name.CamelCase() + return 'ML' + name.CamelCase() def as_cTypeDawn(name): @@ -119,7 +119,7 @@ def annotated(typ, arg): def as_cEnum(type_name, value_name): assert not type_name.native and not value_name.native - return 'Webnn' + type_name.CamelCase() + '_' + value_name.CamelCase() + return 'ML' + type_name.CamelCase() + '_' + value_name.CamelCase() def as_cEnumDawn(type_name, value_name): @@ -137,7 +137,7 @@ def as_cppEnum(value_name): def as_cMethod(type_name, method_name): assert not type_name.native and not method_name.native - return 'webnn' + type_name.CamelCase() + method_name.CamelCase() + return 'ml' + type_name.CamelCase() + method_name.CamelCase() def as_cMethodDawn(type_name, method_name): @@ -164,7 +164,7 @@ def as_frontendType(typ): if typ.category == 'object': return typ.name.CamelCase() + 'Base*' elif typ.category in ['bitmask', 'enum']: - return 'webnn::' + typ.name.CamelCase() + return 'ml::' + typ.name.CamelCase() elif typ.category == 'structure': return as_cppType(typ.name) else: diff --git a/src/common/BUILD.gn b/src/common/BUILD.gn index 94e3fdf41..06e6ac5ce 100644 --- a/src/common/BUILD.gn +++ b/src/common/BUILD.gn @@ -153,7 +153,7 @@ if (is_win || is_linux || is_chromeos || is_mac || is_fuchsia || is_android) { sources = [ "//third_party/dawn/src/common/Assert.cpp", "//third_party/dawn/src/common/Assert.h", - "//third_party/dawn/src/common/Compiler.h", + "//third_party/dawn/src/common/Computer.h", "//third_party/dawn/src/common/Log.cpp", "//third_party/dawn/src/common/Log.h", "//third_party/dawn/src/common/Math.cpp", diff --git a/src/include/webnn_native/WebnnNative.h b/src/include/webnn_native/WebnnNative.h index 0bfa891aa..cd120c62f 100644 --- a/src/include/webnn_native/WebnnNative.h +++ b/src/include/webnn_native/WebnnNative.h @@ -27,7 +27,7 @@ namespace webnn_native { // Backend-agnostic API for webnn_native WEBNN_NATIVE_EXPORT const WebnnProcTable& GetProcs(); - WEBNN_NATIVE_EXPORT WebnnNeuralNetworkContext CreateNeuralNetworkContext(); + WEBNN_NATIVE_EXPORT MLContext CreateContext(MLContextOptions const* options = nullptr); } // namespace webnn_native diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index ad667cb74..4be968191 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -121,13 +121,13 @@ test("webnn_unittests") { sources = get_target_outputs(":mock_webnn_gen") sources += [ + "//third_party/dawn/src/tests/unittests/ResultTests.cpp", "unittests/ErrorTests.cpp", "unittests/ObjectBaseTests.cpp", - "//third_party/dawn/src/tests/unittests/ResultTests.cpp", "unittests/validation/BinaryValidationTests.cpp", "unittests/validation/Conv2dValidationTests.cpp", "unittests/validation/ErrorScopeValidationTests.cpp", - "unittests/validation/ModelValidationTests.cpp", + "unittests/validation/GraphValidationTests.cpp", "unittests/validation/PoolValidationTests.cpp", "unittests/validation/ReshapeValidationTests.cpp", "unittests/validation/TransposeValidationTests.cpp", diff --git a/src/tests/WebnnTest.cpp b/src/tests/WebnnTest.cpp index a7c080f36..15b9fb0ce 100644 --- a/src/tests/WebnnTest.cpp +++ b/src/tests/WebnnTest.cpp @@ -21,12 +21,12 @@ void InitWebnnEnd2EndTestEnvironment() { testing::AddGlobalTestEnvironment(gTestEnv); } -const webnn::NeuralNetworkContext& WebnnTest::GetContext() { +const ml::Context& WebnnTest::GetContext() { return gTestEnv->GetContext(); } void WebnnTest::SetUp() { - const webnn::NeuralNetworkContext& context = GetContext(); + const ml::Context& context = GetContext(); context.SetUncapturedErrorCallback(ErrorCallback, this); } @@ -50,8 +50,8 @@ std::string WebnnTest::GetLastErrorMessage() const { return mErrorMessage; } -void WebnnTest::ErrorCallback(WebnnErrorType type, char const* message, void* userdata) { - ASSERT(type != WebnnErrorType_NoError); +void WebnnTest::ErrorCallback(MLErrorType type, char const* message, void* userdata) { + ASSERT(type != MLErrorType_NoError); auto self = static_cast(userdata); self->mErrorMessage = message; @@ -61,10 +61,10 @@ void WebnnTest::ErrorCallback(WebnnErrorType type, char const* message, void* us } void WebnnTestEnvironment::SetUp() { - mContext = CreateCppNeuralNetworkContext(); + mContext = CreateCppContext(); DAWN_ASSERT(mContext); } -const webnn::NeuralNetworkContext& WebnnTestEnvironment::GetContext() { +const ml::Context& WebnnTestEnvironment::GetContext() { return mContext; } \ No newline at end of file diff --git a/src/tests/WebnnTest.h b/src/tests/WebnnTest.h index f50f54fa4..e785ac08b 100644 --- a/src/tests/WebnnTest.h +++ b/src/tests/WebnnTest.h @@ -25,13 +25,13 @@ class WebnnTest : public testing::Test { void SetUp() override; void TearDown() override; - const webnn::NeuralNetworkContext& GetContext(); + const ml::Context& GetContext(); void StartExpectContextError(); bool EndExpectContextError(); std::string GetLastErrorMessage() const; private: - static void ErrorCallback(WebnnErrorType type, const char* message, void* userdata); + static void ErrorCallback(MLErrorType type, const char* message, void* userdata); std::string mErrorMessage; bool mExpectError = false; bool mError = false; @@ -43,10 +43,10 @@ class WebnnTestEnvironment : public testing::Environment { public: void SetUp() override; - const webnn::NeuralNetworkContext& GetContext(); + const ml::Context& GetContext(); protected: - webnn::NeuralNetworkContext mContext; + ml::Context mContext; }; #endif // TESTS_WEBNN_TEST_H_ diff --git a/src/tests/end2end/AddTests.cpp b/src/tests/end2end/AddTests.cpp index 8a2bf620c..77ede8af3 100644 --- a/src/tests/end2end/AddTests.cpp +++ b/src/tests/end2end/AddTests.cpp @@ -17,8 +17,8 @@ class AddTests : public WebnnTest {}; TEST_F(AddTests, AddConstantAndInput) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); const std::vector bData = { -0.5781865, -0.49248728, -0.2162451, -0.13176449, -0.52118045, 1.9125274, 0.6508799, 0.71873736, -2.3154447, 0.8080079, 0.3022368, 0.21394566, -0.6511544, 0.20001237, @@ -29,11 +29,10 @@ TEST_F(AddTests, AddConstantAndInput) { 0.48171726, 0.34308678, -0.90550417, 0.203841, 0.02521433, -1.7966009, -1.4287543, 0.3222213, 1.0590587, -1.7948701, -1.7195907, -0.9120889, -0.9391962, -0.2566791, -0.5464537, 1.4351872, 0.5705938, -0.30327085}; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {3, 4, 5}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Add(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Add(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector dataA = { 0.08939514, -1.5887482, 0.8545348, 0.20523034, -0.41728342, 1.01752, 0.19677015, 0.5398451, 0.56893295, 1.2511084, 2.0092728, 1.0606714, 0.4893267, 0.09536829, @@ -44,8 +43,8 @@ TEST_F(AddTests, AddConstantAndInput) { -0.03482966, -0.7343786, -0.76851964, 0.9446942, -0.35489243, 0.44452578, 0.00648887, -0.55656946, -0.735903, 0.22050636, -0.5008282, -1.3132697, 1.6642882, -0.48397836, 0.20099205, -0.28786168, 1.3315053, -0.41619393}; - const webnn::Input input = {dataA.data(), dataA.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {dataA.data(), dataA.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {3, 4, 5})); const std::vector expectedValue( {-0.48879138, -2.0812354, 0.6382897, 0.07346585, -0.93846387, 2.9300475, 0.84765005, @@ -61,12 +60,11 @@ TEST_F(AddTests, AddConstantAndInput) { } TEST_F(AddTests, AddTwoInputs) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); - const webnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); - const webnn::Operand c = builder.Add(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); + const ml::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); + const ml::Operand c = builder.Add(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector dataA = { 0.08939514, -1.5887482, 0.8545348, 0.20523034, -0.41728342, 1.01752, 0.19677015, 0.5398451, 0.56893295, 1.2511084, 2.0092728, 1.0606714, 0.4893267, 0.09536829, @@ -77,7 +75,7 @@ TEST_F(AddTests, AddTwoInputs) { -0.03482966, -0.7343786, -0.76851964, 0.9446942, -0.35489243, 0.44452578, 0.00648887, -0.55656946, -0.735903, 0.22050636, -0.5008282, -1.3132697, 1.6642882, -0.48397836, 0.20099205, -0.28786168, 1.3315053, -0.41619393}; - const webnn::Input inputA = {dataA.data(), dataA.size() * sizeof(float)}; + const ml::Input inputA = {dataA.data(), dataA.size() * sizeof(float)}; const std::vector dataB = { -0.5781865, -0.49248728, -0.2162451, -0.13176449, -0.52118045, 1.9125274, 0.6508799, 0.71873736, -2.3154447, 0.8080079, 0.3022368, 0.21394566, -0.6511544, 0.20001237, @@ -88,9 +86,8 @@ TEST_F(AddTests, AddTwoInputs) { 0.48171726, 0.34308678, -0.90550417, 0.203841, 0.02521433, -1.7966009, -1.4287543, 0.3222213, 1.0590587, -1.7948701, -1.7195907, -0.9120889, -0.9391962, -0.2566791, -0.5464537, 1.4351872, 0.5705938, -0.30327085}; - const webnn::Input inputB = {dataB.data(), dataB.size() * sizeof(float)}; - const webnn::Result result = - utils::AwaitCompute(compiledModel, {{"a", inputA}, {"b", inputB}}).Get("c"); + const ml::Input inputB = {dataB.data(), dataB.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", inputA}, {"b", inputB}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {3, 4, 5})); const std::vector expectedValue( {-0.48879138, -2.0812354, 0.6382897, 0.07346585, -0.93846387, 2.9300475, 0.84765005, @@ -106,12 +103,11 @@ TEST_F(AddTests, AddTwoInputs) { } TEST_F(AddTests, AddBroadcast) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); - const webnn::Operand b = utils::BuildInput(builder, "b", {5}); - const webnn::Operand c = builder.Add(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); + const ml::Operand b = utils::BuildInput(builder, "b", {5}); + const ml::Operand c = builder.Add(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector dataA = { -0.08539673, 0.11800674, -1.2358714, 0.30089188, -0.73443925, 1.4894297, 0.16823359, -2.2034893, 1.0740992, -0.35457978, 0.61524934, 0.462153, 0.5992003, -0.81047946, @@ -122,13 +118,12 @@ TEST_F(AddTests, AddBroadcast) { 0.3092435, -1.311751, -0.6659017, 0.8815683, -0.31157655, 0.57511795, -1.1924151, -1.8408557, -0.85080767, -1.3341717, 0.54687303, -0.14426671, -0.15728855, 0.323939, 1.167636, 0.03020451, 0.91373825, 1.0675793}; - const webnn::Input inputA = {dataA.data(), dataA.size() * sizeof(float)}; + const ml::Input inputA = {dataA.data(), dataA.size() * sizeof(float)}; const std::vector dataB = { 0.6338172, 1.630534, -1.3819867, -1.0427561, 1.058136, }; - const webnn::Input inputB = {dataB.data(), dataB.size() * sizeof(float)}; - const webnn::Result result = - utils::AwaitCompute(compiledModel, {{"a", inputA}, {"b", inputB}}).Get("c"); + const ml::Input inputB = {dataB.data(), dataB.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", inputA}, {"b", inputB}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {3, 4, 5})); const std::vector expectedValue( {0.5484205, 1.7485408, -2.6178582, -0.7418642, 0.32369673, 2.123247, 1.7987677, diff --git a/src/tests/end2end/Conv2dTests.cpp b/src/tests/end2end/Conv2dTests.cpp index 3db69b247..a85824d79 100644 --- a/src/tests/end2end/Conv2dTests.cpp +++ b/src/tests/end2end/Conv2dTests.cpp @@ -17,20 +17,19 @@ class Conv2dTests : public WebnnTest {}; TEST_F(Conv2dTests, Conv2dWithPadding) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 5, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand input = utils::BuildInput(builder, "input", {1, 1, 5, 5}); const std::vector filterData(9, 1); - const webnn::Operand filter = utils::BuildConstant(builder, {1, 1, 3, 3}, filterData.data(), - filterData.size() * sizeof(float)); + const ml::Operand filter = utils::BuildConstant(builder, {1, 1, 3, 3}, filterData.data(), + filterData.size() * sizeof(float)); utils::Conv2dOptions options; options.padding = {1, 1, 1, 1}; - const webnn::Operand output = builder.Conv2d(input, filter, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"output", output}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand output = builder.Conv2d(input, filter, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"output", output}}); const std::vector inputData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; - const webnn::Result result = - utils::AwaitCompute(compiledModel, + const ml::Result result = + utils::AwaitCompute(graph, {{"input", {inputData.data(), inputData.size() * sizeof(float)}}}) .Get("output"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 5, 5})); @@ -41,18 +40,17 @@ TEST_F(Conv2dTests, Conv2dWithPadding) { } TEST_F(Conv2dTests, Conv2dWithoutPadding) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 5, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand input = utils::BuildInput(builder, "input", {1, 1, 5, 5}); const std::vector filterData(9, 1); - const webnn::Operand filter = utils::BuildConstant(builder, {1, 1, 3, 3}, filterData.data(), - filterData.size() * sizeof(float)); - const webnn::Operand output = builder.Conv2d(input, filter); - const webnn::Model model = utils::CreateModel(builder, {{"output", output}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand filter = utils::BuildConstant(builder, {1, 1, 3, 3}, filterData.data(), + filterData.size() * sizeof(float)); + const ml::Operand output = builder.Conv2d(input, filter); + const ml::Graph graph = utils::AwaitBuild(builder, {{"output", output}}); const std::vector inputData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; - const webnn::Result result = - utils::AwaitCompute(compiledModel, + const ml::Result result = + utils::AwaitCompute(graph, {{"input", {inputData.data(), inputData.size() * sizeof(float)}}}) .Get("output"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 3, 3})); @@ -61,22 +59,21 @@ TEST_F(Conv2dTests, Conv2dWithoutPadding) { } TEST_F(Conv2dTests, Conv2dWithStrides2AndPadding) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 7, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand input = utils::BuildInput(builder, "input", {1, 1, 7, 5}); const std::vector filterData(9, 1); - const webnn::Operand filter = utils::BuildConstant(builder, {1, 1, 3, 3}, filterData.data(), - filterData.size() * sizeof(float)); + const ml::Operand filter = utils::BuildConstant(builder, {1, 1, 3, 3}, filterData.data(), + filterData.size() * sizeof(float)); utils::Conv2dOptions options; options.padding = {1, 1, 1, 1}; options.strides = {2, 2}; - const webnn::Operand output = builder.Conv2d(input, filter, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"output", output}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand output = builder.Conv2d(input, filter, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"output", output}}); const std::vector inputData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}; - const webnn::Result result = - utils::AwaitCompute(compiledModel, + const ml::Result result = + utils::AwaitCompute(graph, {{"input", {inputData.data(), inputData.size() * sizeof(float)}}}) .Get("output"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 4, 3})); @@ -86,21 +83,20 @@ TEST_F(Conv2dTests, Conv2dWithStrides2AndPadding) { } TEST_F(Conv2dTests, Conv2dWithStrides2AndAsymetricPadding) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand input = utils::BuildInput(builder, "input", {1, 1, 5, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand input = utils::BuildInput(builder, "input", {1, 1, 5, 5}); const std::vector filterData(8, 1); - const webnn::Operand filter = utils::BuildConstant(builder, {1, 1, 4, 2}, filterData.data(), - filterData.size() * sizeof(float)); + const ml::Operand filter = utils::BuildConstant(builder, {1, 1, 4, 2}, filterData.data(), + filterData.size() * sizeof(float)); utils::Conv2dOptions options; options.padding = {1, 2, 0, 1}; options.strides = {2, 2}; - const webnn::Operand output = builder.Conv2d(input, filter, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"output", output}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand output = builder.Conv2d(input, filter, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"output", output}}); const std::vector inputData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; - const webnn::Result result = - utils::AwaitCompute(compiledModel, + const ml::Result result = + utils::AwaitCompute(graph, {{"input", {inputData.data(), inputData.size() * sizeof(float)}}}) .Get("output"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 3, 3})); diff --git a/src/tests/end2end/MatMulTests.cpp b/src/tests/end2end/MatMulTests.cpp index 2992b174b..8c5c26fa2 100644 --- a/src/tests/end2end/MatMulTests.cpp +++ b/src/tests/end2end/MatMulTests.cpp @@ -17,78 +17,74 @@ class MatMulTests : public WebnnTest {}; TEST_F(MatMulTests, MatMul1d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {4}); const std::vector bData = {0.8782074, 0.22533207, 0.7134056, 0.04190519}; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {4}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = {0.9025404, 0.89538723, 0.16789329, 0.7440875}; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {1})); const std::vector expectedValue = {1.1453342}; EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } TEST_F(MatMulTests, MatMul1dx2d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {4}); const std::vector bData = { 0.3093976, -1.2924036, -0.64339244, 1.1423386, 1.5052135, 1.8182521, -1.825652, -0.39694095, -0.90111053, 0.7807154, -1.9163561, -0.13988003, }; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {4, 3}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = {0.1309212, 0.9090703, 0.62183434, 0.9195683}; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {1, 3})); const std::vector expectedValue = {0.6616409, -0.80990994, 0.8797145}; EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } TEST_F(MatMulTests, MatMul2dx1d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {3, 4}); const std::vector bData = {0.25528687, 0.2126722, 0.26320502, 0.8297401}; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {4}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = { 0.3582649, 0.83665735, 0.30253866, 0.6446781, 0.4684662, 0.94761264, 0.4122941, 0.6787481, 0.15072346, 0.2820577, 0.67296237, 0.3856028, }; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {3, 1})); const std::vector expectedValue = {0.8839391, 0.9928265, 0.5955407}; EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } TEST_F(MatMulTests, MatMul2d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {3, 4}); const std::vector bData = {0.17467105, -1.2045133, -0.02621938, 0.6096196, 1.4499376, 1.3465316, 0.03289436, 1.0754977, -0.61485314, 0.94857556, -0.36462623, 1.402278}; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {4, 3}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = {0.9602246, 0.97682184, -0.33201018, 0.8248904, 0.40872088, 0.18995902, 0.69355214, -0.37210146, 0.18104352, 3.270753, -0.803097, -0.7268995}; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {3, 3})); const std::vector expectedValue = {1.5347629, -0.3981255, 2.6510081, -0.14295794, 0.6647107, -0.70315295, @@ -97,26 +93,25 @@ TEST_F(MatMulTests, MatMul2d) { } TEST_F(MatMulTests, MatMul3d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {2, 3, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {2, 3, 4}); const std::vector bData = {-2.7142005, 0.41909233, 0.80572236, 0.19983047, -1.9361104, 1.1919757, 0.61684674, 0.23732206, 0.74679494, 0.4595843, -0.90667343, 0.7676448, 0.48643762, 0.41120672, 1.1319419, 1.9692143, -0.44463134, 0.17005378, 1.1589569, -0.4333597, -0.47976026, 0.01067371, -0.79455626, -1.4024538}; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {2, 4, 3}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = { 0.19521078, 0.11637875, 0.54684865, 0.13257395, -0.05654722, -0.64351636, -1.0019655, -1.6156989, 0.01625126, 1.2386297, -0.1242797, 0.40350053, -0.5883816, 0.93452644, -0.01409106, -0.7825521, -1.2281458, -1.2388189, 0.7644939, -0.8567167, 0.3942727, -0.772506, -0.06412488, -0.9848109, }; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {2, 3, 3})); const std::vector expectedValue = { -0.10833447, -0.13393278, 0.8061598, -1.3357227, 2.449343, -2.801163, @@ -127,23 +122,22 @@ TEST_F(MatMulTests, MatMul3d) { } TEST_F(MatMulTests, MatMul3dx2d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {2, 3, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {2, 3, 4}); const std::vector bData = {-0.38534147, -0.18395364, -2.548874, 0.4525641, -0.41875792, 0.57480955, -0.41603103, 0.6973883, 0.9531734, 1.3292471, -1.003955, -0.7639869}; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {4, 3}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = { -0.57675153, -0.40231872, 0.10705414, -0.66516143, 0.3206562, 0.43695804, -1.8614748, 0.77510875, -1.2424866, -0.58930343, 0.40949076, 0.5517746, 0.09809388, 0.5084747, 0.76594603, 0.8050488, -0.03979152, 2.4019558, -0.54937273, -0.1696853, -1.223669, 1.0791223, -0.61921734, 2.1074235}; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {2, 3, 3})); const std::vector expectedValue = { -0.8885305, 1.0170201, 1.8490261, 1.8789318, -2.3183105, -2.9326258, @@ -153,21 +147,20 @@ TEST_F(MatMulTests, MatMul3dx2d) { } TEST_F(MatMulTests, MatMul3dx2dGet3d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {1, 3, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {1, 3, 4}); const std::vector bData = {0.2545374, -1.6150205, -0.64508885, -0.3454305, 0.38700557, 1.3147515, -0.3379386, 1.1804152, 1.9414345, -1.5912915, 0.40443325, -0.23596671}; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {4, 3}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = {0.25500464, -1.105212, -0.5368534, -0.01583702, 0.9875369, 1.3744136, 0.61079186, 0.74018836, -0.56111795, -0.16432828, 1.3176169, -0.249416}; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {1, 3, 3})); const std::vector expectedValue = { 0.6533069, -1.4796758, -2.6561086, -1.607665, -0.04264185, @@ -177,27 +170,26 @@ TEST_F(MatMulTests, MatMul3dx2dGet3d) { } TEST_F(MatMulTests, MatMul4d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {1, 2, 3, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {1, 2, 3, 4}); const std::vector bData = { -0.45605758, -0.43318668, 0.61509126, -2.2228749, 0.50257015, -0.29311436, -0.64561933, -0.6439757, 1.6211574, -0.28852704, -0.46247238, 0.5082442, 1.2357981, -0.82043344, -0.926581, -0.8955289, 0.74586314, -0.8022598, -0.5360306, -0.08719682, 0.72717273, 1.1277325, 2.0261378, -1.4311641, }; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {1, 2, 4, 3}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = { -0.8074054, -0.72524256, 0.4510249, 1.6203358, 1.9851393, 0.501528, 1.3975041, -2.3231244, 0.70866925, 0.24667543, -0.6271161, -0.9634111, -0.5911732, -0.09888726, -1.0926677, 0.47262478, 0.6141726, -0.634484, -0.07425678, -1.2638812, -1.1002079, -1.5324054, -1.1643038, -0.05644368, }; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {1, 2, 3, 3})); const std::vector expectedValue = { 1.2216457, -1.0545375, 1.2706597, -2.2521434, -0.4334606, 2.1588962, @@ -208,25 +200,24 @@ TEST_F(MatMulTests, MatMul4d) { } TEST_F(MatMulTests, MatMul4dx2d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {1, 2, 3, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {1, 2, 3, 4}); const std::vector bData = { 0.01829041, -0.73948264, -0.95898634, -0.5105271, 2.1705306, 1.2495605, -1.9865801, -0.58367056, -0.80371356, -0.583849, -1.2323712, 1.3314632, }; - const webnn::Operand b = + const ml::Operand b = utils::BuildConstant(builder, {4, 3}, bData.data(), bData.size() * sizeof(float)); - const webnn::Operand c = builder.Matmul(a, b); - const webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand c = builder.Matmul(a, b); + const ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); const std::vector aData = { -0.40162078, -0.5607968, -1.4350457, -0.22855183, -0.1357853, -1.3434876, 1.0602195, -0.17137937, 0.44751146, 0.78427273, -0.49435133, -0.9062699, -0.6109297, 0.645001, 0.6632162, 0.903104, 2.4085212, 0.7805757, -0.9099179, -0.6195976, 0.38710263, 0.5102191, -0.03610202, 1.2280966, }; - const webnn::Input input = {aData.data(), aData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + const ml::Input input = {aData.data(), aData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {1, 2, 3, 3})); const std::vector expectedValue = { 3.2632291, 0.19901966, 0.5334567, -1.3227482, -3.223286, -2.628851, diff --git a/src/tests/end2end/MulTests.cpp b/src/tests/end2end/MulTests.cpp index cbd69d0ed..583f2a0da 100644 --- a/src/tests/end2end/MulTests.cpp +++ b/src/tests/end2end/MulTests.cpp @@ -17,8 +17,8 @@ class MulTests : public WebnnTest {}; TEST_F(MulTests, MulInputAndConstant) { - webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - webnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); + ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + ml::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); std::vector dataB = { 2.0435283, 0.07213961, -1.1644137, -1.2209045, 0.8982674, 0.21796915, 0.27658972, 0.7744382, -0.52159035, -0.969913, 0.6081186, -0.04225572, 0.3275312, -0.06443629, @@ -29,11 +29,10 @@ TEST_F(MulTests, MulInputAndConstant) { 0.78826934, -0.18788454, 0.38178417, 0.9748209, 1.0242884, 0.7939937, 0.24449475, -1.3840157, 1.9665064, 0.35833818, -0.87076694, -0.76727265, 0.6157508, -0.5558823, 0.18417479, -0.93904793, -0.00859687, 0.5034271}; - webnn::Operand b = + ml::Operand b = utils::BuildConstant(builder, {3, 4, 5}, dataB.data(), dataB.size() * sizeof(float)); - webnn::Operand c = builder.Mul(a, b); - webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - webnn::Compilation compiledModel = utils::AwaitCompile(model); + ml::Operand c = builder.Mul(a, b); + ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); std::vector dataA = { 5.6232101e-01, 1.3117781e-01, -1.4161869e+00, 2.0386910e-02, 9.1077393e-01, 7.4952751e-01, -2.8509337e-01, -1.6272701e+00, 1.0271618e+00, 4.2815253e-01, @@ -47,8 +46,8 @@ TEST_F(MulTests, MulInputAndConstant) { 3.4815140e-04, -5.6024802e-01, 1.0848801e+00, -5.1780093e-01, -3.8996863e-01, 5.3133094e-01, 2.3897937e-01, -1.3832775e+00, 6.3414145e-01, 1.0691971e+00, 5.7040757e-01, 3.0711100e-01, 8.8405716e-01, -2.1583509e+00, 4.3243581e-01}; - webnn::Input input = {dataA.data(), dataA.size() * sizeof(float)}; - webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + ml::Input input = {dataA.data(), dataA.size() * sizeof(float)}; + ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {3, 4, 5})); std::vector expectedData = { 1.1491189e+00, 9.4631165e-03, 1.6490275e+00, -2.4890469e-02, 8.1811851e-01, @@ -67,12 +66,11 @@ TEST_F(MulTests, MulInputAndConstant) { } TEST_F(MulTests, MulTwoInputs) { - webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - webnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); - webnn::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); - webnn::Operand c = builder.Mul(a, b); - webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - webnn::Compilation compiledModel = utils::AwaitCompile(model); + ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + ml::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); + ml::Operand b = utils::BuildInput(builder, "b", {3, 4, 5}); + ml::Operand c = builder.Mul(a, b); + ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); std::vector dataA = { 5.6232101e-01, 1.3117781e-01, -1.4161869e+00, 2.0386910e-02, 9.1077393e-01, 7.4952751e-01, -2.8509337e-01, -1.6272701e+00, 1.0271618e+00, 4.2815253e-01, @@ -96,10 +94,9 @@ TEST_F(MulTests, MulTwoInputs) { 0.78826934, -0.18788454, 0.38178417, 0.9748209, 1.0242884, 0.7939937, 0.24449475, -1.3840157, 1.9665064, 0.35833818, -0.87076694, -0.76727265, 0.6157508, -0.5558823, 0.18417479, -0.93904793, -0.00859687, 0.5034271}; - webnn::Input inputA = {dataA.data(), dataA.size() * sizeof(float)}; - webnn::Input inputB = {dataB.data(), dataB.size() * sizeof(float)}; - webnn::Result result = - utils::AwaitCompute(compiledModel, {{"a", inputA}, {"b", inputB}}).Get("c"); + ml::Input inputA = {dataA.data(), dataA.size() * sizeof(float)}; + ml::Input inputB = {dataB.data(), dataB.size() * sizeof(float)}; + ml::Result result = utils::AwaitCompute(graph, {{"a", inputA}, {"b", inputB}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {3, 4, 5})); std::vector expectedData = { 1.1491189e+00, 9.4631165e-03, 1.6490275e+00, -2.4890469e-02, 8.1811851e-01, @@ -118,16 +115,14 @@ TEST_F(MulTests, MulTwoInputs) { } TEST_F(MulTests, MulBroadcast) { - webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - webnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); + ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + ml::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); std::vector dataB = { 0.6338172, 1.630534, -1.3819867, -1.0427561, 1.058136, }; - webnn::Operand b = - utils::BuildConstant(builder, {5}, dataB.data(), dataB.size() * sizeof(float)); - webnn::Operand c = builder.Mul(a, b); - webnn::Model model = utils::CreateModel(builder, {{"c", c}}); - webnn::Compilation compiledModel = utils::AwaitCompile(model); + ml::Operand b = utils::BuildConstant(builder, {5}, dataB.data(), dataB.size() * sizeof(float)); + ml::Operand c = builder.Mul(a, b); + ml::Graph graph = utils::AwaitBuild(builder, {{"c", c}}); std::vector dataA = { -0.08539673, 0.11800674, -1.2358714, 0.30089188, -0.73443925, 1.4894297, 0.16823359, -2.2034893, 1.0740992, -0.35457978, 0.61524934, 0.462153, 0.5992003, -0.81047946, @@ -139,8 +134,8 @@ TEST_F(MulTests, MulBroadcast) { -1.8408557, -0.85080767, -1.3341717, 0.54687303, -0.14426671, -0.15728855, 0.323939, 1.167636, 0.03020451, 0.91373825, 1.0675793, }; - webnn::Input input = {dataA.data(), dataA.size() * sizeof(float)}; - webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("c"); + ml::Input input = {dataA.data(), dataA.size() * sizeof(float)}; + ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("c"); EXPECT_TRUE(utils::CheckShape(result, {3, 4, 5})); std::vector expectedData = { -0.05412592, 0.192414, 1.707958, -0.31375682, -0.7771366, 0.9440262, 0.2743106, diff --git a/src/tests/end2end/Pool2dTests.cpp b/src/tests/end2end/Pool2dTests.cpp index 923e85180..6049abb3c 100644 --- a/src/tests/end2end/Pool2dTests.cpp +++ b/src/tests/end2end/Pool2dTests.cpp @@ -17,51 +17,48 @@ class Pool2dTests : public WebnnTest {}; TEST_F(Pool2dTests, MaxPool2d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; - const webnn::Operand y = builder.MaxPool2d(x, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"y", y}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand y = builder.MaxPool2d(x, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"y", y}}); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - const webnn::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"x", inputX}}).Get("y"); + const ml::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"x", inputX}}).Get("y"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 2, 2})); const std::vector expectedValue({11, 12, 15, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } TEST_F(Pool2dTests, MaxPool2dDilations) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; options.dilations = {2, 2}; - const webnn::Operand y = builder.MaxPool2d(x, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"y", y}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand y = builder.MaxPool2d(x, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"y", y}}); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - const webnn::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"x", inputX}}).Get("y"); + const ml::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"x", inputX}}).Get("y"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 2, 2})); const std::vector expectedValue({11, 12, 15, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } TEST_F(Pool2dTests, MaxPool2dPads) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; options.padding = {2, 2, 2, 2}; - const webnn::Operand y = builder.MaxPool2d(x, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"y", y}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand y = builder.MaxPool2d(x, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"y", y}}); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; - const webnn::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"x", inputX}}).Get("y"); + const ml::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"x", inputX}}).Get("y"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 5, 5})); const std::vector expectedValue({13, 14, 15, 15, 15, 18, 19, 20, 20, 20, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25, 23, 24, 25, 25, 25}); @@ -69,52 +66,49 @@ TEST_F(Pool2dTests, MaxPool2dPads) { } TEST_F(Pool2dTests, MaxPool2dStrides) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; options.strides = {2, 2}; - const webnn::Operand y = builder.MaxPool2d(x, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"y", y}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand y = builder.MaxPool2d(x, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"y", y}}); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; - const webnn::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"x", inputX}}).Get("y"); + const ml::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"x", inputX}}).Get("y"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 2, 2})); const std::vector expectedValue({7, 9, 17, 19}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } TEST_F(Pool2dTests, AveragePool2d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand x = utils::BuildInput(builder, "x", {1, 1, 4, 4}); utils::Pool2dOptions options; options.windowDimensions = {3, 3}; - const webnn::Operand y = builder.AveragePool2d(x, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"y", y}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand y = builder.AveragePool2d(x, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"y", y}}); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - const webnn::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"x", inputX}}).Get("y"); + const ml::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"x", inputX}}).Get("y"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 2, 2})); const std::vector expectedValue({6, 7, 10, 11}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } TEST_F(Pool2dTests, AveragePool2dPads) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {5, 5}; options.padding = {2, 2, 2, 2}; - const webnn::Operand y = builder.AveragePool2d(x, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"y", y}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand y = builder.AveragePool2d(x, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"y", y}}); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; - const webnn::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"x", inputX}}).Get("y"); + const ml::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"x", inputX}}).Get("y"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 5, 5})); const std::vector expectedValue({7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, 13.5, 14, 14.5, 15, 15.5, @@ -123,29 +117,27 @@ TEST_F(Pool2dTests, AveragePool2dPads) { } TEST_F(Pool2dTests, AveragePool2dStrides) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand x = utils::BuildInput(builder, "x", {1, 1, 5, 5}); utils::Pool2dOptions options; options.windowDimensions = {2, 2}; options.strides = {2, 2}; - const webnn::Operand y = builder.AveragePool2d(x, options.AsPtr()); - const webnn::Model model = utils::CreateModel(builder, {{"y", y}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::Operand y = builder.AveragePool2d(x, options.AsPtr()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"y", y}}); const std::vector dataX = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; - const webnn::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"x", inputX}}).Get("y"); + const ml::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"x", inputX}}).Get("y"); EXPECT_TRUE(utils::CheckShape(result, {1, 1, 2, 2})); const std::vector expectedValue({4, 6, 14, 16}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } TEST_F(Pool2dTests, GlobalAveragePool2d) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand x = utils::BuildInput(builder, "x", {1, 3, 5, 5}); - const webnn::Operand y = builder.AveragePool2d(x); - const webnn::Model model = utils::CreateModel(builder, {{"y", y}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand x = utils::BuildInput(builder, "x", {1, 3, 5, 5}); + const ml::Operand y = builder.AveragePool2d(x); + const ml::Graph graph = utils::AwaitBuild(builder, {{"y", y}}); const std::vector dataX = { -1.1289884, 0.34016284, 0.497431, 2.1915932, 0.42038894, -0.18261199, -0.15769927, -0.26465914, 0.03877424, 0.39492005, -0.33410737, 0.74918455, -1.3542547, -0.0222946, @@ -158,8 +150,8 @@ TEST_F(Pool2dTests, GlobalAveragePool2d) { -0.9641269, 0.6065926, -0.5830042, -0.81138134, 1.3569402, 1.2891295, 0.2508177, 0.20211531, 0.8832168, -0.19886094, -0.61088, 0.682026, -0.5253442, 1.5022339, 1.0256356, 1.0642492, -0.4169051, -0.8740329, 1.1494869}; - const webnn::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"x", inputX}}).Get("y"); + const ml::Input inputX = {dataX.data(), dataX.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"x", inputX}}).Get("y"); EXPECT_TRUE(utils::CheckShape(result, {1, 3, 1, 1})); const std::vector expectedValue({0.07170041, 0.05194739, 0.07117923}); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); diff --git a/src/tests/end2end/ReluTests.cpp b/src/tests/end2end/ReluTests.cpp index f59dd0666..f6f01e5be 100644 --- a/src/tests/end2end/ReluTests.cpp +++ b/src/tests/end2end/ReluTests.cpp @@ -17,11 +17,10 @@ class ReluTests : public WebnnTest {}; TEST_F(ReluTests, Relu) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); - const webnn::Operand b = builder.Relu(a); - const webnn::Model model = utils::CreateModel(builder, {{"b", b}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {3, 4, 5}); + const ml::Operand b = builder.Relu(a); + const ml::Graph graph = utils::AwaitBuild(builder, {{"b", b}}); const std::vector inputData = { -1.483762, 0.6447428, -1.2266507, -1.7132527, 0.9777725, -0.34438756, -0.99921757, -1.2882805, 1.3725083, -0.06386258, -0.44738683, -0.6776338, 0.5027815, -1.0428967, @@ -32,8 +31,8 @@ TEST_F(ReluTests, Relu) { 0.7993208, -0.31359985, 0.9019325, -0.02042965, 0.5222995, 1.3394557, -1.0482218, 1.1774449, 0.8999488, -1.1143959, 1.0122099, -0.48604885, -0.06009902, -0.1766853, 1.4515465, -0.7182982, 2.0361354, 0.7899623}; - const webnn::Input input = {inputData.data(), inputData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("b"); + const ml::Input input = {inputData.data(), inputData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("b"); EXPECT_TRUE(utils::CheckShape(result, {3, 4, 5})); const std::vector expectedData( {0., 0.6447428, 0., 0., 0.9777725, 0., 0., 0., diff --git a/src/tests/end2end/ReshapeTests.cpp b/src/tests/end2end/ReshapeTests.cpp index 75559dd78..3f5da0351 100644 --- a/src/tests/end2end/ReshapeTests.cpp +++ b/src/tests/end2end/ReshapeTests.cpp @@ -19,15 +19,14 @@ class ReshapeTests : public WebnnTest { void TestReshape(const std::vector& oldShape, const std::vector& newShape, const std::vector& expectedShape) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", oldShape); - const webnn::Operand b = builder.Reshape(a, newShape.data(), newShape.size()); - const webnn::Model model = utils::CreateModel(builder, {{"b", b}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", oldShape); + const ml::Operand b = builder.Reshape(a, newShape.data(), newShape.size()); + const ml::Graph graph = utils::AwaitBuild(builder, {{"b", b}}); const std::vector inputData = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; - const webnn::Input input = {inputData.data(), inputData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("b"); + const ml::Input input = {inputData.data(), inputData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("b"); EXPECT_TRUE(utils::CheckShape(result, expectedShape)); EXPECT_TRUE(utils::CheckValue(result, inputData)); } diff --git a/src/tests/end2end/SoftmaxTests.cpp b/src/tests/end2end/SoftmaxTests.cpp index 590dde954..75955607d 100644 --- a/src/tests/end2end/SoftmaxTests.cpp +++ b/src/tests/end2end/SoftmaxTests.cpp @@ -17,16 +17,15 @@ class SoftmaxTests : public WebnnTest {}; TEST_F(SoftmaxTests, Softmax) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", {3, 4}); - const webnn::Operand b = builder.Softmax(a); - const webnn::Model model = utils::CreateModel(builder, {{"b", b}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", {3, 4}); + const ml::Operand b = builder.Softmax(a); + const ml::Graph graph = utils::AwaitBuild(builder, {{"b", b}}); const std::vector inputData = {0.4301911, 0.54719144, -1.1637765, 0.18390046, 0.58390397, 0.1735679, 0.539724, -0.953514, -0.59202826, -0.17344485, 0.14395015, -0.37920907}; - const webnn::Input input = {inputData.data(), inputData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("b"); + const ml::Input input = {inputData.data(), inputData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("b"); EXPECT_TRUE(utils::CheckShape(result, {3, 4})); const std::vector expectedData = {0.32165375, 0.36157736, 0.0653337, 0.25143513, 0.35271573, 0.23400122, 0.33747196, 0.07581109, diff --git a/src/tests/end2end/TransposeTests.cpp b/src/tests/end2end/TransposeTests.cpp index 0a8f4e1dc..946f495fb 100644 --- a/src/tests/end2end/TransposeTests.cpp +++ b/src/tests/end2end/TransposeTests.cpp @@ -21,16 +21,15 @@ class TransposeTests : public WebnnTest { const std::vector& expectedShape, const std::vector& expectedValue, const std::vector& permutation = {}) { - const webnn::ModelBuilder builder = GetContext().CreateModelBuilder(); - const webnn::Operand a = utils::BuildInput(builder, "a", inputShape); - webnn::TransposeOptions options; + const ml::GraphBuilder builder = ml::CreateGraphBuilder(GetContext()); + const ml::Operand a = utils::BuildInput(builder, "a", inputShape); + ml::TransposeOptions options; options.permutation = permutation.data(); options.permutationCount = permutation.size(); - const webnn::Operand b = builder.Transpose(a, &options); - const webnn::Model model = utils::CreateModel(builder, {{"b", b}}); - const webnn::Compilation compiledModel = utils::AwaitCompile(model); - const webnn::Input input = {inputData.data(), inputData.size() * sizeof(float)}; - const webnn::Result result = utils::AwaitCompute(compiledModel, {{"a", input}}).Get("b"); + const ml::Operand b = builder.Transpose(a, &options); + const ml::Graph graph = utils::AwaitBuild(builder, {{"b", b}}); + const ml::Input input = {inputData.data(), inputData.size() * sizeof(float)}; + const ml::Result result = utils::AwaitCompute(graph, {{"a", input}}).Get("b"); EXPECT_TRUE(utils::CheckShape(result, expectedShape)); EXPECT_TRUE(utils::CheckValue(result, expectedValue)); } diff --git a/src/tests/unittests/ObjectBaseTests.cpp b/src/tests/unittests/ObjectBaseTests.cpp index 1ffdd3da2..a1160c2ab 100644 --- a/src/tests/unittests/ObjectBaseTests.cpp +++ b/src/tests/unittests/ObjectBaseTests.cpp @@ -18,7 +18,7 @@ #include "webnn/webnn_cpp.h" -class Object : public webnn::ObjectBase { +class Object : public ml::ObjectBase { public: using ObjectBase::ObjectBase; using ObjectBase::operator=; diff --git a/src/tests/unittests/validation/BinaryValidationTests.cpp b/src/tests/unittests/validation/BinaryValidationTests.cpp index 9d879c02b..c1097b33e 100644 --- a/src/tests/unittests/validation/BinaryValidationTests.cpp +++ b/src/tests/unittests/validation/BinaryValidationTests.cpp @@ -22,23 +22,22 @@ class BinaryValidationTest : public ValidationTest {}; TEST_F(BinaryValidationTest, InputsType) { std::vector shape = {2, 2}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand a = mBuilder.Input("input", &inputDesc); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand a = mBuilder.Input("input", &inputDesc); // success { std::vector data(4, 1); - webnn::Operand b = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(float)); - webnn::Operand add = mBuilder.Add(a, b); - webnn::Operand mul = mBuilder.Mul(a, b); - webnn::Operand matmul = mBuilder.Matmul(a, b); + ml::Operand b = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(float)); + ml::Operand add = mBuilder.Add(a, b); + ml::Operand mul = mBuilder.Mul(a, b); + ml::Operand matmul = mBuilder.Matmul(a, b); } // inputs types are inconsistent { std::vector data(4, 1); - inputDesc = {webnn::OperandType::Int32, shape.data(), (uint32_t)shape.size()}; - webnn::Operand b = - mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(int32_t)); + inputDesc = {ml::OperandType::Int32, shape.data(), (uint32_t)shape.size()}; + ml::Operand b = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(int32_t)); ASSERT_CONTEXT_ERROR(mBuilder.Add(a, b)); ASSERT_CONTEXT_ERROR(mBuilder.Mul(a, b)); ASSERT_CONTEXT_ERROR(mBuilder.Matmul(a, b)); diff --git a/src/tests/unittests/validation/Conv2dValidationTests.cpp b/src/tests/unittests/validation/Conv2dValidationTests.cpp index 7b40b1207..d88c3fa53 100644 --- a/src/tests/unittests/validation/Conv2dValidationTests.cpp +++ b/src/tests/unittests/validation/Conv2dValidationTests.cpp @@ -23,48 +23,47 @@ class Conv2dValidationTest : public ValidationTest { void SetUp() override { ValidationTest::SetUp(); std::vector shape = {1, 1, 5, 5}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; mInput = mBuilder.Input("input", &inputDesc); shape = {1, 1, 3, 3}; std::vector data(9, 1); - inputDesc = {webnn::OperandType::Float32, shape.data(), (uint32_t)shape.size()}; + inputDesc = {ml::OperandType::Float32, shape.data(), (uint32_t)shape.size()}; mFilter = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(float)); } - webnn::Operand mInput; - webnn::Operand mFilter; + ml::Operand mInput; + ml::Operand mFilter; }; TEST_F(Conv2dValidationTest, CreateByDefaultOptions) { // Success { // using default value for options - webnn::Conv2dOptions conv2dOptions = {}; - webnn::Operand conv = mBuilder.Conv2d(mInput, mFilter, &conv2dOptions); + ml::Conv2dOptions conv2dOptions = {}; + ml::Operand conv = mBuilder.Conv2d(mInput, mFilter, &conv2dOptions); } - { webnn::Operand conv = mBuilder.Conv2d(mInput, mFilter); } + { ml::Operand conv = mBuilder.Conv2d(mInput, mFilter); } } TEST_F(Conv2dValidationTest, DifferentTypeError) { // input type is fp32 while filter type is int32 std::vector shape = {1, 1, 3, 3}; std::vector data(9, 1); - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Int32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand filter = - mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(int32_t)); - webnn::Conv2dOptions conv2dOptions = {}; + ml::OperandDescriptor inputDesc = {ml::OperandType::Int32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand filter = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(int32_t)); + ml::Conv2dOptions conv2dOptions = {}; ASSERT_CONTEXT_ERROR(mBuilder.Conv2d(mInput, filter, &conv2dOptions)); } TEST_F(Conv2dValidationTest, InvalidInputDimsError) { // input rank is not 4 std::vector shape = {1, 1, 5}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand input = mBuilder.Input("input", &inputDesc); - webnn::Conv2dOptions conv2dOptions = {}; + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand input = mBuilder.Input("input", &inputDesc); + ml::Conv2dOptions conv2dOptions = {}; ASSERT_CONTEXT_ERROR(mBuilder.Conv2d(input, mFilter, &conv2dOptions)); } @@ -72,15 +71,15 @@ TEST_F(Conv2dValidationTest, InvalidFilterDimsError) { // filter rank is 3 std::vector shape = {1, 1, 3}; std::vector data(3, 1); - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand filter = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(float)); - webnn::Conv2dOptions conv2dOptions = {}; + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand filter = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(float)); + ml::Conv2dOptions conv2dOptions = {}; ASSERT_CONTEXT_ERROR(mBuilder.Conv2d(mInput, filter, &conv2dOptions)); } TEST_F(Conv2dValidationTest, InvalidOptions) { - webnn::Conv2dOptions options = {}; + ml::Conv2dOptions options = {}; { // invalid paddingCount std::vector padding = {1, 1, 1}; diff --git a/src/tests/unittests/validation/ErrorScopeValidationTests.cpp b/src/tests/unittests/validation/ErrorScopeValidationTests.cpp index d13785094..66b4b99dc 100644 --- a/src/tests/unittests/validation/ErrorScopeValidationTests.cpp +++ b/src/tests/unittests/validation/ErrorScopeValidationTests.cpp @@ -21,11 +21,11 @@ using namespace testing; class MockContextPopErrorScopeCallback { public: - MOCK_METHOD(void, Call, (WebnnErrorType type, const char* message, void* userdata)); + MOCK_METHOD(void, Call, (MLErrorType type, const char* message, void* userdata)); }; static std::unique_ptr mockContextPopErrorScopeCallback; -static void ToMockContextPopErrorScopeCallback(WebnnErrorType type, +static void ToMockContextPopErrorScopeCallback(MLErrorType type, const char* message, void* userdata) { mockContextPopErrorScopeCallback->Call(type, message, userdata); @@ -48,38 +48,37 @@ class ErrorScopeValidationTest : public ValidationTest { // Test the simple success case. TEST_F(ErrorScopeValidationTest, Success) { - mContext.PushErrorScope(webnn::ErrorFilter::Validation); + mContext.PushErrorScope(ml::ErrorFilter::Validation); - EXPECT_CALL(*mockContextPopErrorScopeCallback, Call(WebnnErrorType_NoError, _, this)).Times(1); + EXPECT_CALL(*mockContextPopErrorScopeCallback, Call(MLErrorType_NoError, _, this)).Times(1); mContext.PopErrorScope(ToMockContextPopErrorScopeCallback, this); } // Test the simple case where the error scope catches an error. TEST_F(ErrorScopeValidationTest, CatchesError) { - mContext.PushErrorScope(webnn::ErrorFilter::Validation); + mContext.PushErrorScope(ml::ErrorFilter::Validation); std::vector shape = {2, 2, 2}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand a = mBuilder.Input("input", &inputDesc); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand a = mBuilder.Input("input", &inputDesc); mBuilder.Softmax(a); - EXPECT_CALL(*mockContextPopErrorScopeCallback, Call(WebnnErrorType_Validation, _, this)) - .Times(1); + EXPECT_CALL(*mockContextPopErrorScopeCallback, Call(MLErrorType_Validation, _, this)).Times(1); mContext.PopErrorScope(ToMockContextPopErrorScopeCallback, this); } // Test that if no error scope handles an error, it goes to the context UncapturedError callback TEST_F(ErrorScopeValidationTest, UnhandledErrorsMatchUncapturedErrorCallback) { - mContext.PushErrorScope(webnn::ErrorFilter::OutOfMemory); + mContext.PushErrorScope(ml::ErrorFilter::OutOfMemory); std::vector shape = {2, 2, 2}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand a = mBuilder.Input("input", &inputDesc); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand a = mBuilder.Input("input", &inputDesc); ASSERT_CONTEXT_ERROR(mBuilder.Softmax(a)); - EXPECT_CALL(*mockContextPopErrorScopeCallback, Call(WebnnErrorType_NoError, _, this)).Times(1); + EXPECT_CALL(*mockContextPopErrorScopeCallback, Call(MLErrorType_NoError, _, this)).Times(1); mContext.PopErrorScope(ToMockContextPopErrorScopeCallback, this); } @@ -90,9 +89,9 @@ TEST_F(ErrorScopeValidationTest, PushPopBalanced) { // Too many pops { - mContext.PushErrorScope(webnn::ErrorFilter::Validation); + mContext.PushErrorScope(ml::ErrorFilter::Validation); - EXPECT_CALL(*mockContextPopErrorScopeCallback, Call(WebnnErrorType_NoError, _, this + 1)) + EXPECT_CALL(*mockContextPopErrorScopeCallback, Call(MLErrorType_NoError, _, this + 1)) .Times(1); mContext.PopErrorScope(ToMockContextPopErrorScopeCallback, this + 1); diff --git a/src/tests/unittests/validation/GraphValidationTests.cpp b/src/tests/unittests/validation/GraphValidationTests.cpp new file mode 100644 index 000000000..155f31e75 --- /dev/null +++ b/src/tests/unittests/validation/GraphValidationTests.cpp @@ -0,0 +1,74 @@ +// Copyright 2021 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "examples/SampleUtils.h" +#include "tests/unittests/validation/ValidationTest.h" + +#include + +using namespace testing; + +class MockGraphBuildCallback { + public: + MOCK_METHOD(void, + Call, + (MLBuildStatus status, MLGraph impl, const char* message, void* userdata)); +}; + +static std::unique_ptr mockGraphBuildCallback; +static void ToMockGraphBuildCallback(MLBuildStatus status, + MLGraph impl, + const char* message, + void* userdata) { + mockGraphBuildCallback->Call(status, impl, message, userdata); +} + +class GraphValidationTest : public ValidationTest { + protected: + void SetUp() override { + ValidationTest::SetUp(); + mockGraphBuildCallback = std::make_unique(); + std::vector shape = {2, 2}; + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand a = mBuilder.Input("input", &inputDesc); + std::vector data(4, 1); + ml::Operand b = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(float)); + mOutput = mBuilder.Add(a, b); + } + + void TearDown() override { + ValidationTest::TearDown(); + + // Delete mocks so that expectations are checked + mockGraphBuildCallback = nullptr; + } + + ml::Operand mOutput; +}; + +// Test the simple success case. +TEST_F(GraphValidationTest, BuildCallBackSuccess) { + ml::NamedOperands namedOperands = ml::CreateNamedOperands(); + namedOperands.Set("output", mOutput); + mBuilder.Build(namedOperands, ToMockGraphBuildCallback, this); + EXPECT_CALL(*mockGraphBuildCallback, Call(MLBuildStatus_Success, _, nullptr, this)).Times(1); +} + +// Create model with null nameOperands +TEST_F(GraphValidationTest, BuildCallBackError) { + ml::NamedOperands namedOperands = ml::CreateNamedOperands(); + mBuilder.Build(namedOperands, ToMockGraphBuildCallback, this); + EXPECT_CALL(*mockGraphBuildCallback, Call(MLBuildStatus_Error, _, _, this)).Times(1); +} diff --git a/src/tests/unittests/validation/ModelValidationTests.cpp b/src/tests/unittests/validation/ModelValidationTests.cpp deleted file mode 100644 index abbb13376..000000000 --- a/src/tests/unittests/validation/ModelValidationTests.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "examples/SampleUtils.h" -#include "tests/unittests/validation/ValidationTest.h" - -#include - -using namespace testing; - -class MockModelCompileCallback { - public: - MOCK_METHOD( - void, - Call, - (WebnnCompileStatus status, WebnnCompilation impl, const char* message, void* userdata)); -}; - -static std::unique_ptr mockModelCompileCallback; -static void ToMockModelCompileCallback(WebnnCompileStatus status, - WebnnCompilation impl, - const char* message, - void* userdata) { - mockModelCompileCallback->Call(status, impl, message, userdata); -} - -class ModelValidationTest : public ValidationTest { - protected: - void SetUp() override { - ValidationTest::SetUp(); - mockModelCompileCallback = std::make_unique(); - std::vector shape = {2, 2}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand a = mBuilder.Input("input", &inputDesc); - std::vector data(4, 1); - webnn::Operand b = mBuilder.Constant(&inputDesc, data.data(), data.size() * sizeof(float)); - mOutput = mBuilder.Add(a, b); - } - - void TearDown() override { - ValidationTest::TearDown(); - - // Delete mocks so that expectations are checked - mockModelCompileCallback = nullptr; - } - - webnn::Operand mOutput; -}; - -// Test the simple success case. -TEST_F(ModelValidationTest, CompileCallBackSuccess) { - webnn::NamedOperands namedOperands = webnn::CreateNamedOperands(); - namedOperands.Set("output", mOutput); - webnn::Model model = mBuilder.CreateModel(namedOperands); - EXPECT_CALL(*mockModelCompileCallback, Call(WebnnCompileStatus_Success, _, nullptr, this)) - .Times(1); - model.Compile(ToMockModelCompileCallback, this); -} - -// Create model with null nameOperands -TEST_F(ModelValidationTest, CompileCallBackError) { - webnn::NamedOperands namedOperands = webnn::CreateNamedOperands(); - webnn::Model model = mBuilder.CreateModel(namedOperands); - EXPECT_CALL(*mockModelCompileCallback, Call(WebnnCompileStatus_Error, _, _, this)).Times(1); - model.Compile(ToMockModelCompileCallback, this); -} diff --git a/src/tests/unittests/validation/PoolValidationTests.cpp b/src/tests/unittests/validation/PoolValidationTests.cpp index 6d31df5d3..423826c2c 100644 --- a/src/tests/unittests/validation/PoolValidationTests.cpp +++ b/src/tests/unittests/validation/PoolValidationTests.cpp @@ -23,27 +23,27 @@ class PoolValidationTest : public ValidationTest {}; TEST_F(PoolValidationTest, CreateByDefaultOptions) { // Success std::vector shape = {1, 100, 1000, 1000}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand input = mBuilder.Input("input", &inputDesc); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand input = mBuilder.Input("input", &inputDesc); { // using default value for options - webnn::Pool2dOptions pool2dOptions = {}; - webnn::Operand pool = mBuilder.AveragePool2d(input, &pool2dOptions); + ml::Pool2dOptions pool2dOptions = {}; + ml::Operand pool = mBuilder.AveragePool2d(input, &pool2dOptions); } - { webnn::Operand pool = mBuilder.MaxPool2d(input); } + { ml::Operand pool = mBuilder.MaxPool2d(input); } } TEST_F(PoolValidationTest, InputDimsError) { // input is not a 4D tensor std::vector shape = {1, 100, 1000, 1000, 1}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand input = mBuilder.Input("input", &inputDesc); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand input = mBuilder.Input("input", &inputDesc); - webnn::Pool2dOptions pool2dOptions = {}; - webnn::Operand pool; + ml::Pool2dOptions pool2dOptions = {}; + ml::Operand pool; ASSERT_CONTEXT_ERROR(pool = mBuilder.MaxPool2d(input, &pool2dOptions)); // input variable pool is not valid ASSERT_CONTEXT_ERROR(mBuilder.MaxPool2d(pool)); @@ -51,12 +51,12 @@ TEST_F(PoolValidationTest, InputDimsError) { TEST_F(PoolValidationTest, FilterCountError) { std::vector shape = {1, 100, 1000, 1000}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand input = mBuilder.Input("input", &inputDesc); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand input = mBuilder.Input("input", &inputDesc); // windowDimensionsCount is incorrect { - webnn::Pool2dOptions options; + ml::Pool2dOptions options; std::vector windowDimensions = {2, 2, 1}; options.windowDimensions = windowDimensions.data(); options.windowDimensionsCount = 3; @@ -67,7 +67,7 @@ TEST_F(PoolValidationTest, FilterCountError) { } // paddingCount is incorrect { - webnn::Pool2dOptions options; + ml::Pool2dOptions options; options.windowDimensions = nullptr; options.strides = nullptr; std::vector padding = {1, 1}; @@ -78,7 +78,7 @@ TEST_F(PoolValidationTest, FilterCountError) { } // stridesCount is incorrect { - webnn::Pool2dOptions options; + ml::Pool2dOptions options; options.windowDimensions = nullptr; std::vector strides = {1}; options.strides = strides.data(); @@ -89,7 +89,7 @@ TEST_F(PoolValidationTest, FilterCountError) { } // dilationsCount is incorrect { - webnn::Pool2dOptions options; + ml::Pool2dOptions options; options.windowDimensions = nullptr; options.strides = nullptr; options.padding = nullptr; diff --git a/src/tests/unittests/validation/ReshapeValidationTests.cpp b/src/tests/unittests/validation/ReshapeValidationTests.cpp index be6388ad5..38b0a8c28 100644 --- a/src/tests/unittests/validation/ReshapeValidationTests.cpp +++ b/src/tests/unittests/validation/ReshapeValidationTests.cpp @@ -22,18 +22,18 @@ class ReshapeValidationTest : public ValidationTest {}; TEST_F(ReshapeValidationTest, InputsType) { std::vector shape = {2, 3, 4}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand a = mBuilder.Input("input", &inputDesc); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand a = mBuilder.Input("input", &inputDesc); // success { std::vector newShape = {1, 2, 3, 4}; - webnn::Operand reshape = mBuilder.Reshape(a, newShape.data(), newShape.size()); + ml::Operand reshape = mBuilder.Reshape(a, newShape.data(), newShape.size()); } // success - Only one component of newShape can be the special value of -1. { std::vector newShape = {-1, 2, 3, 4}; - webnn::Operand reshape = mBuilder.Reshape(a, newShape.data(), newShape.size()); + ml::Operand reshape = mBuilder.Reshape(a, newShape.data(), newShape.size()); } // two component both be -1 { diff --git a/src/tests/unittests/validation/TransposeValidationTests.cpp b/src/tests/unittests/validation/TransposeValidationTests.cpp index 03bdd015d..bed41f3de 100644 --- a/src/tests/unittests/validation/TransposeValidationTests.cpp +++ b/src/tests/unittests/validation/TransposeValidationTests.cpp @@ -23,19 +23,19 @@ class TransposeValidationTest : public ValidationTest { void SetUp() override { ValidationTest::SetUp(); std::vector shape = {2, 3, 4}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; mInput = mBuilder.Input("input", &inputDesc); } - webnn::Operand mInput; + ml::Operand mInput; }; TEST_F(TransposeValidationTest, CreateByDefaultOptions) { // success - { webnn::Operand transpose = mBuilder.Transpose(mInput); } + { ml::Operand transpose = mBuilder.Transpose(mInput); } { - webnn::TransposeOptions options = {}; - webnn::Operand transpose = mBuilder.Transpose(mInput, &options); + ml::TransposeOptions options = {}; + ml::Operand transpose = mBuilder.Transpose(mInput, &options); } } @@ -43,15 +43,15 @@ TEST_F(TransposeValidationTest, InvalidOptions) { // success { std::vector permutation = {2, 0, 1}; - webnn::TransposeOptions options; + ml::TransposeOptions options; options.permutation = permutation.data(); options.permutationCount = permutation.size(); - webnn::Operand transpose = mBuilder.Transpose(mInput, &options); + ml::Operand transpose = mBuilder.Transpose(mInput, &options); } // permutation size is invalid { std::vector permutation = {2, 0, 1, 3}; - webnn::TransposeOptions options; + ml::TransposeOptions options; options.permutation = permutation.data(); options.permutationCount = permutation.size(); ASSERT_CONTEXT_ERROR(mBuilder.Transpose(mInput, &options)); @@ -59,7 +59,7 @@ TEST_F(TransposeValidationTest, InvalidOptions) { // permutation value is invalid { std::vector permutation = {3, 2, 2}; - webnn::TransposeOptions options; + ml::TransposeOptions options; options.permutation = permutation.data(); options.permutationCount = permutation.size(); ASSERT_CONTEXT_ERROR(mBuilder.Transpose(mInput, &options)); @@ -67,7 +67,7 @@ TEST_F(TransposeValidationTest, InvalidOptions) { // permutation value is invalid { std::vector permutation = {3, 2, 4}; - webnn::TransposeOptions options; + ml::TransposeOptions options; options.permutation = permutation.data(); options.permutationCount = permutation.size(); ASSERT_CONTEXT_ERROR(mBuilder.Transpose(mInput, &options)); diff --git a/src/tests/unittests/validation/UnaryValidationTests.cpp b/src/tests/unittests/validation/UnaryValidationTests.cpp index 4d48003ca..9344cdbe3 100644 --- a/src/tests/unittests/validation/UnaryValidationTests.cpp +++ b/src/tests/unittests/validation/UnaryValidationTests.cpp @@ -24,17 +24,17 @@ TEST_F(UnaryValidationTest, SoftmaxValidation) { // success { std::vector shape = {2, 2}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand a = mBuilder.Input("input", &inputDesc); - webnn::Operand softmax = mBuilder.Softmax(a); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand a = mBuilder.Input("input", &inputDesc); + ml::Operand softmax = mBuilder.Softmax(a); } // Input dimensions is incorrect { std::vector shape = {2, 2, 2}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand a = mBuilder.Input("input", &inputDesc); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand a = mBuilder.Input("input", &inputDesc); ASSERT_CONTEXT_ERROR(mBuilder.Softmax(a)); } } @@ -43,9 +43,9 @@ TEST_F(UnaryValidationTest, ReluValidation) { // success { std::vector shape = {2, 2}; - webnn::OperandDescriptor inputDesc = {webnn::OperandType::Float32, shape.data(), - (uint32_t)shape.size()}; - webnn::Operand a = mBuilder.Input("input", &inputDesc); - webnn::Operand relu = mBuilder.Relu(a); + ml::OperandDescriptor inputDesc = {ml::OperandType::Float32, shape.data(), + (uint32_t)shape.size()}; + ml::Operand a = mBuilder.Input("input", &inputDesc); + ml::Operand relu = mBuilder.Relu(a); } } diff --git a/src/tests/unittests/validation/ValidationTest.cpp b/src/tests/unittests/validation/ValidationTest.cpp index 2d4abfe9a..709a69403 100644 --- a/src/tests/unittests/validation/ValidationTest.cpp +++ b/src/tests/unittests/validation/ValidationTest.cpp @@ -23,12 +23,12 @@ void ValidationTest::SetUp() { WebnnProcTable backendProcs = webnn_native::GetProcs(); ASSERT_NE(&backendProcs, nullptr); webnnProcSetProcs(&backendProcs); - WebnnNeuralNetworkContext context = webnn_native::CreateNeuralNetworkContext(); + MLContext context = webnn_native::CreateContext(); // GTest will not run test body if fail to create context. ASSERT_TRUE(context != nullptr); - mContext = webnn::NeuralNetworkContext::Acquire(context); + mContext = ml::Context::Acquire(context); mContext.SetUncapturedErrorCallback(ErrorCallback, this); - mBuilder = mContext.CreateModelBuilder(); + mBuilder = ml::CreateGraphBuilder(mContext); } ValidationTest::~ValidationTest() { @@ -51,8 +51,8 @@ std::string ValidationTest::GetLastErrorMessage() const { return mErrorMessage; } -void ValidationTest::ErrorCallback(WebnnErrorType type, char const* message, void* userdata) { - ASSERT(type != WebnnErrorType_NoError); +void ValidationTest::ErrorCallback(MLErrorType type, char const* message, void* userdata) { + ASSERT(type != MLErrorType_NoError); auto self = static_cast(userdata); self->mErrorMessage = message; diff --git a/src/tests/unittests/validation/ValidationTest.h b/src/tests/unittests/validation/ValidationTest.h index e199952ff..f4f3d8dfd 100644 --- a/src/tests/unittests/validation/ValidationTest.h +++ b/src/tests/unittests/validation/ValidationTest.h @@ -42,11 +42,11 @@ class ValidationTest : public testing::Test { std::string GetLastErrorMessage() const; protected: - webnn::NeuralNetworkContext mContext; - webnn::ModelBuilder mBuilder; + ml::Context mContext; + ml::GraphBuilder mBuilder; private: - static void ErrorCallback(WebnnErrorType type, const char* message, void* userdata); + static void ErrorCallback(MLErrorType type, const char* message, void* userdata); std::string mErrorMessage; bool mExpectError = false; bool mError = false; diff --git a/src/webnn_native/BUILD.gn b/src/webnn_native/BUILD.gn index 15772c4e5..9d93d7745 100644 --- a/src/webnn_native/BUILD.gn +++ b/src/webnn_native/BUILD.gn @@ -89,24 +89,22 @@ source_set("webnn_native_sources") { sources = get_target_outputs(":webnn_native_utils_gen") sources += [ - "Compilation.cpp", - "Compilation.h", + "Context.cpp", + "Context.h", "Error.cpp", "Error.h", "ErrorData.cpp", "ErrorData.h", "ErrorScope.cpp", "ErrorScope.h", - "Model.cpp", - "Model.h", - "ModelBuilder.cpp", - "ModelBuilder.h", + "Graph.cpp", + "Graph.h", + "GraphBuilder.cpp", + "GraphBuilder.h", "NamedInputs.h", "NamedOutputs.h", "NamedRecords.h", "NamedResults.h", - "NeuralNetworkContext.cpp", - "NeuralNetworkContext.h", "ObjectBase.cpp", "ObjectBase.h", "Operand.cpp", @@ -136,23 +134,19 @@ source_set("webnn_native_sources") { if (webnn_enable_null) { sources += [ - "null/NeuralNetworkContextNull.cpp", - "null/NeuralNetworkContextNull.h", + "null/ContextNull.cpp", + "null/ContextNull.h", ] } if (webnn_enable_openvino) { sources += [ - "openvino/CompilationIE.cpp", - "openvino/CompilationIE.h", + "openvino/ContextIE.cpp", + "openvino/ContextIE.h", "openvino/ErrorIE.cpp", "openvino/ErrorIE.h", - "openvino/ModelBuilderIE.cpp", - "openvino/ModelBuilderIE.h", - "openvino/ModelIE.cpp", - "openvino/ModelIE.h", - "openvino/NeuralNetworkContextIE.cpp", - "openvino/NeuralNetworkContextIE.h", + "openvino/GraphIE.cpp", + "openvino/GraphIE.h", ] sources += [ @@ -180,14 +174,10 @@ source_set("webnn_native_sources") { ] sources += [ - "dml/CompilationDML.cpp", - "dml/CompilationDML.h", - "dml/ModelBuilderDML.cpp", - "dml/ModelBuilderDML.h", - "dml/ModelDML.cpp", - "dml/ModelDML.h", - "dml/NeuralNetworkContextDML.cpp", - "dml/NeuralNetworkContextDML.h", + "dml/ContextDML.cpp", + "dml/ContextDML.h", + "dml/GraphDML.cpp", + "dml/GraphDML.h", ] include_dirs = [ diff --git a/src/webnn_native/Compilation.cpp b/src/webnn_native/Compilation.cpp deleted file mode 100644 index 36b4d2fe0..000000000 --- a/src/webnn_native/Compilation.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "webnn_native/Operand.h" - -#include "common/Assert.h" -#include "common/Log.h" -#include "webnn_native/NamedResults.h" - -namespace webnn_native { - - void CompilationBase::Compute(NamedInputsBase* inputs, - WebnnComputeCallback callback, - void* userdata, - NamedOutputsBase* outputs) { - ComputeImpl(inputs, callback, userdata, outputs); - } - -} // namespace webnn_native diff --git a/src/webnn_native/Compilation.h b/src/webnn_native/Compilation.h deleted file mode 100644 index f57eb57ea..000000000 --- a/src/webnn_native/Compilation.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef WEBNN_NATIVE_COMPILATION_H_ -#define WEBNN_NATIVE_COMPILATION_H_ - -#include "common/RefCounted.h" -#include "webnn_native/Forward.h" -#include "webnn_native/NamedInputs.h" -#include "webnn_native/NamedOutputs.h" -#include "webnn_native/ObjectBase.h" -#include "webnn_native/webnn_platform.h" - -namespace webnn_native { - - class CompilationBase : public RefCounted { - public: - CompilationBase() = default; - virtual ~CompilationBase() = default; - - // Dawn API - void Compute(NamedInputsBase* inputs, - WebnnComputeCallback callback, - void* userdata, - NamedOutputsBase* outputs = nullptr); - - private: - virtual void ComputeImpl(NamedInputsBase* inputs, - WebnnComputeCallback callback, - void* userdata, - NamedOutputsBase* outputs) = 0; - }; -} // namespace webnn_native - -#endif // WEBNN_NATIVE_COMPILATION_H_ \ No newline at end of file diff --git a/src/webnn_native/NeuralNetworkContext.cpp b/src/webnn_native/Context.cpp similarity index 68% rename from src/webnn_native/NeuralNetworkContext.cpp rename to src/webnn_native/Context.cpp index 225936d5d..c42d37109 100644 --- a/src/webnn_native/NeuralNetworkContext.cpp +++ b/src/webnn_native/Context.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "webnn_native/NeuralNetworkContext.h" +#include "webnn_native/Context.h" #include @@ -21,27 +21,23 @@ namespace webnn_native { - NeuralNetworkContextBase::NeuralNetworkContextBase() { + ContextBase::ContextBase() { mRootErrorScope = AcquireRef(new ErrorScope()); mCurrentErrorScope = mRootErrorScope.Get(); } - ModelBuilderBase* NeuralNetworkContextBase::CreateModelBuilder() { - return CreateModelBuilderImpl(); + GraphBase* ContextBase::CreateGraph() { + return CreateGraphImpl(); } - ModelBuilderBase* NeuralNetworkContextBase::CreateModelBuilderImpl() { - UNREACHABLE(); - } - - void NeuralNetworkContextBase::PushErrorScope(webnn::ErrorFilter filter) { + void ContextBase::PushErrorScope(ml::ErrorFilter filter) { if (ConsumedError(ValidateErrorFilter(filter))) { return; } mCurrentErrorScope = AcquireRef(new ErrorScope(filter, mCurrentErrorScope.Get())); } - bool NeuralNetworkContextBase::PopErrorScope(webnn::ErrorCallback callback, void* userdata) { + bool ContextBase::PopErrorScope(ml::ErrorCallback callback, void* userdata) { if (DAWN_UNLIKELY(mCurrentErrorScope.Get() == mRootErrorScope.Get())) { return false; } @@ -51,12 +47,11 @@ namespace webnn_native { return true; } - void NeuralNetworkContextBase::SetUncapturedErrorCallback(webnn::ErrorCallback callback, - void* userdata) { + void ContextBase::SetUncapturedErrorCallback(ml::ErrorCallback callback, void* userdata) { mRootErrorScope->SetCallback(callback, userdata); } - void NeuralNetworkContextBase::HandleError(std::unique_ptr error) { + void ContextBase::HandleError(std::unique_ptr error) { ASSERT(error != nullptr); std::ostringstream ss; ss << error->GetMessage(); @@ -67,7 +62,7 @@ namespace webnn_native { // Still forward device loss and internal errors to the error scopes so they // all reject. - mCurrentErrorScope->HandleError(ToWebnnErrorType(error->GetType()), ss.str().c_str()); + mCurrentErrorScope->HandleError(ToMLErrorType(error->GetType()), ss.str().c_str()); } } // namespace webnn_native diff --git a/src/webnn_native/NeuralNetworkContext.h b/src/webnn_native/Context.h similarity index 66% rename from src/webnn_native/NeuralNetworkContext.h rename to src/webnn_native/Context.h index 814b27804..650687f46 100644 --- a/src/webnn_native/NeuralNetworkContext.h +++ b/src/webnn_native/Context.h @@ -12,20 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef WEBNN_NATIVE_NEURAL_NETWORK_CONTEXT_H_ -#define WEBNN_NATIVE_NEURAL_NETWORK_CONTEXT_H_ +#ifndef WEBNN_NATIVE_CONTEXT_H_ +#define WEBNN_NATIVE_CONTEXT_H_ #include "common/RefCounted.h" #include "webnn_native/Error.h" #include "webnn_native/ErrorScope.h" #include "webnn_native/webnn_platform.h" +class WebGLRenderingContext; namespace webnn_native { - class NeuralNetworkContextBase : public RefCounted { + class ContextBase : public RefCounted { public: - NeuralNetworkContextBase(); - virtual ~NeuralNetworkContextBase() = default; + ContextBase(); + virtual ~ContextBase() = default; bool ConsumedError(MaybeError maybeError) { if (DAWN_UNLIKELY(maybeError.IsError())) { @@ -35,15 +36,18 @@ namespace webnn_native { return false; } + GraphBase* CreateGraph(); + // Dawn API - ModelBuilderBase* CreateModelBuilder(); - void PushErrorScope(webnn::ErrorFilter filter); - bool PopErrorScope(webnn::ErrorCallback callback, void* userdata); - void SetUncapturedErrorCallback(webnn::ErrorCallback callback, void* userdata); + void PushErrorScope(ml::ErrorFilter filter); + bool PopErrorScope(ml::ErrorCallback callback, void* userdata); + void SetUncapturedErrorCallback(ml::ErrorCallback callback, void* userdata); private: + // Create concrete model. + virtual GraphBase* CreateGraphImpl() = 0; + void HandleError(std::unique_ptr error); - virtual ModelBuilderBase* CreateModelBuilderImpl(); Ref mRootErrorScope; Ref mCurrentErrorScope; @@ -51,4 +55,4 @@ namespace webnn_native { } // namespace webnn_native -#endif // WEBNN_NATIVE_NEURAL_NETWORK_CONTEXT_H_ +#endif // WEBNN_NATIVE_CONTEXT_H_ diff --git a/src/webnn_native/Error.cpp b/src/webnn_native/Error.cpp index eb9fc6eb8..8488a8676 100644 --- a/src/webnn_native/Error.cpp +++ b/src/webnn_native/Error.cpp @@ -31,32 +31,32 @@ namespace webnn_native { } } - webnn::ErrorType ToWebnnErrorType(InternalErrorType type) { + ml::ErrorType ToMLErrorType(InternalErrorType type) { switch (type) { case InternalErrorType::Validation: - return webnn::ErrorType::Validation; + return ml::ErrorType::Validation; case InternalErrorType::OutOfMemory: - return webnn::ErrorType::OutOfMemory; + return ml::ErrorType::OutOfMemory; // There is no equivalent of Internal errors in the WebGPU API. Internal // errors cause the device at the API level to be lost, so treat it like a // DeviceLost error. case InternalErrorType::Internal: case InternalErrorType::DeviceLost: - return webnn::ErrorType::DeviceLost; + return ml::ErrorType::DeviceLost; default: - return webnn::ErrorType::Unknown; + return ml::ErrorType::Unknown; } } - InternalErrorType FromWebnnErrorType(webnn::ErrorType type) { + InternalErrorType FromMLErrorType(ml::ErrorType type) { switch (type) { - case webnn::ErrorType::Validation: + case ml::ErrorType::Validation: return InternalErrorType::Validation; - case webnn::ErrorType::OutOfMemory: + case ml::ErrorType::OutOfMemory: return InternalErrorType::OutOfMemory; - case webnn::ErrorType::DeviceLost: + case ml::ErrorType::DeviceLost: return InternalErrorType::DeviceLost; default: return InternalErrorType::Internal; diff --git a/src/webnn_native/Error.h b/src/webnn_native/Error.h index bf3c814b0..d7afa9e83 100644 --- a/src/webnn_native/Error.h +++ b/src/webnn_native/Error.h @@ -117,8 +117,8 @@ namespace webnn_native { // Assert that errors are device loss so that we can continue with destruction void IgnoreErrors(MaybeError maybeError); - webnn::ErrorType ToWebnnErrorType(InternalErrorType type); - InternalErrorType FromWebnnErrorType(webnn::ErrorType type); + ml::ErrorType ToMLErrorType(InternalErrorType type); + InternalErrorType FromMLErrorType(ml::ErrorType type); } // namespace webnn_native diff --git a/src/webnn_native/ErrorData.h b/src/webnn_native/ErrorData.h index e02c7fd53..0e5f16d6d 100644 --- a/src/webnn_native/ErrorData.h +++ b/src/webnn_native/ErrorData.h @@ -23,12 +23,12 @@ #include #include -namespace webnn { +namespace ml { enum class ErrorType : uint32_t; } namespace dawn { - using ErrorType = webnn::ErrorType; + using ErrorType = ml::ErrorType; } namespace webnn_native { diff --git a/src/webnn_native/ErrorScope.cpp b/src/webnn_native/ErrorScope.cpp index ef3b0c7df..2e269a3ab 100644 --- a/src/webnn_native/ErrorScope.cpp +++ b/src/webnn_native/ErrorScope.cpp @@ -22,7 +22,7 @@ namespace webnn_native { ErrorScope::ErrorScope() : mIsRoot(true) { } - ErrorScope::ErrorScope(webnn::ErrorFilter errorFilter, ErrorScope* parent) + ErrorScope::ErrorScope(ml::ErrorFilter errorFilter, ErrorScope* parent) : RefCounted(), mErrorFilter(errorFilter), mParent(parent), mIsRoot(false) { ASSERT(mParent.Get() != nullptr); } @@ -33,7 +33,7 @@ namespace webnn_native { } } - void ErrorScope::SetCallback(webnn::ErrorCallback callback, void* userdata) { + void ErrorScope::SetCallback(ml::ErrorCallback callback, void* userdata) { mCallback = callback; mUserdata = userdata; } @@ -51,12 +51,12 @@ namespace webnn_native { if (mCallback != nullptr) { // For non-root error scopes, the callback can run at most once. - mCallback(static_cast(mErrorType), mErrorMessage.c_str(), mUserdata); + mCallback(static_cast(mErrorType), mErrorMessage.c_str(), mUserdata); mCallback = nullptr; } } - void ErrorScope::HandleError(webnn::ErrorType type, const char* message) { + void ErrorScope::HandleError(ml::ErrorType type, const char* message) { HandleErrorImpl(this, type, message); } @@ -65,25 +65,23 @@ namespace webnn_native { } // static - void ErrorScope::HandleErrorImpl(ErrorScope* scope, - webnn::ErrorType type, - const char* message) { + void ErrorScope::HandleErrorImpl(ErrorScope* scope, ml::ErrorType type, const char* message) { ErrorScope* currentScope = scope; for (; !currentScope->IsRoot(); currentScope = currentScope->GetParent()) { ASSERT(currentScope != nullptr); bool consumed = false; switch (type) { - case webnn::ErrorType::Validation: - if (currentScope->mErrorFilter != webnn::ErrorFilter::Validation) { + case ml::ErrorType::Validation: + if (currentScope->mErrorFilter != ml::ErrorFilter::Validation) { // Error filter does not match. Move on to the next scope. continue; } consumed = true; break; - case webnn::ErrorType::OutOfMemory: - if (currentScope->mErrorFilter != webnn::ErrorFilter::OutOfMemory) { + case ml::ErrorType::OutOfMemory: + if (currentScope->mErrorFilter != ml::ErrorFilter::OutOfMemory) { // Error filter does not match. Move on to the next scope. continue; } @@ -92,18 +90,18 @@ namespace webnn_native { // Unknown and DeviceLost are fatal. All error scopes capture them. // |consumed| is false because these should bubble to all scopes. - case webnn::ErrorType::Unknown: - case webnn::ErrorType::DeviceLost: + case ml::ErrorType::Unknown: + case ml::ErrorType::DeviceLost: consumed = false; break; - case webnn::ErrorType::NoError: + case ml::ErrorType::NoError: UNREACHABLE(); return; } // Record the error if the scope doesn't have one yet. - if (currentScope->mErrorType == webnn::ErrorType::NoError) { + if (currentScope->mErrorType == ml::ErrorType::NoError) { currentScope->mErrorType = type; currentScope->mErrorMessage = message; } @@ -116,7 +114,7 @@ namespace webnn_native { // The root error scope captures all uncaptured errors. ASSERT(currentScope->IsRoot()); if (currentScope->mCallback) { - currentScope->mCallback(static_cast(type), message, + currentScope->mCallback(static_cast(type), message, currentScope->mUserdata); } } @@ -132,8 +130,8 @@ namespace webnn_native { ASSERT(parentScope.Get() != nullptr); // On shutdown, error scopes that have yet to have a status get Unknown. - if (currentScope->mErrorType == webnn::ErrorType::NoError) { - currentScope->mErrorType = webnn::ErrorType::Unknown; + if (currentScope->mErrorType == ml::ErrorType::NoError) { + currentScope->mErrorType = ml::ErrorType::Unknown; currentScope->mErrorMessage = "Error scope destroyed"; } diff --git a/src/webnn_native/ErrorScope.h b/src/webnn_native/ErrorScope.h index 95dd00972..d38b3c70a 100644 --- a/src/webnn_native/ErrorScope.h +++ b/src/webnn_native/ErrorScope.h @@ -39,12 +39,12 @@ namespace webnn_native { class ErrorScope final : public RefCounted { public: ErrorScope(); // Constructor for the root error scope. - ErrorScope(webnn::ErrorFilter errorFilter, ErrorScope* parent); + ErrorScope(ml::ErrorFilter errorFilter, ErrorScope* parent); - void SetCallback(webnn::ErrorCallback callback, void* userdata); + void SetCallback(ml::ErrorCallback callback, void* userdata); ErrorScope* GetParent(); - void HandleError(webnn::ErrorType type, const char* message); + void HandleError(ml::ErrorType type, const char* message); void UnlinkForShutdown(); private: @@ -52,17 +52,17 @@ namespace webnn_native { bool IsRoot() const; void RunNonRootCallback(); - static void HandleErrorImpl(ErrorScope* scope, webnn::ErrorType type, const char* message); + static void HandleErrorImpl(ErrorScope* scope, ml::ErrorType type, const char* message); static void UnlinkForShutdownImpl(ErrorScope* scope); - webnn::ErrorFilter mErrorFilter = webnn::ErrorFilter::None; + ml::ErrorFilter mErrorFilter = ml::ErrorFilter::None; Ref mParent = nullptr; bool mIsRoot; - webnn::ErrorCallback mCallback = nullptr; + ml::ErrorCallback mCallback = nullptr; void* mUserdata = nullptr; - webnn::ErrorType mErrorType = webnn::ErrorType::NoError; + ml::ErrorType mErrorType = ml::ErrorType::NoError; std::string mErrorMessage = ""; }; diff --git a/src/webnn_native/Forward.h b/src/webnn_native/Forward.h index df70e71ff..fe10ba15d 100644 --- a/src/webnn_native/Forward.h +++ b/src/webnn_native/Forward.h @@ -24,8 +24,8 @@ class Ref; namespace webnn_native { class CompilationBase; - class ModelBase; - class ModelBuilderBase; + class GraphBase; + class GraphBuilderBase; class NamedInputsBase; class NamedOperandsBase; class NamedOutputsBase; diff --git a/src/webnn_native/Graph.cpp b/src/webnn_native/Graph.cpp new file mode 100644 index 000000000..1d9135c5f --- /dev/null +++ b/src/webnn_native/Graph.cpp @@ -0,0 +1,74 @@ +// Copyright 2021 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "webnn_native/Graph.h" + +#include + +#include "common/Assert.h" +#include "common/RefCounted.h" + +namespace webnn_native { + + GraphBase::GraphBase(ContextBase* context) : ObjectBase(context) { + } + + void GraphBase::Compute(NamedInputsBase* inputs, + MLComputeCallback callback, + void* userdata, + NamedOutputsBase* outputs) { + ComputeImpl(inputs, callback, userdata, outputs); + } + + MaybeError GraphBase::AddConstant(const op::Constant* constant) { + UNREACHABLE(); + } + + MaybeError GraphBase::AddInput(const op::Input* input) { + UNREACHABLE(); + } + + MaybeError GraphBase::AddOutput(const std::string& name, const OperandBase* output) { + UNREACHABLE(); + } + + MaybeError GraphBase::AddBinary(const op::Binary* binary) { + UNREACHABLE(); + } + + MaybeError GraphBase::AddConv2d(const op::Conv2d* conv2d) { + UNREACHABLE(); + } + + MaybeError GraphBase::AddPool2d(const op::Pool2d* pool2d) { + UNREACHABLE(); + } + + MaybeError GraphBase::AddReshape(const op::Reshape* relu) { + UNREACHABLE(); + } + + MaybeError GraphBase::AddTranspose(const op::Transpose* transpose) { + UNREACHABLE(); + } + + MaybeError GraphBase::AddUnary(const op::Unary* unary) { + UNREACHABLE(); + } + + MaybeError GraphBase::Finish() { + UNREACHABLE(); + } + +} // namespace webnn_native diff --git a/src/webnn_native/Model.h b/src/webnn_native/Graph.h similarity index 72% rename from src/webnn_native/Model.h rename to src/webnn_native/Graph.h index da38d883c..1f4f3fb99 100644 --- a/src/webnn_native/Model.h +++ b/src/webnn_native/Graph.h @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef WEBNN_NATIVE_MODEL_H_ -#define WEBNN_NATIVE_MODEL_H_ +#ifndef WEBNN_NATIVE_GRAPH_H_ +#define WEBNN_NATIVE_GRAPH_H_ #include "common/RefCounted.h" -#include "webnn_native/Compilation.h" +#include "webnn_native/Context.h" #include "webnn_native/Error.h" #include "webnn_native/Forward.h" -#include "webnn_native/ModelBuilder.h" #include "webnn_native/ObjectBase.h" #include "webnn_native/Operand.h" #include "webnn_native/webnn_platform.h" @@ -37,18 +36,16 @@ namespace webnn_native { class Unary; } // namespace op - class ModelBase : public ObjectBase { + class GraphBase : public ObjectBase { public: - explicit ModelBase(ModelBuilderBase* modelBuilder); - virtual ~ModelBase() = default; + explicit GraphBase(ContextBase* context); + virtual ~GraphBase() = default; - // static - static ModelBase* MakeError(ModelBuilderBase* modelBuilder); - - // Dawn API - void Compile(WebnnCompileCallback callback, + // Webnn API + void Compute(NamedInputsBase* inputs, + MLComputeCallback callback, void* userdata, - CompilationOptions const* options); + NamedOutputsBase* outputs = nullptr); virtual MaybeError AddConstant(const op::Constant* constant); virtual MaybeError AddInput(const op::Input* input); @@ -62,10 +59,10 @@ namespace webnn_native { virtual MaybeError Finish(); private: - ModelBase(ModelBuilderBase* modelBuilder, ObjectBase::ErrorTag tag); - virtual void CompileImpl(WebnnCompileCallback callback, + virtual void ComputeImpl(NamedInputsBase* inputs, + MLComputeCallback callback, void* userdata, - CompilationOptions const* options); + NamedOutputsBase* outputs) = 0; }; } // namespace webnn_native diff --git a/src/webnn_native/ModelBuilder.cpp b/src/webnn_native/GraphBuilder.cpp similarity index 72% rename from src/webnn_native/ModelBuilder.cpp rename to src/webnn_native/GraphBuilder.cpp index 5ca1273eb..af5cd8ce5 100644 --- a/src/webnn_native/ModelBuilder.cpp +++ b/src/webnn_native/GraphBuilder.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "webnn_native/ModelBuilder.h" +#include "webnn_native/GraphBuilder.h" #include #include @@ -20,9 +20,10 @@ #include #include "common/Assert.h" +#include "common/Log.h" #include "common/RefCounted.h" -#include "webnn_native/Model.h" -#include "webnn_native/NeuralNetworkContext.h" +#include "webnn_native/Context.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" #include "webnn_native/ops/Binary.h" #include "webnn_native/ops/Constant.h" @@ -42,92 +43,105 @@ for (;;) \ break +#define BUILD_ERROR_AND_CALLBACK(message) \ + do { \ + callback(MLBuildStatus_Error, nullptr, message, userdata); \ + return; \ + } while (0) + namespace webnn_native { - ModelBuilderBase::ModelBuilderBase(NeuralNetworkContextBase* context) : ObjectBase(context) { + GraphBuilderBase::GraphBuilderBase(ContextBase* context) : ObjectBase(context) { } - OperandBase* ModelBuilderBase::Constant(OperandDescriptor const* desc, + OperandBase* GraphBuilderBase::Constant(OperandDescriptor const* desc, void const* value, size_t size) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Constant(this, desc, value, size)); } - OperandBase* ModelBuilderBase::Input(char const* name, OperandDescriptor const* desc) { + OperandBase* GraphBuilderBase::Input(char const* name, OperandDescriptor const* desc) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Input(this, std::string(name), desc)); } - OperandBase* ModelBuilderBase::Matmul(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::Matmul(OperandBase* a, OperandBase* b) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Binary(this, op::BinaryOpType::kMatMul, a, b)); } - OperandBase* ModelBuilderBase::Add(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::Add(OperandBase* a, OperandBase* b) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Binary(this, op::BinaryOpType::kAdd, a, b)); } - OperandBase* ModelBuilderBase::Mul(OperandBase* a, OperandBase* b) { + OperandBase* GraphBuilderBase::Mul(OperandBase* a, OperandBase* b) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Binary(this, op::BinaryOpType::kMul, a, b)); } - OperandBase* ModelBuilderBase::Conv2d(OperandBase* input, + OperandBase* GraphBuilderBase::Conv2d(OperandBase* input, OperandBase* filter, Conv2dOptions const* options) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Conv2d(this, input, filter, options)); } - OperandBase* ModelBuilderBase::AveragePool2d(OperandBase* input, Pool2dOptions const* options) { + OperandBase* GraphBuilderBase::AveragePool2d(OperandBase* input, Pool2dOptions const* options) { DAWN_VALIDATE_AND_INFER_TYPES( new op::Pool2d(this, op::Pool2dType::kAveragePool2d, input, options)); } - OperandBase* ModelBuilderBase::MaxPool2d(OperandBase* input, Pool2dOptions const* options) { + OperandBase* GraphBuilderBase::MaxPool2d(OperandBase* input, Pool2dOptions const* options) { DAWN_VALIDATE_AND_INFER_TYPES( new op::Pool2d(this, op::Pool2dType::kMaxPool2d, input, options)); } - OperandBase* ModelBuilderBase::Relu(OperandBase* input) { + OperandBase* GraphBuilderBase::Relu(OperandBase* input) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Unary(this, op::UnaryOpType::kRelu, input)); } - OperandBase* ModelBuilderBase::Reshape(OperandBase* input, + OperandBase* GraphBuilderBase::Reshape(OperandBase* input, int32_t const* new_shape, size_t new_shape_count) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Reshape(this, input, new_shape, new_shape_count)); } - OperandBase* ModelBuilderBase::Softmax(OperandBase* input) { + OperandBase* GraphBuilderBase::Softmax(OperandBase* input) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Unary(this, op::UnaryOpType::kSoftmax, input)); } - OperandBase* ModelBuilderBase::Transpose(OperandBase* input, TransposeOptions const* options) { + OperandBase* GraphBuilderBase::Transpose(OperandBase* input, TransposeOptions const* options) { DAWN_VALIDATE_AND_INFER_TYPES(new op::Transpose(this, input, options)); } - ModelBase* ModelBuilderBase::CreateModel(NamedOperandsBase const* namedOperands) { - Ref model = AcquireRef(CreateModelImpl()); + void GraphBuilderBase::Build(NamedOperandsBase const* namedOperands, + MLBuildCallback callback, + void* userdata) { + if (DAWN_UNLIKELY(this->IsError())) { + BUILD_ERROR_AND_CALLBACK("This Graph object is an error"); + } + std::vector outputs; if (namedOperands->GetRecords().empty()) { - return ModelBase::MakeError(this); + BUILD_ERROR_AND_CALLBACK("The output named operands are empty."); } for (auto& namedOutput : namedOperands->GetRecords()) { outputs.push_back(namedOutput.second); } std::vector sorted_operands = TopologicalSort(outputs); + Ref graph = AcquireRef(GetContext()->CreateGraph()); for (auto& op : sorted_operands) { - if (op->IsError() || GetContext()->ConsumedError(op->AddToModel(model.Get()))) { - return ModelBase::MakeError(this); + if (op->IsError() || GetContext()->ConsumedError(op->AddToGraph(graph.Get()))) { + BUILD_ERROR_AND_CALLBACK("Failed to add the operand when building graph."); } } for (auto& namedOutput : namedOperands->GetRecords()) { if (GetContext()->ConsumedError( - model->AddOutput(namedOutput.first, namedOutput.second))) { - return ModelBase::MakeError(this); + graph->AddOutput(namedOutput.first, namedOutput.second))) { + BUILD_ERROR_AND_CALLBACK("Failed to add output when building graph."); } } - if (GetContext()->ConsumedError(model->Finish())) { - return ModelBase::MakeError(this); + if (GetContext()->ConsumedError(graph->Finish())) { + BUILD_ERROR_AND_CALLBACK("Failed to finish building graph."); } - return model.Detach(); + callback(MLBuildStatus_Success, reinterpret_cast(graph.Detach()), nullptr, + userdata); } // The implementation derives from nGraph topological_sort in @@ -148,7 +162,7 @@ namespace webnn_native { // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** - std::vector ModelBuilderBase::TopologicalSort( + std::vector GraphBuilderBase::TopologicalSort( std::vector& rootNodes) { std::stack nodesToDo; std::unordered_set nodesDone; diff --git a/src/webnn_native/ModelBuilder.h b/src/webnn_native/GraphBuilder.h similarity index 86% rename from src/webnn_native/ModelBuilder.h rename to src/webnn_native/GraphBuilder.h index 45c9074da..70610faed 100644 --- a/src/webnn_native/ModelBuilder.h +++ b/src/webnn_native/GraphBuilder.h @@ -25,10 +25,10 @@ namespace webnn_native { - class ModelBuilderBase : public ObjectBase { + class GraphBuilderBase : public ObjectBase { public: - ModelBuilderBase(NeuralNetworkContextBase* context); - virtual ~ModelBuilderBase() = default; + GraphBuilderBase(ContextBase* context); + virtual ~GraphBuilderBase() = default; // WebNN API OperandBase* Constant(OperandDescriptor const* desc, void const* value, size_t size); @@ -43,12 +43,11 @@ namespace webnn_native { OperandBase* Reshape(OperandBase*, int32_t const*, size_t); OperandBase* Softmax(OperandBase*); OperandBase* Transpose(OperandBase*, TransposeOptions const* options); - ModelBase* CreateModel(NamedOperandsBase const* named_operands); + void Build(NamedOperandsBase const* named_operands, + MLBuildCallback callback, + void* userdata); private: - // Create concrete model. - virtual ModelBase* CreateModelImpl() = 0; - // Topological sort of nodes needed to compute rootNodes std::vector TopologicalSort(std::vector& rootNodes); }; diff --git a/src/webnn_native/Model.cpp b/src/webnn_native/Model.cpp deleted file mode 100644 index 39691db1d..000000000 --- a/src/webnn_native/Model.cpp +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "webnn_native/Model.h" - -#include - -#include "common/Assert.h" -#include "common/RefCounted.h" - -namespace webnn_native { - - ModelBase::ModelBase(ModelBuilderBase* modelBuilder) : ObjectBase(modelBuilder->GetContext()) { - } - - void ModelBase::Compile(WebnnCompileCallback callback, - void* userdata, - CompilationOptions const* options) { - if (DAWN_UNLIKELY(this->IsError())) { - callback(WebnnCompileStatus_Error, nullptr, "This Model object is an error", userdata); - return; - } - CompileImpl(callback, userdata, options); - } - - ModelBase::ModelBase(ModelBuilderBase* modelBuilder, ObjectBase::ErrorTag tag) - : ObjectBase(modelBuilder->GetContext(), tag) { - } - - // static - ModelBase* ModelBase::MakeError(ModelBuilderBase* modelBuilder) { - return new ModelBase(modelBuilder, ObjectBase::kError); - } - - MaybeError ModelBase::AddConstant(const op::Constant* constant) { - UNREACHABLE(); - } - - MaybeError ModelBase::AddInput(const op::Input* input) { - UNREACHABLE(); - } - - MaybeError ModelBase::AddOutput(const std::string& name, const OperandBase* output) { - UNREACHABLE(); - } - - MaybeError ModelBase::AddBinary(const op::Binary* binary) { - UNREACHABLE(); - } - - MaybeError ModelBase::AddConv2d(const op::Conv2d* conv2d) { - UNREACHABLE(); - } - - MaybeError ModelBase::AddPool2d(const op::Pool2d* pool2d) { - UNREACHABLE(); - } - - MaybeError ModelBase::AddReshape(const op::Reshape* relu) { - UNREACHABLE(); - } - - MaybeError ModelBase::AddTranspose(const op::Transpose* transpose) { - UNREACHABLE(); - } - - MaybeError ModelBase::AddUnary(const op::Unary* unary) { - UNREACHABLE(); - } - - MaybeError ModelBase::Finish() { - UNREACHABLE(); - } - - void ModelBase::CompileImpl(WebnnCompileCallback callback, - void* userdata, - CompilationOptions const* options) { - UNREACHABLE(); - } - -} // namespace webnn_native diff --git a/src/webnn_native/ObjectBase.cpp b/src/webnn_native/ObjectBase.cpp index 23c6077ec..4993a6517 100644 --- a/src/webnn_native/ObjectBase.cpp +++ b/src/webnn_native/ObjectBase.cpp @@ -18,15 +18,14 @@ namespace webnn_native { static constexpr uint64_t kErrorPayload = 0; static constexpr uint64_t kNotErrorPayload = 1; - ObjectBase::ObjectBase(NeuralNetworkContextBase* context) - : RefCounted(kNotErrorPayload), mContext(context) { + ObjectBase::ObjectBase(ContextBase* context) : RefCounted(kNotErrorPayload), mContext(context) { } - ObjectBase::ObjectBase(NeuralNetworkContextBase* context, ErrorTag) + ObjectBase::ObjectBase(ContextBase* context, ErrorTag) : RefCounted(kErrorPayload), mContext(context) { } - NeuralNetworkContextBase* ObjectBase::GetContext() const { + ContextBase* ObjectBase::GetContext() const { return mContext; } diff --git a/src/webnn_native/ObjectBase.h b/src/webnn_native/ObjectBase.h index 3f7d7d4e5..48cda1da7 100644 --- a/src/webnn_native/ObjectBase.h +++ b/src/webnn_native/ObjectBase.h @@ -15,7 +15,7 @@ #ifndef WEBNN_NATIVE_OBJECT_BASE_H_ #define WEBNN_NATIVE_OBJECT_BASE_H_ -#include "webnn_native/NeuralNetworkContext.h" +#include "webnn_native/Context.h" namespace webnn_native { @@ -24,17 +24,17 @@ namespace webnn_native { struct ErrorTag {}; static constexpr ErrorTag kError = {}; - explicit ObjectBase(NeuralNetworkContextBase* context); - ObjectBase(NeuralNetworkContextBase* context, ErrorTag tag); + explicit ObjectBase(ContextBase* context); + ObjectBase(ContextBase* context, ErrorTag tag); - NeuralNetworkContextBase* GetContext() const; + ContextBase* GetContext() const; bool IsError() const; protected: ~ObjectBase() override = default; private: - NeuralNetworkContextBase* mContext; + ContextBase* mContext; }; } // namespace webnn_native diff --git a/src/webnn_native/Operand.cpp b/src/webnn_native/Operand.cpp index c019bb85b..2c3200e74 100644 --- a/src/webnn_native/Operand.cpp +++ b/src/webnn_native/Operand.cpp @@ -17,23 +17,24 @@ #include "common/Assert.h" #include "common/Log.h" +#include "webnn_native/GraphBuilder.h" namespace webnn_native { - OperandBase::OperandBase(ModelBuilderBase* modelBuilder, std::vector> inputs) - : ObjectBase(modelBuilder->GetContext()), mInputs(std::move(inputs)) { + OperandBase::OperandBase(GraphBuilderBase* graphBuilder, std::vector> inputs) + : ObjectBase(graphBuilder->GetContext()), mInputs(std::move(inputs)) { } - OperandBase::OperandBase(ModelBuilderBase* modelBuilder, ObjectBase::ErrorTag tag) - : ObjectBase(modelBuilder->GetContext(), tag) { + OperandBase::OperandBase(GraphBuilderBase* graphBuilder, ObjectBase::ErrorTag tag) + : ObjectBase(graphBuilder->GetContext(), tag) { } // static - OperandBase* OperandBase::MakeError(ModelBuilderBase* modelBuilder) { - return new OperandBase(modelBuilder, ObjectBase::kError); + OperandBase* OperandBase::MakeError(GraphBuilderBase* GraphBuilder) { + return new OperandBase(GraphBuilder, ObjectBase::kError); } - MaybeError OperandBase::AddToModel(ModelBase* model) const { + MaybeError OperandBase::AddToGraph(GraphBase* model) const { DAWN_UNREACHABLE(); } diff --git a/src/webnn_native/Operand.h b/src/webnn_native/Operand.h index 92b908940..a1fe9a771 100644 --- a/src/webnn_native/Operand.h +++ b/src/webnn_native/Operand.h @@ -19,7 +19,7 @@ #include #include "webnn_native/Forward.h" -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/ObjectBase.h" #include "webnn_native/webnn_platform.h" @@ -27,34 +27,34 @@ namespace webnn_native { class OperandBase : public ObjectBase { public: - explicit OperandBase(ModelBuilderBase* modelBuilder, std::vector> = {}); + explicit OperandBase(GraphBuilderBase* GraphBuilder, std::vector> = {}); virtual ~OperandBase() = default; // It's used for getting inputs when traversaling model tree. const std::vector>& Inputs() const; // Add the operand to model for specific backend. - virtual MaybeError AddToModel(ModelBase* model) const; + virtual MaybeError AddToGraph(GraphBase* model) const; - webnn::OperandType Type() const { + ml::OperandType Type() const { return mType; } int32_t Rank() const { return mRank; } - static OperandBase* MakeError(ModelBuilderBase* modelBuilder); + static OperandBase* MakeError(GraphBuilderBase* GraphBuilder); virtual MaybeError ValidateAndInferTypes() { UNREACHABLE(); } private: - OperandBase(ModelBuilderBase* modelBuilder, ObjectBase::ErrorTag tag); + OperandBase(GraphBuilderBase* GraphBuilder, ObjectBase::ErrorTag tag); protected: // The inputs of operand. std::vector> mInputs; // The operand type. - webnn::OperandType mType; + ml::OperandType mType; // only set rank for dimensions int32_t mRank; }; diff --git a/src/webnn_native/WebnnNative.cpp b/src/webnn_native/WebnnNative.cpp index ad0371b8c..c085e6fec 100644 --- a/src/webnn_native/WebnnNative.cpp +++ b/src/webnn_native/WebnnNative.cpp @@ -18,8 +18,7 @@ #include #include "common/Assert.h" -#include "webnn_native/Compilation.h" -#include "webnn_native/ModelBuilder.h" +#include "webnn_native/GraphBuilder.h" // Contains the entry-points into webnn_native namespace webnn_native { @@ -30,23 +29,23 @@ namespace webnn_native { } namespace null { - NeuralNetworkContextBase* Create(); + ContextBase* Create(MLContextOptions const* options); } namespace ie { - NeuralNetworkContextBase* Create(); + ContextBase* Create(MLContextOptions const* options); } namespace dml { - NeuralNetworkContextBase* Create(); + ContextBase* Create(MLContextOptions const* options); } // Should put the default null backend at the end. - WebnnNeuralNetworkContext CreateNeuralNetworkContext() { + MLContext CreateContext(MLContextOptions const* options) { #if defined(WEBNN_ENABLE_BACKEND_OPENVINO) - return reinterpret_cast(ie::Create()); + return reinterpret_cast(ie::Create(options)); #elif defined(WEBNN_ENABLE_BACKEND_DML) - return reinterpret_cast(dml::Create()); + return reinterpret_cast(dml::Create(options)); #elif defined(WEBNN_ENABLE_BACKEND_NULL) - return reinterpret_cast(null::Create()); + return reinterpret_cast(null::Create(options)); #else return nullptr; #endif diff --git a/src/webnn_native/dml/CompilationDML.cpp b/src/webnn_native/dml/CompilationDML.cpp deleted file mode 100644 index 775ff59d9..000000000 --- a/src/webnn_native/dml/CompilationDML.cpp +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "webnn_native/dml/CompilationDML.h" - -#include - -#include "common/Log.h" -#include "webnn_native/NamedResults.h" -#include "webnn_native/Operand.h" -#include "webnn_native/Result.h" -#include "webnn_native/dml/deps/src/precomp.h" - -namespace webnn_native { namespace dml { - - class Result : public ResultBase { - public: - explicit Result(void* buffer, uint32_t buffer_size, std::vector& dimensions) - : ResultBase(buffer, buffer_size, dimensions) { - } - ~Result() { - free(mBuffer); - } - }; - - Compilation::Compilation(const Ref& model) : mModel(model) { - std::vector<::dml::Expression> outputs; - for (auto& output : mModel->mOutputs) { - outputs.push_back(output.second); - } - // TODO(nhu): investigate other execution flag, - // e.g. DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION - mCompiledModel.reset( - new pydml::CompiledModel(*(mModel->mGraph), DML_EXECUTION_FLAG_NONE, outputs)); - } - - void Compilation::ComputeImpl(NamedInputsBase* inputs, - WebnnComputeCallback callback, - void* userdata, - NamedOutputsBase* outputs) { - for (auto& input : inputs->GetRecords()) { - ::pydml::Binding* inputBinding = mModel->mInputs.at(input.first); - inputBinding->data.buffer = const_cast(input.second->buffer); - inputBinding->data.size = input.second->size; - } - std::vector inputBindings; - for (auto& binding : mModel->mBindings) { - inputBindings.push_back(binding.get()); - } - std::vector<::dml::Expression*> outputExpressions; - std::vector outputNames; - if (outputs != nullptr) { - for (auto& output : outputs->GetRecords()) { - outputNames.push_back(output.first); - outputExpressions.push_back(&(mModel->mOutputs.at(output.first))); - } - } else { - for (auto& output : mModel->mOutputs) { - outputNames.push_back(output.first); - outputExpressions.push_back(&(output.second)); - } - } - std::vector outputTensors; - if (FAILED(mModel->mDevice->DispatchOperator(mCompiledModel->op.Get(), inputBindings, - outputExpressions, outputTensors))) { - callback(WebnnComputeStatus_Error, nullptr, "Failed to dispatch operator", userdata); - return; - } - - Ref results = AcquireRef(new NamedResultsBase()); - for (size_t i = 0; i < outputNames.size(); ++i) { - std::string outputName = outputNames[i]; - pydml::TensorData* tensor = outputTensors[i]; - void* outputBuffer = tensor->Get(); - size_t bufferLength = tensor->Size(); - std::vector dimensions; - for (auto size : tensor->Desc()->sizes) { - // convert from uint32_t to int32_t. - dimensions.push_back(static_cast(size)); - } - Ref result = AcquireRef(new Result(outputBuffer, bufferLength, dimensions)); - results->Set(outputName.c_str(), result.Detach()); - if (outputs != nullptr) { - const Output* output = outputs->GetRecords().at(outputName); - if (output->size >= bufferLength) { - memcpy(output->buffer, outputBuffer, bufferLength); - } - } - delete tensor; - } - callback(WebnnComputeStatus_Success, reinterpret_cast(results.Detach()), - nullptr, userdata); - return; - } - -}} // namespace webnn_native::dml diff --git a/src/webnn_native/dml/CompilationDML.h b/src/webnn_native/dml/CompilationDML.h deleted file mode 100644 index f1d82c1c9..000000000 --- a/src/webnn_native/dml/CompilationDML.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef WEBNN_NATIVE_DML_COMPILATION_DML_H_ -#define WEBNN_NATIVE_DML_COMPILATION_DML_H_ - -#include "webnn_native/Compilation.h" -#include "webnn_native/dml/ModelDML.h" - -namespace pydml { - struct CompiledModel; -} - -namespace webnn_native { namespace dml { - - class Compilation : public CompilationBase { - public: - explicit Compilation(const Ref& model); - ~Compilation() override = default; - - IDMLCompiledOperator* GetCompiledOperator() { - return mCompiledModel->op.Get(); - } - - private: - void ComputeImpl(NamedInputsBase* inputs, - WebnnComputeCallback callback, - void* userdata, - NamedOutputsBase* outputs = nullptr) override; - - Ref mModel; - std::unique_ptr mCompiledModel; - }; - -}} // namespace webnn_native::dml - -#endif // WEBNN_NATIVE_DML_COMPILATION_DML_H_ diff --git a/src/webnn_native/dml/NeuralNetworkContextDML.cpp b/src/webnn_native/dml/ContextDML.cpp similarity index 63% rename from src/webnn_native/dml/NeuralNetworkContextDML.cpp rename to src/webnn_native/dml/ContextDML.cpp index e1fba77d2..5bf0ea58f 100644 --- a/src/webnn_native/dml/NeuralNetworkContextDML.cpp +++ b/src/webnn_native/dml/ContextDML.cpp @@ -12,23 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "webnn_native/dml/NeuralNetworkContextDML.h" +#include "webnn_native/dml/ContextDML.h" #include "common/RefCounted.h" -#include "webnn_native/dml/ModelBuilderDML.h" +#include "webnn_native/dml/GraphDML.h" namespace webnn_native { namespace dml { - NeuralNetworkContextBase* Create() { - Ref context = AcquireRef(new NeuralNetworkContext()); - if (FAILED(reinterpret_cast(context.Get())->CreateDevice())) { + ContextBase* Create(MLContextOptions const* options) { + Ref context = + AcquireRef(new Context(reinterpret_cast(options))); + if (FAILED(reinterpret_cast(context.Get())->CreateDevice())) { dawn::ErrorLog() << "Failed to create DirectML device."; return nullptr; } return context.Detach(); } - HRESULT NeuralNetworkContext::CreateDevice() { + Context::Context(ContextOptions const* options) { + if (options == nullptr) { + return; + } + mOptions = *options; + } + + HRESULT Context::CreateDevice() { #if defined(_DEBUG) mDevice.reset(new ::pydml::Device(true, true)); #else @@ -37,8 +45,8 @@ namespace webnn_native { namespace dml { return mDevice->Init(); } - ModelBuilderBase* NeuralNetworkContext::CreateModelBuilderImpl() { - return new ModelBuilder(this); + GraphBase* Context::CreateGraphImpl() { + return new Graph(this); } }} // namespace webnn_native::dml diff --git a/src/webnn_native/dml/NeuralNetworkContextDML.h b/src/webnn_native/dml/ContextDML.h similarity index 68% rename from src/webnn_native/dml/NeuralNetworkContextDML.h rename to src/webnn_native/dml/ContextDML.h index 79a3b905b..83aaf0426 100644 --- a/src/webnn_native/dml/NeuralNetworkContextDML.h +++ b/src/webnn_native/dml/ContextDML.h @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef WEBNN_NATIVE_DML_NEURAL_NETWORK_CONTEXT_DML_H_ -#define WEBNN_NATIVE_DML_NEURAL_NETWORK_CONTEXT_DML_H_ +#ifndef WEBNN_NATIVE_DML_CONTEXT_DML_H_ +#define WEBNN_NATIVE_DML_CONTEXT_DML_H_ -#include "webnn_native/NeuralNetworkContext.h" +#include "webnn_native/Context.h" +#include "webnn_native/Graph.h" #include "webnn_native/dml/deps/src/precomp.h" namespace webnn_native { namespace dml { - class NeuralNetworkContext : public NeuralNetworkContextBase { + class Context : public ContextBase { public: - NeuralNetworkContext() = default; - ~NeuralNetworkContext() override = default; - - ModelBuilderBase* CreateModelBuilderImpl() override; + Context(ContextOptions const* options); + ~Context() override = default; HRESULT CreateDevice(); + GraphBase* CreateGraphImpl() override; std::shared_ptr<::pydml::Device> GetDevice() { return mDevice; @@ -35,8 +35,9 @@ namespace webnn_native { namespace dml { private: std::shared_ptr<::pydml::Device> mDevice; + ContextOptions mOptions; }; }} // namespace webnn_native::dml -#endif // WEBNN_NATIVE_DML_NEURAL_NETWORK_CONTEXT_DML_H_ +#endif // WEBNN_NATIVE_DML_CONTEXT_DML_H_ diff --git a/src/webnn_native/dml/ModelDML.cpp b/src/webnn_native/dml/GraphDML.cpp similarity index 84% rename from src/webnn_native/dml/ModelDML.cpp rename to src/webnn_native/dml/GraphDML.cpp index 0de3decf5..b3d21d77f 100644 --- a/src/webnn_native/dml/ModelDML.cpp +++ b/src/webnn_native/dml/GraphDML.cpp @@ -12,26 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "webnn_native/dml/ModelDML.h" +#include "webnn_native/dml/GraphDML.h" #include "common/Assert.h" #include "common/Log.h" #include "webnn_native/ErrorData.h" -#include "webnn_native/dml/CompilationDML.h" -#include "webnn_native/dml/NeuralNetworkContextDML.h" +#include "webnn_native/NamedInputs.h" +#include "webnn_native/NamedOutputs.h" +#include "webnn_native/NamedResults.h" +#include "webnn_native/Operand.h" +#include "webnn_native/Result.h" +#include "webnn_native/dml/ContextDML.h" +#include "webnn_native/dml/deps/src/precomp.h" namespace webnn_native { namespace dml { + class Result : public ResultBase { + public: + explicit Result(void* buffer, uint32_t buffer_size, std::vector& dimensions) + : ResultBase(buffer, buffer_size, dimensions) { + } + ~Result() { + free(mBuffer); + } + }; namespace { - bool GetDmlTensorDataType(webnn::OperandType operandType, + bool GetDmlTensorDataType(ml::OperandType operandType, DML_TENSOR_DATA_TYPE& dmlTensorDataType) { - if (operandType == webnn::OperandType::Float32) { + if (operandType == ml::OperandType::Float32) { dmlTensorDataType = DML_TENSOR_DATA_TYPE_FLOAT32; - } else if (operandType == webnn::OperandType::Float16) { + } else if (operandType == ml::OperandType::Float16) { dmlTensorDataType = DML_TENSOR_DATA_TYPE_FLOAT16; - } else if (operandType == webnn::OperandType::Int32) { + } else if (operandType == ml::OperandType::Int32) { dmlTensorDataType = DML_TENSOR_DATA_TYPE_INT32; - } else if (operandType == webnn::OperandType::Uint32) { + } else if (operandType == ml::OperandType::Uint32) { dmlTensorDataType = DML_TENSOR_DATA_TYPE_UINT32; } else { return false; @@ -230,12 +244,12 @@ namespace webnn_native { namespace dml { return std::to_string(type); } - Model::Model(ModelBuilder* modelBuilder) : ModelBase(modelBuilder) { - mDevice = reinterpret_cast(modelBuilder->GetContext())->GetDevice(); + Graph::Graph(Context* context) : GraphBase(context) { + mDevice = context->GetDevice(); mGraph.reset(new ::dml::Graph(mDevice->GetDevice())); } - MaybeError Model::AddConstant(const op::Constant* constant) { + MaybeError Graph::AddConstant(const op::Constant* constant) { const OperandDescriptor* desc = constant->GetOperandDescriptor(); DML_TENSOR_DATA_TYPE dmlTensorType; if (!GetDmlTensorDataType(desc->type, dmlTensorType)) { @@ -260,7 +274,7 @@ namespace webnn_native { namespace dml { return {}; } - MaybeError Model::AddInput(const op::Input* input) { + MaybeError Graph::AddInput(const op::Input* input) { const OperandDescriptor* desc = input->GetOperandDescriptor(); DML_TENSOR_DATA_TYPE dmlTensorType; if (!GetDmlTensorDataType(desc->type, dmlTensorType)) { @@ -280,14 +294,14 @@ namespace webnn_native { namespace dml { return {}; } - MaybeError Model::AddOutput(const std::string& name, const OperandBase* output) { + MaybeError Graph::AddOutput(const std::string& name, const OperandBase* output) { DAWN_ASSERT(mExpression.find(output) != mExpression.end()); ::dml::Expression dmlOutput = mExpression.at(output); mOutputs.insert(std::make_pair(name, dmlOutput)); return {}; } - MaybeError Model::AddBinary(const op::Binary* binary) { + MaybeError Graph::AddBinary(const op::Binary* binary) { DAWN_ASSERT(binary->Inputs().size() == 2); DAWN_ASSERT(mExpression.find(binary->Inputs()[0].Get()) != mExpression.end()); ::dml::Expression a = mExpression.at(binary->Inputs()[0].Get()); @@ -393,7 +407,7 @@ namespace webnn_native { namespace dml { return {}; } - MaybeError Model::AddConv2d(const op::Conv2d* conv2d) { + MaybeError Graph::AddConv2d(const op::Conv2d* conv2d) { DAWN_ASSERT(conv2d->Inputs().size() == 2); const OperandBase* inputOperand = conv2d->Inputs()[0].Get(); DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end()); @@ -428,7 +442,7 @@ namespace webnn_native { namespace dml { return {}; } - MaybeError Model::AddPool2d(const op::Pool2d* pool2d) { + MaybeError Graph::AddPool2d(const op::Pool2d* pool2d) { DAWN_ASSERT(pool2d->Inputs().size() == 1); const OperandBase* inputOperand = pool2d->Inputs()[0].Get(); DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end()); @@ -476,7 +490,7 @@ namespace webnn_native { namespace dml { return {}; } - MaybeError Model::AddReshape(const op::Reshape* reshape) { + MaybeError Graph::AddReshape(const op::Reshape* reshape) { DAWN_ASSERT(reshape->Inputs().size() == 1); const OperandBase* inputOperand = reshape->Inputs()[0].Get(); DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end()); @@ -519,7 +533,7 @@ namespace webnn_native { namespace dml { return {}; } - MaybeError Model::AddTranspose(const op::Transpose* transpose) { + MaybeError Graph::AddTranspose(const op::Transpose* transpose) { DAWN_ASSERT(transpose->Inputs().size() == 1); const OperandBase* inputOperand = transpose->Inputs()[0].Get(); DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end()); @@ -577,7 +591,7 @@ namespace webnn_native { namespace dml { return {}; } - MaybeError Model::AddUnary(const op::Unary* unary) { + MaybeError Graph::AddUnary(const op::Unary* unary) { DAWN_ASSERT(unary->Inputs().size() == 1); const OperandBase* inputOperand = unary->Inputs()[0].Get(); DAWN_ASSERT(mExpression.find(inputOperand) != mExpression.end()); @@ -597,7 +611,7 @@ namespace webnn_native { namespace dml { return {}; } - MaybeError Model::Finish() { + MaybeError Graph::Finish() { if (mOutputs.size() == 1) { auto output = mOutputs.begin(); if (output->second.Impl()->GetNode().type == ::dml::detail::NodeType::Reinterpret) { @@ -608,25 +622,84 @@ namespace webnn_native { namespace dml { mOutputs[name] = ::dml::ActivationIdentity(reshape); } } + + // FIXME(nhu): implement async + std::vector<::dml::Expression> outputs; + for (auto& output : mOutputs) { + outputs.push_back(output.second); + } + // TODO(nhu): investigate other execution flag, + // e.g. DML_EXECUTION_FLAG_ALLOW_HALF_PRECISION_COMPUTATION + mCompiledModel.reset(new pydml::CompiledModel(*(mGraph), DML_EXECUTION_FLAG_NONE, outputs)); + + std::vector inputBindings; + for (auto& binding : mBindings) { + inputBindings.push_back(binding.get()); + } + if (FAILED(mDevice->InitializeOperator(mCompiledModel->op.Get(), inputBindings))) { + return DAWN_INTERNAL_ERROR("Failed to initialize operator"); + } + return {}; } - void Model::CompileImpl(WebnnCompileCallback callback, + void Graph::ComputeImpl(NamedInputsBase* inputs, + MLComputeCallback callback, void* userdata, - CompilationOptions const* options) { - // FIXME(nhu): implement async - WebnnCompileStatus status = WebnnCompileStatus_Success; - Compilation* compilation = new Compilation(this); + NamedOutputsBase* outputs) { + for (auto& input : inputs->GetRecords()) { + ::pydml::Binding* inputBinding = mInputs.at(input.first); + inputBinding->data.buffer = const_cast(input.second->buffer); + inputBinding->data.size = input.second->size; + } std::vector inputBindings; for (auto& binding : mBindings) { inputBindings.push_back(binding.get()); } - if (FAILED( - mDevice->InitializeOperator(compilation->GetCompiledOperator(), inputBindings))) { - callback(WebnnCompileStatus_Error, nullptr, "Failed to initialize operator", userdata); + std::vector<::dml::Expression*> outputExpressions; + std::vector outputNames; + if (outputs != nullptr) { + for (auto& output : outputs->GetRecords()) { + outputNames.push_back(output.first); + outputExpressions.push_back(&(mOutputs.at(output.first))); + } } else { - callback(status, reinterpret_cast(compilation), nullptr, userdata); + for (auto& output : mOutputs) { + outputNames.push_back(output.first); + outputExpressions.push_back(&(output.second)); + } + } + std::vector outputTensors; + if (FAILED(mDevice->DispatchOperator(mCompiledModel->op.Get(), inputBindings, + outputExpressions, outputTensors))) { + callback(MLComputeStatus_Error, nullptr, "Failed to dispatch operator", userdata); + return; + } + + Ref results = AcquireRef(new NamedResultsBase()); + for (size_t i = 0; i < outputNames.size(); ++i) { + std::string outputName = outputNames[i]; + pydml::TensorData* tensor = outputTensors[i]; + void* outputBuffer = tensor->Get(); + size_t bufferLength = tensor->Size(); + std::vector dimensions; + for (auto size : tensor->Desc()->sizes) { + // convert from uint32_t to int32_t. + dimensions.push_back(static_cast(size)); + } + Ref result = AcquireRef(new Result(outputBuffer, bufferLength, dimensions)); + results->Set(outputName.c_str(), result.Detach()); + if (outputs != nullptr) { + const Output* output = outputs->GetRecords().at(outputName); + if (output->size >= bufferLength) { + memcpy(output->buffer, outputBuffer, bufferLength); + } + } + delete tensor; } + callback(MLComputeStatus_Success, reinterpret_cast(results.Detach()), + nullptr, userdata); + return; } }} // namespace webnn_native::dml diff --git a/src/webnn_native/dml/ModelDML.h b/src/webnn_native/dml/GraphDML.h similarity index 86% rename from src/webnn_native/dml/ModelDML.h rename to src/webnn_native/dml/GraphDML.h index e5e819970..f43663a2a 100644 --- a/src/webnn_native/dml/ModelDML.h +++ b/src/webnn_native/dml/GraphDML.h @@ -18,9 +18,9 @@ #include #include -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" -#include "webnn_native/dml/ModelBuilderDML.h" +#include "webnn_native/dml/ContextDML.h" #include "webnn_native/dml/deps/src/precomp.h" #include "webnn_native/ops/Binary.h" #include "webnn_native/ops/Constant.h" @@ -36,10 +36,10 @@ namespace webnn_native { namespace dml { std::string DmlTensorDimensionsToString(const ::dml::TensorDimensions&); std::string DmlTensorDataTypeToString(DML_TENSOR_DATA_TYPE type); - class Model : public ModelBase { + class Graph : public GraphBase { public: - explicit Model(ModelBuilder* model_builder); - ~Model() override = default; + explicit Graph(Context* context); + ~Graph() override = default; virtual MaybeError AddConstant(const op::Constant* constant) override; virtual MaybeError AddInput(const op::Input* input) override; @@ -55,9 +55,10 @@ namespace webnn_native { namespace dml { friend class Compilation; private: - void CompileImpl(WebnnCompileCallback callback, + void ComputeImpl(NamedInputsBase* inputs, + MLComputeCallback callback, void* userdata, - CompilationOptions const* options) override; + NamedOutputsBase* outputs) override; std::shared_ptr<::pydml::Device> mDevice; std::unique_ptr<::dml::Graph> mGraph; @@ -66,6 +67,7 @@ namespace webnn_native { namespace dml { std::vector> mConstantBuffers; std::map mInputs; std::map mOutputs; + std::unique_ptr mCompiledModel; }; }} // namespace webnn_native::dml diff --git a/src/webnn_native/dml/ModelBuilderDML.cpp b/src/webnn_native/dml/ModelBuilderDML.cpp deleted file mode 100644 index 8812005e7..000000000 --- a/src/webnn_native/dml/ModelBuilderDML.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "webnn_native/dml/ModelBuilderDML.h" - -#include "common/Log.h" -#include "webnn_native/dml/ModelDML.h" -#include "webnn_native/dml/deps/src/precomp.h" - -namespace webnn_native { namespace dml { - - ModelBuilder::ModelBuilder(NeuralNetworkContextBase* context) : ModelBuilderBase(context) { - } - - ModelBase* ModelBuilder::CreateModelImpl() { - return new Model(this); - } - -}} // namespace webnn_native::dml diff --git a/src/webnn_native/dml/ModelBuilderDML.h b/src/webnn_native/dml/ModelBuilderDML.h deleted file mode 100644 index 152c16489..000000000 --- a/src/webnn_native/dml/ModelBuilderDML.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef WEBNN_NATIVE_DML_MODEL_BUILDER_DML_H_ -#define WEBNN_NATIVE_DML_MODEL_BUILDER_DML_H_ - -#include "webnn_native/ModelBuilder.h" - -namespace webnn_native { namespace dml { - - class ModelBuilder : public ModelBuilderBase { - public: - explicit ModelBuilder(NeuralNetworkContextBase* context); - ~ModelBuilder() override = default; - - private: - ModelBase* CreateModelImpl() override; - }; - -}} // namespace webnn_native::dml - -#endif // WEBNN_NATIVE_DML_MODEL_BUILDER_DML_H_ diff --git a/src/webnn_native/null/ContextNull.cpp b/src/webnn_native/null/ContextNull.cpp new file mode 100644 index 000000000..0cf3a052a --- /dev/null +++ b/src/webnn_native/null/ContextNull.cpp @@ -0,0 +1,86 @@ +// Copyright 2021 The WebNN-native Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "webnn_native/null/ContextNull.h" +#include "common/RefCounted.h" + +namespace webnn_native { namespace null { + + // Context + ContextBase* Create(MLContextOptions const* options) { + return new Context(reinterpret_cast(options)); + } + + Context::Context(ContextOptions const* options) { + } + + GraphBase* Context::CreateGraphImpl() { + return new Graph(this); + } + + // GraphBuilder + GraphBuilder::GraphBuilder(ContextBase* context) : GraphBuilderBase(context) { + } + + // Graph + Graph::Graph(Context* context) : GraphBase(context) { + } + + void Graph::ComputeImpl(NamedInputsBase* inputs, + MLComputeCallback callback, + void* userdata, + NamedOutputsBase* outputs) { + } + + MaybeError Graph::AddConstant(const op::Constant* constant) { + return {}; + } + + MaybeError Graph::AddInput(const op::Input* input) { + return {}; + } + + MaybeError Graph::AddOutput(const std::string& name, const OperandBase* output) { + return {}; + } + + MaybeError Graph::AddBinary(const op::Binary* binary) { + return {}; + } + + MaybeError Graph::AddConv2d(const op::Conv2d* conv2d) { + return {}; + } + + MaybeError Graph::AddPool2d(const op::Pool2d* pool2d) { + return {}; + } + + MaybeError Graph::AddReshape(const op::Reshape* relu) { + return {}; + } + + MaybeError Graph::AddTranspose(const op::Transpose* transpose) { + return {}; + } + + MaybeError Graph::AddUnary(const op::Unary* unary) { + return {}; + } + + MaybeError Graph::Finish() { + return {}; + } + +}} // namespace webnn_native::null diff --git a/src/webnn_native/null/NeuralNetworkContextNull.h b/src/webnn_native/null/ContextNull.h similarity index 52% rename from src/webnn_native/null/NeuralNetworkContextNull.h rename to src/webnn_native/null/ContextNull.h index b80017a75..7545f0df5 100644 --- a/src/webnn_native/null/NeuralNetworkContextNull.h +++ b/src/webnn_native/null/ContextNull.h @@ -12,41 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef WEBNN_NATIVE_NULL_NEURAL_NETWORK_CONTEXT_NULL_H_ -#define WEBNN_NATIVE_NULL_NEURAL_NETWORK_CONTEXT_NULL_H_ +#ifndef WEBNN_NATIVE_NULL_CONTEXT_NULL_H_ +#define WEBNN_NATIVE_NULL_CONTEXT_NULL_H_ -#include "webnn_native/Compilation.h" -#include "webnn_native/Model.h" -#include "webnn_native/ModelBuilder.h" -#include "webnn_native/NeuralNetworkContext.h" +#include "webnn_native/Context.h" +#include "webnn_native/Graph.h" +#include "webnn_native/GraphBuilder.h" namespace webnn_native { namespace null { - // NeuralNetworkContext - class NeuralNetworkContext : public NeuralNetworkContextBase { + // Context + class Context : public ContextBase { public: - NeuralNetworkContext() = default; - ~NeuralNetworkContext() override = default; + explicit Context(ContextOptions const* options); + ~Context() override = default; private: - ModelBuilderBase* CreateModelBuilderImpl() override; + GraphBase* CreateGraphImpl() override; }; - // ModelBuilder - class ModelBuilder : public ModelBuilderBase { + // GraphBuilder + class GraphBuilder : public GraphBuilderBase { public: - explicit ModelBuilder(NeuralNetworkContextBase* context); - ~ModelBuilder() override = default; - - private: - ModelBase* CreateModelImpl() override; + explicit GraphBuilder(ContextBase* context); + ~GraphBuilder() override = default; }; - // Model - class Model : public ModelBase { + // Graph + class Graph : public GraphBase { public: - explicit Model(ModelBuilder* model_builder); - ~Model() override = default; + explicit Graph(Context* context); + ~Graph() override = default; virtual MaybeError AddConstant(const op::Constant* constant) override; virtual MaybeError AddInput(const op::Input* input) override; virtual MaybeError AddOutput(const std::string& name, const OperandBase* ouput) override; @@ -57,30 +53,14 @@ namespace webnn_native { namespace null { virtual MaybeError AddTranspose(const op::Transpose* transpose) override; virtual MaybeError AddUnary(const op::Unary* unary) override; virtual MaybeError Finish() override; - friend class Compilation; - - private: - void CompileImpl(WebnnCompileCallback callback, - void* userdata, - CompilationOptions const* options) override; - }; - - // Compilation - class Compilation : public CompilationBase { - public: - Compilation() = default; - ~Compilation() override = default; - void Compile(WebnnCompileCallback callback, - void* userdata, - CompilationOptions const* options); private: void ComputeImpl(NamedInputsBase* inputs, - WebnnComputeCallback callback, + MLComputeCallback callback, void* userdata, NamedOutputsBase* outputs = nullptr) override; }; }} // namespace webnn_native::null -#endif // WEBNN_NATIVE_NULL_NEURAL_NETWORK_CONTEXT_NULL_H_ +#endif // WEBNN_NATIVE_NULL_CONTEXT_NULL_H_ diff --git a/src/webnn_native/null/NeuralNetworkContextNull.cpp b/src/webnn_native/null/NeuralNetworkContextNull.cpp deleted file mode 100644 index f98c16418..000000000 --- a/src/webnn_native/null/NeuralNetworkContextNull.cpp +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "webnn_native/null/NeuralNetworkContextNull.h" -#include "common/RefCounted.h" - -namespace webnn_native { namespace null { - - // NeuralNetworkContext - NeuralNetworkContextBase* Create() { - return new NeuralNetworkContext(); - } - - ModelBuilderBase* NeuralNetworkContext::CreateModelBuilderImpl() { - return new ModelBuilder(this); - } - - // ModelBuilder - ModelBuilder::ModelBuilder(NeuralNetworkContextBase* context) : ModelBuilderBase(context) { - } - - ModelBase* ModelBuilder::CreateModelImpl() { - return new Model(this); - } - - // Model - Model::Model(ModelBuilder* model_builder) : ModelBase(model_builder) { - } - - void Model::CompileImpl(WebnnCompileCallback callback, - void* userdata, - CompilationOptions const* options) { - Compilation* compilation = new Compilation(); - compilation->Compile(callback, userdata, options); - } - - MaybeError Model::AddConstant(const op::Constant* constant) { - return {}; - } - - MaybeError Model::AddInput(const op::Input* input) { - return {}; - } - - MaybeError Model::AddOutput(const std::string& name, const OperandBase* output) { - return {}; - } - - MaybeError Model::AddBinary(const op::Binary* binary) { - return {}; - } - - MaybeError Model::AddConv2d(const op::Conv2d* conv2d) { - return {}; - } - - MaybeError Model::AddPool2d(const op::Pool2d* pool2d) { - return {}; - } - - MaybeError Model::AddReshape(const op::Reshape* relu) { - return {}; - } - - MaybeError Model::AddTranspose(const op::Transpose* transpose) { - return {}; - } - - MaybeError Model::AddUnary(const op::Unary* unary) { - return {}; - } - - MaybeError Model::Finish() { - return {}; - } - - // Compilation - void Compilation::Compile(WebnnCompileCallback callback, - void* userdata, - CompilationOptions const* options) { - callback(WebnnCompileStatus_Success, reinterpret_cast(this), nullptr, - userdata); - } - - void Compilation::ComputeImpl(NamedInputsBase* inputs, - WebnnComputeCallback callback, - void* userdata, - NamedOutputsBase* outputs) { - } - -}} // namespace webnn_native::null diff --git a/src/webnn_native/openvino/CompilationIE.cpp b/src/webnn_native/openvino/CompilationIE.cpp deleted file mode 100644 index 48ab16a73..000000000 --- a/src/webnn_native/openvino/CompilationIE.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "webnn_native/openvino/CompilationIE.h" - -#include - -#include "common/Log.h" -#include "webnn_native/Error.h" -#include "webnn_native/ErrorData.h" -#include "webnn_native/NamedResults.h" -#include "webnn_native/Operand.h" -#include "webnn_native/Result.h" -#include "webnn_native/openvino/ErrorIE.h" -#include "webnn_native/openvino/ienn_symbol_table/ienn_symbol_table.h" - -#define DAWN_CALLBACK_TRY(code, messages) \ - { \ - MaybeError maybeError = CheckStatusCode(code, messages); \ - if (maybeError.IsError()) { \ - std::unique_ptr error = maybeError.AcquireError(); \ - callback(status, nullptr, error->GetMessage().c_str(), userdata); \ - return; \ - } \ - } \ - for (;;) \ - break - -namespace webnn_native { namespace ie { - - class Result : public ResultBase { - public: - using ResultBase::Reference; - ~Result() { - ie_compilation_free_buffer(&mBuffer); - } - }; - - Compilation::Compilation(Ref model) : mModel(model) { - } - - Compilation::~Compilation() { - IE(ie_compilation_free)(mIeCompilation); - } - - void Compilation::Compile(WebnnCompileCallback callback, - void* userdata, - CompilationOptions const* options) { - // We may leverage https://dawn-review.googlesource.com/c/dawn/+/36360 to - // implement async compilation as standle-alone component. - WebnnCompileStatus status = WebnnCompileStatus_Error; - // Create compilation for IE backend. - IEStatusCode code = - IE(ie_create_compilation)(mModel->GetInferenceEngineModel(), &mIeCompilation); - DAWN_CALLBACK_TRY(code, "IE create compilation"); - status = WebnnCompileStatus_Success; - callback(status, reinterpret_cast(this), nullptr, userdata); - } - - void Compilation::ComputeImpl(NamedInputsBase* inputs, - WebnnComputeCallback callback, - void* userdata, - NamedOutputsBase* outputs) { - WebnnComputeStatus status = WebnnComputeStatus_Error; - // Set input data to nGraph. - for (auto& input : inputs->GetRecords()) { - ie_operand_t ieOperand; - ieOperand.name = const_cast(mModel->mInputIdMap[input.first].c_str()); - IEStatusCode code = IE(ie_compilation_set_input)( - mIeCompilation, &ieOperand, input.second->buffer, input.second->size); - DAWN_CALLBACK_TRY(code, "IE set input"); - } - - // Compute the compiled model. - IEStatusCode code = IE(ie_compilation_compute)(mIeCompilation); - DAWN_CALLBACK_TRY(code, "IE compute model"); - // Get Data from nGraph with output. - Ref results = AcquireRef(new NamedResultsBase()); - size_t outputNumber = mModel->GetOutputsNumber(); - for (size_t i = 0; i < outputNumber; ++i) { - std::string outputId = mModel->GetOutputId(i); - void* outputBuffer; - size_t bufferLength; - IEStatusCode code = IE(ie_compilation_get_buffer)(mIeCompilation, outputId.data(), - &outputBuffer, &bufferLength); - DAWN_CALLBACK_TRY(code, "IE get buffer"); - ie_dimensions_t ieDimensions; - code = - IE(ie_compilation_get_dimensions)(mIeCompilation, outputId.data(), &ieDimensions); - DAWN_CALLBACK_TRY(code, "IE get dimensions"); - std::vector dimensions(ieDimensions.dims, - ieDimensions.dims + ieDimensions.ranks); - code = IE(ie_compilation_free_dimensions)(&ieDimensions); - Ref result = - AcquireRef(new Result::ResultBase(outputBuffer, bufferLength, dimensions)); - std::string outputName = mModel->mOutputNameMap[outputId]; - results->Set(outputName.c_str(), result.Detach()); - if (outputs != nullptr) { - const Output* output = outputs->GetRecords().at(outputName); - ie_operand_t ieOperand; - ieOperand.name = const_cast(outputId.c_str()); - IEStatusCode code = IE(ie_compilation_get_output)(mIeCompilation, &ieOperand, - output->buffer, output->size); - DAWN_CALLBACK_TRY(code, "IE get output"); - } - } - status = WebnnComputeStatus_Success; - callback(status, reinterpret_cast(results.Detach()), nullptr, userdata); - return; - } - -}} // namespace webnn_native::ie diff --git a/src/webnn_native/openvino/CompilationIE.h b/src/webnn_native/openvino/CompilationIE.h deleted file mode 100644 index 0f3c07e02..000000000 --- a/src/webnn_native/openvino/CompilationIE.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef WEBNN_NATIVE_IE_COMPILATION_IE_H_ -#define WEBNN_NATIVE_IE_COMPILATION_IE_H_ - -#include "webnn_native/Compilation.h" -#include "webnn_native/openvino/ModelIE.h" -#include "webnn_native/openvino/ienn/src/ie_nn_c_api.h" - -namespace webnn_native { namespace ie { - - class Compilation : public CompilationBase { - public: - Compilation(Ref model); - ~Compilation() override; - - void Compile(WebnnCompileCallback callback, - void* userdata, - CompilationOptions const* options); - - private: - void ComputeImpl(NamedInputsBase* inputs, - WebnnComputeCallback callback, - void* userdata, - NamedOutputsBase* outputs) override; - - Ref mModel; - ie_compilation_t* mIeCompilation; - }; - -}} // namespace webnn_native::ie - -#endif // WEBNN_NATIVE_IE_COMPILATION_IE_H_ diff --git a/src/webnn_native/openvino/NeuralNetworkContextIE.cpp b/src/webnn_native/openvino/ContextIE.cpp similarity index 70% rename from src/webnn_native/openvino/NeuralNetworkContextIE.cpp rename to src/webnn_native/openvino/ContextIE.cpp index fa83cc0bd..f165e220a 100644 --- a/src/webnn_native/openvino/NeuralNetworkContextIE.cpp +++ b/src/webnn_native/openvino/ContextIE.cpp @@ -12,27 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "webnn_native/openvino/NeuralNetworkContextIE.h" +#include "webnn_native/openvino/ContextIE.h" #include "common/Log.h" #include "common/RefCounted.h" -#include "webnn_native/openvino/ModelBuilderIE.h" +#include "webnn_native/openvino/GraphIE.h" #include "webnn_native/openvino/ienn_symbol_table/ienn_symbol_table.h" namespace webnn_native { namespace ie { - NeuralNetworkContextBase* Create() { + ContextBase* Create(MLContextOptions const* options) { // Load ienn_c_api library. if (!GetIESymbolTable()->Load()) { dawn::ErrorLog() << "Failed to load the OpenVINO libraries, please make sure the " "OpenVINO environment variables are set."; return nullptr; } - return new NeuralNetworkContext(); + return new Context(reinterpret_cast(options)); } - ModelBuilderBase* NeuralNetworkContext::CreateModelBuilderImpl() { - return new ModelBuilder(this); + Context::Context(ContextOptions const* options) { + if (options == nullptr) { + return; + } + mOptions = *options; + } + + GraphBase* Context::CreateGraphImpl() { + return new Graph(this); } }} // namespace webnn_native::ie diff --git a/src/webnn_native/openvino/ModelBuilderIE.h b/src/webnn_native/openvino/ContextIE.h similarity index 64% rename from src/webnn_native/openvino/ModelBuilderIE.h rename to src/webnn_native/openvino/ContextIE.h index 543353321..3e9bf70cc 100644 --- a/src/webnn_native/openvino/ModelBuilderIE.h +++ b/src/webnn_native/openvino/ContextIE.h @@ -12,22 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef WEBNN_NATIVE_IE_MODEL_BUILDER_IE_H_ -#define WEBNN_NATIVE_IE_MODEL_BUILDER_IE_H_ +#ifndef WEBNN_NATIVE_IE_CONTEXT_IE_H_ +#define WEBNN_NATIVE_IE_CONTEXT_IE_H_ -#include "webnn_native/ModelBuilder.h" +#include "webnn_native/Context.h" +#include "webnn_native/Graph.h" namespace webnn_native { namespace ie { - class ModelBuilder : public ModelBuilderBase { + class Context : public ContextBase { public: - explicit ModelBuilder(NeuralNetworkContextBase* context); - ~ModelBuilder() override = default; + explicit Context(ContextOptions const* options); + ~Context() override = default; private: - ModelBase* CreateModelImpl() override; + GraphBase* CreateGraphImpl() override; + + ContextOptions mOptions; }; }} // namespace webnn_native::ie -#endif // WEBNN_NATIVE_IE_MODEL_BUILDER_IE_H_ +#endif // WEBNN_NATIVE_IE_CONTEXT_IE_H_ diff --git a/src/webnn_native/openvino/ModelIE.cpp b/src/webnn_native/openvino/GraphIE.cpp similarity index 58% rename from src/webnn_native/openvino/ModelIE.cpp rename to src/webnn_native/openvino/GraphIE.cpp index 8dbbf6ecf..7d8da8774 100644 --- a/src/webnn_native/openvino/ModelIE.cpp +++ b/src/webnn_native/openvino/GraphIE.cpp @@ -12,36 +12,73 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "webnn_native/openvino/ModelIE.h" +#include "webnn_native/openvino/GraphIE.h" #include #include "common/Assert.h" #include "common/Log.h" #include "webnn_native/ErrorData.h" +#include "webnn_native/NamedInputs.h" #include "webnn_native/NamedOperands.h" -#include "webnn_native/openvino/CompilationIE.h" +#include "webnn_native/NamedOutputs.h" +#include "webnn_native/NamedResults.h" +#include "webnn_native/Result.h" #include "webnn_native/openvino/ErrorIE.h" #include "webnn_native/openvino/ienn_symbol_table/ienn_symbol_table.h" +#define COMPUTE_CALLBACK_TRY(code, messages) \ + { \ + MaybeError maybeError = CheckStatusCode(code, messages); \ + if (maybeError.IsError()) { \ + std::unique_ptr error = maybeError.AcquireError(); \ + callback(MLComputeStatus_Error, nullptr, error->GetMessage().c_str(), userdata); \ + return; \ + } \ + } \ + for (;;) \ + break + namespace webnn_native { namespace ie { + class Result : public ResultBase { + public: + using ResultBase::Reference; + ~Result() { + ie_compilation_free_buffer(&mBuffer); + } + }; namespace { + + std::string GetOutputId(ie_model_t* model, size_t index) { + char* outputName; + IEStatusCode code = IE(ie_model_get_output_name)(model, index, &outputName); + if (code != IEStatusCode::OK) { + dawn::ErrorLog() << "Failing to get output name for IE."; + return std::string(); + } + std::string name(outputName); + // The name has been kept in outputs object, so it can be free. + IE(ie_model_free_name)(&outputName); + + return name; + } + ie_operand_descriptor ConvertTo(OperandDescriptor const* desc) { ie_operand_descriptor ieDesc; ieDesc.dimensions = desc->dimensions; ieDesc.dimensionsCount = desc->dimensionsCount; switch (desc->type) { - case webnn::OperandType::Float32: + case ml::OperandType::Float32: ieDesc.type = ie_operand_type::Float32; break; - case webnn::OperandType::Int32: + case ml::OperandType::Int32: ieDesc.type = ie_operand_type::Int32; break; - case webnn::OperandType::Float16: + case ml::OperandType::Float16: ieDesc.type = ie_operand_type::Float16; break; - case webnn::OperandType::Uint32: + case ml::OperandType::Uint32: ieDesc.type = ie_operand_type::Uint32; break; default: @@ -81,7 +118,7 @@ namespace webnn_native { namespace ie { } // namespace - Model::Model(ModelBuilder* modelBuilder) : ModelBase(modelBuilder) { + Graph::Graph(Context* context) : GraphBase(context) { // Create model. IEStatusCode code = IE(ie_create_model)(&mIeModel); if (code != IEStatusCode::OK) { @@ -90,11 +127,12 @@ namespace webnn_native { namespace ie { } } - Model::~Model() { + Graph::~Graph() { IE(ie_model_free)(mIeModel); + IE(ie_compilation_free)(mIeCompilation); } - MaybeError Model::AddConstant(const op::Constant* constant) { + MaybeError Graph::AddConstant(const op::Constant* constant) { ie_operand_descriptor ieDesc = ConvertTo(constant->GetOperandDescriptor()); ie_operand_t* ieOperand; IEStatusCode code = IE(ie_model_add_constant)(mIeModel, &ieDesc, constant->GetValue(), @@ -105,7 +143,7 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::AddInput(const op::Input* input) { + MaybeError Graph::AddInput(const op::Input* input) { ie_operand_descriptor ieDesc = ConvertTo(input->GetOperandDescriptor()); ie_operand_t* ieOperand; IEStatusCode code = IE(ie_model_add_input)(mIeModel, &ieDesc, &ieOperand); @@ -116,7 +154,7 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::AddOutput(const std::string& name, const OperandBase* output) { + MaybeError Graph::AddOutput(const std::string& name, const OperandBase* output) { ie_operand_t ieOperand; ieOperand.name = const_cast(mOperandIdMap[output].c_str()); IEStatusCode code = IE(ie_model_add_output)(mIeModel, &ieOperand); @@ -126,7 +164,7 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::AddBinary(const op::Binary* binary) { + MaybeError Graph::AddBinary(const op::Binary* binary) { auto inputs = binary->Inputs(); ie_operand_t primary; primary.name = const_cast(mOperandIdMap[inputs[0].Get()].c_str()); @@ -146,7 +184,7 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::AddConv2d(const op::Conv2d* conv2d) { + MaybeError Graph::AddConv2d(const op::Conv2d* conv2d) { auto inputs = conv2d->Inputs(); ie_operand_t input; input.name = const_cast(mOperandIdMap[inputs[0].Get()].c_str()); @@ -162,7 +200,7 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::AddPool2d(const op::Pool2d* pool2d) { + MaybeError Graph::AddPool2d(const op::Pool2d* pool2d) { auto inputs = pool2d->Inputs(); ie_operand_t input; input.name = const_cast(mOperandIdMap[inputs[0].Get()].c_str()); @@ -176,7 +214,7 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::AddUnary(const op::Unary* unary) { + MaybeError Graph::AddUnary(const op::Unary* unary) { auto inputs = unary->Inputs(); ie_operand_t input; input.name = const_cast(mOperandIdMap[inputs[0].Get()].c_str()); @@ -193,7 +231,7 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::AddReshape(const op::Reshape* reshape) { + MaybeError Graph::AddReshape(const op::Reshape* reshape) { auto inputs = reshape->Inputs(); ie_operand_t input; input.name = const_cast(mOperandIdMap[inputs[0].Get()].c_str()); @@ -206,7 +244,7 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::AddTranspose(const op::Transpose* transpose) { + MaybeError Graph::AddTranspose(const op::Transpose* transpose) { auto inputs = transpose->Inputs(); ie_operand_t input; input.name = const_cast(mOperandIdMap[inputs[0].Get()].c_str()); @@ -219,44 +257,71 @@ namespace webnn_native { namespace ie { return {}; } - MaybeError Model::Finish() { + MaybeError Graph::Finish() { IEStatusCode code = IE(ie_model_finish)(mIeModel); DAWN_TRY(CheckStatusCode(code, "IE finish creating model")); + + // We may leverage https://dawn-review.googlesource.com/c/dawn/+/36360 to + // implement async compilation as standle-alone component. + // Create compilation for IE backend. + code = IE(ie_create_compilation)(mIeModel, &mIeCompilation); + DAWN_TRY(CheckStatusCode(code, "IE create compilation")); + return {}; } - void Model::CompileImpl(WebnnCompileCallback callback, + void Graph::ComputeImpl(NamedInputsBase* inputs, + MLComputeCallback callback, void* userdata, - CompilationOptions const* options) { - Compilation* compilation = new Compilation(this); - compilation->Compile(callback, userdata, options); - } + NamedOutputsBase* outputs) { + // Set input data to nGraph. + for (auto& input : inputs->GetRecords()) { + ie_operand_t ieOperand; + ieOperand.name = const_cast(mInputIdMap[input.first].c_str()); + IEStatusCode code = IE(ie_compilation_set_input)( + mIeCompilation, &ieOperand, input.second->buffer, input.second->size); + COMPUTE_CALLBACK_TRY(code, "IE set input"); + } - ie_model_t* Model::GetInferenceEngineModel() { - return mIeModel; - } + // Compute the compiled model. + IEStatusCode code = IE(ie_compilation_compute)(mIeCompilation); + COMPUTE_CALLBACK_TRY(code, "IE compute model"); + // Get Data from nGraph with output. + Ref results = AcquireRef(new NamedResultsBase()); - size_t Model::GetOutputsNumber() { size_t outputNumber = 0; - IEStatusCode code = IE(ie_model_get_outputs_number)(mIeModel, &outputNumber); - if (code != IEStatusCode::OK) { - dawn::ErrorLog() << "Failing to get output number for IE."; - } - return outputNumber; - } - - std::string Model::GetOutputId(size_t index) { - char* outputName; - IEStatusCode code = IE(ie_model_get_output_name)(mIeModel, index, &outputName); - if (code != IEStatusCode::OK) { - dawn::ErrorLog() << "Failing to get output name for IE."; - return std::string(); + code = IE(ie_model_get_outputs_number)(mIeModel, &outputNumber); + COMPUTE_CALLBACK_TRY(code, "Failing to get output number for IE."); + for (size_t i = 0; i < outputNumber; ++i) { + std::string outputId = GetOutputId(mIeModel, i); + void* outputBuffer; + size_t bufferLength; + IEStatusCode code = IE(ie_compilation_get_buffer)(mIeCompilation, outputId.data(), + &outputBuffer, &bufferLength); + COMPUTE_CALLBACK_TRY(code, "IE get buffer"); + ie_dimensions_t ieDimensions; + code = + IE(ie_compilation_get_dimensions)(mIeCompilation, outputId.data(), &ieDimensions); + COMPUTE_CALLBACK_TRY(code, "IE get dimensions"); + std::vector dimensions(ieDimensions.dims, + ieDimensions.dims + ieDimensions.ranks); + code = IE(ie_compilation_free_dimensions)(&ieDimensions); + Ref result = + AcquireRef(new Result::ResultBase(outputBuffer, bufferLength, dimensions)); + std::string outputName = mOutputNameMap[outputId]; + results->Set(outputName.c_str(), result.Detach()); + if (outputs != nullptr) { + const Output* output = outputs->GetRecords().at(outputName); + ie_operand_t ieOperand; + ieOperand.name = const_cast(outputId.c_str()); + IEStatusCode code = IE(ie_compilation_get_output)(mIeCompilation, &ieOperand, + output->buffer, output->size); + COMPUTE_CALLBACK_TRY(code, "IE get output"); + } } - std::string name(outputName); - // The name has been kept in outputs object, so it can be free. - IE(ie_model_free_name)(&outputName); - - return name; + callback(MLComputeStatus_Success, reinterpret_cast(results.Detach()), + nullptr, userdata); + return; } }} // namespace webnn_native::ie diff --git a/src/webnn_native/openvino/ModelIE.h b/src/webnn_native/openvino/GraphIE.h similarity index 83% rename from src/webnn_native/openvino/ModelIE.h rename to src/webnn_native/openvino/GraphIE.h index 6b635ae2e..391ed6b63 100644 --- a/src/webnn_native/openvino/ModelIE.h +++ b/src/webnn_native/openvino/GraphIE.h @@ -19,9 +19,9 @@ #include #include "webnn_native/Error.h" -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" -#include "webnn_native/openvino/ModelBuilderIE.h" +#include "webnn_native/openvino/ContextIE.h" #include "webnn_native/openvino/ienn/src/ie_nn_c_api.h" #include "webnn_native/ops/Binary.h" #include "webnn_native/ops/Constant.h" @@ -34,10 +34,10 @@ namespace webnn_native { namespace ie { - class Model : public ModelBase { + class Graph : public GraphBase { public: - explicit Model(ModelBuilder* model_builder); - ~Model() override; + explicit Graph(Context* context); + ~Graph() override; virtual MaybeError AddConstant(const op::Constant* constant) override; virtual MaybeError AddInput(const op::Input* input) override; @@ -50,18 +50,14 @@ namespace webnn_native { namespace ie { virtual MaybeError AddUnary(const op::Unary* unary) override; virtual MaybeError Finish() override; - ie_model_t* GetInferenceEngineModel(); - size_t GetOutputsNumber(); - std::string GetOutputId(size_t index); - - friend class Compilation; - private: - void CompileImpl(WebnnCompileCallback callback, + void ComputeImpl(NamedInputsBase* inputs, + MLComputeCallback callback, void* userdata, - CompilationOptions const* options) override; + NamedOutputsBase* outputs) override; ie_model_t* mIeModel; + ie_compilation_t* mIeCompilation; // Map the input name to IE internal id std::map mInputIdMap; diff --git a/src/webnn_native/openvino/ModelBuilderIE.cpp b/src/webnn_native/openvino/ModelBuilderIE.cpp deleted file mode 100644 index 26ad38f56..000000000 --- a/src/webnn_native/openvino/ModelBuilderIE.cpp +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "webnn_native/openvino/ModelBuilderIE.h" - -#include "common/Log.h" -#include "webnn_native/openvino/ModelIE.h" - -namespace webnn_native { namespace ie { - - ModelBuilder::ModelBuilder(NeuralNetworkContextBase* context) : ModelBuilderBase(context) { - } - - ModelBase* ModelBuilder::CreateModelImpl() { - return new Model(this); - } - -}} // namespace webnn_native::ie diff --git a/src/webnn_native/openvino/NeuralNetworkContextIE.h b/src/webnn_native/openvino/NeuralNetworkContextIE.h deleted file mode 100644 index 8c48db171..000000000 --- a/src/webnn_native/openvino/NeuralNetworkContextIE.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 The WebNN-native Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef WEBNN_NATIVE_IE_NEURAL_NETWORK_CONTEXT_IE_H_ -#define WEBNN_NATIVE_IE_NEURAL_NETWORK_CONTEXT_IE_H_ - -#include "webnn_native/NeuralNetworkContext.h" - -namespace webnn_native { namespace ie { - - class NeuralNetworkContext : public NeuralNetworkContextBase { - public: - NeuralNetworkContext() = default; - ~NeuralNetworkContext() override = default; - - ModelBuilderBase* CreateModelBuilderImpl() override; - - private: - }; - -}} // namespace webnn_native::ie - -#endif // WEBNN_NATIVE_IE_NEURAL_NETWORK_CONTEXT_IE_H_ diff --git a/src/webnn_native/ops/Binary.h b/src/webnn_native/ops/Binary.h index f74c31bda..9c851d451 100644 --- a/src/webnn_native/ops/Binary.h +++ b/src/webnn_native/ops/Binary.h @@ -15,7 +15,7 @@ #ifndef WEBNN_NATIVE_OPS_BINARY_H_ #define WEBNN_NATIVE_OPS_BINARY_H_ -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" namespace webnn_native { namespace op { @@ -32,12 +32,12 @@ namespace webnn_native { namespace op { class Binary final : public OperandBase { public: - Binary(ModelBuilderBase* builder, BinaryOpType opType, OperandBase* a, OperandBase* b) + Binary(GraphBuilderBase* builder, BinaryOpType opType, OperandBase* a, OperandBase* b) : OperandBase(builder, {a, b}), mOpType(opType) { } ~Binary() override = default; - MaybeError AddToModel(ModelBase* model) const override { + MaybeError AddToGraph(GraphBase* model) const override { return model->AddBinary(this); } BinaryOpType GetType() const { diff --git a/src/webnn_native/ops/Constant.h b/src/webnn_native/ops/Constant.h index 760d08f43..e6337c3d9 100644 --- a/src/webnn_native/ops/Constant.h +++ b/src/webnn_native/ops/Constant.h @@ -15,14 +15,14 @@ #ifndef WEBNN_NATIVE_OPS_CONSTANT_H_ #define WEBNN_NATIVE_OPS_CONSTANT_H_ -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" namespace webnn_native { namespace op { class Constant final : public OperandBase { public: - Constant(ModelBuilderBase* builder, + Constant(GraphBuilderBase* builder, const OperandDescriptor* desc, void const* value, size_t size) @@ -36,7 +36,7 @@ namespace webnn_native { namespace op { } ~Constant() override = default; - MaybeError AddToModel(ModelBase* model) const override { + MaybeError AddToGraph(GraphBase* model) const override { return model->AddConstant(this); } diff --git a/src/webnn_native/ops/Conv2d.cpp b/src/webnn_native/ops/Conv2d.cpp index 4feba3d3f..2dc6053e0 100644 --- a/src/webnn_native/ops/Conv2d.cpp +++ b/src/webnn_native/ops/Conv2d.cpp @@ -19,7 +19,7 @@ namespace webnn_native { namespace op { - Conv2d::Conv2d(ModelBuilderBase* builder, + Conv2d::Conv2d(GraphBuilderBase* builder, OperandBase* input, OperandBase* filter, Conv2dOptions const* options) @@ -55,13 +55,13 @@ namespace webnn_native { namespace op { } if (options == nullptr) { - mOptions.layout = webnn::OperandLayout::Nchw; + mOptions.layout = ml::InputOperandLayout::Nchw; } else { mOptions.layout = options->layout; } } - MaybeError Conv2d::AddToModel(ModelBase* model) const { + MaybeError Conv2d::AddToGraph(GraphBase* model) const { return model->AddConv2d(this); } diff --git a/src/webnn_native/ops/Conv2d.h b/src/webnn_native/ops/Conv2d.h index 52266137b..a18dcd3ae 100644 --- a/src/webnn_native/ops/Conv2d.h +++ b/src/webnn_native/ops/Conv2d.h @@ -15,20 +15,20 @@ #ifndef WEBNN_NATIVE_OPS_CONV2D_H_ #define WEBNN_NATIVE_OPS_CONV2D_H_ -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" namespace webnn_native { namespace op { class Conv2d final : public OperandBase { public: - Conv2d(ModelBuilderBase* builder, + Conv2d(GraphBuilderBase* builder, OperandBase* input, OperandBase* filter, Conv2dOptions const* options); ~Conv2d() override = default; - MaybeError AddToModel(ModelBase* model) const override; + MaybeError AddToGraph(GraphBase* model) const override; MaybeError ValidateAndInferTypes() override; Conv2dOptions const* GetOptions() const; diff --git a/src/webnn_native/ops/Input.h b/src/webnn_native/ops/Input.h index 53a47ce7b..eecbd9652 100644 --- a/src/webnn_native/ops/Input.h +++ b/src/webnn_native/ops/Input.h @@ -18,14 +18,14 @@ #include #include -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" namespace webnn_native { namespace op { class Input final : public OperandBase { public: - Input(ModelBuilderBase* builder, const std::string& name, const OperandDescriptor* desc) + Input(GraphBuilderBase* builder, const std::string& name, const OperandDescriptor* desc) : OperandBase(builder), mName(name) { mDescriptor.type = desc->type; mType = desc->type; @@ -36,7 +36,7 @@ namespace webnn_native { namespace op { } ~Input() override = default; - MaybeError AddToModel(ModelBase* model) const override { + MaybeError AddToGraph(GraphBase* model) const override { return model->AddInput(this); } MaybeError ValidateAndInferTypes() override; diff --git a/src/webnn_native/ops/Pool2d.cpp b/src/webnn_native/ops/Pool2d.cpp index c0d30bf62..a1e8227ff 100644 --- a/src/webnn_native/ops/Pool2d.cpp +++ b/src/webnn_native/ops/Pool2d.cpp @@ -19,7 +19,7 @@ namespace webnn_native { namespace op { - Pool2d::Pool2d(ModelBuilderBase* builder, + Pool2d::Pool2d(GraphBuilderBase* builder, Pool2dType opType, OperandBase* input, Pool2dOptions const* options) @@ -61,13 +61,13 @@ namespace webnn_native { namespace op { mOptions.dilationsCount = mDilations.size(); if (options == nullptr) { - mOptions.layout = webnn::OperandLayout::Nchw; + mOptions.layout = ml::InputOperandLayout::Nchw; } else { mOptions.layout = options->layout; } } - MaybeError Pool2d::AddToModel(ModelBase* model) const { + MaybeError Pool2d::AddToGraph(GraphBase* model) const { return model->AddPool2d(this); } diff --git a/src/webnn_native/ops/Pool2d.h b/src/webnn_native/ops/Pool2d.h index 6ec3a5f8f..e17db4312 100644 --- a/src/webnn_native/ops/Pool2d.h +++ b/src/webnn_native/ops/Pool2d.h @@ -15,7 +15,7 @@ #ifndef WEBNN_NATIVE_OPS_POOL2d_H_ #define WEBNN_NATIVE_OPS_POOL2d_H_ -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" namespace webnn_native { namespace op { @@ -28,13 +28,13 @@ namespace webnn_native { namespace op { class Pool2d final : public OperandBase { public: - Pool2d(ModelBuilderBase* builder, + Pool2d(GraphBuilderBase* builder, Pool2dType type, OperandBase* input, Pool2dOptions const* options); ~Pool2d() override = default; - MaybeError AddToModel(ModelBase* model) const override; + MaybeError AddToGraph(GraphBase* model) const override; MaybeError ValidateAndInferTypes() override; Pool2dOptions const* GetOptions() const; diff --git a/src/webnn_native/ops/Reshape.h b/src/webnn_native/ops/Reshape.h index 1b4f3bb42..afccbf7c0 100644 --- a/src/webnn_native/ops/Reshape.h +++ b/src/webnn_native/ops/Reshape.h @@ -15,14 +15,14 @@ #ifndef WEBNN_NATIVE_OPS_RESHAPE_H_ #define WEBNN_NATIVE_OPS_RESHAPE_H_ -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" namespace webnn_native { namespace op { class Reshape final : public OperandBase { public: - Reshape(ModelBuilderBase* builder, + Reshape(GraphBuilderBase* builder, OperandBase* input, int32_t const* newShape, size_t newShapeCount) @@ -31,7 +31,7 @@ namespace webnn_native { namespace op { } ~Reshape() override = default; - MaybeError AddToModel(ModelBase* model) const override { + MaybeError AddToGraph(GraphBase* model) const override { return model->AddReshape(this); } MaybeError ValidateAndInferTypes() override; diff --git a/src/webnn_native/ops/Transpose.h b/src/webnn_native/ops/Transpose.h index 9c4db460a..94eda4179 100644 --- a/src/webnn_native/ops/Transpose.h +++ b/src/webnn_native/ops/Transpose.h @@ -15,14 +15,14 @@ #ifndef WEBNN_NATIVE_OPS_TRANSPOSE_H_ #define WEBNN_NATIVE_OPS_TRANSPOSE_H_ -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" namespace webnn_native { namespace op { class Transpose final : public OperandBase { public: - Transpose(ModelBuilderBase* builder, OperandBase* input, TransposeOptions const* options) + Transpose(GraphBuilderBase* builder, OperandBase* input, TransposeOptions const* options) : OperandBase(builder, {input}) { if (options == nullptr || options->permutation == nullptr) { int32_t rank = input->Rank(); @@ -39,7 +39,7 @@ namespace webnn_native { namespace op { } ~Transpose() override = default; - MaybeError AddToModel(ModelBase* model) const override { + MaybeError AddToGraph(GraphBase* model) const override { return model->AddTranspose(this); } MaybeError ValidateAndInferTypes() override; diff --git a/src/webnn_native/ops/Unary.h b/src/webnn_native/ops/Unary.h index a0c8b7bd0..5d606ab4e 100644 --- a/src/webnn_native/ops/Unary.h +++ b/src/webnn_native/ops/Unary.h @@ -15,7 +15,7 @@ #ifndef WEBNN_NATIVE_OPS_UNARY_H_ #define WEBNN_NATIVE_OPS_UNARY_H_ -#include "webnn_native/Model.h" +#include "webnn_native/Graph.h" #include "webnn_native/Operand.h" namespace webnn_native { namespace op { @@ -27,12 +27,12 @@ namespace webnn_native { namespace op { class Unary final : public OperandBase { public: - Unary(ModelBuilderBase* builder, UnaryOpType opType, OperandBase* input) + Unary(GraphBuilderBase* builder, UnaryOpType opType, OperandBase* input) : OperandBase(builder, {input}), mOpType(opType) { } ~Unary() override = default; - MaybeError AddToModel(ModelBase* model) const override { + MaybeError AddToGraph(GraphBase* model) const override { return model->AddUnary(this); } MaybeError ValidateAndInferTypes() override; diff --git a/webnn.json b/webnn.json index 2726a0188..32d605010 100644 --- a/webnn.json +++ b/webnn.json @@ -68,13 +68,23 @@ {"name": "userdata", "type": "void", "annotation": "*"} ] }, - "neural network context": { + "power preference": { + "category": "enum", + "values": [ + {"value": 0, "name": "default"}, + {"value": 1, "name": "high_performance"}, + {"value": 2, "name": "low_power"} + ] + }, + "context options": { + "category": "structure", + "members": [ + {"name": "power preference", "type": "power preference", "default": "default"} + ] + }, + "context": { "category": "object", "methods": [ - { - "name": "create model builder", - "returns": "model builder" - }, { "name": "set uncaptured error callback", "args": [ @@ -98,13 +108,22 @@ } ] }, + "input operand layout": { + "category": "enum", + "values": [ + {"value": 0, "name": "nchw"}, + {"value": 1, "name": "nhwc"} + ] + }, "operand type": { "category": "enum", "values": [ {"value": 0, "name": "float32"}, {"value": 1, "name": "float16"}, {"value": 2, "name": "int32"}, - {"value": 3, "name": "uint32"} + {"value": 3, "name": "uint32"}, + {"value": 4, "name": "int8"}, + {"value": 5, "name": "uint8"} ] }, "operand descriptor": { @@ -118,20 +137,6 @@ "operand": { "category": "object" }, - "power preference": { - "category": "enum", - "values": [ - {"value": 0, "name": "default"}, - {"value": 1, "name": "low_power"}, - {"value": 2, "name": "high_performance"} - ] - }, - "compilation options": { - "category": "structure", - "members": [ - {"name": "power preference", "type": "power preference"} - ] - }, "named operands": { "category": "object", "methods": [ @@ -144,13 +149,6 @@ } ] }, - "operand layout": { - "category": "enum", - "values": [ - {"value": 0, "name": "nchw"}, - {"value": 1, "name": "nhwc"} - ] - }, "conv2d options": { "category": "structure", "members": [ @@ -161,7 +159,7 @@ {"name": "dilations count", "type": "uint32_t", "default": 0}, {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "optional": true}, {"name": "groups", "type": "int32_t", "default": 1}, - {"name": "layout", "type": "operand layout", "default": "nchw"} + {"name": "layout", "type": "input operand layout", "default": "nchw"} ] }, "pool2d options": { @@ -175,7 +173,7 @@ {"name": "strides", "type": "int32_t", "annotation": "const*", "length": "strides count", "optional": true}, {"name": "dilations count", "type": "uint32_t", "default": 0}, {"name": "dilations", "type": "int32_t", "annotation": "const*", "length": "dilations count", "optional": true}, - {"name": "layout", "type": "operand layout", "default": "nchw"} + {"name": "layout", "type": "input operand layout", "default": "nchw"} ] }, "transpose options": { @@ -185,7 +183,7 @@ {"name": "permutation", "type": "int32_t", "annotation": "const*", "length": "permutation count", "optional": true} ] }, - "model builder": { + "graph builder": { "category": "object", "methods": [ { @@ -286,24 +284,25 @@ ] }, { - "name": "createModel", - "returns": "model", + "name": "build", "args": [ - {"name": "named operands", "type": "named operands"} + {"name": "named operands", "type": "named operands"}, + {"name": "callback", "type": "build callback"}, + {"name": "userdata", "type": "void", "annotation": "*"} ] } ] }, - "compile callback": { + "build callback": { "category": "callback", "args": [ - {"name": "status", "type": "compile status"}, - {"name": "compilation", "type": "compilation"}, + {"name": "status", "type": "build status"}, + {"name": "graph", "type": "graph"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "userdata", "type": "void", "annotation": "*"} ] }, - "compile status": { + "build status": { "category": "enum", "values": [ {"value": 0, "name": "success"}, @@ -312,19 +311,6 @@ {"value": 3, "name": "unknown"} ] }, - "model": { - "category": "object", - "methods": [ - { - "name": "compile", - "args": [ - {"name": "callback", "type": "compile callback"}, - {"name": "userdata", "type": "void", "annotation": "*"}, - {"name": "options", "type": "compilation options", "annotation": "const*", "optional": true} - ] - } - ] - }, "input": { "category": "structure", "members": [ @@ -418,7 +404,7 @@ {"value": 3, "name": "unknown"} ] }, - "compilation": { + "graph": { "category": "object", "methods": [ { From 7de99fdc5169e4bb63d672aad29f887040afd418 Mon Sep 17 00:00:00 2001 From: fujunwei Date: Tue, 20 Apr 2021 20:40:14 +0800 Subject: [PATCH 2/2] Keep the private method to compile the graph --- examples/SampleUtils.cpp | 8 +-- examples/SampleUtils.h | 19 ++++--- src/common/BUILD.gn | 2 +- .../validation/GraphValidationTests.cpp | 9 ++-- src/webnn_native/Graph.cpp | 6 ++- src/webnn_native/Graph.h | 7 ++- src/webnn_native/GraphBuilder.cpp | 20 +++++--- src/webnn_native/GraphBuilder.h | 6 ++- src/webnn_native/dml/GraphDML.cpp | 20 +++++--- src/webnn_native/dml/GraphDML.h | 5 +- src/webnn_native/null/ContextNull.cpp | 5 +- src/webnn_native/null/ContextNull.h | 3 +- src/webnn_native/openvino/GraphIE.cpp | 49 ++++++++++--------- src/webnn_native/openvino/GraphIE.h | 3 +- webnn.json | 16 +++--- 15 files changed, 103 insertions(+), 75 deletions(-) diff --git a/examples/SampleUtils.cpp b/examples/SampleUtils.cpp index 2fe8b1962..a7298eb47 100644 --- a/examples/SampleUtils.cpp +++ b/examples/SampleUtils.cpp @@ -93,11 +93,11 @@ namespace utils { namedOperands.Set(output.name.c_str(), output.operand); } builder.Build(namedOperands, - [](MLBuildStatus status, MLGraph impl, char const* message, + [](MLBuildGraphStatus status, MLGraph impl, char const* message, void* userData) { BuildData* buildDataPtr = reinterpret_cast(userData); DAWN_ASSERT(buildDataPtr); - if (status != MLBuildStatus_Success) { + if (status != MLBuildGraphStatus_Success) { dawn::ErrorLog() << "Compute failed: " << message; } else { buildDataPtr->graph = buildDataPtr->graph.Acquire(impl); @@ -124,11 +124,11 @@ namespace utils { } graph.Compute( namedInputs, - [](MLComputeStatus status, MLNamedResults impl, char const* message, + [](MLComputeGraphStatus status, MLNamedResults impl, char const* message, void* userData) { ComputeData* computeDataPtr = reinterpret_cast(userData); DAWN_ASSERT(computeDataPtr); - if (status != MLComputeStatus_Success) { + if (status != MLComputeGraphStatus_Success) { dawn::ErrorLog() << "Compute failed: " << message; } else { computeDataPtr->results = computeDataPtr->results.Acquire(impl); diff --git a/examples/SampleUtils.h b/examples/SampleUtils.h index d74d3db87..67056e267 100644 --- a/examples/SampleUtils.h +++ b/examples/SampleUtils.h @@ -35,15 +35,15 @@ bool Expected(float output, float expected); namespace utils { ml::Operand BuildInput(const ml::GraphBuilder& builder, - std::string name, - const std::vector& dimensions, - ml::OperandType type = ml::OperandType::Float32); + std::string name, + const std::vector& dimensions, + ml::OperandType type = ml::OperandType::Float32); ml::Operand BuildConstant(const ml::GraphBuilder& builder, - const std::vector& dimensions, - const void* value, - size_t size, - ml::OperandType type = ml::OperandType::Float32); + const std::vector& dimensions, + const void* value, + size_t size, + ml::OperandType type = ml::OperandType::Float32); struct Conv2dOptions { public: @@ -113,8 +113,7 @@ namespace utils { const ml::Operand& operand; } NamedOutput; - ml::Graph AwaitBuild(const ml::GraphBuilder& builder, - const std::vector& outputs); + ml::Graph AwaitBuild(const ml::GraphBuilder& builder, const std::vector& outputs); typedef struct { const std::string& name; @@ -122,7 +121,7 @@ namespace utils { } NamedInput; ml::NamedResults AwaitCompute(const ml::Graph& compilation, - const std::vector& inputs); + const std::vector& inputs); bool CheckShape(const ml::Result& result, const std::vector& expectedShape); diff --git a/src/common/BUILD.gn b/src/common/BUILD.gn index 06e6ac5ce..94e3fdf41 100644 --- a/src/common/BUILD.gn +++ b/src/common/BUILD.gn @@ -153,7 +153,7 @@ if (is_win || is_linux || is_chromeos || is_mac || is_fuchsia || is_android) { sources = [ "//third_party/dawn/src/common/Assert.cpp", "//third_party/dawn/src/common/Assert.h", - "//third_party/dawn/src/common/Computer.h", + "//third_party/dawn/src/common/Compiler.h", "//third_party/dawn/src/common/Log.cpp", "//third_party/dawn/src/common/Log.h", "//third_party/dawn/src/common/Math.cpp", diff --git a/src/tests/unittests/validation/GraphValidationTests.cpp b/src/tests/unittests/validation/GraphValidationTests.cpp index 155f31e75..0b2aadfb8 100644 --- a/src/tests/unittests/validation/GraphValidationTests.cpp +++ b/src/tests/unittests/validation/GraphValidationTests.cpp @@ -23,11 +23,11 @@ class MockGraphBuildCallback { public: MOCK_METHOD(void, Call, - (MLBuildStatus status, MLGraph impl, const char* message, void* userdata)); + (MLBuildGraphStatus status, MLGraph impl, const char* message, void* userdata)); }; static std::unique_ptr mockGraphBuildCallback; -static void ToMockGraphBuildCallback(MLBuildStatus status, +static void ToMockGraphBuildCallback(MLBuildGraphStatus status, MLGraph impl, const char* message, void* userdata) { @@ -63,12 +63,13 @@ TEST_F(GraphValidationTest, BuildCallBackSuccess) { ml::NamedOperands namedOperands = ml::CreateNamedOperands(); namedOperands.Set("output", mOutput); mBuilder.Build(namedOperands, ToMockGraphBuildCallback, this); - EXPECT_CALL(*mockGraphBuildCallback, Call(MLBuildStatus_Success, _, nullptr, this)).Times(1); + EXPECT_CALL(*mockGraphBuildCallback, Call(MLBuildGraphStatus_Success, _, nullptr, this)) + .Times(1); } // Create model with null nameOperands TEST_F(GraphValidationTest, BuildCallBackError) { ml::NamedOperands namedOperands = ml::CreateNamedOperands(); mBuilder.Build(namedOperands, ToMockGraphBuildCallback, this); - EXPECT_CALL(*mockGraphBuildCallback, Call(MLBuildStatus_Error, _, _, this)).Times(1); + EXPECT_CALL(*mockGraphBuildCallback, Call(MLBuildGraphStatus_Error, _, _, this)).Times(1); } diff --git a/src/webnn_native/Graph.cpp b/src/webnn_native/Graph.cpp index 1d9135c5f..e37e26254 100644 --- a/src/webnn_native/Graph.cpp +++ b/src/webnn_native/Graph.cpp @@ -25,7 +25,7 @@ namespace webnn_native { } void GraphBase::Compute(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs) { ComputeImpl(inputs, callback, userdata, outputs); @@ -71,4 +71,8 @@ namespace webnn_native { UNREACHABLE(); } + void GraphBase::Compile(BuildGraphCallbackDelgate delgate) { + CompileImpl(delgate); + } + } // namespace webnn_native diff --git a/src/webnn_native/Graph.h b/src/webnn_native/Graph.h index 1f4f3fb99..cdacf9685 100644 --- a/src/webnn_native/Graph.h +++ b/src/webnn_native/Graph.h @@ -19,6 +19,7 @@ #include "webnn_native/Context.h" #include "webnn_native/Error.h" #include "webnn_native/Forward.h" +#include "webnn_native/GraphBuilder.h" #include "webnn_native/ObjectBase.h" #include "webnn_native/Operand.h" #include "webnn_native/webnn_platform.h" @@ -43,7 +44,7 @@ namespace webnn_native { // Webnn API void Compute(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs = nullptr); @@ -57,10 +58,12 @@ namespace webnn_native { virtual MaybeError AddTranspose(const op::Transpose* transpose); virtual MaybeError AddUnary(const op::Unary* unary); virtual MaybeError Finish(); + virtual void Compile(BuildGraphCallbackDelgate delgate); private: + virtual void CompileImpl(BuildGraphCallbackDelgate delgate) = 0; virtual void ComputeImpl(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs) = 0; }; diff --git a/src/webnn_native/GraphBuilder.cpp b/src/webnn_native/GraphBuilder.cpp index af5cd8ce5..7935924c4 100644 --- a/src/webnn_native/GraphBuilder.cpp +++ b/src/webnn_native/GraphBuilder.cpp @@ -43,10 +43,10 @@ for (;;) \ break -#define BUILD_ERROR_AND_CALLBACK(message) \ - do { \ - callback(MLBuildStatus_Error, nullptr, message, userdata); \ - return; \ +#define BUILD_ERROR_AND_CALLBACK(message) \ + do { \ + callback(MLBuildGraphStatus_Error, nullptr, message, userdata); \ + return; \ } while (0) namespace webnn_native { @@ -111,7 +111,7 @@ namespace webnn_native { } void GraphBuilderBase::Build(NamedOperandsBase const* namedOperands, - MLBuildCallback callback, + MLBuildGraphCallback callback, void* userdata) { if (DAWN_UNLIKELY(this->IsError())) { BUILD_ERROR_AND_CALLBACK("This Graph object is an error"); @@ -140,8 +140,14 @@ namespace webnn_native { if (GetContext()->ConsumedError(graph->Finish())) { BUILD_ERROR_AND_CALLBACK("Failed to finish building graph."); } - callback(MLBuildStatus_Success, reinterpret_cast(graph.Detach()), nullptr, - userdata); + graph.Detach()->Compile([callback, userdata](MLBuildGraphStatus status, GraphBase* graph) { + if (status == MLBuildGraphStatus_Success) { + callback(status, reinterpret_cast(graph), nullptr, userdata); + } else { + delete graph; + callback(status, nullptr, "Failed to compile graph", userdata); + } + }); } // The implementation derives from nGraph topological_sort in diff --git a/src/webnn_native/GraphBuilder.h b/src/webnn_native/GraphBuilder.h index 70610faed..218b714ee 100644 --- a/src/webnn_native/GraphBuilder.h +++ b/src/webnn_native/GraphBuilder.h @@ -21,10 +21,14 @@ #include "webnn_native/ObjectBase.h" #include "webnn_native/webnn_platform.h" +#include #include namespace webnn_native { + using BuildGraphCallbackDelgate = + std::function; + class GraphBuilderBase : public ObjectBase { public: GraphBuilderBase(ContextBase* context); @@ -44,7 +48,7 @@ namespace webnn_native { OperandBase* Softmax(OperandBase*); OperandBase* Transpose(OperandBase*, TransposeOptions const* options); void Build(NamedOperandsBase const* named_operands, - MLBuildCallback callback, + MLBuildGraphCallback callback, void* userdata); private: diff --git a/src/webnn_native/dml/GraphDML.cpp b/src/webnn_native/dml/GraphDML.cpp index b3d21d77f..3742c14d3 100644 --- a/src/webnn_native/dml/GraphDML.cpp +++ b/src/webnn_native/dml/GraphDML.cpp @@ -623,6 +623,10 @@ namespace webnn_native { namespace dml { } } + return {}; + } + + void Graph::CompileImpl(BuildGraphCallbackDelgate delgate) { // FIXME(nhu): implement async std::vector<::dml::Expression> outputs; for (auto& output : mOutputs) { @@ -636,15 +640,15 @@ namespace webnn_native { namespace dml { for (auto& binding : mBindings) { inputBindings.push_back(binding.get()); } - if (FAILED(mDevice->InitializeOperator(mCompiledModel->op.Get(), inputBindings))) { - return DAWN_INTERNAL_ERROR("Failed to initialize operator"); - } - - return {}; + MLBuildGraphStatus status = + FAILED(mDevice->InitializeOperator(mCompiledModel->op.Get(), inputBindings)) + ? MLBuildGraphStatus_Error + : MLBuildGraphStatus_Success; + delgate(status, this); } void Graph::ComputeImpl(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs) { for (auto& input : inputs->GetRecords()) { @@ -672,7 +676,7 @@ namespace webnn_native { namespace dml { std::vector outputTensors; if (FAILED(mDevice->DispatchOperator(mCompiledModel->op.Get(), inputBindings, outputExpressions, outputTensors))) { - callback(MLComputeStatus_Error, nullptr, "Failed to dispatch operator", userdata); + callback(MLComputeGraphStatus_Error, nullptr, "Failed to dispatch operator", userdata); return; } @@ -697,7 +701,7 @@ namespace webnn_native { namespace dml { } delete tensor; } - callback(MLComputeStatus_Success, reinterpret_cast(results.Detach()), + callback(MLComputeGraphStatus_Success, reinterpret_cast(results.Detach()), nullptr, userdata); return; } diff --git a/src/webnn_native/dml/GraphDML.h b/src/webnn_native/dml/GraphDML.h index f43663a2a..8ebc5ef57 100644 --- a/src/webnn_native/dml/GraphDML.h +++ b/src/webnn_native/dml/GraphDML.h @@ -52,11 +52,10 @@ namespace webnn_native { namespace dml { virtual MaybeError AddUnary(const op::Unary* unary) override; virtual MaybeError Finish() override; - friend class Compilation; - private: + void CompileImpl(BuildGraphCallbackDelgate delgate) override; void ComputeImpl(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs) override; diff --git a/src/webnn_native/null/ContextNull.cpp b/src/webnn_native/null/ContextNull.cpp index 0cf3a052a..163e047eb 100644 --- a/src/webnn_native/null/ContextNull.cpp +++ b/src/webnn_native/null/ContextNull.cpp @@ -37,8 +37,11 @@ namespace webnn_native { namespace null { Graph::Graph(Context* context) : GraphBase(context) { } + void Graph::CompileImpl(BuildGraphCallbackDelgate delgate) { + } + void Graph::ComputeImpl(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs) { } diff --git a/src/webnn_native/null/ContextNull.h b/src/webnn_native/null/ContextNull.h index 7545f0df5..bc9a67ed5 100644 --- a/src/webnn_native/null/ContextNull.h +++ b/src/webnn_native/null/ContextNull.h @@ -55,8 +55,9 @@ namespace webnn_native { namespace null { virtual MaybeError Finish() override; private: + void CompileImpl(BuildGraphCallbackDelgate delgate) override; void ComputeImpl(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs = nullptr) override; }; diff --git a/src/webnn_native/openvino/GraphIE.cpp b/src/webnn_native/openvino/GraphIE.cpp index 7d8da8774..ae090e913 100644 --- a/src/webnn_native/openvino/GraphIE.cpp +++ b/src/webnn_native/openvino/GraphIE.cpp @@ -27,16 +27,16 @@ #include "webnn_native/openvino/ErrorIE.h" #include "webnn_native/openvino/ienn_symbol_table/ienn_symbol_table.h" -#define COMPUTE_CALLBACK_TRY(code, messages) \ - { \ - MaybeError maybeError = CheckStatusCode(code, messages); \ - if (maybeError.IsError()) { \ - std::unique_ptr error = maybeError.AcquireError(); \ - callback(MLComputeStatus_Error, nullptr, error->GetMessage().c_str(), userdata); \ - return; \ - } \ - } \ - for (;;) \ +#define COMPUTE_ERROR_CALLBACK(code, messages) \ + { \ + MaybeError maybeError = CheckStatusCode(code, messages); \ + if (maybeError.IsError()) { \ + std::unique_ptr error = maybeError.AcquireError(); \ + callback(MLComputeGraphStatus_Error, nullptr, error->GetMessage().c_str(), userdata); \ + return; \ + } \ + } \ + for (;;) \ break namespace webnn_native { namespace ie { @@ -261,17 +261,20 @@ namespace webnn_native { namespace ie { IEStatusCode code = IE(ie_model_finish)(mIeModel); DAWN_TRY(CheckStatusCode(code, "IE finish creating model")); - // We may leverage https://dawn-review.googlesource.com/c/dawn/+/36360 to + return {}; + } + + void Graph::CompileImpl(BuildGraphCallbackDelgate delgate) { + // TODO(junwei): We may leverage https://dawn-review.googlesource.com/c/dawn/+/36360 to // implement async compilation as standle-alone component. // Create compilation for IE backend. - code = IE(ie_create_compilation)(mIeModel, &mIeCompilation); - DAWN_TRY(CheckStatusCode(code, "IE create compilation")); - - return {}; + IEStatusCode code = IE(ie_create_compilation)(mIeModel, &mIeCompilation); + delgate(code == IEStatusCode::OK ? MLBuildGraphStatus_Success : MLBuildGraphStatus_Error, + this); } void Graph::ComputeImpl(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs) { // Set input data to nGraph. @@ -280,29 +283,29 @@ namespace webnn_native { namespace ie { ieOperand.name = const_cast(mInputIdMap[input.first].c_str()); IEStatusCode code = IE(ie_compilation_set_input)( mIeCompilation, &ieOperand, input.second->buffer, input.second->size); - COMPUTE_CALLBACK_TRY(code, "IE set input"); + COMPUTE_ERROR_CALLBACK(code, "IE set input"); } // Compute the compiled model. IEStatusCode code = IE(ie_compilation_compute)(mIeCompilation); - COMPUTE_CALLBACK_TRY(code, "IE compute model"); + COMPUTE_ERROR_CALLBACK(code, "IE compute model"); // Get Data from nGraph with output. Ref results = AcquireRef(new NamedResultsBase()); size_t outputNumber = 0; code = IE(ie_model_get_outputs_number)(mIeModel, &outputNumber); - COMPUTE_CALLBACK_TRY(code, "Failing to get output number for IE."); + COMPUTE_ERROR_CALLBACK(code, "Failing to get output number for IE."); for (size_t i = 0; i < outputNumber; ++i) { std::string outputId = GetOutputId(mIeModel, i); void* outputBuffer; size_t bufferLength; IEStatusCode code = IE(ie_compilation_get_buffer)(mIeCompilation, outputId.data(), &outputBuffer, &bufferLength); - COMPUTE_CALLBACK_TRY(code, "IE get buffer"); + COMPUTE_ERROR_CALLBACK(code, "IE get buffer"); ie_dimensions_t ieDimensions; code = IE(ie_compilation_get_dimensions)(mIeCompilation, outputId.data(), &ieDimensions); - COMPUTE_CALLBACK_TRY(code, "IE get dimensions"); + COMPUTE_ERROR_CALLBACK(code, "IE get dimensions"); std::vector dimensions(ieDimensions.dims, ieDimensions.dims + ieDimensions.ranks); code = IE(ie_compilation_free_dimensions)(&ieDimensions); @@ -316,10 +319,10 @@ namespace webnn_native { namespace ie { ieOperand.name = const_cast(outputId.c_str()); IEStatusCode code = IE(ie_compilation_get_output)(mIeCompilation, &ieOperand, output->buffer, output->size); - COMPUTE_CALLBACK_TRY(code, "IE get output"); + COMPUTE_ERROR_CALLBACK(code, "IE get output"); } } - callback(MLComputeStatus_Success, reinterpret_cast(results.Detach()), + callback(MLComputeGraphStatus_Success, reinterpret_cast(results.Detach()), nullptr, userdata); return; } diff --git a/src/webnn_native/openvino/GraphIE.h b/src/webnn_native/openvino/GraphIE.h index 391ed6b63..73f11c909 100644 --- a/src/webnn_native/openvino/GraphIE.h +++ b/src/webnn_native/openvino/GraphIE.h @@ -51,8 +51,9 @@ namespace webnn_native { namespace ie { virtual MaybeError Finish() override; private: + void CompileImpl(BuildGraphCallbackDelgate delgate); void ComputeImpl(NamedInputsBase* inputs, - MLComputeCallback callback, + MLComputeGraphCallback callback, void* userdata, NamedOutputsBase* outputs) override; diff --git a/webnn.json b/webnn.json index 32d605010..a9a946c62 100644 --- a/webnn.json +++ b/webnn.json @@ -287,22 +287,22 @@ "name": "build", "args": [ {"name": "named operands", "type": "named operands"}, - {"name": "callback", "type": "build callback"}, + {"name": "callback", "type": "build graph callback"}, {"name": "userdata", "type": "void", "annotation": "*"} ] } ] }, - "build callback": { + "build graph callback": { "category": "callback", "args": [ - {"name": "status", "type": "build status"}, + {"name": "status", "type": "build graph status"}, {"name": "graph", "type": "graph"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "userdata", "type": "void", "annotation": "*"} ] }, - "build status": { + "build graph status": { "category": "enum", "values": [ {"value": 0, "name": "success"}, @@ -386,16 +386,16 @@ } ] }, - "compute callback": { + "compute graph callback": { "category": "callback", "args": [ - {"name": "status", "type": "compute status"}, + {"name": "status", "type": "compute graph status"}, {"name": "outputs", "type": "named results"}, {"name": "message", "type": "char", "annotation": "const*", "length": "strlen"}, {"name": "userdata", "type": "void", "annotation": "*"} ] }, - "compute status": { + "compute graph status": { "category": "enum", "values": [ {"value": 0, "name": "success"}, @@ -411,7 +411,7 @@ "name": "compute", "args": [ {"name": "inputs", "type": "named inputs"}, - {"name": "callback", "type": "compute callback"}, + {"name": "callback", "type": "compute graph callback"}, {"name": "userdata", "type": "void", "annotation": "*"}, {"name": "outputs", "type": "named outputs", "optional": true} ]