summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/binaryen-c.cpp24
-rw-r--r--src/ir/element-utils.h52
-rw-r--r--src/ir/module-splitting.cpp84
-rw-r--r--src/ir/module-utils.h12
-rw-r--r--src/ir/table-utils.cpp8
-rw-r--r--src/ir/table-utils.h19
-rw-r--r--src/passes/DeadArgumentElimination.cpp8
-rw-r--r--src/passes/FuncCastEmulation.cpp21
-rw-r--r--src/passes/GenerateDynCalls.cpp8
-rw-r--r--src/passes/Inlining.cpp8
-rw-r--r--src/passes/LegalizeJSInterface.cpp11
-rw-r--r--src/passes/Print.cpp34
-rw-r--r--src/passes/PrintCallGraph.cpp11
-rw-r--r--src/passes/RemoveImports.cpp8
-rw-r--r--src/passes/RemoveUnusedModuleElements.cpp9
-rw-r--r--src/passes/ReorderFunctions.cpp8
-rw-r--r--src/passes/opt-utils.h7
-rw-r--r--src/shell-interface.h9
-rw-r--r--src/tools/fuzzing.h16
-rw-r--r--src/tools/wasm-ctor-eval.cpp30
-rw-r--r--src/tools/wasm-metadce.cpp16
-rw-r--r--src/tools/wasm-reduce.cpp79
-rw-r--r--src/tools/wasm-shell.cpp19
-rw-r--r--src/wasm-binary.h3
-rw-r--r--src/wasm-interpreter.h19
-rw-r--r--src/wasm-s-parser.h3
-rw-r--r--src/wasm-traversal.h3
-rw-r--r--src/wasm.h4
-rw-r--r--src/wasm/wasm-binary.cpp76
-rw-r--r--src/wasm/wasm-s-parser.cpp61
-rw-r--r--src/wasm/wasm-validator.cpp44
-rw-r--r--src/wasm2js.h49
-rw-r--r--test/multi-table.wast6
-rw-r--r--test/multi-table.wast.from-wast5
-rw-r--r--test/multi-table.wast.fromBinary5
-rw-r--r--test/multi-table.wast.fromBinary.noDebugInfo5
-rw-r--r--test/passes/converge_O3_metrics.bin.txt6
-rw-r--r--test/passes/func-metrics.txt3
-rw-r--r--test/passes/fuzz_metrics_noprint.bin.txt3
-rw-r--r--test/passes/metrics_all-features.txt3
-rw-r--r--test/spec/call_indirect_refnull.wast12
41 files changed, 531 insertions, 280 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp
index 1b3f11e22..aff4ed4fc 100644
--- a/src/binaryen-c.cpp
+++ b/src/binaryen-c.cpp
@@ -3409,7 +3409,13 @@ BinaryenAddActiveElementSegment(BinaryenModuleRef module,
auto segment = std::make_unique<ElementSegment>(table, (Expression*)offset);
segment->setExplicitName(name);
for (BinaryenIndex i = 0; i < numFuncNames; i++) {
- segment->data.push_back(funcNames[i]);
+ auto* func = ((Module*)module)->getFunctionOrNull(funcNames[i]);
+ if (func == nullptr) {
+ Fatal() << "invalid function '" << funcNames[i] << "'.";
+ }
+ Type type(HeapType(func->sig), Nullable);
+ segment->data.push_back(
+ Builder(*(Module*)module).makeRefFunc(funcNames[i], type));
}
return ((Module*)module)->addElementSegment(std::move(segment));
}
@@ -3421,7 +3427,13 @@ BinaryenAddPassiveElementSegment(BinaryenModuleRef module,
auto segment = std::make_unique<ElementSegment>();
segment->setExplicitName(name);
for (BinaryenIndex i = 0; i < numFuncNames; i++) {
- segment->data.push_back(funcNames[i]);
+ auto* func = ((Module*)module)->getFunctionOrNull(funcNames[i]);
+ if (func == nullptr) {
+ Fatal() << "invalid function '" << funcNames[i] << "'.";
+ }
+ Type type(HeapType(func->sig), Nullable);
+ segment->data.push_back(
+ Builder(*(Module*)module).makeRefFunc(funcNames[i], type));
}
return ((Module*)module)->addElementSegment(std::move(segment));
}
@@ -3460,7 +3472,13 @@ const char* BinaryenElementSegmentGetData(BinaryenElementSegmentRef elem,
if (data.size() <= dataId) {
Fatal() << "invalid segment data id.";
}
- return data[dataId].c_str();
+ if (data[dataId]->is<RefNull>()) {
+ return NULL;
+ } else if (auto* get = data[dataId]->dynCast<RefFunc>()) {
+ return get->func.c_str();
+ } else {
+ Fatal() << "invalid expression in segment data.";
+ }
}
// Memory. One per module
diff --git a/src/ir/element-utils.h b/src/ir/element-utils.h
new file mode 100644
index 000000000..adfd9955f
--- /dev/null
+++ b/src/ir/element-utils.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2021 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_ir_element_h
+#define wasm_ir_element_h
+
+#include "wasm-builder.h"
+#include "wasm.h"
+
+namespace wasm {
+
+namespace ElementUtils {
+
+// iterate over functions referenced in an element segment
+template<typename T>
+inline void iterElementSegmentFunctionNames(ElementSegment* segment,
+ T visitor) {
+ // TODO(reference-types): return early if segment type is non-funcref
+ for (Index i = 0; i < segment->data.size(); i++) {
+ if (auto* get = segment->data[i]->dynCast<RefFunc>()) {
+ visitor(get->func, i);
+ }
+ }
+}
+
+// iterate over functions referenced in all element segments of a module
+template<typename T>
+inline void iterAllElementFunctionNames(const Module* wasm, T visitor) {
+ for (auto& segment : wasm->elementSegments) {
+ iterElementSegmentFunctionNames(
+ segment.get(), [&](Name& name, Index i) { visitor(name); });
+ }
+}
+
+} // namespace ElementUtils
+
+} // namespace wasm
+
+#endif // wasm_ir_element_h
diff --git a/src/ir/module-splitting.cpp b/src/ir/module-splitting.cpp
index d4d03991e..5b86dd494 100644
--- a/src/ir/module-splitting.cpp
+++ b/src/ir/module-splitting.cpp
@@ -73,6 +73,7 @@
// complex code, so it is a good candidate for a follow up PR.
#include "ir/module-splitting.h"
+#include "ir/element-utils.h"
#include "ir/manipulation.h"
#include "ir/module-utils.h"
#include "ir/names.h"
@@ -95,12 +96,19 @@ template<class F> void forEachElement(Module& module, F f) {
} else if (auto* g = segment->offset->dynCast<GlobalGet>()) {
base = g->name;
}
- for (size_t i = 0; i < segment->data.size(); ++i) {
- f(segment->table, base, offset + i, segment->data[i]);
- }
+ ElementUtils::iterElementSegmentFunctionNames(
+ segment, [&](Name& entry, Index i) {
+ f(segment->table, base, offset + i, entry);
+ });
});
}
+static RefFunc* makeRefFunc(Module& wasm, Function* func) {
+ // FIXME: make the type NonNullable when we support it!
+ return Builder(wasm).makeRefFunc(func->name,
+ Type(HeapType(func->sig), Nullable));
+}
+
struct TableSlotManager {
struct Slot {
Name tableName;
@@ -124,7 +132,7 @@ struct TableSlotManager {
Table* makeTable();
// Returns the table index for `func`, allocating a new index if necessary.
- Slot getSlot(Name func);
+ Slot getSlot(RefFunc* entry);
void addSlot(Name func, Slot slot);
};
@@ -199,8 +207,8 @@ Table* TableSlotManager::makeTable() {
return module.addTable(Builder::makeTable(Name::fromInt(0)));
}
-TableSlotManager::Slot TableSlotManager::getSlot(Name func) {
- auto slotIt = funcIndices.find(func);
+TableSlotManager::Slot TableSlotManager::getSlot(RefFunc* entry) {
+ auto slotIt = funcIndices.find(entry->func);
if (slotIt != funcIndices.end()) {
return slotIt->second;
}
@@ -227,8 +235,10 @@ TableSlotManager::Slot TableSlotManager::getSlot(Name func) {
Slot newSlot = {activeBase.tableName,
activeBase.global,
activeBase.index + Index(activeSegment->data.size())};
- activeSegment->data.push_back(func);
- addSlot(func, newSlot);
+
+ activeSegment->data.push_back(entry);
+
+ addSlot(entry->func, newSlot);
if (activeTable->initial <= newSlot.index) {
activeTable->initial = newSlot.index + 1;
}
@@ -372,15 +382,15 @@ void ModuleSplitter::thunkExportedSecondaryFunctions() {
// We've already created a thunk for this function
continue;
}
- auto tableSlot = tableManager.getSlot(secondaryFunc);
auto func = std::make_unique<Function>();
-
func->name = secondaryFunc;
func->sig = secondary.getFunction(secondaryFunc)->sig;
std::vector<Expression*> args;
for (size_t i = 0, size = func->sig.params.size(); i < size; ++i) {
args.push_back(builder.makeLocalGet(i, func->sig.params[i]));
}
+
+ auto tableSlot = tableManager.getSlot(makeRefFunc(primary, func.get()));
func->body = builder.makeCallIndirect(
tableSlot.tableName, tableSlot.makeExpr(primary), args, func->sig);
primary.addFunction(std::move(func));
@@ -395,17 +405,21 @@ void ModuleSplitter::indirectCallsToSecondaryFunctions() {
Builder builder;
CallIndirector(ModuleSplitter& parent)
: parent(parent), builder(parent.primary) {}
+ // Avoid visitRefFunc on element segment data
+ void walkElementSegment(ElementSegment* segment) {}
void visitCall(Call* curr) {
if (!parent.secondaryFuncs.count(curr->target)) {
return;
}
- auto tableSlot = parent.tableManager.getSlot(curr->target);
- replaceCurrent(builder.makeCallIndirect(
- tableSlot.tableName,
- tableSlot.makeExpr(parent.primary),
- curr->operands,
- parent.secondary.getFunction(curr->target)->sig,
- curr->isReturn));
+ auto func = parent.secondary.getFunction(curr->target);
+ auto tableSlot =
+ parent.tableManager.getSlot(makeRefFunc(parent.primary, func));
+ replaceCurrent(
+ builder.makeCallIndirect(tableSlot.tableName,
+ tableSlot.makeExpr(parent.primary),
+ curr->operands,
+ func->sig,
+ curr->isReturn));
}
void visitRefFunc(RefFunc* curr) {
assert(false && "TODO: handle ref.func as well");
@@ -454,14 +468,14 @@ void ModuleSplitter::setupTablePatching() {
return;
}
- std::map<Index, Name> replacedElems;
+ std::map<Index, Function*> replacedElems;
// Replace table references to secondary functions with an imported
// placeholder that encodes the table index in its name:
// `importNamespace`.`index`.
forEachElement(primary, [&](Name, Name, Index index, Name& elem) {
if (secondaryFuncs.count(elem)) {
- replacedElems[index] = elem;
auto* secondaryFunc = secondary.getFunction(elem);
+ replacedElems[index] = secondaryFunc;
auto placeholder = std::make_unique<Function>();
placeholder->module = config.placeholderNamespace;
placeholder->base = std::to_string(index);
@@ -495,8 +509,8 @@ void ModuleSplitter::setupTablePatching() {
// to be imported into the second module. TODO: use better strategies here,
// such as using ref.func in the start function or standardizing addition in
// initializer expressions.
- const ElementSegment* primarySeg = tableManager.activeTableSegments.front();
- std::vector<Name> secondaryElems;
+ ElementSegment* primarySeg = tableManager.activeTableSegments.front();
+ std::vector<Expression*> secondaryElems;
secondaryElems.reserve(primarySeg->data.size());
// Copy functions from the primary segment to the secondary segment,
@@ -507,33 +521,35 @@ void ModuleSplitter::setupTablePatching() {
++i) {
if (replacement->first == i) {
// primarySeg->data[i] is a placeholder, so use the secondary function.
- secondaryElems.push_back(replacement->second);
+ secondaryElems.push_back(makeRefFunc(secondary, replacement->second));
++replacement;
- } else {
- exportImportFunction(primarySeg->data[i]);
- secondaryElems.push_back(primarySeg->data[i]);
+ } else if (auto* get = primarySeg->data[i]->dynCast<RefFunc>()) {
+ exportImportFunction(get->func);
+ auto* copied =
+ ExpressionManipulator::copy(primarySeg->data[i], secondary);
+ secondaryElems.push_back(copied);
}
}
auto offset = ExpressionManipulator::copy(primarySeg->offset, secondary);
- auto secondaryElem = std::make_unique<ElementSegment>(
+ auto secondarySeg = std::make_unique<ElementSegment>(
secondaryTable->name, offset, secondaryElems);
- secondaryElem->setName(primarySeg->name, primarySeg->hasExplicitName);
- secondary.addElementSegment(std::move(secondaryElem));
+ secondarySeg->setName(primarySeg->name, primarySeg->hasExplicitName);
+ secondary.addElementSegment(std::move(secondarySeg));
return;
}
// Create active table segments in the secondary module to patch in the
// original functions when it is instantiated.
Index currBase = replacedElems.begin()->first;
- std::vector<Name> currData;
+ std::vector<Expression*> currData;
auto finishSegment = [&]() {
auto* offset = Builder(secondary).makeConst(int32_t(currBase));
- auto secondaryElem =
+ auto secondarySeg =
std::make_unique<ElementSegment>(secondaryTable->name, offset, currData);
- secondaryElem->setName(Name::fromInt(secondary.elementSegments.size()),
- false);
- secondary.addElementSegment(std::move(secondaryElem));
+ secondarySeg->setName(Name::fromInt(secondary.elementSegments.size()),
+ false);
+ secondary.addElementSegment(std::move(secondarySeg));
};
for (auto curr = replacedElems.begin(); curr != replacedElems.end(); ++curr) {
if (curr->first != currBase + currData.size()) {
@@ -541,7 +557,7 @@ void ModuleSplitter::setupTablePatching() {
currBase = curr->first;
currData.clear();
}
- currData.push_back(curr->second);
+ currData.push_back(makeRefFunc(secondary, curr->second));
}
if (currData.size()) {
finishSegment();
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h
index 9b1682b91..a87daee9e 100644
--- a/src/ir/module-utils.h
+++ b/src/ir/module-utils.h
@@ -17,6 +17,7 @@
#ifndef wasm_ir_module_h
#define wasm_ir_module_h
+#include "ir/element-utils.h"
#include "ir/find_all.h"
#include "ir/manipulation.h"
#include "ir/properties.h"
@@ -75,7 +76,10 @@ inline ElementSegment* copyElementSegment(const ElementSegment* segment,
auto copy = [&](std::unique_ptr<ElementSegment>&& ret) {
ret->name = segment->name;
ret->hasExplicitName = segment->hasExplicitName;
- ret->data = segment->data;
+ ret->data.reserve(segment->data.size());
+ for (auto* item : segment->data) {
+ ret->data.push_back(ExpressionManipulator::copy(item, out));
+ }
return out.addElementSegment(std::move(ret));
};
@@ -161,11 +165,7 @@ template<typename T> inline void renameFunctions(Module& wasm, T& map) {
}
};
maybeUpdate(wasm.start);
- for (auto& segment : wasm.elementSegments) {
- for (auto& name : segment->data) {
- maybeUpdate(name);
- }
- }
+ ElementUtils::iterAllElementFunctionNames(&wasm, maybeUpdate);
for (auto& exp : wasm.exports) {
if (exp->kind == ExternalKind::Function) {
maybeUpdate(exp->value);
diff --git a/src/ir/table-utils.cpp b/src/ir/table-utils.cpp
index 639f8fbe6..80ef885f9 100644
--- a/src/ir/table-utils.cpp
+++ b/src/ir/table-utils.cpp
@@ -15,6 +15,7 @@
*/
#include "table-utils.h"
+#include "element-utils.h"
#include "find_all.h"
#include "module-utils.h"
@@ -31,11 +32,8 @@ std::set<Name> getFunctionsNeedingElemDeclare(Module& wasm) {
// Find all the names in the tables.
std::unordered_set<Name> tableNames;
- for (auto& segment : wasm.elementSegments) {
- for (auto name : segment->data) {
- tableNames.insert(name);
- }
- }
+ ElementUtils::iterAllElementFunctionNames(
+ &wasm, [&](Name name) { tableNames.insert(name); });
// Find all the names in ref.funcs.
using Names = std::unordered_set<Name>;
diff --git a/src/ir/table-utils.h b/src/ir/table-utils.h
index e90b0ca72..2d91f0035 100644
--- a/src/ir/table-utils.h
+++ b/src/ir/table-utils.h
@@ -17,6 +17,7 @@
#ifndef wasm_ir_table_h
#define wasm_ir_table_h
+#include "ir/element-utils.h"
#include "ir/literal-utils.h"
#include "ir/module-utils.h"
#include "wasm-traversal.h"
@@ -45,9 +46,8 @@ struct FlatTable {
if (end > names.size()) {
names.resize(end);
}
- for (Index i = 0; i < segment->data.size(); i++) {
- names[start + i] = segment->data[i];
- }
+ ElementUtils::iterElementSegmentFunctionNames(
+ segment, [&](Name entry, Index i) { names[start + i] = entry; });
});
}
};
@@ -84,7 +84,12 @@ inline Index append(Table& table, Name name, Module& wasm) {
}
wasm.dylinkSection->tableSize++;
}
- segment->data.push_back(name);
+
+ auto* func = wasm.getFunctionOrNull(name);
+ assert(func != nullptr && "Cannot append non-existing function to a table.");
+ // FIXME: make the type NonNullable when we support it!
+ auto type = Type(HeapType(func->sig), Nullable);
+ segment->data.push_back(Builder(wasm).makeRefFunc(name, type));
table.initial = table.initial + 1;
return tableIndex;
}
@@ -94,8 +99,10 @@ inline Index append(Table& table, Name name, Module& wasm) {
inline Index getOrAppend(Table& table, Name name, Module& wasm) {
auto segment = getSingletonSegment(table, wasm);
for (Index i = 0; i < segment->data.size(); i++) {
- if (segment->data[i] == name) {
- return i;
+ if (auto* get = segment->data[i]->dynCast<RefFunc>()) {
+ if (get->func == name) {
+ return i;
+ }
}
}
return append(table, name, wasm);
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp
index 6223c841d..c4de07d81 100644
--- a/src/passes/DeadArgumentElimination.cpp
+++ b/src/passes/DeadArgumentElimination.cpp
@@ -39,6 +39,7 @@
#include "cfg/cfg-traversal.h"
#include "ir/effects.h"
+#include "ir/element-utils.h"
#include "ir/module-utils.h"
#include "pass.h"
#include "passes/opt-utils.h"
@@ -284,11 +285,8 @@ struct DAE : public Pass {
infoMap[curr->value].hasUnseenCalls = true;
}
}
- for (auto& segment : module->elementSegments) {
- for (auto name : segment->data) {
- infoMap[name].hasUnseenCalls = true;
- }
- }
+ ElementUtils::iterAllElementFunctionNames(
+ module, [&](Name name) { infoMap[name].hasUnseenCalls = true; });
// Scan all the functions.
DAEScanner(&infoMap).run(runner, module);
// Combine all the info.
diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp
index 753c986b8..e63375f2b 100644
--- a/src/passes/FuncCastEmulation.cpp
+++ b/src/passes/FuncCastEmulation.cpp
@@ -30,6 +30,7 @@
#include <string>
+#include <ir/element-utils.h>
#include <ir/literal-utils.h>
#include <pass.h>
#include <wasm-builder.h>
@@ -173,18 +174,16 @@ struct FuncCastEmulation : public Pass {
Signature ABIType(Type(std::vector<Type>(numParams, Type::i64)), Type::i64);
// Add a thunk for each function in the table, and do the call through it.
std::unordered_map<Name, Name> funcThunks;
- for (auto& segment : module->elementSegments) {
- for (auto& name : segment->data) {
- auto iter = funcThunks.find(name);
- if (iter == funcThunks.end()) {
- auto thunk = makeThunk(name, module, numParams);
- funcThunks[name] = thunk;
- name = thunk;
- } else {
- name = iter->second;
- }
+ ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) {
+ auto iter = funcThunks.find(name);
+ if (iter == funcThunks.end()) {
+ auto thunk = makeThunk(name, module, numParams);
+ funcThunks[name] = thunk;
+ name = thunk;
+ } else {
+ name = iter->second;
}
- }
+ });
// update call_indirects
ParallelFuncCastEmulation(ABIType, numParams).run(runner, module);
diff --git a/src/passes/GenerateDynCalls.cpp b/src/passes/GenerateDynCalls.cpp
index dad992fea..8ed9ce8b4 100644
--- a/src/passes/GenerateDynCalls.cpp
+++ b/src/passes/GenerateDynCalls.cpp
@@ -24,6 +24,7 @@
#include "abi/js.h"
#include "asm_v_wasm.h"
+#include "ir/element-utils.h"
#include "ir/import-utils.h"
#include "pass.h"
#include "support/debug.h"
@@ -57,9 +58,10 @@ struct GenerateDynCalls : public WalkerPass<PostWalker<GenerateDynCalls>> {
});
if (it != segments.end()) {
std::vector<Name> tableSegmentData;
- for (const auto& indirectFunc : it->get()->data) {
- generateDynCallThunk(wasm->getFunction(indirectFunc)->sig);
- }
+ ElementUtils::iterElementSegmentFunctionNames(
+ it->get(), [&](Name name, Index) {
+ generateDynCallThunk(wasm->getFunction(name)->sig);
+ });
}
}
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp
index 3f6263815..b69d8f7c2 100644
--- a/src/passes/Inlining.cpp
+++ b/src/passes/Inlining.cpp
@@ -31,6 +31,7 @@
#include <atomic>
#include "ir/debug.h"
+#include "ir/element-utils.h"
#include "ir/literal-utils.h"
#include "ir/module-utils.h"
#include "ir/type-updating.h"
@@ -350,11 +351,8 @@ struct Inlining : public Pass {
infos[ex->value].usedGlobally = true;
}
}
- for (auto& segment : module->elementSegments) {
- for (auto name : segment->data) {
- infos[name].usedGlobally = true;
- }
- }
+ ElementUtils::iterAllElementFunctionNames(
+ module, [&](Name name) { infos[name].usedGlobally = true; });
for (auto& global : module->globals) {
if (!global->imported()) {
diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp
index 0a991caa5..a53481d65 100644
--- a/src/passes/LegalizeJSInterface.cpp
+++ b/src/passes/LegalizeJSInterface.cpp
@@ -31,6 +31,7 @@
//
#include "asmjs/shared-constants.h"
+#include "ir/element-utils.h"
#include "ir/import-utils.h"
#include "ir/literal-utils.h"
#include "ir/utils.h"
@@ -97,13 +98,11 @@ struct LegalizeJSInterface : public Pass {
// we need to use the legalized version in the tables, as the import
// from JS is legal for JS. Our stub makes it look like a native wasm
// function.
- for (auto& segment : module->elementSegments) {
- for (auto& name : segment->data) {
- if (name == im->name) {
- name = funcName;
- }
+ ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) {
+ if (name == im->name) {
+ name = funcName;
}
- }
+ });
}
}
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 4346fe705..6dfd495c0 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -2704,6 +2704,18 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
if (curr->data.empty()) {
return;
}
+ bool allElementsRefFunc =
+ std::all_of(curr->data.begin(), curr->data.end(), [](Expression* entry) {
+ return entry->is<RefFunc>();
+ });
+ auto printElemType = [&]() {
+ if (allElementsRefFunc) {
+ TypeNamePrinter(o, currModule).print(HeapType::func);
+ } else {
+ TypeNamePrinter(o, currModule).print(Type::funcref);
+ }
+ };
+
doIndent(o, indent);
o << '(';
printMedium(o, "elem");
@@ -2714,7 +2726,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
if (curr->table.is()) {
// TODO(reference-types): check for old-style based on the complete spec
- if (currModule->tables.size() > 1) {
+ if (!allElementsRefFunc || currModule->tables.size() > 1) {
// tableuse
o << " (table ";
printName(curr->table, o);
@@ -2724,18 +2736,26 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
o << ' ';
visit(curr->offset);
- if (currModule->tables.size() > 1) {
+ if (!allElementsRefFunc || currModule->tables.size() > 1) {
o << ' ';
- TypeNamePrinter(o, currModule).print(HeapType::func);
+ printElemType();
}
} else {
o << ' ';
- TypeNamePrinter(o, currModule).print(HeapType::func);
+ printElemType();
}
- for (auto name : curr->data) {
- o << ' ';
- printName(name, o);
+ if (allElementsRefFunc) {
+ for (auto* entry : curr->data) {
+ auto* refFunc = entry->cast<RefFunc>();
+ o << ' ';
+ printName(refFunc->func, o);
+ }
+ } else {
+ for (auto* entry : curr->data) {
+ o << ' ';
+ printExpression(entry, o);
+ }
}
o << ')' << maybeNewLine;
}
diff --git a/src/passes/PrintCallGraph.cpp b/src/passes/PrintCallGraph.cpp
index d70e0d7f2..e4576b84b 100644
--- a/src/passes/PrintCallGraph.cpp
+++ b/src/passes/PrintCallGraph.cpp
@@ -22,6 +22,7 @@
#include <iomanip>
#include <memory>
+#include "ir/element-utils.h"
#include "ir/module-utils.h"
#include "ir/utils.h"
#include "pass.h"
@@ -96,12 +97,10 @@ struct PrintCallGraph : public Pass {
CallPrinter printer(module);
// Indirect Targets
- for (auto& segment : module->elementSegments) {
- for (auto& curr : segment->data) {
- auto* func = module->getFunction(curr);
- o << " \"" << func->name << "\" [style=\"filled, rounded\"];\n";
- }
- }
+ ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) {
+ auto* func = module->getFunction(name);
+ o << " \"" << func->name << "\" [style=\"filled, rounded\"];\n";
+ });
o << "}\n";
}
diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp
index be41717f4..2a1c119fd 100644
--- a/src/passes/RemoveImports.cpp
+++ b/src/passes/RemoveImports.cpp
@@ -22,6 +22,7 @@
// look at all the rest of the code).
//
+#include "ir/element-utils.h"
#include "ir/module-utils.h"
#include "pass.h"
#include "wasm.h"
@@ -49,11 +50,8 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> {
*curr, [&](Function* func) { names.push_back(func->name); });
// Do not remove names referenced in a table
std::set<Name> indirectNames;
- for (auto& segment : curr->elementSegments) {
- for (auto& name : segment->data) {
- indirectNames.insert(name);
- }
- }
+ ElementUtils::iterAllElementFunctionNames(
+ curr, [&](Name& name) { indirectNames.insert(name); });
for (auto& name : names) {
if (indirectNames.find(name) == indirectNames.end()) {
curr->removeFunction(name);
diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp
index fa295dac5..b5cf2635a 100644
--- a/src/passes/RemoveUnusedModuleElements.cpp
+++ b/src/passes/RemoveUnusedModuleElements.cpp
@@ -22,6 +22,7 @@
#include <memory>
+#include "ir/element-utils.h"
#include "ir/module-utils.h"
#include "ir/utils.h"
#include "pass.h"
@@ -194,11 +195,9 @@ struct RemoveUnusedModuleElements : public Pass {
importsMemory = true;
}
// For now, all functions that can be called indirectly are marked as roots.
- for (auto& segment : module->elementSegments) {
- for (auto& curr : segment->data) {
- roots.emplace_back(ModuleElementKind::Function, curr);
- }
- }
+ ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) {
+ roots.emplace_back(ModuleElementKind::Function, name);
+ });
// Compute reachability starting from the root set.
ReachabilityAnalyzer analyzer(module, roots);
// Remove unreachable elements.
diff --git a/src/passes/ReorderFunctions.cpp b/src/passes/ReorderFunctions.cpp
index 0c95101a5..66b8275ef 100644
--- a/src/passes/ReorderFunctions.cpp
+++ b/src/passes/ReorderFunctions.cpp
@@ -29,6 +29,7 @@
#include <memory>
+#include <ir/element-utils.h>
#include <pass.h>
#include <wasm.h>
@@ -70,11 +71,8 @@ struct ReorderFunctions : public Pass {
for (auto& curr : module->exports) {
counts[curr->value]++;
}
- for (auto& segment : module->elementSegments) {
- for (auto& curr : segment->data) {
- counts[curr]++;
- }
- }
+ ElementUtils::iterAllElementFunctionNames(
+ module, [&](Name& name) { counts[name]++; });
// sort
std::sort(module->functions.begin(),
module->functions.end(),
diff --git a/src/passes/opt-utils.h b/src/passes/opt-utils.h
index 27a3c5c25..b333779f7 100644
--- a/src/passes/opt-utils.h
+++ b/src/passes/opt-utils.h
@@ -20,6 +20,7 @@
#include <functional>
#include <unordered_set>
+#include <ir/element-utils.h>
#include <pass.h>
#include <wasm.h>
@@ -86,11 +87,7 @@ inline void replaceFunctions(PassRunner* runner,
// replace direct calls
FunctionRefReplacer(maybeReplace).run(runner, &module);
// replace in table
- for (auto& segment : module.elementSegments) {
- for (auto& name : segment->data) {
- maybeReplace(name);
- }
- }
+ ElementUtils::iterAllElementFunctionNames(&module, maybeReplace);
// replace in start
if (module.start.is()) {
diff --git a/src/shell-interface.h b/src/shell-interface.h
index 1cb8768cf..0ba4946dd 100644
--- a/src/shell-interface.h
+++ b/src/shell-interface.h
@@ -86,7 +86,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
}
} memory;
- std::unordered_map<Name, std::vector<Name>> tables;
+ std::unordered_map<Name, std::vector<Literal>> tables;
ShellExternalInterface() : memory() {}
virtual ~ShellExternalInterface() = default;
@@ -177,7 +177,10 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
if (index >= table.size()) {
trap("callTable overflow");
}
- auto* func = instance.wasm.getFunctionOrNull(table[index]);
+ Function* func = nullptr;
+ if (table[index].isFunction() && !table[index].isNull()) {
+ func = instance.wasm.getFunctionOrNull(table[index].getFunc());
+ }
if (!func) {
trap("uninitialized table element");
}
@@ -231,7 +234,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
memory.set<std::array<uint8_t, 16>>(addr, value);
}
- void tableStore(Name tableName, Address addr, Name entry) override {
+ void tableStore(Name tableName, Address addr, Literal entry) override {
tables[tableName][addr] = entry;
}
diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h
index a0d5430d8..595652c94 100644
--- a/src/tools/fuzzing.h
+++ b/src/tools/fuzzing.h
@@ -723,7 +723,9 @@ private:
while (oneIn(3) && !finishedInput) {
auto& randomElem =
wasm.elementSegments[upTo(wasm.elementSegments.size())];
- randomElem->data.push_back(func->name);
+ // FIXME: make the type NonNullable when we support it!
+ auto type = Type(HeapType(func->sig), Nullable);
+ randomElem->data.push_back(builder.makeRefFunc(func->name, type));
}
numAddedFunctions++;
return func;
@@ -1436,11 +1438,13 @@ private:
bool isReturn;
while (1) {
// TODO: handle unreachable
- targetFn = wasm.getFunction(data[i]);
- isReturn = type == Type::unreachable && wasm.features.hasTailCall() &&
- funcContext->func->sig.results == targetFn->sig.results;
- if (targetFn->sig.results == type || isReturn) {
- break;
+ if (auto* get = data[i]->dynCast<RefFunc>()) {
+ targetFn = wasm.getFunction(get->func);
+ isReturn = type == Type::unreachable && wasm.features.hasTailCall() &&
+ funcContext->func->sig.results == targetFn->sig.results;
+ if (targetFn->sig.results == type || isReturn) {
+ break;
+ }
}
i++;
if (i == data.size()) {
diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp
index f39d8a34e..13a522135 100644
--- a/src/tools/wasm-ctor-eval.cpp
+++ b/src/tools/wasm-ctor-eval.cpp
@@ -251,19 +251,25 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface {
}
auto end = start + segment->data.size();
if (start <= index && index < end) {
- auto name = segment->data[index - start];
- // if this is one of our functions, we can call it; if it was imported,
- // fail
- auto* func = wasm->getFunction(name);
- if (func->sig != sig) {
- throw FailToEvalException(
- std::string("callTable signature mismatch: ") + name.str);
- }
- if (!func->imported()) {
- return instance.callFunctionInternal(name, arguments);
+ auto entry = segment->data[index - start];
+ if (auto* get = entry->dynCast<RefFunc>()) {
+ auto name = get->func;
+ // if this is one of our functions, we can call it; if it was
+ // imported, fail
+ auto* func = wasm->getFunction(name);
+ if (func->sig != sig) {
+ throw FailToEvalException(
+ std::string("callTable signature mismatch: ") + name.str);
+ }
+ if (!func->imported()) {
+ return instance.callFunctionInternal(name, arguments);
+ } else {
+ throw FailToEvalException(
+ std::string("callTable on imported function: ") + name.str);
+ }
} else {
throw FailToEvalException(
- std::string("callTable on imported function: ") + name.str);
+ std::string("callTable on uninitialized entry"));
}
}
}
@@ -295,7 +301,7 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface {
}
// called during initialization, but we don't keep track of a table
- void tableStore(Name tableName, Address addr, Name value) override {}
+ void tableStore(Name tableName, Address addr, Literal value) override {}
bool growMemory(Address /*oldSize*/, Address newSize) override {
throw FailToEvalException("grow memory");
diff --git a/src/tools/wasm-metadce.cpp b/src/tools/wasm-metadce.cpp
index 71ea2f693..5ddf52dae 100644
--- a/src/tools/wasm-metadce.cpp
+++ b/src/tools/wasm-metadce.cpp
@@ -26,6 +26,7 @@
#include <memory>
+#include "ir/element-utils.h"
#include "ir/module-utils.h"
#include "pass.h"
#include "support/colors.h"
@@ -216,13 +217,14 @@ struct MetaDCEGraph {
ModuleUtils::iterActiveElementSegments(wasm, [&](ElementSegment* segment) {
// TODO: currently, all functions in the table are roots, but we
// should add an option to refine that
- for (auto& name : segment->data) {
- if (!wasm.getFunction(name)->imported()) {
- roots.insert(functionToDCENode[name]);
- } else {
- roots.insert(importIdToDCENode[getFunctionImportId(name)]);
- }
- }
+ ElementUtils::iterElementSegmentFunctionNames(
+ segment, [&](Name name, Index) {
+ if (!wasm.getFunction(name)->imported()) {
+ roots.insert(functionToDCENode[name]);
+ } else {
+ roots.insert(importIdToDCENode[getFunctionImportId(name)]);
+ }
+ });
rooter.walk(segment->offset);
});
for (auto& segment : wasm.memory.segments) {
diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp
index efe5c1caa..19afc2032 100644
--- a/src/tools/wasm-reduce.cpp
+++ b/src/tools/wasm-reduce.cpp
@@ -766,7 +766,31 @@ struct Reducer
}
// the "opposite" of shrinking: copy a 'zero' element
for (auto& segment : curr->segments) {
- reduceByZeroing(&segment, 0, 2, shrank);
+ reduceByZeroing(
+ &segment, 0, [](char item) { return item == 0; }, 2, shrank);
+ }
+ }
+
+ template<typename T, typename U, typename C>
+ void
+ reduceByZeroing(T* segment, U zero, C isZero, size_t bonus, bool shrank) {
+ for (auto& item : segment->data) {
+ if (!shouldTryToReduce(bonus) || isZero(item)) {
+ continue;
+ }
+ auto save = item;
+ item = zero;
+ if (writeAndTestReduction()) {
+ std::cerr << "| zeroed elem segment\n";
+ noteReduction();
+ } else {
+ item = save;
+ }
+ if (shrank) {
+ // zeroing is fairly inefficient. if we are managing to shrink
+ // (which we do exponentially), just zero one per segment at most
+ break;
+ }
}
}
@@ -803,37 +827,9 @@ struct Reducer
return shrank;
}
- template<typename T, typename U>
- void reduceByZeroing(T* segment, U zero, size_t bonus, bool shrank) {
- if (segment->data.empty()) {
- return;
- }
- for (auto& item : segment->data) {
- if (!shouldTryToReduce(bonus)) {
- continue;
- }
- if (item == zero) {
- continue;
- }
- auto save = item;
- item = zero;
- if (writeAndTestReduction()) {
- std::cerr << "| zeroed elem segment\n";
- noteReduction();
- } else {
- item = save;
- }
- if (shrank) {
- // zeroing is fairly inefficient. if we are managing to shrink
- // (which we do exponentially), just zero one per segment at most
- break;
- }
- }
- }
-
void shrinkElementSegments(Module* module) {
std::cerr << "| try to simplify elem segments\n";
- Name first;
+ Expression* first = nullptr;
auto it =
std::find_if_not(module->elementSegments.begin(),
module->elementSegments.end(),
@@ -842,6 +838,10 @@ struct Reducer
if (it != module->elementSegments.end()) {
first = it->get()->data[0];
}
+ if (first == nullptr) {
+ // The elements are all empty, nothing to shrink
+ return;
+ }
// try to reduce to first function. first, shrink segment elements.
// while we are shrinking successfully, keep going exponentially.
@@ -851,7 +851,24 @@ struct Reducer
}
// the "opposite" of shrinking: copy a 'zero' element
for (auto& segment : module->elementSegments) {
- reduceByZeroing(segment.get(), first, 100, shrank);
+ reduceByZeroing(
+ segment.get(),
+ first,
+ [&](Expression* entry) {
+ if (entry->is<RefNull>()) {
+ // we don't need to replace a ref.null
+ return true;
+ } else if (first->is<RefNull>()) {
+ return false;
+ } else {
+ // Both are ref.func
+ auto* f = first->cast<RefFunc>();
+ auto* e = entry->cast<RefFunc>();
+ return f->func == e->func;
+ }
+ },
+ 100,
+ shrank);
}
}
diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp
index efa2c0ca4..6d2077e1c 100644
--- a/src/tools/wasm-shell.cpp
+++ b/src/tools/wasm-shell.cpp
@@ -23,6 +23,7 @@
#include <memory>
#include "execution-results.h"
+#include "ir/element-utils.h"
#include "pass.h"
#include "shell-interface.h"
#include "support/command-line.h"
@@ -166,18 +167,16 @@ run_asserts(Name moduleName,
reportUnknownImport(import);
}
});
- for (auto& segment : wasm.elementSegments) {
- for (auto name : segment->data) {
- // spec tests consider it illegal to use spectest.print in a table
- if (auto* import = wasm.getFunction(name)) {
- if (import->imported() && import->module == SPECTEST &&
- import->base.startsWith(PRINT)) {
- std::cerr << "cannot put spectest.print in table\n";
- invalid = true;
- }
+ ElementUtils::iterAllElementFunctionNames(&wasm, [&](Name name) {
+ // spec tests consider it illegal to use spectest.print in a table
+ if (auto* import = wasm.getFunction(name)) {
+ if (import->imported() && import->module == SPECTEST &&
+ import->base.startsWith(PRINT)) {
+ std::cerr << "cannot put spectest.print in table\n";
+ invalid = true;
}
}
- }
+ });
if (wasm.memory.imported()) {
reportUnknownImport(&wasm.memory);
}
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 1cf8186f8..ec89bb1da 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -1531,9 +1531,6 @@ public:
void readDataSegments();
void readDataCount();
- // A map from elem segment indexes to their entries
- std::map<Index, std::vector<Index>> functionTable;
-
void readTableDeclarations();
void readElementSegments();
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index dcf9d2540..61d29ada1 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -2222,7 +2222,7 @@ public:
WASM_UNREACHABLE("unimp");
}
- virtual void tableStore(Name tableName, Address addr, Name entry) {
+ virtual void tableStore(Name tableName, Address addr, Literal entry) {
WASM_UNREACHABLE("unimp");
}
};
@@ -2309,21 +2309,22 @@ private:
std::unordered_set<size_t> droppedSegments;
void initializeTableContents() {
- for (auto& segment : wasm.elementSegments) {
- if (segment->table.isNull()) {
- continue;
- }
-
+ ModuleUtils::iterActiveElementSegments(wasm, [&](ElementSegment* segment) {
Address offset =
(uint32_t)InitializerExpressionRunner<GlobalManager>(globals, maxDepth)
.visit(segment->offset)
.getSingleValue()
.geti32();
- for (size_t i = 0; i != segment->data.size(); ++i) {
+
+ Function dummyFunc;
+ FunctionScope dummyScope(&dummyFunc, {});
+ RuntimeExpressionRunner runner(*this, dummyScope, maxDepth);
+ for (Index i = 0; i < segment->data.size(); ++i) {
+ Flow ret = runner.visit(segment->data[i]);
externalInterface->tableStore(
- segment->table, offset + i, segment->data[i]);
+ segment->table, offset + i, ret.getSingleValue());
}
- }
+ });
}
void initializeMemoryContents() {
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index eaddd8c8b..03da961bd 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -317,7 +317,8 @@ private:
void parseElem(Element& s, Table* table = nullptr);
ElementSegment* parseElemFinish(Element& s,
std::unique_ptr<ElementSegment>& segment,
- Index i = 1);
+ Index i = 1,
+ bool usesExpressions = false);
// Parses something like (func ..), (array ..), (struct)
HeapType parseHeapType(Element& s);
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index 7dcf2d146..1f307c318 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -196,6 +196,9 @@ struct Walker : public VisitorType {
if (segment->table.is()) {
walk(segment->offset);
}
+ for (auto* expr : segment->data) {
+ walk(expr);
+ }
static_cast<SubType*>(this)->visitElementSegment(segment);
}
diff --git a/src/wasm.h b/src/wasm.h
index 630e5b003..4747b48a3 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1722,12 +1722,12 @@ class ElementSegment : public Named {
public:
Name table;
Expression* offset;
- std::vector<Name> data;
+ std::vector<Expression*> data;
ElementSegment() = default;
ElementSegment(Name table, Expression* offset)
: table(table), offset(offset) {}
- ElementSegment(Name table, Expression* offset, std::vector<Name>& init)
+ ElementSegment(Name table, Expression* offset, std::vector<Expression*>& init)
: table(table), offset(offset) {
data.swap(init);
}
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 09d3287b0..5fff58045 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -571,9 +571,11 @@ void WasmBinaryWriter::writeElementSegments() {
Index tableIdx = 0;
bool isPassive = segment->table.isNull();
- // TODO(reference-types): add support for writing expressions instead of
- // function indices.
- bool usesExpressions = false;
+ // if all items are ref.func, we can use the shorter form.
+ bool usesExpressions =
+ std::any_of(segment->data.begin(),
+ segment->data.end(),
+ [](Expression* curr) { return !curr->is<RefFunc>(); });
bool hasTableIndex = false;
if (!isPassive) {
@@ -600,13 +602,27 @@ void WasmBinaryWriter::writeElementSegments() {
o << int8_t(BinaryConsts::End);
}
- if (!usesExpressions && (isPassive || hasTableIndex)) {
- // elemKind funcref
- o << U32LEB(0);
+ if (isPassive || hasTableIndex) {
+ if (usesExpressions) {
+ // elemType funcref
+ writeType(Type::funcref);
+ } else {
+ // elemKind funcref
+ o << U32LEB(0);
+ }
}
o << U32LEB(segment->data.size());
- for (auto& name : segment->data) {
- o << U32LEB(getFunctionIndex(name));
+ if (usesExpressions) {
+ for (auto* item : segment->data) {
+ writeExpression(item);
+ o << int8_t(BinaryConsts::End);
+ }
+ } else {
+ for (auto& item : segment->data) {
+ // We've ensured that all items are ref.func.
+ auto& name = item->cast<RefFunc>()->func;
+ o << U32LEB(getFunctionIndex(name));
+ }
}
}
@@ -2676,14 +2692,6 @@ void WasmBinaryBuilder::processNames() {
}
}
- for (auto& pair : functionTable) {
- auto i = pair.first;
- auto& indices = pair.second;
- for (auto j : indices) {
- wasm.elementSegments[i]->data.push_back(getFunctionName(j));
- }
- }
-
for (auto& iter : globalRefs) {
size_t index = iter.first;
auto& refs = iter.second;
@@ -2787,10 +2795,6 @@ void WasmBinaryBuilder::readElementSegments() {
continue;
}
- if (usesExpressions) {
- throwError("Only elem segments with function indexes are supported.");
- }
-
if (!isPassive) {
Index tableIdx = 0;
if (hasTableIdx) {
@@ -2819,17 +2823,35 @@ void WasmBinaryBuilder::readElementSegments() {
}
if (isPassive || hasTableIdx) {
- auto elemKind = getU32LEB();
- if (elemKind != 0x0) {
- throwError("Only funcref elem kinds are valid.");
+ if (usesExpressions) {
+ auto type = getType();
+ if (type != Type::funcref) {
+ throwError("Only funcref elem kinds are valid.");
+ }
+ } else {
+ auto elemKind = getU32LEB();
+ if (elemKind != 0x0) {
+ throwError("Only funcref elem kinds are valid.");
+ }
}
}
- size_t segmentIndex = functionTable.size();
- auto& indexSegment = functionTable[segmentIndex];
+ auto& segmentData = elementSegments.back()->data;
auto size = getU32LEB();
- for (Index j = 0; j < size; j++) {
- indexSegment.push_back(getU32LEB());
+ if (usesExpressions) {
+ for (Index j = 0; j < size; j++) {
+ segmentData.push_back(readExpression());
+ }
+ } else {
+ for (Index j = 0; j < size; j++) {
+ Index index = getU32LEB();
+ auto sig = getSignatureByFunctionIndex(index);
+ // Use a placeholder name for now
+ auto* refFunc = Builder(wasm).makeRefFunc(
+ Name::fromInt(index), Type(HeapType(sig), Nullable));
+ functionRefs[index].push_back(refFunc);
+ segmentData.push_back(refFunc);
+ }
}
}
}
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 45c7c9161..7774eac3e 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -3268,12 +3268,14 @@ void SExpressionWasmBuilder::parseElem(Element& s, Table* table) {
Index i = 1;
Name name = Name::fromInt(elemCounter++);
bool hasExplicitName = false;
+ bool isPassive = false;
+ bool usesExpressions = false;
if (table) {
Expression* offset = allocator.alloc<Const>()->set(Literal(int32_t(0)));
auto segment = std::make_unique<ElementSegment>(table->name, offset);
segment->setName(name, hasExplicitName);
- parseElemFinish(s, segment, i);
+ parseElemFinish(s, segment, i, false);
return;
}
@@ -3286,10 +3288,20 @@ void SExpressionWasmBuilder::parseElem(Element& s, Table* table) {
return;
}
- if (s[i]->isStr() && s[i]->str() == FUNC) {
+ if (s[i]->isStr()) {
+ if (s[i]->str() == FUNC) {
+ isPassive = true;
+ usesExpressions = false;
+ } else if (s[i]->str() == FUNCREF) {
+ isPassive = true;
+ usesExpressions = true;
+ }
+ }
+
+ if (isPassive) {
auto segment = std::make_unique<ElementSegment>();
segment->setName(name, hasExplicitName);
- parseElemFinish(s, segment, i + 1);
+ parseElemFinish(s, segment, i + 1, usesExpressions);
return;
}
@@ -3326,11 +3338,11 @@ void SExpressionWasmBuilder::parseElem(Element& s, Table* table) {
}
if (!oldStyle) {
- if (s[i]->str() != FUNC) {
- throw ParseException(
- "only the abbreviated form of elemList is supported.");
+ if (s[i]->str() == FUNCREF) {
+ usesExpressions = true;
+ } else if (s[i]->str() != FUNC) {
+ throw ParseException("expected func or funcref.");
}
- // ignore elemType for now
i += 1;
}
@@ -3340,13 +3352,40 @@ void SExpressionWasmBuilder::parseElem(Element& s, Table* table) {
auto segment = std::make_unique<ElementSegment>(table->name, offset);
segment->setName(name, hasExplicitName);
- parseElemFinish(s, segment, i);
+ parseElemFinish(s, segment, i, usesExpressions);
}
ElementSegment* SExpressionWasmBuilder::parseElemFinish(
- Element& s, std::unique_ptr<ElementSegment>& segment, Index i) {
- for (; i < s.size(); i++) {
- segment->data.push_back(getFunctionName(*s[i]));
+ Element& s,
+ std::unique_ptr<ElementSegment>& segment,
+ Index i,
+ bool usesExpressions) {
+
+ if (usesExpressions) {
+ for (; i < s.size(); i++) {
+ if (!s[i]->isList()) {
+ throw ParseException("expected a ref.* expression.");
+ }
+ auto& inner = *s[i];
+ if (elementStartsWith(inner, ITEM)) {
+ if (inner[1]->isList()) {
+ // (item (ref.func $f))
+ segment->data.push_back(parseExpression(inner[1]));
+ } else {
+ // (item ref.func $f)
+ inner.list().removeAt(0);
+ segment->data.push_back(parseExpression(inner));
+ }
+ } else {
+ segment->data.push_back(parseExpression(inner));
+ }
+ }
+ } else {
+ for (; i < s.size(); i++) {
+ auto func = getFunctionName(*s[i]);
+ segment->data.push_back(Builder(wasm).makeRefFunc(
+ func, Type(HeapType(functionSignatures[func]), Nullable)));
+ }
}
return wasm.addElementSegment(std::move(segment));
}
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 71ecf208e..52ba36872 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -1960,7 +1960,10 @@ void FunctionValidator::visitMemoryGrow(MemoryGrow* curr) {
}
void FunctionValidator::visitRefNull(RefNull* curr) {
- shouldBeTrue(getModule()->features.hasReferenceTypes(),
+ // If we are not in a function, this is a global location like a table. We
+ // allow RefNull there as we represent tables that way regardless of what
+ // features are enabled.
+ shouldBeTrue(!getFunction() || getModule()->features.hasReferenceTypes(),
curr,
"ref.null requires reference-types to be enabled");
shouldBeTrue(
@@ -1978,7 +1981,10 @@ void FunctionValidator::visitRefIs(RefIs* curr) {
}
void FunctionValidator::visitRefFunc(RefFunc* curr) {
- shouldBeTrue(getModule()->features.hasReferenceTypes(),
+ // If we are not in a function, this is a global location like a table. We
+ // allow RefFunc there as we represent tables that way regardless of what
+ // features are enabled.
+ shouldBeTrue(!getFunction() || getModule()->features.hasReferenceTypes(),
curr,
"ref.func requires reference-types to be enabled");
if (!info.validateGlobally) {
@@ -2799,11 +2805,29 @@ static void validateMemory(Module& module, ValidationInfo& info) {
}
static void validateTables(Module& module, ValidationInfo& info) {
+ FunctionValidator validator(module, &info);
+
if (!module.features.hasReferenceTypes()) {
info.shouldBeTrue(module.tables.size() <= 1,
"table",
"Only 1 table definition allowed in MVP (requires "
"--enable-reference-types)");
+ if (!module.tables.empty()) {
+ auto& table = module.tables.front();
+ for (auto& segment : module.elementSegments) {
+ info.shouldBeTrue(segment->table == table->name,
+ "elem",
+ "all element segments should refer to a single table "
+ "in MVP.");
+ for (auto* expr : segment->data) {
+ info.shouldBeTrue(
+ expr->is<RefFunc>(),
+ expr,
+ "all table elements must be non-null funcrefs in MVP.");
+ validator.validate(expr);
+ }
+ }
+ }
}
for (auto& segment : module.elementSegments) {
@@ -2820,11 +2844,17 @@ static void validateTables(Module& module, ValidationInfo& info) {
table->initial * Table::kPageSize),
segment->offset,
"table segment offset should be reasonable");
- FunctionValidator(module, &info).validate(segment->offset);
- }
- for (auto name : segment->data) {
- info.shouldBeTrue(
- module.getFunctionOrNull(name), name, "segment name should be valid");
+ validator.validate(segment->offset);
+ }
+ // Avoid double checking items
+ if (module.features.hasReferenceTypes()) {
+ for (auto* expr : segment->data) {
+ info.shouldBeTrue(
+ expr->is<RefFunc>() || expr->is<RefNull>(),
+ expr,
+ "element segment items must be either ref.func or ref.null func.");
+ validator.validate(expr);
+ }
}
}
}
diff --git a/src/wasm2js.h b/src/wasm2js.h
index b3a4c7e18..6d3f0682c 100644
--- a/src/wasm2js.h
+++ b/src/wasm2js.h
@@ -32,6 +32,7 @@
#include "emscripten-optimizer/optimizer.h"
#include "ir/branch-utils.h"
#include "ir/effects.h"
+#include "ir/element-utils.h"
#include "ir/find_all.h"
#include "ir/import-utils.h"
#include "ir/load-utils.h"
@@ -325,11 +326,8 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) {
functionsCallableFromOutside.insert(exp->value);
}
}
- for (auto& segment : wasm->elementSegments) {
- for (auto name : segment->data) {
- functionsCallableFromOutside.insert(name);
- }
- }
+ ElementUtils::iterAllElementFunctionNames(
+ wasm, [&](Name name) { functionsCallableFromOutside.insert(name); });
// Ensure the scratch memory helpers.
// If later on they aren't needed, we'll clean them up.
@@ -680,26 +678,27 @@ void Wasm2JSBuilder::addTable(Ref ast, Module* wasm) {
ModuleUtils::iterTableSegments(
*wasm, table->name, [&](ElementSegment* segment) {
auto offset = segment->offset;
- for (Index i = 0; i < segment->data.size(); i++) {
- Ref index;
- if (auto* c = offset->dynCast<Const>()) {
- index = ValueBuilder::makeInt(c->value.geti32() + i);
- } else if (auto* get = offset->dynCast<GlobalGet>()) {
- index =
- ValueBuilder::makeBinary(ValueBuilder::makeName(stringToIString(
- asmangle(get->name.str))),
- PLUS,
- ValueBuilder::makeNum(i));
- } else {
- WASM_UNREACHABLE("unexpected expr type");
- }
- ast->push_back(ValueBuilder::makeStatement(ValueBuilder::makeBinary(
- ValueBuilder::makeSub(ValueBuilder::makeName(FUNCTION_TABLE),
- index),
- SET,
- ValueBuilder::makeName(
- fromName(segment->data[i], NameScope::Top)))));
- }
+ ElementUtils::iterElementSegmentFunctionNames(
+ segment, [&](Name entry, Index i) {
+ Ref index;
+ if (auto* c = offset->dynCast<Const>()) {
+ index = ValueBuilder::makeInt(c->value.geti32() + i);
+ } else if (auto* get = offset->dynCast<GlobalGet>()) {
+ index = ValueBuilder::makeBinary(
+ ValueBuilder::makeName(
+ stringToIString(asmangle(get->name.str))),
+ PLUS,
+ ValueBuilder::makeNum(i));
+ } else {
+ WASM_UNREACHABLE("unexpected expr type");
+ }
+ ast->push_back(
+ ValueBuilder::makeStatement(ValueBuilder::makeBinary(
+ ValueBuilder::makeSub(ValueBuilder::makeName(FUNCTION_TABLE),
+ index),
+ SET,
+ ValueBuilder::makeName(fromName(entry, NameScope::Top)))));
+ });
});
}
}
diff --git a/test/multi-table.wast b/test/multi-table.wast
index 378d4018a..cbfa369e0 100644
--- a/test/multi-table.wast
+++ b/test/multi-table.wast
@@ -1,6 +1,7 @@
(module
(import "a" "b" (table $t1 1 10 funcref))
(table $t2 3 3 funcref)
+ (table $t3 4 4 funcref)
;; add to $t1
(elem (i32.const 0) $f)
@@ -9,7 +10,10 @@
(elem (table $t2) (i32.const 0) func $f)
(elem $activeNonZeroOffset (table $t2) (offset (i32.const 1)) func $f $g)
- (elem $passive func $f $g)
+ (elem $e3-1 (table $t3) (i32.const 0) funcref (ref.func $f) (ref.null func))
+ (elem $e3-2 (table $t3) (offset (i32.const 2)) funcref (item ref.func $f) (item (ref.func $g)))
+
+ (elem $passive funcref (item ref.func $f) (item (ref.func $g)) (ref.null func))
(elem $empty func)
(elem $declarative declare func $h)
diff --git a/test/multi-table.wast.from-wast b/test/multi-table.wast.from-wast
index f60bab8cb..80c081354 100644
--- a/test/multi-table.wast.from-wast
+++ b/test/multi-table.wast.from-wast
@@ -5,7 +5,10 @@
(table $t2 3 3 funcref)
(elem (table $t2) (i32.const 0) func $f)
(elem $activeNonZeroOffset (table $t2) (i32.const 1) func $f $g)
- (elem $passive func $f $g)
+ (table $t3 4 4 funcref)
+ (elem $e3-1 (table $t3) (i32.const 0) funcref (ref.func $f) (ref.null func))
+ (elem $e3-2 (table $t3) (i32.const 2) func $f $g)
+ (elem $passive funcref (ref.func $f) (ref.func $g) (ref.null func))
(elem declare func $h)
(func $f
(drop
diff --git a/test/multi-table.wast.fromBinary b/test/multi-table.wast.fromBinary
index 13604e69d..9918b6183 100644
--- a/test/multi-table.wast.fromBinary
+++ b/test/multi-table.wast.fromBinary
@@ -5,7 +5,10 @@
(table $t2 3 3 funcref)
(elem (table $t2) (i32.const 0) func $f)
(elem $activeNonZeroOffset (table $t2) (i32.const 1) func $f $g)
- (elem $passive func $f $g)
+ (table $t3 4 4 funcref)
+ (elem $e3-1 (table $t3) (i32.const 0) funcref (ref.func $f) (ref.null func))
+ (elem $e3-2 (table $t3) (i32.const 2) func $f $g)
+ (elem $passive funcref (ref.func $f) (ref.func $g) (ref.null func))
(elem declare func $h)
(func $f
(drop
diff --git a/test/multi-table.wast.fromBinary.noDebugInfo b/test/multi-table.wast.fromBinary.noDebugInfo
index 77e2ea439..638ee7329 100644
--- a/test/multi-table.wast.fromBinary.noDebugInfo
+++ b/test/multi-table.wast.fromBinary.noDebugInfo
@@ -5,7 +5,10 @@
(table $0 3 3 funcref)
(elem (table $0) (i32.const 0) func $0)
(elem (table $0) (i32.const 1) func $0 $1)
- (elem func $0 $1)
+ (table $1 4 4 funcref)
+ (elem (table $1) (i32.const 0) funcref (ref.func $0) (ref.null func))
+ (elem (table $1) (i32.const 2) func $0 $1)
+ (elem funcref (ref.func $0) (ref.func $1) (ref.null func))
(elem declare func $2)
(func $0
(drop
diff --git a/test/passes/converge_O3_metrics.bin.txt b/test/passes/converge_O3_metrics.bin.txt
index 2df4a4b0b..4c1bed30d 100644
--- a/test/passes/converge_O3_metrics.bin.txt
+++ b/test/passes/converge_O3_metrics.bin.txt
@@ -7,7 +7,7 @@ total
[memory-data] : 28
[table-data] : 429
[tables] : 0
- [total] : 129
+ [total] : 558
[vars] : 4
Binary : 12
Block : 8
@@ -23,6 +23,7 @@ total
LocalGet : 18
LocalSet : 7
Loop : 1
+ RefFunc : 429
Store : 5
(module
(type $i32_i32_i32_=>_i32 (func (param i32 i32 i32) (result i32)))
@@ -249,7 +250,7 @@ total
[memory-data] : 28
[table-data] : 429
[tables] : 0
- [total] : 129
+ [total] : 558
[vars] : 4
Binary : 12
Block : 8
@@ -265,6 +266,7 @@ total
LocalGet : 18
LocalSet : 7
Loop : 1
+ RefFunc : 429
Store : 5
(module
(type $i32_i32_i32_=>_i32 (func (param i32 i32 i32) (result i32)))
diff --git a/test/passes/func-metrics.txt b/test/passes/func-metrics.txt
index 1ca5d2fd9..61d97ddad 100644
--- a/test/passes/func-metrics.txt
+++ b/test/passes/func-metrics.txt
@@ -7,8 +7,9 @@ global
[memory-data] : 9
[table-data] : 3
[tables] : 1
- [total] : 3
+ [total] : 6
Const : 3
+ RefFunc : 3
func: empty
[binary-bytes] : 3
[total] : 1
diff --git a/test/passes/fuzz_metrics_noprint.bin.txt b/test/passes/fuzz_metrics_noprint.bin.txt
index eb24adfa4..d2b5e37e9 100644
--- a/test/passes/fuzz_metrics_noprint.bin.txt
+++ b/test/passes/fuzz_metrics_noprint.bin.txt
@@ -7,7 +7,7 @@ total
[memory-data] : 4
[table-data] : 17
[tables] : 1
- [total] : 5797
+ [total] : 5814
[vars] : 160
Binary : 467
Block : 810
@@ -24,6 +24,7 @@ total
LocalSet : 292
Loop : 127
Nop : 69
+ RefFunc : 17
Return : 259
Select : 53
Store : 53
diff --git a/test/passes/metrics_all-features.txt b/test/passes/metrics_all-features.txt
index ae746355d..14c1bc72d 100644
--- a/test/passes/metrics_all-features.txt
+++ b/test/passes/metrics_all-features.txt
@@ -7,13 +7,14 @@ total
[memory-data] : 9
[table-data] : 3
[tables] : 1
- [total] : 27
+ [total] : 30
[vars] : 1
Binary : 1
Block : 1
Const : 15
Drop : 6
If : 4
+ RefFunc : 3
(module
(type $0 (func (param i32)))
(type $i32_i32_=>_none (func (param i32 i32)))
diff --git a/test/spec/call_indirect_refnull.wast b/test/spec/call_indirect_refnull.wast
new file mode 100644
index 000000000..f1b4f5978
--- /dev/null
+++ b/test/spec/call_indirect_refnull.wast
@@ -0,0 +1,12 @@
+(module
+ (table $t 1 1 funcref)
+ (elem (table $t) (i32.const 0) funcref (ref.null func))
+
+ (func $call-refnull (export "call-refnull") (result f32)
+ (call_indirect (result f32) (i32.const 0))
+ )
+)
+(assert_trap
+ (invoke "call-refnull")
+ "uninitialized table element"
+) \ No newline at end of file