diff --git a/src/asm2wasm.h b/src/asm2wasm.h index c0f13ff8f3c..f266c6b03c3 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -660,11 +660,11 @@ void Asm2WasmBuilder::processAsm(Ref ast) { // TODO: when not using aliasing function pointers, we could merge them by noticing that // index 0 in each table is the null func, and each other index should only have one // non-null func. However, that breaks down when function pointer casts are emulated. - functionTableStarts[name] = wasm.table.names.size(); // this table starts here + functionTableStarts[name] = wasm.getDefaultTable()->values.size(); // this table starts here Ref contents = value[1]; for (unsigned k = 0; k < contents->size(); k++) { IString curr = contents[k][1]->getIString(); - wasm.table.names.push_back(curr); + wasm.getDefaultTable()->values.push_back(curr); } } else { abort_on("invalid var element", pair); diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index b7fc4347a2c..8a1b35d3f26 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -586,7 +586,7 @@ BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressio } BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value) { auto* ret = Builder(*((Module*)module)).makeReturn((Expression*)value); - + if (tracing) { auto id = noteExpression(ret); std::cout << " expressions[" << id << "] = BinaryenReturn(the_module, expressions[" << expressions[value] << "]);\n"; @@ -730,7 +730,7 @@ void BinaryenSetFunctionTable(BinaryenModuleRef module, BinaryenFunctionRef* fun auto* wasm = (Module*)module; for (BinaryenIndex i = 0; i < numFuncs; i++) { - wasm->table.names.push_back(((Function*)funcs[i])->name); + wasm->getDefaultTable()->values.push_back(((Function*)funcs[i])->name); } } diff --git a/src/passes/DuplicateFunctionElimination.cpp b/src/passes/DuplicateFunctionElimination.cpp index 961d26ba5a3..af5b133cffb 100644 --- a/src/passes/DuplicateFunctionElimination.cpp +++ b/src/passes/DuplicateFunctionElimination.cpp @@ -123,10 +123,12 @@ struct DuplicateFunctionElimination : public Pass { replacerRunner.add(&replacements); replacerRunner.run(); // replace in table - for (auto& name : module->table.names) { - auto iter = replacements.find(name); - if (iter != replacements.end()) { - name = iter->second; + for (auto& curr : module->tables) { + for (auto& name : curr->values) { + auto iter = replacements.find(name); + if (iter != replacements.end()) { + name = iter->second; + } } } // replace in start diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 47e503fa4c4..1e3161cc15f 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -576,7 +576,7 @@ struct PrintSExpression : public Visitor { } void visitTable(Table *curr) { printOpening(o, "table"); - for (auto name : curr->names) { + for (auto name : curr->values) { o << ' '; printName(name); } @@ -647,9 +647,9 @@ struct PrintSExpression : public Visitor { visitGlobal(child.get()); o << maybeNewLine; } - if (curr->table.names.size() > 0) { + for (auto& child : curr->tables) { doIndent(o, indent); - visitTable(&curr->table); + visitTable(child.get()); o << maybeNewLine; } for (auto& child : curr->functions) { diff --git a/src/passes/RemoveUnusedFunctions.cpp b/src/passes/RemoveUnusedFunctions.cpp index a2941aff662..e73b6f897d5 100644 --- a/src/passes/RemoveUnusedFunctions.cpp +++ b/src/passes/RemoveUnusedFunctions.cpp @@ -39,8 +39,10 @@ struct RemoveUnusedFunctions : public Pass { root.push_back(module->getFunction(curr->value)); } // For now, all functions that can be called indirectly are marked as roots. - for (auto& curr : module->table.names) { - root.push_back(module->getFunction(curr)); + for (auto& child : module->tables) { + for (auto& curr : child->values) { + root.push_back(module->getFunction(curr)); + } } // Compute function reachability starting from the root set. DirectCallGraphAnalyzer analyzer(module, root); diff --git a/src/passes/ReorderFunctions.cpp b/src/passes/ReorderFunctions.cpp index 38ef98afb7a..c4aeae9bd71 100644 --- a/src/passes/ReorderFunctions.cpp +++ b/src/passes/ReorderFunctions.cpp @@ -38,8 +38,10 @@ struct ReorderFunctions : public WalkerPassexports) { counts[curr->value]++; } - for (auto& curr : module->table.names) { - counts[curr]++; + for (auto& child : module->tables) { + for (auto& curr : child->values) { + counts[curr]++; + } } std::sort(module->functions.begin(), module->functions.end(), [this]( const std::unique_ptr& a, diff --git a/src/wasm-binary.h b/src/wasm-binary.h index aad142d4459..3106617622a 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -480,7 +480,7 @@ class WasmBinaryWriter : public Visitor { writeSignatures(); writeImports(); writeFunctionSignatures(); - writeFunctionTable(); + writeFunctionTables(); writeMemory(); writeGlobals(); writeExports(); @@ -556,7 +556,7 @@ class WasmBinaryWriter : public Visitor { finishSection(start); } - int32_t getFunctionTypeIndex(Name type) { + Index getFunctionTypeIndex(Name type) { // TODO: optimize for (size_t i = 0; i < wasm->functionTypes.size(); i++) { if (wasm->functionTypes[i]->name == type) return i; @@ -685,7 +685,7 @@ class WasmBinaryWriter : public Visitor { void writeExports() { if (wasm->exports.size() == 0) return; - if (debug) std::cerr << "== writeexports" << std::endl; + if (debug) std::cerr << "== writeExports" << std::endl; auto start = startSection(BinaryConsts::Section::ExportTable); o << U32LEB(wasm->exports.size()); for (auto& curr : wasm->exports) { @@ -724,8 +724,8 @@ class WasmBinaryWriter : public Visitor { assert(mappedImports.count(name)); return mappedImports[name]; } - - std::map mappedFunctions; // name of the Function => index + + std::map mappedFunctions; // name of the Function => entry index uint32_t getFunctionIndex(Name name) { if (!mappedFunctions.size()) { // Create name => index mapping. @@ -738,13 +738,18 @@ class WasmBinaryWriter : public Visitor { return mappedFunctions[name]; } - void writeFunctionTable() { - if (wasm->table.names.size() == 0) return; - if (debug) std::cerr << "== writeFunctionTable" << std::endl; + void writeFunctionTables() { + if (wasm->tables.size() == 0) return; + if (debug) std::cerr << "== writeFunctionTables" << std::endl; auto start = startSection(BinaryConsts::Section::FunctionTable); - o << U32LEB(wasm->table.names.size()); - for (auto name : wasm->table.names) { - o << U32LEB(getFunctionIndex(name)); + assert(wasm->tables.size() == 1); + // o << U32LEB(wasm->tables.size()); + for (auto& curr : wasm->tables) { + if (debug) std::cerr << "write one" << std::endl; + o << U32LEB(curr->values.size()); + for (auto name : curr->values) { + o << U32LEB(getFunctionIndex(name)); + } } finishSection(start); } @@ -1256,7 +1261,7 @@ class WasmBinaryBuilder { else if (match(BinaryConsts::Section::ExportTable)) readExports(); else if (match(BinaryConsts::Section::Globals)) readGlobals(); else if (match(BinaryConsts::Section::DataSegments)) readDataSegments(); - else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTable(); + else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTables(); else if (match(BinaryConsts::Section::Names)) readNames(); else { std::cerr << "unfamiliar section: "; @@ -1638,9 +1643,10 @@ class WasmBinaryBuilder { } } - for (size_t index : functionTable) { - assert(index < wasm.functions.size()); - wasm.table.names.push_back(wasm.functions[index]->name); + for (auto& pair : functionTable) { + assert(pair.first < wasm.tables.size()); + assert(pair.second < wasm.functions.size()); + wasm.tables[pair.first]->values.push_back(wasm.functions[pair.second]->name); } } @@ -1660,14 +1666,20 @@ class WasmBinaryBuilder { } } - std::vector functionTable; + std::vector> functionTable; - void readFunctionTable() { - if (debug) std::cerr << "== readFunctionTable" << std::endl; - auto num = getU32LEB(); - for (size_t i = 0; i < num; i++) { - auto index = getU32LEB(); - functionTable.push_back(index); + void readFunctionTables() { + if (debug) std::cerr << "== readFunctionTables" << std::endl; + size_t numTables = 1; // getU32LEB() + for (size_t i = 0; i < numTables; i++) { + if (debug) std::cerr << "read one" << std::endl; + auto curr = new Table; + auto size = getU32LEB(); + for (size_t j = 0; j < size; j++) { + auto index = getU32LEB(); + functionTable.push_back(std::make_pair<>(i, index)); + } + wasm.addTable(curr); } } diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index adaee0b6fca..5b21fc75680 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -666,9 +666,10 @@ class ModuleInstance { LiteralList arguments; Flow flow = generateArguments(curr->operands, arguments); if (flow.breaking()) return flow; + Table *table = instance.wasm.getDefaultTable(); size_t index = target.value.geti32(); - if (index >= instance.wasm.table.names.size()) trap("callIndirect: overflow"); - Name name = instance.wasm.table.names[index]; + if (index >= table->values.size()) trap("callIndirect: overflow"); + Name name = table->values[index]; Function *func = instance.wasm.getFunction(name); if (func->type.is() && func->type != curr->fullType) trap("callIndirect: bad type"); if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments"); diff --git a/src/wasm-linker.cpp b/src/wasm-linker.cpp index 45615fd59c1..c17f15e807c 100644 --- a/src/wasm-linker.cpp +++ b/src/wasm-linker.cpp @@ -117,22 +117,20 @@ void Linker::layout() { // Pre-assign the function indexes for (auto& pair : out.indirectIndexes) { - if (functionIndexes.count(pair.first) != 0 || - functionNames.count(pair.second) != 0) { - Fatal() << "Function " << pair.first << " already has an index " << - functionIndexes[pair.first] << " while setting index " << pair.second; + Index tableIndex = Table::kDefault; + Table *table = out.getIndirectTable(tableIndex); + if (functionIndexes.count(pair.second) != 0) { + Fatal() << "Function " << pair.second << " already has an index " << + functionIndexes[pair.second].first << " while setting index " << pair.first; } if (debug) { - std::cerr << "pre-assigned function index: " << pair.first << ": " - << pair.second << '\n'; + std::cerr << "pre-assigned function index: " << pair.second << ": " + << pair.first << '\n'; } - functionIndexes[pair.first] = pair.second; - functionNames[pair.second] = pair.first; - } - - // Emit the pre-assigned function names in sorted order - for (const auto& P : functionNames) { - out.wasm.table.names.push_back(P.second); + assert(table->values.size() == pair.first); + table->values.push_back(pair.second); + auto indexes = std::make_pair(tableIndex, pair.first); + functionIndexes[pair.second] = indexes; } for (auto& relocation : out.relocations) { @@ -170,6 +168,15 @@ void Linker::layout() { } } } + + // Create the actual tables in the underlying module. This is delayed because + // table references may be out of order, and the underlying object is a vector. + Index counter = 0; + for (auto& pair : out.tables) { + if (pair.first != counter++) Fatal() << "Tables are nonconsecutive!" << '\n'; + out.wasm.addTable(pair.second); + } + if (!!startFunction) { if (out.symbolInfo.implementedFunctions.count(startFunction) == 0) { Fatal() << "Unknown start function: `" << startFunction << "`\n"; @@ -206,9 +213,11 @@ void Linker::layout() { } // ensure an explicit function type for indirect call targets - for (auto& name : out.wasm.table.names) { - auto* func = out.wasm.getFunction(name); - func->type = ensureFunctionType(getSig(func), &out.wasm)->name; + for (auto& table : out.wasm.tables) { + for (auto& name : table->values) { + auto* func = out.wasm.getFunction(name); + func->type = ensureFunctionType(getSig(func), &out.wasm)->name; + } } } @@ -371,14 +380,17 @@ void Linker::emscriptenGlue(std::ostream& o) { Index Linker::getFunctionIndex(Name name) { if (!functionIndexes.count(name)) { - functionIndexes[name] = out.wasm.table.names.size(); - out.wasm.table.names.push_back(name); + Index tableIndex = Table::kDefault; + Table *table = out.getIndirectTable(tableIndex); + functionIndexes[name] = std::make_pair(tableIndex, table->values.size()); + table->values.push_back(name); if (debug) { std::cerr << "function index: " << name << ": " - << functionIndexes[name] << '\n'; + << functionIndexes[name].first << " " + << functionIndexes[name].second << '\n'; } } - return functionIndexes[name]; + return functionIndexes[name].second; } bool hasI64ResultOrParam(FunctionType* ft) { @@ -390,7 +402,7 @@ bool hasI64ResultOrParam(FunctionType* ft) { } void Linker::makeDummyFunction() { - assert(out.wasm.table.names.empty()); + assert(out.wasm.tables.empty()); bool create = false; // Check if there are address-taken functions for (auto& relocation : out.relocations) { @@ -410,27 +422,29 @@ void Linker::makeDummyFunction() { void Linker::makeDynCallThunks() { std::unordered_set sigs; wasm::Builder wasmBuilder(out.wasm); - for (const auto& indirectFunc : out.wasm.table.names) { - // Skip generating thunks for the dummy function - if (indirectFunc == dummyFunction) continue; - std::string sig(getSig(out.wasm.getFunction(indirectFunc))); - auto* funcType = ensureFunctionType(sig, &out.wasm); - if (hasI64ResultOrParam(funcType)) continue; // Can't export i64s on the web. - if (!sigs.insert(sig).second) continue; // Sig is already in the set - std::vector params; - params.emplace_back("fptr", i32); // function pointer param - int p = 0; - for (const auto& ty : funcType->params) params.emplace_back(std::to_string(p++), ty); - Function* f = wasmBuilder.makeFunction(std::string("dynCall_") + sig, std::move(params), funcType->result, {}); - Expression* fptr = wasmBuilder.makeGetLocal(0, i32); - std::vector args; - for (unsigned i = 0; i < funcType->params.size(); ++i) { - args.push_back(wasmBuilder.makeGetLocal(i + 1, funcType->params[i])); + for (const auto& table : out.wasm.tables) { + for (const auto& indirectFunc : table->values) { + // Skip generating thunks for the dummy function + if (indirectFunc == dummyFunction) continue; + std::string sig(getSig(out.wasm.getFunction(indirectFunc))); + auto* funcType = ensureFunctionType(sig, &out.wasm); + if (hasI64ResultOrParam(funcType)) continue; // Can't export i64s on the web. + if (!sigs.insert(sig).second) continue; // Sig is already in the set + std::vector params; + params.emplace_back("fptr", i32); // function pointer param + int p = 0; + for (const auto& ty : funcType->params) params.emplace_back(std::to_string(p++), ty); + Function* f = wasmBuilder.makeFunction(std::string("dynCall_") + sig, std::move(params), funcType->result, {}); + Expression* fptr = wasmBuilder.makeGetLocal(0, i32); + std::vector args; + for (unsigned i = 0; i < funcType->params.size(); ++i) { + args.push_back(wasmBuilder.makeGetLocal(i + 1, funcType->params[i])); + } + Expression* call = wasmBuilder.makeCallIndirect(funcType, fptr, args); + f->body = call; + out.wasm.addFunction(f); + exportFunction(f->name, true); } - Expression* call = wasmBuilder.makeCallIndirect(funcType, fptr, args); - f->body = call; - out.wasm.addFunction(f); - exportFunction(f->name, true); } } diff --git a/src/wasm-linker.h b/src/wasm-linker.h index 3f1e8c7aca6..6a6e4a84cae 100644 --- a/src/wasm-linker.h +++ b/src/wasm-linker.h @@ -111,6 +111,24 @@ class LinkerObject { return nullptr; } + // Create a table locally, because insertion into the underlying wasm vector + // needs to be delayed until all tables have been encountered. + Table *getIndirectTable(Index index) { + if (tables.count(index)) + return tables[index]; + + // Add the first default table, if it is missing and another table is + // being requested. + if (index && !tables.count(Table::kDefault)) { + getIndirectTable(Table::kDefault); + } + + // Otherwise, proceed and create the requested table. + assert(index == Table::kDefault); + tables[index] = Table::createDefaultTable(); + return tables[index]; + } + // Add an initializer segment for the named static variable. void addSegment(Name name, const char* data, Address size) { segments[name] = wasm.memory.segments.size(); @@ -142,8 +160,8 @@ class LinkerObject { } void addIndirectIndex(Name name, Address index) { - assert(!indirectIndexes.count(name)); - indirectIndexes[name] = index; + assert(!indirectIndexes.count(index)); + indirectIndexes[index] = name; } bool isEmpty() { @@ -177,9 +195,10 @@ class LinkerObject { std::unordered_map externTypesMap; std::map segments; // name => segment index (in wasm module) + std::map tables; // index => table index (in wasm module) // preassigned indexes for functions called indirectly - std::map indirectIndexes; + std::map indirectIndexes; std::vector initializerFunctions; @@ -328,8 +347,7 @@ class Linker { std::unordered_map staticAddresses; // name => address std::unordered_map segmentsByAddress; // address => segment index - std::unordered_map functionIndexes; - std::map functionNames; + std::unordered_map> functionIndexes; // name => table, entry indexes }; diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 750d485fbe1..bf927b0a208 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -1403,7 +1403,7 @@ class SExpressionWasmBuilder { void parseTable(Element& s) { for (size_t i = 1; i < s.size(); i++) { - wasm.table.names.push_back(getFunctionName(*s[i])); + wasm.getDefaultTable()->values.push_back(getFunctionName(*s[i])); } } diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index b50ca0fb292..cddb5be899c 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -205,7 +205,9 @@ struct Walker : public VisitorType { for (auto& curr : module->functions) { self->walkFunction(curr.get()); } - self->visitTable(&module->table); + for (auto &curr : module->tables) { + self->visitTable(curr.get()); + } self->visitMemory(&module->memory); } diff --git a/src/wasm.h b/src/wasm.h index 38a030e1a26..2e25a64b038 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1432,7 +1432,16 @@ class Export { class Table { public: - std::vector names; + Name name; + std::vector values; + + enum { kDefault = 0 }; + + static Table* createDefaultTable() { + Table *table = new Table(); + table->name = Name::fromInt(kDefault); + return table; + } }; class Memory { @@ -1475,8 +1484,8 @@ class Module { std::vector> exports; std::vector> functions; std::vector> globals; + std::vector> tables; - Table table; Memory memory; Name start; @@ -1489,6 +1498,7 @@ class Module { std::map exportsMap; std::map functionsMap; std::map globalsMap; + std::map tablesMap; public: Module() {}; @@ -1498,24 +1508,28 @@ class Module { Export* getExport(Index i) { assert(i < exports.size()); return exports[i].get(); } Function* getFunction(Index i) { assert(i < functions.size()); return functions[i].get(); } Global* getGlobal(Index i) { assert(i < globals.size()); return globals[i].get(); } + Table* getTable(Index i) { assert(i < tables.size()); return tables[i].get(); } FunctionType* getFunctionType(Name name) { assert(functionTypesMap.count(name)); return functionTypesMap[name]; } Import* getImport(Name name) { assert(importsMap.count(name)); return importsMap[name]; } Export* getExport(Name name) { assert(exportsMap.count(name)); return exportsMap[name]; } Function* getFunction(Name name) { assert(functionsMap.count(name)); return functionsMap[name]; } Global* getGlobal(Name name) { assert(globalsMap.count(name)); return globalsMap[name]; } + Table* getTable(Name name) { assert(tablesMap.count(name)); return tablesMap[name]; } FunctionType* checkFunctionType(Name name) { if (!functionTypesMap.count(name)) return nullptr; return functionTypesMap[name]; } Import* checkImport(Name name) { if (!importsMap.count(name)) return nullptr; return importsMap[name]; } Export* checkExport(Name name) { if (!exportsMap.count(name)) return nullptr; return exportsMap[name]; } Function* checkFunction(Name name) { if (!functionsMap.count(name)) return nullptr; return functionsMap[name]; } Global* checkGlobal(Name name) { if (!globalsMap.count(name)) return nullptr; return globalsMap[name]; } + Table* checkTable(Name name) { if (!tablesMap.count(name)) return nullptr; return tablesMap[name]; } FunctionType* checkFunctionType(Index i) { if (i >= functionTypes.size()) return nullptr; return functionTypes[i].get(); } Import* checkImport(Index i) { if (i >= imports.size()) return nullptr; return imports[i].get(); } Export* checkExport(Index i) { if (i >= exports.size()) return nullptr; return exports[i].get(); } Function* checkFunction(Index i) { if (i >= functions.size()) return nullptr; return functions[i].get(); } Global* checkGlobal(Index i) { if (i >= globals.size()) return nullptr; return globals[i].get(); } + Table* checkTable(Index i) { if (i >= tables.size()) return nullptr; return tables[i].get(); } void addFunctionType(FunctionType* curr) { Name numericName = Name::fromInt(functionTypes.size()); // TODO: remove all these, assert on names already existing, do numeric stuff in wasm-s-parser etc. @@ -1562,11 +1576,27 @@ class Module { globalsMap[curr->name] = curr; globalsMap[numericName] = curr; } - + void addTable(Table *curr) { + Name numericName = Name::fromInt(tables.size()); + if (curr->name.isNull()) { + curr->name = numericName; + } + tables.push_back(std::unique_ptr(curr)); + tablesMap[curr->name] = curr; + tablesMap[numericName] = curr; + } void addStart(const Name &s) { start = s; } + Table* getDefaultTable() { + Table *def = checkTable(Name::fromInt(Table::kDefault)); + if (def) return def; + def = Table::createDefaultTable(); + addTable(def); + return def; + } + void removeImport(Name name) { for (size_t i = 0; i < imports.size(); i++) { if (imports[i]->name == name) {