diff options
41 files changed, 531 insertions, 280 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 1b3f11e22..aff4ed4fc 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -3409,7 +3409,13 @@ BinaryenAddActiveElementSegment(BinaryenModuleRef module, auto segment = std::make_unique<ElementSegment>(table, (Expression*)offset); segment->setExplicitName(name); for (BinaryenIndex i = 0; i < numFuncNames; i++) { - segment->data.push_back(funcNames[i]); + auto* func = ((Module*)module)->getFunctionOrNull(funcNames[i]); + if (func == nullptr) { + Fatal() << "invalid function '" << funcNames[i] << "'."; + } + Type type(HeapType(func->sig), Nullable); + segment->data.push_back( + Builder(*(Module*)module).makeRefFunc(funcNames[i], type)); } return ((Module*)module)->addElementSegment(std::move(segment)); } @@ -3421,7 +3427,13 @@ BinaryenAddPassiveElementSegment(BinaryenModuleRef module, auto segment = std::make_unique<ElementSegment>(); segment->setExplicitName(name); for (BinaryenIndex i = 0; i < numFuncNames; i++) { - segment->data.push_back(funcNames[i]); + auto* func = ((Module*)module)->getFunctionOrNull(funcNames[i]); + if (func == nullptr) { + Fatal() << "invalid function '" << funcNames[i] << "'."; + } + Type type(HeapType(func->sig), Nullable); + segment->data.push_back( + Builder(*(Module*)module).makeRefFunc(funcNames[i], type)); } return ((Module*)module)->addElementSegment(std::move(segment)); } @@ -3460,7 +3472,13 @@ const char* BinaryenElementSegmentGetData(BinaryenElementSegmentRef elem, if (data.size() <= dataId) { Fatal() << "invalid segment data id."; } - return data[dataId].c_str(); + if (data[dataId]->is<RefNull>()) { + return NULL; + } else if (auto* get = data[dataId]->dynCast<RefFunc>()) { + return get->func.c_str(); + } else { + Fatal() << "invalid expression in segment data."; + } } // Memory. One per module diff --git a/src/ir/element-utils.h b/src/ir/element-utils.h new file mode 100644 index 000000000..adfd9955f --- /dev/null +++ b/src/ir/element-utils.h @@ -0,0 +1,52 @@ +/* + * Copyright 2021 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_element_h +#define wasm_ir_element_h + +#include "wasm-builder.h" +#include "wasm.h" + +namespace wasm { + +namespace ElementUtils { + +// iterate over functions referenced in an element segment +template<typename T> +inline void iterElementSegmentFunctionNames(ElementSegment* segment, + T visitor) { + // TODO(reference-types): return early if segment type is non-funcref + for (Index i = 0; i < segment->data.size(); i++) { + if (auto* get = segment->data[i]->dynCast<RefFunc>()) { + visitor(get->func, i); + } + } +} + +// iterate over functions referenced in all element segments of a module +template<typename T> +inline void iterAllElementFunctionNames(const Module* wasm, T visitor) { + for (auto& segment : wasm->elementSegments) { + iterElementSegmentFunctionNames( + segment.get(), [&](Name& name, Index i) { visitor(name); }); + } +} + +} // namespace ElementUtils + +} // namespace wasm + +#endif // wasm_ir_element_h diff --git a/src/ir/module-splitting.cpp b/src/ir/module-splitting.cpp index d4d03991e..5b86dd494 100644 --- a/src/ir/module-splitting.cpp +++ b/src/ir/module-splitting.cpp @@ -73,6 +73,7 @@ // complex code, so it is a good candidate for a follow up PR. #include "ir/module-splitting.h" +#include "ir/element-utils.h" #include "ir/manipulation.h" #include "ir/module-utils.h" #include "ir/names.h" @@ -95,12 +96,19 @@ template<class F> void forEachElement(Module& module, F f) { } else if (auto* g = segment->offset->dynCast<GlobalGet>()) { base = g->name; } - for (size_t i = 0; i < segment->data.size(); ++i) { - f(segment->table, base, offset + i, segment->data[i]); - } + ElementUtils::iterElementSegmentFunctionNames( + segment, [&](Name& entry, Index i) { + f(segment->table, base, offset + i, entry); + }); }); } +static RefFunc* makeRefFunc(Module& wasm, Function* func) { + // FIXME: make the type NonNullable when we support it! + return Builder(wasm).makeRefFunc(func->name, + Type(HeapType(func->sig), Nullable)); +} + struct TableSlotManager { struct Slot { Name tableName; @@ -124,7 +132,7 @@ struct TableSlotManager { Table* makeTable(); // Returns the table index for `func`, allocating a new index if necessary. - Slot getSlot(Name func); + Slot getSlot(RefFunc* entry); void addSlot(Name func, Slot slot); }; @@ -199,8 +207,8 @@ Table* TableSlotManager::makeTable() { return module.addTable(Builder::makeTable(Name::fromInt(0))); } -TableSlotManager::Slot TableSlotManager::getSlot(Name func) { - auto slotIt = funcIndices.find(func); +TableSlotManager::Slot TableSlotManager::getSlot(RefFunc* entry) { + auto slotIt = funcIndices.find(entry->func); if (slotIt != funcIndices.end()) { return slotIt->second; } @@ -227,8 +235,10 @@ TableSlotManager::Slot TableSlotManager::getSlot(Name func) { Slot newSlot = {activeBase.tableName, activeBase.global, activeBase.index + Index(activeSegment->data.size())}; - activeSegment->data.push_back(func); - addSlot(func, newSlot); + + activeSegment->data.push_back(entry); + + addSlot(entry->func, newSlot); if (activeTable->initial <= newSlot.index) { activeTable->initial = newSlot.index + 1; } @@ -372,15 +382,15 @@ void ModuleSplitter::thunkExportedSecondaryFunctions() { // We've already created a thunk for this function continue; } - auto tableSlot = tableManager.getSlot(secondaryFunc); auto func = std::make_unique<Function>(); - func->name = secondaryFunc; func->sig = secondary.getFunction(secondaryFunc)->sig; std::vector<Expression*> args; for (size_t i = 0, size = func->sig.params.size(); i < size; ++i) { args.push_back(builder.makeLocalGet(i, func->sig.params[i])); } + + auto tableSlot = tableManager.getSlot(makeRefFunc(primary, func.get())); func->body = builder.makeCallIndirect( tableSlot.tableName, tableSlot.makeExpr(primary), args, func->sig); primary.addFunction(std::move(func)); @@ -395,17 +405,21 @@ void ModuleSplitter::indirectCallsToSecondaryFunctions() { Builder builder; CallIndirector(ModuleSplitter& parent) : parent(parent), builder(parent.primary) {} + // Avoid visitRefFunc on element segment data + void walkElementSegment(ElementSegment* segment) {} void visitCall(Call* curr) { if (!parent.secondaryFuncs.count(curr->target)) { return; } - auto tableSlot = parent.tableManager.getSlot(curr->target); - replaceCurrent(builder.makeCallIndirect( - tableSlot.tableName, - tableSlot.makeExpr(parent.primary), - curr->operands, - parent.secondary.getFunction(curr->target)->sig, - curr->isReturn)); + auto func = parent.secondary.getFunction(curr->target); + auto tableSlot = + parent.tableManager.getSlot(makeRefFunc(parent.primary, func)); + replaceCurrent( + builder.makeCallIndirect(tableSlot.tableName, + tableSlot.makeExpr(parent.primary), + curr->operands, + func->sig, + curr->isReturn)); } void visitRefFunc(RefFunc* curr) { assert(false && "TODO: handle ref.func as well"); @@ -454,14 +468,14 @@ void ModuleSplitter::setupTablePatching() { return; } - std::map<Index, Name> replacedElems; + std::map<Index, Function*> replacedElems; // Replace table references to secondary functions with an imported // placeholder that encodes the table index in its name: // `importNamespace`.`index`. forEachElement(primary, [&](Name, Name, Index index, Name& elem) { if (secondaryFuncs.count(elem)) { - replacedElems[index] = elem; auto* secondaryFunc = secondary.getFunction(elem); + replacedElems[index] = secondaryFunc; auto placeholder = std::make_unique<Function>(); placeholder->module = config.placeholderNamespace; placeholder->base = std::to_string(index); @@ -495,8 +509,8 @@ void ModuleSplitter::setupTablePatching() { // to be imported into the second module. TODO: use better strategies here, // such as using ref.func in the start function or standardizing addition in // initializer expressions. - const ElementSegment* primarySeg = tableManager.activeTableSegments.front(); - std::vector<Name> secondaryElems; + ElementSegment* primarySeg = tableManager.activeTableSegments.front(); + std::vector<Expression*> secondaryElems; secondaryElems.reserve(primarySeg->data.size()); // Copy functions from the primary segment to the secondary segment, @@ -507,33 +521,35 @@ void ModuleSplitter::setupTablePatching() { ++i) { if (replacement->first == i) { // primarySeg->data[i] is a placeholder, so use the secondary function. - secondaryElems.push_back(replacement->second); + secondaryElems.push_back(makeRefFunc(secondary, replacement->second)); ++replacement; - } else { - exportImportFunction(primarySeg->data[i]); - secondaryElems.push_back(primarySeg->data[i]); + } else if (auto* get = primarySeg->data[i]->dynCast<RefFunc>()) { + exportImportFunction(get->func); + auto* copied = + ExpressionManipulator::copy(primarySeg->data[i], secondary); + secondaryElems.push_back(copied); } } auto offset = ExpressionManipulator::copy(primarySeg->offset, secondary); - auto secondaryElem = std::make_unique<ElementSegment>( + auto secondarySeg = std::make_unique<ElementSegment>( secondaryTable->name, offset, secondaryElems); - secondaryElem->setName(primarySeg->name, primarySeg->hasExplicitName); - secondary.addElementSegment(std::move(secondaryElem)); + secondarySeg->setName(primarySeg->name, primarySeg->hasExplicitName); + secondary.addElementSegment(std::move(secondarySeg)); return; } // Create active table segments in the secondary module to patch in the // original functions when it is instantiated. Index currBase = replacedElems.begin()->first; - std::vector<Name> currData; + std::vector<Expression*> currData; auto finishSegment = [&]() { auto* offset = Builder(secondary).makeConst(int32_t(currBase)); - auto secondaryElem = + auto secondarySeg = std::make_unique<ElementSegment>(secondaryTable->name, offset, currData); - secondaryElem->setName(Name::fromInt(secondary.elementSegments.size()), - false); - secondary.addElementSegment(std::move(secondaryElem)); + secondarySeg->setName(Name::fromInt(secondary.elementSegments.size()), + false); + secondary.addElementSegment(std::move(secondarySeg)); }; for (auto curr = replacedElems.begin(); curr != replacedElems.end(); ++curr) { if (curr->first != currBase + currData.size()) { @@ -541,7 +557,7 @@ void ModuleSplitter::setupTablePatching() { currBase = curr->first; currData.clear(); } - currData.push_back(curr->second); + currData.push_back(makeRefFunc(secondary, curr->second)); } if (currData.size()) { finishSegment(); diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index 9b1682b91..a87daee9e 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -17,6 +17,7 @@ #ifndef wasm_ir_module_h #define wasm_ir_module_h +#include "ir/element-utils.h" #include "ir/find_all.h" #include "ir/manipulation.h" #include "ir/properties.h" @@ -75,7 +76,10 @@ inline ElementSegment* copyElementSegment(const ElementSegment* segment, auto copy = [&](std::unique_ptr<ElementSegment>&& ret) { ret->name = segment->name; ret->hasExplicitName = segment->hasExplicitName; - ret->data = segment->data; + ret->data.reserve(segment->data.size()); + for (auto* item : segment->data) { + ret->data.push_back(ExpressionManipulator::copy(item, out)); + } return out.addElementSegment(std::move(ret)); }; @@ -161,11 +165,7 @@ template<typename T> inline void renameFunctions(Module& wasm, T& map) { } }; maybeUpdate(wasm.start); - for (auto& segment : wasm.elementSegments) { - for (auto& name : segment->data) { - maybeUpdate(name); - } - } + ElementUtils::iterAllElementFunctionNames(&wasm, maybeUpdate); for (auto& exp : wasm.exports) { if (exp->kind == ExternalKind::Function) { maybeUpdate(exp->value); diff --git a/src/ir/table-utils.cpp b/src/ir/table-utils.cpp index 639f8fbe6..80ef885f9 100644 --- a/src/ir/table-utils.cpp +++ b/src/ir/table-utils.cpp @@ -15,6 +15,7 @@ */ #include "table-utils.h" +#include "element-utils.h" #include "find_all.h" #include "module-utils.h" @@ -31,11 +32,8 @@ std::set<Name> getFunctionsNeedingElemDeclare(Module& wasm) { // Find all the names in the tables. std::unordered_set<Name> tableNames; - for (auto& segment : wasm.elementSegments) { - for (auto name : segment->data) { - tableNames.insert(name); - } - } + ElementUtils::iterAllElementFunctionNames( + &wasm, [&](Name name) { tableNames.insert(name); }); // Find all the names in ref.funcs. using Names = std::unordered_set<Name>; diff --git a/src/ir/table-utils.h b/src/ir/table-utils.h index e90b0ca72..2d91f0035 100644 --- a/src/ir/table-utils.h +++ b/src/ir/table-utils.h @@ -17,6 +17,7 @@ #ifndef wasm_ir_table_h #define wasm_ir_table_h +#include "ir/element-utils.h" #include "ir/literal-utils.h" #include "ir/module-utils.h" #include "wasm-traversal.h" @@ -45,9 +46,8 @@ struct FlatTable { if (end > names.size()) { names.resize(end); } - for (Index i = 0; i < segment->data.size(); i++) { - names[start + i] = segment->data[i]; - } + ElementUtils::iterElementSegmentFunctionNames( + segment, [&](Name entry, Index i) { names[start + i] = entry; }); }); } }; @@ -84,7 +84,12 @@ inline Index append(Table& table, Name name, Module& wasm) { } wasm.dylinkSection->tableSize++; } - segment->data.push_back(name); + + auto* func = wasm.getFunctionOrNull(name); + assert(func != nullptr && "Cannot append non-existing function to a table."); + // FIXME: make the type NonNullable when we support it! + auto type = Type(HeapType(func->sig), Nullable); + segment->data.push_back(Builder(wasm).makeRefFunc(name, type)); table.initial = table.initial + 1; return tableIndex; } @@ -94,8 +99,10 @@ inline Index append(Table& table, Name name, Module& wasm) { inline Index getOrAppend(Table& table, Name name, Module& wasm) { auto segment = getSingletonSegment(table, wasm); for (Index i = 0; i < segment->data.size(); i++) { - if (segment->data[i] == name) { - return i; + if (auto* get = segment->data[i]->dynCast<RefFunc>()) { + if (get->func == name) { + return i; + } } } return append(table, name, wasm); diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 6223c841d..c4de07d81 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -39,6 +39,7 @@ #include "cfg/cfg-traversal.h" #include "ir/effects.h" +#include "ir/element-utils.h" #include "ir/module-utils.h" #include "pass.h" #include "passes/opt-utils.h" @@ -284,11 +285,8 @@ struct DAE : public Pass { infoMap[curr->value].hasUnseenCalls = true; } } - for (auto& segment : module->elementSegments) { - for (auto name : segment->data) { - infoMap[name].hasUnseenCalls = true; - } - } + ElementUtils::iterAllElementFunctionNames( + module, [&](Name name) { infoMap[name].hasUnseenCalls = true; }); // Scan all the functions. DAEScanner(&infoMap).run(runner, module); // Combine all the info. diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp index 753c986b8..e63375f2b 100644 --- a/src/passes/FuncCastEmulation.cpp +++ b/src/passes/FuncCastEmulation.cpp @@ -30,6 +30,7 @@ #include <string> +#include <ir/element-utils.h> #include <ir/literal-utils.h> #include <pass.h> #include <wasm-builder.h> @@ -173,18 +174,16 @@ struct FuncCastEmulation : public Pass { Signature ABIType(Type(std::vector<Type>(numParams, Type::i64)), Type::i64); // Add a thunk for each function in the table, and do the call through it. std::unordered_map<Name, Name> funcThunks; - for (auto& segment : module->elementSegments) { - for (auto& name : segment->data) { - auto iter = funcThunks.find(name); - if (iter == funcThunks.end()) { - auto thunk = makeThunk(name, module, numParams); - funcThunks[name] = thunk; - name = thunk; - } else { - name = iter->second; - } + ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) { + auto iter = funcThunks.find(name); + if (iter == funcThunks.end()) { + auto thunk = makeThunk(name, module, numParams); + funcThunks[name] = thunk; + name = thunk; + } else { + name = iter->second; } - } + }); // update call_indirects ParallelFuncCastEmulation(ABIType, numParams).run(runner, module); diff --git a/src/passes/GenerateDynCalls.cpp b/src/passes/GenerateDynCalls.cpp index dad992fea..8ed9ce8b4 100644 --- a/src/passes/GenerateDynCalls.cpp +++ b/src/passes/GenerateDynCalls.cpp @@ -24,6 +24,7 @@ #include "abi/js.h" #include "asm_v_wasm.h" +#include "ir/element-utils.h" #include "ir/import-utils.h" #include "pass.h" #include "support/debug.h" @@ -57,9 +58,10 @@ struct GenerateDynCalls : public WalkerPass<PostWalker<GenerateDynCalls>> { }); if (it != segments.end()) { std::vector<Name> tableSegmentData; - for (const auto& indirectFunc : it->get()->data) { - generateDynCallThunk(wasm->getFunction(indirectFunc)->sig); - } + ElementUtils::iterElementSegmentFunctionNames( + it->get(), [&](Name name, Index) { + generateDynCallThunk(wasm->getFunction(name)->sig); + }); } } diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 3f6263815..b69d8f7c2 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -31,6 +31,7 @@ #include <atomic> #include "ir/debug.h" +#include "ir/element-utils.h" #include "ir/literal-utils.h" #include "ir/module-utils.h" #include "ir/type-updating.h" @@ -350,11 +351,8 @@ struct Inlining : public Pass { infos[ex->value].usedGlobally = true; } } - for (auto& segment : module->elementSegments) { - for (auto name : segment->data) { - infos[name].usedGlobally = true; - } - } + ElementUtils::iterAllElementFunctionNames( + module, [&](Name name) { infos[name].usedGlobally = true; }); for (auto& global : module->globals) { if (!global->imported()) { diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index 0a991caa5..a53481d65 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -31,6 +31,7 @@ // #include "asmjs/shared-constants.h" +#include "ir/element-utils.h" #include "ir/import-utils.h" #include "ir/literal-utils.h" #include "ir/utils.h" @@ -97,13 +98,11 @@ struct LegalizeJSInterface : public Pass { // we need to use the legalized version in the tables, as the import // from JS is legal for JS. Our stub makes it look like a native wasm // function. - for (auto& segment : module->elementSegments) { - for (auto& name : segment->data) { - if (name == im->name) { - name = funcName; - } + ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) { + if (name == im->name) { + name = funcName; } - } + }); } } diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 4346fe705..6dfd495c0 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2704,6 +2704,18 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { if (curr->data.empty()) { return; } + bool allElementsRefFunc = + std::all_of(curr->data.begin(), curr->data.end(), [](Expression* entry) { + return entry->is<RefFunc>(); + }); + auto printElemType = [&]() { + if (allElementsRefFunc) { + TypeNamePrinter(o, currModule).print(HeapType::func); + } else { + TypeNamePrinter(o, currModule).print(Type::funcref); + } + }; + doIndent(o, indent); o << '('; printMedium(o, "elem"); @@ -2714,7 +2726,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { if (curr->table.is()) { // TODO(reference-types): check for old-style based on the complete spec - if (currModule->tables.size() > 1) { + if (!allElementsRefFunc || currModule->tables.size() > 1) { // tableuse o << " (table "; printName(curr->table, o); @@ -2724,18 +2736,26 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { o << ' '; visit(curr->offset); - if (currModule->tables.size() > 1) { + if (!allElementsRefFunc || currModule->tables.size() > 1) { o << ' '; - TypeNamePrinter(o, currModule).print(HeapType::func); + printElemType(); } } else { o << ' '; - TypeNamePrinter(o, currModule).print(HeapType::func); + printElemType(); } - for (auto name : curr->data) { - o << ' '; - printName(name, o); + if (allElementsRefFunc) { + for (auto* entry : curr->data) { + auto* refFunc = entry->cast<RefFunc>(); + o << ' '; + printName(refFunc->func, o); + } + } else { + for (auto* entry : curr->data) { + o << ' '; + printExpression(entry, o); + } } o << ')' << maybeNewLine; } diff --git a/src/passes/PrintCallGraph.cpp b/src/passes/PrintCallGraph.cpp index d70e0d7f2..e4576b84b 100644 --- a/src/passes/PrintCallGraph.cpp +++ b/src/passes/PrintCallGraph.cpp @@ -22,6 +22,7 @@ #include <iomanip> #include <memory> +#include "ir/element-utils.h" #include "ir/module-utils.h" #include "ir/utils.h" #include "pass.h" @@ -96,12 +97,10 @@ struct PrintCallGraph : public Pass { CallPrinter printer(module); // Indirect Targets - for (auto& segment : module->elementSegments) { - for (auto& curr : segment->data) { - auto* func = module->getFunction(curr); - o << " \"" << func->name << "\" [style=\"filled, rounded\"];\n"; - } - } + ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) { + auto* func = module->getFunction(name); + o << " \"" << func->name << "\" [style=\"filled, rounded\"];\n"; + }); o << "}\n"; } diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp index be41717f4..2a1c119fd 100644 --- a/src/passes/RemoveImports.cpp +++ b/src/passes/RemoveImports.cpp @@ -22,6 +22,7 @@ // look at all the rest of the code). // +#include "ir/element-utils.h" #include "ir/module-utils.h" #include "pass.h" #include "wasm.h" @@ -49,11 +50,8 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> { *curr, [&](Function* func) { names.push_back(func->name); }); // Do not remove names referenced in a table std::set<Name> indirectNames; - for (auto& segment : curr->elementSegments) { - for (auto& name : segment->data) { - indirectNames.insert(name); - } - } + ElementUtils::iterAllElementFunctionNames( + curr, [&](Name& name) { indirectNames.insert(name); }); for (auto& name : names) { if (indirectNames.find(name) == indirectNames.end()) { curr->removeFunction(name); diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index fa295dac5..b5cf2635a 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -22,6 +22,7 @@ #include <memory> +#include "ir/element-utils.h" #include "ir/module-utils.h" #include "ir/utils.h" #include "pass.h" @@ -194,11 +195,9 @@ struct RemoveUnusedModuleElements : public Pass { importsMemory = true; } // For now, all functions that can be called indirectly are marked as roots. - for (auto& segment : module->elementSegments) { - for (auto& curr : segment->data) { - roots.emplace_back(ModuleElementKind::Function, curr); - } - } + ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) { + roots.emplace_back(ModuleElementKind::Function, name); + }); // Compute reachability starting from the root set. ReachabilityAnalyzer analyzer(module, roots); // Remove unreachable elements. diff --git a/src/passes/ReorderFunctions.cpp b/src/passes/ReorderFunctions.cpp index 0c95101a5..66b8275ef 100644 --- a/src/passes/ReorderFunctions.cpp +++ b/src/passes/ReorderFunctions.cpp @@ -29,6 +29,7 @@ #include <memory> +#include <ir/element-utils.h> #include <pass.h> #include <wasm.h> @@ -70,11 +71,8 @@ struct ReorderFunctions : public Pass { for (auto& curr : module->exports) { counts[curr->value]++; } - for (auto& segment : module->elementSegments) { - for (auto& curr : segment->data) { - counts[curr]++; - } - } + ElementUtils::iterAllElementFunctionNames( + module, [&](Name& name) { counts[name]++; }); // sort std::sort(module->functions.begin(), module->functions.end(), diff --git a/src/passes/opt-utils.h b/src/passes/opt-utils.h index 27a3c5c25..b333779f7 100644 --- a/src/passes/opt-utils.h +++ b/src/passes/opt-utils.h @@ -20,6 +20,7 @@ #include <functional> #include <unordered_set> +#include <ir/element-utils.h> #include <pass.h> #include <wasm.h> @@ -86,11 +87,7 @@ inline void replaceFunctions(PassRunner* runner, // replace direct calls FunctionRefReplacer(maybeReplace).run(runner, &module); // replace in table - for (auto& segment : module.elementSegments) { - for (auto& name : segment->data) { - maybeReplace(name); - } - } + ElementUtils::iterAllElementFunctionNames(&module, maybeReplace); // replace in start if (module.start.is()) { diff --git a/src/shell-interface.h b/src/shell-interface.h index 1cb8768cf..0ba4946dd 100644 --- a/src/shell-interface.h +++ b/src/shell-interface.h @@ -86,7 +86,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { } } memory; - std::unordered_map<Name, std::vector<Name>> tables; + std::unordered_map<Name, std::vector<Literal>> tables; ShellExternalInterface() : memory() {} virtual ~ShellExternalInterface() = default; @@ -177,7 +177,10 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { if (index >= table.size()) { trap("callTable overflow"); } - auto* func = instance.wasm.getFunctionOrNull(table[index]); + Function* func = nullptr; + if (table[index].isFunction() && !table[index].isNull()) { + func = instance.wasm.getFunctionOrNull(table[index].getFunc()); + } if (!func) { trap("uninitialized table element"); } @@ -231,7 +234,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { memory.set<std::array<uint8_t, 16>>(addr, value); } - void tableStore(Name tableName, Address addr, Name entry) override { + void tableStore(Name tableName, Address addr, Literal entry) override { tables[tableName][addr] = entry; } diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index a0d5430d8..595652c94 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -723,7 +723,9 @@ private: while (oneIn(3) && !finishedInput) { auto& randomElem = wasm.elementSegments[upTo(wasm.elementSegments.size())]; - randomElem->data.push_back(func->name); + // FIXME: make the type NonNullable when we support it! + auto type = Type(HeapType(func->sig), Nullable); + randomElem->data.push_back(builder.makeRefFunc(func->name, type)); } numAddedFunctions++; return func; @@ -1436,11 +1438,13 @@ private: bool isReturn; while (1) { // TODO: handle unreachable - targetFn = wasm.getFunction(data[i]); - isReturn = type == Type::unreachable && wasm.features.hasTailCall() && - funcContext->func->sig.results == targetFn->sig.results; - if (targetFn->sig.results == type || isReturn) { - break; + if (auto* get = data[i]->dynCast<RefFunc>()) { + targetFn = wasm.getFunction(get->func); + isReturn = type == Type::unreachable && wasm.features.hasTailCall() && + funcContext->func->sig.results == targetFn->sig.results; + if (targetFn->sig.results == type || isReturn) { + break; + } } i++; if (i == data.size()) { diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp index f39d8a34e..13a522135 100644 --- a/src/tools/wasm-ctor-eval.cpp +++ b/src/tools/wasm-ctor-eval.cpp @@ -251,19 +251,25 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface { } auto end = start + segment->data.size(); if (start <= index && index < end) { - auto name = segment->data[index - start]; - // if this is one of our functions, we can call it; if it was imported, - // fail - auto* func = wasm->getFunction(name); - if (func->sig != sig) { - throw FailToEvalException( - std::string("callTable signature mismatch: ") + name.str); - } - if (!func->imported()) { - return instance.callFunctionInternal(name, arguments); + auto entry = segment->data[index - start]; + if (auto* get = entry->dynCast<RefFunc>()) { + auto name = get->func; + // if this is one of our functions, we can call it; if it was + // imported, fail + auto* func = wasm->getFunction(name); + if (func->sig != sig) { + throw FailToEvalException( + std::string("callTable signature mismatch: ") + name.str); + } + if (!func->imported()) { + return instance.callFunctionInternal(name, arguments); + } else { + throw FailToEvalException( + std::string("callTable on imported function: ") + name.str); + } } else { throw FailToEvalException( - std::string("callTable on imported function: ") + name.str); + std::string("callTable on uninitialized entry")); } } } @@ -295,7 +301,7 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface { } // called during initialization, but we don't keep track of a table - void tableStore(Name tableName, Address addr, Name value) override {} + void tableStore(Name tableName, Address addr, Literal value) override {} bool growMemory(Address /*oldSize*/, Address newSize) override { throw FailToEvalException("grow memory"); diff --git a/src/tools/wasm-metadce.cpp b/src/tools/wasm-metadce.cpp index 71ea2f693..5ddf52dae 100644 --- a/src/tools/wasm-metadce.cpp +++ b/src/tools/wasm-metadce.cpp @@ -26,6 +26,7 @@ #include <memory> +#include "ir/element-utils.h" #include "ir/module-utils.h" #include "pass.h" #include "support/colors.h" @@ -216,13 +217,14 @@ struct MetaDCEGraph { ModuleUtils::iterActiveElementSegments(wasm, [&](ElementSegment* segment) { // TODO: currently, all functions in the table are roots, but we // should add an option to refine that - for (auto& name : segment->data) { - if (!wasm.getFunction(name)->imported()) { - roots.insert(functionToDCENode[name]); - } else { - roots.insert(importIdToDCENode[getFunctionImportId(name)]); - } - } + ElementUtils::iterElementSegmentFunctionNames( + segment, [&](Name name, Index) { + if (!wasm.getFunction(name)->imported()) { + roots.insert(functionToDCENode[name]); + } else { + roots.insert(importIdToDCENode[getFunctionImportId(name)]); + } + }); rooter.walk(segment->offset); }); for (auto& segment : wasm.memory.segments) { diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp index efe5c1caa..19afc2032 100644 --- a/src/tools/wasm-reduce.cpp +++ b/src/tools/wasm-reduce.cpp @@ -766,7 +766,31 @@ struct Reducer } // the "opposite" of shrinking: copy a 'zero' element for (auto& segment : curr->segments) { - reduceByZeroing(&segment, 0, 2, shrank); + reduceByZeroing( + &segment, 0, [](char item) { return item == 0; }, 2, shrank); + } + } + + template<typename T, typename U, typename C> + void + reduceByZeroing(T* segment, U zero, C isZero, size_t bonus, bool shrank) { + for (auto& item : segment->data) { + if (!shouldTryToReduce(bonus) || isZero(item)) { + continue; + } + auto save = item; + item = zero; + if (writeAndTestReduction()) { + std::cerr << "| zeroed elem segment\n"; + noteReduction(); + } else { + item = save; + } + if (shrank) { + // zeroing is fairly inefficient. if we are managing to shrink + // (which we do exponentially), just zero one per segment at most + break; + } } } @@ -803,37 +827,9 @@ struct Reducer return shrank; } - template<typename T, typename U> - void reduceByZeroing(T* segment, U zero, size_t bonus, bool shrank) { - if (segment->data.empty()) { - return; - } - for (auto& item : segment->data) { - if (!shouldTryToReduce(bonus)) { - continue; - } - if (item == zero) { - continue; - } - auto save = item; - item = zero; - if (writeAndTestReduction()) { - std::cerr << "| zeroed elem segment\n"; - noteReduction(); - } else { - item = save; - } - if (shrank) { - // zeroing is fairly inefficient. if we are managing to shrink - // (which we do exponentially), just zero one per segment at most - break; - } - } - } - void shrinkElementSegments(Module* module) { std::cerr << "| try to simplify elem segments\n"; - Name first; + Expression* first = nullptr; auto it = std::find_if_not(module->elementSegments.begin(), module->elementSegments.end(), @@ -842,6 +838,10 @@ struct Reducer if (it != module->elementSegments.end()) { first = it->get()->data[0]; } + if (first == nullptr) { + // The elements are all empty, nothing to shrink + return; + } // try to reduce to first function. first, shrink segment elements. // while we are shrinking successfully, keep going exponentially. @@ -851,7 +851,24 @@ struct Reducer } // the "opposite" of shrinking: copy a 'zero' element for (auto& segment : module->elementSegments) { - reduceByZeroing(segment.get(), first, 100, shrank); + reduceByZeroing( + segment.get(), + first, + [&](Expression* entry) { + if (entry->is<RefNull>()) { + // we don't need to replace a ref.null + return true; + } else if (first->is<RefNull>()) { + return false; + } else { + // Both are ref.func + auto* f = first->cast<RefFunc>(); + auto* e = entry->cast<RefFunc>(); + return f->func == e->func; + } + }, + 100, + shrank); } } diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index efa2c0ca4..6d2077e1c 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -23,6 +23,7 @@ #include <memory> #include "execution-results.h" +#include "ir/element-utils.h" #include "pass.h" #include "shell-interface.h" #include "support/command-line.h" @@ -166,18 +167,16 @@ run_asserts(Name moduleName, reportUnknownImport(import); } }); - for (auto& segment : wasm.elementSegments) { - for (auto name : segment->data) { - // spec tests consider it illegal to use spectest.print in a table - if (auto* import = wasm.getFunction(name)) { - if (import->imported() && import->module == SPECTEST && - import->base.startsWith(PRINT)) { - std::cerr << "cannot put spectest.print in table\n"; - invalid = true; - } + ElementUtils::iterAllElementFunctionNames(&wasm, [&](Name name) { + // spec tests consider it illegal to use spectest.print in a table + if (auto* import = wasm.getFunction(name)) { + if (import->imported() && import->module == SPECTEST && + import->base.startsWith(PRINT)) { + std::cerr << "cannot put spectest.print in table\n"; + invalid = true; } } - } + }); if (wasm.memory.imported()) { reportUnknownImport(&wasm.memory); } diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 1cf8186f8..ec89bb1da 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1531,9 +1531,6 @@ public: void readDataSegments(); void readDataCount(); - // A map from elem segment indexes to their entries - std::map<Index, std::vector<Index>> functionTable; - void readTableDeclarations(); void readElementSegments(); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index dcf9d2540..61d29ada1 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -2222,7 +2222,7 @@ public: WASM_UNREACHABLE("unimp"); } - virtual void tableStore(Name tableName, Address addr, Name entry) { + virtual void tableStore(Name tableName, Address addr, Literal entry) { WASM_UNREACHABLE("unimp"); } }; @@ -2309,21 +2309,22 @@ private: std::unordered_set<size_t> droppedSegments; void initializeTableContents() { - for (auto& segment : wasm.elementSegments) { - if (segment->table.isNull()) { - continue; - } - + ModuleUtils::iterActiveElementSegments(wasm, [&](ElementSegment* segment) { Address offset = (uint32_t)InitializerExpressionRunner<GlobalManager>(globals, maxDepth) .visit(segment->offset) .getSingleValue() .geti32(); - for (size_t i = 0; i != segment->data.size(); ++i) { + + Function dummyFunc; + FunctionScope dummyScope(&dummyFunc, {}); + RuntimeExpressionRunner runner(*this, dummyScope, maxDepth); + for (Index i = 0; i < segment->data.size(); ++i) { + Flow ret = runner.visit(segment->data[i]); externalInterface->tableStore( - segment->table, offset + i, segment->data[i]); + segment->table, offset + i, ret.getSingleValue()); } - } + }); } void initializeMemoryContents() { diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index eaddd8c8b..03da961bd 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -317,7 +317,8 @@ private: void parseElem(Element& s, Table* table = nullptr); ElementSegment* parseElemFinish(Element& s, std::unique_ptr<ElementSegment>& segment, - Index i = 1); + Index i = 1, + bool usesExpressions = false); // Parses something like (func ..), (array ..), (struct) HeapType parseHeapType(Element& s); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 7dcf2d146..1f307c318 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -196,6 +196,9 @@ struct Walker : public VisitorType { if (segment->table.is()) { walk(segment->offset); } + for (auto* expr : segment->data) { + walk(expr); + } static_cast<SubType*>(this)->visitElementSegment(segment); } diff --git a/src/wasm.h b/src/wasm.h index 630e5b003..4747b48a3 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1722,12 +1722,12 @@ class ElementSegment : public Named { public: Name table; Expression* offset; - std::vector<Name> data; + std::vector<Expression*> data; ElementSegment() = default; ElementSegment(Name table, Expression* offset) : table(table), offset(offset) {} - ElementSegment(Name table, Expression* offset, std::vector<Name>& init) + ElementSegment(Name table, Expression* offset, std::vector<Expression*>& init) : table(table), offset(offset) { data.swap(init); } diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 09d3287b0..5fff58045 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -571,9 +571,11 @@ void WasmBinaryWriter::writeElementSegments() { Index tableIdx = 0; bool isPassive = segment->table.isNull(); - // TODO(reference-types): add support for writing expressions instead of - // function indices. - bool usesExpressions = false; + // if all items are ref.func, we can use the shorter form. + bool usesExpressions = + std::any_of(segment->data.begin(), + segment->data.end(), + [](Expression* curr) { return !curr->is<RefFunc>(); }); bool hasTableIndex = false; if (!isPassive) { @@ -600,13 +602,27 @@ void WasmBinaryWriter::writeElementSegments() { o << int8_t(BinaryConsts::End); } - if (!usesExpressions && (isPassive || hasTableIndex)) { - // elemKind funcref - o << U32LEB(0); + if (isPassive || hasTableIndex) { + if (usesExpressions) { + // elemType funcref + writeType(Type::funcref); + } else { + // elemKind funcref + o << U32LEB(0); + } } o << U32LEB(segment->data.size()); - for (auto& name : segment->data) { - o << U32LEB(getFunctionIndex(name)); + if (usesExpressions) { + for (auto* item : segment->data) { + writeExpression(item); + o << int8_t(BinaryConsts::End); + } + } else { + for (auto& item : segment->data) { + // We've ensured that all items are ref.func. + auto& name = item->cast<RefFunc>()->func; + o << U32LEB(getFunctionIndex(name)); + } } } @@ -2676,14 +2692,6 @@ void WasmBinaryBuilder::processNames() { } } - for (auto& pair : functionTable) { - auto i = pair.first; - auto& indices = pair.second; - for (auto j : indices) { - wasm.elementSegments[i]->data.push_back(getFunctionName(j)); - } - } - for (auto& iter : globalRefs) { size_t index = iter.first; auto& refs = iter.second; @@ -2787,10 +2795,6 @@ void WasmBinaryBuilder::readElementSegments() { continue; } - if (usesExpressions) { - throwError("Only elem segments with function indexes are supported."); - } - if (!isPassive) { Index tableIdx = 0; if (hasTableIdx) { @@ -2819,17 +2823,35 @@ void WasmBinaryBuilder::readElementSegments() { } if (isPassive || hasTableIdx) { - auto elemKind = getU32LEB(); - if (elemKind != 0x0) { - throwError("Only funcref elem kinds are valid."); + if (usesExpressions) { + auto type = getType(); + if (type != Type::funcref) { + throwError("Only funcref elem kinds are valid."); + } + } else { + auto elemKind = getU32LEB(); + if (elemKind != 0x0) { + throwError("Only funcref elem kinds are valid."); + } } } - size_t segmentIndex = functionTable.size(); - auto& indexSegment = functionTable[segmentIndex]; + auto& segmentData = elementSegments.back()->data; auto size = getU32LEB(); - for (Index j = 0; j < size; j++) { - indexSegment.push_back(getU32LEB()); + if (usesExpressions) { + for (Index j = 0; j < size; j++) { + segmentData.push_back(readExpression()); + } + } else { + for (Index j = 0; j < size; j++) { + Index index = getU32LEB(); + auto sig = getSignatureByFunctionIndex(index); + // Use a placeholder name for now + auto* refFunc = Builder(wasm).makeRefFunc( + Name::fromInt(index), Type(HeapType(sig), Nullable)); + functionRefs[index].push_back(refFunc); + segmentData.push_back(refFunc); + } } } } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 45c7c9161..7774eac3e 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -3268,12 +3268,14 @@ void SExpressionWasmBuilder::parseElem(Element& s, Table* table) { Index i = 1; Name name = Name::fromInt(elemCounter++); bool hasExplicitName = false; + bool isPassive = false; + bool usesExpressions = false; if (table) { Expression* offset = allocator.alloc<Const>()->set(Literal(int32_t(0))); auto segment = std::make_unique<ElementSegment>(table->name, offset); segment->setName(name, hasExplicitName); - parseElemFinish(s, segment, i); + parseElemFinish(s, segment, i, false); return; } @@ -3286,10 +3288,20 @@ void SExpressionWasmBuilder::parseElem(Element& s, Table* table) { return; } - if (s[i]->isStr() && s[i]->str() == FUNC) { + if (s[i]->isStr()) { + if (s[i]->str() == FUNC) { + isPassive = true; + usesExpressions = false; + } else if (s[i]->str() == FUNCREF) { + isPassive = true; + usesExpressions = true; + } + } + + if (isPassive) { auto segment = std::make_unique<ElementSegment>(); segment->setName(name, hasExplicitName); - parseElemFinish(s, segment, i + 1); + parseElemFinish(s, segment, i + 1, usesExpressions); return; } @@ -3326,11 +3338,11 @@ void SExpressionWasmBuilder::parseElem(Element& s, Table* table) { } if (!oldStyle) { - if (s[i]->str() != FUNC) { - throw ParseException( - "only the abbreviated form of elemList is supported."); + if (s[i]->str() == FUNCREF) { + usesExpressions = true; + } else if (s[i]->str() != FUNC) { + throw ParseException("expected func or funcref."); } - // ignore elemType for now i += 1; } @@ -3340,13 +3352,40 @@ void SExpressionWasmBuilder::parseElem(Element& s, Table* table) { auto segment = std::make_unique<ElementSegment>(table->name, offset); segment->setName(name, hasExplicitName); - parseElemFinish(s, segment, i); + parseElemFinish(s, segment, i, usesExpressions); } ElementSegment* SExpressionWasmBuilder::parseElemFinish( - Element& s, std::unique_ptr<ElementSegment>& segment, Index i) { - for (; i < s.size(); i++) { - segment->data.push_back(getFunctionName(*s[i])); + Element& s, + std::unique_ptr<ElementSegment>& segment, + Index i, + bool usesExpressions) { + + if (usesExpressions) { + for (; i < s.size(); i++) { + if (!s[i]->isList()) { + throw ParseException("expected a ref.* expression."); + } + auto& inner = *s[i]; + if (elementStartsWith(inner, ITEM)) { + if (inner[1]->isList()) { + // (item (ref.func $f)) + segment->data.push_back(parseExpression(inner[1])); + } else { + // (item ref.func $f) + inner.list().removeAt(0); + segment->data.push_back(parseExpression(inner)); + } + } else { + segment->data.push_back(parseExpression(inner)); + } + } + } else { + for (; i < s.size(); i++) { + auto func = getFunctionName(*s[i]); + segment->data.push_back(Builder(wasm).makeRefFunc( + func, Type(HeapType(functionSignatures[func]), Nullable))); + } } return wasm.addElementSegment(std::move(segment)); } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 71ecf208e..52ba36872 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -1960,7 +1960,10 @@ void FunctionValidator::visitMemoryGrow(MemoryGrow* curr) { } void FunctionValidator::visitRefNull(RefNull* curr) { - shouldBeTrue(getModule()->features.hasReferenceTypes(), + // If we are not in a function, this is a global location like a table. We + // allow RefNull there as we represent tables that way regardless of what + // features are enabled. + shouldBeTrue(!getFunction() || getModule()->features.hasReferenceTypes(), curr, "ref.null requires reference-types to be enabled"); shouldBeTrue( @@ -1978,7 +1981,10 @@ void FunctionValidator::visitRefIs(RefIs* curr) { } void FunctionValidator::visitRefFunc(RefFunc* curr) { - shouldBeTrue(getModule()->features.hasReferenceTypes(), + // If we are not in a function, this is a global location like a table. We + // allow RefFunc there as we represent tables that way regardless of what + // features are enabled. + shouldBeTrue(!getFunction() || getModule()->features.hasReferenceTypes(), curr, "ref.func requires reference-types to be enabled"); if (!info.validateGlobally) { @@ -2799,11 +2805,29 @@ static void validateMemory(Module& module, ValidationInfo& info) { } static void validateTables(Module& module, ValidationInfo& info) { + FunctionValidator validator(module, &info); + if (!module.features.hasReferenceTypes()) { info.shouldBeTrue(module.tables.size() <= 1, "table", "Only 1 table definition allowed in MVP (requires " "--enable-reference-types)"); + if (!module.tables.empty()) { + auto& table = module.tables.front(); + for (auto& segment : module.elementSegments) { + info.shouldBeTrue(segment->table == table->name, + "elem", + "all element segments should refer to a single table " + "in MVP."); + for (auto* expr : segment->data) { + info.shouldBeTrue( + expr->is<RefFunc>(), + expr, + "all table elements must be non-null funcrefs in MVP."); + validator.validate(expr); + } + } + } } for (auto& segment : module.elementSegments) { @@ -2820,11 +2844,17 @@ static void validateTables(Module& module, ValidationInfo& info) { table->initial * Table::kPageSize), segment->offset, "table segment offset should be reasonable"); - FunctionValidator(module, &info).validate(segment->offset); - } - for (auto name : segment->data) { - info.shouldBeTrue( - module.getFunctionOrNull(name), name, "segment name should be valid"); + validator.validate(segment->offset); + } + // Avoid double checking items + if (module.features.hasReferenceTypes()) { + for (auto* expr : segment->data) { + info.shouldBeTrue( + expr->is<RefFunc>() || expr->is<RefNull>(), + expr, + "element segment items must be either ref.func or ref.null func."); + validator.validate(expr); + } } } } diff --git a/src/wasm2js.h b/src/wasm2js.h index b3a4c7e18..6d3f0682c 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h @@ -32,6 +32,7 @@ #include "emscripten-optimizer/optimizer.h" #include "ir/branch-utils.h" #include "ir/effects.h" +#include "ir/element-utils.h" #include "ir/find_all.h" #include "ir/import-utils.h" #include "ir/load-utils.h" @@ -325,11 +326,8 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) { functionsCallableFromOutside.insert(exp->value); } } - for (auto& segment : wasm->elementSegments) { - for (auto name : segment->data) { - functionsCallableFromOutside.insert(name); - } - } + ElementUtils::iterAllElementFunctionNames( + wasm, [&](Name name) { functionsCallableFromOutside.insert(name); }); // Ensure the scratch memory helpers. // If later on they aren't needed, we'll clean them up. @@ -680,26 +678,27 @@ void Wasm2JSBuilder::addTable(Ref ast, Module* wasm) { ModuleUtils::iterTableSegments( *wasm, table->name, [&](ElementSegment* segment) { auto offset = segment->offset; - for (Index i = 0; i < segment->data.size(); i++) { - Ref index; - if (auto* c = offset->dynCast<Const>()) { - index = ValueBuilder::makeInt(c->value.geti32() + i); - } else if (auto* get = offset->dynCast<GlobalGet>()) { - index = - ValueBuilder::makeBinary(ValueBuilder::makeName(stringToIString( - asmangle(get->name.str))), - PLUS, - ValueBuilder::makeNum(i)); - } else { - WASM_UNREACHABLE("unexpected expr type"); - } - ast->push_back(ValueBuilder::makeStatement(ValueBuilder::makeBinary( - ValueBuilder::makeSub(ValueBuilder::makeName(FUNCTION_TABLE), - index), - SET, - ValueBuilder::makeName( - fromName(segment->data[i], NameScope::Top))))); - } + ElementUtils::iterElementSegmentFunctionNames( + segment, [&](Name entry, Index i) { + Ref index; + if (auto* c = offset->dynCast<Const>()) { + index = ValueBuilder::makeInt(c->value.geti32() + i); + } else if (auto* get = offset->dynCast<GlobalGet>()) { + index = ValueBuilder::makeBinary( + ValueBuilder::makeName( + stringToIString(asmangle(get->name.str))), + PLUS, + ValueBuilder::makeNum(i)); + } else { + WASM_UNREACHABLE("unexpected expr type"); + } + ast->push_back( + ValueBuilder::makeStatement(ValueBuilder::makeBinary( + ValueBuilder::makeSub(ValueBuilder::makeName(FUNCTION_TABLE), + index), + SET, + ValueBuilder::makeName(fromName(entry, NameScope::Top))))); + }); }); } } diff --git a/test/multi-table.wast b/test/multi-table.wast index 378d4018a..cbfa369e0 100644 --- a/test/multi-table.wast +++ b/test/multi-table.wast @@ -1,6 +1,7 @@ (module (import "a" "b" (table $t1 1 10 funcref)) (table $t2 3 3 funcref) + (table $t3 4 4 funcref) ;; add to $t1 (elem (i32.const 0) $f) @@ -9,7 +10,10 @@ (elem (table $t2) (i32.const 0) func $f) (elem $activeNonZeroOffset (table $t2) (offset (i32.const 1)) func $f $g) - (elem $passive func $f $g) + (elem $e3-1 (table $t3) (i32.const 0) funcref (ref.func $f) (ref.null func)) + (elem $e3-2 (table $t3) (offset (i32.const 2)) funcref (item ref.func $f) (item (ref.func $g))) + + (elem $passive funcref (item ref.func $f) (item (ref.func $g)) (ref.null func)) (elem $empty func) (elem $declarative declare func $h) diff --git a/test/multi-table.wast.from-wast b/test/multi-table.wast.from-wast index f60bab8cb..80c081354 100644 --- a/test/multi-table.wast.from-wast +++ b/test/multi-table.wast.from-wast @@ -5,7 +5,10 @@ (table $t2 3 3 funcref) (elem (table $t2) (i32.const 0) func $f) (elem $activeNonZeroOffset (table $t2) (i32.const 1) func $f $g) - (elem $passive func $f $g) + (table $t3 4 4 funcref) + (elem $e3-1 (table $t3) (i32.const 0) funcref (ref.func $f) (ref.null func)) + (elem $e3-2 (table $t3) (i32.const 2) func $f $g) + (elem $passive funcref (ref.func $f) (ref.func $g) (ref.null func)) (elem declare func $h) (func $f (drop diff --git a/test/multi-table.wast.fromBinary b/test/multi-table.wast.fromBinary index 13604e69d..9918b6183 100644 --- a/test/multi-table.wast.fromBinary +++ b/test/multi-table.wast.fromBinary @@ -5,7 +5,10 @@ (table $t2 3 3 funcref) (elem (table $t2) (i32.const 0) func $f) (elem $activeNonZeroOffset (table $t2) (i32.const 1) func $f $g) - (elem $passive func $f $g) + (table $t3 4 4 funcref) + (elem $e3-1 (table $t3) (i32.const 0) funcref (ref.func $f) (ref.null func)) + (elem $e3-2 (table $t3) (i32.const 2) func $f $g) + (elem $passive funcref (ref.func $f) (ref.func $g) (ref.null func)) (elem declare func $h) (func $f (drop diff --git a/test/multi-table.wast.fromBinary.noDebugInfo b/test/multi-table.wast.fromBinary.noDebugInfo index 77e2ea439..638ee7329 100644 --- a/test/multi-table.wast.fromBinary.noDebugInfo +++ b/test/multi-table.wast.fromBinary.noDebugInfo @@ -5,7 +5,10 @@ (table $0 3 3 funcref) (elem (table $0) (i32.const 0) func $0) (elem (table $0) (i32.const 1) func $0 $1) - (elem func $0 $1) + (table $1 4 4 funcref) + (elem (table $1) (i32.const 0) funcref (ref.func $0) (ref.null func)) + (elem (table $1) (i32.const 2) func $0 $1) + (elem funcref (ref.func $0) (ref.func $1) (ref.null func)) (elem declare func $2) (func $0 (drop diff --git a/test/passes/converge_O3_metrics.bin.txt b/test/passes/converge_O3_metrics.bin.txt index 2df4a4b0b..4c1bed30d 100644 --- a/test/passes/converge_O3_metrics.bin.txt +++ b/test/passes/converge_O3_metrics.bin.txt @@ -7,7 +7,7 @@ total [memory-data] : 28 [table-data] : 429 [tables] : 0 - [total] : 129 + [total] : 558 [vars] : 4 Binary : 12 Block : 8 @@ -23,6 +23,7 @@ total LocalGet : 18 LocalSet : 7 Loop : 1 + RefFunc : 429 Store : 5 (module (type $i32_i32_i32_=>_i32 (func (param i32 i32 i32) (result i32))) @@ -249,7 +250,7 @@ total [memory-data] : 28 [table-data] : 429 [tables] : 0 - [total] : 129 + [total] : 558 [vars] : 4 Binary : 12 Block : 8 @@ -265,6 +266,7 @@ total LocalGet : 18 LocalSet : 7 Loop : 1 + RefFunc : 429 Store : 5 (module (type $i32_i32_i32_=>_i32 (func (param i32 i32 i32) (result i32))) diff --git a/test/passes/func-metrics.txt b/test/passes/func-metrics.txt index 1ca5d2fd9..61d97ddad 100644 --- a/test/passes/func-metrics.txt +++ b/test/passes/func-metrics.txt @@ -7,8 +7,9 @@ global [memory-data] : 9 [table-data] : 3 [tables] : 1 - [total] : 3 + [total] : 6 Const : 3 + RefFunc : 3 func: empty [binary-bytes] : 3 [total] : 1 diff --git a/test/passes/fuzz_metrics_noprint.bin.txt b/test/passes/fuzz_metrics_noprint.bin.txt index eb24adfa4..d2b5e37e9 100644 --- a/test/passes/fuzz_metrics_noprint.bin.txt +++ b/test/passes/fuzz_metrics_noprint.bin.txt @@ -7,7 +7,7 @@ total [memory-data] : 4 [table-data] : 17 [tables] : 1 - [total] : 5797 + [total] : 5814 [vars] : 160 Binary : 467 Block : 810 @@ -24,6 +24,7 @@ total LocalSet : 292 Loop : 127 Nop : 69 + RefFunc : 17 Return : 259 Select : 53 Store : 53 diff --git a/test/passes/metrics_all-features.txt b/test/passes/metrics_all-features.txt index ae746355d..14c1bc72d 100644 --- a/test/passes/metrics_all-features.txt +++ b/test/passes/metrics_all-features.txt @@ -7,13 +7,14 @@ total [memory-data] : 9 [table-data] : 3 [tables] : 1 - [total] : 27 + [total] : 30 [vars] : 1 Binary : 1 Block : 1 Const : 15 Drop : 6 If : 4 + RefFunc : 3 (module (type $0 (func (param i32))) (type $i32_i32_=>_none (func (param i32 i32))) diff --git a/test/spec/call_indirect_refnull.wast b/test/spec/call_indirect_refnull.wast new file mode 100644 index 000000000..f1b4f5978 --- /dev/null +++ b/test/spec/call_indirect_refnull.wast @@ -0,0 +1,12 @@ +(module + (table $t 1 1 funcref) + (elem (table $t) (i32.const 0) funcref (ref.null func)) + + (func $call-refnull (export "call-refnull") (result f32) + (call_indirect (result f32) (i32.const 0)) + ) +) +(assert_trap + (invoke "call-refnull") + "uninitialized table element" +)
\ No newline at end of file |