diff options
Diffstat (limited to 'src')
39 files changed, 998 insertions, 695 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index e3780ce15..b9bf9752f 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -3374,47 +3374,13 @@ BinaryenExportRef BinaryenGetExportByIndex(BinaryenModuleRef module, return exports[index].get(); } -// TODO(reference-types): maybe deprecate this function? -void BinaryenSetFunctionTable(BinaryenModuleRef module, - BinaryenIndex initial, - BinaryenIndex maximum, - const char** funcNames, - BinaryenIndex numFuncNames, - BinaryenExpressionRef offset) { - auto* wasm = (Module*)module; - if (wasm->tables.empty()) { - wasm->addTable(Builder::makeTable(Name::fromInt(0))); - } - - auto& table = wasm->tables.front(); - table->initial = initial; - table->max = maximum; - - Table::Segment segment((Expression*)offset); - for (BinaryenIndex i = 0; i < numFuncNames; i++) { - segment.data.push_back(funcNames[i]); - } - table->segments.push_back(segment); -} - BinaryenTableRef BinaryenAddTable(BinaryenModuleRef module, const char* name, BinaryenIndex initial, - BinaryenIndex maximum, - const char** funcNames, - BinaryenIndex numFuncNames, - BinaryenExpressionRef offset) { + BinaryenIndex maximum) { auto table = Builder::makeTable(name, initial, maximum); table->hasExplicitName = true; - - Table::Segment segment((Expression*)offset); - for (BinaryenIndex i = 0; i < numFuncNames; i++) { - segment.data.push_back(funcNames[i]); - } - table->segments.push_back(segment); - ((Module*)module)->addTable(std::move(table)); - - return ((Module*)module)->getTable(name); + return ((Module*)module)->addTable(std::move(table)); } void BinaryenRemoveTable(BinaryenModuleRef module, const char* table) { ((Module*)module)->removeTable(table); @@ -3433,59 +3399,68 @@ BinaryenTableRef BinaryenGetTableByIndex(BinaryenModuleRef module, } return tables[index].get(); } - -int BinaryenIsFunctionTableImported(BinaryenModuleRef module) { - if (((Module*)module)->tables.size() > 0) { - return ((Module*)module)->tables[0]->imported(); +BinaryenElementSegmentRef +BinaryenAddActiveElementSegment(BinaryenModuleRef module, + const char* table, + const char* name, + const char** funcNames, + BinaryenIndex numFuncNames, + BinaryenExpressionRef offset) { + auto segment = std::make_unique<ElementSegment>(table, (Expression*)offset); + segment->setExplicitName(name); + for (BinaryenIndex i = 0; i < numFuncNames; i++) { + segment->data.push_back(funcNames[i]); } - - return false; + return ((Module*)module)->addElementSegment(std::move(segment)); +} +BinaryenElementSegmentRef +BinaryenAddPassiveElementSegment(BinaryenModuleRef module, + const char* name, + const char** funcNames, + BinaryenIndex numFuncNames) { + auto segment = std::make_unique<ElementSegment>(); + segment->setExplicitName(name); + for (BinaryenIndex i = 0; i < numFuncNames; i++) { + segment->data.push_back(funcNames[i]); + } + return ((Module*)module)->addElementSegment(std::move(segment)); +} +void BinaryenRemoveElementSegment(BinaryenModuleRef module, const char* name) { + ((Module*)module)->removeElementSegment(name); +} +BinaryenElementSegmentRef BinaryenGetElementSegment(BinaryenModuleRef module, + const char* name) { + return ((Module*)module)->getElementSegmentOrNull(name); } -BinaryenIndex BinaryenGetNumFunctionTableSegments(BinaryenModuleRef module) { - if (((Module*)module)->tables.size() > 0) { - return ((Module*)module)->tables[0]->segments.size(); +BinaryenElementSegmentRef +BinaryenGetElementSegmentByIndex(BinaryenModuleRef module, + BinaryenIndex index) { + const auto& elementSegments = ((Module*)module)->elementSegments; + if (elementSegments.size() <= index) { + Fatal() << "invalid table index."; } - - return 0; + return elementSegments[index].get(); +} +BinaryenIndex BinaryenGetNumElementSegments(BinaryenModuleRef module) { + return ((Module*)module)->elementSegments.size(); } BinaryenExpressionRef -BinaryenGetFunctionTableSegmentOffset(BinaryenModuleRef module, - BinaryenIndex segmentId) { - if (((Module*)module)->tables.empty()) { - Fatal() << "module has no tables."; +BinaryenElementSegmentGetOffset(BinaryenElementSegmentRef elem) { + if (((ElementSegment*)elem)->table.isNull()) { + Fatal() << "elem segment is passive."; } - - const auto& segments = ((Module*)module)->tables[0]->segments; - if (segments.size() <= segmentId) { - Fatal() << "invalid function table segment id."; - } - return segments[segmentId].offset; + return ((ElementSegment*)elem)->offset; } -BinaryenIndex BinaryenGetFunctionTableSegmentLength(BinaryenModuleRef module, - BinaryenIndex segmentId) { - if (((Module*)module)->tables.empty()) { - Fatal() << "module has no tables."; - } - - const auto& segments = ((Module*)module)->tables[0]->segments; - if (segments.size() <= segmentId) { - Fatal() << "invalid function table segment id."; - } - return segments[segmentId].data.size(); +BinaryenIndex BinaryenElementSegmentGetLength(BinaryenElementSegmentRef elem) { + return ((ElementSegment*)elem)->data.size(); } -const char* BinaryenGetFunctionTableSegmentData(BinaryenModuleRef module, - BinaryenIndex segmentId, - BinaryenIndex dataId) { - if (((Module*)module)->tables.empty()) { - Fatal() << "module has no tables."; +const char* BinaryenElementSegmentGetData(BinaryenElementSegmentRef elem, + BinaryenIndex dataId) { + const auto& data = ((ElementSegment*)elem)->data; + if (data.size() <= dataId) { + Fatal() << "invalid segment data id."; } - - const auto& segments = ((Module*)module)->tables[0]->segments; - if (segments.size() <= segmentId || - segments[segmentId].data.size() <= dataId) { - Fatal() << "invalid function table segment or data id."; - } - return segments[segmentId].data[dataId].c_str(); + return data[dataId].c_str(); } // Memory. One per module @@ -3977,6 +3952,27 @@ void BinaryenTableSetMax(BinaryenTableRef table, BinaryenIndex max) { } // +// =========== ElementSegment operations =========== +// +const char* BinaryenElementSegmentGetName(BinaryenElementSegmentRef elem) { + return ((ElementSegment*)elem)->name.c_str(); +} +void BinaryenElementSegmentSetName(BinaryenElementSegmentRef elem, + const char* name) { + ((ElementSegment*)elem)->name = name; +} +const char* BinaryenElementSegmentGetTable(BinaryenElementSegmentRef elem) { + return ((ElementSegment*)elem)->table.c_str(); +} +void BinaryenElementSegmentSetTable(BinaryenElementSegmentRef elem, + const char* table) { + ((ElementSegment*)elem)->table = table; +} +int BinayenElementSegmentIsPassive(BinaryenElementSegmentRef elem) { + return ((ElementSegment*)elem)->table.isNull(); +} + +// // =========== Global operations =========== // diff --git a/src/binaryen-c.h b/src/binaryen-c.h index 24018e4d8..881b0fddc 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -2069,25 +2069,6 @@ BINARYEN_API BinaryenEventRef BinaryenGetEvent(BinaryenModuleRef module, BINARYEN_API void BinaryenRemoveEvent(BinaryenModuleRef module, const char* name); -// Function table. One per module - -// TODO: Add support for multiple segments in BinaryenSetFunctionTable. -BINARYEN_API void BinaryenSetFunctionTable(BinaryenModuleRef module, - BinaryenIndex initial, - BinaryenIndex maximum, - const char** funcNames, - BinaryenIndex numFuncNames, - BinaryenExpressionRef offset); -BINARYEN_API int BinaryenIsFunctionTableImported(BinaryenModuleRef module); -BINARYEN_API BinaryenIndex -BinaryenGetNumFunctionTableSegments(BinaryenModuleRef module); -BINARYEN_API BinaryenExpressionRef BinaryenGetFunctionTableSegmentOffset( - BinaryenModuleRef module, BinaryenIndex segmentId); -BINARYEN_API BinaryenIndex BinaryenGetFunctionTableSegmentLength( - BinaryenModuleRef module, BinaryenIndex segmentId); -BINARYEN_API const char* BinaryenGetFunctionTableSegmentData( - BinaryenModuleRef module, BinaryenIndex segmentId, BinaryenIndex dataId); - // Tables BINARYEN_REF(Table); @@ -2095,10 +2076,7 @@ BINARYEN_REF(Table); BINARYEN_API BinaryenTableRef BinaryenAddTable(BinaryenModuleRef module, const char* table, BinaryenIndex initial, - BinaryenIndex maximum, - const char** funcNames, - BinaryenIndex numFuncNames, - BinaryenExpressionRef offset); + BinaryenIndex maximum); BINARYEN_API void BinaryenRemoveTable(BinaryenModuleRef module, const char* table); BINARYEN_API BinaryenIndex BinaryenGetNumTables(BinaryenModuleRef module); @@ -2107,6 +2085,31 @@ BINARYEN_API BinaryenTableRef BinaryenGetTable(BinaryenModuleRef module, BINARYEN_API BinaryenTableRef BinaryenGetTableByIndex(BinaryenModuleRef module, BinaryenIndex index); +// Elem segments + +BINARYEN_REF(ElementSegment); + +BINARYEN_API BinaryenElementSegmentRef +BinaryenAddActiveElementSegment(BinaryenModuleRef module, + const char* table, + const char* name, + const char** funcNames, + BinaryenIndex numFuncNames, + BinaryenExpressionRef offset); +BINARYEN_API BinaryenElementSegmentRef +BinaryenAddPassiveElementSegment(BinaryenModuleRef module, + const char* name, + const char** funcNames, + BinaryenIndex numFuncNames); +BINARYEN_API void BinaryenRemoveElementSegment(BinaryenModuleRef module, + const char* name); +BINARYEN_API BinaryenIndex +BinaryenGetNumElementSegments(BinaryenModuleRef module); +BINARYEN_API BinaryenElementSegmentRef +BinaryenGetElementSegment(BinaryenModuleRef module, const char* name); +BINARYEN_API BinaryenElementSegmentRef +BinaryenGetElementSegmentByIndex(BinaryenModuleRef module, BinaryenIndex index); + // Memory. One per module // Each memory has data in segments, a start offset in segmentOffsets, and a @@ -2420,6 +2423,35 @@ BINARYEN_API void BinaryenTableSetMax(BinaryenTableRef table, BinaryenIndex max); // +// ========== Elem Segment Operations ========== +// + +// Gets the name of the specified `ElementSegment`. +BINARYEN_API const char* +BinaryenElementSegmentGetName(BinaryenElementSegmentRef elem); +// Sets the name of the specified `ElementSegment`. +BINARYEN_API void BinaryenElementSegmentSetName(BinaryenElementSegmentRef elem, + const char* name); +// Gets the table name of the specified `ElementSegment`. +BINARYEN_API const char* +BinaryenElementSegmentGetTable(BinaryenElementSegmentRef elem); +// Sets the table name of the specified `ElementSegment`. +BINARYEN_API void BinaryenElementSegmentSetTable(BinaryenElementSegmentRef elem, + const char* table); +// Gets the segment offset in case of active segments +BINARYEN_API BinaryenExpressionRef +BinaryenElementSegmentGetOffset(BinaryenElementSegmentRef elem); +// Gets the length of items in the segment +BINARYEN_API BinaryenIndex +BinaryenElementSegmentGetLength(BinaryenElementSegmentRef elem); +// Gets the item at the specified index +BINARYEN_API const char* +BinaryenElementSegmentGetData(BinaryenElementSegmentRef elem, + BinaryenIndex dataId); +// Returns true if the specified elem segment is passive +BINARYEN_API int BinayenElementSegmentIsPassive(BinaryenElementSegmentRef elem); + +// // ========== Global Operations ========== // diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp index de4184596..a0166381b 100644 --- a/src/ir/ReFinalize.cpp +++ b/src/ir/ReFinalize.cpp @@ -178,6 +178,9 @@ void ReFinalize::visitFunction(Function* curr) { void ReFinalize::visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::visitGlobal(Global* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::visitTable(Table* curr) { WASM_UNREACHABLE("unimp"); } +void ReFinalize::visitElementSegment(ElementSegment* curr) { + WASM_UNREACHABLE("unimp"); +} void ReFinalize::visitMemory(Memory* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::visitEvent(Event* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::visitModule(Module* curr) { WASM_UNREACHABLE("unimp"); } diff --git a/src/ir/module-splitting.cpp b/src/ir/module-splitting.cpp index 7215629e4..d4d03991e 100644 --- a/src/ir/module-splitting.cpp +++ b/src/ir/module-splitting.cpp @@ -87,20 +87,18 @@ namespace ModuleSplitting { namespace { template<class F> void forEachElement(Module& module, F f) { - for (auto& table : module.tables) { - for (auto& segment : table->segments) { - Name base = ""; - Index offset = 0; - if (auto* c = segment.offset->dynCast<Const>()) { - offset = c->value.geti32(); - } else if (auto* g = segment.offset->dynCast<GlobalGet>()) { - base = g->name; - } - for (size_t i = 0; i < segment.data.size(); ++i) { - f(table->name, base, offset + i, segment.data[i]); - } + ModuleUtils::iterActiveElementSegments(module, [&](ElementSegment* segment) { + Name base = ""; + Index offset = 0; + if (auto* c = segment->offset->dynCast<Const>()) { + offset = c->value.geti32(); + } else if (auto* g = segment->offset->dynCast<GlobalGet>()) { + base = g->name; } - } + for (size_t i = 0; i < segment->data.size(); ++i) { + f(segment->table, base, offset + i, segment->data[i]); + } + }); } struct TableSlotManager { @@ -116,9 +114,10 @@ struct TableSlotManager { }; Module& module; Table* activeTable = nullptr; - Table::Segment* activeSegment = nullptr; + ElementSegment* activeSegment = nullptr; Slot activeBase; std::map<Name, Slot> funcIndices; + std::vector<ElementSegment*> activeTableSegments; TableSlotManager(Module& module); @@ -148,22 +147,29 @@ void TableSlotManager::addSlot(Name func, Slot slot) { } TableSlotManager::TableSlotManager(Module& module) : module(module) { + // TODO: Reject or handle passive element segments + if (module.tables.empty()) { return; } activeTable = module.tables.front().get(); + ModuleUtils::iterTableSegments( + module, activeTable->name, [&](ElementSegment* segment) { + activeTableSegments.push_back(segment); + }); + // If there is exactly one table segment and that segment has a non-constant // offset, append new items to the end of that segment. In all other cases, // append new items at constant offsets after all existing items at constant // offsets. - if (activeTable->segments.size() == 1 && - !activeTable->segments[0].offset->is<Const>()) { - assert(activeTable->segments[0].offset->is<GlobalGet>() && + if (activeTableSegments.size() == 1 && + !activeTableSegments[0]->offset->is<Const>()) { + assert(activeTableSegments[0]->offset->is<GlobalGet>() && "Unexpected initializer instruction"); - activeSegment = &activeTable->segments[0]; + activeSegment = activeTableSegments[0]; activeBase = {activeTable->name, - activeTable->segments[0].offset->cast<GlobalGet>()->name, + activeTableSegments[0]->offset->cast<GlobalGet>()->name, 0}; } else { // Finds the segment with the highest occupied table slot so that new items @@ -171,13 +177,13 @@ TableSlotManager::TableSlotManager(Module& module) : module(module) { // overwriting any other items. TODO: be more clever about filling gaps in // the table, if that is ever useful. Index maxIndex = 0; - for (auto& segment : activeTable->segments) { - assert(segment.offset->is<Const>() && + for (auto& segment : activeTableSegments) { + assert(segment->offset->is<Const>() && "Unexpected non-const segment offset with multiple segments"); - Index segmentBase = segment.offset->cast<Const>()->value.geti32(); - if (segmentBase + segment.data.size() >= maxIndex) { - maxIndex = segmentBase + segment.data.size(); - activeSegment = &segment; + Index segmentBase = segment->offset->cast<Const>()->value.geti32(); + if (segmentBase + segment->data.size() >= maxIndex) { + maxIndex = segmentBase + segment->data.size(); + activeSegment = segment; activeBase = {activeTable->name, "", segmentBase}; } } @@ -190,9 +196,7 @@ TableSlotManager::TableSlotManager(Module& module) : module(module) { } Table* TableSlotManager::makeTable() { - module.addTable(Builder::makeTable(Name::fromInt(0))); - - return module.tables.front().get(); + return module.addTable(Builder::makeTable(Name::fromInt(0))); } TableSlotManager::Slot TableSlotManager::getSlot(Name func) { @@ -208,9 +212,16 @@ TableSlotManager::Slot TableSlotManager::getSlot(Name func) { activeBase = {activeTable->name, "", 0}; } - assert(activeTable->segments.size() == 0); - activeTable->segments.emplace_back(Builder(module).makeConst(int32_t(0))); - activeSegment = &activeTable->segments.back(); + assert(std::all_of(module.elementSegments.begin(), + module.elementSegments.end(), + [&](std::unique_ptr<ElementSegment>& segment) { + return segment->table != activeTable->name; + })); + auto segment = std::make_unique<ElementSegment>( + activeTable->name, Builder(module).makeConst(int32_t(0))); + segment->setName(Name::fromInt(0), false); + activeSegment = segment.get(); + module.addElementSegment(std::move(segment)); } Slot newSlot = {activeBase.tableName, @@ -470,13 +481,12 @@ void ModuleSplitter::setupTablePatching() { } auto secondaryTable = - ModuleUtils::copyTableWithoutSegments(tableManager.activeTable, secondary); + ModuleUtils::copyTable(tableManager.activeTable, secondary); if (tableManager.activeBase.global.size()) { - assert(tableManager.activeTable->segments.size() == 1 && + assert(tableManager.activeTableSegments.size() == 1 && "Unexpected number of segments with non-const base"); - assert(secondary.tables.size() == 1 && - secondary.tables.front()->segments.empty()); + assert(secondary.tables.size() == 1 && secondary.elementSegments.empty()); // Since addition is not currently allowed in initializer expressions, we // need to start the new secondary segment where the primary segment starts. // The secondary segment will contain the same primary functions as the @@ -485,29 +495,31 @@ void ModuleSplitter::setupTablePatching() { // to be imported into the second module. TODO: use better strategies here, // such as using ref.func in the start function or standardizing addition in // initializer expressions. - const Table::Segment& primarySeg = - tableManager.activeTable->segments.front(); + const ElementSegment* primarySeg = tableManager.activeTableSegments.front(); std::vector<Name> secondaryElems; - secondaryElems.reserve(primarySeg.data.size()); + secondaryElems.reserve(primarySeg->data.size()); // Copy functions from the primary segment to the secondary segment, // replacing placeholders and creating new exports and imports as necessary. auto replacement = replacedElems.begin(); for (Index i = 0; - i < primarySeg.data.size() && replacement != replacedElems.end(); + i < primarySeg->data.size() && replacement != replacedElems.end(); ++i) { if (replacement->first == i) { - // primarySeg.data[i] is a placeholder, so use the secondary function. + // primarySeg->data[i] is a placeholder, so use the secondary function. secondaryElems.push_back(replacement->second); ++replacement; } else { - exportImportFunction(primarySeg.data[i]); - secondaryElems.push_back(primarySeg.data[i]); + exportImportFunction(primarySeg->data[i]); + secondaryElems.push_back(primarySeg->data[i]); } } - auto offset = ExpressionManipulator::copy(primarySeg.offset, secondary); - secondaryTable->segments.emplace_back(offset, secondaryElems); + auto offset = ExpressionManipulator::copy(primarySeg->offset, secondary); + auto secondaryElem = std::make_unique<ElementSegment>( + secondaryTable->name, offset, secondaryElems); + secondaryElem->setName(primarySeg->name, primarySeg->hasExplicitName); + secondary.addElementSegment(std::move(secondaryElem)); return; } @@ -517,7 +529,11 @@ void ModuleSplitter::setupTablePatching() { std::vector<Name> currData; auto finishSegment = [&]() { auto* offset = Builder(secondary).makeConst(int32_t(currBase)); - secondaryTable->segments.emplace_back(offset, currData); + auto secondaryElem = + std::make_unique<ElementSegment>(secondaryTable->name, offset, currData); + secondaryElem->setName(Name::fromInt(secondary.elementSegments.size()), + false); + secondary.addElementSegment(std::move(secondaryElem)); }; for (auto curr = replacedElems.begin(); curr != replacedElems.end(); ++curr) { if (curr->first != currBase + currData.size()) { @@ -577,8 +593,7 @@ void ModuleSplitter::shareImportableItems() { for (auto& table : primary.tables) { auto secondaryTable = secondary.getTableOrNull(table->name); if (!secondaryTable) { - secondaryTable = - ModuleUtils::copyTableWithoutSegments(table.get(), secondary); + secondaryTable = ModuleUtils::copyTable(table.get(), secondary); } makeImportExport(*table, *secondaryTable, "table", ExternalKind::Table); diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index 371239e19..22ed89b37 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -70,7 +70,25 @@ inline Event* copyEvent(Event* event, Module& out) { return ret; } -inline Table* copyTableWithoutSegments(Table* table, Module& out) { +inline ElementSegment* copyElementSegment(const ElementSegment* segment, + Module& out) { + auto copy = [&](std::unique_ptr<ElementSegment>&& ret) { + ret->name = segment->name; + ret->hasExplicitName = segment->hasExplicitName; + ret->data = segment->data; + + return out.addElementSegment(std::move(ret)); + }; + + if (segment->table.isNull()) { + return copy(std::make_unique<ElementSegment>()); + } else { + auto offset = ExpressionManipulator::copy(segment->offset, out); + return copy(std::make_unique<ElementSegment>(segment->table, offset)); + } +} + +inline Table* copyTable(Table* table, Module& out) { auto ret = std::make_unique<Table>(); ret->name = table->name; ret->module = table->module; @@ -82,17 +100,6 @@ inline Table* copyTableWithoutSegments(Table* table, Module& out) { return out.addTable(std::move(ret)); } -inline Table* copyTable(Table* table, Module& out) { - auto ret = copyTableWithoutSegments(table, out); - - for (auto segment : table->segments) { - segment.offset = ExpressionManipulator::copy(segment.offset, out); - ret->segments.push_back(segment); - } - - return ret; -} - inline void copyModule(const Module& in, Module& out) { // we use names throughout, not raw pointers, so simple copying is fine // for everything *but* expressions @@ -108,9 +115,13 @@ inline void copyModule(const Module& in, Module& out) { for (auto& curr : in.events) { copyEvent(curr.get(), out); } + for (auto& curr : in.elementSegments) { + copyElementSegment(curr.get(), out); + } for (auto& curr : in.tables) { copyTable(curr.get(), out); } + out.memory = in.memory; for (auto& segment : out.memory.segments) { segment.offset = ExpressionManipulator::copy(segment.offset, out); @@ -148,11 +159,9 @@ template<typename T> inline void renameFunctions(Module& wasm, T& map) { } }; maybeUpdate(wasm.start); - for (auto& table : wasm.tables) { - for (auto& segment : table->segments) { - for (auto& name : segment.data) { - maybeUpdate(name); - } + for (auto& segment : wasm.elementSegments) { + for (auto& name : segment->data) { + maybeUpdate(name); } } for (auto& exp : wasm.exports) { @@ -208,6 +217,28 @@ template<typename T> inline void iterDefinedTables(Module& wasm, T visitor) { } } +template<typename T> +inline void iterTableSegments(Module& wasm, Name table, T visitor) { + // Just a precaution so that we don't iterate over passive elem segments by + // accident + assert(table.is() && "Table name must not be null"); + + for (auto& segment : wasm.elementSegments) { + if (segment->table == table) { + visitor(segment.get()); + } + } +} + +template<typename T> +inline void iterActiveElementSegments(Module& wasm, T visitor) { + for (auto& segment : wasm.elementSegments) { + if (segment->table.is()) { + visitor(segment.get()); + } + } +} + template<typename T> inline void iterImportedGlobals(Module& wasm, T visitor) { for (auto& import : wasm.globals) { if (import->imported()) { diff --git a/src/ir/table-utils.cpp b/src/ir/table-utils.cpp index ef89e50f3..639f8fbe6 100644 --- a/src/ir/table-utils.cpp +++ b/src/ir/table-utils.cpp @@ -31,11 +31,9 @@ std::set<Name> getFunctionsNeedingElemDeclare(Module& wasm) { // Find all the names in the tables. std::unordered_set<Name> tableNames; - for (auto& table : wasm.tables) { - for (auto& segment : table->segments) { - for (auto name : segment.data) { - tableNames.insert(name); - } + for (auto& segment : wasm.elementSegments) { + for (auto name : segment->data) { + tableNames.insert(name); } } diff --git a/src/ir/table-utils.h b/src/ir/table-utils.h index 80ffc0c06..e90b0ca72 100644 --- a/src/ir/table-utils.h +++ b/src/ir/table-utils.h @@ -18,6 +18,7 @@ #define wasm_ir_table_h #include "ir/literal-utils.h" +#include "ir/module-utils.h" #include "wasm-traversal.h" #include "wasm.h" @@ -29,32 +30,38 @@ struct FlatTable { std::vector<Name> names; bool valid; - FlatTable(Table& table) { + FlatTable(Module& wasm, Table& table) { valid = true; - for (auto& segment : table.segments) { - auto offset = segment.offset; - if (!offset->is<Const>()) { - // TODO: handle some non-constant segments - valid = false; - return; - } - Index start = offset->cast<Const>()->value.geti32(); - Index end = start + segment.data.size(); - if (end > names.size()) { - names.resize(end); - } - for (Index i = 0; i < segment.data.size(); i++) { - names[start + i] = segment.data[i]; - } - } + ModuleUtils::iterTableSegments( + wasm, table.name, [&](ElementSegment* segment) { + auto offset = segment->offset; + if (!offset->is<Const>()) { + // TODO: handle some non-constant segments + valid = false; + return; + } + Index start = offset->cast<Const>()->value.geti32(); + Index end = start + segment->data.size(); + if (end > names.size()) { + names.resize(end); + } + for (Index i = 0; i < segment->data.size(); i++) { + names[start + i] = segment->data[i]; + } + }); } }; -inline Table::Segment& getSingletonSegment(Table& table, Module& wasm) { - if (table.segments.size() != 1) { +inline ElementSegment* getSingletonSegment(Table& table, Module& wasm) { + std::vector<ElementSegment*> tableSegments; + ModuleUtils::iterTableSegments( + wasm, table.name, [&](ElementSegment* segment) { + tableSegments.push_back(segment); + }); + if (tableSegments.size() != 1) { Fatal() << "Table doesn't have a singleton segment."; } - return table.segments[0]; + return tableSegments[0]; } // Appends a name to the table. This assumes the table has 0 or 1 segments, @@ -65,10 +72,10 @@ inline Table::Segment& getSingletonSegment(Table& table, Module& wasm) { // module has a single table segment, and that the dylink section indicates // we can validly append to that segment, see the check below. inline Index append(Table& table, Name name, Module& wasm) { - auto& segment = getSingletonSegment(table, wasm); - auto tableIndex = segment.data.size(); + auto* segment = getSingletonSegment(table, wasm); + auto tableIndex = segment->data.size(); if (wasm.dylinkSection) { - if (segment.data.size() != wasm.dylinkSection->tableSize) { + if (segment->data.size() != wasm.dylinkSection->tableSize) { Fatal() << "Appending to the table in a module with a dylink section " "that has tableSize which indicates it wants to reserve more " "table space than the actual table elements in the module. " @@ -77,7 +84,7 @@ inline Index append(Table& table, Name name, Module& wasm) { } wasm.dylinkSection->tableSize++; } - segment.data.push_back(name); + segment->data.push_back(name); table.initial = table.initial + 1; return tableIndex; } @@ -85,9 +92,9 @@ inline Index append(Table& table, Name name, Module& wasm) { // Checks if a function is already in the table. Returns that index if so, // otherwise appends it. inline Index getOrAppend(Table& table, Name name, Module& wasm) { - auto& segment = getSingletonSegment(table, wasm); - for (Index i = 0; i < segment.data.size(); i++) { - if (segment.data[i] == name) { + auto segment = getSingletonSegment(table, wasm); + for (Index i = 0; i < segment->data.size(); i++) { + if (segment->data[i] == name) { return i; } } diff --git a/src/ir/utils.h b/src/ir/utils.h index 424298bb3..f06a68fff 100644 --- a/src/ir/utils.h +++ b/src/ir/utils.h @@ -121,6 +121,7 @@ struct ReFinalize void visitExport(Export* curr); void visitGlobal(Global* curr); void visitTable(Table* curr); + void visitElementSegment(ElementSegment* curr); void visitMemory(Memory* curr); void visitEvent(Event* curr); void visitModule(Module* curr); @@ -144,6 +145,7 @@ struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> { void visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); } void visitGlobal(Global* curr) { WASM_UNREACHABLE("unimp"); } void visitTable(Table* curr) { WASM_UNREACHABLE("unimp"); } + void visitElementSegment(ElementSegment* curr) { WASM_UNREACHABLE("unimp"); } void visitMemory(Memory* curr) { WASM_UNREACHABLE("unimp"); } void visitEvent(Event* curr) { WASM_UNREACHABLE("unimp"); } void visitModule(Module* curr) { WASM_UNREACHABLE("unimp"); } diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js index 785c56106..bafbaa480 100644 --- a/src/js/binaryen.js-post.js +++ b/src/js/binaryen.js-post.js @@ -2214,23 +2214,55 @@ function wrapModule(module, self = {}) { self['getGlobal'] = function(name) { return preserveStack(() => Module['_BinaryenGetGlobal'](module, strToStack(name))); }; - self['addTable'] = function(table, initial, maximum, funcNames, offset = self['i32']['const'](0)) { - return preserveStack(() => Module['_BinaryenAddTable'](module, - strToStack(table), initial, maximum, - i32sToStack(funcNames.map(strToStack)), - funcNames.length, - offset) - ); + self['addTable'] = function(table, initial, maximum) { + return preserveStack(() => Module['_BinaryenAddTable'](module, strToStack(table), initial, maximum)); } self['getTable'] = function(name) { return preserveStack(() => Module['_BinaryenGetTable'](module, strToStack(name))); }; + self['addActiveElementSegment'] = function(table, name, funcNames, offset = self['i32']['const'](0)) { + return preserveStack(() => Module['_BinaryenAddActiveElementSegment']( + module, + strToStack(table), + strToStack(name), + i32sToStack(funcNames.map(strToStack)), + funcNames.length, + offset + )); + }; + self['addPassiveElementSegment'] = function(name, funcNames) { + return preserveStack(() => Module['_BinaryenAddPassiveElementSegment']( + module, + strToStack(name), + i32sToStack(funcNames.map(strToStack)), + funcNames.length + )); + }; + self['getElementSegment'] = function(name) { + return preserveStack(() => Module['_BinaryenGetElementSegment'](module, strToStack(name))); + }; + self['getTableSegments'] = function(table) { + var numElementSegments = Module['_BinaryenGetNumElementSegments'](module); + var tableName = UTF8ToString(Module['_BinaryenTableGetName'](table)); + var ret = []; + for (var i = 0; i < numElementSegments; i++) { + var segment = Module['_BinaryenGetElementSegmentByIndex'](module, i); + var elemTableName = UTF8ToString(Module['_BinaryenElementSegmentGetTable'](segment)); + if (tableName === elemTableName) { + ret.push(segment); + } + } + return ret; + } self['removeGlobal'] = function(name) { return preserveStack(() => Module['_BinaryenRemoveGlobal'](module, strToStack(name))); } self['removeTable'] = function(name) { return preserveStack(() => Module['_BinaryenRemoveTable'](module, strToStack(name))); - } + }; + self['removeElementSegment'] = function(name) { + return preserveStack(() => Module['_BinaryenRemoveElementSegment'](module, strToStack(name))); + }; self['addEvent'] = function(name, attribute, params, results) { return preserveStack(() => Module['_BinaryenAddEvent'](module, strToStack(name), attribute, params, results)); }; @@ -2284,37 +2316,6 @@ function wrapModule(module, self = {}) { self['removeExport'] = function(externalName) { return preserveStack(() => Module['_BinaryenRemoveExport'](module, strToStack(externalName))); }; - self['setFunctionTable'] = function(initial, maximum, funcNames, offset = self['i32']['const'](0)) { - return preserveStack(() => { - return Module['_BinaryenSetFunctionTable'](module, initial, maximum, - i32sToStack(funcNames.map(strToStack)), - funcNames.length, - offset - ); - }); - }; - self['getFunctionTable'] = function() { - return { - 'imported': Boolean(Module['_BinaryenIsFunctionTableImported'](module)), - 'segments': (function() { - const numSegments = Module['_BinaryenGetNumFunctionTableSegments'](module) - const arr = new Array(numSegments); - for (let i = 0; i !== numSegments; ++i) { - const segmentLength = Module['_BinaryenGetFunctionTableSegmentLength'](module, i); - const names = new Array(segmentLength); - for (let j = 0; j !== segmentLength; ++j) { - const ptr = Module['_BinaryenGetFunctionTableSegmentData'](module, i, j); - names[j] = UTF8ToString(ptr); - } - arr[i] = { - 'offset': Module['_BinaryenGetFunctionTableSegmentOffset'](module, i), - 'names': names - }; - } - return arr; - })() - }; - }; self['setMemory'] = function(initial, maximum, exportName, segments = [], shared = false) { // segments are assumed to be { passive: bool, offset: expression ref, data: array of 8-bit data } return preserveStack(() => { @@ -2394,12 +2395,18 @@ function wrapModule(module, self = {}) { self['getNumTables'] = function() { return Module['_BinaryenGetNumTables'](module); }; + self['getNumElementSegments'] = function() { + return Module['_BinaryenGetNumElementSegments'](module); + }; self['getGlobalByIndex'] = function(index) { return Module['_BinaryenGetGlobalByIndex'](module, index); }; self['getTableByIndex'] = function(index) { return Module['_BinaryenGetTableByIndex'](module, index); }; + self['getElementSegmentByIndex'] = function(index) { + return Module['_BinaryenGetElementSegmentByIndex'](module, index); + }; self['emitText'] = function() { const old = out; let ret = ''; @@ -3029,8 +3036,8 @@ Module['getTableInfo'] = function(table) { 'name': UTF8ToString(Module['_BinaryenTableGetName'](table)), 'module': UTF8ToString(Module['_BinaryenTableImportGetModule'](table)), 'base': UTF8ToString(Module['_BinaryenTableImportGetBase'](table)), - 'initial': Module['_BinaryenTableGetInitial'](table) - }; + 'initial': Module['_BinaryenTableGetInitial'](table), + } if (hasMax) { tableInfo.max = Module['_BinaryenTableGetMax'](table); @@ -3039,6 +3046,22 @@ Module['getTableInfo'] = function(table) { return tableInfo; }; +Module['getElementSegmentInfo'] = function(segment) { + var segmentLength = Module['_BinaryenElementSegmentGetLength'](segment); + var names = new Array(segmentLength); + for (let j = 0; j !== segmentLength; ++j) { + var ptr = Module['_BinaryenElementSegmentGetData'](segment, j); + names[j] = UTF8ToString(ptr); + } + + return { + 'name': UTF8ToString(Module['_BinaryenElementSegmentGetName'](segment)), + 'table': UTF8ToString(Module['_BinaryenElementSegmentGetTable'](segment)), + 'offset': Module['_BinaryenElementSegmentGetOffset'](segment), + 'data': names + } +} + // Obtains information about a 'Event' Module['getEventInfo'] = function(event_) { return { diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 1a404dbcf..6223c841d 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -284,11 +284,9 @@ struct DAE : public Pass { infoMap[curr->value].hasUnseenCalls = true; } } - for (auto& table : module->tables) { - for (auto& segment : table->segments) { - for (auto name : segment.data) { - infoMap[name].hasUnseenCalls = true; - } + for (auto& segment : module->elementSegments) { + for (auto name : segment->data) { + infoMap[name].hasUnseenCalls = true; } } // Scan all the functions. diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index 6ee976b39..3a7ef1024 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -124,7 +124,7 @@ struct Directize : public Pass { } if (canOptimizeCallIndirect) { - TableUtils::FlatTable flatTable(*table); + TableUtils::FlatTable flatTable(*module, *table); if (flatTable.valid) { validTables.emplace(table->name, flatTable); } diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp index b91d7ba4f..753c986b8 100644 --- a/src/passes/FuncCastEmulation.cpp +++ b/src/passes/FuncCastEmulation.cpp @@ -173,17 +173,15 @@ struct FuncCastEmulation : public Pass { Signature ABIType(Type(std::vector<Type>(numParams, Type::i64)), Type::i64); // Add a thunk for each function in the table, and do the call through it. std::unordered_map<Name, Name> funcThunks; - for (auto& table : module->tables) { - for (auto& segment : table->segments) { - for (auto& name : segment.data) { - auto iter = funcThunks.find(name); - if (iter == funcThunks.end()) { - auto thunk = makeThunk(name, module, numParams); - funcThunks[name] = thunk; - name = thunk; - } else { - name = iter->second; - } + for (auto& segment : module->elementSegments) { + for (auto& name : segment->data) { + auto iter = funcThunks.find(name); + if (iter == funcThunks.end()) { + auto thunk = makeThunk(name, module, numParams); + funcThunks[name] = thunk; + name = thunk; + } else { + name = iter->second; } } } diff --git a/src/passes/GenerateDynCalls.cpp b/src/passes/GenerateDynCalls.cpp index 827dccea0..dad992fea 100644 --- a/src/passes/GenerateDynCalls.cpp +++ b/src/passes/GenerateDynCalls.cpp @@ -45,10 +45,20 @@ struct GenerateDynCalls : public WalkerPass<PostWalker<GenerateDynCalls>> { void visitTable(Table* table) { // Generate dynCalls for functions in the table - if (table->segments.size() > 0) { + Module* wasm = getModule(); + auto& segments = wasm->elementSegments; + + // Find a single elem segment for the table. We only care about one, since + // wasm-ld emits only one table with a single segment. + auto it = std::find_if(segments.begin(), + segments.end(), + [&](std::unique_ptr<ElementSegment>& segment) { + return segment->table == table->name; + }); + if (it != segments.end()) { std::vector<Name> tableSegmentData; - for (const auto& indirectFunc : table->segments[0].data) { - generateDynCallThunk(getModule()->getFunction(indirectFunc)->sig); + for (const auto& indirectFunc : it->get()->data) { + generateDynCallThunk(wasm->getFunction(indirectFunc)->sig); } } } diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 727e94d72..41481b310 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -348,11 +348,9 @@ struct Inlining : public Pass { infos[ex->value].usedGlobally = true; } } - for (auto& table : module->tables) { - for (auto& segment : table->segments) { - for (auto name : segment.data) { - infos[name].usedGlobally = true; - } + for (auto& segment : module->elementSegments) { + for (auto name : segment->data) { + infos[name].usedGlobally = true; } } diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index e8a604326..0a991caa5 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -97,12 +97,10 @@ struct LegalizeJSInterface : public Pass { // we need to use the legalized version in the tables, as the import // from JS is legal for JS. Our stub makes it look like a native wasm // function. - for (auto& table : module->tables) { - for (auto& segment : table->segments) { - for (auto& name : segment.data) { - if (name == im->name) { - name = funcName; - } + for (auto& segment : module->elementSegments) { + for (auto& name : segment->data) { + if (name == im->name) { + name = funcName; } } } diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp index 09dcc5445..231ad928c 100644 --- a/src/passes/Metrics.cpp +++ b/src/passes/Metrics.cpp @@ -74,11 +74,13 @@ struct Metrics } Index size = 0; + ModuleUtils::iterActiveElementSegments( + *module, [&](ElementSegment* segment) { size += segment->data.size(); }); for (auto& table : module->tables) { walkTable(table.get()); - for (auto& segment : table->segments) { - size += segment.data.size(); - } + } + for (auto& segment : module->elementSegments) { + walkElementSegment(segment.get()); } if (!module->tables.empty()) { counts["[table-data]"] = size; diff --git a/src/passes/PostEmscripten.cpp b/src/passes/PostEmscripten.cpp index 35e2b9a8b..403fb998b 100644 --- a/src/passes/PostEmscripten.cpp +++ b/src/passes/PostEmscripten.cpp @@ -67,7 +67,7 @@ struct PostEmscripten : public Pass { // Next, see if the Table is flat, which we need in order to see where // invokes go statically. (In dynamic linking, the table is not flat, // and we can't do this.) - TableUtils::FlatTable flatTable(*module->tables[0]); + TableUtils::FlatTable flatTable(*module, *module->tables[0]); if (!flatTable.valid) { return; } diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 993eddb99..05a639158 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2665,36 +2665,57 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { printTableHeader(curr); o << maybeNewLine; } - for (auto& segment : curr->segments) { - // Don't print empty segments - if (segment.data.empty()) { - continue; - } - doIndent(o, indent); - o << '('; - printMedium(o, "elem "); + ModuleUtils::iterTableSegments( + *currModule, curr->name, [&](ElementSegment* segment) { + printElementSegment(segment); + }); + } + void visitElementSegment(ElementSegment* curr) { + if (curr->table.is()) { + return; + } + printElementSegment(curr); + } + void printElementSegment(ElementSegment* curr) { + // Don't print empty segments + if (curr->data.empty()) { + return; + } + doIndent(o, indent); + o << '('; + printMedium(o, "elem"); + if (curr->hasExplicitName) { + o << ' '; + printName(curr->name, o); + } + + if (curr->table.is()) { // TODO(reference-types): check for old-style based on the complete spec if (currModule->tables.size() > 1) { // tableuse - o << '('; - printMedium(o, "table "); - printName(curr->name, o); - o << ") "; + o << " (table "; + printName(curr->table, o); + o << ")"; } - visit(segment.offset); + o << ' '; + visit(curr->offset); if (currModule->tables.size() > 1) { - o << " func"; - } - - for (auto name : segment.data) { o << ' '; - printName(name, o); + TypeNamePrinter(o, currModule).print(HeapType::func); } - o << ')' << maybeNewLine; + } else { + o << ' '; + TypeNamePrinter(o, currModule).print(HeapType::func); } + + for (auto name : curr->data) { + o << ' '; + printName(name, o); + } + o << ')' << maybeNewLine; } void printMemoryHeader(Memory* curr) { o << '('; @@ -2833,6 +2854,9 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { *curr, [&](Memory* memory) { visitMemory(memory); }); ModuleUtils::iterDefinedTables(*curr, [&](Table* table) { visitTable(table); }); + for (auto& segment : curr->elementSegments) { + visitElementSegment(segment.get()); + } auto elemDeclareNames = TableUtils::getFunctionsNeedingElemDeclare(*curr); if (!elemDeclareNames.empty()) { doIndent(o, indent); diff --git a/src/passes/PrintCallGraph.cpp b/src/passes/PrintCallGraph.cpp index 49e056312..d70e0d7f2 100644 --- a/src/passes/PrintCallGraph.cpp +++ b/src/passes/PrintCallGraph.cpp @@ -96,12 +96,10 @@ struct PrintCallGraph : public Pass { CallPrinter printer(module); // Indirect Targets - for (auto& table : module->tables) { - for (auto& segment : table->segments) { - for (auto& curr : segment.data) { - auto* func = module->getFunction(curr); - o << " \"" << func->name << "\" [style=\"filled, rounded\"];\n"; - } + for (auto& segment : module->elementSegments) { + for (auto& curr : segment->data) { + auto* func = module->getFunction(curr); + o << " \"" << func->name << "\" [style=\"filled, rounded\"];\n"; } } diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp index 10a885586..be41717f4 100644 --- a/src/passes/RemoveImports.cpp +++ b/src/passes/RemoveImports.cpp @@ -49,11 +49,9 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> { *curr, [&](Function* func) { names.push_back(func->name); }); // Do not remove names referenced in a table std::set<Name> indirectNames; - for (auto& table : curr->tables) { - for (auto& segment : table->segments) { - for (auto& name : segment.data) { - indirectNames.insert(name); - } + for (auto& segment : curr->elementSegments) { + for (auto& name : segment->data) { + indirectNames.insert(name); } } for (auto& name : names) { diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index d12fdb487..fa295dac5 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -29,7 +29,7 @@ namespace wasm { -enum class ModuleElementKind { Function, Global, Event, Table }; +enum class ModuleElementKind { Function, Global, Event, Table, ElementSegment }; typedef std::pair<ModuleElementKind, Name> ModuleElement; @@ -51,9 +51,9 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { walk(segment.offset); } } - for (auto& table : module->tables) { - for (auto& segment : table->segments) { - walk(segment.offset); + for (auto& segment : module->elementSegments) { + if (segment->table.is()) { + walk(segment->offset); } } @@ -76,10 +76,10 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { walk(global->init); } } else if (curr.first == ModuleElementKind::Table) { - auto* table = module->getTable(curr.second); - for (auto& segment : table->segments) { - walk(segment.offset); - } + ModuleUtils::iterTableSegments( + *module, curr.second, [&](ElementSegment* segment) { + walk(segment->offset); + }); } } } @@ -96,8 +96,12 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { } void visitCallIndirect(CallIndirect* curr) { assert(!module->tables.empty() && "call-indirect to undefined table."); - maybeAdd(ModuleElement(ModuleElementKind::Table, curr->table)); + ModuleUtils::iterTableSegments( + *module, curr->table, [&](ElementSegment* segment) { + maybeAdd( + ModuleElement(ModuleElementKind::ElementSegment, segment->name)); + }); } void visitGlobalGet(GlobalGet* curr) { @@ -157,6 +161,13 @@ struct RemoveUnusedModuleElements : public Pass { roots.emplace_back(ModuleElementKind::Function, func->name); }); } + ModuleUtils::iterActiveElementSegments( + *module, [&](ElementSegment* segment) { + auto table = module->getTable(segment->table); + if (table->imported() && !segment->data.empty()) { + roots.emplace_back(ModuleElementKind::ElementSegment, segment->name); + } + }); // Exports are roots. bool exportsMemory = false; for (auto& curr : module->exports) { @@ -168,6 +179,11 @@ struct RemoveUnusedModuleElements : public Pass { roots.emplace_back(ModuleElementKind::Event, curr->value); } else if (curr->kind == ExternalKind::Table) { roots.emplace_back(ModuleElementKind::Table, curr->value); + ModuleUtils::iterTableSegments( + *module, curr->value, [&](ElementSegment* segment) { + roots.emplace_back(ModuleElementKind::ElementSegment, + segment->name); + }); } else if (curr->kind == ExternalKind::Memory) { exportsMemory = true; } @@ -178,12 +194,9 @@ struct RemoveUnusedModuleElements : public Pass { importsMemory = true; } // For now, all functions that can be called indirectly are marked as roots. - for (auto& table : module->tables) { - // TODO(reference-types): Check whether table's datatype is funcref. - for (auto& segment : table->segments) { - for (auto& curr : segment.data) { - roots.emplace_back(ModuleElementKind::Function, curr); - } + for (auto& segment : module->elementSegments) { + for (auto& curr : segment->data) { + roots.emplace_back(ModuleElementKind::Function, curr); } } // Compute reachability starting from the root set. @@ -201,20 +214,24 @@ struct RemoveUnusedModuleElements : public Pass { return analyzer.reachable.count( ModuleElement(ModuleElementKind::Event, curr->name)) == 0; }); - - for (auto& table : module->tables) { - table->segments.erase( - std::remove_if(table->segments.begin(), - table->segments.end(), - [&](auto& seg) { return seg.data.empty(); }), - table->segments.end()); - } + module->removeElementSegments([&](ElementSegment* curr) { + return curr->data.empty() || + analyzer.reachable.count(ModuleElement( + ModuleElementKind::ElementSegment, curr->name)) == 0; + }); + // Since we've removed all empty element segments, here we mark all tables + // that have a segment left. + std::unordered_set<Name> nonemptyTables; + ModuleUtils::iterActiveElementSegments( + *module, + [&](ElementSegment* segment) { nonemptyTables.insert(segment->table); }); module->removeTables([&](Table* curr) { - return (curr->segments.empty() || !curr->imported()) && + return (nonemptyTables.count(curr->name) == 0 || !curr->imported()) && analyzer.reachable.count( ModuleElement(ModuleElementKind::Table, curr->name)) == 0; }); - // Handle the memory and table + + // Handle the memory if (!exportsMemory && !analyzer.usesMemory) { if (!importsMemory) { // The memory is unobservable to the outside, we can remove the diff --git a/src/passes/ReorderFunctions.cpp b/src/passes/ReorderFunctions.cpp index 4d02616f0..0c95101a5 100644 --- a/src/passes/ReorderFunctions.cpp +++ b/src/passes/ReorderFunctions.cpp @@ -70,11 +70,9 @@ struct ReorderFunctions : public Pass { for (auto& curr : module->exports) { counts[curr->value]++; } - for (auto& table : module->tables) { - for (auto& segment : table->segments) { - for (auto& curr : segment.data) { - counts[curr]++; - } + for (auto& segment : module->elementSegments) { + for (auto& curr : segment->data) { + counts[curr]++; } } // sort diff --git a/src/passes/opt-utils.h b/src/passes/opt-utils.h index 0a67a7e1e..27a3c5c25 100644 --- a/src/passes/opt-utils.h +++ b/src/passes/opt-utils.h @@ -86,13 +86,12 @@ inline void replaceFunctions(PassRunner* runner, // replace direct calls FunctionRefReplacer(maybeReplace).run(runner, &module); // replace in table - for (auto& table : module.tables) { - for (auto& segment : table->segments) { - for (auto& name : segment.data) { - maybeReplace(name); - } + for (auto& segment : module.elementSegments) { + for (auto& name : segment->data) { + maybeReplace(name); } } + // replace in start if (module.start.is()) { maybeReplace(module.start); diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index 1cc3635dc..288fdc23c 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -195,7 +195,7 @@ public: if (allowMemory) { setupMemory(); } - setupTable(); + setupTables(); setupGlobals(); if (wasm.features.hasExceptionHandling()) { setupEvents(); @@ -424,16 +424,23 @@ private: } // TODO(reference-types): allow the fuzzer to create multiple tables - void setupTable() { + void setupTables() { if (wasm.tables.size() > 0) { auto& table = wasm.tables[0]; table->initial = table->max = 0; - table->segments.emplace_back(builder.makeConst(int32_t(0))); + + auto segment = std::make_unique<ElementSegment>( + table->name, builder.makeConst(int32_t(0))); + segment->setName(Name::fromInt(0), false); + wasm.addElementSegment(std::move(segment)); } else { auto table = builder.makeTable( Names::getValidTableName(wasm, "fuzzing_table"), 0, 0); table->hasExplicitName = true; - table->segments.emplace_back(builder.makeConst(int32_t(0))); + auto segment = std::make_unique<ElementSegment>( + table->name, builder.makeConst(int32_t(0))); + segment->setName(Name::fromInt(0), false); + wasm.addElementSegment(std::move(segment)); wasm.addTable(std::move(table)); } } @@ -532,22 +539,23 @@ private: void finalizeTable() { for (auto& table : wasm.tables) { - for (auto& segment : table->segments) { - // If the offset is a global that was imported (which is ok) but no - // longer is (not ok) we need to change that. - if (auto* offset = segment.offset->dynCast<GlobalGet>()) { - if (!wasm.getGlobal(offset->name)->imported()) { - // TODO: the segments must not overlap... - segment.offset = - builder.makeConst(Literal::makeFromInt32(0, Type::i32)); + ModuleUtils::iterTableSegments( + wasm, table->name, [&](ElementSegment* segment) { + // If the offset is a global that was imported (which is ok) but no + // longer is (not ok) we need to change that. + if (auto* offset = segment->offset->dynCast<GlobalGet>()) { + if (!wasm.getGlobal(offset->name)->imported()) { + // TODO: the segments must not overlap... + segment->offset = + builder.makeConst(Literal::makeFromInt32(0, Type::i32)); + } } - } - Address maxOffset = segment.data.size(); - if (auto* offset = segment.offset->dynCast<Const>()) { - maxOffset = maxOffset + offset->value.getInteger(); - } - table->initial = std::max(table->initial, maxOffset); - } + Address maxOffset = segment->data.size(); + if (auto* offset = segment->offset->dynCast<Const>()) { + maxOffset = maxOffset + offset->value.getInteger(); + } + table->initial = std::max(table->initial, maxOffset); + }); table->max = oneIn(2) ? Address(Table::kUnlimitedSize) : table->initial; // Avoid an imported table (which the fuzz harness would need to handle). table->module = table->base = Name(); @@ -713,9 +721,11 @@ private: export_->kind = ExternalKind::Function; wasm.addExport(export_); } - // add some to the table + // add some to an elem segment while (oneIn(3) && !finishedInput) { - wasm.tables[0]->segments[0].data.push_back(func->name); + auto& randomElem = + wasm.elementSegments[upTo(wasm.elementSegments.size())]; + randomElem->data.push_back(func->name); } numAddedFunctions++; return func; @@ -1435,7 +1445,8 @@ private: } Expression* makeCallIndirect(Type type) { - auto& data = wasm.tables[0]->segments[0].data; + auto& randomElem = wasm.elementSegments[upTo(wasm.elementSegments.size())]; + auto& data = randomElem->data; if (data.empty()) { return make(type); } diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp index 9e3915073..f39d8a34e 100644 --- a/src/tools/wasm-ctor-eval.cpp +++ b/src/tools/wasm-ctor-eval.cpp @@ -231,23 +231,27 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface { // we assume the table is not modified (hmm) // look through the segments, try to find the function - for (auto& segment : table->segments) { + for (auto& segment : wasm->elementSegments) { + if (segment->table != tableName) { + continue; + } + Index start; // look for the index in this segment. if it has a constant offset, we // look in the proper range. if it instead gets a global, we rely on the // fact that when not dynamically linking then the table is loaded at // offset 0. - if (auto* c = segment.offset->dynCast<Const>()) { + if (auto* c = segment->offset->dynCast<Const>()) { start = c->value.getInteger(); - } else if (segment.offset->is<GlobalGet>()) { + } else if (segment->offset->is<GlobalGet>()) { start = 0; } else { // wasm spec only allows const and global.get there WASM_UNREACHABLE("invalid expr type"); } - auto end = start + segment.data.size(); + auto end = start + segment->data.size(); if (start <= index && index < end) { - auto name = segment.data[index - start]; + auto name = segment->data[index - start]; // if this is one of our functions, we can call it; if it was imported, // fail auto* func = wasm->getFunction(name); diff --git a/src/tools/wasm-metadce.cpp b/src/tools/wasm-metadce.cpp index 46b66c986..71ea2f693 100644 --- a/src/tools/wasm-metadce.cpp +++ b/src/tools/wasm-metadce.cpp @@ -213,20 +213,18 @@ struct MetaDCEGraph { // we can't remove segments, so root what they need InitScanner rooter(this, Name()); rooter.setModule(&wasm); - for (auto& table : wasm.tables) { - for (auto& segment : table->segments) { - // TODO: currently, all functions in the table are roots, but we - // should add an option to refine that - for (auto& name : segment.data) { - if (!wasm.getFunction(name)->imported()) { - roots.insert(functionToDCENode[name]); - } else { - roots.insert(importIdToDCENode[getFunctionImportId(name)]); - } + ModuleUtils::iterActiveElementSegments(wasm, [&](ElementSegment* segment) { + // TODO: currently, all functions in the table are roots, but we + // should add an option to refine that + for (auto& name : segment->data) { + if (!wasm.getFunction(name)->imported()) { + roots.insert(functionToDCENode[name]); + } else { + roots.insert(importIdToDCENode[getFunctionImportId(name)]); } - rooter.walk(segment.offset); } - } + rooter.walk(segment->offset); + }); for (auto& segment : wasm.memory.segments) { if (!segment.isPassive) { rooter.walk(segment.offset); diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp index 4bb4b6f9a..5cd58301e 100644 --- a/src/tools/wasm-reduce.cpp +++ b/src/tools/wasm-reduce.cpp @@ -741,47 +741,39 @@ struct Reducer // TODO: bisection on segment shrinking? - void visitTable(Table* curr) { - std::cerr << "| try to simplify table\n"; - Name first; - for (auto& segment : curr->segments) { - for (auto item : segment.data) { - first = item; - break; - } - if (!first.isNull()) { - break; - } - } - visitSegmented(curr, first, 100); - } - void visitMemory(Memory* curr) { std::cerr << "| try to simplify memory\n"; - visitSegmented(curr, 0, 2); + + // try to reduce to first function. first, shrink segment elements. + // while we are shrinking successfully, keep going exponentially. + bool shrank = false; + for (auto& segment : curr->segments) { + shrank = shrinkByReduction(&segment, 2); + } + // the "opposite" of shrinking: copy a 'zero' element + for (auto& segment : curr->segments) { + reduceByZeroing(&segment, 0, 2, shrank); + } } - template<typename T, typename U> - void visitSegmented(T* curr, U zero, size_t bonus) { + template<typename T> bool shrinkByReduction(T* segment, size_t bonus) { // try to reduce to first function. first, shrink segment elements. // while we are shrinking successfully, keep going exponentially. bool justShrank = false; bool shrank = false; - for (auto& segment : curr->segments) { - auto& data = segment.data; - // when we succeed, try to shrink by more and more, similar to bisection - size_t skip = 1; - for (size_t i = 0; i < data.size() && !data.empty(); i++) { - if (!justShrank && !shouldTryToReduce(bonus)) { - continue; - } + + auto& data = segment->data; + // when we succeed, try to shrink by more and more, similar to bisection + size_t skip = 1; + for (size_t i = 0; i < data.size() && !data.empty(); i++) { + if (justShrank || shouldTryToReduce(bonus)) { auto save = data; for (size_t j = 0; j < skip; j++) { if (!data.empty()) { data.pop_back(); } } - auto justShrank = writeAndTestReduction(); + justShrank = writeAndTestReduction(); if (justShrank) { std::cerr << "| shrank segment (skip: " << skip << ")\n"; shrank = true; @@ -793,37 +785,67 @@ struct Reducer } } } - // the "opposite" of shrinking: copy a 'zero' element - for (auto& segment : curr->segments) { - if (segment.data.empty()) { + + return shrank; + } + + template<typename T, typename U> + void reduceByZeroing(T* segment, U zero, size_t bonus, bool shrank) { + if (segment->data.empty()) { + return; + } + for (auto& item : segment->data) { + if (!shouldTryToReduce(bonus)) { continue; } - for (auto& item : segment.data) { - if (!shouldTryToReduce(bonus)) { - continue; - } - if (item == zero) { - continue; - } - auto save = item; - item = zero; - if (writeAndTestReduction()) { - std::cerr << "| zeroed segment\n"; - noteReduction(); - } else { - item = save; - } - if (shrank) { - // zeroing is fairly inefficient. if we are managing to shrink - // (which we do exponentially), just zero one per segment at most - break; - } + if (item == zero) { + continue; + } + auto save = item; + item = zero; + if (writeAndTestReduction()) { + std::cerr << "| zeroed elem segment\n"; + noteReduction(); + } else { + item = save; + } + if (shrank) { + // zeroing is fairly inefficient. if we are managing to shrink + // (which we do exponentially), just zero one per segment at most + break; } } } + void shrinkElementSegments(Module* module) { + std::cerr << "| try to simplify elem segments\n"; + Name first; + auto it = + std::find_if_not(module->elementSegments.begin(), + module->elementSegments.end(), + [&](auto& segment) { return segment->data.empty(); }); + + if (it != module->elementSegments.end()) { + first = it->get()->data[0]; + } + + // try to reduce to first function. first, shrink segment elements. + // while we are shrinking successfully, keep going exponentially. + bool shrank = false; + for (auto& segment : module->elementSegments) { + shrank = shrinkByReduction(segment.get(), 100); + } + // the "opposite" of shrinking: copy a 'zero' element + for (auto& segment : module->elementSegments) { + reduceByZeroing(segment.get(), first, 100, shrank); + } + } + void visitModule(Module* curr) { assert(curr == module.get()); + + shrinkElementSegments(curr); + // try to remove functions std::cerr << "| try to remove functions\n"; std::vector<Name> functionNames; @@ -898,12 +920,11 @@ struct Reducer } // If we are left with a single function that is not exported or used in // a table, that is useful as then we can change the return type. - bool allTablesEmpty = std::all_of( - module->tables.begin(), module->tables.end(), [&](auto& table) { - return std::all_of(table->segments.begin(), - table->segments.end(), - [&](auto& segment) { return segment.data.empty(); }); - }); + bool allTablesEmpty = + std::all_of(module->elementSegments.begin(), + module->elementSegments.end(), + [&](auto& segment) { return segment->data.empty(); }); + if (module->functions.size() == 1 && module->exports.empty() && allTablesEmpty) { auto* func = module->functions[0].get(); @@ -964,30 +985,6 @@ struct Reducer exportsToRemove.push_back(curr->name); } } - void visitTable(Table* curr) { - Name other; - for (auto& segment : curr->segments) { - for (auto name : segment.data) { - if (!names.count(name)) { - other = name; - break; - } - } - if (!other.isNull()) { - break; - } - } - if (other.isNull()) { - return; // we failed to find a replacement - } - for (auto& segment : curr->segments) { - for (auto& name : segment.data) { - if (names.count(name)) { - name = other; - } - } - } - } void doWalkModule(Module* module) { PostWalker<FunctionReferenceRemover>::doWalkModule(module); for (auto name : exportsToRemove) { diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 09aaa023e..efa2c0ca4 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -166,20 +166,18 @@ run_asserts(Name moduleName, reportUnknownImport(import); } }); - ModuleUtils::iterDefinedTables(wasm, [&](Table* table) { - for (auto& segment : table->segments) { - for (auto name : segment.data) { - // spec tests consider it illegal to use spectest.print in a table - if (auto* import = wasm.getFunction(name)) { - if (import->imported() && import->module == SPECTEST && - import->base.startsWith(PRINT)) { - std::cerr << "cannot put spectest.print in table\n"; - invalid = true; - } + for (auto& segment : wasm.elementSegments) { + for (auto name : segment->data) { + // spec tests consider it illegal to use spectest.print in a table + if (auto* import = wasm.getFunction(name)) { + if (import->imported() && import->module == SPECTEST && + import->base.startsWith(PRINT)) { + std::cerr << "cannot put spectest.print in table\n"; + invalid = true; } } } - }); + } if (wasm.memory.imported()) { reportUnknownImport(&wasm.memory); } diff --git a/src/wasm-binary.h b/src/wasm-binary.h index d490a2a96..8ca07e7d2 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1106,6 +1106,7 @@ class WasmBinaryWriter { std::unordered_map<Name, Index> eventIndexes; std::unordered_map<Name, Index> globalIndexes; std::unordered_map<Name, Index> tableIndexes; + std::unordered_map<Name, Index> elemIndexes; BinaryIndexes(Module& wasm) { auto addIndexes = [&](auto& source, auto& indexes) { @@ -1128,6 +1129,11 @@ class WasmBinaryWriter { addIndexes(wasm.events, eventIndexes); addIndexes(wasm.tables, tableIndexes); + for (auto& curr : wasm.elementSegments) { + auto index = elemIndexes.size(); + elemIndexes[curr->name] = index; + } + // Globals may have tuple types in the IR, in which case they lower to // multiple globals, one for each tuple element, in the binary. Tuple // globals therefore occupy multiple binary indices, and we have to take @@ -1205,7 +1211,7 @@ public: uint32_t getTypeIndex(HeapType type) const; void writeTableDeclarations(); - void writeTableElements(); + void writeElementSegments(); void writeNames(); void writeSourceMapUrl(); void writeSymbolMap(); @@ -1398,6 +1404,12 @@ public: // at index i we have all references to the table i std::map<Index, std::vector<Expression*>> tableRefs; + std::map<Index, Name> elemTables; + + // we store elems here after being read from binary, until when we know their + // names + std::vector<std::unique_ptr<ElementSegment>> elementSegments; + // we store globals here before wasm.addGlobal after we know their names std::vector<std::unique_ptr<Global>> globals; // we store global imports here before wasm.addGlobalImport after we know @@ -1504,11 +1516,11 @@ public: void readDataSegments(); void readDataCount(); - // A map from table indexes to the map of segment indexes to their elements - std::map<Index, std::map<Index, std::vector<Index>>> functionTable; + // A map from elem segment indexes to their entries + std::map<Index, std::vector<Index>> functionTable; - void readFunctionTableDeclaration(); - void readTableElements(); + void readTableDeclarations(); + void readElementSegments(); void readEvents(); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 5032b7277..ac095e0c0 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -2309,20 +2309,19 @@ private: std::unordered_set<size_t> droppedSegments; void initializeTableContents() { - for (auto& table : wasm.tables) { - for (auto& segment : table->segments) { - Address offset = (uint32_t)InitializerExpressionRunner<GlobalManager>( - globals, maxDepth) - .visit(segment.offset) - .getSingleValue() - .geti32(); - if (offset + segment.data.size() > table->initial) { - externalInterface->trap("invalid offset when initializing table"); - } - for (size_t i = 0; i != segment.data.size(); ++i) { - externalInterface->tableStore( - table->name, offset + i, segment.data[i]); - } + for (auto& segment : wasm.elementSegments) { + if (segment->table.isNull()) { + continue; + } + + Address offset = + (uint32_t)InitializerExpressionRunner<GlobalManager>(globals, maxDepth) + .visit(segment->offset) + .getSingleValue() + .geti32(); + for (size_t i = 0; i != segment->data.size(); ++i) { + externalInterface->tableStore( + segment->table, offset + i, segment->data[i]); } } } diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 6ccc7c36c..eaddd8c8b 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -132,6 +132,7 @@ class SExpressionWasmBuilder { int globalCounter = 0; int eventCounter = 0; int tableCounter = 0; + int elemCounter = 0; int memoryCounter = 0; // we need to know function return types before we parse their contents std::map<Name, Signature> functionSignatures; @@ -313,11 +314,10 @@ private: void parseImport(Element& s); void parseGlobal(Element& s, bool preParseImport = false); void parseTable(Element& s, bool preParseImport = false); - void parseElem(Element& s); - void parseInnerElem(Table* table, - Element& s, - Index i = 1, - Expression* offset = nullptr); + void parseElem(Element& s, Table* table = nullptr); + ElementSegment* parseElemFinish(Element& s, + std::unique_ptr<ElementSegment>& segment, + Index i = 1); // Parses something like (func ..), (array ..), (struct) HeapType parseHeapType(Element& s); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 803fa98aa..7dcf2d146 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -48,6 +48,7 @@ template<typename SubType, typename ReturnType = void> struct Visitor { ReturnType visitGlobal(Global* curr) { return ReturnType(); } ReturnType visitFunction(Function* curr) { return ReturnType(); } ReturnType visitTable(Table* curr) { return ReturnType(); } + ReturnType visitElementSegment(ElementSegment* curr) { return ReturnType(); } ReturnType visitMemory(Memory* curr) { return ReturnType(); } ReturnType visitEvent(Event* curr) { return ReturnType(); } ReturnType visitModule(Module* curr) { return ReturnType(); } @@ -191,10 +192,14 @@ struct Walker : public VisitorType { // override this to provide custom functionality void doWalkFunction(Function* func) { walk(func->body); } - void walkTable(Table* table) { - for (auto& segment : table->segments) { - walk(segment.offset); + void walkElementSegment(ElementSegment* segment) { + if (segment->table.is()) { + walk(segment->offset); } + static_cast<SubType*>(this)->visitElementSegment(segment); + } + + void walkTable(Table* table) { static_cast<SubType*>(this)->visitTable(table); } @@ -244,7 +249,10 @@ struct Walker : public VisitorType { } for (auto& curr : module->tables) { self->walkTable(curr.get()); - }; + } + for (auto& curr : module->elementSegments) { + self->walkElementSegment(curr.get()); + } self->walkMemory(&module->memory); } diff --git a/src/wasm.h b/src/wasm.h index 9bf89f828..698de9520 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1550,7 +1550,7 @@ public: // Globals -struct Importable { +struct Named { Name name; // Explicit names are ones that we read from the input file and @@ -1559,11 +1559,6 @@ struct Importable { // use only and will not be written the name section. bool hasExplicitName = false; - // If these are set, then this is an import, as module.base - Name module, base; - - bool imported() const { return module.is(); } - void setName(Name name_, bool hasExplicitName_) { name = name_; hasExplicitName = hasExplicitName_; @@ -1572,6 +1567,13 @@ struct Importable { void setExplicitName(Name name_) { setName(name_, true); } }; +struct Importable : Named { + // If these are set, then this is an import, as module.base + Name module, base; + + bool imported() const { return module.is(); } +}; + class Function; // Represents an offset into a wasm binary file. This is used for debug info. @@ -1720,6 +1722,21 @@ public: ExternalKind kind; }; +class ElementSegment : public Named { +public: + Name table; + Expression* offset; + std::vector<Name> data; + + ElementSegment() = default; + ElementSegment(Name table, Expression* offset) + : table(table), offset(offset) {} + ElementSegment(Name table, Expression* offset, std::vector<Name>& init) + : table(table), offset(offset) { + data.swap(init); + } +}; + class Table : public Importable { public: static const Address::address32_t kPageSize = 1; @@ -1727,26 +1744,14 @@ public: // In wasm32/64, the maximum table size is limited by a 32-bit pointer: 4GB static const Index kMaxSize = Index(-1); - struct Segment { - Expression* offset; - std::vector<Name> data; - Segment() = default; - Segment(Expression* offset) : offset(offset) {} - Segment(Expression* offset, std::vector<Name>& init) : offset(offset) { - data.swap(init); - } - }; - Address initial = 0; Address max = kMaxSize; - std::vector<Segment> segments; bool hasMax() { return max != kUnlimitedSize; } void clear() { name = ""; initial = 0; max = kMaxSize; - segments.clear(); } }; @@ -1847,7 +1852,7 @@ public: std::vector<std::unique_ptr<Function>> functions; std::vector<std::unique_ptr<Global>> globals; std::vector<std::unique_ptr<Event>> events; - + std::vector<std::unique_ptr<ElementSegment>> elementSegments; std::vector<std::unique_ptr<Table>> tables; Memory memory; @@ -1890,6 +1895,7 @@ private: std::unordered_map<Name, Export*> exportsMap; std::unordered_map<Name, Function*> functionsMap; std::unordered_map<Name, Table*> tablesMap; + std::unordered_map<Name, ElementSegment*> elementSegmentsMap; std::unordered_map<Name, Global*> globalsMap; std::unordered_map<Name, Event*> eventsMap; @@ -1899,11 +1905,13 @@ public: Export* getExport(Name name); Function* getFunction(Name name); Table* getTable(Name name); + ElementSegment* getElementSegment(Name name); Global* getGlobal(Name name); Event* getEvent(Name name); Export* getExportOrNull(Name name); Table* getTableOrNull(Name name); + ElementSegment* getElementSegmentOrNull(Name name); Function* getFunctionOrNull(Name name); Global* getGlobalOrNull(Name name); Event* getEventOrNull(Name name); @@ -1916,6 +1924,7 @@ public: Export* addExport(std::unique_ptr<Export>&& curr); Function* addFunction(std::unique_ptr<Function>&& curr); Table* addTable(std::unique_ptr<Table>&& curr); + ElementSegment* addElementSegment(std::unique_ptr<ElementSegment>&& curr); Global* addGlobal(std::unique_ptr<Global>&& curr); Event* addEvent(std::unique_ptr<Event>&& curr); @@ -1924,12 +1933,14 @@ public: void removeExport(Name name); void removeFunction(Name name); void removeTable(Name name); + void removeElementSegment(Name name); void removeGlobal(Name name); void removeEvent(Name name); void removeExports(std::function<bool(Export*)> pred); void removeFunctions(std::function<bool(Function*)> pred); void removeTables(std::function<bool(Table*)> pred); + void removeElementSegments(std::function<bool(ElementSegment*)> pred); void removeGlobals(std::function<bool(Global*)> pred); void removeEvents(std::function<bool(Event*)> pred); diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 0f40cfd58..dacb6edbb 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -55,7 +55,7 @@ void WasmBinaryWriter::write() { writeGlobals(); writeExports(); writeStart(); - writeTableElements(); + writeElementSegments(); writeDataCount(); writeFunctions(); writeDataSegments(); @@ -552,11 +552,8 @@ void WasmBinaryWriter::writeTableDeclarations() { finishSection(start); } -void WasmBinaryWriter::writeTableElements() { - size_t elemCount = 0; - for (auto& table : wasm->tables) { - elemCount += table->segments.size(); - } +void WasmBinaryWriter::writeElementSegments() { + size_t elemCount = wasm->elementSegments.size(); auto needingElemDecl = TableUtils::getFunctionsNeedingElemDeclare(*wasm); if (!needingElemDecl.empty()) { elemCount++; @@ -565,40 +562,50 @@ void WasmBinaryWriter::writeTableElements() { return; } - BYN_TRACE("== writeTableElements\n"); + BYN_TRACE("== writeElementSegments\n"); auto start = startSection(BinaryConsts::Section::Element); o << U32LEB(elemCount); - for (auto& table : wasm->tables) { - for (auto& segment : table->segments) { - Index tableIdx = getTableIndex(table->name); - // No support for passive element segments yet as they don't belong to a - // table. - bool isPassive = false; - bool isDeclarative = false; - bool hasTableIndex = tableIdx > 0; - bool usesExpressions = false; - - uint32_t flags = - (isPassive ? BinaryConsts::IsPassive | - (isDeclarative ? BinaryConsts::IsDeclarative : 0) - : (hasTableIndex ? BinaryConsts::HasIndex : 0)) | - (usesExpressions ? BinaryConsts::UsesExpressions : 0); - - o << U32LEB(flags); + for (auto& segment : wasm->elementSegments) { + Index tableIdx = 0; + + bool isPassive = segment->table.isNull(); + // TODO(reference-types): add support for writing expressions instead of + // function indices. + bool usesExpressions = false; + + bool hasTableIndex = false; + if (!isPassive) { + tableIdx = getTableIndex(segment->table); + hasTableIndex = tableIdx > 0; + } + + uint32_t flags = 0; + if (usesExpressions) { + flags |= BinaryConsts::UsesExpressions; + } + if (isPassive) { + flags |= BinaryConsts::IsPassive; + } else if (hasTableIndex) { + flags |= BinaryConsts::HasIndex; + } + + o << U32LEB(flags); + if (!isPassive) { if (hasTableIndex) { o << U32LEB(tableIdx); } - writeExpression(segment.offset); + writeExpression(segment->offset); o << int8_t(BinaryConsts::End); - if (!usesExpressions && (isPassive || hasTableIndex)) { - // elemKind funcref - o << U32LEB(0); - } - o << U32LEB(segment.data.size()); - for (auto name : segment.data) { - o << U32LEB(getFunctionIndex(name)); - } + } + + if (!usesExpressions && (isPassive || hasTableIndex)) { + // elemKind funcref + o << U32LEB(0); + } + o << U32LEB(segment->data.size()); + for (auto& name : segment->data) { + o << U32LEB(getFunctionIndex(name)); } } @@ -765,6 +772,32 @@ void WasmBinaryWriter::writeNames() { } } + // elem names + { + std::vector<std::pair<Index, ElementSegment*>> elemsWithNames; + Index checked = 0; + for (auto& curr : wasm->elementSegments) { + if (curr->hasExplicitName) { + elemsWithNames.push_back({checked, curr.get()}); + } + checked++; + } + assert(checked == indexes.elemIndexes.size()); + + if (elemsWithNames.size() > 0) { + auto substart = + startSubsection(BinaryConsts::UserSections::Subsection::NameElem); + o << U32LEB(elemsWithNames.size()); + + for (auto& indexedElem : elemsWithNames) { + o << U32LEB(indexedElem.first); + writeEscapedName(indexedElem.second->name.str); + } + + finishSubsection(substart); + } + } + // memory names if (wasm->memory.exists && wasm->memory.hasExplicitName) { auto substart = @@ -1337,7 +1370,7 @@ void WasmBinaryBuilder::read() { readExports(); break; case BinaryConsts::Section::Element: - readTableElements(); + readElementSegments(); break; case BinaryConsts::Section::Global: readGlobals(); @@ -1349,7 +1382,7 @@ void WasmBinaryBuilder::read() { readDataCount(); break; case BinaryConsts::Section::Table: - readFunctionTableDeclaration(); + readTableDeclarations(); break; case BinaryConsts::Section::Event: readEvents(); @@ -2549,6 +2582,9 @@ void WasmBinaryBuilder::processNames() { for (auto& table : tables) { wasm.addTable(std::move(table)); } + for (auto& segment : elementSegments) { + wasm.addElementSegment(std::move(segment)); + } // now that we have names, apply things @@ -2607,14 +2643,11 @@ void WasmBinaryBuilder::processNames() { } } - for (auto& table_pair : functionTable) { - for (auto& pair : table_pair.second) { - auto i = pair.first; - auto& indices = pair.second; - for (auto j : indices) { - wasm.tables[table_pair.first]->segments[i].data.push_back( - getFunctionName(j)); - } + for (auto& pair : functionTable) { + auto i = pair.first; + auto& indices = pair.second; + for (auto j : indices) { + wasm.elementSegments[i]->data.push_back(getFunctionName(j)); } } @@ -2670,8 +2703,8 @@ void WasmBinaryBuilder::readDataSegments() { } } -void WasmBinaryBuilder::readFunctionTableDeclaration() { - BYN_TRACE("== readFunctionTableDeclaration\n"); +void WasmBinaryBuilder::readTableDeclarations() { + BYN_TRACE("== readTableDeclarations\n"); auto numTables = getU32LEB(); for (size_t i = 0; i < numTables; i++) { @@ -2695,8 +2728,8 @@ void WasmBinaryBuilder::readFunctionTableDeclaration() { } } -void WasmBinaryBuilder::readTableElements() { - BYN_TRACE("== readTableElements\n"); +void WasmBinaryBuilder::readElementSegments() { + BYN_TRACE("== readElementSegments\n"); auto numSegments = getU32LEB(); if (numSegments >= Table::kMaxSize) { throwError("Too many segments"); @@ -2704,55 +2737,63 @@ void WasmBinaryBuilder::readTableElements() { for (size_t i = 0; i < numSegments; i++) { auto flags = getU32LEB(); bool isPassive = (flags & BinaryConsts::IsPassive) != 0; - bool hasTableIdx = (flags & BinaryConsts::HasIndex) != 0; + bool hasTableIdx = !isPassive && ((flags & BinaryConsts::HasIndex) != 0); + bool isDeclarative = + isPassive && ((flags & BinaryConsts::IsDeclarative) != 0); bool usesExpressions = (flags & BinaryConsts::UsesExpressions) != 0; - if (isPassive) { - bool isDeclarative = (flags & BinaryConsts::IsDeclarative) != 0; - if (isDeclarative) { - // "elem declare" is needed in wasm text and binary, but not in Binaryen - // IR; read and ignore the contents. - auto type = getU32LEB(); - WASM_UNUSED(type); - auto num = getU32LEB(); - for (Index i = 0; i < num; i++) { - getU32LEB(); - } - continue; + if (isDeclarative) { + // Declared segments are needed in wasm text and binary, but not in + // Binaryen IR; skip over the segment + auto type = getU32LEB(); + WASM_UNUSED(type); + auto num = getU32LEB(); + for (Index i = 0; i < num; i++) { + getU32LEB(); } - - throwError("Only active elem segments are supported."); + continue; } if (usesExpressions) { throwError("Only elem segments with function indexes are supported."); } - Index tableIdx = 0; - if (hasTableIdx) { - tableIdx = getU32LEB(); - } + if (!isPassive) { + Index tableIdx = 0; + if (hasTableIdx) { + tableIdx = getU32LEB(); + } - auto numTableImports = tableImports.size(); - if (tableIdx < numTableImports) { - auto table = tableImports[tableIdx]; - table->segments.emplace_back(readExpression()); - } else if (tableIdx - numTableImports < tables.size()) { - auto table = tables[tableIdx - numTableImports].get(); - table->segments.emplace_back(readExpression()); + auto makeActiveElem = [&](Table* table) { + auto segment = + std::make_unique<ElementSegment>(table->name, readExpression()); + segment->setName(Name::fromInt(i), false); + elementSegments.push_back(std::move(segment)); + }; + + auto numTableImports = tableImports.size(); + if (tableIdx < numTableImports) { + makeActiveElem(tableImports[tableIdx]); + } else if (tableIdx - numTableImports < tables.size()) { + makeActiveElem(tables[tableIdx - numTableImports].get()); + } else { + throwError("Table index out of range."); + } } else { - throwError("Table index out of range."); + auto segment = std::make_unique<ElementSegment>(); + segment->setName(Name::fromInt(i), false); + elementSegments.push_back(std::move(segment)); } - if (hasTableIdx) { + if (isPassive || hasTableIdx) { auto elemKind = getU32LEB(); if (elemKind != 0x0) { throwError("Only funcref elem kinds are valid."); } } - size_t segmentIndex = functionTable[tableIdx].size(); - auto& indexSegment = functionTable[tableIdx][segmentIndex]; + size_t segmentIndex = functionTable.size(); + auto& indexSegment = functionTable[segmentIndex]; auto size = getU32LEB(); for (Index j = 0; j < size; j++) { indexSegment.push_back(getU32LEB()); @@ -2921,10 +2962,19 @@ void WasmBinaryBuilder::readNames(size_t payloadLen) { auto rawName = getInlineString(); auto name = processor.process(rawName); auto numTableImports = tableImports.size(); + auto setTableName = [&](Table* table) { + for (auto& segment : elementSegments) { + if (segment->table == table->name) { + segment->table = name; + } + } + table->setExplicitName(name); + }; + if (index < numTableImports) { - tableImports[index]->setExplicitName(name); + setTableName(tableImports[index]); } else if (index - numTableImports < tables.size()) { - tables[index - numTableImports]->setExplicitName(name); + setTableName(tables[index - numTableImports].get()); } else { std::cerr << "warning: table index out of bounds in name section, " "table subsection: " @@ -2932,6 +2982,23 @@ void WasmBinaryBuilder::readNames(size_t payloadLen) { << std::to_string(index) << std::endl; } } + } else if (nameType == BinaryConsts::UserSections::Subsection::NameElem) { + auto num = getU32LEB(); + NameProcessor processor; + for (size_t i = 0; i < num; i++) { + auto index = getU32LEB(); + auto rawName = getInlineString(); + auto name = processor.process(rawName); + + if (index < elementSegments.size()) { + elementSegments[index]->setExplicitName(name); + } else { + std::cerr << "warning: elem index out of bounds in name section, " + "elem subsection: " + << std::string(rawName.str) << " at index " + << std::to_string(index) << std::endl; + } + } } else if (nameType == BinaryConsts::UserSections::Subsection::NameMemory) { auto num = getU32LEB(); for (size_t i = 0; i < num; i++) { diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp index 339214f43..d97121544 100644 --- a/src/wasm/wasm-emscripten.cpp +++ b/src/wasm/wasm-emscripten.cpp @@ -23,7 +23,6 @@ #include "ir/import-utils.h" #include "ir/literal-utils.h" #include "ir/module-utils.h" -#include "ir/table-utils.h" #include "shared-constants.h" #include "support/debug.h" #include "wasm-builder.h" diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 31ae188d0..91ed5056f 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -51,7 +51,7 @@ int unhex(char c) { namespace wasm { static Name STRUCT("struct"), FIELD("field"), ARRAY("array"), I8("i8"), - I16("i16"), RTT("rtt"), DECLARE("declare"); + I16("i16"), RTT("rtt"), DECLARE("declare"), ITEM("item"), OFFSET("offset"); static Address getAddress(const Element* s) { return atoll(s->c_str()); } @@ -3215,15 +3215,25 @@ void SExpressionWasmBuilder::parseTable(Element& s, bool preParseImport) { wasm.addTable(std::move(table)); return; } + + auto parseTableElem = [&](Table* table, Element& s) { + parseElem(s, table); + auto it = std::find_if(wasm.elementSegments.begin(), + wasm.elementSegments.end(), + [&](std::unique_ptr<ElementSegment>& segment) { + return segment->table == table->name; + }); + if (it != wasm.elementSegments.end()) { + table->initial = table->max = it->get()->data.size(); + } else { + table->initial = table->max = 0; + } + }; + if (!s[i]->dollared()) { if (s[i]->str() == FUNCREF) { // (table type (elem ..)) - parseInnerElem(table.get(), *s[i + 1]); - if (table->segments.size() > 0) { - table->initial = table->max = table->segments[0].data.size(); - } else { - table->initial = table->max = 0; - } + parseTableElem(table.get(), *s[i + 1]); wasm.addTable(std::move(table)); return; } @@ -3243,13 +3253,7 @@ void SExpressionWasmBuilder::parseTable(Element& s, bool preParseImport) { } } // old notation (table func1 func2 ..) - parseInnerElem(table.get(), s, i); - if (table->segments.size() > 0) { - table->initial = table->max = table->segments[0].data.size(); - } else { - table->initial = table->max = 0; - } - + parseTableElem(table.get(), s); wasm.addTable(std::move(table)); } @@ -3257,49 +3261,77 @@ void SExpressionWasmBuilder::parseTable(Element& s, bool preParseImport) { // elem ::= (elem (expr) vec(funcidx)) // | (elem (offset (expr)) func vec(funcidx)) // | (elem (table tableidx) (offset (expr)) func vec(funcidx)) -// | (elem declare func $foo) +// | (elem func vec(funcidx)) +// | (elem declare func vec(funcidx)) // // abbreviation: // (offset (expr)) ≡ (expr) // (elem (expr) vec(funcidx)) ≡ (elem (table 0) (offset (expr)) func // vec(funcidx)) // -void SExpressionWasmBuilder::parseElem(Element& s) { +void SExpressionWasmBuilder::parseElem(Element& s, Table* table) { Index i = 1; - Table* table = nullptr; - Expression* offset = nullptr; + Name name = Name::fromInt(elemCounter++); + bool hasExplicitName = false; + + if (table) { + Expression* offset = allocator.alloc<Const>()->set(Literal(int32_t(0))); + auto segment = std::make_unique<ElementSegment>(table->name, offset); + segment->setName(name, hasExplicitName); + parseElemFinish(s, segment, i); + return; + } - if (!s[i]->isList()) { - // optional segment id OR 'declare' OR start of elemList - if (s[i]->str() == DECLARE) { - // "elem declare" is needed in wasm text and binary, but not in Binaryen - // IR; ignore the contents. - return; - } - i += 1; + if (s[i]->isStr() && s[i]->dollared()) { + name = s[i++]->str(); + hasExplicitName = true; + } + if (s[i]->isStr() && s[i]->str() == DECLARE) { + // We don't store declared segments in the IR + return; } - // old style refers to the pre-reftypes form of (elem (expr) vec(funcidx)) + if (s[i]->isStr() && s[i]->str() == FUNC) { + auto segment = std::make_unique<ElementSegment>(); + segment->setName(name, hasExplicitName); + parseElemFinish(s, segment, i + 1); + return; + } + + // old style refers to the pre-reftypes form of (elem 0? (expr) vec(funcidx)) bool oldStyle = true; - while (1) { + // At this point, we know that we're parsing an active element segment. A + // table will be mandatory now. + if (wasm.tables.empty()) { + throw ParseException("elem without table", s.line, s.col); + } + + // Old style table index (elem 0 (i32.const 0) ...) + if (s[i]->isStr()) { + i += 1; + } + + if (s[i]->isList() && elementStartsWith(s[i], TABLE)) { + oldStyle = false; auto& inner = *s[i++]; - if (elementStartsWith(inner, TABLE)) { + Name tableName = getTableName(*inner[1]); + table = wasm.getTable(tableName); + } + + Expression* offset = nullptr; + if (s[i]->isList()) { + auto& inner = *s[i++]; + if (elementStartsWith(inner, OFFSET)) { + offset = parseExpression(inner[1]); oldStyle = false; - Name tableName = getTableName(*inner[1]); - table = wasm.getTable(tableName); } else { - if (elementStartsWith(inner, "offset")) { - offset = parseExpression(inner[1]); - } else { - offset = parseExpression(inner); - } - break; + offset = parseExpression(inner); } } if (!oldStyle) { - if (strcmp(s[i]->c_str(), "func") != 0) { + if (s[i]->str() != FUNC) { throw ParseException( "only the abbreviated form of elemList is supported."); } @@ -3307,27 +3339,21 @@ void SExpressionWasmBuilder::parseElem(Element& s) { i += 1; } - if (wasm.tables.empty()) { - throw ParseException("elem without table", s.line, s.col); - } else if (!table) { - table = wasm.tables[0].get(); + if (!table) { + table = wasm.tables.front().get(); } - parseInnerElem(table, s, i, offset); + auto segment = std::make_unique<ElementSegment>(table->name, offset); + segment->setName(name, hasExplicitName); + parseElemFinish(s, segment, i); } -void SExpressionWasmBuilder::parseInnerElem(Table* table, - Element& s, - Index i, - Expression* offset) { - if (!offset) { - offset = allocator.alloc<Const>()->set(Literal(int32_t(0))); - } - Table::Segment segment(offset); +ElementSegment* SExpressionWasmBuilder::parseElemFinish( + Element& s, std::unique_ptr<ElementSegment>& segment, Index i) { for (; i < s.size(); i++) { - segment.data.push_back(getFunctionName(*s[i])); + segment->data.push_back(getFunctionName(*s[i])); } - table->segments.push_back(segment); + return wasm.addElementSegment(std::move(segment)); } HeapType SExpressionWasmBuilder::parseHeapType(Element& s) { diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index e95950bb5..656c5e832 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2882,21 +2882,25 @@ static void validateTables(Module& module, ValidationInfo& info) { "Only 1 table definition allowed in MVP (requires " "--enable-reference-types)"); } - for (auto& curr : module.tables) { - for (auto& segment : curr->segments) { - info.shouldBeEqual(segment.offset->type, + + for (auto& segment : module.elementSegments) { + if (segment->table.is()) { + auto table = module.getTableOrNull(segment->table); + info.shouldBeTrue( + table != nullptr, "elem", "elem segment must have a valid table name"); + info.shouldBeEqual(segment->offset->type, Type(Type::i32), - segment.offset, - "segment offset should be i32"); - info.shouldBeTrue(checkSegmentOffset(segment.offset, - segment.data.size(), - curr->initial * Table::kPageSize), - segment.offset, + segment->offset, + "elem segment offset should be i32"); + info.shouldBeTrue(checkSegmentOffset(segment->offset, + segment->data.size(), + table->initial * Table::kPageSize), + segment->offset, "table segment offset should be reasonable"); - for (auto name : segment.data) { - info.shouldBeTrue( - module.getFunctionOrNull(name), name, "segment name should be valid"); - } + } + for (auto name : segment->data) { + info.shouldBeTrue( + module.getFunctionOrNull(name), name, "segment name should be valid"); } } } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 92e03f579..ef8cfb2f7 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -70,6 +70,9 @@ Name EXPORT("export"); Name IMPORT("import"); Name TABLE("table"); Name ELEM("elem"); +Name DECLARE("declare"); +Name OFFSET("offset"); +Name ITEM("item"); Name LOCAL("local"); Name TYPE("type"); Name REF("ref"); @@ -802,12 +805,9 @@ void MemoryGrow::finalize() { void RefNull::finalize(HeapType heapType) { type = Type(heapType, Nullable); } -void RefNull::finalize(Type type_) { - type = type_; -} +void RefNull::finalize(Type type_) { type = type_; } -void RefNull::finalize() { -} +void RefNull::finalize() {} void RefIs::finalize() { if (value->type == Type::unreachable) { @@ -1169,6 +1169,10 @@ Table* Module::getTable(Name name) { return getModuleElement(tablesMap, name, "getTable"); } +ElementSegment* Module::getElementSegment(Name name) { + return getModuleElement(elementSegmentsMap, name, "getElementSegment"); +} + Global* Module::getGlobal(Name name) { return getModuleElement(globalsMap, name, "getGlobal"); } @@ -1198,6 +1202,10 @@ Table* Module::getTableOrNull(Name name) { return getModuleElementOrNull(tablesMap, name); } +ElementSegment* Module::getElementSegmentOrNull(Name name) { + return getModuleElementOrNull(elementSegmentsMap, name); +} + Global* Module::getGlobalOrNull(Name name) { return getModuleElementOrNull(globalsMap, name); } @@ -1267,6 +1275,12 @@ Table* Module::addTable(std::unique_ptr<Table>&& curr) { return addModuleElement(tables, tablesMap, std::move(curr), "addTable"); } +ElementSegment* +Module::addElementSegment(std::unique_ptr<ElementSegment>&& curr) { + return addModuleElement( + elementSegments, elementSegmentsMap, std::move(curr), "addElementSegment"); +} + Global* Module::addGlobal(std::unique_ptr<Global>&& curr) { return addModuleElement(globals, globalsMap, std::move(curr), "addGlobal"); } @@ -1297,6 +1311,9 @@ void Module::removeFunction(Name name) { void Module::removeTable(Name name) { removeModuleElement(tables, tablesMap, name); } +void Module::removeElementSegment(Name name) { + removeModuleElement(elementSegments, elementSegmentsMap, name); +} void Module::removeGlobal(Name name) { removeModuleElement(globals, globalsMap, name); } @@ -1329,6 +1346,9 @@ void Module::removeFunctions(std::function<bool(Function*)> pred) { void Module::removeTables(std::function<bool(Table*)> pred) { removeModuleElements(tables, tablesMap, pred); } +void Module::removeElementSegments(std::function<bool(ElementSegment*)> pred) { + removeModuleElements(elementSegments, elementSegmentsMap, pred); +} void Module::removeGlobals(std::function<bool(Global*)> pred) { removeModuleElements(globals, globalsMap, pred); } @@ -1349,6 +1369,10 @@ void Module::updateMaps() { for (auto& curr : tables) { tablesMap[curr->name] = curr.get(); } + elementSegmentsMap.clear(); + for (auto& curr : elementSegments) { + elementSegmentsMap[curr->name] = curr.get(); + } globalsMap.clear(); for (auto& curr : globals) { globalsMap[curr->name] = curr.get(); diff --git a/src/wasm2js.h b/src/wasm2js.h index b1cc13898..b3a4c7e18 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h @@ -325,11 +325,9 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) { functionsCallableFromOutside.insert(exp->value); } } - for (auto& table : wasm->tables) { - for (auto& segment : table->segments) { - for (auto name : segment.data) { - functionsCallableFromOutside.insert(name); - } + for (auto& segment : wasm->elementSegments) { + for (auto name : segment->data) { + functionsCallableFromOutside.insert(name); } } @@ -635,7 +633,7 @@ void Wasm2JSBuilder::addTable(Ref ast, Module* wasm) { Ref theArray = ValueBuilder::makeArray(); for (auto& table : wasm->tables) { if (!table->imported()) { - TableUtils::FlatTable flat(*table); + TableUtils::FlatTable flat(*wasm, *table); if (flat.valid) { Name null("null"); for (auto& name : flat.names) { @@ -679,28 +677,30 @@ void Wasm2JSBuilder::addTable(Ref ast, Module* wasm) { if (perElementInit) { // TODO: optimize for size - for (auto& segment : table->segments) { - auto offset = segment.offset; - for (Index i = 0; i < segment.data.size(); i++) { - Ref index; - if (auto* c = offset->dynCast<Const>()) { - index = ValueBuilder::makeInt(c->value.geti32() + i); - } else if (auto* get = offset->dynCast<GlobalGet>()) { - index = ValueBuilder::makeBinary( - ValueBuilder::makeName(stringToIString(asmangle(get->name.str))), - PLUS, - ValueBuilder::makeNum(i)); - } else { - WASM_UNREACHABLE("unexpected expr type"); + ModuleUtils::iterTableSegments( + *wasm, table->name, [&](ElementSegment* segment) { + auto offset = segment->offset; + for (Index i = 0; i < segment->data.size(); i++) { + Ref index; + if (auto* c = offset->dynCast<Const>()) { + index = ValueBuilder::makeInt(c->value.geti32() + i); + } else if (auto* get = offset->dynCast<GlobalGet>()) { + index = + ValueBuilder::makeBinary(ValueBuilder::makeName(stringToIString( + asmangle(get->name.str))), + PLUS, + ValueBuilder::makeNum(i)); + } else { + WASM_UNREACHABLE("unexpected expr type"); + } + ast->push_back(ValueBuilder::makeStatement(ValueBuilder::makeBinary( + ValueBuilder::makeSub(ValueBuilder::makeName(FUNCTION_TABLE), + index), + SET, + ValueBuilder::makeName( + fromName(segment->data[i], NameScope::Top))))); } - ast->push_back(ValueBuilder::makeStatement(ValueBuilder::makeBinary( - ValueBuilder::makeSub(ValueBuilder::makeName(FUNCTION_TABLE), - index), - SET, - ValueBuilder::makeName( - fromName(segment.data[i], NameScope::Top))))); - } - } + }); } } } |