From f93df0cbf9f434ecaa7ea5d7d3f93e5d546ba861 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Tue, 8 Aug 2023 14:24:00 -0700 Subject: [PATCH 01/10] wasi-nn: refactor to allow `preview2` access This change refactors the `wasmtime-wasi-nn` crate to allow access from both `preview1` and `preview2` ABIs. Though the `wasi-nn` specification has included a WIT description for some time, here we use some in-tree files until https://github.com/WebAssembly/wasi-nn/pull/38 is landed. The `preview2` code is not exercised anywhere yet: ideally this would be wired up once component model `resource`s are fully implemented in Wasmtime. prtest:full --- Cargo.lock | 1 + crates/wasi-nn/Cargo.toml | 8 +- crates/wasi-nn/src/{api.rs => backend/mod.rs} | 29 ++- crates/wasi-nn/src/{ => backend}/openvino.rs | 57 ++---- crates/wasi-nn/src/ctx.rs | 46 +++-- crates/wasi-nn/src/impl.rs | 93 ---------- crates/wasi-nn/src/lib.rs | 10 +- crates/wasi-nn/src/preview1.rs | 174 ++++++++++++++++++ crates/wasi-nn/src/preview2.rs | 165 +++++++++++++++++ crates/wasi-nn/src/types.rs | 39 ++++ crates/wasi-nn/src/witx.rs | 30 --- crates/wasi-nn/wit/inference.wit | 24 +++ crates/wasi-nn/wit/types.wit | 88 +++++++++ crates/wasi-nn/wit/world.wit | 20 ++ 14 files changed, 597 insertions(+), 187 deletions(-) rename crates/wasi-nn/src/{api.rs => backend/mod.rs} (70%) rename crates/wasi-nn/src/{ => backend}/openvino.rs (67%) delete mode 100644 crates/wasi-nn/src/impl.rs create mode 100644 crates/wasi-nn/src/preview1.rs create mode 100644 crates/wasi-nn/src/preview2.rs create mode 100644 crates/wasi-nn/src/types.rs delete mode 100644 crates/wasi-nn/src/witx.rs create mode 100644 crates/wasi-nn/wit/inference.wit create mode 100644 crates/wasi-nn/wit/types.wit create mode 100644 crates/wasi-nn/wit/world.wit diff --git a/Cargo.lock b/Cargo.lock index a211324257ef..8da77b235cf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3774,6 +3774,7 @@ dependencies = [ "openvino", "thiserror", "walkdir", + "wasmtime", "wiggle", ] diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index b0977562638b..d48f769d3fdb 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -12,13 +12,19 @@ readme = "README.md" edition.workspace = true [dependencies] -# These dependencies are necessary for the witx-generation macros to work: +# These dependencies are necessary for the WITX-generation macros to work: anyhow = { workspace = true } wiggle = { workspace = true } +# This dependency is necessary for the WIT-generation macros to work: +wasmtime = { workspace = true, optional = true, features = ["component-model"] } + # These dependencies are necessary for the wasi-nn implementation: openvino = { version = "0.5.0", features = ["runtime-linking"] } thiserror = { workspace = true } [build-dependencies] walkdir = { workspace = true } + +[features] +preview2 = ["wasmtime"] diff --git a/crates/wasi-nn/src/api.rs b/crates/wasi-nn/src/backend/mod.rs similarity index 70% rename from crates/wasi-nn/src/api.rs rename to crates/wasi-nn/src/backend/mod.rs index 2ad6e0edf94e..9f9d925b735b 100644 --- a/crates/wasi-nn/src/api.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -1,17 +1,25 @@ //! Define the Rust interface a backend must implement in order to be used by -//! this crate. the `Box` types returned by these interfaces allow +//! this crate. The `Box` types returned by these interfaces allow //! implementations to maintain backend-specific state between calls. -use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor}; +mod openvino; + +use self::openvino::OpenvinoBackend; +use crate::types::{ExecutionTarget, Tensor}; use thiserror::Error; use wiggle::GuestError; +/// Return a list of all available backend frameworks. +pub(crate) fn list() -> Vec<(BackendKind, Box)> { + vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))] +} + /// A [Backend] contains the necessary state to load [BackendGraph]s. pub(crate) trait Backend: Send + Sync { fn name(&self) -> &str; fn load( &mut self, - builders: &GraphBuilderArray<'_>, + builders: &[&[u8]], target: ExecutionTarget, ) -> Result, BackendError>; } @@ -39,7 +47,20 @@ pub enum BackendError { #[error("Failed while accessing guest module")] GuestAccess(#[from] GuestError), #[error("The backend expects {0} buffers, passed {1}")] - InvalidNumberOfBuilders(u32, u32), + InvalidNumberOfBuilders(usize, usize), #[error("Not enough memory to copy tensor data of size: {0}")] NotEnoughMemory(usize), } + +#[derive(Hash, PartialEq, Eq, Clone, Copy)] +pub(crate) enum BackendKind { + OpenVINO, +} +impl From for BackendKind { + fn from(value: u8) -> Self { + match value { + 0 => BackendKind::OpenVINO, + _ => panic!("invalid backend"), + } + } +} diff --git a/crates/wasi-nn/src/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs similarity index 67% rename from crates/wasi-nn/src/openvino.rs rename to crates/wasi-nn/src/backend/openvino.rs index 9924326369f3..6f19c5208167 100644 --- a/crates/wasi-nn/src/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -1,13 +1,12 @@ -//! Implements the wasi-nn API. +//! Implements a `wasi-nn` [`Backend`] using OpenVINO. -use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; -use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor, TensorType}; +use super::{Backend, BackendError, BackendExecutionContext, BackendGraph}; +use crate::types::{ExecutionTarget, Tensor, TensorType}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; use std::sync::Arc; #[derive(Default)] pub(crate) struct OpenvinoBackend(Option); - unsafe impl Send for OpenvinoBackend {} unsafe impl Sync for OpenvinoBackend {} @@ -18,7 +17,7 @@ impl Backend for OpenvinoBackend { fn load( &mut self, - builders: &GraphBuilderArray<'_>, + builders: &[&[u8]], target: ExecutionTarget, ) -> Result, BackendError> { if builders.len() != 2 { @@ -34,16 +33,8 @@ impl Backend for OpenvinoBackend { } // Read the guest array. - let builders = builders.as_ptr(); - let xml = builders - .read()? - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); - let weights = builders - .add(1)? - .read()? - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + let xml = &builders[0]; + let weights = &builders[1]; // Construct OpenVINO graph structures: `cnn_network` contains the graph // structure, `exec_network` can perform inference. @@ -53,8 +44,9 @@ impl Backend for OpenvinoBackend { .expect("openvino::Core was previously constructed"); let mut cnn_network = core.read_network_from_buffer(&xml, &weights)?; - // TODO this is a temporary workaround. We need a more eligant way to specify the layout in the long run. - // However, without this newer versions of OpenVINO will fail due to parameter mismatch. + // TODO: this is a temporary workaround. We need a more elegant way to + // specify the layout in the long run. However, without this newer + // versions of OpenVINO will fail due to parameter mismatch. for i in 0..cnn_network.get_inputs_len()? { let name = cnn_network.get_input_name(i)?; cnn_network.set_input_layout(&name, Layout::NHWC)?; @@ -85,27 +77,14 @@ impl BackendGraph for OpenvinoGraph { struct OpenvinoExecutionContext(Arc, openvino::InferRequest); impl BackendExecutionContext for OpenvinoExecutionContext { - fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError> { + fn set_input<'a>(&mut self, index: u32, tensor: &Tensor<'a>) -> Result<(), BackendError> { let input_name = self.0.get_input_name(index as usize)?; - // Construct the blob structure. - let dimensions = tensor - .dimensions - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)") - .iter() - .map(|d| *d as usize) - .collect::>(); - let precision = map_tensor_type_to_precision(tensor.type_); - - // TODO There must be some good way to discover the layout here; this - // should not have to default to NHWC. - let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision); - let data = tensor - .data - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); - let blob = openvino::Blob::new(&desc, &data)?; + // Construct the blob structure. TODO: there must be some good way to + // discover the layout here; `desc` should not have to default to NHWC. + let precision = map_tensor_type_to_precision(tensor.ty); + let desc = TensorDesc::new(Layout::NHWC, tensor.dims, precision); + let blob = openvino::Blob::new(&desc, tensor.data)?; // Actually assign the blob to the request. self.1.set_blob(&input_name, &blob)?; @@ -147,9 +126,9 @@ impl From for BackendError { /// `ExecutionTarget` enum provided by wasi-nn. fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str { match target { - ExecutionTarget::Cpu => "CPU", - ExecutionTarget::Gpu => "GPU", - ExecutionTarget::Tpu => unimplemented!("OpenVINO does not support TPU execution targets"), + ExecutionTarget::CPU => "CPU", + ExecutionTarget::GPU => "GPU", + ExecutionTarget::TPU => unimplemented!("OpenVINO does not support TPU execution targets"), } } diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs index 988bc27bcb03..c6891d11a8f5 100644 --- a/crates/wasi-nn/src/ctx.rs +++ b/crates/wasi-nn/src/ctx.rs @@ -1,31 +1,31 @@ -//! Implements the base structure (i.e. [WasiNnCtx]) that will provide the -//! implementation of the wasi-nn API. -use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; -use crate::openvino::OpenvinoBackend; -use crate::r#impl::UsageError; -use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext}; +//! Implements the host state for the `wasi-nn` API: [WasiNnCtx]. + +use crate::backend::{ + self, Backend, BackendError, BackendExecutionContext, BackendGraph, BackendKind, +}; +use crate::types::GraphEncoding; use std::collections::HashMap; use std::hash::Hash; use thiserror::Error; use wiggle::GuestError; +type GraphId = u32; +type GraphExecutionContextId = u32; + /// Capture the state necessary for calling into the backend ML libraries. pub struct WasiNnCtx { - pub(crate) backends: HashMap>, - pub(crate) graphs: Table>, - pub(crate) executions: Table>, + pub(crate) backends: HashMap>, + pub(crate) graphs: Table>, + pub(crate) executions: Table>, } impl WasiNnCtx { /// Make a new context from the default state. pub fn new() -> WasiNnResult { let mut backends = HashMap::new(); - backends.insert( - // This is necessary because Wiggle's variant types do not derive - // `Hash` and `Eq`. - GraphEncoding::Openvino.into(), - Box::new(OpenvinoBackend::default()) as Box, - ); + for (kind, backend) in backend::list() { + backends.insert(kind, backend); + } Ok(Self { backends, graphs: Table::default(), @@ -45,6 +45,22 @@ pub enum WasiNnError { UsageError(#[from] UsageError), } +#[derive(Debug, Error)] +pub enum UsageError { + #[error("Invalid context; has the load function been called?")] + InvalidContext, + #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] + InvalidEncoding(GraphEncoding), + #[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")] + InvalidNumberOfBuilders(u32), + #[error("Invalid graph handle; has it been loaded?")] + InvalidGraphHandle, + #[error("Invalid execution context handle; has it been initialized?")] + InvalidExecutionContextHandle, + #[error("Not enough memory to copy tensor data of size: {0}")] + NotEnoughMemory(u32), +} + pub(crate) type WasiNnResult = std::result::Result; /// Record handle entries in a table. diff --git a/crates/wasi-nn/src/impl.rs b/crates/wasi-nn/src/impl.rs deleted file mode 100644 index 0f8da5247a7b..000000000000 --- a/crates/wasi-nn/src/impl.rs +++ /dev/null @@ -1,93 +0,0 @@ -//! Implements the wasi-nn API. -use crate::ctx::WasiNnResult as Result; -use crate::witx::types::{ - ExecutionTarget, Graph, GraphBuilderArray, GraphEncoding, GraphExecutionContext, Tensor, -}; -use crate::witx::wasi_ephemeral_nn::WasiEphemeralNn; -use crate::WasiNnCtx; -use thiserror::Error; -use wiggle::GuestPtr; - -#[derive(Debug, Error)] -pub enum UsageError { - #[error("Invalid context; has the load function been called?")] - InvalidContext, - #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] - InvalidEncoding(GraphEncoding), - #[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")] - InvalidNumberOfBuilders(u32), - #[error("Invalid graph handle; has it been loaded?")] - InvalidGraphHandle, - #[error("Invalid execution context handle; has it been initialized?")] - InvalidExecutionContextHandle, - #[error("Not enough memory to copy tensor data of size: {0}")] - NotEnoughMemory(u32), -} - -impl<'a> WasiEphemeralNn for WasiNnCtx { - fn load<'b>( - &mut self, - builders: &GraphBuilderArray<'_>, - encoding: GraphEncoding, - target: ExecutionTarget, - ) -> Result { - let encoding_id: u8 = encoding.into(); - let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) { - backend.load(builders, target)? - } else { - return Err(UsageError::InvalidEncoding(encoding).into()); - }; - let graph_id = self.graphs.insert(graph); - Ok(graph_id) - } - - fn init_execution_context(&mut self, graph_id: Graph) -> Result { - let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { - graph.init_execution_context()? - } else { - return Err(UsageError::InvalidGraphHandle.into()); - }; - - let exec_context_id = self.executions.insert(exec_context); - Ok(exec_context_id) - } - - fn set_input<'b>( - &mut self, - exec_context_id: GraphExecutionContext, - index: u32, - tensor: &Tensor<'b>, - ) -> Result<()> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - Ok(exec_context.set_input(index, tensor)?) - } else { - Err(UsageError::InvalidGraphHandle.into()) - } - } - - fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - Ok(exec_context.compute()?) - } else { - Err(UsageError::InvalidExecutionContextHandle.into()) - } - } - - fn get_output<'b>( - &mut self, - exec_context_id: GraphExecutionContext, - index: u32, - out_buffer: &GuestPtr<'_, u8>, - out_buffer_max_size: u32, - ) -> Result { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - let mut destination = out_buffer - .as_array(out_buffer_max_size) - .as_slice_mut()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); - Ok(exec_context.get_output(index, &mut destination)?) - } else { - Err(UsageError::InvalidGraphHandle.into()) - } - } -} diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 7efae80f6159..541833a8ed60 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -1,8 +1,8 @@ -mod api; +mod backend; mod ctx; -mod r#impl; -mod openvino; -mod witx; pub use ctx::WasiNnCtx; -pub use witx::wasi_ephemeral_nn::add_to_linker; +pub mod preview1; +#[cfg(feature = "preview2")] +pub mod preview2; +pub mod types; diff --git a/crates/wasi-nn/src/preview1.rs b/crates/wasi-nn/src/preview1.rs new file mode 100644 index 000000000000..92cac17dc750 --- /dev/null +++ b/crates/wasi-nn/src/preview1.rs @@ -0,0 +1,174 @@ +//! Implements the `wasi-nn` API for a "preview1" ABI. +//! +//! Note that `wasi-nn` was never included in the official "preview1" snapshot, +//! but the naming here means that the `wasi-nn` imports can be called with the +//! original "preview1" ABI. +//! +//! The only export from this module is [`add_to_linker`]. To implement it, +//! this module proceeds in steps: +//! 1. generate all of the Wiggle glue code into a `witx::*` namespace +//! 2. wire up the `witx::*` glue to the context state, delegating actual +//! computation to a `Backend` +//! 3. wrap up with some conversions, i.e., from `witx::*` types to this crate's +//! [`types`]. +//! +//! [`types`]: crate::types + +use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; +use wiggle::GuestPtr; + +pub use witx::wasi_ephemeral_nn::add_to_linker; + +/// Generate the traits and types from the `wasi-nn` WITX specification. +mod witx { + use super::*; + wiggle::from_witx!({ + witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"], + errors: { nn_errno => WasiNnError } + }); + + /// Additionally, we must let Wiggle know which of our error codes + /// represents a successful operation. + impl wiggle::GuestErrorType for types::NnErrno { + fn success() -> Self { + Self::Success + } + } + + /// Convert the host errors to their WITX-generated type. + impl<'a> types::UserErrorConversion for WasiNnCtx { + fn nn_errno_from_wasi_nn_error( + &mut self, + e: WasiNnError, + ) -> anyhow::Result { + eprintln!("Host error: {:?}", e); + match e { + WasiNnError::BackendError(_) => unimplemented!(), + WasiNnError::GuestError(_) => unimplemented!(), + WasiNnError::UsageError(_) => unimplemented!(), + } + } + } +} + +/// Wire up the WITX-generated trait to the `wasi-nn` host state. +impl<'a> witx::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { + fn load<'b>( + &mut self, + builders: &witx::types::GraphBuilderArray<'_>, + encoding: witx::types::GraphEncoding, + target: witx::types::ExecutionTarget, + ) -> Result { + let encoding_id: u8 = encoding.into(); + let graph = if let Some(backend) = self.backends.get_mut(&encoding_id.into()) { + // Retrieve all of the "builder lists" from the Wasm memory (see + // $graph_builder_array) as slices for a backend to operate on. + let mut slices = vec![]; + for builder in builders.iter() { + let slice = builder? + .read()? + .as_slice()? + .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + slices.push(slice); + } + let slice_refs = slices.iter().map(|s| s.as_ref()).collect::>(); + backend.load(&slice_refs, target.into())? + } else { + return Err(UsageError::InvalidEncoding(encoding.into()).into()); + }; + let graph_id = self.graphs.insert(graph); + Ok(graph_id.into()) + } + + fn init_execution_context( + &mut self, + graph_id: witx::types::Graph, + ) -> Result { + let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id.into()) { + graph.init_execution_context()? + } else { + return Err(UsageError::InvalidGraphHandle.into()); + }; + + let exec_context_id = self.executions.insert(exec_context); + Ok(exec_context_id.into()) + } + + fn set_input<'b>( + &mut self, + exec_context_id: witx::types::GraphExecutionContext, + index: u32, + tensor: &witx::types::Tensor<'b>, + ) -> Result<()> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { + let mut dims = vec![]; + for d in tensor.dimensions.iter() { + dims.push(d?.read()? as usize); + } + let ty = tensor.type_.into(); + let data_ = tensor.data + .as_slice()? + .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + let data = data_.as_ref(); + let dims = &dims; + Ok(exec_context.set_input(index, &crate::types::Tensor { dims, ty, data })?) + } else { + Err(UsageError::InvalidGraphHandle.into()) + } + } + + fn compute(&mut self, exec_context_id: witx::types::GraphExecutionContext) -> Result<()> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { + Ok(exec_context.compute()?) + } else { + Err(UsageError::InvalidExecutionContextHandle.into()) + } + } + + fn get_output<'b>( + &mut self, + exec_context_id: witx::types::GraphExecutionContext, + index: u32, + out_buffer: &GuestPtr<'_, u8>, + out_buffer_max_size: u32, + ) -> Result { + if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { + let mut destination = out_buffer + .as_array(out_buffer_max_size) + .as_slice_mut()? + .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); + Ok(exec_context.get_output(index, &mut destination)?) + } else { + Err(UsageError::InvalidGraphHandle.into()) + } + } +} + +// Implement some conversion from `witx::types::*` to this crate's version. + +impl From for crate::types::ExecutionTarget { + fn from(value: witx::types::ExecutionTarget) -> Self { + match value { + witx::types::ExecutionTarget::Cpu => crate::types::ExecutionTarget::CPU, + witx::types::ExecutionTarget::Gpu => crate::types::ExecutionTarget::GPU, + witx::types::ExecutionTarget::Tpu => crate::types::ExecutionTarget::TPU, + } + } +} +impl From for crate::types::GraphEncoding { + fn from(value: witx::types::GraphEncoding) -> Self { + match value { + witx::types::GraphEncoding::Openvino => crate::types::GraphEncoding::OpenVINO, + } + } +} +impl From for crate::types::TensorType { + fn from(value: witx::types::TensorType) -> Self { + match value { + witx::types::TensorType::F16 => crate::types::TensorType::F16, + witx::types::TensorType::F32 => crate::types::TensorType::F32, + witx::types::TensorType::U8 => crate::types::TensorType::U8, + witx::types::TensorType::I32 => crate::types::TensorType::I32, + } + } +} diff --git a/crates/wasi-nn/src/preview2.rs b/crates/wasi-nn/src/preview2.rs new file mode 100644 index 000000000000..9cd3e8a0cba5 --- /dev/null +++ b/crates/wasi-nn/src/preview2.rs @@ -0,0 +1,165 @@ +//! Implements the `wasi-nn` API for a "preview2" ABI. +//! +//! Note that `wasi-nn` is not yet included in an official "preview2" world +//! (though it could be) so by "preview2" here we mean that this can be called +//! with the component model's canonical ABI. +//! +//! The only export from this module is the [`ML`] object, which exposes +//! [`ML::add_to_linker`]. To implement it, this module proceeds in steps: +//! 1. generate all of the WIT glue code into a `wit::*` namespace +//! 2. wire up the `wit::*` glue to the context state, delegating actual +//! computation to a `Backend` +//! 3. wrap up with some conversions, i.e., from `wit::*` types to this crate's +//! [`types`]. +//! +//! [`Backend`]: crate::backend::Backend +//! [`types`]: crate::types + +use crate::{backend::BackendKind, ctx::UsageError, WasiNnCtx}; + +pub use wit_::Ml as ML; + +/// Generate the traits and types from the `wasi-nn` WIT specification. +mod wit_ { + wasmtime::component::bindgen!("ml"); +} +use wit_::wasi::nn as wit; // Shortcut to the module containing the types we need. + +impl wit::inference::Host for WasiNnCtx { + /// Load an opaque sequence of bytes to use for inference. + fn load( + &mut self, + builders: wit::types::GraphBuilderArray, + encoding: wit::types::GraphEncoding, + target: wit::types::ExecutionTarget, + ) -> wasmtime::Result> { + let backend_kind: BackendKind = encoding.try_into()?; + let graph = if let Some(backend) = self.backends.get_mut(&backend_kind) { + let slices = builders.iter().map(|s| s.as_slice()).collect::>(); + backend.load(&slices, target.into())? + } else { + return Err(UsageError::InvalidEncoding(encoding.into()).into()); + }; + let graph_id = self.graphs.insert(graph); + Ok(Ok(graph_id)) + } + + /// Create an execution instance of a loaded graph. + /// + /// TODO: remove completely? + fn init_execution_context( + &mut self, + graph_id: wit::types::Graph, + ) -> wasmtime::Result> { + let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { + graph.init_execution_context()? + } else { + return Err(UsageError::InvalidGraphHandle.into()); + }; + + let exec_context_id = self.executions.insert(exec_context); + Ok(Ok(exec_context_id)) + } + + /// Define the inputs to use for inference. + fn set_input( + &mut self, + exec_context_id: wit::types::GraphExecutionContext, + index: u32, + tensor: wit::types::Tensor, + ) -> wasmtime::Result> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id) { + let dims = &tensor + .dimensions + .iter() + .map(|d| *d as usize) + .collect::>(); + let ty = tensor.tensor_type.into(); + let data = tensor.data.as_slice(); + exec_context.set_input(index, &crate::types::Tensor { dims, ty, data })?; + Ok(Ok(())) + } else { + Err(UsageError::InvalidGraphHandle.into()) + } + } + + /// Compute the inference on the given inputs. + /// + /// TODO: refactor to compute(list) -> result, error> + fn compute( + &mut self, + exec_context_id: wit::types::GraphExecutionContext, + ) -> wasmtime::Result> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id) { + exec_context.compute()?; + Ok(Ok(())) + } else { + Err(UsageError::InvalidExecutionContextHandle.into()) + } + } + + /// Extract the outputs after inference. + fn get_output( + &mut self, + exec_context_id: wit::types::GraphExecutionContext, + index: u32, + ) -> wasmtime::Result> { + if let Some(exec_context) = self.executions.get_mut(exec_context_id) { + // Read the output bytes. TODO: this involves a hard-coded upper + // limit on the tensor size that is necessary because there is no + // way to introspect the graph outputs + // (https://github.com/WebAssembly/wasi-nn/issues/37). + let mut destination = vec![0; 1024 * 1024]; + let bytes_read = exec_context.get_output(index, &mut destination)?; + destination.truncate(bytes_read as usize); + Ok(Ok(destination)) + } else { + Err(UsageError::InvalidGraphHandle.into()) + } + } +} + +impl From for crate::types::GraphEncoding { + fn from(value: wit::types::GraphEncoding) -> Self { + match value { + wit::types::GraphEncoding::Openvino => crate::types::GraphEncoding::OpenVINO, + wit::types::GraphEncoding::Onnx => crate::types::GraphEncoding::ONNX, + wit::types::GraphEncoding::Tensorflow => crate::types::GraphEncoding::Tensorflow, + wit::types::GraphEncoding::Pytorch => crate::types::GraphEncoding::PyTorch, + wit::types::GraphEncoding::Tensorflowlite => { + crate::types::GraphEncoding::TensorflowLite + } + } + } +} + +impl TryFrom for crate::backend::BackendKind { + type Error = UsageError; + fn try_from(value: wit::types::GraphEncoding) -> Result { + match value { + wit::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO), + _ => Err(UsageError::InvalidEncoding(value.into())), + } + } +} + +impl From for crate::types::ExecutionTarget { + fn from(value: wit::types::ExecutionTarget) -> Self { + match value { + wit::types::ExecutionTarget::Cpu => crate::types::ExecutionTarget::CPU, + wit::types::ExecutionTarget::Gpu => crate::types::ExecutionTarget::GPU, + wit::types::ExecutionTarget::Tpu => crate::types::ExecutionTarget::TPU, + } + } +} + +impl From for crate::types::TensorType { + fn from(value: wit::types::TensorType) -> Self { + match value { + wit::types::TensorType::Fp16 => crate::types::TensorType::F16, + wit::types::TensorType::Fp32 => crate::types::TensorType::F32, + wit::types::TensorType::U8 => crate::types::TensorType::U8, + wit::types::TensorType::I32 => crate::types::TensorType::I32, + } + } +} diff --git a/crates/wasi-nn/src/types.rs b/crates/wasi-nn/src/types.rs new file mode 100644 index 000000000000..ae23f0d8b443 --- /dev/null +++ b/crates/wasi-nn/src/types.rs @@ -0,0 +1,39 @@ +//! The `wasi-nn` types used internally in this crate. +//! +//! These types form a common "ground truth" for the [`preview1`] and +//! [`preview2`] types to be converted from and to. As such, these types should +//! be kept up to date with the WIT and WITX specifications; if anything changes +//! in the specifications, we should see compile errors in the conversion +//! functions (e.g., `impl From for `crate::...`). +//! +//! [`preview1`]: crate::preview1 +//! [`preview2`]: crate::preview2 + +pub struct Tensor<'a> { + pub dims: &'a [usize], + pub ty: TensorType, + pub data: &'a [u8], +} + +#[derive(Clone, Copy)] +pub enum TensorType { + F16, + F32, + U8, + I32, +} + +pub enum ExecutionTarget { + CPU, + GPU, + TPU, +} + +#[derive(Debug)] +pub enum GraphEncoding { + OpenVINO, + ONNX, + Tensorflow, + PyTorch, + TensorflowLite, +} diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs deleted file mode 100644 index e7c877bd907e..000000000000 --- a/crates/wasi-nn/src/witx.rs +++ /dev/null @@ -1,30 +0,0 @@ -//! Contains the macro-generated implementation of wasi-nn from the its witx definition file. -use crate::ctx::WasiNnCtx; -use crate::ctx::WasiNnError; -use anyhow::Result; - -// Generate the traits and types of wasi-nn in several Rust modules (e.g. `types`). -wiggle::from_witx!({ - witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"], - errors: { nn_errno => WasiNnError } -}); - -use types::NnErrno; - -impl<'a> types::UserErrorConversion for WasiNnCtx { - fn nn_errno_from_wasi_nn_error(&mut self, e: WasiNnError) -> Result { - eprintln!("Host error: {:?}", e); - match e { - WasiNnError::BackendError(_) => unimplemented!(), - WasiNnError::GuestError(_) => unimplemented!(), - WasiNnError::UsageError(_) => unimplemented!(), - } - } -} - -/// Additionally, we must let Wiggle know which of our error codes represents a successful operation. -impl wiggle::GuestErrorType for NnErrno { - fn success() -> Self { - Self::Success - } -} diff --git a/crates/wasi-nn/wit/inference.wit b/crates/wasi-nn/wit/inference.wit new file mode 100644 index 000000000000..df754231f696 --- /dev/null +++ b/crates/wasi-nn/wit/inference.wit @@ -0,0 +1,24 @@ +interface inference { + use types.{graph-builder-array, graph-encoding, execution-target, graph, + tensor, tensor-data, error, graph-execution-context} + + /// Load an opaque sequence of bytes to use for inference. + load: func(builder: graph-builder-array, encoding: graph-encoding, + target: execution-target) -> result + + /// Create an execution instance of a loaded graph. + /// + /// TODO: remove completely? + init-execution-context: func(graph: graph) -> result + + /// Define the inputs to use for inference. + set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error> + + /// Compute the inference on the given inputs. + /// + /// TODO: refactor to compute(list) -> result, error> + compute: func(ctx: graph-execution-context) -> result<_, error> + + /// Extract the outputs after inference. + get-output: func(ctx: graph-execution-context, index: u32) -> result +} diff --git a/crates/wasi-nn/wit/types.wit b/crates/wasi-nn/wit/types.wit new file mode 100644 index 000000000000..f134b730e1a8 --- /dev/null +++ b/crates/wasi-nn/wit/types.wit @@ -0,0 +1,88 @@ +interface types { + /// The dimensions of a tensor. + /// + /// The array length matches the tensor rank and each element in the array + /// describes the size of each dimension. + type tensor-dimensions = list + + /// The type of the elements in a tensor. + enum tensor-type { + FP16, + FP32, + U8, + I32 + } + + /// The tensor data. + /// + /// Initially conceived as a sparse representation, each empty cell would be filled with zeros and + /// the array length must match the product of all of the dimensions and the number of bytes in the + /// type (e.g., a 2x2 tensor with 4-byte f32 elements would have a data array of length 16). + /// Naturally, this representation requires some knowledge of how to lay out data in memory--e.g., + /// using row-major ordering--and could perhaps be improved. + type tensor-data = list + + /// A tensor. + record tensor { + /// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor + /// containing a single value, use `[1]` for the tensor dimensions. + dimensions: tensor-dimensions, + + /// Describe the type of element in the tensor (e.g., f32). + tensor-type: tensor-type, + + /// Contains the tensor data. + data: tensor-data, + } + + /// The graph initialization data. + // + /// This consists of an array of buffers because implementing backends may encode their graph IR in + /// parts (e.g., OpenVINO stores its IR and weights separately). + type graph-builder = list + type graph-builder-array = list + + /// An execution graph for performing inference (i.e., a model). + /// + /// TODO: replace with `resource` + type graph = u32 + + /// Describes the encoding of the graph. This allows the API to be implemented by various backends + /// that encode (i.e., serialize) their graph IR with different formats. + enum graph-encoding { + openvino, + onnx, + tensorflow, + pytorch, + tensorflowlite + } + + /// Define where the graph should be executed. + enum execution-target { + cpu, + gpu, + tpu + } + + /// Bind a `graph` to the input and output tensors for an inference. + /// + /// TODO: replace with `resource` + /// TODO: remove execution contexts completely + type graph-execution-context = u32 + + /// Error codes returned by functions in this API. + enum error { + /// No error occurred. + success, + /// Caller module passed an invalid argument. + invalid-argument, + /// Invalid encoding. + invalid-encoding, + /// Caller module is missing a memory export. + missing-memory, + /// Device or resource busy. + busy, + /// Runtime Error. + runtime-error, + } +} diff --git a/crates/wasi-nn/wit/world.wit b/crates/wasi-nn/wit/world.wit new file mode 100644 index 000000000000..42bffb93420c --- /dev/null +++ b/crates/wasi-nn/wit/world.wit @@ -0,0 +1,20 @@ +/// `wasi-nn` API +/// +/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The +/// API is not (yet) capable of performing ML training. WebAssembly programs +/// that want to use a host's ML capabilities can access these capabilities +/// through `wasi-nn`'s core abstractions: _backends_, _graphs_, and _tensors_. +/// A user selects a _backend_ for inference and `load`s a model, instantiated +/// as a _graph_, to use in the _backend_. Then, the user passes _tensor_ inputs +/// to the _graph_, computes the inference, and retrieves the _tensor_ outputs. +/// +/// This module draws inspiration from the inference side of +/// [WebNN](https://webmachinelearning.github.io/webnn/#api). See the +/// [README](https://github.com/WebAssembly/wasi-nn/blob/main/README.md) for +/// more context about the design and goals of this API. + +package wasi:nn + +world ml { + import inference +} From 34ef49552bd4ca2000297695a23e74089cd89f52 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Tue, 8 Aug 2023 15:10:08 -0700 Subject: [PATCH 02/10] wasi-nn: use `preview1` linkage prtest:full --- src/commands/run.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commands/run.rs b/src/commands/run.rs index 8a124291a0d0..3fc067e84d7b 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -722,7 +722,7 @@ fn populate_with_wasi( } #[cfg(feature = "wasi-nn")] { - wasmtime_wasi_nn::add_to_linker(linker, |host| { + wasmtime_wasi_nn::preview1::add_to_linker(linker, |host| { // This WASI proposal is currently not protected against // concurrent access--i.e., when wasi-threads is actively // spawning new threads, we cannot (yet) safely allow access and From 45a03749722082d9a45b7ad5f611c62d21fc4bb2 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 9 Aug 2023 15:09:36 -0700 Subject: [PATCH 03/10] review: rename `preview*` to `wit*` This is based on @pchickey's [comments] on ABI naming. [comments]: https://bytecodealliance.zulipchat.com/#narrow/stream/266558-wasi-nn/topic/wasi-nn.20.2B.20preview2/near/383368292 --- crates/wasi-nn/Cargo.toml | 2 +- crates/wasi-nn/src/lib.rs | 6 +- crates/wasi-nn/src/{preview2.rs => wit.rs} | 84 ++++++++++----------- crates/wasi-nn/src/{preview1.rs => witx.rs} | 71 +++++++++-------- src/commands/run.rs | 2 +- 5 files changed, 82 insertions(+), 83 deletions(-) rename crates/wasi-nn/src/{preview2.rs => wit.rs} (63%) rename crates/wasi-nn/src/{preview1.rs => witx.rs} (67%) diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index d48f769d3fdb..0931138cf32b 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -27,4 +27,4 @@ thiserror = { workspace = true } walkdir = { workspace = true } [features] -preview2 = ["wasmtime"] +component-model = ["wasmtime"] diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 541833a8ed60..1b116609352a 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -2,7 +2,7 @@ mod backend; mod ctx; pub use ctx::WasiNnCtx; -pub mod preview1; -#[cfg(feature = "preview2")] -pub mod preview2; pub mod types; +#[cfg(feature = "component-model")] +pub mod wit; +pub mod witx; diff --git a/crates/wasi-nn/src/preview2.rs b/crates/wasi-nn/src/wit.rs similarity index 63% rename from crates/wasi-nn/src/preview2.rs rename to crates/wasi-nn/src/wit.rs index 9cd3e8a0cba5..7853de25de1a 100644 --- a/crates/wasi-nn/src/preview2.rs +++ b/crates/wasi-nn/src/wit.rs @@ -1,4 +1,4 @@ -//! Implements the `wasi-nn` API for a "preview2" ABI. +//! Implements the `wasi-nn` API for the WIT ("preview2") ABI. //! //! Note that `wasi-nn` is not yet included in an official "preview2" world //! (though it could be) so by "preview2" here we mean that this can be called @@ -6,10 +6,10 @@ //! //! The only export from this module is the [`ML`] object, which exposes //! [`ML::add_to_linker`]. To implement it, this module proceeds in steps: -//! 1. generate all of the WIT glue code into a `wit::*` namespace -//! 2. wire up the `wit::*` glue to the context state, delegating actual +//! 1. generate all of the WIT glue code into a `gen::*` namespace +//! 2. wire up the `gen::*` glue to the context state, delegating actual //! computation to a `Backend` -//! 3. wrap up with some conversions, i.e., from `wit::*` types to this crate's +//! 3. wrap up with some conversions, i.e., from `gen::*` types to this crate's //! [`types`]. //! //! [`Backend`]: crate::backend::Backend @@ -17,22 +17,22 @@ use crate::{backend::BackendKind, ctx::UsageError, WasiNnCtx}; -pub use wit_::Ml as ML; +pub use gen_::Ml as ML; /// Generate the traits and types from the `wasi-nn` WIT specification. -mod wit_ { +mod gen_ { wasmtime::component::bindgen!("ml"); } -use wit_::wasi::nn as wit; // Shortcut to the module containing the types we need. +use gen_::wasi::nn as wit; // Shortcut to the module containing the types we need. -impl wit::inference::Host for WasiNnCtx { +impl gen::inference::Host for WasiNnCtx { /// Load an opaque sequence of bytes to use for inference. fn load( &mut self, - builders: wit::types::GraphBuilderArray, - encoding: wit::types::GraphEncoding, - target: wit::types::ExecutionTarget, - ) -> wasmtime::Result> { + builders: gen::types::GraphBuilderArray, + encoding: gen::types::GraphEncoding, + target: gen::types::ExecutionTarget, + ) -> wasmtime::Result> { let backend_kind: BackendKind = encoding.try_into()?; let graph = if let Some(backend) = self.backends.get_mut(&backend_kind) { let slices = builders.iter().map(|s| s.as_slice()).collect::>(); @@ -49,8 +49,8 @@ impl wit::inference::Host for WasiNnCtx { /// TODO: remove completely? fn init_execution_context( &mut self, - graph_id: wit::types::Graph, - ) -> wasmtime::Result> { + graph_id: gen::types::Graph, + ) -> wasmtime::Result> { let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { graph.init_execution_context()? } else { @@ -64,10 +64,10 @@ impl wit::inference::Host for WasiNnCtx { /// Define the inputs to use for inference. fn set_input( &mut self, - exec_context_id: wit::types::GraphExecutionContext, + exec_context_id: gen::types::GraphExecutionContext, index: u32, - tensor: wit::types::Tensor, - ) -> wasmtime::Result> { + tensor: gen::types::Tensor, + ) -> wasmtime::Result> { if let Some(exec_context) = self.executions.get_mut(exec_context_id) { let dims = &tensor .dimensions @@ -88,8 +88,8 @@ impl wit::inference::Host for WasiNnCtx { /// TODO: refactor to compute(list) -> result, error> fn compute( &mut self, - exec_context_id: wit::types::GraphExecutionContext, - ) -> wasmtime::Result> { + exec_context_id: gen::types::GraphExecutionContext, + ) -> wasmtime::Result> { if let Some(exec_context) = self.executions.get_mut(exec_context_id) { exec_context.compute()?; Ok(Ok(())) @@ -101,9 +101,9 @@ impl wit::inference::Host for WasiNnCtx { /// Extract the outputs after inference. fn get_output( &mut self, - exec_context_id: wit::types::GraphExecutionContext, + exec_context_id: gen::types::GraphExecutionContext, index: u32, - ) -> wasmtime::Result> { + ) -> wasmtime::Result> { if let Some(exec_context) = self.executions.get_mut(exec_context_id) { // Read the output bytes. TODO: this involves a hard-coded upper // limit on the tensor size that is necessary because there is no @@ -119,47 +119,47 @@ impl wit::inference::Host for WasiNnCtx { } } -impl From for crate::types::GraphEncoding { - fn from(value: wit::types::GraphEncoding) -> Self { +impl From for crate::types::GraphEncoding { + fn from(value: gen::types::GraphEncoding) -> Self { match value { - wit::types::GraphEncoding::Openvino => crate::types::GraphEncoding::OpenVINO, - wit::types::GraphEncoding::Onnx => crate::types::GraphEncoding::ONNX, - wit::types::GraphEncoding::Tensorflow => crate::types::GraphEncoding::Tensorflow, - wit::types::GraphEncoding::Pytorch => crate::types::GraphEncoding::PyTorch, - wit::types::GraphEncoding::Tensorflowlite => { + gen::types::GraphEncoding::Openvino => crate::types::GraphEncoding::OpenVINO, + gen::types::GraphEncoding::Onnx => crate::types::GraphEncoding::ONNX, + gen::types::GraphEncoding::Tensorflow => crate::types::GraphEncoding::Tensorflow, + gen::types::GraphEncoding::Pytorch => crate::types::GraphEncoding::PyTorch, + gen::types::GraphEncoding::Tensorflowlite => { crate::types::GraphEncoding::TensorflowLite } } } } -impl TryFrom for crate::backend::BackendKind { +impl TryFrom for crate::backend::BackendKind { type Error = UsageError; - fn try_from(value: wit::types::GraphEncoding) -> Result { + fn try_from(value: gen::types::GraphEncoding) -> Result { match value { - wit::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO), + gen::types::GraphEncoding::Openvino => Ok(crate::backend::BackendKind::OpenVINO), _ => Err(UsageError::InvalidEncoding(value.into())), } } } -impl From for crate::types::ExecutionTarget { - fn from(value: wit::types::ExecutionTarget) -> Self { +impl From for crate::types::ExecutionTarget { + fn from(value: gen::types::ExecutionTarget) -> Self { match value { - wit::types::ExecutionTarget::Cpu => crate::types::ExecutionTarget::CPU, - wit::types::ExecutionTarget::Gpu => crate::types::ExecutionTarget::GPU, - wit::types::ExecutionTarget::Tpu => crate::types::ExecutionTarget::TPU, + gen::types::ExecutionTarget::Cpu => crate::types::ExecutionTarget::CPU, + gen::types::ExecutionTarget::Gpu => crate::types::ExecutionTarget::GPU, + gen::types::ExecutionTarget::Tpu => crate::types::ExecutionTarget::TPU, } } } -impl From for crate::types::TensorType { - fn from(value: wit::types::TensorType) -> Self { +impl From for crate::types::TensorType { + fn from(value: gen::types::TensorType) -> Self { match value { - wit::types::TensorType::Fp16 => crate::types::TensorType::F16, - wit::types::TensorType::Fp32 => crate::types::TensorType::F32, - wit::types::TensorType::U8 => crate::types::TensorType::U8, - wit::types::TensorType::I32 => crate::types::TensorType::I32, + gen::types::TensorType::Fp16 => crate::types::TensorType::F16, + gen::types::TensorType::Fp32 => crate::types::TensorType::F32, + gen::types::TensorType::U8 => crate::types::TensorType::U8, + gen::types::TensorType::I32 => crate::types::TensorType::I32, } } } diff --git a/crates/wasi-nn/src/preview1.rs b/crates/wasi-nn/src/witx.rs similarity index 67% rename from crates/wasi-nn/src/preview1.rs rename to crates/wasi-nn/src/witx.rs index 92cac17dc750..eaca434182f9 100644 --- a/crates/wasi-nn/src/preview1.rs +++ b/crates/wasi-nn/src/witx.rs @@ -1,15 +1,14 @@ -//! Implements the `wasi-nn` API for a "preview1" ABI. +//! Implements the `wasi-nn` API for the WITX ("preview1") ABI. //! -//! Note that `wasi-nn` was never included in the official "preview1" snapshot, -//! but the naming here means that the `wasi-nn` imports can be called with the -//! original "preview1" ABI. +//! `wasi-nn` was never included in the official "preview1" snapshot, but this +//! module implements the ABI that is compatible with "preview1". //! -//! The only export from this module is [`add_to_linker`]. To implement it, -//! this module proceeds in steps: -//! 1. generate all of the Wiggle glue code into a `witx::*` namespace -//! 2. wire up the `witx::*` glue to the context state, delegating actual +//! The only export from this module is [`add_to_linker`]. To implement it, this +//! module proceeds in steps: +//! 1. generate all of the Wiggle glue code into a `gen::*` namespace +//! 2. wire up the `gen::*` glue to the context state, delegating actual //! computation to a `Backend` -//! 3. wrap up with some conversions, i.e., from `witx::*` types to this crate's +//! 3. wrap up with some conversions, i.e., from `gen::*` types to this crate's //! [`types`]. //! //! [`types`]: crate::types @@ -17,10 +16,10 @@ use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; use wiggle::GuestPtr; -pub use witx::wasi_ephemeral_nn::add_to_linker; +pub use gen::wasi_ephemeral_nn::add_to_linker; /// Generate the traits and types from the `wasi-nn` WITX specification. -mod witx { +mod gen { use super::*; wiggle::from_witx!({ witx: ["$WASI_ROOT/phases/ephemeral/witx/wasi_ephemeral_nn.witx"], @@ -52,13 +51,13 @@ mod witx { } /// Wire up the WITX-generated trait to the `wasi-nn` host state. -impl<'a> witx::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { +impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { fn load<'b>( &mut self, - builders: &witx::types::GraphBuilderArray<'_>, - encoding: witx::types::GraphEncoding, - target: witx::types::ExecutionTarget, - ) -> Result { + builders: &gen::types::GraphBuilderArray<'_>, + encoding: gen::types::GraphEncoding, + target: gen::types::ExecutionTarget, + ) -> Result { let encoding_id: u8 = encoding.into(); let graph = if let Some(backend) = self.backends.get_mut(&encoding_id.into()) { // Retrieve all of the "builder lists" from the Wasm memory (see @@ -82,8 +81,8 @@ impl<'a> witx::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { fn init_execution_context( &mut self, - graph_id: witx::types::Graph, - ) -> Result { + graph_id: gen::types::Graph, + ) -> Result { let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id.into()) { graph.init_execution_context()? } else { @@ -96,9 +95,9 @@ impl<'a> witx::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { fn set_input<'b>( &mut self, - exec_context_id: witx::types::GraphExecutionContext, + exec_context_id: gen::types::GraphExecutionContext, index: u32, - tensor: &witx::types::Tensor<'b>, + tensor: &gen::types::Tensor<'b>, ) -> Result<()> { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { let mut dims = vec![]; @@ -117,7 +116,7 @@ impl<'a> witx::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { } } - fn compute(&mut self, exec_context_id: witx::types::GraphExecutionContext) -> Result<()> { + fn compute(&mut self, exec_context_id: gen::types::GraphExecutionContext) -> Result<()> { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { Ok(exec_context.compute()?) } else { @@ -127,7 +126,7 @@ impl<'a> witx::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { fn get_output<'b>( &mut self, - exec_context_id: witx::types::GraphExecutionContext, + exec_context_id: gen::types::GraphExecutionContext, index: u32, out_buffer: &GuestPtr<'_, u8>, out_buffer_max_size: u32, @@ -146,29 +145,29 @@ impl<'a> witx::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { // Implement some conversion from `witx::types::*` to this crate's version. -impl From for crate::types::ExecutionTarget { - fn from(value: witx::types::ExecutionTarget) -> Self { +impl From for crate::types::ExecutionTarget { + fn from(value: gen::types::ExecutionTarget) -> Self { match value { - witx::types::ExecutionTarget::Cpu => crate::types::ExecutionTarget::CPU, - witx::types::ExecutionTarget::Gpu => crate::types::ExecutionTarget::GPU, - witx::types::ExecutionTarget::Tpu => crate::types::ExecutionTarget::TPU, + gen::types::ExecutionTarget::Cpu => crate::types::ExecutionTarget::CPU, + gen::types::ExecutionTarget::Gpu => crate::types::ExecutionTarget::GPU, + gen::types::ExecutionTarget::Tpu => crate::types::ExecutionTarget::TPU, } } } -impl From for crate::types::GraphEncoding { - fn from(value: witx::types::GraphEncoding) -> Self { +impl From for crate::types::GraphEncoding { + fn from(value: gen::types::GraphEncoding) -> Self { match value { - witx::types::GraphEncoding::Openvino => crate::types::GraphEncoding::OpenVINO, + gen::types::GraphEncoding::Openvino => crate::types::GraphEncoding::OpenVINO, } } } -impl From for crate::types::TensorType { - fn from(value: witx::types::TensorType) -> Self { +impl From for crate::types::TensorType { + fn from(value: gen::types::TensorType) -> Self { match value { - witx::types::TensorType::F16 => crate::types::TensorType::F16, - witx::types::TensorType::F32 => crate::types::TensorType::F32, - witx::types::TensorType::U8 => crate::types::TensorType::U8, - witx::types::TensorType::I32 => crate::types::TensorType::I32, + gen::types::TensorType::F16 => crate::types::TensorType::F16, + gen::types::TensorType::F32 => crate::types::TensorType::F32, + gen::types::TensorType::U8 => crate::types::TensorType::U8, + gen::types::TensorType::I32 => crate::types::TensorType::I32, } } } diff --git a/src/commands/run.rs b/src/commands/run.rs index 3fc067e84d7b..7f9eddc9ff14 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -722,7 +722,7 @@ fn populate_with_wasi( } #[cfg(feature = "wasi-nn")] { - wasmtime_wasi_nn::preview1::add_to_linker(linker, |host| { + wasmtime_wasi_nn::witx::add_to_linker(linker, |host| { // This WASI proposal is currently not protected against // concurrent access--i.e., when wasi-threads is actively // spawning new threads, we cannot (yet) safely allow access and From e38b61b8a3ee1ae297430a6c0a52e65712170e49 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 9 Aug 2023 15:11:11 -0700 Subject: [PATCH 04/10] review: update README --- crates/wasi-nn/README.md | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/crates/wasi-nn/README.md b/crates/wasi-nn/README.md index 165b5063d002..4933fcba339c 100644 --- a/crates/wasi-nn/README.md +++ b/crates/wasi-nn/README.md @@ -1,38 +1,45 @@ # wasmtime-wasi-nn -This crate enables support for the [wasi-nn] API in Wasmtime. Currently it contains an implementation of [wasi-nn] using -OpenVINO™ but in the future it could support multiple machine learning backends. Since the [wasi-nn] API is expected -to be an optional feature of WASI, this crate is currently separate from the [wasi-common] crate. This crate is -experimental and its API, functionality, and location could quickly change. +This crate enables support for the [wasi-nn] API in Wasmtime. Currently it +contains an implementation of [wasi-nn] using OpenVINO™ but in the future it +could support multiple machine learning backends. Since the [wasi-nn] API is +expected to be an optional feature of WASI, this crate is currently separate +from the [wasi-common] crate. This crate is experimental and its API, +functionality, and location could quickly change. [examples]: examples [openvino]: https://crates.io/crates/openvino [wasi-nn]: https://github.com/WebAssembly/wasi-nn [wasi-common]: ../wasi-common +[bindings]: https://crates.io/crates/wasi-nn ### Use -Use the Wasmtime APIs to instantiate a Wasm module and link in the `WasiNn` implementation as follows: +Use the Wasmtime APIs to instantiate a Wasm module and link in the `wasi-nn` +implementation as follows: -``` -let wasi_nn = WasiNn::new(&store, WasiNnCtx::new()?); -wasi_nn.add_to_linker(&mut linker)?; +```rust +let wasi_nn = WasiNnCtx::new()?; +wasmtime_wasi_nn::witx::add_to_linker(...); ``` ### Build -This crate should build as usual (i.e. `cargo build`) but note that using an existing installation of OpenVINO™, rather -than building from source, will drastically improve the build times. See the [openvino] crate for more information +```sh +$ cargo build +``` + +To use the WIT-based ABI, compile with `--features component-model` and use `wasmtime_wasi_nn::wit::add_to_linker`. ### Example An end-to-end example demonstrating ML classification is included in [examples]: - - `tests/wasi-nn-rust-bindings` contains ergonomic bindings for writing Rust code against the [wasi-nn] APIs - - `tests/classification-example` contains a standalone Rust project that uses the [wasi-nn] APIs and is compiled to the - `wasm32-wasi` target using the `wasi-nn-rust-bindings` +`examples/classification-example` contains a standalone Rust project that uses +the [wasi-nn] APIs and is compiled to the `wasm32-wasi` target using the +high-level `wasi-nn` [bindings]. Run the example from the Wasmtime project directory: -``` -ci/run-wasi-nn-example.sh +```sh +$ ci/run-wasi-nn-example.sh ``` From 3efa7f5b8550a50f7e216724ef320e2fd240ddd1 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 9 Aug 2023 15:12:34 -0700 Subject: [PATCH 05/10] fix: remove broken doc links --- crates/wasi-nn/src/types.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/crates/wasi-nn/src/types.rs b/crates/wasi-nn/src/types.rs index ae23f0d8b443..4c0bf6566e68 100644 --- a/crates/wasi-nn/src/types.rs +++ b/crates/wasi-nn/src/types.rs @@ -1,13 +1,10 @@ //! The `wasi-nn` types used internally in this crate. //! -//! These types form a common "ground truth" for the [`preview1`] and -//! [`preview2`] types to be converted from and to. As such, these types should -//! be kept up to date with the WIT and WITX specifications; if anything changes -//! in the specifications, we should see compile errors in the conversion -//! functions (e.g., `impl From for `crate::...`). -//! -//! [`preview1`]: crate::preview1 -//! [`preview2`]: crate::preview2 +//! These types form a common "ground truth" for the `witx` and `wit` ABI types +//! to be converted from and to. As such, these types should be kept up to date +//! with the WIT and WITX specifications; if anything changes in the +//! specifications, we should see compile errors in the conversion functions +//! (e.g., `impl From for `crate::...`). pub struct Tensor<'a> { pub dims: &'a [usize], From aed6471e640cd3ee1c9138e3db1a4f8279f98dfa Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 9 Aug 2023 15:27:06 -0700 Subject: [PATCH 06/10] fix: replace typo `wit` with `gen` --- crates/wasi-nn/src/wit.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index 7853de25de1a..e757dbbc7bee 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -23,7 +23,7 @@ pub use gen_::Ml as ML; mod gen_ { wasmtime::component::bindgen!("ml"); } -use gen_::wasi::nn as wit; // Shortcut to the module containing the types we need. +use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need. impl gen::inference::Host for WasiNnCtx { /// Load an opaque sequence of bytes to use for inference. From fd36048c6376a93d18d2ce7b6e5ec87985876213 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 16 Aug 2023 09:53:57 -0700 Subject: [PATCH 07/10] review: use `wit` types everywhere This removes the crate-specific types in order to use the WIT-generated types throughout the crate. The main effect of this is that the crate no longer optionally includes `wasmtime` with the `component-model` feature--now that is required. --- crates/wasi-nn/Cargo.toml | 5 +- .../classification-example/Cargo.lock | 4 +- crates/wasi-nn/src/backend/mod.rs | 4 +- crates/wasi-nn/src/backend/openvino.rs | 25 ++++---- crates/wasi-nn/src/ctx.rs | 2 +- crates/wasi-nn/src/lib.rs | 2 - crates/wasi-nn/src/types.rs | 36 ------------ crates/wasi-nn/src/wit.rs | 57 +++---------------- crates/wasi-nn/src/witx.rs | 41 ++++++------- 9 files changed, 48 insertions(+), 128 deletions(-) delete mode 100644 crates/wasi-nn/src/types.rs diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index 0931138cf32b..fcd5206a1463 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -17,7 +17,7 @@ anyhow = { workspace = true } wiggle = { workspace = true } # This dependency is necessary for the WIT-generation macros to work: -wasmtime = { workspace = true, optional = true, features = ["component-model"] } +wasmtime = { workspace = true, features = ["component-model"] } # These dependencies are necessary for the wasi-nn implementation: openvino = { version = "0.5.0", features = ["runtime-linking"] } @@ -25,6 +25,3 @@ thiserror = { workspace = true } [build-dependencies] walkdir = { workspace = true } - -[features] -component-model = ["wasmtime"] diff --git a/crates/wasi-nn/examples/classification-example/Cargo.lock b/crates/wasi-nn/examples/classification-example/Cargo.lock index 0a2414873852..a649a0429289 100644 --- a/crates/wasi-nn/examples/classification-example/Cargo.lock +++ b/crates/wasi-nn/examples/classification-example/Cargo.lock @@ -1,5 +1,7 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + [[package]] name = "wasi-nn" version = "0.1.0" @@ -8,7 +10,7 @@ checksum = "0c909acded993dc129e02f64a7646eb7b53079f522a814024a88772f41558996" [[package]] name = "wasi-nn-example" -version = "0.19.0" +version = "0.0.0" dependencies = [ "wasi-nn", ] diff --git a/crates/wasi-nn/src/backend/mod.rs b/crates/wasi-nn/src/backend/mod.rs index 9f9d925b735b..b5317e525120 100644 --- a/crates/wasi-nn/src/backend/mod.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -5,7 +5,7 @@ mod openvino; use self::openvino::OpenvinoBackend; -use crate::types::{ExecutionTarget, Tensor}; +use crate::wit::types::{ExecutionTarget, Tensor}; use thiserror::Error; use wiggle::GuestError; @@ -33,7 +33,7 @@ pub(crate) trait BackendGraph: Send + Sync { /// A [BackendExecutionContext] performs the actual inference; this is the /// backing implementation for a [crate::witx::types::GraphExecutionContext]. pub(crate) trait BackendExecutionContext: Send + Sync { - fn set_input(&mut self, index: u32, tensor: &Tensor<'_>) -> Result<(), BackendError>; + fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>; fn compute(&mut self) -> Result<(), BackendError>; fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result; } diff --git a/crates/wasi-nn/src/backend/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs index 6f19c5208167..d29bd0a1f4f0 100644 --- a/crates/wasi-nn/src/backend/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -1,7 +1,7 @@ //! Implements a `wasi-nn` [`Backend`] using OpenVINO. use super::{Backend, BackendError, BackendExecutionContext, BackendGraph}; -use crate::types::{ExecutionTarget, Tensor, TensorType}; +use crate::wit::types::{ExecutionTarget, Tensor, TensorType}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; use std::sync::Arc; @@ -77,14 +77,19 @@ impl BackendGraph for OpenvinoGraph { struct OpenvinoExecutionContext(Arc, openvino::InferRequest); impl BackendExecutionContext for OpenvinoExecutionContext { - fn set_input<'a>(&mut self, index: u32, tensor: &Tensor<'a>) -> Result<(), BackendError> { + fn set_input<'a>(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { let input_name = self.0.get_input_name(index as usize)?; // Construct the blob structure. TODO: there must be some good way to // discover the layout here; `desc` should not have to default to NHWC. - let precision = map_tensor_type_to_precision(tensor.ty); - let desc = TensorDesc::new(Layout::NHWC, tensor.dims, precision); - let blob = openvino::Blob::new(&desc, tensor.data)?; + let precision = map_tensor_type_to_precision(tensor.tensor_type); + let dimensions = tensor + .dimensions + .iter() + .map(|&d| d as usize) + .collect::>(); + let desc = TensorDesc::new(Layout::NHWC, &dimensions, precision); + let blob = openvino::Blob::new(&desc, &tensor.data)?; // Actually assign the blob to the request. self.1.set_blob(&input_name, &blob)?; @@ -126,9 +131,9 @@ impl From for BackendError { /// `ExecutionTarget` enum provided by wasi-nn. fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str { match target { - ExecutionTarget::CPU => "CPU", - ExecutionTarget::GPU => "GPU", - ExecutionTarget::TPU => unimplemented!("OpenVINO does not support TPU execution targets"), + ExecutionTarget::Cpu => "CPU", + ExecutionTarget::Gpu => "GPU", + ExecutionTarget::Tpu => unimplemented!("OpenVINO does not support TPU execution targets"), } } @@ -136,8 +141,8 @@ fn map_execution_target_to_string(target: ExecutionTarget) -> &'static str { /// wasi-nn. fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision { match tensor_type { - TensorType::F16 => Precision::FP16, - TensorType::F32 => Precision::FP32, + TensorType::Fp16 => Precision::FP16, + TensorType::Fp32 => Precision::FP32, TensorType::U8 => Precision::U8, TensorType::I32 => Precision::I32, } diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs index c6891d11a8f5..e2cea15f9654 100644 --- a/crates/wasi-nn/src/ctx.rs +++ b/crates/wasi-nn/src/ctx.rs @@ -3,7 +3,7 @@ use crate::backend::{ self, Backend, BackendError, BackendExecutionContext, BackendGraph, BackendKind, }; -use crate::types::GraphEncoding; +use crate::wit::types::GraphEncoding; use std::collections::HashMap; use std::hash::Hash; use thiserror::Error; diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 1b116609352a..2cf8d6e8e56b 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -2,7 +2,5 @@ mod backend; mod ctx; pub use ctx::WasiNnCtx; -pub mod types; -#[cfg(feature = "component-model")] pub mod wit; pub mod witx; diff --git a/crates/wasi-nn/src/types.rs b/crates/wasi-nn/src/types.rs deleted file mode 100644 index 4c0bf6566e68..000000000000 --- a/crates/wasi-nn/src/types.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! The `wasi-nn` types used internally in this crate. -//! -//! These types form a common "ground truth" for the `witx` and `wit` ABI types -//! to be converted from and to. As such, these types should be kept up to date -//! with the WIT and WITX specifications; if anything changes in the -//! specifications, we should see compile errors in the conversion functions -//! (e.g., `impl From for `crate::...`). - -pub struct Tensor<'a> { - pub dims: &'a [usize], - pub ty: TensorType, - pub data: &'a [u8], -} - -#[derive(Clone, Copy)] -pub enum TensorType { - F16, - F32, - U8, - I32, -} - -pub enum ExecutionTarget { - CPU, - GPU, - TPU, -} - -#[derive(Debug)] -pub enum GraphEncoding { - OpenVINO, - ONNX, - Tensorflow, - PyTorch, - TensorflowLite, -} diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index e757dbbc7bee..e5a823bde939 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -4,19 +4,20 @@ //! (though it could be) so by "preview2" here we mean that this can be called //! with the component model's canonical ABI. //! -//! The only export from this module is the [`ML`] object, which exposes -//! [`ML::add_to_linker`]. To implement it, this module proceeds in steps: +//! This module exports its [`types`] for use throughout the crate and the +//! [`ML`] object, which exposes [`ML::add_to_linker`]. To implement all of +//! this, this module proceeds in steps: //! 1. generate all of the WIT glue code into a `gen::*` namespace //! 2. wire up the `gen::*` glue to the context state, delegating actual -//! computation to a `Backend` -//! 3. wrap up with some conversions, i.e., from `gen::*` types to this crate's -//! [`types`]. +//! computation to a [`Backend`] +//! 3. convert some types //! //! [`Backend`]: crate::backend::Backend -//! [`types`]: crate::types +//! [`types`]: crate::wit::types use crate::{backend::BackendKind, ctx::UsageError, WasiNnCtx}; +pub use gen::types; pub use gen_::Ml as ML; /// Generate the traits and types from the `wasi-nn` WIT specification. @@ -69,14 +70,7 @@ impl gen::inference::Host for WasiNnCtx { tensor: gen::types::Tensor, ) -> wasmtime::Result> { if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - let dims = &tensor - .dimensions - .iter() - .map(|d| *d as usize) - .collect::>(); - let ty = tensor.tensor_type.into(); - let data = tensor.data.as_slice(); - exec_context.set_input(index, &crate::types::Tensor { dims, ty, data })?; + exec_context.set_input(index, &tensor)?; Ok(Ok(())) } else { Err(UsageError::InvalidGraphHandle.into()) @@ -119,20 +113,6 @@ impl gen::inference::Host for WasiNnCtx { } } -impl From for crate::types::GraphEncoding { - fn from(value: gen::types::GraphEncoding) -> Self { - match value { - gen::types::GraphEncoding::Openvino => crate::types::GraphEncoding::OpenVINO, - gen::types::GraphEncoding::Onnx => crate::types::GraphEncoding::ONNX, - gen::types::GraphEncoding::Tensorflow => crate::types::GraphEncoding::Tensorflow, - gen::types::GraphEncoding::Pytorch => crate::types::GraphEncoding::PyTorch, - gen::types::GraphEncoding::Tensorflowlite => { - crate::types::GraphEncoding::TensorflowLite - } - } - } -} - impl TryFrom for crate::backend::BackendKind { type Error = UsageError; fn try_from(value: gen::types::GraphEncoding) -> Result { @@ -142,24 +122,3 @@ impl TryFrom for crate::backend::BackendKind { } } } - -impl From for crate::types::ExecutionTarget { - fn from(value: gen::types::ExecutionTarget) -> Self { - match value { - gen::types::ExecutionTarget::Cpu => crate::types::ExecutionTarget::CPU, - gen::types::ExecutionTarget::Gpu => crate::types::ExecutionTarget::GPU, - gen::types::ExecutionTarget::Tpu => crate::types::ExecutionTarget::TPU, - } - } -} - -impl From for crate::types::TensorType { - fn from(value: gen::types::TensorType) -> Self { - match value { - gen::types::TensorType::Fp16 => crate::types::TensorType::F16, - gen::types::TensorType::Fp32 => crate::types::TensorType::F32, - gen::types::TensorType::U8 => crate::types::TensorType::U8, - gen::types::TensorType::I32 => crate::types::TensorType::I32, - } - } -} diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index eaca434182f9..ae3a3165845f 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -11,7 +11,7 @@ //! 3. wrap up with some conversions, i.e., from `gen::*` types to this crate's //! [`types`]. //! -//! [`types`]: crate::types +//! [`types`]: crate::wit::types use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; use wiggle::GuestPtr; @@ -100,17 +100,12 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { tensor: &gen::types::Tensor<'b>, ) -> Result<()> { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { - let mut dims = vec![]; - for d in tensor.dimensions.iter() { - dims.push(d?.read()? as usize); - } - let ty = tensor.type_.into(); - let data_ = tensor.data - .as_slice()? - .expect("cannot use with shared memories; see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)"); - let data = data_.as_ref(); - let dims = &dims; - Ok(exec_context.set_input(index, &crate::types::Tensor { dims, ty, data })?) + let tensor = crate::wit::types::Tensor { + dimensions: tensor.dimensions.to_vec()?, + tensor_type: tensor.type_.into(), + data: tensor.data.to_vec()?, + }; + Ok(exec_context.set_input(index, &tensor)?) } else { Err(UsageError::InvalidGraphHandle.into()) } @@ -145,29 +140,29 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { // Implement some conversion from `witx::types::*` to this crate's version. -impl From for crate::types::ExecutionTarget { +impl From for crate::wit::types::ExecutionTarget { fn from(value: gen::types::ExecutionTarget) -> Self { match value { - gen::types::ExecutionTarget::Cpu => crate::types::ExecutionTarget::CPU, - gen::types::ExecutionTarget::Gpu => crate::types::ExecutionTarget::GPU, - gen::types::ExecutionTarget::Tpu => crate::types::ExecutionTarget::TPU, + gen::types::ExecutionTarget::Cpu => crate::wit::types::ExecutionTarget::Cpu, + gen::types::ExecutionTarget::Gpu => crate::wit::types::ExecutionTarget::Gpu, + gen::types::ExecutionTarget::Tpu => crate::wit::types::ExecutionTarget::Tpu, } } } -impl From for crate::types::GraphEncoding { +impl From for crate::wit::types::GraphEncoding { fn from(value: gen::types::GraphEncoding) -> Self { match value { - gen::types::GraphEncoding::Openvino => crate::types::GraphEncoding::OpenVINO, + gen::types::GraphEncoding::Openvino => crate::wit::types::GraphEncoding::Openvino, } } } -impl From for crate::types::TensorType { +impl From for crate::wit::types::TensorType { fn from(value: gen::types::TensorType) -> Self { match value { - gen::types::TensorType::F16 => crate::types::TensorType::F16, - gen::types::TensorType::F32 => crate::types::TensorType::F32, - gen::types::TensorType::U8 => crate::types::TensorType::U8, - gen::types::TensorType::I32 => crate::types::TensorType::I32, + gen::types::TensorType::F16 => crate::wit::types::TensorType::Fp16, + gen::types::TensorType::F32 => crate::wit::types::TensorType::Fp32, + gen::types::TensorType::U8 => crate::wit::types::TensorType::U8, + gen::types::TensorType::I32 => crate::wit::types::TensorType::I32, } } } From efc4b461f0fa05c7c8e8fb4a3b2565de1afe6262 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 16 Aug 2023 10:38:37 -0700 Subject: [PATCH 08/10] review: move `BackendKind` conversion into `witx.rs` --- crates/wasi-nn/src/backend/mod.rs | 8 -------- crates/wasi-nn/src/witx.rs | 11 +++++++++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/crates/wasi-nn/src/backend/mod.rs b/crates/wasi-nn/src/backend/mod.rs index b5317e525120..19b6610f1581 100644 --- a/crates/wasi-nn/src/backend/mod.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -56,11 +56,3 @@ pub enum BackendError { pub(crate) enum BackendKind { OpenVINO, } -impl From for BackendKind { - fn from(value: u8) -> Self { - match value { - 0 => BackendKind::OpenVINO, - _ => panic!("invalid backend"), - } - } -} diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index ae3a3165845f..724b304aa7d9 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -13,6 +13,7 @@ //! //! [`types`]: crate::wit::types +use crate::backend::BackendKind; use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; use wiggle::GuestPtr; @@ -58,8 +59,7 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { encoding: gen::types::GraphEncoding, target: gen::types::ExecutionTarget, ) -> Result { - let encoding_id: u8 = encoding.into(); - let graph = if let Some(backend) = self.backends.get_mut(&encoding_id.into()) { + let graph = if let Some(backend) = self.backends.get_mut(&encoding.into()) { // Retrieve all of the "builder lists" from the Wasm memory (see // $graph_builder_array) as slices for a backend to operate on. let mut slices = vec![]; @@ -140,6 +140,13 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { // Implement some conversion from `witx::types::*` to this crate's version. +impl From for BackendKind { + fn from(value: gen::types::GraphEncoding) -> Self { + match value { + gen::types::GraphEncoding::Openvino => BackendKind::OpenVINO, + } + } +} impl From for crate::wit::types::ExecutionTarget { fn from(value: gen::types::ExecutionTarget) -> Self { match value { From 1b6c386700701339279764afad9b96aa15bd95c8 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 16 Aug 2023 10:57:06 -0700 Subject: [PATCH 09/10] review: remove `<'a>` --- crates/wasi-nn/src/backend/openvino.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/wasi-nn/src/backend/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs index d29bd0a1f4f0..d44236250760 100644 --- a/crates/wasi-nn/src/backend/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -77,7 +77,7 @@ impl BackendGraph for OpenvinoGraph { struct OpenvinoExecutionContext(Arc, openvino::InferRequest); impl BackendExecutionContext for OpenvinoExecutionContext { - fn set_input<'a>(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { + fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { let input_name = self.0.get_input_name(index as usize)?; // Construct the blob structure. TODO: there must be some good way to From 7a07592cc13001c1e21e8318a08b12dcec90a418 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 16 Aug 2023 11:01:05 -0700 Subject: [PATCH 10/10] review: use `tracing` crate instead of `eprintln!` --- Cargo.lock | 1 + crates/wasi-nn/Cargo.toml | 1 + crates/wasi-nn/src/witx.rs | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 8da77b235cf5..ed8aeefc252f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3773,6 +3773,7 @@ dependencies = [ "anyhow", "openvino", "thiserror", + "tracing", "walkdir", "wasmtime", "wiggle", diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index fcd5206a1463..13f1501583e4 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -20,6 +20,7 @@ wiggle = { workspace = true } wasmtime = { workspace = true, features = ["component-model"] } # These dependencies are necessary for the wasi-nn implementation: +tracing = { workspace = true } openvino = { version = "0.5.0", features = ["runtime-linking"] } thiserror = { workspace = true } diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index 724b304aa7d9..7371017d0424 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -41,7 +41,7 @@ mod gen { &mut self, e: WasiNnError, ) -> anyhow::Result { - eprintln!("Host error: {:?}", e); + tracing::debug!("host error: {:?}", e); match e { WasiNnError::BackendError(_) => unimplemented!(), WasiNnError::GuestError(_) => unimplemented!(),