Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 75 additions & 19 deletions crates/wit-component/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ use crate::metadata::{self, Bindgen, ModuleMetadata};
use crate::validation::{ValidatedModule, BARE_FUNC_MODULE_NAME, MAIN_MODULE_IMPORT_NAME};
use crate::StringEncoding;
use anyhow::{anyhow, bail, Context, Result};
use indexmap::IndexMap;
use indexmap::{IndexMap, IndexSet};
use std::collections::HashMap;
use std::hash::Hash;
use wasm_encoder::*;
use wasmparser::{Validator, WasmFeatures};
use wit_parser::{
abi::{AbiVariant, WasmSignature, WasmType},
Function, InterfaceId, Resolve, Type, TypeDefKind, TypeId, TypeOwner, WorldId, WorldItem,
Function, InterfaceId, Resolve, Type, TypeDefKind, TypeId, TypeOwner, WorldItem,
};

const INDIRECT_TABLE_NAME: &str = "$imports";
Expand Down Expand Up @@ -457,7 +457,10 @@ impl<'a> EncodingState<'a> {
// will forward them through.
if let Some(live) = encoder.state.info.live_types.get(&interface_id) {
for ty in live {
log::trace!("encoding extra type {ty:?}");
log::trace!(
"encoding extra type {ty:?} name={:?}",
resolve.types[*ty].name
);
encoder.encode_valtype(resolve, &Type::Id(*ty))?;
}
}
Expand Down Expand Up @@ -593,11 +596,20 @@ impl<'a> EncodingState<'a> {
Some(core_wasm_name)
};
let import = &self.info.import_map[&interface];
let required_imports = match for_module {
CustomModule::Main => &self.info.info.required_imports[core_wasm_name],
CustomModule::Adapter(name) => {
&self.info.adapters[name].0.required_imports[core_wasm_name]
}
};
let mut exports = Vec::with_capacity(import.direct.len() + import.indirect.len());

// Add an entry for all indirect lowerings which come as an export of
// the shim module.
for (i, lowering) in import.indirect.iter().enumerate() {
if !required_imports.contains(&lowering.name) {
continue;
}
let encoding =
metadata.import_encodings[&(core_wasm_name.to_string(), lowering.name.to_string())];
let index = self.component.alias_core_item(
Expand All @@ -617,6 +629,9 @@ impl<'a> EncodingState<'a> {
// All direct lowerings can be `canon lower`'d here immediately and
// passed as arguments.
for lowering in &import.direct {
if !required_imports.contains(&lowering.name) {
continue;
}
let func_index = match &import.interface {
Some((interface, _url)) => {
let instance_index = self.imported_instances[interface];
Expand All @@ -633,13 +648,13 @@ impl<'a> EncodingState<'a> {

fn encode_exports(&mut self, module: CustomModule) -> Result<()> {
let resolve = &self.info.encoder.metadata.resolve;
let world = match module {
CustomModule::Main => self.info.encoder.metadata.world,
CustomModule::Adapter(name) => self.info.encoder.adapters[name].2,
let exports = match module {
CustomModule::Main => &self.info.encoder.main_module_exports,
CustomModule::Adapter(name) => &self.info.encoder.adapters[name].2,
};
let world = &resolve.worlds[world];
for (export_name, export) in world.exports.iter() {
match export {
let world = &resolve.worlds[self.info.encoder.metadata.world];
for export_name in exports {
match &world.exports[export_name] {
WorldItem::Function(func) => {
let mut enc = self.root_type_encoder(None);
let ty = enc.encode_func_type(resolve, func)?;
Expand Down Expand Up @@ -857,7 +872,7 @@ impl<'a> EncodingState<'a> {

// For all interfaces imported into the main module record all of their
// indirect lowerings into `Shims`.
for core_wasm_name in info.required_imports.keys() {
for (core_wasm_name, required) in info.required_imports.iter() {
let import_name = if *core_wasm_name == BARE_FUNC_MODULE_NAME {
None
} else {
Expand All @@ -868,6 +883,7 @@ impl<'a> EncodingState<'a> {
core_wasm_name,
CustomModule::Main,
import,
required,
info.metadata,
&mut signatures,
);
Expand All @@ -877,12 +893,13 @@ impl<'a> EncodingState<'a> {
// function and additionally a set of shims are created for the
// interface imported into the shim module itself.
for (adapter, (info, _wasm)) in self.info.adapters.iter() {
for (name, _) in info.required_imports.iter() {
for (name, required) in info.required_imports.iter() {
let import = &self.info.import_map[&Some(*name)];
ret.append_indirect(
name,
CustomModule::Adapter(adapter),
import,
required,
info.metadata,
&mut signatures,
);
Expand Down Expand Up @@ -916,7 +933,8 @@ impl<'a> EncodingState<'a> {
}

for shim in ret.list.iter() {
ret.shim_names.insert(shim.kind, shim.name.clone());
let prev = ret.shim_names.insert(shim.kind, shim.name.clone());
assert!(prev.is_none());
}

assert!(self.shim_instance_index.is_none());
Expand Down Expand Up @@ -1304,6 +1322,7 @@ impl<'a> Shims<'a> {
core_wasm_module: &'a str,
for_module: CustomModule<'a>,
import: &ImportedInterface<'a>,
required: &IndexSet<&str>,
metadata: &ModuleMetadata,
sigs: &mut Vec<WasmSignature>,
) {
Expand All @@ -1313,6 +1332,9 @@ impl<'a> Shims<'a> {
Some(core_wasm_module)
};
for (indirect_index, lowering) in import.indirect.iter().enumerate() {
if !required.contains(&lowering.name) {
continue;
}
let shim_name = self.list.len().to_string();
log::debug!(
"shim {shim_name} is import `{core_wasm_module}` lowering {indirect_index} `{}`",
Expand Down Expand Up @@ -1342,15 +1364,17 @@ pub struct ComponentEncoder {
module: Vec<u8>,
metadata: Bindgen,
validate: bool,
main_module_exports: IndexSet<String>,

// This is a map from the name of the adapter to a pair of:
//
// * The wasm of the adapter itself, with `component-type` sections
// stripped.
// * the metadata for the adapter, verified to have no exports and only
// imports.
// * The world within `self.metadata.doc` which the adapter works with.
adapters: IndexMap<String, (Vec<u8>, ModuleMetadata, WorldId)>,
// * The set of exports from the final world which are defined by this
// adapter.
adapters: IndexMap<String, (Vec<u8>, ModuleMetadata, IndexSet<String>)>,
}

impl ComponentEncoder {
Expand All @@ -1361,7 +1385,15 @@ impl ComponentEncoder {
/// core module.
pub fn module(mut self, module: &[u8]) -> Result<Self> {
let (wasm, metadata) = metadata::decode(module)?;
self.metadata.merge(metadata)?;
self.main_module_exports.extend(
metadata.resolve.worlds[metadata.world]
.exports
.keys()
.cloned(),
);
self.metadata
.merge(metadata)
.context("failed merge WIT package sets together")?;
self.module = if let Some(producers) = &self.metadata.producers {
producers.add_to_wasm(&wasm)?
} else {
Expand Down Expand Up @@ -1396,11 +1428,35 @@ impl ComponentEncoder {
pub fn adapter(mut self, name: &str, bytes: &[u8]) -> Result<Self> {
let (wasm, metadata) = metadata::decode(bytes)?;
// Merge the adapter's document into our own document to have one large
// document, but the adapter's world isn't merged in to our world so
// retain it separately.
let world = self.metadata.resolve.merge(metadata.resolve).worlds[metadata.world.index()];
// document, and then afterwards merge worlds as well.
//
// The first `merge` operation will interleave equivalent packages from
// each adapter into packages that are stored within our own resolve.
// The second `merge_worlds` operation will then ensure that both the
// adapter and the main module have compatible worlds, meaning that they
// either import the same items or they import disjoint items, for
// example.
let world = self
.metadata
.resolve
.merge(metadata.resolve)
.with_context(|| {
format!("failed to merge WIT packages of adapter `{name}` into main packages")
})?
.worlds[metadata.world.index()];
self.metadata
.resolve
.merge_worlds(world, self.metadata.world)
.with_context(|| {
format!("failed to merge WIT world of adapter `{name}` into main package")
})?;
let exports = self.metadata.resolve.worlds[world]
.exports
.keys()
.cloned()
.collect();
self.adapters
.insert(name.to_string(), (wasm, metadata.metadata, world));
.insert(name.to_string(), (wasm, metadata.metadata, exports));
Ok(self)
}

Expand Down
1 change: 1 addition & 0 deletions crates/wit-component/src/encoding/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub trait ValtypeEncoder<'a> {
// If this type is imported from another interface then return
// it as it was bound here with an alias.
let ty = &resolve.types[id];
log::trace!("encode type name={:?}", ty.name);
if let Some(index) = self.maybe_import_type(resolve, id) {
self.type_map().insert(id, index);
return Ok(ComponentValType::Type(index));
Expand Down
Loading