diff --git a/Cargo.lock b/Cargo.lock index a211324257ef..ed8aeefc252f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3773,7 +3773,9 @@ dependencies = [ "anyhow", "openvino", "thiserror", + "tracing", "walkdir", + "wasmtime", "wiggle", ] diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index b0977562638b..13f1501583e4 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -12,11 +12,15 @@ readme = "README.md" edition.workspace = true [dependencies] -# These dependencies are necessary for the witx-generation macros to work: +# These dependencies are necessary for the WITX-generation macros to work: anyhow = { workspace = true } wiggle = { workspace = true } +# This dependency is necessary for the WIT-generation macros to work: +wasmtime = { workspace = true, features = ["component-model"] } + # These dependencies are necessary for the wasi-nn implementation: +tracing = { workspace = true } openvino = { version = "0.5.0", features = ["runtime-linking"] } thiserror = { workspace = true } diff --git a/crates/wasi-nn/README.md b/crates/wasi-nn/README.md index 165b5063d002..4933fcba339c 100644 --- a/crates/wasi-nn/README.md +++ b/crates/wasi-nn/README.md @@ -1,38 +1,45 @@ # wasmtime-wasi-nn -This crate enables support for the [wasi-nn] API in Wasmtime. Currently it contains an implementation of [wasi-nn] using -OpenVINO™ but in the future it could support multiple machine learning backends. Since the [wasi-nn] API is expected -to be an optional feature of WASI, this crate is currently separate from the [wasi-common] crate. This crate is -experimental and its API, functionality, and location could quickly change. +This crate enables support for the [wasi-nn] API in Wasmtime. Currently it +contains an implementation of [wasi-nn] using OpenVINO™ but in the future it +could support multiple machine learning backends. Since the [wasi-nn] API is +expected to be an optional feature of WASI, this crate is currently separate +from the [wasi-common] crate. This crate is experimental and its API, +functionality, and location could quickly change. [examples]: examples [openvino]: https://crates.io/crates/openvino [wasi-nn]: https://github.com/WebAssembly/wasi-nn [wasi-common]: ../wasi-common +[bindings]: https://crates.io/crates/wasi-nn ### Use -Use the Wasmtime APIs to instantiate a Wasm module and link in the `WasiNn` implementation as follows: +Use the Wasmtime APIs to instantiate a Wasm module and link in the `wasi-nn` +implementation as follows: -``` -let wasi_nn = WasiNn::new(&store, WasiNnCtx::new()?); -wasi_nn.add_to_linker(&mut linker)?; +```rust +let wasi_nn = WasiNnCtx::new()?; +wasmtime_wasi_nn::witx::add_to_linker(...); ``` ### Build -This crate should build as usual (i.e. `cargo build`) but note that using an existing installation of OpenVINO™, rather -than building from source, will drastically improve the build times. See the [openvino] crate for more information +```sh +$ cargo build +``` + +To use the WIT-based ABI, compile with `--features component-model` and use `wasmtime_wasi_nn::wit::add_to_linker`. ### Example An end-to-end example demonstrating ML classification is included in [examples]: - - `tests/wasi-nn-rust-bindings` contains ergonomic bindings for writing Rust code against the [wasi-nn] APIs - - `tests/classification-example` contains a standalone Rust project that uses the [wasi-nn] APIs and is compiled to the - `wasm32-wasi` target using the `wasi-nn-rust-bindings` +`examples/classification-example` contains a standalone Rust project that uses +the [wasi-nn] APIs and is compiled to the `wasm32-wasi` target using the +high-level `wasi-nn` [bindings]. Run the example from the Wasmtime project directory: -``` -ci/run-wasi-nn-example.sh +```sh +$ ci/run-wasi-nn-example.sh ``` diff --git a/crates/wasi-nn/examples/classification-example/Cargo.lock b/crates/wasi-nn/examples/classification-example/Cargo.lock index 0a2414873852..a649a0429289 100644 --- a/crates/wasi-nn/examples/classification-example/Cargo.lock +++ b/crates/wasi-nn/examples/classification-example/Cargo.lock @@ -1,5 +1,7 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + [[package]] name = "wasi-nn" version = "0.1.0" @@ -8,7 +10,7 @@ checksum = "0c909acded993dc129e02f64a7646eb7b53079f522a814024a88772f41558996" [[package]] name = "wasi-nn-example" -version = "0.19.0" +version = "0.0.0" dependencies = [ "wasi-nn", ] diff --git a/crates/wasi-nn/src/api.rs b/crates/wasi-nn/src/backend/mod.rs similarity index 72% rename from crates/wasi-nn/src/api.rs rename to crates/wasi-nn/src/backend/mod.rs index 2ad6e0edf94e..19b6610f1581 100644 --- a/crates/wasi-nn/src/api.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -1,17 +1,25 @@ //! Define the Rust interface a backend must implement in order to be used by -//! this crate. the `Box` types returned by these interfaces allow +//! this crate. The `Box` types returned by these interfaces allow //! implementations to maintain backend-specific state between calls. -use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor}; +mod openvino; + +use self::openvino::OpenvinoBackend; +use crate::wit::types::{ExecutionTarget, Tensor}; use thiserror::Error; use wiggle::GuestError; +/// Return a list of all available backend frameworks. +pub(crate) fn list() -> Vec<(BackendKind, Box)> { + vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))] +} + /// A [Backend] contains the necessary state to load [BackendGraph]s. pub(crate) trait Backend: Send + Sync { fn name(&self) -> &str; fn load( &mut self, - builders: &GraphBuilderArray<'_>, + builders: &[&[u8]], target: ExecutionTarget, ) -> Result, BackendError>; } @@ -25,7 +33,7 @@ pub(crate) trait BackendGraph: Send + Sync { /// A [BackendExecutionContext] performs the actual inference; this is the /// backing implementation for a [crate::witx::types::GraphExecutionContext]. pub(crate) trait BackendExecutionContext: Send + Sync { - fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError>; + fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>; fn compute(&mut self) -> Result<(), BackendError>; fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result; } @@ -39,7 +47,12 @@ pub enum BackendError { #[error("Failed while accessing guest module")] GuestAccess(#[from] GuestError), #[error("The backend expects {0} buffers, passed {1}")] - InvalidNumberOfBuilders(u32, u32), + InvalidNumberOfBuilders(usize, usize), #[error("Not enough memory to copy tensor data of size: {0}")] NotEnoughMemory(usize), } + +#[derive(Hash, PartialEq, Eq, Clone, Copy)] +pub(crate) enum BackendKind { + OpenVINO, +} diff --git a/crates/wasi-nn/src/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs similarity index 71% rename from crates/wasi-nn/src/openvino.rs rename to crates/wasi-nn/src/backend/openvino.rs index 9924326369f3..d44236250760 100644 --- a/crates/wasi-nn/src/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -1,13 +1,12 @@ -//! Implements the wasi-nn API. +//! Implements a `wasi-nn` [`Backend`] using OpenVINO. -use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; -use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor, TensorType}; +use super::{Backend, BackendError, BackendExecutionContext, BackendGraph}; +use crate::wit::types::{ExecutionTarget, Tensor, TensorType}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; use std::sync::Arc; #[derive(Default)] pub(crate) struct OpenvinoBackend(Option); - unsafe impl Send for OpenvinoBackend {} unsafe impl Sync for OpenvinoBackend {} @@ -18,7 +17,7 @@ impl Backend for OpenvinoBackend { fn load( &mut self, - builders: &GraphBuilderArray<'_>, + builders: &[&[u8]], target: ExecutionTarget, ) -> Result, BackendError> { if builders.len() != 2 { @@ -34,16 +33,8 @@ impl Backend for OpenvinoBackend { } // Read the guest array. - let builders = builders.as_ptr(); - let xml = builders - .read()? - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); - let weights = builders - .add(1)? - .read()? - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + let xml = &builders[0]; + let weights = &builders[1]; // Construct OpenVINO graph structures: `cnn_network` contains the graph // structure, `exec_network` can perform inference. @@ -53,8 +44,9 @@ impl Backend for OpenvinoBackend { .expect("openvino::Core was previously constructed"); let mut cnn_network = core.read_network_from_buffer(&xml, &weights)?; - // TODO this is a temporary workaround. We need a more eligant way to specify the layout in the long run. - // However, without this newer versions of OpenVINO will fail due to parameter mismatch. + // TODO: this is a temporary workaround. We need a more elegant way to + // specify the layout in the long run. However, without this newer + // versions of OpenVINO will fail due to parameter mismatch. for i in 0..cnn_network.get_inputs_len()? { let name = cnn_network.get_input_name(i)?; cnn_network.set_input_layout(&name, Layout::NHWC)?; @@ -85,27 +77,19 @@ impl BackendGraph for OpenvinoGraph { struct OpenvinoExecutionContext(Arc, openvino::InferRequest); impl BackendExecutionContext for OpenvinoExecutionContext { - fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError> { + fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { let input_name = self.0.get_input_name(index as usize)?; - // Construct the blob structure. + // Construct the blob structure. TODO: there must be some good way to + // discover the layout here; `desc` should not have to default to NHWC. + let precision = map_tensor_type_to_precision(tensor.tensor_type); let dimensions = tensor .dimensions - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)") .iter() - .map(|d| *d as usize) + .map(|&d| d as usize) .collect::>(); - let precision = map_tensor_type_to_precision(tensor.type_); - - // TODO There must be some good way to discover the layout here; this - // should not have to default to NHWC. let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision); - let data = tensor - .data - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); - let blob = openvino::Blob::new(&desc, &data)?; + let blob = openvino::Blob::new(&desc, &tensor.data)?; // Actually assign the blob to the request. self.1.set_blob(&input_name, &blob)?; @@ -157,8 +141,8 @@ fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str { /// wasi-nn. fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision { match tensor_type { - TensorType::F16 => Precision::FP16, - TensorType::F32 => Precision::FP32, + TensorType::Fp16 => Precision::FP16, + TensorType::Fp32 => Precision::FP32, TensorType::U8 => Precision::U8, TensorType::I32 => Precision::I32, } diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs index 988bc27bcb03..e2cea15f9654 100644 --- a/crates/wasi-nn/src/ctx.rs +++ b/crates/wasi-nn/src/ctx.rs @@ -1,31 +1,31 @@ -//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the -//! implementation of the wasi-nn API. -use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; -use crate::openvino::OpenvinoBackend; -use crate::r#impl::UsageError; -use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext}; +//! Implements the host state for the `wasi-nn` API: [WasiNnCtx]. + +use crate::backend::{ + self, Backend, BackendError, BackendExecutionContext, BackendGraph, BackendKind, +}; +use crate::wit::types::GraphEncoding; use std::collections::HashMap; use std::hash::Hash; use thiserror::Error; use wiggle::GuestError; +type GraphId = u32; +type GraphExecutionContextId = u32; + /// Capture the state necessary for calling into the backend ML libraries. pub struct WasiNnCtx { - pub(crate) backends: HashMap>, - pub(crate) graphs: Table>, - pub(crate) executions: Table>, + pub(crate) backends: HashMap>, + pub(crate) graphs: Table>, + pub(crate) executions: Table>, } impl WasiNnCtx { /// Make a new context from the default state. pub fn new() -> WasiNnResult { let mut backends = HashMap::new(); - backends.insert( - // This is necessary because Wiggle's variant types do not derive - // `Hash` and `Eq`. - GraphEncoding::Openvino.into(), - Box::new(OpenvinoBackend::default()) as Box, - ); + for (kind, backend) in backend::list() { + backends.insert(kind, backend); + } Ok(Self { backends, graphs: Table::default(), @@ -45,6 +45,22 @@ pub enum WasiNnError { UsageError(#[from] UsageError), } +#[derive(Debug, Error)] +pub enum UsageError { + #[error("Invalid context; has the load function been called?")] + InvalidContext, + #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] + InvalidEncoding(GraphEncoding), + #[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")] + InvalidNumberOfBuilders(u32), + #[error("Invalid graph handle; has it been loaded?")] + InvalidGraphHandle, + #[error("Invalid execution context handle; has it been initialized?")] + InvalidExecutionContextHandle, + #[error("Not enough memory to copy tensor data of size: {0}")] + NotEnoughMemory(u32), +} + pub(crate) type WasiNnResult = std::result::Result; /// Record handle entries in a table. diff --git a/crates/wasi-nn/src/impl.rs b/crates/wasi-nn/src/impl.rs deleted file mode 100644 index 0f8da5247a7b..000000000000 --- a/crates/wasi-nn/src/impl.rs +++ /dev/null @@ -1,93 +0,0 @@ -//! Implements the wasi-nn API. -use crate::ctx::WasiNnResult as Result; -use crate::witx::types::{ - ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor, -}; -use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn; -use crate::WasiNnCtx; -use thiserror::Error; -use wiggle::GuestPtr; - -#[derive(Debug, Error)] -pub enum UsageError { - #[error("Invalid context; has the load function been called?")] - InvalidContext, - #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] - InvalidEncoding(GraphEncoding), - #[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")] - InvalidNumberOfBuilders(u32), - #[error("Invalid graph handle; has it been loaded?")] - InvalidGraphHandle, - #[error("Invalid execution context handle; has it been initialized?")] - InvalidExecutionContextHandle, - #[error("Not enough memory to copy tensor data of size: {0}")] - NotEnoughMemory(u32), -} - -impl<'a> WasiEphemeralNn for WasiNnCtx { - fn load<'b>( - &mut self, - builders: &GraphBuilderArray<'_>, - encoding: GraphEncoding, - target: ExecutionTarget, - ) -> Result { - let encoding_id: u8 = encoding.into(); - let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) { - backend.load(builders, target)? - } else { - return Err(UsageError::InvalidEncoding(encoding).into()); - }; - let graph_id = self.graphs.insert(graph); - Ok(graph_id) - } - - fn init_execution_context(&mut self, graph_id: Graph) -> Result { - let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { - graph.init_execution_context()? - } else { - return Err(UsageError::InvalidGraphHandle.into()); - }; - - let exec_context_id = self.executions.insert(exec_context); - Ok(exec_context_id) - } - - fn set_input<'b>( - &mut self, - exec_context_id: GraphExecutionContext, - index: u32, - tensor: &Tensor<'b>, - ) -> Result<()> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - Ok(exec_context.set_input(index, tensor)?) - } else { - Err(UsageError::InvalidGraphHandle.into()) - } - } - - fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - Ok(exec_context.compute()?) - } else { - Err(UsageError::InvalidExecutionContextHandle.into()) - } - } - - fn get_output<'b>( - &mut self, - exec_context_id: GraphExecutionContext, - index: u32, - out_buffer: &GuestPtr<'_, u8>, - out_buffer_max_size: u32, - ) -> Result { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - let mut destination = out_buffer - .as_array(out_buffer_max_size) - .as_slice_mut()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); - Ok(exec_context.get_output(index, &mut destination)?) - } else { - Err(UsageError::InvalidGraphHandle.into()) - } - } -} diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 7efae80f6159..2cf8d6e8e56b 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -1,8 +1,6 @@ -mod api; +mod backend; mod ctx; -mod r#impl; -mod openvino; -mod witx; pub use ctx::WasiNnCtx; -pub use witx::wasi_ephemeral_nn::add_to_linker; +pub mod wit; +pub mod witx; diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs new file mode 100644 index 000000000000..e5a823bde939 --- /dev/null +++ b/crates/wasi-nn/src/wit.rs @@ -0,0 +1,124 @@ +//! Implements the `wasi-nn` API for the WIT ("preview2") ABI. +//! +//! Note that `wasi-nn` is not yet included in an official "preview2" world +//! (though it could be) so by "preview2" here we mean that this can be called +//! with the component model's canonical ABI. +//! +//! This module exports its [`types`] for use throughout the crate and the +//! [`ML`] object, which exposes [`ML::add_to_linker`]. To implement all of +//! this, this module proceeds in steps: +//! 1. generate all of the WIT glue code into a `gen::*` namespace +//! 2. wire up the `gen::*` glue to the context state, delegating actual +//! computation to a [`Backend`] +//! 3. convert some types +//! +//! [`Backend`]: crate::backend::Backend +//! [`types`]: crate::wit::types + +use crate::{backend::BackendKind, ctx::UsageError, WasiNnCtx}; + +pub use gen::types; +pub use gen_::Ml as ML; + +/// Generate the traits and types from the `wasi-nn` WIT specification. +mod gen_ { + wasmtime::component::bindgen!("ml"); +} +use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need. + +impl gen::inference::Host for WasiNnCtx { + /// Load an opaque sequence of bytes to use for inference. + fn load( + &mut self, + builders: gen::types::GraphBuilderArray, + encoding: gen::types::GraphEncoding, + target: gen::types::ExecutionTarget, + ) -> wasmtime::Result> { + let backend_kind: BackendKind = encoding.try_into()?; + let graph = if let Some(backend) = self.backends.get_mut(&backend_kind) { + let slices = builders.iter().map(|s| s.as_slice()).collect::>(); + backend.load(&slices, target.into())? + } else { + return Err(UsageError::InvalidEncoding(encoding.into()).into()); + }; + let graph_id = self.graphs.insert(graph); + Ok(Ok(graph_id)) + } + + /// Create an execution instance of a loaded graph. + /// + /// TODO: remove completely? + fn init_execution_context( + &mut self, + graph_id: gen::types::Graph, + ) -> wasmtime::Result> { + let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { + graph.init_execution_context()? + } else { + return Err(UsageError::InvalidGraphHandle.into()); + }; + + let exec_context_id = self.executions.insert(exec_context); + Ok(Ok(exec_context_id)) + } + + /// Define the inputs to use for inference. + fn set_input( + &mut self, + exec_context_id: gen::types::GraphExecutionContext, + index: u32, + tensor: gen::types::Tensor, + ) -> wasmtime::Result> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id) { + exec_context.set_input(index, &tensor)?; + Ok(Ok(())) + } else { + Err(UsageError::InvalidGraphHandle.into()) + } + } + + /// Compute the inference on the given inputs. + /// + /// TODO: refactor to compute(list) -> result, error> + fn compute( + &mut self, + exec_context_id: gen::types::GraphExecutionContext, + ) -> wasmtime::Result> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id) { + exec_context.compute()?; + Ok(Ok(())) + } else { + Err(UsageError::InvalidExecutionContextHandle.into()) + } + } + + /// Extract the outputs after inference. + fn get_output( + &mut self, + exec_context_id: gen::types::GraphExecutionContext, + index: u32, + ) -> wasmtime::Result> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id) { + // Read the output bytes. TODO: this involves a hard-coded upper + // limit on the tensor size that is necessary because there is no + // way to introspect the graph outputs + // (https://github.com/WebAssembly/wasi-nn/issues/37). + let mut destination = vec![0; 1024 * 1024]; + let bytes_read = exec_context.get_output(index, &mut destination)?; + destination.truncate(bytes_read as usize); + Ok(Ok(destination)) + } else { + Err(UsageError::InvalidGraphHandle.into()) + } + } +} + +impl TryFrom for crate::backend::BackendKind { + type Error = UsageError; + fn try_from(value: gen::types::GraphEncoding) -> Result { + match value { + gen::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO), + _ => Err(UsageError::InvalidEncoding(value.into())), + } + } +} diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index e7c877bd907e..7371017d0424 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -1,30 +1,175 @@ -//! Contains the macro-generated implementation of wasi-nn from the its witx definition file. -use crate::ctx::WasiNnCtx; -use crate::ctx::WasiNnError; -use anyhow::Result; +//! Implements the `wasi-nn` API for the WITX ("preview1") ABI. +//! +//! `wasi-nn` was never included in the official "preview1" snapshot, but this +//! module implements the ABI that is compatible with "preview1". +//! +//! The only export from this module is [`add_to_linker`]. To implement it, this +//! module proceeds in steps: +//! 1. generate all of the Wiggle glue code into a `gen::*` namespace +//! 2. wire up the `gen::*` glue to the context state, delegating actual +//! computation to a `Backend` +//! 3. wrap up with some conversions, i.e., from `gen::*` types to this crate's +//! [`types`]. +//! +//! [`types`]: crate::wit::types -// Generate the traits and types of wasi-nn in several Rust modules (e.g. `types`). -wiggle::from_witx!({ - witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"], - errors: { nn_errno => WasiNnError } -}); +use crate::backend::BackendKind; +use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; +use wiggle::GuestPtr; -use types::NnErrno; +pub use gen::wasi_ephemeral_nn::add_to_linker; -impl<'a> types::UserErrorConversion for WasiNnCtx { - fn nn_errno_from_wasi_nn_error(&mut self, e: WasiNnError) -> Result { - eprintln!("Host error: {:?}", e); - match e { - WasiNnError::BackendError(_) => unimplemented!(), - WasiNnError::GuestError(_) => unimplemented!(), - WasiNnError::UsageError(_) => unimplemented!(), +/// Generate the traits and types from the `wasi-nn` WITX specification. +mod gen { + use super::*; + wiggle::from_witx!({ + witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"], + errors: { nn_errno => WasiNnError } + }); + + /// Additionally, we must let Wiggle know which of our error codes + /// represents a successful operation. + impl wiggle::GuestErrorType for types::NnErrno { + fn success() -> Self { + Self::Success + } + } + + /// Convert the host errors to their WITX-generated type. + impl<'a> types::UserErrorConversion for WasiNnCtx { + fn nn_errno_from_wasi_nn_error( + &mut self, + e: WasiNnError, + ) -> anyhow::Result { + tracing::debug!("host error: {:?}", e); + match e { + WasiNnError::BackendError(_) => unimplemented!(), + WasiNnError::GuestError(_) => unimplemented!(), + WasiNnError::UsageError(_) => unimplemented!(), + } + } + } +} + +/// Wire up the WITX-generated trait to the `wasi-nn` host state. +impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { + fn load<'b>( + &mut self, + builders: &gen::types::GraphBuilderArray<'_>, + encoding: gen::types::GraphEncoding, + target: gen::types::ExecutionTarget, + ) -> Result { + let graph = if let Some(backend) = self.backends.get_mut(&encoding.into()) { + // Retrieve all of the "builder lists" from the Wasm memory (see + // $graph_builder_array) as slices for a backend to operate on. + let mut slices = vec![]; + for builder in builders.iter() { + let slice = builder? + .read()? + .as_slice()? + .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + slices.push(slice); + } + let slice_refs = slices.iter().map(|s| s.as_ref()).collect::>(); + backend.load(&slice_refs, target.into())? + } else { + return Err(UsageError::InvalidEncoding(encoding.into()).into()); + }; + let graph_id = self.graphs.insert(graph); + Ok(graph_id.into()) + } + + fn init_execution_context( + &mut self, + graph_id: gen::types::Graph, + ) -> Result { + let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id.into()) { + graph.init_execution_context()? + } else { + return Err(UsageError::InvalidGraphHandle.into()); + }; + + let exec_context_id = self.executions.insert(exec_context); + Ok(exec_context_id.into()) + } + + fn set_input<'b>( + &mut self, + exec_context_id: gen::types::GraphExecutionContext, + index: u32, + tensor: &gen::types::Tensor<'b>, + ) -> Result<()> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { + let tensor = crate::wit::types::Tensor { + dimensions: tensor.dimensions.to_vec()?, + tensor_type: tensor.type_.into(), + data: tensor.data.to_vec()?, + }; + Ok(exec_context.set_input(index, &tensor)?) + } else { + Err(UsageError::InvalidGraphHandle.into()) + } + } + + fn compute(&mut self, exec_context_id: gen::types::GraphExecutionContext) -> Result<()> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { + Ok(exec_context.compute()?) + } else { + Err(UsageError::InvalidExecutionContextHandle.into()) + } + } + + fn get_output<'b>( + &mut self, + exec_context_id: gen::types::GraphExecutionContext, + index: u32, + out_buffer: &GuestPtr<'_, u8>, + out_buffer_max_size: u32, + ) -> Result { + if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { + let mut destination = out_buffer + .as_array(out_buffer_max_size) + .as_slice_mut()? + .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + Ok(exec_context.get_output(index, &mut destination)?) + } else { + Err(UsageError::InvalidGraphHandle.into()) } } } -/// Additionally, we must let Wiggle know which of our error codes represents a successful operation. -impl wiggle::GuestErrorType for NnErrno { - fn success() -> Self { - Self::Success +// Implement some conversion from `witx::types::*` to this crate's version. + +impl From for BackendKind { + fn from(value: gen::types::GraphEncoding) -> Self { + match value { + gen::types::GraphEncoding::Openvino => BackendKind::OpenVINO, + } + } +} +impl From for crate::wit::types::ExecutionTarget { + fn from(value: gen::types::ExecutionTarget) -> Self { + match value { + gen::types::ExecutionTarget::Cpu => crate::wit::types::ExecutionTarget::Cpu, + gen::types::ExecutionTarget::Gpu => crate::wit::types::ExecutionTarget::Gpu, + gen::types::ExecutionTarget::Tpu => crate::wit::types::ExecutionTarget::Tpu, + } + } +} +impl From for crate::wit::types::GraphEncoding { + fn from(value: gen::types::GraphEncoding) -> Self { + match value { + gen::types::GraphEncoding::Openvino => crate::wit::types::GraphEncoding::Openvino, + } + } +} +impl From for crate::wit::types::TensorType { + fn from(value: gen::types::TensorType) -> Self { + match value { + gen::types::TensorType::F16 => crate::wit::types::TensorType::Fp16, + gen::types::TensorType::F32 => crate::wit::types::TensorType::Fp32, + gen::types::TensorType::U8 => crate::wit::types::TensorType::U8, + gen::types::TensorType::I32 => crate::wit::types::TensorType::I32, + } } } diff --git a/crates/wasi-nn/wit/inference.wit b/crates/wasi-nn/wit/inference.wit new file mode 100644 index 000000000000..df754231f696 --- /dev/null +++ b/crates/wasi-nn/wit/inference.wit @@ -0,0 +1,24 @@ +interface inference { + use types.{graph-builder-array, graph-encoding, execution-target, graph, + tensor, tensor-data, error, graph-execution-context} + + /// Load an opaque sequence of bytes to use for inference. + load: func(builder: graph-builder-array, encoding: graph-encoding, + target: execution-target) -> result + + /// Create an execution instance of a loaded graph. + /// + /// TODO: remove completely? + init-execution-context: func(graph: graph) -> result + + /// Define the inputs to use for inference. + set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error> + + /// Compute the inference on the given inputs. + /// + /// TODO: refactor to compute(list) -> result, error> + compute: func(ctx: graph-execution-context) -> result<_, error> + + /// Extract the outputs after inference. + get-output: func(ctx: graph-execution-context, index: u32) -> result +} diff --git a/crates/wasi-nn/wit/types.wit b/crates/wasi-nn/wit/types.wit new file mode 100644 index 000000000000..f134b730e1a8 --- /dev/null +++ b/crates/wasi-nn/wit/types.wit @@ -0,0 +1,88 @@ +interface types { + /// The dimensions of a tensor. + /// + /// The array length matches the tensor rank and each element in the array + /// describes the size of each dimension. + type tensor-dimensions = list + + /// The type of the elements in a tensor. + enum tensor-type { + FP16, + FP32, + U8, + I32 + } + + /// The tensor data. + /// + /// Initially conceived as a sparse representation, each empty cell would be filled with zeros and + /// the array length must match the product of all of the dimensions and the number of bytes in the + /// type (e.g., a 2x2 tensor with 4-byte f32 elements would have a data array of length 16). + /// Naturally, this representation requires some knowledge of how to lay out data in memory--e.g., + /// using row-major ordering--and could perhaps be improved. + type tensor-data = list + + /// A tensor. + record tensor { + /// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor + /// containing a single value, use `[1]` for the tensor dimensions. + dimensions: tensor-dimensions, + + /// Describe the type of element in the tensor (e.g., f32). + tensor-type: tensor-type, + + /// Contains the tensor data. + data: tensor-data, + } + + /// The graph initialization data. + // + /// This consists of an array of buffers because implementing backends may encode their graph IR in + /// parts (e.g., OpenVINO stores its IR and weights separately). + type graph-builder = list + type graph-builder-array = list + + /// An execution graph for performing inference (i.e., a model). + /// + /// TODO: replace with `resource` + type graph = u32 + + /// Describes the encoding of the graph. This allows the API to be implemented by various backends + /// that encode (i.e., serialize) their graph IR with different formats. + enum graph-encoding { + openvino, + onnx, + tensorflow, + pytorch, + tensorflowlite + } + + /// Define where the graph should be executed. + enum execution-target { + cpu, + gpu, + tpu + } + + /// Bind a `graph` to the input and output tensors for an inference. + /// + /// TODO: replace with `resource` + /// TODO: remove execution contexts completely + type graph-execution-context = u32 + + /// Error codes returned by functions in this API. + enum error { + /// No error occurred. + success, + /// Caller module passed an invalid argument. + invalid-argument, + /// Invalid encoding. + invalid-encoding, + /// Caller module is missing a memory export. + missing-memory, + /// Device or resource busy. + busy, + /// Runtime Error. + runtime-error, + } +} diff --git a/crates/wasi-nn/wit/world.wit b/crates/wasi-nn/wit/world.wit new file mode 100644 index 000000000000..42bffb93420c --- /dev/null +++ b/crates/wasi-nn/wit/world.wit @@ -0,0 +1,20 @@ +/// `wasi-nn` API +/// +/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The +/// API is not (yet) capable of performing ML training. WebAssembly programs +/// that want to use a host's ML capabilities can access these capabilities +/// through `wasi-nn`'s core abstractions: _backends_, _graphs_, and _tensors_. +/// A user selects a _backend_ for inference and `load`s a model, instantiated +/// as a _graph_, to use in the _backend_. Then, the user passes _tensor_ inputs +/// to the _graph_, computes the inference, and retrieves the _tensor_ outputs. +/// +/// This module draws inspiration from the inference side of +/// [WebNN](https://webmachinelearning.github.io/webnn/#api). See the +/// [README](https://github.com/WebAssembly/wasi-nn/blob/main/README.md) for +/// more context about the design and goals of this API. + +package wasi:nn + +world ml { + import inference +} diff --git a/src/commands/run.rs b/src/commands/run.rs index 8a124291a0d0..7f9eddc9ff14 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -722,7 +722,7 @@ fn populate_with_wasi( } #[cfg(feature = "wasi-nn")] { - wasmtime_wasi_nn::add_to_linker(linker, |host| { + wasmtime_wasi_nn::witx::add_to_linker(linker, |host| { // This WASI proposal is currently not protected against // concurrent access--i.e., when wasi-threads is actively // spawning new threads, we cannot (yet) safely allow access and