summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/asm2wasm.h291
-rw-r--r--src/ast_utils.h111
-rw-r--r--src/binaryen-c.cpp63
-rw-r--r--src/binaryen-c.h14
-rw-r--r--src/cfg/Relooper.cpp2
-rw-r--r--src/cfg/cfg-traversal.h70
-rw-r--r--src/js/wasm.js-post.js66
-rw-r--r--src/passes/CMakeLists.txt2
-rw-r--r--src/passes/CoalesceLocals.cpp12
-rw-r--r--src/passes/DeadCodeElimination.cpp58
-rw-r--r--src/passes/DropReturnValues.cpp83
-rw-r--r--src/passes/LowerIfElse.cpp67
-rw-r--r--src/passes/MergeBlocks.cpp69
-rw-r--r--src/passes/NameManager.cpp3
-rw-r--r--src/passes/Print.cpp148
-rw-r--r--src/passes/RemoveImports.cpp2
-rw-r--r--src/passes/RemoveUnusedBrs.cpp2
-rw-r--r--src/passes/RemoveUnusedNames.cpp28
-rw-r--r--src/passes/SimplifyLocals.cpp42
-rw-r--r--src/passes/Vacuum.cpp48
-rw-r--r--src/passes/pass.cpp5
-rw-r--r--src/passes/passes.h1
-rw-r--r--src/s2wasm.h12
-rw-r--r--src/shell-interface.h13
-rw-r--r--src/support/colors.h9
-rw-r--r--src/tools/asm2wasm.cpp28
-rw-r--r--src/tools/wasm-shell.cpp37
-rw-r--r--src/wasm-binary.h178
-rw-r--r--src/wasm-builder.h53
-rw-r--r--src/wasm-interpreter.h102
-rw-r--r--src/wasm-js.cpp123
-rw-r--r--src/wasm-linker.cpp13
-rw-r--r--src/wasm-linker.h1
-rw-r--r--src/wasm-module-building.h6
-rw-r--r--src/wasm-printing.h2
-rw-r--r--src/wasm-s-parser.h361
-rw-r--r--src/wasm-traversal.h125
-rw-r--r--src/wasm-validator.h208
-rw-r--r--src/wasm.cpp10
-rw-r--r--src/wasm.h74
40 files changed, 1649 insertions, 893 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h
index ce3c0749d..53fa78db4 100644
--- a/src/asm2wasm.h
+++ b/src/asm2wasm.h
@@ -143,15 +143,13 @@ class Asm2WasmBuilder {
// globals
- unsigned nextGlobal; // next place to put a global
- unsigned maxGlobal; // highest address we can put a global
struct MappedGlobal {
- unsigned address;
WasmType type;
bool import; // if true, this is an import - we should read the value, not just set a zero
IString module, base;
- MappedGlobal() : address(0), type(none), import(false) {}
- MappedGlobal(unsigned address, WasmType type, bool import, IString module, IString base) : address(address), type(type), import(import), module(module), base(base) {}
+ MappedGlobal() : type(none), import(false) {}
+ MappedGlobal(WasmType type) : type(type), import(false) {}
+ MappedGlobal(WasmType type, bool import, IString module, IString base) : type(type), import(import), module(module), base(base) {}
};
// function table
@@ -165,31 +163,20 @@ class Asm2WasmBuilder {
public:
std::map<IString, MappedGlobal> mappedGlobals;
- // the global mapping info is not present in the output wasm. We need to save it on the side
- // if we intend to load and run this module's wasm.
- void serializeMappedGlobals(const char *filename) {
- FILE *f = fopen(filename, "w");
- assert(f);
- fprintf(f, "{\n");
- bool first = true;
- for (auto& pair : mappedGlobals) {
- auto name = pair.first;
- auto& global = pair.second;
- if (first) first = false;
- else fprintf(f, ",");
- fprintf(f, "\"%s\": { \"address\": %d, \"type\": %d, \"import\": %d, \"module\": \"%s\", \"base\": \"%s\" }\n",
- name.str, global.address, global.type, global.import, global.module.str, global.base.str);
- }
- fprintf(f, "}");
- fclose(f);
- }
-
private:
- void allocateGlobal(IString name, WasmType type, bool import, IString module = IString(), IString base = IString()) {
+ void allocateGlobal(IString name, WasmType type) {
assert(mappedGlobals.find(name) == mappedGlobals.end());
- mappedGlobals.emplace(name, MappedGlobal(nextGlobal, type, import, module, base));
- nextGlobal += 8;
- assert(nextGlobal < maxGlobal);
+ mappedGlobals.emplace(name, MappedGlobal(type));
+ auto global = new Global();
+ global->name = name;
+ global->type = type;
+ Literal value;
+ if (type == i32) value = Literal(uint32_t(0));
+ else if (type == f32) value = Literal(float(0));
+ else if (type == f64) value = Literal(double(0));
+ else WASM_UNREACHABLE();
+ global->init = wasm.allocator.alloc<Const>()->set(value);
+ wasm.addGlobal(global);
}
struct View {
@@ -237,12 +224,6 @@ private:
// if we already saw this signature, verify it's the same (or else handle that)
if (importedFunctionTypes.find(importName) != importedFunctionTypes.end()) {
FunctionType* previous = importedFunctionTypes[importName].get();
-#if 0
- std::cout << "compare " << importName.str << "\nfirst: ";
- type.print(std::cout, 0);
- std::cout << "\nsecond: ";
- previous.print(std::cout, 0) << ".\n";
-#endif
if (*type != *previous) {
// merge it in. we'll add on extra 0 parameters for ones not actually used, and upgrade types to
// double where there is a conflict (which is ok since in JS, double can contain everything
@@ -260,6 +241,8 @@ private:
}
if (previous->result == none) {
previous->result = type->result; // use a more concrete type
+ } else if (previous->result != type->result) {
+ previous->result = f64; // overloaded return type, make it a double
}
}
} else {
@@ -278,8 +261,6 @@ public:
: wasm(wasm),
allocator(wasm.allocator),
builder(wasm),
- nextGlobal(8),
- maxGlobal(1000),
memoryGrowth(memoryGrowth),
debug(debug),
imprecise(imprecise),
@@ -416,9 +397,9 @@ private:
}
void fixCallType(Expression* call, WasmType type) {
- if (call->is<Call>()) call->type = type;
- if (call->is<CallImport>()) call->type = type;
- else if (call->is<CallIndirect>()) call->type = type;
+ if (call->is<Call>()) call->cast<Call>()->type = type;
+ if (call->is<CallImport>()) call->cast<CallImport>()->type = type;
+ else if (call->is<CallIndirect>()) call->cast<CallIndirect>()->type = type;
}
FunctionType* getBuiltinFunctionType(Name module, Name base, ExpressionList* operands = nullptr) {
@@ -520,12 +501,13 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
type = WasmType::f64;
}
if (type != WasmType::none) {
- // wasm has no imported constants, so allocate a global, and we need to write the value into that
- allocateGlobal(name, type, true, import->module, import->base);
- delete import;
+ import->kind = Import::Global;
+ import->globalType = type;
+ mappedGlobals.emplace(name, type);
} else {
- wasm.addImport(import);
+ import->kind = Import::Function;
}
+ wasm.addImport(import);
};
IString Int8Array, Int16Array, Int32Array, UInt8Array, UInt16Array, UInt32Array, Float32Array, Float64Array;
@@ -537,7 +519,10 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
for (unsigned i = 1; i < body->size(); i++) {
if (body[i][0] == DEFUN) numFunctions++;
}
- optimizingBuilder = make_unique<OptimizingIncrementalModuleBuilder>(&wasm, numFunctions);
+ optimizingBuilder = make_unique<OptimizingIncrementalModuleBuilder>(&wasm, numFunctions, [&](PassRunner& passRunner) {
+ // run autodrop first, before optimizations
+ passRunner.add<AutoDrop>();
+ });
}
// first pass - do almost everything, but function imports and indirect calls
@@ -553,7 +538,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
if (value[0] == NUM) {
// global int
assert(value[1]->getNumber() == 0);
- allocateGlobal(name, WasmType::i32, false);
+ allocateGlobal(name, WasmType::i32);
} else if (value[0] == BINARY) {
// int import
assert(value[1] == OR && value[3][0] == NUM && value[3][1]->getNumber() == 0);
@@ -566,14 +551,14 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
if (import[0] == NUM) {
// global
assert(import[1]->getNumber() == 0);
- allocateGlobal(name, WasmType::f64, false);
+ allocateGlobal(name, WasmType::f64);
} else {
// import
addImport(name, import, WasmType::f64);
}
} else if (value[0] == CALL) {
assert(value[1][0] == NAME && value[1][1] == Math_fround && value[2][0][0] == NUM && value[2][0][1]->getNumber() == 0);
- allocateGlobal(name, WasmType::f32, false);
+ allocateGlobal(name, WasmType::f32);
} else if (value[0] == DOT) {
// simple module.base import. can be a view, or a function.
if (value[1][0] == NAME) {
@@ -710,6 +695,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
auto* export_ = new Export;
export_->name = key;
export_->value = value;
+ export_->kind = Export::Function;
wasm.addExport(export_);
exported[key] = export_;
}
@@ -719,11 +705,9 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
if (optimize) {
optimizingBuilder->finish();
- if (maxGlobal < 1024) {
- PassRunner passRunner(&wasm);
- passRunner.add("post-emscripten");
- passRunner.run();
- }
+ PassRunner passRunner(&wasm);
+ passRunner.add("post-emscripten");
+ passRunner.run();
}
// second pass. first, function imports
@@ -731,15 +715,16 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
std::vector<IString> toErase;
for (auto& import : wasm.imports) {
+ if (import->kind != Import::Function) continue;
IString name = import->name;
if (importedFunctionTypes.find(name) != importedFunctionTypes.end()) {
// special math builtins
FunctionType* builtin = getBuiltinFunctionType(import->module, import->base);
if (builtin) {
- import->type = builtin;
+ import->functionType = builtin;
continue;
}
- import->type = ensureFunctionType(getSig(importedFunctionTypes[name].get()), &wasm);
+ import->functionType = ensureFunctionType(getSig(importedFunctionTypes[name].get()), &wasm);
} else if (import->module != ASM2WASM) { // special-case the special module
// never actually used
toErase.push_back(name);
@@ -750,7 +735,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
wasm.removeImport(curr);
}
- // Finalize indirect calls and import calls
+ // Finalize calls now that everything is known and generated
struct FinalizeCalls : public WalkerPass<PostWalker<FinalizeCalls, Visitor<FinalizeCalls>>> {
bool isFunctionParallel() override { return true; }
@@ -761,6 +746,14 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
FinalizeCalls(Asm2WasmBuilder* parent) : parent(parent) {}
+ void visitCall(Call* curr) {
+ assert(getModule()->checkFunction(curr->target) ? true : (std::cerr << curr->target << '\n', false));
+ auto result = getModule()->getFunction(curr->target)->result;
+ if (curr->type != result) {
+ curr->type = result;
+ }
+ }
+
void visitCallImport(CallImport* curr) {
// fill out call_import - add extra params as needed, etc. asm tolerates ffi overloading, wasm does not
auto iter = parent->importedFunctionTypes.find(curr->target);
@@ -782,7 +775,25 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
}
}
}
+ auto importResult = getModule()->getImport(curr->target)->functionType->result;
+ if (curr->type != importResult) {
+ if (importResult == f64) {
+ // we use a JS f64 value which is the most general, and convert to it
+ switch (curr->type) {
+ case i32: replaceCurrent(parent->builder.makeUnary(TruncSFloat64ToInt32, curr)); break;
+ case f32: replaceCurrent(parent->builder.makeUnary(DemoteFloat64, curr)); break;
+ case none: replaceCurrent(parent->builder.makeDrop(curr)); break;
+ default: WASM_UNREACHABLE();
+ }
+ } else {
+ assert(curr->type == none);
+ // we don't want a return value here, but the import does provide one
+ replaceCurrent(parent->builder.makeDrop(curr));
+ }
+ curr->type = importResult;
+ }
}
+
void visitCallIndirect(CallIndirect* curr) {
// we already call into target = something + offset, where offset is a callImport with the name of the table. replace that with the table offset
auto add = curr->target->cast<Binary>();
@@ -793,6 +804,12 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
};
PassRunner passRunner(&wasm);
passRunner.add<FinalizeCalls>(this);
+ passRunner.add<ReFinalize>(); // FinalizeCalls changes call types, need to percolate
+ passRunner.add<AutoDrop>(); // FinalizeCalls may cause us to require additional drops
+ if (optimize) {
+ passRunner.add("vacuum"); // autodrop can add some garbage
+ passRunner.add("remove-unused-brs"); // vacuum may open up more opportunities
+ }
passRunner.run();
// apply memory growth, if relevant
@@ -802,7 +819,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
wasm.addFunction(builder.makeFunction(
GROW_WASM_MEMORY,
{ { NEW_SIZE, i32 } },
- none,
+ i32,
{},
builder.makeHost(
GrowMemory,
@@ -812,10 +829,57 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
));
auto export_ = new Export;
export_->name = export_->value = GROW_WASM_MEMORY;
+ export_->kind = Export::Function;
wasm.addExport(export_);
}
- wasm.memory.exportName = MEMORY;
+#if 0
+ // export memory
+ auto memoryExport = make_unique<Export>();
+ memoryExport->name = MEMORY;
+ memoryExport->value = Name::fromInt(0);
+ memoryExport->kind = Export::Memory;
+ wasm.addExport(memoryExport.release());
+#else
+ // import memory
+ auto memoryImport = make_unique<Import>();
+ memoryImport->name = MEMORY;
+ memoryImport->module = ENV;
+ memoryImport->base = MEMORY;
+ memoryImport->kind = Import::Memory;
+ wasm.addImport(memoryImport.release());
+
+ // import table
+ auto tableImport = make_unique<Import>();
+ tableImport->name = TABLE;
+ tableImport->module = ENV;
+ tableImport->base = TABLE;
+ tableImport->kind = Import::Table;
+ wasm.addImport(tableImport.release());
+
+ // Import memory offset
+ {
+ auto* import = new Import;
+ import->name = Name("memoryBase");
+ import->module = Name("env");
+ import->base = Name("memoryBase");
+ import->kind = Import::Global;
+ import->globalType = i32;
+ wasm.addImport(import);
+ }
+
+ // Import table offset
+ {
+ auto* import = new Import;
+ import->name = Name("tableBase");
+ import->module = Name("env");
+ import->base = Name("tableBase");
+ import->kind = Import::Global;
+ import->globalType = i32;
+ wasm.addImport(import);
+ }
+
+#endif
#if 0 // enable asm2wasm i64 optimizations when browsers have consistent i64 support in wasm
if (udivmoddi4.is() && getTempRet0.is()) {
@@ -1009,20 +1073,16 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
auto ret = allocator.alloc<SetLocal>();
ret->index = function->getLocalIndex(ast[2][1]->getIString());
ret->value = process(ast[3]);
- ret->type = ret->value->type;
+ ret->setTee(false);
+ ret->finalize();
return ret;
}
- // global var, do a store to memory
+ // global var
assert(mappedGlobals.find(name) != mappedGlobals.end());
- MappedGlobal global = mappedGlobals[name];
- auto ret = allocator.alloc<Store>();
- ret->bytes = getWasmTypeSize(global.type);
- ret->offset = 0;
- ret->align = ret->bytes;
- ret->ptr = builder.makeConst(Literal(int32_t(global.address)));
- ret->value = process(ast[3]);
- ret->type = global.type;
- return ret;
+ auto* ret = builder.makeSetGlobal(name, process(ast[3]));
+ // set_global does not return; if our value is trivially not used, don't emit a load (if nontrivially not used, opts get it later)
+ if (astStackHelper.getParent()[0] == STAT) return ret;
+ return builder.makeSequence(ret, builder.makeGetGlobal(name, ret->value->type));
} else if (ast[2][0] == SUB) {
Ref target = ast[2];
assert(target[1][0] == NAME);
@@ -1035,10 +1095,11 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
ret->align = view.bytes;
ret->ptr = processUnshifted(target[2], view.bytes);
ret->value = process(ast[3]);
- ret->type = asmToWasmType(view.type);
- if (ret->type != ret->value->type) {
+ ret->valueType = asmToWasmType(view.type);
+ ret->finalize();
+ if (ret->valueType != ret->value->type) {
// in asm.js we have some implicit coercions that we must do explicitly here
- if (ret->type == f32 && ret->value->type == f64) {
+ if (ret->valueType == f32 && ret->value->type == f64) {
auto conv = allocator.alloc<Unary>();
conv->op = DemoteFloat64;
conv->value = ret->value;
@@ -1076,7 +1137,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
import->name = F64_REM;
import->module = ASM2WASM;
import->base = F64_REM;
- import->type = ensureFunctionType("ddd", &wasm);
+ import->functionType = ensureFunctionType("ddd", &wasm);
+ import->kind = Import::Function;
wasm.addImport(import);
}
return call;
@@ -1101,7 +1163,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
import->name = call->target;
import->module = ASM2WASM;
import->base = call->target;
- import->type = ensureFunctionType("iii", &wasm);
+ import->functionType = ensureFunctionType("iii", &wasm);
+ import->kind = Import::Function;
wasm.addImport(import);
}
return call;
@@ -1139,22 +1202,16 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
import->name = DEBUGGER;
import->module = ASM2WASM;
import->base = DEBUGGER;
- import->type = ensureFunctionType("v", &wasm);
+ import->functionType = ensureFunctionType("v", &wasm);
+ import->kind = Import::Function;
wasm.addImport(import);
}
return call;
}
- // global var, do a load from memory
- assert(mappedGlobals.find(name) != mappedGlobals.end());
- MappedGlobal global = mappedGlobals[name];
- auto ret = allocator.alloc<Load>();
- ret->bytes = getWasmTypeSize(global.type);
- ret->signed_ = true; // but doesn't matter
- ret->offset = 0;
- ret->align = ret->bytes;
- ret->ptr = builder.makeConst(Literal(int32_t(global.address)));
- ret->type = global.type;
- return ret;
+ // global var
+ assert(mappedGlobals.find(name) != mappedGlobals.end() ? true : (std::cerr << name.str << '\n', false));
+ MappedGlobal& global = mappedGlobals[name];
+ return builder.makeGetGlobal(name, global.type);
} else if (what == SUB) {
Ref target = ast[1];
assert(target[0] == NAME);
@@ -1251,7 +1308,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
import->name = F64_TO_INT;
import->module = ASM2WASM;
import->base = F64_TO_INT;
- import->type = ensureFunctionType("id", &wasm);
+ import->functionType = ensureFunctionType("id", &wasm);
+ import->kind = Import::Function;
wasm.addImport(import);
}
return ret;
@@ -1273,11 +1331,9 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
}
abort_on("bad unary", ast);
} else if (what == IF) {
- auto ret = allocator.alloc<If>();
- ret->condition = process(ast[1]);
- ret->ifTrue = process(ast[2]);
- ret->ifFalse = !!ast[3] ? process(ast[3]) : nullptr;
- return ret;
+ auto* condition = process(ast[1]);
+ auto* ifTrue = process(ast[2]);
+ return builder.makeIf(condition, ifTrue, !!ast[3] ? process(ast[3]) : nullptr);
} else if (what == CALL) {
if (ast[1][0] == NAME) {
IString name = ast[1][1]->getIString();
@@ -1330,9 +1386,10 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
// No wasm support, so use a temp local
ensureI32Temp();
auto set = allocator.alloc<SetLocal>();
+ set->setTee(false);
set->index = function->getLocalIndex(I32_TEMP);
set->value = value;
- set->type = i32;
+ set->finalize();
auto get = [&]() {
auto ret = allocator.alloc<GetLocal>();
ret->index = function->getLocalIndex(I32_TEMP);
@@ -1477,8 +1534,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
out = getNextId("while-out");
in = getNextId("while-in");
}
- ret->out = out;
- ret->in = in;
+ ret->name = in;
breakStack.push_back(out);
continueStack.push_back(in);
if (forever) {
@@ -1489,6 +1545,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
If *condition = allocator.alloc<If>();
condition->condition = builder.makeUnary(EqZInt32, process(ast[1]));
condition->ifTrue = breakOut;
+ condition->finalize();
auto body = allocator.alloc<Block>();
body->list.push_back(condition);
body->list.push_back(process(ast[2]));
@@ -1496,11 +1553,13 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
ret->body = body;
}
// loops do not automatically loop, add a branch back
- Block* block = blockify(ret->body);
+ Block* block = builder.blockifyWithName(ret->body, out);
auto continuer = allocator.alloc<Break>();
- continuer->name = ret->in;
+ continuer->name = ret->name;
block->list.push_back(continuer);
+ block->finalize();
ret->body = block;
+ ret->finalize();
continueStack.pop_back();
breakStack.pop_back();
return ret;
@@ -1526,19 +1585,22 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
if (breakSeeker.found == 0) {
auto block = allocator.alloc<Block>();
block->list.push_back(child);
+ if (isConcreteWasmType(child->type)) {
+ block->list.push_back(builder.makeNop()); // ensure a nop at the end, so the block has guaranteed none type and no values fall through
+ }
block->name = stop;
block->finalize();
return block;
} else {
auto loop = allocator.alloc<Loop>();
loop->body = child;
- loop->out = stop;
- loop->in = more;
- return loop;
+ loop->name = more;
+ loop->finalize();
+ return builder.blockifyWithName(loop, stop);
}
}
// general do-while loop
- auto ret = allocator.alloc<Loop>();
+ auto loop = allocator.alloc<Loop>();
IString out, in;
if (!parentLabel.isNull()) {
out = getBreakLabelName(parentLabel);
@@ -1548,20 +1610,19 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
out = getNextId("do-out");
in = getNextId("do-in");
}
- ret->out = out;
- ret->in = in;
+ loop->name = in;
breakStack.push_back(out);
continueStack.push_back(in);
- ret->body = process(ast[2]);
+ loop->body = process(ast[2]);
continueStack.pop_back();
breakStack.pop_back();
Break *continuer = allocator.alloc<Break>();
continuer->name = in;
continuer->condition = process(ast[1]);
- Block *block = blockify(ret->body);
- block->list.push_back(continuer);
- ret->body = block;
- return ret;
+ Block *block = builder.blockifyWithName(loop->body, out, continuer);
+ loop->body = block;
+ loop->finalize();
+ return loop;
} else if (what == FOR) {
Ref finit = ast[1],
fcond = ast[2],
@@ -1577,8 +1638,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
out = getNextId("for-out");
in = getNextId("for-in");
}
- ret->out = out;
- ret->in = in;
+ ret->name = in;
breakStack.push_back(out);
continueStack.push_back(in);
Break *breakOut = allocator.alloc<Break>();
@@ -1586,6 +1646,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
If *condition = allocator.alloc<If>();
condition->condition = builder.makeUnary(EqZInt32, process(fcond));
condition->ifTrue = breakOut;
+ condition->finalize();
auto body = allocator.alloc<Block>();
body->list.push_back(condition);
body->list.push_back(process(fbody));
@@ -1593,11 +1654,11 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
body->finalize();
ret->body = body;
// loops do not automatically loop, add a branch back
- Block* block = blockify(ret->body);
auto continuer = allocator.alloc<Break>();
- continuer->name = ret->in;
- block->list.push_back(continuer);
+ continuer->name = ret->name;
+ Block* block = builder.blockifyWithName(ret->body, out, continuer);
ret->body = block;
+ ret->finalize();
continueStack.pop_back();
breakStack.pop_back();
Block *outer = allocator.alloc<Block>();
@@ -1615,7 +1676,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
ret->condition = process(ast[1]);
ret->ifTrue = process(ast[2]);
ret->ifFalse = process(ast[3]);
- ret->type = ret->ifTrue->type;
+ ret->finalize();
return ret;
} else if (what == SEQ) {
// Some (x, y) patterns can be optimized, like bitcasts,
@@ -1774,7 +1835,9 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
// if there is a shift, we can just look through it, etc.
processUnshifted = [&](Ref ptr, unsigned bytes) {
auto shifts = bytesToShift(bytes);
- if (ptr[0] == BINARY && ptr[1] == RSHIFT && ptr[3][0] == NUM && ptr[3][1]->getInteger() == shifts) {
+ // HEAP?[addr >> ?], or HEAP8[x | 0]
+ if ((ptr[0] == BINARY && ptr[1] == RSHIFT && ptr[3][0] == NUM && ptr[3][1]->getInteger() == shifts) ||
+ (bytes == 1 && ptr[0] == BINARY && ptr[1] == OR && ptr[3][0] == NUM && ptr[3][1]->getInteger() == 0)) {
return process(ptr[2]); // look through it
} else if (ptr[0] == NUM) {
// constant, apply a shift (e.g. HEAP32[1] is address 4)
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 3e45d0e33..9b2ff10cd 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -21,6 +21,7 @@
#include "wasm.h"
#include "wasm-traversal.h"
#include "wasm-builder.h"
+#include "pass.h"
namespace wasm {
@@ -129,6 +130,9 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer
bool checkPost(Expression* curr) {
visit(curr);
+ if (curr->is<Loop>()) {
+ branches = true;
+ }
return hasAnything();
}
@@ -147,8 +151,7 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer
if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks
}
void visitLoop(Loop* curr) {
- if (curr->in.is()) breakNames.erase(curr->in); // these were internal breaks
- if (curr->out.is()) breakNames.erase(curr->out); // these were internal breaks
+ if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks
}
void visitCall(Call *curr) { calls = true; }
@@ -244,7 +247,7 @@ struct ExpressionManipulator {
return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse));
}
Expression* visitLoop(Loop *curr) {
- return builder.makeLoop(curr->out, curr->in, copy(curr->body));
+ return builder.makeLoop(curr->name, copy(curr->body));
}
Expression* visitBreak(Break *curr) {
return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition));
@@ -277,19 +280,23 @@ struct ExpressionManipulator {
return builder.makeGetLocal(curr->index, curr->type);
}
Expression* visitSetLocal(SetLocal *curr) {
- return builder.makeSetLocal(curr->index, copy(curr->value));
+ if (curr->isTee()) {
+ return builder.makeTeeLocal(curr->index, copy(curr->value));
+ } else {
+ return builder.makeSetLocal(curr->index, copy(curr->value));
+ }
}
Expression* visitGetGlobal(GetGlobal *curr) {
- return builder.makeGetGlobal(curr->index, curr->type);
+ return builder.makeGetGlobal(curr->name, curr->type);
}
Expression* visitSetGlobal(SetGlobal *curr) {
- return builder.makeSetGlobal(curr->index, copy(curr->value));
+ return builder.makeSetGlobal(curr->name, copy(curr->value));
}
Expression* visitLoad(Load *curr) {
return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type);
}
Expression* visitStore(Store *curr) {
- return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value));
+ return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType);
}
Expression* visitConst(Const *curr) {
return builder.makeConst(curr->value);
@@ -303,6 +310,9 @@ struct ExpressionManipulator {
Expression* visitSelect(Select *curr) {
return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse));
}
+ Expression* visitDrop(Drop *curr) {
+ return builder.makeDrop(copy(curr->value));
+ }
Expression* visitReturn(Return *curr) {
return builder.makeReturn(copy(curr->value));
}
@@ -340,7 +350,7 @@ struct ExpressionAnalyzer {
for (int i = int(stack.size()) - 2; i >= 0; i--) {
auto* curr = stack[i];
auto* above = stack[i + 1];
- // only if and block can drop values
+ // only if and block can drop values (pre-drop expression was added) FIXME
if (curr->is<Block>()) {
auto* block = curr->cast<Block>();
for (size_t j = 0; j < block->list.size() - 1; j++) {
@@ -355,6 +365,7 @@ struct ExpressionAnalyzer {
assert(above == iff->ifTrue || above == iff->ifFalse);
// continue down
} else {
+ if (curr->is<Drop>()) return false;
return true; // all other node types use the result
}
}
@@ -429,8 +440,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::LoopId: {
- if (!noteNames(left->cast<Loop>()->out, right->cast<Loop>()->out)) return false;
- if (!noteNames(left->cast<Loop>()->in, right->cast<Loop>()->in)) return false;
+ if (!noteNames(left->cast<Loop>()->name, right->cast<Loop>()->name)) return false;
PUSH(Loop, body);
break;
}
@@ -481,15 +491,16 @@ struct ExpressionAnalyzer {
}
case Expression::Id::SetLocalId: {
CHECK(SetLocal, index);
+ CHECK(SetLocal, type); // for tee/set
PUSH(SetLocal, value);
break;
}
case Expression::Id::GetGlobalId: {
- CHECK(GetGlobal, index);
+ CHECK(GetGlobal, name);
break;
}
case Expression::Id::SetGlobalId: {
- CHECK(SetGlobal, index);
+ CHECK(SetGlobal, name);
PUSH(SetGlobal, value);
break;
}
@@ -505,6 +516,7 @@ struct ExpressionAnalyzer {
CHECK(Store, bytes);
CHECK(Store, offset);
CHECK(Store, align);
+ CHECK(Store, valueType);
PUSH(Store, ptr);
PUSH(Store, value);
break;
@@ -530,6 +542,10 @@ struct ExpressionAnalyzer {
PUSH(Select, condition);
break;
}
+ case Expression::Id::DropId: {
+ PUSH(Drop, value);
+ break;
+ }
case Expression::Id::ReturnId: {
PUSH(Return, value);
break;
@@ -640,8 +656,7 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::LoopId: {
- noteName(curr->cast<Loop>()->out);
- noteName(curr->cast<Loop>()->in);
+ noteName(curr->cast<Loop>()->name);
PUSH(Loop, body);
break;
}
@@ -696,11 +711,11 @@ struct ExpressionAnalyzer {
break;
}
case Expression::Id::GetGlobalId: {
- HASH(GetGlobal, index);
+ HASH_NAME(GetGlobal, name);
break;
}
case Expression::Id::SetGlobalId: {
- HASH(SetGlobal, index);
+ HASH_NAME(SetGlobal, name);
PUSH(SetGlobal, value);
break;
}
@@ -716,6 +731,7 @@ struct ExpressionAnalyzer {
HASH(Store, bytes);
HASH(Store, offset);
HASH(Store, align);
+ HASH(Store, valueType);
PUSH(Store, ptr);
PUSH(Store, value);
break;
@@ -742,6 +758,10 @@ struct ExpressionAnalyzer {
PUSH(Select, condition);
break;
}
+ case Expression::Id::DropId: {
+ PUSH(Drop, value);
+ break;
+ }
case Expression::Id::ReturnId: {
PUSH(Return, value);
break;
@@ -770,6 +790,65 @@ struct ExpressionAnalyzer {
}
};
+// Adds drop() operations where necessary. This lets you not worry about adding drop when
+// generating code.
+struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop, Visitor<AutoDrop>>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new AutoDrop; }
+
+ void visitBlock(Block* curr) {
+ if (curr->list.size() == 0) return;
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ auto* child = curr->list[i];
+ if (isConcreteWasmType(child->type)) {
+ curr->list[i] = Builder(*getModule()).makeDrop(child);
+ }
+ }
+ auto* last = curr->list.back();
+ expressionStack.push_back(last);
+ if (isConcreteWasmType(last->type) && !ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) {
+ curr->list.back() = Builder(*getModule()).makeDrop(last);
+ }
+ expressionStack.pop_back();
+ curr->finalize(); // we may have changed our type
+ }
+
+ void visitFunction(Function* curr) {
+ if (curr->result == none && isConcreteWasmType(curr->body->type)) {
+ curr->body = Builder(*getModule()).makeDrop(curr->body);
+ }
+ }
+};
+
+// Finalizes a node
+
+struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, Visitor<ReFinalize>>> {
+ void visitBlock(Block *curr) { curr->finalize(); }
+ void visitIf(If *curr) { curr->finalize(); }
+ void visitLoop(Loop *curr) { curr->finalize(); }
+ void visitBreak(Break *curr) { curr->finalize(); }
+ void visitSwitch(Switch *curr) { curr->finalize(); }
+ void visitCall(Call *curr) { curr->finalize(); }
+ void visitCallImport(CallImport *curr) { curr->finalize(); }
+ void visitCallIndirect(CallIndirect *curr) { curr->finalize(); }
+ void visitGetLocal(GetLocal *curr) { curr->finalize(); }
+ void visitSetLocal(SetLocal *curr) { curr->finalize(); }
+ void visitGetGlobal(GetGlobal *curr) { curr->finalize(); }
+ void visitSetGlobal(SetGlobal *curr) { curr->finalize(); }
+ void visitLoad(Load *curr) { curr->finalize(); }
+ void visitStore(Store *curr) { curr->finalize(); }
+ void visitConst(Const *curr) { curr->finalize(); }
+ void visitUnary(Unary *curr) { curr->finalize(); }
+ void visitBinary(Binary *curr) { curr->finalize(); }
+ void visitSelect(Select *curr) { curr->finalize(); }
+ void visitDrop(Drop *curr) { curr->finalize(); }
+ void visitReturn(Return *curr) { curr->finalize(); }
+ void visitHost(Host *curr) { curr->finalize(); }
+ void visitNop(Nop *curr) { curr->finalize(); }
+ void visitUnreachable(Unreachable *curr) { curr->finalize(); }
+};
+
} // namespace wasm
#endif // wasm_ast_utils_h
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp
index 83c626eb1..3ce24bdbd 100644
--- a/src/binaryen-c.cpp
+++ b/src/binaryen-c.cpp
@@ -338,16 +338,13 @@ BinaryenExpressionRef BinaryenIf(BinaryenModuleRef module, BinaryenExpressionRef
return static_cast<Expression*>(ret);
}
-BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* out, const char* in, BinaryenExpressionRef body) {
- if (out && !in) abort();
- auto* ret = Builder(*((Module*)module)).makeLoop(out ? Name(out) : Name(), in ? Name(in) : Name(), (Expression*)body);
+BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* name, BinaryenExpressionRef body) {
+ auto* ret = Builder(*((Module*)module)).makeLoop(name ? Name(name) : Name(), (Expression*)body);
if (tracing) {
auto id = noteExpression(ret);
std::cout << " expressions[" << id << "] = BinaryenLoop(the_module, ";
- traceNameOrNULL(out);
- std::cout << ", ";
- traceNameOrNULL(in);
+ traceNameOrNULL(name);
std::cout << ", expressions[" << expressions[body] << "]);\n";
}
@@ -488,6 +485,21 @@ BinaryenExpressionRef BinaryenSetLocal(BinaryenModuleRef module, BinaryenIndex i
ret->index = index;
ret->value = (Expression*)value;
+ ret->setTee(false);
+ ret->finalize();
+ return static_cast<Expression*>(ret);
+}
+BinaryenExpressionRef BinaryenTeeLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value) {
+ auto* ret = ((Module*)module)->allocator.alloc<SetLocal>();
+
+ if (tracing) {
+ auto id = noteExpression(ret);
+ std::cout << " expressions[" << id << "] = BinaryenTeeLocal(the_module, " << index << ", expressions[" << expressions[value] << "]);\n";
+ }
+
+ ret->index = index;
+ ret->value = (Expression*)value;
+ ret->setTee(true);
ret->finalize();
return static_cast<Expression*>(ret);
}
@@ -508,12 +520,12 @@ BinaryenExpressionRef BinaryenLoad(BinaryenModuleRef module, uint32_t bytes, int
ret->finalize();
return static_cast<Expression*>(ret);
}
-BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value) {
+BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value, BinaryenType type) {
auto* ret = ((Module*)module)->allocator.alloc<Store>();
if (tracing) {
auto id = noteExpression(ret);
- std::cout << " expressions[" << id << "] = BinaryenStore(the_module, " << bytes << ", " << offset << ", " << align << ", expressions[" << expressions[ptr] << "], expressions[" << expressions[value] << "]);\n";
+ std::cout << " expressions[" << id << "] = BinaryenStore(the_module, " << bytes << ", " << offset << ", " << align << ", expressions[" << expressions[ptr] << "], expressions[" << expressions[value] << "], " << type << ");\n";
}
ret->bytes = bytes;
@@ -521,6 +533,7 @@ BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, ui
ret->align = align ? align : bytes;
ret->ptr = (Expression*)ptr;
ret->value = (Expression*)value;
+ ret->valueType = WasmType(type);
ret->finalize();
return static_cast<Expression*>(ret);
}
@@ -584,6 +597,18 @@ BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressio
ret->finalize();
return static_cast<Expression*>(ret);
}
+BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value) {
+ auto* ret = ((Module*)module)->allocator.alloc<Drop>();
+
+ if (tracing) {
+ auto id = noteExpression(ret);
+ std::cout << " expressions[" << id << "] = BinaryenDrop(the_module, expressions[" << expressions[value] << "]);\n";
+ }
+
+ ret->value = (Expression*)value;
+ ret->finalize();
+ return static_cast<Expression*>(ret);
+}
BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value) {
auto* ret = Builder(*((Module*)module)).makeReturn((Expression*)value);
@@ -692,7 +717,8 @@ BinaryenImportRef BinaryenAddImport(BinaryenModuleRef module, const char* intern
ret->name = internalName;
ret->module = externalModuleName;
ret->base = externalBaseName;
- ret->type = (FunctionType*)type;
+ ret->functionType = (FunctionType*)type;
+ ret->kind = Import::Function;
wasm->addImport(ret);
return ret;
}
@@ -780,7 +806,13 @@ void BinaryenSetMemory(BinaryenModuleRef module, BinaryenIndex initial, Binaryen
auto* wasm = (Module*)module;
wasm->memory.initial = initial;
wasm->memory.max = maximum;
- if (exportName) wasm->memory.exportName = exportName;
+ if (exportName) {
+ auto memoryExport = make_unique<Export>();
+ memoryExport->name = exportName;
+ memoryExport->value = Name::fromInt(0);
+ memoryExport->kind = Export::Memory;
+ wasm->addExport(memoryExport.release());
+ }
for (BinaryenIndex i = 0; i < numSegments; i++) {
wasm->memory.segments.emplace_back((Expression*)segmentOffsets[i], segments[i], segmentSizes[i]);
}
@@ -829,6 +861,17 @@ void BinaryenModuleOptimize(BinaryenModuleRef module) {
passRunner.run();
}
+void BinaryenModuleAutoDrop(BinaryenModuleRef module) {
+ if (tracing) {
+ std::cout << " BinaryenModuleAutoDrop(the_module);\n";
+ }
+
+ Module* wasm = (Module*)module;
+ PassRunner passRunner(wasm);
+ passRunner.add<AutoDrop>();
+ passRunner.run();
+}
+
size_t BinaryenModuleWrite(BinaryenModuleRef module, char* output, size_t outputSize) {
if (tracing) {
std::cout << " // BinaryenModuleWrite\n";
diff --git a/src/binaryen-c.h b/src/binaryen-c.h
index e2086072b..9ecb3f409 100644
--- a/src/binaryen-c.h
+++ b/src/binaryen-c.h
@@ -265,8 +265,7 @@ typedef void* BinaryenExpressionRef;
BinaryenExpressionRef BinaryenBlock(BinaryenModuleRef module, const char* name, BinaryenExpressionRef* children, BinaryenIndex numChildren);
// If: ifFalse can be NULL
BinaryenExpressionRef BinaryenIf(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, BinaryenExpressionRef ifFalse);
-// Loop: both out and in can be NULL, or just out can be NULL
-BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* out, const char* in, BinaryenExpressionRef body);
+BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* in, BinaryenExpressionRef body);
// Break: value and condition can be NULL
BinaryenExpressionRef BinaryenBreak(BinaryenModuleRef module, const char* name, BinaryenExpressionRef condition, BinaryenExpressionRef value);
// Switch: value can be NULL
@@ -294,14 +293,16 @@ BinaryenExpressionRef BinaryenCallIndirect(BinaryenModuleRef module, BinaryenExp
// for more details.
BinaryenExpressionRef BinaryenGetLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenType type);
BinaryenExpressionRef BinaryenSetLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value);
+BinaryenExpressionRef BinaryenTeeLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value);
// Load: align can be 0, in which case it will be the natural alignment (equal to bytes)
BinaryenExpressionRef BinaryenLoad(BinaryenModuleRef module, uint32_t bytes, int8_t signed_, uint32_t offset, uint32_t align, BinaryenType type, BinaryenExpressionRef ptr);
// Store: align can be 0, in which case it will be the natural alignment (equal to bytes)
-BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value);
+BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value, BinaryenType type);
BinaryenExpressionRef BinaryenConst(BinaryenModuleRef module, struct BinaryenLiteral value);
BinaryenExpressionRef BinaryenUnary(BinaryenModuleRef module, BinaryenOp op, BinaryenExpressionRef value);
BinaryenExpressionRef BinaryenBinary(BinaryenModuleRef module, BinaryenOp op, BinaryenExpressionRef left, BinaryenExpressionRef right);
BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, BinaryenExpressionRef ifFalse);
+BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value);
// Return: value can be NULL
BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value);
// Host: name may be NULL
@@ -366,8 +367,13 @@ int BinaryenModuleValidate(BinaryenModuleRef module);
// Run the standard optimization passes on the module.
void BinaryenModuleOptimize(BinaryenModuleRef module);
+// Auto-generate drop() operations where needed. This lets you generate code without
+// worrying about where they are needed. (It is more efficient to do it yourself,
+// but simpler to use autodrop).
+void BinaryenModuleAutoDrop(BinaryenModuleRef module);
+
// Serialize a module into binary form.
-// @return how many bytes were written. This will be less than or equal to bufferSize
+// @return how many bytes were written. This will be less than or equal to outputSize
size_t BinaryenModuleWrite(BinaryenModuleRef module, char* output, size_t outputSize);
// Deserialize a module from binary form.
diff --git a/src/cfg/Relooper.cpp b/src/cfg/Relooper.cpp
index e75adbce5..8a5337ed0 100644
--- a/src/cfg/Relooper.cpp
+++ b/src/cfg/Relooper.cpp
@@ -383,7 +383,7 @@ wasm::Expression* MultipleShape::Render(RelooperBuilder& Builder, bool InLoop) {
// LoopShape
wasm::Expression* LoopShape::Render(RelooperBuilder& Builder, bool InLoop) {
- wasm::Expression* Ret = Builder.makeLoop(wasm::Name(), Builder.getShapeContinueName(Id), Inner->Render(Builder, true));
+ wasm::Expression* Ret = Builder.makeLoop(Builder.getShapeContinueName(Id), Inner->Render(Builder, true));
Ret = HandleFollowupMultiples(Ret, this, Builder, InLoop);
if (Next) {
Ret = Builder.makeSequence(Ret, Next->Render(Builder, InLoop));
diff --git a/src/cfg/cfg-traversal.h b/src/cfg/cfg-traversal.h
index d690de4aa..5bc593691 100644
--- a/src/cfg/cfg-traversal.h
+++ b/src/cfg/cfg-traversal.h
@@ -36,7 +36,7 @@
namespace wasm {
template<typename SubType, typename VisitorType, typename Contents>
-struct CFGWalker : public PostWalker<SubType, VisitorType> {
+struct CFGWalker : public ControlFlowWalker<SubType, VisitorType> {
// public interface
@@ -57,7 +57,7 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
// traversal state
BasicBlock* currBasicBlock; // the current block in play during traversal
- std::map<Name, std::vector<BasicBlock*>> branches;
+ std::map<Expression*, std::vector<BasicBlock*>> branches; // a block or loop => its branches
std::vector<BasicBlock*> ifStack;
std::vector<BasicBlock*> loopStack;
@@ -74,7 +74,7 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
static void doEndBlock(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<Block>();
if (!curr->name.is()) return;
- auto iter = self->branches.find(curr->name);
+ auto iter = self->branches.find(curr);
if (iter == self->branches.end()) return;
auto& origins = iter->second;
if (origins.size() == 0) return;
@@ -83,12 +83,10 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
doStartBasicBlock(self, currp);
self->link(last, self->currBasicBlock); // fallthrough
// branches to the new one
- if (curr->name.is()) {
- for (auto* origin : origins) {
- self->link(origin, self->currBasicBlock);
- }
- self->branches.erase(curr->name);
+ for (auto* origin : origins) {
+ self->link(origin, self->currBasicBlock);
}
+ self->branches.erase(curr);
}
static void doStartIfTrue(SubType* self, Expression** currp) {
@@ -130,30 +128,22 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
auto* last = self->currBasicBlock;
doStartBasicBlock(self, currp);
self->link(last, self->currBasicBlock); // fallthrough
- // branches to the new one
auto* curr = (*currp)->cast<Loop>();
- if (curr->out.is()) {
- auto& origins = self->branches[curr->out];
- for (auto* origin : origins) {
- self->link(origin, self->currBasicBlock);
- }
- self->branches.erase(curr->out);
- }
// branches to the top of the loop
- if (curr->in.is()) {
+ if (curr->name.is()) {
auto* loopStart = self->loopStack.back();
- auto& origins = self->branches[curr->in];
+ auto& origins = self->branches[curr];
for (auto* origin : origins) {
self->link(origin, loopStart);
}
- self->branches.erase(curr->in);
+ self->branches.erase(curr);
}
self->loopStack.pop_back();
}
static void doEndBreak(SubType* self, Expression** currp) {
auto* curr = (*currp)->cast<Break>();
- self->branches[curr->name].push_back(self->currBasicBlock); // branch to the target
+ self->branches[self->findBreakTarget(curr->name)].push_back(self->currBasicBlock); // branch to the target
auto* last = self->currBasicBlock;
doStartBasicBlock(self, currp);
if (curr->condition) {
@@ -166,12 +156,12 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
std::set<Name> seen; // we might see the same label more than once; do not spam branches
for (Name target : curr->targets) {
if (!seen.count(target)) {
- self->branches[target].push_back(self->currBasicBlock); // branch to the target
+ self->branches[self->findBreakTarget(target)].push_back(self->currBasicBlock); // branch to the target
seen.insert(target);
}
}
if (!seen.count(curr->default_)) {
- self->branches[curr->default_].push_back(self->currBasicBlock); // branch to the target
+ self->branches[self->findBreakTarget(curr->default_)].push_back(self->currBasicBlock); // branch to the target
}
doStartBasicBlock(self, currp);
}
@@ -182,14 +172,10 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
switch (curr->_id) {
case Expression::Id::BlockId: {
self->pushTask(SubType::doEndBlock, currp);
- self->pushTask(SubType::doVisitBlock, currp);
- auto& list = curr->cast<Block>()->list;
- for (int i = int(list.size()) - 1; i >= 0; i--) {
- self->pushTask(SubType::scan, &list[i]);
- }
break;
}
case Expression::Id::IfId: {
+ self->pushTask(SubType::doPostVisitControlFlow, currp);
self->pushTask(SubType::doEndIf, currp);
self->pushTask(SubType::doVisitIf, currp);
auto* ifFalse = curr->cast<If>()->ifFalse;
@@ -200,44 +186,40 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue);
self->pushTask(SubType::doStartIfTrue, currp);
self->pushTask(SubType::scan, &curr->cast<If>()->condition);
- break;
+ self->pushTask(SubType::doPreVisitControlFlow, currp);
+ return; // don't do anything else
}
case Expression::Id::LoopId: {
self->pushTask(SubType::doEndLoop, currp);
- self->pushTask(SubType::doVisitLoop, currp);
- self->pushTask(SubType::scan, &curr->cast<Loop>()->body);
- self->pushTask(SubType::doStartLoop, currp);
break;
}
case Expression::Id::BreakId: {
self->pushTask(SubType::doEndBreak, currp);
- self->pushTask(SubType::doVisitBreak, currp);
- self->maybePushTask(SubType::scan, &curr->cast<Break>()->condition);
- self->maybePushTask(SubType::scan, &curr->cast<Break>()->value);
break;
}
case Expression::Id::SwitchId: {
self->pushTask(SubType::doEndSwitch, currp);
- self->pushTask(SubType::doVisitSwitch, currp);
- self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value);
- self->pushTask(SubType::scan, &curr->cast<Switch>()->condition);
break;
}
case Expression::Id::ReturnId: {
self->pushTask(SubType::doStartBasicBlock, currp);
- self->pushTask(SubType::doVisitReturn, currp);
- self->maybePushTask(SubType::scan, &curr->cast<Return>()->value);
break;
}
case Expression::Id::UnreachableId: {
self->pushTask(SubType::doStartBasicBlock, currp);
- self->pushTask(SubType::doVisitUnreachable, currp);
break;
}
- default: {
- // other node types do not have control flow, use regular post-order
- PostWalker<SubType, VisitorType>::scan(self, currp);
+ default: {}
+ }
+
+ ControlFlowWalker<SubType, VisitorType>::scan(self, currp);
+
+ switch (curr->_id) {
+ case Expression::Id::LoopId: {
+ self->pushTask(SubType::doStartLoop, currp);
+ break;
}
+ default: {}
}
}
@@ -246,7 +228,7 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> {
doStartBasicBlock(static_cast<SubType*>(this), nullptr);
entry = currBasicBlock;
- PostWalker<SubType, VisitorType>::doWalkFunction(func);
+ ControlFlowWalker<SubType, VisitorType>::doWalkFunction(func);
assert(branches.size() == 0);
assert(ifStack.size() == 0);
diff --git a/src/js/wasm.js-post.js b/src/js/wasm.js-post.js
index ae42ef175..1c5e6b0b1 100644
--- a/src/js/wasm.js-post.js
+++ b/src/js/wasm.js-post.js
@@ -29,6 +29,7 @@ function integrateWasmJS(Module) {
// inputs
var method = Module['wasmJSMethod'] || {{{ wasmJSMethod }}} || 'native-wasm,interpret-s-expr'; // by default, try native and then .wast
+ Module['wasmJSMethod'] = method;
var wasmTextFile = Module['wasmTextFile'] || {{{ wasmTextFile }}};
var wasmBinaryFile = Module['wasmBinaryFile'] || {{{ wasmBinaryFile }}};
@@ -100,19 +101,15 @@ function integrateWasmJS(Module) {
}
var oldView = new Int8Array(oldBuffer);
var newView = new Int8Array(newBuffer);
- if ({{{ WASM_BACKEND }}}) {
- // memory segments arrived in the wast, do not trample them
+
+ // If we have a mem init file, do not trample it
+ if (!memoryInitializer) {
oldView.set(newView.subarray(STATIC_BASE, STATIC_BASE + STATIC_BUMP), STATIC_BASE);
}
+
newView.set(oldView);
updateGlobalBuffer(newBuffer);
updateGlobalBufferViews();
- Module['reallocBuffer'] = function(size) {
- size = Math.ceil(size / wasmPageSize) * wasmPageSize; // round up to wasm page size
- var old = Module['buffer'];
- exports['__growWasmMemory'](size / wasmPageSize); // tiny wasm method that just does grow_memory
- return Module['buffer'] !== old ? Module['buffer'] : null; // if it was reallocated, it changed
- };
}
var WasmTypes = {
@@ -123,26 +120,6 @@ function integrateWasmJS(Module) {
f64: 4
};
- // wasm lacks globals, so asm2wasm maps them into locations in memory. that information cannot
- // be present in the wasm output of asm2wasm, so we store it in a side file. If we load asm2wasm
- // output, either generated ahead of time or on the client, we need to apply those mapped
- // globals after loading the module.
- function applyMappedGlobals(globalsFileBase) {
- var mappedGlobals = JSON.parse(Module['read'](globalsFileBase + '.mappedGlobals'));
- for (var name in mappedGlobals) {
- var global = mappedGlobals[name];
- if (!global.import) continue; // non-imports are initialized to zero in the typed array anyhow, so nothing to do here
- var value = lookupImport(global.module, global.base);
- var address = global.address;
- switch (global.type) {
- case WasmTypes.i32: Module['HEAP32'][address >> 2] = value; break;
- case WasmTypes.f32: Module['HEAPF32'][address >> 2] = value; break;
- case WasmTypes.f64: Module['HEAPF64'][address >> 3] = value; break;
- default: abort();
- }
- }
- }
-
function fixImports(imports) {
if (!{{{ WASM_BACKEND }}}) return imports;
var ret = {};
@@ -206,9 +183,7 @@ function integrateWasmJS(Module) {
return false;
}
exports = instance.exports;
- mergeMemory(exports.memory);
-
- applyMappedGlobals(wasmBinaryFile);
+ if (exports.memory) mergeMemory(exports.memory);
Module["usingWasm"] = true;
@@ -237,6 +212,13 @@ function integrateWasmJS(Module) {
info.global = global;
info.env = env;
+ if (!('memoryBase' in env)) {
+ env['memoryBase'] = STATIC_BASE; // tell the memory segments where to place themselves
+ }
+ if (!('tableBase' in env)) {
+ env['tableBase'] = 0; // tell the memory segments where to place themselves
+ }
+
wasmJS['providedTotalMemory'] = Module['buffer'].byteLength;
// Prepare to generate wasm, using either asm2wasm or s-exprs
@@ -271,12 +253,6 @@ function integrateWasmJS(Module) {
Module['newBuffer'] = null;
}
- if (method == 'interpret-s-expr') {
- applyMappedGlobals(wasmTextFile);
- } else if (method == 'interpret-binary') {
- applyMappedGlobals(wasmBinaryFile);
- }
-
exports = wasmJS['asmExports'];
return exports;
@@ -285,6 +261,14 @@ function integrateWasmJS(Module) {
// We may have a preloaded value in Module.asm, save it
Module['asmPreload'] = Module['asm'];
+ // Memory growth integration code
+ Module['reallocBuffer'] = function(size) {
+ size = Math.ceil(size / wasmPageSize) * wasmPageSize; // round up to wasm page size
+ var old = Module['buffer'];
+ exports['__growWasmMemory'](size / wasmPageSize); // tiny wasm method that just does grow_memory
+ return Module['buffer'] !== old ? Module['buffer'] : null; // if it was reallocated, it changed
+ };
+
// Provide an "asm.js function" for the application, called to "link" the asm.js module. We instantiate
// the wasm module at that time, and it receives imports and provides exports and so forth, the app
// doesn't need to care that it is wasm or olyfilled wasm or asm.js.
@@ -293,6 +277,14 @@ function integrateWasmJS(Module) {
global = fixImports(global);
env = fixImports(env);
+ // import memory and table
+ if (!env['memory']) {
+ env['memory'] = providedBuffer;
+ }
+ if (!env['table']) {
+ env['table'] = new Array(1024);
+ }
+
// try the methods. each should return the exports if it succeeded
var exports;
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index 2ccf5f040..1b4c65562 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -2,9 +2,7 @@ SET(passes_SOURCES
pass.cpp
CoalesceLocals.cpp
DeadCodeElimination.cpp
- DropReturnValues.cpp
DuplicateFunctionElimination.cpp
- LowerIfElse.cpp
MergeBlocks.cpp
Metrics.cpp
NameManager.cpp
diff --git a/src/passes/CoalesceLocals.cpp b/src/passes/CoalesceLocals.cpp
index 2c707b614..2063a20db 100644
--- a/src/passes/CoalesceLocals.cpp
+++ b/src/passes/CoalesceLocals.cpp
@@ -485,9 +485,19 @@ void CoalesceLocals::applyIndices(std::vector<Index>& indices, Expression* root)
// in addition, we can optimize out redundant copies and ineffective sets
GetLocal* get;
if ((get = set->value->dynCast<GetLocal>()) && get->index == set->index) {
- *action.origin = get; // further optimizations may get rid of the get, if this is in a place where the output does not matter
+ if (set->isTee()) {
+ *action.origin = get;
+ } else {
+ ExpressionManipulator::nop(set);
+ }
} else if (!action.effective) {
*action.origin = set->value; // value may have no side effects, further optimizations can eliminate it
+ if (!set->isTee()) {
+ // we need to drop it
+ Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(set);
+ drop->value = *action.origin;
+ *action.origin = drop;
+ }
}
}
}
diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp
index de1191131..b30b8ffbd 100644
--- a/src/passes/DeadCodeElimination.cpp
+++ b/src/passes/DeadCodeElimination.cpp
@@ -31,6 +31,7 @@
#include <wasm.h>
#include <pass.h>
#include <ast_utils.h>
+#include <wasm-builder.h>
namespace wasm {
@@ -131,12 +132,8 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
}
void visitLoop(Loop* curr) {
- if (curr->in.is()) {
- reachableBreaks.erase(curr->in);
- }
- if (curr->out.is()) {
- reachable = reachable || reachableBreaks.count(curr->out);
- reachableBreaks.erase(curr->out);
+ if (curr->name.is()) {
+ reachableBreaks.erase(curr->name);
}
if (isDead(curr->body)) {
replaceCurrent(curr->body);
@@ -191,6 +188,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
case Expression::Id::UnaryId: DELEGATE(Unary);
case Expression::Id::BinaryId: DELEGATE(Binary);
case Expression::Id::SelectId: DELEGATE(Select);
+ case Expression::Id::DropId: DELEGATE(Drop);
case Expression::Id::ReturnId: DELEGATE(Return);
case Expression::Id::HostId: DELEGATE(Host);
case Expression::Id::NopId: DELEGATE(Nop);
@@ -226,46 +224,52 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
// other things
+ Expression* drop(Expression* toDrop) {
+ if (toDrop->is<Unreachable>()) return toDrop;
+ return Builder(*getModule()).makeDrop(toDrop);
+ }
+
template<typename T>
- void handleCall(T* curr, Expression* initial) {
+ Expression* handleCall(T* curr) {
for (Index i = 0; i < curr->operands.size(); i++) {
if (isDead(curr->operands[i])) {
- if (i > 0 || initial != nullptr) {
+ if (i > 0) {
auto* block = getModule()->allocator.alloc<Block>();
- Index newSize = i + 1 + (initial ? 1 : 0);
+ Index newSize = i + 1;
block->list.resize(newSize);
Index j = 0;
- if (initial) {
- block->list[j] = initial;
- j++;
- }
for (; j < newSize; j++) {
- block->list[j] = curr->operands[j - (initial ? 1 : 0)];
+ block->list[j] = drop(curr->operands[j]);
}
block->finalize();
- replaceCurrent(block);
+ return replaceCurrent(block);
} else {
- replaceCurrent(curr->operands[i]);
+ return replaceCurrent(curr->operands[i]);
}
- return;
}
}
+ return curr;
}
void visitCall(Call* curr) {
- handleCall(curr, nullptr);
+ handleCall(curr);
}
void visitCallImport(CallImport* curr) {
- handleCall(curr, nullptr);
+ handleCall(curr);
}
void visitCallIndirect(CallIndirect* curr) {
+ if (handleCall(curr) != curr) return;
if (isDead(curr->target)) {
- replaceCurrent(curr->target);
- return;
+ auto* block = getModule()->allocator.alloc<Block>();
+ for (auto* operand : curr->operands) {
+ block->list.push_back(drop(operand));
+ }
+ block->list.push_back(curr->target);
+ block->finalize();
+ replaceCurrent(block);
}
- handleCall(curr, curr->target);
}
void visitSetLocal(SetLocal* curr) {
@@ -288,7 +292,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
if (isDead(curr->value)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
- block->list[0] = curr->ptr;
+ block->list[0] = drop(curr->ptr);
block->list[1] = curr->value;
block->finalize();
replaceCurrent(block);
@@ -309,7 +313,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
if (isDead(curr->right)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
- block->list[0] = curr->left;
+ block->list[0] = drop(curr->left);
block->list[1] = curr->right;
block->finalize();
replaceCurrent(block);
@@ -324,7 +328,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
if (isDead(curr->ifFalse)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(2);
- block->list[0] = curr->ifTrue;
+ block->list[0] = drop(curr->ifTrue);
block->list[1] = curr->ifFalse;
block->finalize();
replaceCurrent(block);
@@ -333,8 +337,8 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V
if (isDead(curr->condition)) {
auto* block = getModule()->allocator.alloc<Block>();
block->list.resize(3);
- block->list[0] = curr->ifTrue;
- block->list[1] = curr->ifFalse;
+ block->list[0] = drop(curr->ifTrue);
+ block->list[1] = drop(curr->ifFalse);
block->list[2] = curr->condition;
block->finalize();
replaceCurrent(block);
diff --git a/src/passes/DropReturnValues.cpp b/src/passes/DropReturnValues.cpp
deleted file mode 100644
index 8715f3f61..000000000
--- a/src/passes/DropReturnValues.cpp
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Copyright 2016 WebAssembly Community Group participants
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-//
-// Stops using return values from set_local and store nodes.
-//
-
-#include <wasm.h>
-#include <pass.h>
-#include <ast_utils.h>
-#include <wasm-builder.h>
-
-namespace wasm {
-
-struct DropReturnValues : public WalkerPass<PostWalker<DropReturnValues, Visitor<DropReturnValues>>> {
- bool isFunctionParallel() override { return true; }
-
- Pass* create() override { return new DropReturnValues; }
-
- std::vector<Expression*> expressionStack;
-
- void visitSetLocal(SetLocal* curr) {
- if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) {
- Builder builder(*getModule());
- replaceCurrent(builder.makeSequence(
- curr,
- builder.makeGetLocal(curr->index, curr->type)
- ));
- }
- }
-
- void visitStore(Store* curr) {
- if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) {
- Index index = getFunction()->getNumLocals();
- getFunction()->vars.emplace_back(curr->type);
- Builder builder(*getModule());
- replaceCurrent(builder.makeSequence(
- builder.makeSequence(
- builder.makeSetLocal(index, curr->value),
- curr
- ),
- builder.makeGetLocal(index, curr->type)
- ));
- curr->value = builder.makeGetLocal(index, curr->type);
- }
- }
-
- static void visitPre(DropReturnValues* self, Expression** currp) {
- self->expressionStack.push_back(*currp);
- }
-
- static void visitPost(DropReturnValues* self, Expression** currp) {
- self->expressionStack.pop_back();
- }
-
- static void scan(DropReturnValues* self, Expression** currp) {
- self->pushTask(visitPost, currp);
-
- WalkerPass<PostWalker<DropReturnValues, Visitor<DropReturnValues>>>::scan(self, currp);
-
- self->pushTask(visitPre, currp);
- }
-};
-
-Pass *createDropReturnValuesPass() {
- return new DropReturnValues();
-}
-
-} // namespace wasm
-
diff --git a/src/passes/LowerIfElse.cpp b/src/passes/LowerIfElse.cpp
deleted file mode 100644
index b566e8207..000000000
--- a/src/passes/LowerIfElse.cpp
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Copyright 2015 WebAssembly Community Group participants
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-//
-// Lowers if (x) y else z into
-//
-// L: {
-// if (x) break (y) L
-// z
-// }
-//
-// This is useful for investigating how beneficial if_else is.
-//
-
-#include <memory>
-
-#include <wasm.h>
-#include <pass.h>
-
-namespace wasm {
-
-struct LowerIfElse : public WalkerPass<PostWalker<LowerIfElse, Visitor<LowerIfElse>>> {
- MixedArena* allocator;
- std::unique_ptr<NameManager> namer;
-
- void prepare(PassRunner* runner, Module *module) override {
- allocator = runner->allocator;
- namer = make_unique<NameManager>();
- namer->run(runner, module);
- }
-
- void visitIf(If *curr) {
- if (curr->ifFalse) {
- auto block = allocator->alloc<Block>();
- auto name = namer->getUnique("L"); // TODO: getUniqueInFunction
- block->name = name;
- block->list.push_back(curr);
- block->list.push_back(curr->ifFalse);
- block->finalize();
- curr->ifFalse = nullptr;
- auto break_ = allocator->alloc<Break>();
- break_->name = name;
- break_->value = curr->ifTrue;
- curr->ifTrue = break_;
- replaceCurrent(block);
- }
- }
-};
-
-Pass *createLowerIfElsePass() {
- return new LowerIfElse();
-}
-
-} // namespace wasm
diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp
index 686bb5d75..bde9397a8 100644
--- a/src/passes/MergeBlocks.cpp
+++ b/src/passes/MergeBlocks.cpp
@@ -64,9 +64,40 @@
#include <wasm.h>
#include <pass.h>
#include <ast_utils.h>
+#include <wasm-builder.h>
namespace wasm {
+struct SwitchFinder : public ControlFlowWalker<SwitchFinder, Visitor<SwitchFinder>> {
+ Expression* origin;
+ bool found = false;
+
+ void visitSwitch(Switch* curr) {
+ if (findBreakTarget(curr->default_) == origin) {
+ found = true;
+ return;
+ }
+ for (auto& target : curr->targets) {
+ if (findBreakTarget(target) == origin) {
+ found = true;
+ return;
+ }
+ }
+ }
+};
+
+struct BreakValueDropper : public ControlFlowWalker<BreakValueDropper, Visitor<BreakValueDropper>> {
+ Expression* origin;
+
+ void visitBreak(Break* curr) {
+ if (curr->value && findBreakTarget(curr->name) == origin) {
+ Builder builder(*getModule());
+ replaceCurrent(builder.makeSequence(builder.makeDrop(curr->value), curr));
+ curr->value = nullptr;
+ }
+ }
+};
+
struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBlocks>>> {
bool isFunctionParallel() override { return true; }
@@ -74,10 +105,46 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc
void visitBlock(Block *curr) {
bool more = true;
+ bool changed = false;
while (more) {
more = false;
for (size_t i = 0; i < curr->list.size(); i++) {
Block* child = curr->list[i]->dynCast<Block>();
+ if (!child) {
+ // if we have a child that is (drop (block ..)) then we can move the drop into the block, and remove br values. this allows more merging,
+ auto* drop = curr->list[i]->dynCast<Drop>();
+ if (drop) {
+ child = drop->value->dynCast<Block>();
+ if (child) {
+ if (child->name.is()) {
+ Expression* expression = child;
+ // if there is a switch targeting us, we can't do it - we can't remove the value from other targets too
+ SwitchFinder finder;
+ finder.origin = child;
+ finder.walk(expression);
+ if (finder.found) {
+ child = nullptr;
+ } else {
+ // fix up breaks
+ BreakValueDropper fixer;
+ fixer.origin = child;
+ fixer.setModule(getModule());
+ fixer.walk(expression);
+ }
+ }
+ if (child) {
+ // we can do it!
+ // reuse the drop
+ drop->value = child->list.back();
+ child->list.back() = drop;
+ child->finalize();
+ curr->list[i] = child;
+ more = true;
+ changed = true;
+ }
+ }
+ }
+ }
if (!child) continue;
if (child->name.is()) continue; // named blocks can have breaks to them (and certainly do, if we ran RemoveUnusedNames and RemoveUnusedBrs)
ExpressionList merged(getModule()->allocator);
@@ -92,9 +159,11 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc
}
curr->list = merged;
more = true;
+ changed = true;
break;
}
}
+ if (changed) curr->finalize();
}
Block* optimize(Expression* curr, Expression*& child, Block* outer = nullptr, Expression** dependency1 = nullptr, Expression** dependency2 = nullptr) {
diff --git a/src/passes/NameManager.cpp b/src/passes/NameManager.cpp
index df8b34557..9f0198c2f 100644
--- a/src/passes/NameManager.cpp
+++ b/src/passes/NameManager.cpp
@@ -37,8 +37,7 @@ void NameManager::visitBlock(Block* curr) {
names.insert(curr->name);
}
void NameManager::visitLoop(Loop* curr) {
- names.insert(curr->out);
- names.insert(curr->in);
+ names.insert(curr->name);
}
void NameManager::visitBreak(Break* curr) {
names.insert(curr->name);
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 5eea38bdc..43ddc954c 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -32,14 +32,17 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
const char *maybeSpace;
const char *maybeNewLine;
- bool fullAST = false; // whether to not elide nodes in output when possible
- // (like implicit blocks)
+ bool full = false; // whether to not elide nodes in output when possible
+ // (like implicit blocks) and to emit types
Module* currModule = nullptr;
Function* currFunction = nullptr;
PrintSExpression(std::ostream& o) : o(o) {
setMinify(false);
+ if (getenv("BINARYEN_PRINT_FULL")) {
+ full = std::stoi(getenv("BINARYEN_PRINT_FULL"));
+ }
}
void setMinify(bool minify_) {
@@ -48,7 +51,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
maybeNewLine = minify ? "" : "\n";
}
- void setFullAST(bool fullAST_) { fullAST = fullAST_; }
+ void setFull(bool full_) { full = full_; }
void incIndent() {
if (minify) return;
@@ -64,6 +67,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
}
void printFullLine(Expression *expression) {
!minify && doIndent(o, indent);
+ if (full) {
+ o << "[" << printWasmType(expression->type) << "] ";
+ }
visit(expression);
o << maybeNewLine;
}
@@ -79,10 +85,6 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
return name;
}
- Name printableGlobal(Index index) {
- return currModule->getGlobal(index)->name;
- }
-
std::ostream& printName(Name name) {
// we need to quote names if they have tricky chars
if (strpbrk(name.str, "()")) {
@@ -99,6 +101,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
while (1) {
if (stack.size() > 0) doIndent(o, indent);
stack.push_back(curr);
+ if (full) {
+ o << "[" << printWasmType(curr->type) << "] ";
+ }
printOpening(o, "block");
if (curr->name.is()) {
o << ' ';
@@ -135,13 +140,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
incIndent();
printFullLine(curr->condition);
// ifTrue and False have implict blocks, avoid printing them if possible
- if (!fullAST && curr->ifTrue->is<Block>() && curr->ifTrue->dynCast<Block>()->name.isNull() && curr->ifTrue->dynCast<Block>()->list.size() == 1) {
+ if (!full && curr->ifTrue->is<Block>() && curr->ifTrue->dynCast<Block>()->name.isNull() && curr->ifTrue->dynCast<Block>()->list.size() == 1) {
printFullLine(curr->ifTrue->dynCast<Block>()->list.back());
} else {
printFullLine(curr->ifTrue);
}
if (curr->ifFalse) {
- if (!fullAST && curr->ifFalse->is<Block>() && curr->ifFalse->dynCast<Block>()->name.isNull() && curr->ifFalse->dynCast<Block>()->list.size() == 1) {
+ if (!full && curr->ifFalse->is<Block>() && curr->ifFalse->dynCast<Block>()->name.isNull() && curr->ifFalse->dynCast<Block>()->list.size() == 1) {
printFullLine(curr->ifFalse->dynCast<Block>()->list.back());
} else {
printFullLine(curr->ifFalse);
@@ -151,16 +156,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
}
void visitLoop(Loop *curr) {
printOpening(o, "loop");
- if (curr->out.is()) {
- o << ' ' << curr->out;
- assert(curr->in.is()); // if just one is printed, it must be the in
- }
- if (curr->in.is()) {
- o << ' ' << curr->in;
+ if (curr->name.is()) {
+ o << ' ' << curr->name;
}
incIndent();
auto block = curr->body->dynCast<Block>();
- if (!fullAST && block && block->name.isNull()) {
+ if (!full && block && block->name.isNull()) {
// wasm spec has loops containing children directly, while our ast
// has a single child for simplicity. print out the optimal form.
for (auto expression : block->list) {
@@ -229,26 +230,33 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
void visitCallIndirect(CallIndirect *curr) {
printOpening(o, "call_indirect ") << curr->fullType;
incIndent();
- printFullLine(curr->target);
for (auto operand : curr->operands) {
printFullLine(operand);
}
+ printFullLine(curr->target);
decIndent();
}
void visitGetLocal(GetLocal *curr) {
printOpening(o, "get_local ") << printableLocal(curr->index) << ')';
}
void visitSetLocal(SetLocal *curr) {
- printOpening(o, "set_local ") << printableLocal(curr->index);
+ if (curr->isTee()) {
+ printOpening(o, "tee_local ");
+ } else {
+ printOpening(o, "set_local ");
+ }
+ o << printableLocal(curr->index);
incIndent();
printFullLine(curr->value);
decIndent();
}
void visitGetGlobal(GetGlobal *curr) {
- printOpening(o, "get_global ") << printableGlobal(curr->index) << ')';
+ printOpening(o, "get_global ");
+ printName(curr->name) << ')';
}
void visitSetGlobal(SetGlobal *curr) {
- printOpening(o, "set_global ") << printableGlobal(curr->index);
+ printOpening(o, "set_global ");
+ printName(curr->name);
incIndent();
printFullLine(curr->value);
decIndent();
@@ -281,7 +289,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
}
void visitStore(Store *curr) {
o << '(';
- prepareColor(o) << printWasmType(curr->type) << ".store";
+ prepareColor(o) << printWasmType(curr->valueType) << ".store";
if (curr->bytes < 4 || (curr->type == i64 && curr->bytes < 8)) {
if (curr->bytes == 1) {
o << '8';
@@ -466,9 +474,16 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printFullLine(curr->condition);
decIndent();
}
+ void visitDrop(Drop *curr) {
+ o << '(';
+ prepareColor(o) << "drop";
+ incIndent();
+ printFullLine(curr->value);
+ decIndent();
+ }
void visitReturn(Return *curr) {
printOpening(o, "return");
- if (!curr->value || curr->value->is<Nop>()) {
+ if (!curr->value) {
// avoid a new line just for the parens
o << ')';
return;
@@ -499,11 +514,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
printMinorOpening(o, "unreachable") << ')';
}
// Module-level visitors
- void visitFunctionType(FunctionType *curr, bool full=false) {
- if (full) {
- printOpening(o, "type") << ' ';
- printName(curr->name) << " (func";
- }
+ void visitFunctionType(FunctionType *curr, Name* internalName = nullptr) {
+ o << "(func";
+ if (internalName) o << ' ' << *internalName;
if (curr->params.size() > 0) {
o << maybeSpace;
printMinorOpening(o, "param");
@@ -516,27 +529,39 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
o << maybeSpace;
printMinorOpening(o, "result ") << printWasmType(curr->result) << ')';
}
- if (full) {
- o << "))";
- }
+ o << ")";
}
void visitImport(Import *curr) {
printOpening(o, "import ");
- printName(curr->name) << ' ';
printText(o, curr->module.str) << ' ';
- printText(o, curr->base.str);
- if (curr->type) visitFunctionType(curr->type);
+ printText(o, curr->base.str) << ' ';
+ switch (curr->kind) {
+ case Export::Function: if (curr->functionType) visitFunctionType(curr->functionType, &curr->name); break;
+ case Export::Table: o << "(table " << curr->name << ")"; break;
+ case Export::Memory: o << "(memory " << curr->name << ")"; break;
+ case Export::Global: o << "(global " << curr->name << ' ' << printWasmType(curr->globalType) << ")"; break;
+ default: WASM_UNREACHABLE();
+ }
o << ')';
}
void visitExport(Export *curr) {
printOpening(o, "export ");
- printText(o, curr->name.str) << ' ';
- printName(curr->value) << ')';
+ printText(o, curr->name.str) << " (";
+ switch (curr->kind) {
+ case Export::Function: o << "func"; break;
+ case Export::Table: o << "table"; break;
+ case Export::Memory: o << "memory"; break;
+ case Export::Global: o << "global"; break;
+ default: WASM_UNREACHABLE();
+ }
+ o << ' ';
+ printName(curr->value) << "))";
}
void visitGlobal(Global *curr) {
printOpening(o, "global ");
- printName(curr->name) << ' ' << printWasmType(curr->type);
- printFullLine(curr->init);
+ printName(curr->name) << ' ';
+ o << printWasmType(curr->type) << ' ';
+ visit(curr->init);
o << ')';
}
void visitFunction(Function *curr) {
@@ -564,7 +589,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
}
// It is ok to emit a block here, as a function can directly contain a list, even if our
// ast avoids that for simplicity. We can just do that optimization here..
- if (!fullAST && curr->body->is<Block>() && curr->body->cast<Block>()->name.isNull()) {
+ if (!full && curr->body->is<Block>() && curr->body->cast<Block>()->name.isNull()) {
Block* block = curr->body->cast<Block>();
for (auto item : block->list) {
printFullLine(item);
@@ -575,7 +600,8 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
decIndent();
}
void visitTable(Table *curr) {
- printOpening(o, "table") << ' ' << curr->initial;
+ printOpening(o, "table") << ' ';
+ o << curr->initial;
if (curr->max && curr->max != Table::kMaxSize) o << ' ' << curr->max;
o << " anyfunc)\n";
doIndent(o, indent);
@@ -589,15 +615,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
o << ')';
}
}
- void visitModule(Module *curr) {
- currModule = curr;
- printOpening(o, "module", true);
- incIndent();
- doIndent(o, indent);
- printOpening(o, "memory") << ' ' << curr->memory.initial;
- if (curr->memory.max && curr->memory.max != Memory::kMaxSize) o << ' ' << curr->memory.max;
+ void visitMemory(Memory* curr) {
+ printOpening(o, "memory") << ' ';
+ o << curr->initial;
+ if (curr->max && curr->max != Memory::kMaxSize) o << ' ' << curr->max;
o << ")\n";
- for (auto segment : curr->memory.segments) {
+ for (auto segment : curr->segments) {
doIndent(o, indent);
printOpening(o, "data ", true);
visit(segment.offset);
@@ -624,12 +647,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
}
o << "\")\n";
}
- if (curr->memory.exportName.is()) {
- doIndent(o, indent);
- printOpening(o, "export ");
- printText(o, curr->memory.exportName.str) << " memory)";
- o << maybeNewLine;
- }
+ }
+ void visitModule(Module *curr) {
+ currModule = curr;
+ printOpening(o, "module", true);
+ incIndent();
+ doIndent(o, indent);
+ visitMemory(&curr->memory);
if (curr->start.is()) {
doIndent(o, indent);
printOpening(o, "start") << ' ' << curr->start << ')';
@@ -637,8 +661,10 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
}
for (auto& child : curr->functionTypes) {
doIndent(o, indent);
- visitFunctionType(child.get(), true);
- o << maybeNewLine;
+ printOpening(o, "type") << ' ';
+ printName(child->name) << ' ';
+ visitFunctionType(child.get());
+ o << ")" << maybeNewLine;
}
for (auto& child : curr->imports) {
doIndent(o, indent);
@@ -707,7 +733,7 @@ public:
void run(PassRunner* runner, Module* module) override {
PrintSExpression print(o);
- print.setFullAST(true);
+ print.setFull(true);
print.visitModule(module);
}
};
@@ -718,9 +744,17 @@ Pass *createFullPrinterPass() {
// Print individual expressions
-std::ostream& WasmPrinter::printExpression(Expression* expression, std::ostream& o, bool minify) {
+std::ostream& WasmPrinter::printExpression(Expression* expression, std::ostream& o, bool minify, bool full) {
+ if (!expression) {
+ o << "(null expression)";
+ return o;
+ }
PrintSExpression print(o);
print.setMinify(minify);
+ if (full) {
+ print.setFull(true);
+ o << "[" << printWasmType(expression->type) << "] ";
+ }
print.visit(expression);
return o;
}
diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp
index 0b3f50049..19d6c3eb1 100644
--- a/src/passes/RemoveImports.cpp
+++ b/src/passes/RemoveImports.cpp
@@ -37,7 +37,7 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports, Visitor<Remov
}
void visitCallImport(CallImport *curr) {
- WasmType type = module->getImport(curr->target)->type->result;
+ WasmType type = module->getImport(curr->target)->functionType->result;
if (type == none) {
replaceCurrent(allocator->alloc<Nop>());
} else {
diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp
index 263cea655..59d3af6fc 100644
--- a/src/passes/RemoveUnusedBrs.cpp
+++ b/src/passes/RemoveUnusedBrs.cpp
@@ -189,7 +189,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R
// finally, we may have simplified ifs enough to turn them into selects
struct Selectifier : public WalkerPass<PostWalker<Selectifier, Visitor<Selectifier>>> {
void visitIf(If* curr) {
- if (curr->ifFalse) {
+ if (curr->ifFalse && isConcreteWasmType(curr->ifTrue->type) && isConcreteWasmType(curr->ifFalse->type)) {
// if with else, consider turning it into a select if there is no control flow
// TODO: estimate cost
EffectAnalyzer condition(curr->condition);
diff --git a/src/passes/RemoveUnusedNames.cpp b/src/passes/RemoveUnusedNames.cpp
index 9c6743479..8e24f9549 100644
--- a/src/passes/RemoveUnusedNames.cpp
+++ b/src/passes/RemoveUnusedNames.cpp
@@ -76,34 +76,12 @@ struct RemoveUnusedNames : public WalkerPass<PostWalker<RemoveUnusedNames, Visit
}
}
handleBreakTarget(curr->name);
- if (curr->name.is() && curr->list.size() == 1) {
- auto* child = curr->list[0]->dynCast<Loop>();
- if (child && !child->out.is()) {
- // we have just one child, this loop, and it lacks an out label. So this block's name is doing just that!
- child->out = curr->name;
- replaceCurrent(child);
- }
- }
}
void visitLoop(Loop *curr) {
- handleBreakTarget(curr->in);
- // Loops can have just 'in', but cannot have just 'out'
- auto out = curr->out;
- handleBreakTarget(curr->out);
- if (curr->out.is() && !curr->in.is()) {
- auto* block = getModule()->allocator.alloc<Block>();
- block->name = out;
- block->list.push_back(curr->body);
- replaceCurrent(block);
- }
- if (curr->in.is() && !curr->out.is()) {
- auto* child = curr->body->dynCast<Block>();
- if (child && child->name.is()) {
- // we have just one child, this block, and we lack an out label. So we can take the block's!
- curr->out = child->name;
- child->name = Name();
- }
+ handleBreakTarget(curr->name);
+ if (!curr->name.is()) {
+ replaceCurrent(curr->body);
}
}
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index d9786e62c..5315edac4 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -55,7 +55,13 @@ struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemo
void visitSetLocal(SetLocal *curr) {
if ((*numGetLocals)[curr->index] == 0) {
- replaceCurrent(curr->value);
+ auto* value = curr->value;
+ if (curr->isTee()) {
+ replaceCurrent(value);
+ } else {
+ Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(curr);
+ drop->value = value;
+ }
}
}
};
@@ -180,7 +186,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
auto found = sinkables.find(curr->index);
if (found != sinkables.end()) {
// sink it, and nop the origin
- replaceCurrent(*found->second.item);
+ auto* set = (*found->second.item)->cast<SetLocal>();
+ replaceCurrent(set);
+ assert(!set->isTee());
+ set->setTee(true);
// reuse the getlocal that is dying
*found->second.item = curr;
ExpressionManipulator::nop(curr);
@@ -189,6 +198,16 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
}
}
+ void visitDrop(Drop* curr) {
+ // collapse drop-tee into set, which can occur if a get was sunk into a tee
+ auto* set = curr->value->dynCast<SetLocal>();
+ if (set) {
+ assert(set->isTee());
+ set->setTee(false);
+ replaceCurrent(set);
+ }
+ }
+
void checkInvalidations(EffectAnalyzer& effects) {
// TODO: this is O(bad)
std::vector<Index> invalidated;
@@ -225,7 +244,11 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
// store is dead, leave just the value
auto found = self->sinkables.find(set->index);
if (found != self->sinkables.end()) {
- *found->second.item = (*found->second.item)->cast<SetLocal>()->value;
+ auto* previous = (*found->second.item)->cast<SetLocal>();
+ assert(!previous->isTee());
+ auto* previousValue = previous->value;
+ Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(previous);
+ drop->value = previousValue;
self->sinkables.erase(found);
self->anotherCycle = true;
}
@@ -236,15 +259,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
self->checkInvalidations(effects);
}
- if (set) {
- // we may be a replacement for the current node, update the stack
- self->expressionStack.pop_back();
- self->expressionStack.push_back(set);
- if (!ExpressionAnalyzer::isResultUsed(self->expressionStack, self->getFunction())) {
- Index index = set->index;
- assert(self->sinkables.count(index) == 0);
- self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp)));
- }
+ if (set && !set->isTee()) {
+ Index index = set->index;
+ assert(self->sinkables.count(index) == 0);
+ self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp)));
}
self->expressionStack.pop_back();
diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp
index 0195427d7..db42a994a 100644
--- a/src/passes/Vacuum.cpp
+++ b/src/passes/Vacuum.cpp
@@ -25,13 +25,11 @@
namespace wasm {
-struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> {
+struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>>> {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new Vacuum; }
- std::vector<Expression*> expressionStack;
-
// returns nullptr if curr is dead, curr if it must stay as is, or another node if it can be replaced
Expression* optimize(Expression* curr, bool resultUsed) {
while (1) {
@@ -41,6 +39,7 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> {
case Expression::Id::BlockId: return curr; // not always needed, but handled in visitBlock()
case Expression::Id::IfId: return curr; // not always needed, but handled in visitIf()
case Expression::Id::LoopId: return curr; // not always needed, but handled in visitLoop()
+ case Expression::Id::DropId: return curr; // not always needed, but handled in visitDrop()
case Expression::Id::BreakId:
case Expression::Id::SwitchId:
@@ -51,6 +50,8 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> {
case Expression::Id::LoadId:
case Expression::Id::StoreId:
case Expression::Id::ReturnId:
+ case Expression::Id::GetGlobalId:
+ case Expression::Id::SetGlobalId:
case Expression::Id::HostId:
case Expression::Id::UnreachableId: return curr; // always needed
@@ -189,7 +190,7 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> {
// no else
if (curr->ifTrue->is<Nop>()) {
// no nothing
- replaceCurrent(curr->condition);
+ replaceCurrent(Builder(*getModule()).makeDrop(curr->condition));
}
}
}
@@ -198,21 +199,30 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> {
if (curr->body->is<Nop>()) ExpressionManipulator::nop(curr);
}
- static void visitPre(Vacuum* self, Expression** currp) {
- self->expressionStack.push_back(*currp);
- }
-
- static void visitPost(Vacuum* self, Expression** currp) {
- self->expressionStack.pop_back();
- }
-
- // override scan to add a pre and a post check task to all nodes
- static void scan(Vacuum* self, Expression** currp) {
- self->pushTask(visitPost, currp);
-
- WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>>::scan(self, currp);
-
- self->pushTask(visitPre, currp);
+ void visitDrop(Drop* curr) {
+ // if the drop input has no side effects, it can be wiped out
+ if (!EffectAnalyzer(curr->value).hasSideEffects()) {
+ ExpressionManipulator::nop(curr);
+ return;
+ }
+ // sink a drop into an arm of an if-else if the other arm ends in an unreachable, as it if is a branch, this can make that branch optimizable and more vaccuming possible
+ auto* iff = curr->value->dynCast<If>();
+ if (iff && iff->ifFalse && isConcreteWasmType(iff->type)) {
+ // reuse the drop in both cases
+ if (iff->ifTrue->type == unreachable) {
+ assert(isConcreteWasmType(iff->ifFalse->type));
+ curr->value = iff->ifFalse;
+ iff->ifFalse = curr;
+ iff->type = none;
+ replaceCurrent(iff);
+ } else if (iff->ifFalse->type == unreachable) {
+ assert(isConcreteWasmType(iff->ifTrue->type));
+ curr->value = iff->ifTrue;
+ iff->ifTrue = curr;
+ iff->type = none;
+ replaceCurrent(iff);
+ }
+ }
}
void visitFunction(Function* curr) {
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index f27eccf6a..437b023bc 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -64,9 +64,7 @@ void PassRegistry::registerPasses() {
registerPass("coalesce-locals", "reduce # of locals by coalescing", createCoalesceLocalsPass);
registerPass("coalesce-locals-learning", "reduce # of locals by coalescing and learning", createCoalesceLocalsWithLearningPass);
registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass);
- registerPass("drop-return-values", "stops relying on return values from set_local and store", createDropReturnValuesPass);
registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass);
- registerPass("lower-if-else", "lowers if-elses into ifs, blocks and branches", createLowerIfElsePass);
registerPass("merge-blocks", "merges blocks to their parents", createMergeBlocksPass);
registerPass("metrics", "reports metrics", createMetricsPass);
registerPass("nm", "name list", createNameListPass);
@@ -211,6 +209,9 @@ void PassRunner::run() {
}
void PassRunner::runFunction(Function* func) {
+ if (debug) {
+ std::cerr << "[PassRunner] running passes on function " << func->name << std::endl;
+ }
for (auto* pass : passes) {
runPassOnFunction(pass, func);
}
diff --git a/src/passes/passes.h b/src/passes/passes.h
index 731536d7d..4bb76edad 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -25,7 +25,6 @@ class Pass;
Pass *createCoalesceLocalsPass();
Pass *createCoalesceLocalsWithLearningPass();
Pass *createDeadCodeEliminationPass();
-Pass *createDropReturnValuesPass();
Pass *createDuplicateFunctionEliminationPass();
Pass *createLowerIfElsePass();
Pass *createMergeBlocksPass();
diff --git a/src/s2wasm.h b/src/s2wasm.h
index a23cbb549..8c04f7891 100644
--- a/src/s2wasm.h
+++ b/src/s2wasm.h
@@ -753,6 +753,7 @@ class S2WasmBuilder {
set->index = func->getLocalIndex(assign);
set->value = curr;
set->type = curr->type;
+ set->setTee(false);
addToBlock(set);
}
};
@@ -834,7 +835,7 @@ class S2WasmBuilder {
auto makeStore = [&](WasmType type) {
skipComma();
auto curr = allocator->alloc<Store>();
- curr->type = type;
+ curr->valueType = type;
int32_t bytes = getInt() / CHAR_BIT;
curr->bytes = bytes > 0 ? bytes : getWasmTypeSize(type);
Name assign = getAssign();
@@ -849,6 +850,7 @@ class S2WasmBuilder {
curr->align = 1U << getInt(attributes[0] + 8);
}
curr->value = inputs[1];
+ curr->finalize();
setOutput(curr, assign);
};
auto makeSelect = [&](WasmType type) {
@@ -1062,7 +1064,7 @@ class S2WasmBuilder {
if (target->is<Block>()) {
return target->cast<Block>()->name;
} else {
- return target->cast<Loop>()->in;
+ return target->cast<Loop>()->name;
}
};
// fixups
@@ -1097,10 +1099,9 @@ class S2WasmBuilder {
} else if (match("loop")) {
auto curr = allocator->alloc<Loop>();
addToBlock(curr);
- curr->in = getNextLabel();
- curr->out = getNextLabel();
+ curr->name = getNextLabel();
auto block = allocator->alloc<Block>();
- block->name = curr->out; // temporary, fake - this way, on bstack we have the right label at the right offset for a br
+ block->name = getNextLabel();
curr->body = block;
loopBlocks.push_back(block);
bstack.push_back(block);
@@ -1146,6 +1147,7 @@ class S2WasmBuilder {
Name assign = getAssign();
skipComma();
auto curr = allocator->alloc<SetLocal>();
+ curr->setTee(true);
curr->index = func->getLocalIndex(getAssign());
skipComma();
curr->value = getInput();
diff --git a/src/shell-interface.h b/src/shell-interface.h
index f332307ad..ee9ff166a 100644
--- a/src/shell-interface.h
+++ b/src/shell-interface.h
@@ -90,11 +90,11 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
ShellExternalInterface() : memory() {}
- void init(Module& wasm) override {
+ void init(Module& wasm, ModuleInstance& instance) override {
memory.resize(wasm.memory.initial * wasm::Memory::kPageSize);
// apply memory segments
for (auto& segment : wasm.memory.segments) {
- Address offset = ConstantExpressionRunner().visit(segment.offset).value.geti32();
+ Address offset = ConstantExpressionRunner(instance.globals).visit(segment.offset).value.geti32();
assert(offset + segment.data.size() <= wasm.memory.initial * wasm::Memory::kPageSize);
for (size_t i = 0; i != segment.data.size(); ++i) {
memory.set(offset + i, segment.data[i]);
@@ -103,7 +103,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
table.resize(wasm.table.initial);
for (auto& segment : wasm.table.segments) {
- Address offset = ConstantExpressionRunner().visit(segment.offset).value.geti32();
+ Address offset = ConstantExpressionRunner(instance.globals).visit(segment.offset).value.geti32();
assert(offset + segment.data.size() <= wasm.table.initial);
for (size_t i = 0; i != segment.data.size(); ++i) {
table[offset + i] = segment.data[i];
@@ -111,6 +111,8 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
}
}
+ void importGlobals(std::map<Name, Literal>& globals, Module& wasm) override {}
+
Literal callImport(Import *import, LiteralList& arguments) override {
if (import->module == SPECTEST && import->base == PRINT) {
for (auto argument : arguments) {
@@ -126,10 +128,9 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
abort();
}
- Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) override {
+ Literal callTable(Index index, LiteralList& arguments, WasmType result, ModuleInstance& instance) override {
if (index >= table.size()) trap("callTable overflow");
auto* func = instance.wasm.getFunction(table[index]);
- if (func->type.is() && func->type != type) trap("callIndirect: bad type");
if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments");
for (size_t i = 0; i < func->params.size(); i++) {
if (func->params[i] != arguments[i].type) {
@@ -167,7 +168,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
}
void store(Store* store, Address addr, Literal value) override {
- switch (store->type) {
+ switch (store->valueType) {
case i32: {
switch (store->bytes) {
case 1: memory.set<int8_t>(addr, value.geti32()); break;
diff --git a/src/support/colors.h b/src/support/colors.h
index d81ecbc16..fb5267ce1 100644
--- a/src/support/colors.h
+++ b/src/support/colors.h
@@ -42,6 +42,15 @@ inline void grey(std::ostream& stream) { outputColorCode(stream, 0x08); }
inline void green(std::ostream& stream) { outputColorCode(stream, 0x02); }
inline void blue(std::ostream& stream) { outputColorCode(stream, 0x09); }
inline void bold(std::ostream& stream) { /* Do nothing */ }
+#else
+inline void normal(std::ostream& stream) {}
+inline void red(std::ostream& stream) {}
+inline void magenta(std::ostream& stream) {}
+inline void orange(std::ostream& stream) {}
+inline void grey(std::ostream& stream) {}
+inline void green(std::ostream& stream) {}
+inline void blue(std::ostream& stream) {}
+inline void bold(std::ostream& stream) {}
#endif
};
diff --git a/src/tools/asm2wasm.cpp b/src/tools/asm2wasm.cpp
index 7f1a07d1d..031b1fd23 100644
--- a/src/tools/asm2wasm.cpp
+++ b/src/tools/asm2wasm.cpp
@@ -21,6 +21,7 @@
#include "support/colors.h"
#include "support/command-line.h"
#include "support/file.h"
+#include "wasm-builder.h"
#include "wasm-printing.h"
#include "asm2wasm.h"
@@ -40,9 +41,13 @@ int main(int argc, const char *argv[]) {
o->extra["output"] = argument;
Colors::disable();
})
- .add("--mapped-globals", "-m", "Mapped globals", Options::Arguments::One,
+ .add("--mapped-globals", "-n", "Mapped globals", Options::Arguments::One,
[](Options *o, const std::string &argument) {
- o->extra["mapped globals"] = argument;
+ std::cerr << "warning: the --mapped-globals/-m option is deprecated (a mapped globals file is no longer needed as we use wasm globals)" << std::endl;
+ })
+ .add("--mem-init", "-t", "Import a memory initialization file into the output module", Options::Arguments::One,
+ [](Options *o, const std::string &argument) {
+ o->extra["mem init"] = argument;
})
.add("--total-memory", "-m", "Total memory size", Options::Arguments::One,
[](Options *o, const std::string &argument) {
@@ -62,10 +67,6 @@ int main(int argc, const char *argv[]) {
});
options.parse(argc, argv);
- const auto &mg_it = options.extra.find("mapped globals");
- const char *mappedGlobals =
- mg_it == options.extra.end() ? nullptr : mg_it->second.c_str();
-
const auto &tm_it = options.extra.find("total memory");
size_t totalMemory =
tm_it == options.extra.end() ? 16 * 1024 * 1024 : atoi(tm_it->second.c_str());
@@ -95,15 +96,18 @@ int main(int argc, const char *argv[]) {
Asm2WasmBuilder asm2wasm(wasm, pre.memoryGrowth, options.debug, imprecise, opts);
asm2wasm.processAsm(asmjs);
+ // import mem init file, if provided
+ const auto &memInit = options.extra.find("mem init");
+ if (memInit != options.extra.end()) {
+ auto filename = memInit->second.c_str();
+ auto data(read_file<std::vector<char>>(filename, Flags::Binary, options.debug ? Flags::Debug : Flags::Release));
+ // create the memory segment
+ wasm.memory.segments.emplace_back(Builder(wasm).makeGetGlobal(Name("memoryBase"), i32), data);
+ }
+
if (options.debug) std::cerr << "printing..." << std::endl;
Output output(options.extra["output"], Flags::Text, options.debug ? Flags::Debug : Flags::Release);
WasmPrinter::printModule(&wasm, output.getStream());
- if (mappedGlobals) {
- if (options.debug)
- std::cerr << "serializing mapped globals..." << std::endl;
- asm2wasm.serializeMappedGlobals(mappedGlobals);
- }
-
if (options.debug) std::cerr << "done." << std::endl;
}
diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp
index 72e8e13ba..a95ced87f 100644
--- a/src/tools/wasm-shell.cpp
+++ b/src/tools/wasm-shell.cpp
@@ -125,14 +125,19 @@ static void run_asserts(size_t* i, bool* checked, Module* wasm,
// maybe parsed ok, but otherwise incorrect
invalid = !WasmValidator().validate(wasm);
}
- assert(invalid);
+ if (!invalid) {
+ Colors::red(std::cerr);
+ std::cerr << "[should have been invalid]\n";
+ Colors::normal(std::cerr);
+ std::cerr << &wasm << '\n';
+ abort();
+ }
} else if (id == INVOKE) {
assert(wasm);
Invocation invocation(curr, instance.get(), *builder->get());
invocation.invoke();
- } else {
+ } else if (wasm) { // if no wasm, we skipped the module
// an invoke test
- assert(wasm);
bool trapped = false;
WASM_UNUSED(trapped);
Literal result;
@@ -169,6 +174,7 @@ static void run_asserts(size_t* i, bool* checked, Module* wasm,
int main(int argc, const char* argv[]) {
Name entry;
+ std::set<size_t> skipped;
Options options("wasm-shell", "Execute .wast files");
options
@@ -176,6 +182,21 @@ int main(int argc, const char* argv[]) {
"--entry", "-e", "call the entry point after parsing the module",
Options::Arguments::One,
[&entry](Options*, const std::string& argument) { entry = argument; })
+ .add(
+ "--skip", "-s", "skip input on certain lines (comma-separated-list)",
+ Options::Arguments::One,
+ [&skipped](Options*, const std::string& argument) {
+ size_t i = 0;
+ while (i < argument.size()) {
+ auto ending = argument.find(',', i);
+ if (ending == std::string::npos) {
+ ending = argument.size();
+ }
+ auto sub = argument.substr(i, ending - i);
+ skipped.insert(atoi(sub.c_str()));
+ i = ending + 1;
+ }
+ })
.add_positional("INFILE", Options::Arguments::One,
[](Options* o, const std::string& argument) {
o->extra["infile"] = argument;
@@ -195,9 +216,19 @@ int main(int argc, const char* argv[]) {
size_t i = 0;
while (i < root.size()) {
Element& curr = *root[i];
+ if (skipped.count(curr.line) > 0) {
+ Colors::green(std::cerr);
+ std::cerr << "SKIPPING [line: " << curr.line << "]\n";
+ Colors::normal(std::cerr);
+ i++;
+ continue;
+ }
IString id = curr[0]->str();
if (id == MODULE) {
if (options.debug) std::cerr << "parsing s-expressions to wasm...\n";
+ Colors::green(std::cerr);
+ std::cerr << "BUILDING MODULE [line: " << curr.line << "]\n";
+ Colors::normal(std::cerr);
Module wasm;
std::unique_ptr<SExpressionWasmBuilder> builder;
builder = wasm::make_unique<SExpressionWasmBuilder>(wasm, *root[i]);
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 04bebddc5..40671b082 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -234,7 +234,7 @@ namespace BinaryConsts {
enum Meta {
Magic = 0x6d736100,
- Version = 11
+ Version = 0x0c
};
namespace Section {
@@ -412,6 +412,7 @@ enum ASTNodes {
CallFunction = 0x16,
CallIndirect = 0x17,
CallImport = 0x18,
+ TeeLocal = 0x19,
GetGlobal = 0x1a,
SetGlobal = 0x1b,
@@ -426,6 +427,7 @@ enum ASTNodes {
TableSwitch = 0x08,
Return = 0x09,
Unreachable = 0x0a,
+ Drop = 0x0b,
End = 0x0f
};
@@ -529,8 +531,7 @@ public:
if (debug) std::cerr << "== writeMemory" << std::endl;
auto start = startSection(BinaryConsts::Section::Memory);
o << U32LEB(wasm->memory.initial)
- << U32LEB(wasm->memory.max)
- << int8_t(wasm->memory.exportName.is()); // export memory
+ << U32LEB(wasm->memory.max);
finishSection(start);
}
@@ -571,7 +572,14 @@ public:
o << U32LEB(wasm->imports.size());
for (auto& import : wasm->imports) {
if (debug) std::cerr << "write one" << std::endl;
- o << U32LEB(getFunctionTypeIndex(import->type->name));
+ o << U32LEB(import->kind);
+ switch (import->kind) {
+ case Export::Function: o << U32LEB(getFunctionTypeIndex(import->functionType->name));
+ case Export::Table: break;
+ case Export::Memory: break;
+ case Export::Global: o << binaryWasmType(import->globalType);break;
+ default: WASM_UNREACHABLE();
+ }
writeInlineString(import->module.str);
writeInlineString(import->base.str);
}
@@ -690,7 +698,14 @@ public:
o << U32LEB(wasm->exports.size());
for (auto& curr : wasm->exports) {
if (debug) std::cerr << "write one" << std::endl;
- o << U32LEB(getFunctionIndex(curr->value));
+ o << U32LEB(curr->kind);
+ switch (curr->kind) {
+ case Export::Function: o << U32LEB(getFunctionIndex(curr->value)); break;
+ case Export::Table: o << U32LEB(0); break;
+ case Export::Memory: o << U32LEB(0); break;
+ case Export::Global: o << U32LEB(getGlobalIndex(curr->value)); break;
+ default: WASM_UNREACHABLE();
+ }
writeInlineString(curr->name.str);
}
finishSection(start);
@@ -726,7 +741,7 @@ public:
return mappedImports[name];
}
- std::map<Name, uint32_t> mappedFunctions; // name of the Function => index
+ std::map<Name, uint32_t> mappedFunctions; // name of the Function => index
uint32_t getFunctionIndex(Name name) {
if (!mappedFunctions.size()) {
// Create name => index mapping.
@@ -739,6 +754,26 @@ public:
return mappedFunctions[name];
}
+ std::map<Name, uint32_t> mappedGlobals; // name of the Global => index. first imported globals, then internal globals
+ uint32_t getGlobalIndex(Name name) {
+ if (!mappedGlobals.size()) {
+ // Create name => index mapping.
+ for (auto& import : wasm->imports) {
+ if (import->kind != Import::Global) continue;
+ assert(mappedGlobals.count(import->name) == 0);
+ auto index = mappedGlobals.size();
+ mappedGlobals[import->name] = index;
+ }
+ for (size_t i = 0; i < wasm->globals.size(); i++) {
+ assert(mappedGlobals.count(wasm->globals[i]->name) == 0);
+ auto index = mappedGlobals.size();
+ mappedGlobals[wasm->globals[i]->name] = index;
+ }
+ }
+ assert(mappedGlobals.count(name));
+ return mappedGlobals[name];
+ }
+
void writeFunctionTable() {
if (wasm->table.segments.size() == 0) return;
if (debug) std::cerr << "== writeFunctionTable" << std::endl;
@@ -873,11 +908,9 @@ public:
void visitLoop(Loop *curr) {
if (debug) std::cerr << "zz node: Loop" << std::endl;
o << int8_t(BinaryConsts::Loop);
- breakStack.push_back(curr->out);
- breakStack.push_back(curr->in);
+ breakStack.push_back(curr->name);
recursePossibleBlockContents(curr->body);
breakStack.pop_back();
- breakStack.pop_back();
o << int8_t(BinaryConsts::End);
}
@@ -939,18 +972,18 @@ public:
o << int8_t(BinaryConsts::GetLocal) << U32LEB(mappedLocals[curr->index]);
}
void visitSetLocal(SetLocal *curr) {
- if (debug) std::cerr << "zz node: SetLocal" << std::endl;
+ if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl;
recurse(curr->value);
- o << int8_t(BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]);
+ o << int8_t(curr->isTee() ? BinaryConsts::TeeLocal : BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]);
}
void visitGetGlobal(GetGlobal *curr) {
if (debug) std::cerr << "zz node: GetGlobal " << (o.size() + 1) << std::endl;
- o << int8_t(BinaryConsts::GetGlobal) << U32LEB(curr->index);
+ o << int8_t(BinaryConsts::GetGlobal) << U32LEB(getGlobalIndex(curr->name));
}
void visitSetGlobal(SetGlobal *curr) {
if (debug) std::cerr << "zz node: SetGlobal" << std::endl;
recurse(curr->value);
- o << int8_t(BinaryConsts::SetGlobal) << U32LEB(curr->index);
+ o << int8_t(BinaryConsts::SetGlobal) << U32LEB(getGlobalIndex(curr->name));
}
void emitMemoryAccess(size_t alignment, size_t bytes, uint32_t offset) {
@@ -991,7 +1024,7 @@ public:
if (debug) std::cerr << "zz node: Store" << std::endl;
recurse(curr->ptr);
recurse(curr->value);
- switch (curr->type) {
+ switch (curr->valueType) {
case i32: {
switch (curr->bytes) {
case 1: o << int8_t(BinaryConsts::I32StoreMem8); break;
@@ -1219,6 +1252,11 @@ public:
if (debug) std::cerr << "zz node: Unreachable" << std::endl;
o << int8_t(BinaryConsts::Unreachable);
}
+ void visitDrop(Drop *curr) {
+ if (debug) std::cerr << "zz node: Drop" << std::endl;
+ recurse(curr->value);
+ o << int8_t(BinaryConsts::Drop);
+ }
};
class WasmBinaryBuilder {
@@ -1432,10 +1470,6 @@ public:
if (debug) std::cerr << "== readMemory" << std::endl;
wasm.memory.initial = getU32LEB();
wasm.memory.max = getU32LEB();
- auto exportMemory = getInt8();
- if (exportMemory) {
- wasm.memory.exportName = Name("memory");
- }
}
void readSignatures() {
@@ -1472,10 +1506,20 @@ public:
if (debug) std::cerr << "read one" << std::endl;
auto curr = new Import;
curr->name = Name(std::string("import$") + std::to_string(i));
- auto index = getU32LEB();
- assert(index < wasm.functionTypes.size());
- curr->type = wasm.getFunctionType(index);
- assert(curr->type->name.is());
+ curr->kind = (Import::Kind)getU32LEB();
+ switch (curr->kind) {
+ case Export::Function: {
+ auto index = getU32LEB();
+ assert(index < wasm.functionTypes.size());
+ curr->functionType = wasm.getFunctionType(index);
+ assert(curr->functionType->name.is());
+ break;
+ }
+ case Export::Table: break;
+ case Export::Memory: break;
+ case Export::Global: curr->globalType = getWasmType(); break;
+ default: WASM_UNREACHABLE();
+ }
curr->module = getInlineString();
curr->base = getInlineString();
wasm.addImport(curr);
@@ -1573,8 +1617,8 @@ public:
for (size_t i = 0; i < num; i++) {
if (debug) std::cerr << "read one" << std::endl;
auto curr = new Export;
+ curr->kind = (Export::Kind)getU32LEB();
auto index = getU32LEB();
- assert(index < functionTypes.size());
curr->name = getInlineString();
exportIndexes[curr] = index;
}
@@ -1627,6 +1671,24 @@ public:
return ret;
}
+ std::map<Index, Name> mappedGlobals; // index of the Global => name. first imported globals, then internal globals
+ Name getGlobalName(Index index) {
+ if (!mappedGlobals.size()) {
+ // Create name => index mapping.
+ for (auto& import : wasm.imports) {
+ if (import->kind != Import::Global) continue;
+ auto index = mappedGlobals.size();
+ mappedGlobals[index] = import->name;
+ }
+ for (size_t i = 0; i < wasm.globals.size(); i++) {
+ auto index = mappedGlobals.size();
+ mappedGlobals[index] = wasm.globals[i]->name;
+ }
+ }
+ assert(mappedGlobals.count(index));
+ return mappedGlobals[index];
+ }
+
void processFunctions() {
for (auto& func : functions) {
wasm.addFunction(func);
@@ -1639,7 +1701,13 @@ public:
for (auto& iter : exportIndexes) {
Export* curr = iter.first;
- curr->value = wasm.functions[iter.second]->name;
+ switch (curr->kind) {
+ case Export::Function: curr->value = wasm.functions[iter.second]->name; break;
+ case Export::Table: curr->value = Name::fromInt(0); break;
+ case Export::Memory: curr->value = Name::fromInt(0); break;
+ case Export::Global: curr->value = getGlobalName(iter.second); break;
+ default: WASM_UNREACHABLE();
+ }
wasm.addExport(curr);
}
@@ -1728,13 +1796,15 @@ public:
case BinaryConsts::CallImport: visitCallImport((curr = allocator.alloc<CallImport>())->cast<CallImport>()); break;
case BinaryConsts::CallIndirect: visitCallIndirect((curr = allocator.alloc<CallIndirect>())->cast<CallIndirect>()); break;
case BinaryConsts::GetLocal: visitGetLocal((curr = allocator.alloc<GetLocal>())->cast<GetLocal>()); break;
- case BinaryConsts::SetLocal: visitSetLocal((curr = allocator.alloc<SetLocal>())->cast<SetLocal>()); break;
+ case BinaryConsts::TeeLocal:
+ case BinaryConsts::SetLocal: visitSetLocal((curr = allocator.alloc<SetLocal>())->cast<SetLocal>(), code); break;
case BinaryConsts::GetGlobal: visitGetGlobal((curr = allocator.alloc<GetGlobal>())->cast<GetGlobal>()); break;
case BinaryConsts::SetGlobal: visitSetGlobal((curr = allocator.alloc<SetGlobal>())->cast<SetGlobal>()); break;
case BinaryConsts::Select: visitSelect((curr = allocator.alloc<Select>())->cast<Select>()); break;
case BinaryConsts::Return: visitReturn((curr = allocator.alloc<Return>())->cast<Return>()); break;
case BinaryConsts::Nop: visitNop((curr = allocator.alloc<Nop>())->cast<Nop>()); break;
case BinaryConsts::Unreachable: visitUnreachable((curr = allocator.alloc<Unreachable>())->cast<Unreachable>()); break;
+ case BinaryConsts::Drop: visitDrop((curr = allocator.alloc<Drop>())->cast<Drop>()); break;
case BinaryConsts::End:
case BinaryConsts::Else: curr = nullptr; break;
default: {
@@ -1832,13 +1902,10 @@ public:
}
void visitLoop(Loop *curr) {
if (debug) std::cerr << "zz node: Loop" << std::endl;
- curr->out = getNextLabel();
- curr->in = getNextLabel();
- breakStack.push_back(curr->out);
- breakStack.push_back(curr->in);
+ curr->name = getNextLabel();
+ breakStack.push_back(curr->name);
curr->body = getMaybeBlock();
breakStack.pop_back();
- breakStack.pop_back();
curr->finalize();
}
@@ -1890,7 +1957,7 @@ public:
WASM_UNUSED(arity);
auto import = wasm.getImport(getU32LEB());
curr->target = import->name;
- auto type = import->type;
+ auto type = import->functionType;
assert(type);
auto num = type->params.size();
assert(num == arity);
@@ -1922,25 +1989,35 @@ public:
assert(curr->index < currFunction->getNumLocals());
curr->type = currFunction->getLocalType(curr->index);
}
- void visitSetLocal(SetLocal *curr) {
- if (debug) std::cerr << "zz node: SetLocal" << std::endl;
+ void visitSetLocal(SetLocal *curr, uint8_t code) {
+ if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl;
curr->index = getU32LEB();
assert(curr->index < currFunction->getNumLocals());
curr->value = popExpression();
curr->type = curr->value->type;
+ curr->setTee(code == BinaryConsts::TeeLocal);
}
void visitGetGlobal(GetGlobal *curr) {
if (debug) std::cerr << "zz node: GetGlobal " << pos << std::endl;
- curr->index = getU32LEB();
- assert(curr->index < wasm.globals.size());
- curr->type = wasm.globals[curr->index]->type;
+ auto index = getU32LEB();
+ curr->name = getGlobalName(index);
+ auto* global = wasm.checkGlobal(curr->name);
+ if (global) {
+ curr->type = global->type;
+ return;
+ }
+ auto* import = wasm.checkImport(curr->name);
+ if (import && import->kind == Import::Global) {
+ curr->type = import->globalType;
+ return;
+ }
+ throw ParseException("bad get_global");
}
void visitSetGlobal(SetGlobal *curr) {
if (debug) std::cerr << "zz node: SetGlobal" << std::endl;
- curr->index = getU32LEB();
- assert(curr->index < wasm.globals.size());
+ auto index = getU32LEB();
+ curr->name = getGlobalName(index);
curr->value = popExpression();
- curr->type = curr->value->type;
}
void readMemoryAccess(Address& alignment, size_t bytes, Address& offset) {
@@ -1976,21 +2053,22 @@ public:
bool maybeVisitStore(Expression*& out, uint8_t code) {
Store* curr;
switch (code) {
- case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->type = i32; break;
- case BinaryConsts::I32StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->type = i32; break;
- case BinaryConsts::I32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = i32; break;
- case BinaryConsts::I64StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->type = i64; break;
- case BinaryConsts::I64StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->type = i64; break;
- case BinaryConsts::I64StoreMem32: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = i64; break;
- case BinaryConsts::I64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->type = i64; break;
- case BinaryConsts::F32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = f32; break;
- case BinaryConsts::F64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->type = f64; break;
+ case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->valueType = i32; break;
+ case BinaryConsts::I32StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->valueType = i32; break;
+ case BinaryConsts::I32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = i32; break;
+ case BinaryConsts::I64StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->valueType = i64; break;
+ case BinaryConsts::I64StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->valueType = i64; break;
+ case BinaryConsts::I64StoreMem32: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = i64; break;
+ case BinaryConsts::I64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->valueType = i64; break;
+ case BinaryConsts::F32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = f32; break;
+ case BinaryConsts::F64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->valueType = f64; break;
default: return false;
}
if (debug) std::cerr << "zz node: Store" << std::endl;
readMemoryAccess(curr->align, curr->bytes, curr->offset);
curr->value = popExpression();
curr->ptr = popExpression();
+ curr->finalize();
out = curr;
return true;
}
@@ -2175,6 +2253,10 @@ public:
void visitUnreachable(Unreachable *curr) {
if (debug) std::cerr << "zz node: Unreachable" << std::endl;
}
+ void visitDrop(Drop *curr) {
+ if (debug) std::cerr << "zz node: Drop" << std::endl;
+ curr->value = popExpression();
+ }
};
} // namespace wasm
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index 22e1f9a00..2841585e6 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -82,9 +82,9 @@ public:
ret->finalize();
return ret;
}
- Loop* makeLoop(Name out, Name in, Expression* body) {
+ Loop* makeLoop(Name name, Expression* body) {
auto* ret = allocator.alloc<Loop>();
- ret->out = out; ret->in = in; ret->body = body;
+ ret->name = name; ret->body = body;
ret->finalize();
return ret;
}
@@ -142,20 +142,26 @@ public:
auto* ret = allocator.alloc<SetLocal>();
ret->index = index;
ret->value = value;
+ ret->type = none;
+ return ret;
+ }
+ SetLocal* makeTeeLocal(Index index, Expression* value) {
+ auto* ret = allocator.alloc<SetLocal>();
+ ret->index = index;
+ ret->value = value;
ret->type = value->type;
return ret;
}
- GetGlobal* makeGetGlobal(Index index, WasmType type) {
+ GetGlobal* makeGetGlobal(Name name, WasmType type) {
auto* ret = allocator.alloc<GetGlobal>();
- ret->index = index;
+ ret->name = name;
ret->type = type;
return ret;
}
- SetGlobal* makeSetGlobal(Index index, Expression* value) {
+ SetGlobal* makeSetGlobal(Name name, Expression* value) {
auto* ret = allocator.alloc<SetGlobal>();
- ret->index = index;
+ ret->name = name;
ret->value = value;
- ret->type = value->type;
return ret;
}
Load* makeLoad(unsigned bytes, bool signed_, uint32_t offset, unsigned align, Expression *ptr, WasmType type) {
@@ -164,10 +170,11 @@ public:
ret->type = type;
return ret;
}
- Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value) {
+ Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value, WasmType type) {
auto* ret = allocator.alloc<Store>();
- ret->bytes = bytes; ret->offset = offset; ret->align = align; ret->ptr = ptr; ret->value = value;
- ret->type = value->type;
+ ret->bytes = bytes; ret->offset = offset; ret->align = align; ret->ptr = ptr; ret->value = value; ret->valueType = type;
+ ret->finalize();
+ assert(isConcreteWasmType(ret->value->type) ? ret->value->type == type : true);
return ret;
}
Const* makeConst(Literal value) {
@@ -205,12 +212,22 @@ public:
ret->op = op;
ret->nameOperand = nameOperand;
ret->operands.set(operands);
+ ret->finalize();
return ret;
}
Unreachable* makeUnreachable() {
return allocator.alloc<Unreachable>();
}
+ // Additional helpers
+
+ Drop* makeDrop(Expression *value) {
+ auto* ret = allocator.alloc<Drop>();
+ ret->value = value;
+ ret->finalize();
+ return ret;
+ }
+
// Additional utility functions for building on top of nodes
static Index addParam(Function* func, Name name, WasmType type) {
@@ -247,7 +264,21 @@ public:
if (!block) block = makeBlock(any);
if (append) {
block->list.push_back(append);
- block->finalize();
+ block->finalize(); // TODO: move out of if
+ }
+ return block;
+ }
+
+ // ensure a node is a block, if it isn't already, and optionally append to the block
+ // this variant sets a name for the block, so it will not reuse a block already named
+ Block* blockifyWithName(Expression* any, Name name, Expression* append = nullptr) {
+ Block* block = nullptr;
+ if (any) block = any->dynCast<Block>();
+ if (!block || block->name.is()) block = makeBlock(any);
+ block->name = name;
+ if (append) {
+ block->list.push_back(append);
+ block->finalize(); // TODO: move out of if
}
return block;
}
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index 28a3e8e6d..292d6a521 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -31,8 +31,6 @@
#ifdef WASM_INTERPRETER_DEBUG
#include "wasm-printing.h"
-
-int indent = 0;
#endif
@@ -77,32 +75,11 @@ typedef std::vector<Literal> LiteralList;
// Debugging helpers
#ifdef WASM_INTERPRETER_DEBUG
-struct IndentHandler {
- const char *name;
- IndentHandler(const char *name, Expression *expression) : name(name) {
- doIndent(std::cout, indent);
- std::cout << "visit " << name << " :\n";
- indent++;
-#if WASM_INTERPRETER_DEBUG == 2
- doIndent(std::cout, indent);
- if (expression) std::cout << "\n" << expression << '\n';
- indent++;
-#endif
- }
- ~IndentHandler() {
-#if WASM_INTERPRETER_DEBUG == 2
- indent--;
-#endif
- indent--;
- doIndent(std::cout, indent);
- std::cout << "exit " << name << '\n';
- }
-};
-#define NOTE_ENTER(x) IndentHandler indentHandler(x, curr);
-#define NOTE_ENTER_(x) IndentHandler indentHandler(x, nullptr);
-#define NOTE_NAME(p0) { doIndent(std::cout, indent); std::cout << "name in " << indentHandler.name << '(' << Name(p0) << ")\n"; }
-#define NOTE_EVAL1(p0) { doIndent(std::cout, indent); std::cout << "eval in " << indentHandler.name << '(' << p0 << ")\n"; }
-#define NOTE_EVAL2(p0, p1) { doIndent(std::cout, indent); std::cout << "eval in " << indentHandler.name << '(' << p0 << ", " << p1 << ")\n"; }
+#define NOTE_ENTER(x) { std::cout << "visit " << x << " : " << curr << "\n"; }
+#define NOTE_ENTER_(x) { std::cout << "visit " << x << "\n"; }
+#define NOTE_NAME(p0) { std::cout << "name " << '(' << Name(p0) << ")\n"; }
+#define NOTE_EVAL1(p0) { std::cout << "eval " << '(' << p0 << ")\n"; }
+#define NOTE_EVAL2(p0, p1) { std::cout << "eval " << '(' << p0 << ", " << p1 << ")\n"; }
#else // WASM_INTERPRETER_DEBUG
#define NOTE_ENTER(x)
#define NOTE_ENTER_(x)
@@ -170,8 +147,7 @@ public:
while (1) {
Flow flow = visit(curr->body);
if (flow.breaking()) {
- if (flow.breakTo == curr->in) continue; // lol
- flow.clearIf(curr->out);
+ if (flow.breakTo == curr->name) continue; // lol
}
return flow; // loop does not loop automatically, only continue achieves that
}
@@ -435,6 +411,12 @@ public:
NOTE_EVAL1(condition.value);
return condition.value.geti32() ? ifTrue : ifFalse; // ;-)
}
+ Flow visitDrop(Drop *curr) {
+ NOTE_ENTER("Drop");
+ Flow value = visit(curr->value);
+ if (value.breaking()) return value;
+ return Flow();
+ }
Flow visitReturn(Return *curr) {
NOTE_ENTER("Return");
Flow flow;
@@ -503,14 +485,17 @@ public:
// Execute an constant expression in a global init or memory offset
class ConstantExpressionRunner : public ExpressionRunner<ConstantExpressionRunner> {
+ std::map<Name, Literal>& globals;
public:
+ ConstantExpressionRunner(std::map<Name, Literal>& globals) : globals(globals) {}
+
Flow visitLoop(Loop* curr) { WASM_UNREACHABLE(); }
Flow visitCall(Call* curr) { WASM_UNREACHABLE(); }
Flow visitCallImport(CallImport* curr) { WASM_UNREACHABLE(); }
Flow visitCallIndirect(CallIndirect* curr) { WASM_UNREACHABLE(); }
Flow visitGetLocal(GetLocal *curr) { WASM_UNREACHABLE(); }
Flow visitSetLocal(SetLocal *curr) { WASM_UNREACHABLE(); }
- Flow visitGetGlobal(GetGlobal *curr) { WASM_UNREACHABLE(); }
+ Flow visitGetGlobal(GetGlobal *curr) { return Flow(globals[curr->name]); }
Flow visitSetGlobal(SetGlobal *curr) { WASM_UNREACHABLE(); }
Flow visitLoad(Load *curr) { WASM_UNREACHABLE(); }
Flow visitStore(Store *curr) { WASM_UNREACHABLE(); }
@@ -535,9 +520,10 @@ public:
// an imported function or accessing memory.
//
struct ExternalInterface {
- virtual void init(Module& wasm) {}
+ virtual void init(Module& wasm, ModuleInstance& instance) {}
+ virtual void importGlobals(std::map<Name, Literal>& globals, Module& wasm) = 0;
virtual Literal callImport(Import* import, LiteralList& arguments) = 0;
- virtual Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) = 0;
+ virtual Literal callTable(Index index, LiteralList& arguments, WasmType result, ModuleInstance& instance) = 0;
virtual Literal load(Load* load, Address addr) = 0;
virtual void store(Store* store, Address addr, Literal value) = 0;
virtual void growMemory(Address oldSize, Address newSize) = 0;
@@ -547,14 +533,20 @@ public:
Module& wasm;
// Values of globals
- std::vector<Literal> globals;
+ std::map<Name, Literal> globals;
ModuleInstance(Module& wasm, ExternalInterface* externalInterface) : wasm(wasm), externalInterface(externalInterface) {
+ // import globals from the outside
+ externalInterface->importGlobals(globals, wasm);
+ // prepare memory
memorySize = wasm.memory.initial;
- for (Index i = 0; i < wasm.globals.size(); i++) {
- globals.push_back(ConstantExpressionRunner().visit(wasm.globals[i]->init).value);
+ // generate internal (non-imported) globals
+ for (auto& global : wasm.globals) {
+ globals[global->name] = ConstantExpressionRunner(globals).visit(global->init).value;
}
- externalInterface->init(wasm);
+ // initialize the rest of the external interface
+ externalInterface->init(wasm, *this);
+ // run start, if present
if (wasm.start.is()) {
LiteralList arguments;
callFunction(wasm.start, arguments);
@@ -670,13 +662,13 @@ public:
}
Flow visitCallIndirect(CallIndirect *curr) {
NOTE_ENTER("CallIndirect");
- Flow target = visit(curr->target);
- if (target.breaking()) return target;
LiteralList arguments;
Flow flow = generateArguments(curr->operands, arguments);
if (flow.breaking()) return flow;
+ Flow target = visit(curr->target);
+ if (target.breaking()) return target;
Index index = target.value.geti32();
- return instance.externalInterface->callTable(index, curr->fullType, arguments, instance);
+ return instance.externalInterface->callTable(index, arguments, curr->type, instance);
}
Flow visitGetLocal(GetLocal *curr) {
@@ -693,28 +685,27 @@ public:
if (flow.breaking()) return flow;
NOTE_EVAL1(index);
NOTE_EVAL1(flow.value);
- assert(flow.value.type == curr->type);
+ assert(curr->isTee() ? flow.value.type == curr->type : true);
scope.locals[index] = flow.value;
- return flow;
+ return curr->isTee() ? flow : Flow();
}
Flow visitGetGlobal(GetGlobal *curr) {
NOTE_ENTER("GetGlobal");
- auto index = curr->index;
- NOTE_EVAL1(index);
- NOTE_EVAL1(instance.globals[index]);
- return instance.globals[index];
+ auto name = curr->name;
+ NOTE_EVAL1(name);
+ NOTE_EVAL1(instance.globals[name]);
+ return instance.globals[name];
}
Flow visitSetGlobal(SetGlobal *curr) {
NOTE_ENTER("SetGlobal");
- auto index = curr->index;
+ auto name = curr->name;
Flow flow = visit(curr->value);
if (flow.breaking()) return flow;
- NOTE_EVAL1(index);
+ NOTE_EVAL1(name);
NOTE_EVAL1(flow.value);
- assert(flow.value.type == curr->type);
- instance.globals[index] = flow.value;
- return flow;
+ instance.globals[name] = flow.value;
+ return Flow();
}
Flow visitLoad(Load *curr) {
@@ -730,7 +721,7 @@ public:
Flow value = visit(curr->value);
if (value.breaking()) return value;
instance.externalInterface->store(curr, instance.getFinalAddress(curr, ptr.value), value.value);
- return value;
+ return Flow();
}
Flow visitHost(Host *curr) {
@@ -739,14 +730,15 @@ public:
case PageSize: return Literal((int32_t)Memory::kPageSize);
case CurrentMemory: return Literal(int32_t(instance.memorySize));
case GrowMemory: {
+ auto fail = Literal(int32_t(-1));
Flow flow = visit(curr->operands[0]);
if (flow.breaking()) return flow;
int32_t ret = instance.memorySize;
uint32_t delta = flow.value.geti32();
- if (delta > uint32_t(-1) /Memory::kPageSize) trap("growMemory: delta relatively too big");
- if (instance.memorySize >= uint32_t(-1) - delta) trap("growMemory: delta objectively too big");
+ if (delta > uint32_t(-1) /Memory::kPageSize) return fail;
+ if (instance.memorySize >= uint32_t(-1) - delta) return fail;
uint32_t newSize = instance.memorySize + delta;
- if (newSize > instance.wasm.memory.max) trap("growMemory: exceeds max");
+ if (newSize > instance.wasm.memory.max) return fail;
instance.externalInterface->growMemory(instance.memorySize * Memory::kPageSize, newSize * Memory::kPageSize);
instance.memorySize = newSize;
return Literal(int32_t(ret));
diff --git a/src/wasm-js.cpp b/src/wasm-js.cpp
index 83956e47e..7d71cec3c 100644
--- a/src/wasm-js.cpp
+++ b/src/wasm-js.cpp
@@ -81,21 +81,6 @@ extern "C" void EMSCRIPTEN_KEEPALIVE load_asm2wasm(char *input) {
if (wasmJSDebug) std::cerr << "wasming...\n";
asm2wasm = new Asm2WasmBuilder(*module, pre.memoryGrowth, debug, false /* TODO: support imprecise? */, false /* TODO: support optimizing? */);
asm2wasm->processAsm(asmjs);
-
- if (wasmJSDebug) std::cerr << "mapping globals...\n";
- for (auto& pair : asm2wasm->mappedGlobals) {
- auto name = pair.first;
- auto& global = pair.second;
- if (!global.import) continue; // non-imports are initialized to zero in the typed array anyhow, so nothing to do here
- double value = EM_ASM_DOUBLE({ return Module['lookupImport'](Pointer_stringify($0), Pointer_stringify($1)) }, global.module.str, global.base.str);
- uint32_t address = global.address;
- switch (global.type) {
- case i32: EM_ASM_({ Module['info'].parent['HEAP32'][$0 >> 2] = $1 }, address, value); break;
- case f32: EM_ASM_({ Module['info'].parent['HEAPF32'][$0 >> 2] = $1 }, address, value); break;
- case f64: EM_ASM_({ Module['info'].parent['HEAPF64'][$0 >> 3] = $1 }, address, value); break;
- default: abort();
- }
- }
}
void finalizeModule() {
@@ -160,14 +145,16 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
Module['asmExports'] = {};
});
for (auto& curr : module->exports) {
- EM_ASM_({
- var name = Pointer_stringify($0);
- Module['asmExports'][name] = function() {
- Module['tempArguments'] = Array.prototype.slice.call(arguments);
- Module['_call_from_js']($0);
- return Module['tempReturn'];
- };
- }, curr->name.str);
+ if (curr->kind == Export::Function) {
+ EM_ASM_({
+ var name = Pointer_stringify($0);
+ Module['asmExports'][name] = function() {
+ Module['tempArguments'] = Array.prototype.slice.call(arguments);
+ Module['_call_from_js']($0);
+ return Module['tempReturn'];
+ };
+ }, curr->name.str);
+ }
}
// verify imports are provided
@@ -176,32 +163,68 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
var mod = Pointer_stringify($0);
var base = Pointer_stringify($1);
var name = Pointer_stringify($2);
- assert(Module['lookupImport'](mod, base), 'checking import ' + name + ' = ' + mod + '.' + base);
+ assert(Module['lookupImport'](mod, base) !== undefined, 'checking import ' + name + ' = ' + mod + '.' + base);
}, import->module.str, import->base.str, import->name.str);
}
if (wasmJSDebug) std::cerr << "creating instance...\n";
struct JSExternalInterface : ModuleInstance::ExternalInterface {
- void init(Module& wasm) override {
- // create a new buffer here, just like native wasm support would.
- EM_ASM_({
- Module['outside']['newBuffer'] = new ArrayBuffer($0);
- }, wasm.memory.initial * Memory::kPageSize);
+ void init(Module& wasm, ModuleInstance& instance) override {
+ // look for imported memory
+ {
+ bool found = false;
+ for (auto& import : wasm.imports) {
+ if (import->module == ENV && import->base == MEMORY) {
+ assert(import->kind == Import::Memory);
+ // memory is imported
+ EM_ASM({
+ Module['asmExports']['memory'] = Module['lookupImport']('env', 'memory');
+ });
+ found = true;
+ }
+ }
+ if (!found) {
+ // no memory import; create a new buffer here, just like native wasm support would.
+ EM_ASM_({
+ Module['asmExports']['memory'] = Module['outside']['newBuffer'] = new ArrayBuffer($0);
+ }, wasm.memory.initial * Memory::kPageSize);
+ }
+ }
for (auto segment : wasm.memory.segments) {
EM_ASM_({
var source = Module['HEAP8'].subarray($1, $1 + $2);
- var target = new Int8Array(Module['outside']['newBuffer']);
+ var target = new Int8Array(Module['asmExports']['memory']);
target.set(source, $0);
- }, ConstantExpressionRunner().visit(segment.offset).value.geti32(), &segment.data[0], segment.data.size());
+ }, ConstantExpressionRunner(instance.globals).visit(segment.offset).value.geti32(), &segment.data[0], segment.data.size());
+ }
+ // look for imported table
+ {
+ bool found = false;
+ for (auto& import : wasm.imports) {
+ if (import->module == ENV && import->base == TABLE) {
+ assert(import->kind == Import::Table);
+ // table is imported
+ EM_ASM({
+ Module['outside']['wasmTable'] = Module['lookupImport']('env', 'table');
+ });
+ found = true;
+ }
+ }
+ if (!found) {
+ // no table import; create a new one here, just like native wasm support would.
+ EM_ASM_({
+ Module['outside']['wasmTable'] = new Array($0);
+ }, wasm.table.initial);
+ }
}
- // Table support is in a JS array. If the entry is a number, it's a function pointer. If not, it's a JS method to be called directly
+ EM_ASM({
+ Module['asmExports']['table'] = Module['outside']['wasmTable'];
+ });
+ // Emulated table support is in a JS array. If the entry is a number, it's a function pointer. If not, it's a JS method to be called directly
// TODO: make them all JS methods, wrapping a dynCall where necessary?
- EM_ASM_({
- Module['outside']['wasmTable'] = new Array($0);
- }, wasm.table.initial);
for (auto segment : wasm.table.segments) {
- Address offset = ConstantExpressionRunner().visit(segment.offset).value.geti32();
+ Address offset = ConstantExpressionRunner(instance.globals).visit(segment.offset).value.geti32();
assert(offset + segment.data.size() <= wasm.table.initial);
for (size_t i = 0; i != segment.data.size(); ++i) {
EM_ASM_({
@@ -238,6 +261,23 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
}
}
+ void importGlobals(std::map<Name, Literal>& globals, Module& wasm) override {
+ for (auto& import : wasm.imports) {
+ if (import->kind == Import::Global) {
+ double ret = EM_ASM_DOUBLE({
+ var mod = Pointer_stringify($0);
+ var base = Pointer_stringify($1);
+ var lookup = Module['lookupImport'](mod, base);
+ return lookup;
+ }, import->module.str, import->base.str);
+
+ if (wasmJSDebug) std::cout << "calling importGlobal for " << import->name << " returning " << ret << '\n';
+
+ globals[import->name] = getResultFromJS(ret, import->globalType);
+ }
+ }
+ }
+
Literal callImport(Import *import, LiteralList& arguments) override {
if (wasmJSDebug) std::cout << "calling import " << import->name.str << '\n';
prepareTempArgments(arguments);
@@ -252,10 +292,10 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
if (wasmJSDebug) std::cout << "calling import returning " << ret << '\n';
- return getResultFromJS(ret, import->type->result);
+ return getResultFromJS(ret, import->functionType->result);
}
- Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) override {
+ Literal callTable(Index index, LiteralList& arguments, WasmType result, ModuleInstance& instance) override {
void* ptr = (void*)EM_ASM_INT({
var value = Module['outside']['wasmTable'][$0];
return typeof value === "number" ? value : -1;
@@ -264,7 +304,6 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
if (ptr != (void*)-1) {
// a Function we can call
Function* func = (Function*)ptr;
- if (func->type.is() && func->type != type) trap("callIndirect: bad type");
if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments");
for (size_t i = 0; i < func->params.size(); i++) {
if (func->params[i] != arguments[i].type) {
@@ -281,7 +320,7 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
Module['tempArguments'] = null;
return func.apply(null, tempArguments);
}, index);
- return getResultFromJS(ret, instance.wasm.getFunctionType(type)->result);
+ return getResultFromJS(ret, result);
}
}
@@ -411,11 +450,11 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() {
Module["info"].parent["HEAPU8"][addr + i] = HEAPU8[i];
}
HEAP32[0] = save0; HEAP32[1] = save1;
- }, (uint32_t)addr, store_->bytes, isWasmTypeFloat(store_->type), isWasmTypeFloat(store_->type) ? value.getFloat() : (double)value.getInteger());
+ }, (uint32_t)addr, store_->bytes, isWasmTypeFloat(store_->valueType), isWasmTypeFloat(store_->valueType) ? value.getFloat() : (double)value.getInteger());
return;
}
// nicely aligned
- if (!isWasmTypeFloat(store_->type)) {
+ if (!isWasmTypeFloat(store_->valueType)) {
if (store_->bytes == 1) {
EM_ASM_INT({ Module['info'].parent['HEAP8'][$0] = $1 }, addr, value.geti32());
} else if (store_->bytes == 2) {
diff --git a/src/wasm-linker.cpp b/src/wasm-linker.cpp
index b1160e94a..04deddf4a 100644
--- a/src/wasm-linker.cpp
+++ b/src/wasm-linker.cpp
@@ -49,7 +49,8 @@ void Linker::ensureImport(Name target, std::string signature) {
auto import = new Import;
import->name = import->base = target;
import->module = ENV;
- import->type = ensureFunctionType(signature, &out.wasm);
+ import->functionType = ensureFunctionType(signature, &out.wasm);
+ import->kind = Import::Function;
out.wasm.addImport(import);
}
}
@@ -106,7 +107,12 @@ void Linker::layout() {
}
if (userMaxMemory) out.wasm.memory.max = userMaxMemory / Memory::kPageSize;
- out.wasm.memory.exportName = MEMORY;
+
+ auto memoryExport = make_unique<Export>();
+ memoryExport->name = MEMORY;
+ memoryExport->value = Name::fromInt(0);
+ memoryExport->kind = Export::Memory;
+ out.wasm.addExport(memoryExport.release());
// XXX For now, export all functions marked .globl.
for (Name name : out.globls) exportFunction(name, false);
@@ -333,7 +339,8 @@ void Linker::emscriptenGlue(std::ostream& o) {
auto import = new Import;
import->name = import->base = curr->target;
import->module = ENV;
- import->type = ensureFunctionType(getSig(curr), &parent->out.wasm);
+ import->functionType = ensureFunctionType(getSig(curr), &parent->out.wasm);
+ import->kind = Import::Function;
parent->out.wasm.addImport(import);
}
}
diff --git a/src/wasm-linker.h b/src/wasm-linker.h
index a6f5d319a..21d336273 100644
--- a/src/wasm-linker.h
+++ b/src/wasm-linker.h
@@ -310,6 +310,7 @@ class Linker {
if (out.wasm.checkExport(name)) return; // Already exported
auto exp = new Export;
exp->name = exp->value = name;
+ exp->kind = Export::Function;
out.wasm.addExport(exp);
}
diff --git a/src/wasm-module-building.h b/src/wasm-module-building.h
index ead074991..43cc493d1 100644
--- a/src/wasm-module-building.h
+++ b/src/wasm-module-building.h
@@ -73,6 +73,7 @@ static std::mutex debug;
class OptimizingIncrementalModuleBuilder {
Module* wasm;
uint32_t numFunctions;
+ std::function<void (PassRunner&)> addPrePasses;
Function* endMarker;
std::atomic<Function*>* list;
uint32_t nextFunction; // only used on main thread
@@ -86,8 +87,8 @@ class OptimizingIncrementalModuleBuilder {
public:
// numFunctions must be equal to the number of functions allocated, or higher. Knowing
// this bounds helps avoid locking.
- OptimizingIncrementalModuleBuilder(Module* wasm, Index numFunctions)
- : wasm(wasm), numFunctions(numFunctions), endMarker(nullptr), list(nullptr), nextFunction(0),
+ OptimizingIncrementalModuleBuilder(Module* wasm, Index numFunctions, std::function<void (PassRunner&)> addPrePasses)
+ : wasm(wasm), numFunctions(numFunctions), addPrePasses(addPrePasses), endMarker(nullptr), list(nullptr), nextFunction(0),
numWorkers(0), liveWorkers(0), activeWorkers(0), availableFuncs(0), finishedFuncs(0),
finishing(false) {
if (numFunctions == 0) {
@@ -201,6 +202,7 @@ private:
void optimizeFunction(Function* func) {
PassRunner passRunner(wasm);
+ addPrePasses(passRunner);
passRunner.addDefaultFunctionOptimizationPasses();
passRunner.runFunction(func);
}
diff --git a/src/wasm-printing.h b/src/wasm-printing.h
index 2f1c97831..830f2dda2 100644
--- a/src/wasm-printing.h
+++ b/src/wasm-printing.h
@@ -36,7 +36,7 @@ struct WasmPrinter {
return printModule(module, std::cout);
}
- static std::ostream& printExpression(Expression* expression, std::ostream& o, bool minify = false);
+ static std::ostream& printExpression(Expression* expression, std::ostream& o, bool minify = false, bool full = false);
};
}
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index b70f70b4e..9abdea744 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -60,13 +60,15 @@ class Element {
List list_;
IString str_;
bool dollared_;
+ bool quoted_;
public:
Element(MixedArena& allocator) : isList_(true), list_(allocator), line(-1), col(-1) {}
bool isList() { return isList_; }
bool isStr() { return !isList_; }
- bool dollared() { return dollared_; }
+ bool dollared() { return isStr() && dollared_; }
+ bool quoted() { return isStr() && quoted_; }
size_t line, col;
@@ -98,10 +100,11 @@ public:
return str_.str;
}
- Element* setString(IString str__, bool dollared__) {
+ Element* setString(IString str__, bool dollared__, bool quoted__) {
isList_ = false;
str_ = str__;
dollared_ = dollared__;
+ quoted_ = quoted__;
return this;
}
@@ -242,13 +245,13 @@ private:
input++;
}
input++;
- return allocator.alloc<Element>()->setString(IString(str.c_str(), false), dollared)->setMetadata(line, start - lineStart);
+ return allocator.alloc<Element>()->setString(IString(str.c_str(), false), dollared, true)->setMetadata(line, start - lineStart);
}
while (input[0] && !isspace(input[0]) && input[0] != ')' && input[0] != '(' && input[0] != ';') input++;
if (start == input) throw ParseException("expected string", line, input - lineStart);
char temp = input[0];
input[0] = 0;
- auto ret = allocator.alloc<Element>()->setString(IString(start, false), dollared)->setMetadata(line, start - lineStart);
+ auto ret = allocator.alloc<Element>()->setString(IString(start, false), dollared, false)->setMetadata(line, start - lineStart);
input[0] = temp;
return ret;
}
@@ -319,10 +322,11 @@ private:
Element& curr = *s[i];
IString id = curr[0]->str();
if (id == RESULT) {
+ if (curr.size() > 2) throw ParseException("invalid result arity", curr.line, curr.col);
functionTypes[name] = stringToWasmType(curr[1]->str());
} else if (id == TYPE) {
Name typeName = curr[1]->str();
- if (!wasm.checkFunctionType(typeName)) throw ParseException("unknown function");
+ if (!wasm.checkFunctionType(typeName)) throw ParseException("unknown function", curr.line, curr.col);
type = wasm.getFunctionType(typeName);
functionTypes[name] = type->result;
} else if (id == PARAM && curr.size() > 1) {
@@ -387,7 +391,11 @@ private:
bool brokeToAutoBlock;
Name getPrefixedName(std::string prefix) {
- return IString((prefix + std::to_string(otherIndex++)).c_str(), false);
+ // make sure to return a unique name not already on the stack
+ while (1) {
+ Name ret = IString((prefix + std::to_string(otherIndex++)).c_str(), false);
+ if (std::find(labelStack.begin(), labelStack.end(), ret) == labelStack.end()) return ret;
+ }
}
Name getFunctionName(Element& s) {
@@ -408,14 +416,23 @@ private:
// returns the next index in s
size_t parseFunctionNames(Element& s, Name& name, Name& exportName) {
size_t i = 1;
- while (i < s.size() && s[i]->isStr()) {
- if (!s[i]->dollared()) {
+ while (i < s.size() && i < 3 && s[i]->isStr()) {
+ if (s[i]->quoted()) {
// an export name
exportName = s[i]->str();
i++;
- } else {
+ } else if (s[i]->dollared()) {
name = s[i]->str();
i++;
+ } else {
+ break;
+ }
+ }
+ if (i < s.size() && s[i]->isList()) {
+ auto& inner = *s[i];
+ if (inner.size() > 0 && inner[0]->str() == EXPORT) {
+ exportName = inner[1]->str();
+ i++;
}
}
return i;
@@ -433,6 +450,7 @@ private:
auto ex = make_unique<Export>();
ex->name = exportName;
ex->value = name;
+ ex->kind = Export::Function;
wasm.addExport(ex.release());
}
functionCounter++;
@@ -490,6 +508,7 @@ private:
currLocalTypes[name] = type;
}
} else if (id == RESULT) {
+ if (curr.size() > 2) throw ParseException("invalid result arity", curr.line, curr.col);
result = stringToWasmType(curr[1]->str());
} else if (id == TYPE) {
Name name = curr[1]->str();
@@ -564,7 +583,6 @@ private:
if (str[1] == '6' && str[2] == '4' && (prefix || str[3] == 0)) return f64;
}
if (allowError) return none;
- throw ParseException("unknown type");
abort();
}
@@ -576,6 +594,7 @@ public:
#define abort_on(str) { throw ParseException(std::string("abort_on ") + str); }
Expression* parseExpression(Element& s) {
+ if (!s.isList()) throw ParseException("invalid node for parseExpression, needed list", s.line, s.col);
IString id = s[0]->str();
const char *str = id.str;
const char *dot = strchr(str, '.');
@@ -617,7 +636,7 @@ public:
if (op[3] == '_') return makeBinary(s, op[4] == 'u' ? BINARY_INT(DivU) : BINARY_INT(DivS), type);
if (op[3] == 0) return makeBinary(s, BINARY_FLOAT(Div), type);
}
- if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type);
+ if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type);
abort_on(op);
}
case 'e': {
@@ -735,6 +754,10 @@ public:
} else if (str[1] == 'u') return makeHost(s, HostOp::CurrentMemory);
abort_on(str);
}
+ case 'd': {
+ if (str[1] == 'r') return makeDrop(s);
+ abort_on(str);
+ }
case 'e': {
if (str[1] == 'l') return makeThenOrElse(s);
abort_on(str);
@@ -781,6 +804,7 @@ public:
}
case 't': {
if (str[1] == 'h') return makeThenOrElse(s);
+ if (str[1] == 'e' && str[2] == 'e') return makeTeeLocal(s);
abort_on(str);
}
case 'u': {
@@ -874,13 +898,20 @@ private:
return ret;
}
+ Expression* makeDrop(Element& s) {
+ auto ret = allocator.alloc<Drop>();
+ ret->value = parseExpression(s[1]);
+ ret->finalize();
+ return ret;
+ }
+
Expression* makeHost(Element& s, HostOp op) {
auto ret = allocator.alloc<Host>();
ret->op = op;
if (op == HostOp::HasFeature) {
ret->nameOperand = s[1]->str();
} else {
- parseCallOperands(s, 1, ret);
+ parseCallOperands(s, 1, s.size(), ret);
}
ret->finalize();
return ret;
@@ -906,40 +937,41 @@ private:
return ret;
}
- Expression* makeSetLocal(Element& s) {
+ Expression* makeTeeLocal(Element& s) {
auto ret = allocator.alloc<SetLocal>();
ret->index = getLocalIndex(*s[1]);
ret->value = parseExpression(s[2]);
- ret->type = currFunction->getLocalType(ret->index);
+ ret->setTee(true);
return ret;
}
-
- Index getGlobalIndex(Element& s) {
- if (s.dollared()) {
- auto name = s.str();
- for (Index i = 0; i < wasm.globals.size(); i++) {
- if (wasm.globals[i]->name == name) return i;
- }
- throw ParseException("bad global name", s.line, s.col);
- }
- // this is a numeric index
- Index ret = atoi(s.c_str());
- if (!wasm.checkGlobal(ret)) throw ParseException("bad global index", s.line, s.col);
+ Expression* makeSetLocal(Element& s) {
+ auto ret = allocator.alloc<SetLocal>();
+ ret->index = getLocalIndex(*s[1]);
+ ret->value = parseExpression(s[2]);
+ ret->setTee(false);
return ret;
}
Expression* makeGetGlobal(Element& s) {
auto ret = allocator.alloc<GetGlobal>();
- ret->index = getGlobalIndex(*s[1]);
- ret->type = wasm.getGlobal(ret->index)->type;
- return ret;
+ ret->name = s[1]->str();
+ auto* global = wasm.checkGlobal(ret->name);
+ if (global) {
+ ret->type = global->type;
+ return ret;
+ }
+ auto* import = wasm.checkImport(ret->name);
+ if (import && import->kind == Import::Global) {
+ ret->type = import->globalType;
+ return ret;
+ }
+ throw ParseException("bad get_global name", s.line, s.col);
}
Expression* makeSetGlobal(Element& s) {
auto ret = allocator.alloc<SetGlobal>();
- ret->index = getGlobalIndex(*s[1]);
+ ret->name = s[1]->str();
ret->value = parseExpression(s[2]);
- ret->type = wasm.getGlobal(ret->index)->type;
return ret;
}
@@ -1057,7 +1089,7 @@ private:
Expression* makeStore(Element& s, WasmType type) {
const char *extra = strchr(s[0]->c_str(), '.') + 6; // after "type.store"
auto ret = allocator.alloc<Store>();
- ret->type = type;
+ ret->valueType = type;
ret->bytes = getWasmTypeSize(type);
if (extra[0] == '8') {
ret->bytes = 1;
@@ -1088,6 +1120,7 @@ private:
}
ret->ptr = parseExpression(s[i]);
ret->value = parseExpression(s[i+1]);
+ ret->finalize();
return ret;
}
@@ -1148,24 +1181,28 @@ private:
Expression* makeLoop(Element& s) {
auto ret = allocator.alloc<Loop>();
size_t i = 1;
+ Name out;
if (s.size() > i + 1 && s[i]->isStr() && s[i + 1]->isStr()) { // out can only be named if both are
- ret->out = s[i]->str();
+ out = s[i]->str();
i++;
- } else {
- ret->out = getPrefixedName("loop-out");
}
if (s.size() > i && s[i]->isStr()) {
- ret->in = s[i]->str();
+ ret->name = s[i]->str();
i++;
} else {
- ret->in = getPrefixedName("loop-in");
+ ret->name = getPrefixedName("loop-in");
}
- labelStack.push_back(ret->out);
- labelStack.push_back(ret->in);
+ labelStack.push_back(ret->name);
ret->body = makeMaybeBlock(s, i);
labelStack.pop_back();
- labelStack.pop_back();
ret->finalize();
+ if (out.is()) {
+ auto* block = allocator.alloc<Block>();
+ block->name = out;
+ block->list.push_back(ret);
+ block->finalize();
+ return block;
+ }
return ret;
}
@@ -1173,7 +1210,7 @@ private:
auto ret = allocator.alloc<Call>();
ret->target = s[1]->str();
ret->type = functionTypes[ret->target];
- parseCallOperands(s, 2, ret);
+ parseCallOperands(s, 2, s.size(), ret);
return ret;
}
@@ -1181,8 +1218,8 @@ private:
auto ret = allocator.alloc<CallImport>();
ret->target = s[1]->str();
Import* import = wasm.getImport(ret->target);
- ret->type = import->type->result;
- parseCallOperands(s, 2, ret);
+ ret->type = import->functionType->result;
+ parseCallOperands(s, 2, s.size(), ret);
return ret;
}
@@ -1193,14 +1230,14 @@ private:
if (!fullType) throw ParseException("invalid call_indirect type", s.line, s.col);
ret->fullType = fullType->name;
ret->type = fullType->result;
- ret->target = parseExpression(s[2]);
- parseCallOperands(s, 3, ret);
+ parseCallOperands(s, 2, s.size() - 1, ret);
+ ret->target = parseExpression(s[s.size() - 1]);
return ret;
}
template<class T>
- void parseCallOperands(Element& s, size_t i, T* call) {
- while (i < s.size()) {
+ void parseCallOperands(Element& s, Index i, Index j, T* call) {
+ while (i < j) {
call->operands.push_back(parseExpression(s[i]));
i++;
}
@@ -1311,10 +1348,29 @@ private:
void parseMemory(Element& s) {
hasMemory = true;
-
- wasm.memory.initial = atoi(s[1]->c_str());
- if (s.size() == 2) return;
- size_t i = 2;
+ Index i = 1;
+ if (s[i]->dollared()) {
+ wasm.memory.name = s[i++]->str();
+ }
+ if (s[i]->isList()) {
+ auto& inner = *s[i];
+ if (inner[0]->str() == EXPORT) {
+ auto ex = make_unique<Export>();
+ ex->name = inner[1]->str();
+ ex->value = wasm.memory.name;
+ ex->kind = Export::Memory;
+ wasm.addExport(ex.release());
+ i++;
+ } else {
+ assert(inner.size() > 0 ? inner[0]->str() != IMPORT : true);
+ // (memory (data ..)) format
+ parseData(*s[i]);
+ wasm.memory.initial = wasm.memory.segments[0].data.size();
+ return;
+ }
+ }
+ wasm.memory.initial = atoi(s[i++]->c_str());
+ if (i == s.size()) return;
if (s[i]->isStr()) {
uint64_t max = atoll(s[i]->c_str());
if (max > Memory::kMaxSize) throw ParseException("total memory must be <= 4GB");
@@ -1346,108 +1402,220 @@ private:
}
void parseData(Element& s) {
- auto* offset = parseExpression(s[1]);
- const char *input = s[2]->c_str();
- if (auto size = strlen(input)) {
- std::vector<char> data;
- stringToBinary(input, size, data);
- wasm.memory.segments.emplace_back(offset, data.data(), data.size());
+ if (!hasMemory) throw ParseException("data but no memory");
+ Index i = 1;
+ Expression* offset;
+ if (i < s.size() && s[i]->isList()) {
+ // there is an init expression
+ offset = parseExpression(s[i++]);
} else {
- wasm.memory.segments.emplace_back(offset, "", 0);
+ offset = allocator.alloc<Const>()->set(Literal(int32_t(0)));
+ }
+ std::vector<char> data;
+ while (i < s.size()) {
+ const char *input = s[i++]->c_str();
+ if (auto size = strlen(input)) {
+ stringToBinary(input, size, data);
+ }
}
+ wasm.memory.segments.emplace_back(offset, data.data(), data.size());
}
void parseExport(Element& s) {
- if (!s[2]->dollared() && !std::isdigit(s[2]->str()[0])) {
- assert(s[2]->str() == MEMORY);
- if (!hasMemory) throw ParseException("memory exported but no memory");
- wasm.memory.exportName = s[1]->str();
- return;
- }
std::unique_ptr<Export> ex = make_unique<Export>();
ex->name = s[1]->str();
- ex->value = s[2]->str();
+ if (s[2]->isList()) {
+ auto& inner = *s[2];
+ if (inner[0]->str() == FUNC) {
+ ex->value = inner[1]->str();
+ ex->kind = Export::Function;
+ } else if (inner[0]->str() == MEMORY) {
+ if (!hasMemory) throw ParseException("memory exported but no memory");
+ ex->value = Name::fromInt(0);
+ ex->kind = Export::Memory;
+ } else if (inner[0]->str() == TABLE) {
+ ex->value = Name::fromInt(0);
+ ex->kind = Export::Table;
+ } else if (inner[0]->str() == GLOBAL) {
+ ex->value = inner[1]->str();
+ ex->kind = Export::Table;
+ } else {
+ WASM_UNREACHABLE();
+ }
+ } else if (!s[2]->dollared() && !std::isdigit(s[2]->str()[0])) {
+ if (s[2]->str() == MEMORY) {
+ if (!hasMemory) throw ParseException("memory exported but no memory");
+ ex->value = Name::fromInt(0);
+ ex->kind = Export::Memory;
+ } else if (s[2]->str() == TABLE) {
+ ex->value = Name::fromInt(0);
+ ex->kind = Export::Table;
+ } else if (s[2]->str() == GLOBAL) {
+ ex->value = s[3]->str();
+ ex->kind = Export::Table;
+ } else {
+ WASM_UNREACHABLE();
+ }
+ } else {
+ // function
+ ex->value = s[2]->str();
+ ex->kind = Export::Function;
+ }
wasm.addExport(ex.release());
}
void parseImport(Element& s) {
std::unique_ptr<Import> im = make_unique<Import>();
size_t i = 1;
+ bool newStyle = s.size() == 4 && s[3]->isList(); // (import "env" "STACKTOP" (global $stackTop i32))
+ if (newStyle) {
+ if ((*s[3])[0]->str() == FUNC) {
+ im->kind = Import::Function;
+ } else if ((*s[3])[0]->str() == MEMORY) {
+ im->kind = Import::Memory;
+ } else if ((*s[3])[0]->str() == TABLE) {
+ im->kind = Import::Table;
+ } else if ((*s[3])[0]->str() == GLOBAL) {
+ im->kind = Import::Global;
+ } else {
+ newStyle = false; // either (param..) or (result..)
+ }
+ }
if (s.size() > 3 && s[3]->isStr()) {
im->name = s[i++]->str();
+ } else if (newStyle) {
+ im->name = (*s[3])[1]->str();
} else {
im->name = Name::fromInt(importCounter);
}
importCounter++;
+ if (!s[i]->quoted()) {
+ if (s[i]->str() == MEMORY) {
+ im->kind = Import::Memory;
+ } else if (s[i]->str() == TABLE) {
+ im->kind = Import::Table;
+ } else if (s[i]->str() == GLOBAL) {
+ im->kind = Import::Global;
+ } else {
+ WASM_UNREACHABLE();
+ }
+ i++;
+ } else if (!newStyle) {
+ im->kind = Import::Function;
+ }
im->module = s[i++]->str();
if (!s[i]->isStr()) throw ParseException("no name for import");
im->base = s[i++]->str();
- std::unique_ptr<FunctionType> type = make_unique<FunctionType>();
- if (s.size() > i) {
- Element& params = *s[i];
- IString id = params[0]->str();
- if (id == PARAM) {
- for (size_t i = 1; i < params.size(); i++) {
- type->params.push_back(stringToWasmType(params[i]->str()));
+ // parse internals
+ Element& inner = newStyle ? *s[3] : s;
+ Index j = newStyle ? 2 : i;
+ if (im->kind == Import::Function) {
+ std::unique_ptr<FunctionType> type = make_unique<FunctionType>();
+ if (inner.size() > j) {
+ Element& params = *inner[j];
+ IString id = params[0]->str();
+ if (id == PARAM) {
+ for (size_t i = 1; i < params.size(); i++) {
+ type->params.push_back(stringToWasmType(params[i]->str()));
+ }
+ } else if (id == RESULT) {
+ type->result = stringToWasmType(params[1]->str());
+ } else if (id == TYPE) {
+ IString name = params[1]->str();
+ if (!wasm.checkFunctionType(name)) throw ParseException("bad function type for import");
+ *type = *wasm.getFunctionType(name);
+ } else {
+ throw ParseException("bad import element");
+ }
+ if (inner.size() > j+1) {
+ Element& result = *inner[j+1];
+ assert(result[0]->str() == RESULT);
+ type->result = stringToWasmType(result[1]->str());
}
- } else if (id == RESULT) {
- type->result = stringToWasmType(params[1]->str());
- } else if (id == TYPE) {
- IString name = params[1]->str();
- if (!wasm.checkFunctionType(name)) throw ParseException("bad function type for import");
- *type = *wasm.getFunctionType(name);
- } else {
- throw ParseException("bad import element");
- }
- if (s.size() > i+1) {
- Element& result = *s[i+1];
- assert(result[0]->str() == RESULT);
- type->result = stringToWasmType(result[1]->str());
}
+ im->functionType = ensureFunctionType(getSig(type.get()), &wasm);
+ } else if (im->kind == Import::Global) {
+ im->globalType = stringToWasmType(inner[j]->str());
}
- im->type = ensureFunctionType(getSig(type.get()), &wasm);
wasm.addImport(im.release());
}
void parseGlobal(Element& s) {
std::unique_ptr<Global> global = make_unique<Global>();
size_t i = 1;
- if (s.size() == 4) {
+ if (s[i]->dollared()) {
global->name = s[i++]->str();
} else {
global->name = Name::fromInt(globalCounter);
}
globalCounter++;
+ if (s[i]->isList()) {
+ auto& inner = *s[i];
+ if (inner[0]->str() == EXPORT) {
+ auto ex = make_unique<Export>();
+ ex->name = inner[1]->str();
+ ex->value = global->name;
+ ex->kind = Export::Global;
+ wasm.addExport(ex.release());
+ i++;
+ } else {
+ WASM_UNREACHABLE();
+ }
+ }
global->type = stringToWasmType(s[i++]->str());
global->init = parseExpression(s[i++]);
assert(i == s.size());
wasm.addGlobal(global.release());
}
+ bool seenTable = false;
+
void parseTable(Element& s) {
- if (s.size() == 1) return; // empty table in old notation
- if (!s[1]->dollared()) {
- if (s[1]->str() == ANYFUNC) {
+ seenTable = true;
+ Index i = 1;
+ if (i == s.size()) return; // empty table in old notation
+#if 0 // TODO: new table notation
+ if (s[i]->dollared()) {
+ wasm.table.name = s[i++]->str();
+ }
+#endif
+ if (i == s.size()) return;
+ if (s[i]->isList()) {
+ auto& inner = *s[i];
+ if (inner[0]->str() == EXPORT) {
+ auto ex = make_unique<Export>();
+ ex->name = inner[1]->str();
+ ex->value = wasm.table.name;
+ ex->kind = Export::Table;
+ wasm.addExport(ex.release());
+ i++;
+ } else {
+ WASM_UNREACHABLE();
+ }
+ }
+ if (i == s.size()) return;
+ if (!s[i]->dollared()) {
+ if (s[i]->str() == ANYFUNC) {
// (table type (elem ..))
- parseElem(*s[2]);
+ parseElem(*s[i + 1]);
wasm.table.initial = wasm.table.max = wasm.table.segments[0].data.size();
return;
}
// first element isn't dollared, and isn't anyfunc. this could be old syntax for (table 0 1) which means function 0 and 1, or it could be (table initial max? type), look for type
if (s[s.size() - 1]->str() == ANYFUNC) {
// (table initial max? type)
- wasm.table.initial = atoi(s[1]->c_str());
- wasm.table.max = atoi(s[2]->c_str());
+ wasm.table.initial = atoi(s[i]->c_str());
+ wasm.table.max = atoi(s[i + 1]->c_str());
return;
}
}
// old notation (table func1 func2 ..)
- parseElem(s);
+ parseElem(s, i);
wasm.table.initial = wasm.table.max = wasm.table.segments[0].data.size();
}
- void parseElem(Element& s) {
- Index i = 1;
+ void parseElem(Element& s, Index i = 1) {
+ if (!seenTable) throw ParseException("elem without table", s.line, s.col);
Expression* offset;
if (s[i]->isList()) {
// there is an init expression
@@ -1478,6 +1646,7 @@ private:
type->params.push_back(stringToWasmType(curr[j]->str()));
}
} else if (curr[0]->str() == RESULT) {
+ if (curr.size() > 2) throw ParseException("invalid result arity", curr.line, curr.col);
type->result = stringToWasmType(curr[1]->str());
}
}
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index b50ca0fb2..4725237dd 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -53,6 +53,7 @@ struct Visitor {
ReturnType visitUnary(Unary *curr) {}
ReturnType visitBinary(Binary *curr) {}
ReturnType visitSelect(Select *curr) {}
+ ReturnType visitDrop(Drop *curr) {}
ReturnType visitReturn(Return *curr) {}
ReturnType visitHost(Host *curr) {}
ReturnType visitNop(Nop *curr) {}
@@ -93,6 +94,7 @@ struct Visitor {
case Expression::Id::UnaryId: DELEGATE(Unary);
case Expression::Id::BinaryId: DELEGATE(Binary);
case Expression::Id::SelectId: DELEGATE(Select);
+ case Expression::Id::DropId: DELEGATE(Drop);
case Expression::Id::ReturnId: DELEGATE(Return);
case Expression::Id::HostId: DELEGATE(Host);
case Expression::Id::NopId: DELEGATE(Nop);
@@ -132,6 +134,7 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> {
ReturnType visitUnary(Unary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitBinary(Binary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitSelect(Select *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitDrop(Drop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitReturn(Return *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitHost(Host *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
ReturnType visitNop(Nop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
@@ -153,8 +156,8 @@ struct Walker : public VisitorType {
// Note that the visit*() for the result node is not called for you (i.e.,
// just one visit*() method is called by the traversal; if you replace a node,
// and you want to process the output, you must do that explicitly).
- void replaceCurrent(Expression *expression) {
- replace = expression;
+ Expression* replaceCurrent(Expression *expression) {
+ return replace = expression;
}
// Get the current module
@@ -264,14 +267,15 @@ struct Walker : public VisitorType {
static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast<CallIndirect>()); }
static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast<GetLocal>()); }
static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast<SetLocal>()); }
- static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); }
- static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); }
+ static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); }
+ static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); }
static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); }
static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); }
static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); }
static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); }
static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); }
static void doVisitSelect(SubType* self, Expression** currp) { self->visitSelect((*currp)->cast<Select>()); }
+ static void doVisitDrop(SubType* self, Expression** currp) { self->visitDrop((*currp)->cast<Drop>()); }
static void doVisitReturn(SubType* self, Expression** currp) { self->visitReturn((*currp)->cast<Return>()); }
static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast<Host>()); }
static void doVisitNop(SubType* self, Expression** currp) { self->visitNop((*currp)->cast<Nop>()); }
@@ -354,10 +358,10 @@ struct PostWalker : public Walker<SubType, VisitorType> {
case Expression::Id::CallIndirectId: {
self->pushTask(SubType::doVisitCallIndirect, currp);
auto& list = curr->cast<CallIndirect>()->operands;
+ self->pushTask(SubType::scan, &curr->cast<CallIndirect>()->target);
for (int i = int(list.size()) - 1; i >= 0; i--) {
self->pushTask(SubType::scan, &list[i]);
}
- self->pushTask(SubType::scan, &curr->cast<CallIndirect>()->target);
break;
}
case Expression::Id::GetLocalId: {
@@ -411,6 +415,11 @@ struct PostWalker : public Walker<SubType, VisitorType> {
self->pushTask(SubType::scan, &curr->cast<Select>()->ifTrue);
break;
}
+ case Expression::Id::DropId: {
+ self->pushTask(SubType::doVisitDrop, currp);
+ self->pushTask(SubType::scan, &curr->cast<Drop>()->value);
+ break;
+ }
case Expression::Id::ReturnId: {
self->pushTask(SubType::doVisitReturn, currp);
self->maybePushTask(SubType::scan, &curr->cast<Return>()->value);
@@ -437,6 +446,112 @@ struct PostWalker : public Walker<SubType, VisitorType> {
}
};
+// Traversal with a control-flow stack.
+
+template<typename SubType, typename VisitorType>
+struct ControlFlowWalker : public PostWalker<SubType, VisitorType> {
+ ControlFlowWalker() {}
+
+ std::vector<Expression*> controlFlowStack; // contains blocks, loops, and ifs
+
+ // Uses the control flow stack to find the target of a break to a name
+ Expression* findBreakTarget(Name name) {
+ assert(!controlFlowStack.empty());
+ Index i = controlFlowStack.size() - 1;
+ while (1) {
+ auto* curr = controlFlowStack[i];
+ if (Block* block = curr->template dynCast<Block>()) {
+ if (name == block->name) return curr;
+ } else if (Loop* loop = curr->template dynCast<Loop>()) {
+ if (name == loop->name) return curr;
+ } else {
+ // an if, ignorable
+ assert(curr->template is<If>());
+ }
+ if (i == 0) return nullptr;
+ i--;
+ }
+ }
+
+ static void doPreVisitControlFlow(SubType* self, Expression** currp) {
+ self->controlFlowStack.push_back(*currp);
+ }
+
+ static void doPostVisitControlFlow(SubType* self, Expression** currp) {
+ assert(self->controlFlowStack.back() == *currp);
+ self->controlFlowStack.pop_back();
+ }
+
+ static void scan(SubType* self, Expression** currp) {
+ auto* curr = *currp;
+
+ switch (curr->_id) {
+ case Expression::Id::BlockId:
+ case Expression::Id::IfId:
+ case Expression::Id::LoopId: {
+ self->pushTask(SubType::doPostVisitControlFlow, currp);
+ break;
+ }
+ default: {}
+ }
+
+ PostWalker<SubType, VisitorType>::scan(self, currp);
+
+ switch (curr->_id) {
+ case Expression::Id::BlockId:
+ case Expression::Id::IfId:
+ case Expression::Id::LoopId: {
+ self->pushTask(SubType::doPreVisitControlFlow, currp);
+ break;
+ }
+ default: {}
+ }
+ }
+};
+
+// Traversal with an expression stack.
+
+template<typename SubType, typename VisitorType>
+struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> {
+ ExpressionStackWalker() {}
+
+ std::vector<Expression*> expressionStack;
+
+ // Uses the control flow stack to find the target of a break to a name
+ Expression* findBreakTarget(Name name) {
+ assert(!expressionStack.empty());
+ Index i = expressionStack.size() - 1;
+ while (1) {
+ auto* curr = expressionStack[i];
+ if (Block* block = curr->template dynCast<Block>()) {
+ if (name == block->name) return curr;
+ } else if (Loop* loop = curr->template dynCast<Loop>()) {
+ if (name == loop->name) return curr;
+ } else {
+ WASM_UNREACHABLE();
+ }
+ if (i == 0) return nullptr;
+ i--;
+ }
+ }
+
+ static void doPreVisit(SubType* self, Expression** currp) {
+ self->expressionStack.push_back(*currp);
+ }
+
+ static void doPostVisit(SubType* self, Expression** currp) {
+ self->expressionStack.pop_back();
+ }
+
+ static void scan(SubType* self, Expression** currp) {
+ self->pushTask(SubType::doPostVisit, currp);
+
+ PostWalker<SubType, VisitorType>::scan(self, currp);
+
+ self->pushTask(SubType::doPreVisit, currp);
+ }
+};
+
// Traversal in the order of execution. This is quick and simple, but
// does not provide the same comprehensive information that a full
// conversion to basic blocks would. What it does give is a quick
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 3d9a0e1c3..58a30f9a3 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -30,7 +30,16 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>>
bool valid = true;
bool validateWebConstraints = false;
- std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type
+ struct BreakInfo {
+ WasmType type;
+ Index arity;
+ BreakInfo() {}
+ BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {}
+ };
+
+ std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them
+ std::map<Expression*, BreakInfo> breakInfos;
+
WasmType returnType = unreachable; // type used in returns
public:
@@ -42,55 +51,107 @@ public:
// visitors
+ static void visitPreBlock(WasmValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Block>();
+ if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
+ }
+
void visitBlock(Block *curr) {
// if we are break'ed to, then the value must be right for us
if (curr->name.is()) {
- // none or unreachable means a poison value that we should ignore - if consumed, it will error
- if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) {
- shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value");
+ if (breakInfos.count(curr) > 0) {
+ auto& info = breakInfos[curr];
+ // none or unreachable means a poison value that we should ignore - if consumed, it will error
+ if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) {
+ shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value");
+ }
+ shouldBeTrue(info.arity != Index(-1), curr, "break arities must match");
+ if (curr->list.size() > 0) {
+ auto last = curr->list.back()->type;
+ if (isConcreteWasmType(last) && info.type != unreachable) {
+ shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value");
+ }
+ if (last == none) {
+ shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type");
+ }
+ }
+ }
+ breakTargets[curr->name].pop_back();
+ }
+ if (curr->list.size() > 1) {
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) {
+ std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n";
+ }
}
- breakTypes.erase(curr->name);
}
}
- void visitIf(If *curr) {
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+
+ static void visitPreLoop(WasmValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Loop>();
+ if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
}
+
void visitLoop(Loop *curr) {
- if (curr->in.is()) {
- breakTypes.erase(curr->in);
- }
- if (curr->out.is()) {
- breakTypes.erase(curr->out);
+ if (curr->name.is()) {
+ breakTargets[curr->name].pop_back();
+ if (breakInfos.count(curr) > 0) {
+ auto& info = breakInfos[curr];
+ shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value");
+ }
}
}
- void noteBreak(Name name, Expression* value) {
+
+ void visitIf(If *curr) {
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+ }
+
+ // override scan to add a pre and a post check task to all nodes
+ static void scan(WasmValidator* self, Expression** currp) {
+ PostWalker<WasmValidator, Visitor<WasmValidator>>::scan(self, currp);
+
+ auto* curr = *currp;
+ if (curr->is<Block>()) self->pushTask(visitPreBlock, currp);
+ if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp);
+ }
+
+ void noteBreak(Name name, Expression* value, Expression* curr) {
WasmType valueType = none;
+ Index arity = 0;
if (value) {
valueType = value->type;
+ shouldBeUnequal(valueType, none, curr, "breaks must have a valid value");
+ arity = 1;
}
- if (breakTypes.count(name) == 0) {
- breakTypes[name] = valueType;
+ if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return;
+ auto* target = breakTargets[name].back();
+ if (breakInfos.count(target) == 0) {
+ breakInfos[target] = BreakInfo(valueType, arity);
} else {
- if (breakTypes[name] == unreachable) {
- breakTypes[name] = valueType;
+ auto& info = breakInfos[target];
+ if (info.type == unreachable) {
+ info.type = valueType;
} else if (valueType != unreachable) {
- if (valueType != breakTypes[name]) {
- breakTypes[name] = none; // a poison value that must not be consumed
+ if (valueType != info.type) {
+ info.type = none; // a poison value that must not be consumed
}
}
+ if (arity != info.arity) {
+ info.arity = Index(-1); // a poison value
+ }
}
}
void visitBreak(Break *curr) {
- noteBreak(curr->name, curr->value);
+ noteBreak(curr->name, curr->value, curr);
if (curr->condition) {
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32");
}
}
void visitSwitch(Switch *curr) {
for (auto& target : curr->targets) {
- noteBreak(target, curr->value);
+ noteBreak(target, curr->value, curr);
}
- noteBreak(curr->default_, curr->value);
+ noteBreak(curr->default_, curr->value, curr);
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
}
void visitCall(Call *curr) {
@@ -106,7 +167,7 @@ public:
void visitCallImport(CallImport *curr) {
auto* import = getModule()->checkImport(curr->target);
if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return;
- auto* type = import->type;
+ auto* type = import->functionType;
if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match")) {
@@ -115,6 +176,7 @@ public:
}
}
void visitCallIndirect(CallIndirect *curr) {
+ shouldBeTrue(getModule()->table.segments.size() > 0, curr, "no table");
auto* type = getModule()->checkFunctionType(curr->fullType);
if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return;
shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32");
@@ -128,18 +190,21 @@ public:
void visitSetLocal(SetLocal *curr) {
shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough");
if (curr->value->type != unreachable) {
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ if (curr->type != none) { // tee is ok anyhow
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ }
shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function");
}
}
void visitLoad(Load *curr) {
- validateAlignment(curr->align);
+ validateAlignment(curr->align, curr->type, curr->bytes);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
}
void visitStore(Store *curr) {
- validateAlignment(curr->align);
+ validateAlignment(curr->align, curr->type, curr->bytes);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match");
+ shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match");
}
void visitBinary(Binary *curr) {
if (curr->left->type != unreachable && curr->right->type != unreachable) {
@@ -216,6 +281,10 @@ public:
default: abort();
}
}
+ void visitSelect(Select* curr) {
+ shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid");
+ shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid");
+ }
void visitReturn(Return* curr) {
if (curr->value) {
@@ -245,9 +314,11 @@ public:
void visitImport(Import* curr) {
if (!validateWebConstraints) return;
- shouldBeUnequal(curr->type->result, i64, curr->name, "Imported function must not have i64 return type");
- for (WasmType param : curr->type->params) {
- shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters");
+ if (curr->kind == Import::Function) {
+ shouldBeUnequal(curr->functionType->result, i64, curr->name, "Imported function must not have i64 return type");
+ for (WasmType param : curr->functionType->params) {
+ shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters");
+ }
}
}
@@ -262,39 +333,66 @@ public:
void visitGlobal(Global* curr) {
shouldBeTrue(curr->init->is<Const>(), curr->name, "global init must be valid");
- shouldBeEqual(curr->type, curr->init->type, curr, "global init must have correct type");
+ shouldBeEqual(curr->type, curr->init->type, nullptr, "global init must have correct type");
}
void visitFunction(Function *curr) {
// if function has no result, it is ignored
// if body is unreachable, it might be e.g. a return
- if (curr->result != none) {
- if (curr->body->type != unreachable) {
- shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
- }
+ if (curr->body->type != unreachable) {
+ shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
+ }
+ if (curr->result != none) { // TODO: over previous too?
if (returnType != unreachable) {
shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns");
}
}
returnType = unreachable;
}
+
+ bool isConstant(Expression* curr) {
+ return curr->is<Const>() || curr->is<GetGlobal>();
+ }
+
void visitMemory(Memory *curr) {
shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB");
+ Index mustBeGreaterOrEqual = 0;
+ for (auto& segment : curr->segments) {
+ if (!shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue;
+ shouldBeTrue(isConstant(segment.offset), segment.offset, "segment offset should be constant");
+ Index size = segment.data.size();
+ shouldBeTrue(size <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
+ if (segment.offset->is<Const>()) {
+ Index start = segment.offset->cast<Const>()->value.geti32();
+ Index end = start + size;
+ shouldBeTrue(end <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
+ shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory");
+ mustBeGreaterOrEqual = end;
+ }
+ }
+ }
+ void visitTable(Table* curr) {
+ for (auto& segment : curr->segments) {
+ shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32");
+ shouldBeTrue(isConstant(segment.offset), segment.offset, "segment offset should be constant");
+ }
}
void visitModule(Module *curr) {
// exports
std::set<Name> exportNames;
for (auto& exp : curr->exports) {
Name name = exp->value;
- bool found = false;
- for (auto& func : curr->functions) {
- if (func->name == name) {
- found = true;
- break;
+ if (exp->kind == Export::Function) {
+ bool found = false;
+ for (auto& func : curr->functions) {
+ if (func->name == name) {
+ found = true;
+ break;
+ }
}
+ shouldBeTrue(found, name, "module exports must be found");
}
- shouldBeTrue(found, name, "module exports must be found");
Name exportName = exp->name;
shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique");
exportNames.insert(exportName);
@@ -311,12 +409,6 @@ public:
void doWalkFunction(Function* func) {
PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func);
- if (!shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid")) {
- for (auto& target : breakTypes) {
- std::cerr << " - " << target.first << '\n';
- }
- breakTypes.clear();
- }
}
private:
@@ -360,7 +452,8 @@ private:
template<typename T, typename S>
bool shouldBeEqual(S left, S right, T curr, const char* text) {
if (left != right) {
- fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl;
+ fail() << "" << left << " != " << right << ": " << text << ", on \n";
+ WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
valid = false;
return false;
}
@@ -379,7 +472,8 @@ private:
template<typename T, typename S>
bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) {
if (left != unreachable && left != right) {
- fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl;
+ fail() << "" << left << " != " << right << ": " << text << ", on \n";
+ WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
valid = false;
return false;
}
@@ -396,7 +490,7 @@ private:
return true;
}
- void validateAlignment(size_t align) {
+ void validateAlignment(size_t align, WasmType type, Index bytes) {
switch (align) {
case 1:
case 2:
@@ -408,6 +502,20 @@ private:
break;
}
}
+ shouldBeTrue(align <= bytes, align, "alignment must not exceed natural");
+ switch (type) {
+ case i32:
+ case f32: {
+ shouldBeTrue(align <= 4, align, "alignment must not exceed natural");
+ break;
+ }
+ case i64:
+ case f64: {
+ shouldBeTrue(align <= 8, align, "alignment must not exceed natural");
+ break;
+ }
+ default: {}
+ }
}
};
diff --git a/src/wasm.cpp b/src/wasm.cpp
index 8cfdee759..f6486eb50 100644
--- a/src/wasm.cpp
+++ b/src/wasm.cpp
@@ -120,7 +120,7 @@ struct TypeSeeker : public PostWalker<TypeSeeker, Visitor<TypeSeeker>> {
void visitLoop(Loop* curr) {
if (curr == target) {
types.push_back(curr->body->type);
- } else if (curr->in == targetName || curr->out == targetName) {
+ } else if (curr->name == targetName) {
types.clear(); // ignore all breaks til now, they were captured by someone with the same name
}
}
@@ -162,13 +162,7 @@ void Block::finalize() {
}
void Loop::finalize() {
- if (!out.is()) {
- type = body->type;
- return;
- }
-
- TypeSeeker seeker(this, this->out);
- type = mergeTypes(seeker.types);
+ type = body->type;
}
} // namespace wasm
diff --git a/src/wasm.h b/src/wasm.h
index 68558033d..d6bdfe91f 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -802,7 +802,7 @@ enum UnaryOp {
ConvertSInt32ToFloat32, ConvertSInt32ToFloat64, ConvertUInt32ToFloat32, ConvertUInt32ToFloat64, ConvertSInt64ToFloat32, ConvertSInt64ToFloat64, ConvertUInt64ToFloat32, ConvertUInt64ToFloat64, // int to float
PromoteFloat32, // f32 to f64
DemoteFloat64, // f64 to f32
- ReinterpretInt32, ReinterpretInt64 // reinterpret bits to float
+ ReinterpretInt32, ReinterpretInt64, // reinterpret bits to float
};
enum BinaryOp {
@@ -877,6 +877,7 @@ public:
UnaryId,
BinaryId,
SelectId,
+ DropId,
ReturnId,
HostId,
NopId,
@@ -929,6 +930,7 @@ inline const char *getExpressionName(Expression *curr) {
case Expression::Id::UnaryId: return "unary";
case Expression::Id::BinaryId: return "binary";
case Expression::Id::SelectId: return "select";
+ case Expression::Id::DropId: return "drop";
case Expression::Id::ReturnId: return "return";
case Expression::Id::HostId: return "host";
case Expression::Id::NopId: return "nop";
@@ -999,7 +1001,7 @@ public:
Loop() {}
Loop(MixedArena& allocator) {}
- Name out, in;
+ Name name;
Expression *body;
// set the type of a loop if you already know it
@@ -1108,8 +1110,13 @@ public:
Index index;
Expression *value;
- void finalize() {
- type = value->type;
+ bool isTee() {
+ return type != none;
+ }
+
+ void setTee(bool is) {
+ if (is) type = value->type;
+ else type = none;
}
};
@@ -1118,7 +1125,7 @@ public:
GetGlobal() {}
GetGlobal(MixedArena& allocator) {}
- Index index;
+ Name name;
};
class SetGlobal : public SpecificExpression<Expression::SetGlobalId> {
@@ -1126,12 +1133,8 @@ public:
SetGlobal() {}
SetGlobal(MixedArena& allocator) {}
- Index index;
+ Name name;
Expression *value;
-
- void finalize() {
- type = value->type;
- }
};
class Load : public SpecificExpression<Expression::LoadId> {
@@ -1150,16 +1153,17 @@ public:
class Store : public SpecificExpression<Expression::StoreId> {
public:
- Store() {}
- Store(MixedArena& allocator) {}
+ Store() : valueType(none) {}
+ Store(MixedArena& allocator) : Store() {}
uint8_t bytes;
Address offset;
Address align;
Expression *ptr, *value;
+ WasmType valueType; // the store never returns a value
void finalize() {
- type = value->type;
+ assert(valueType != none); // must be set
}
};
@@ -1312,6 +1316,14 @@ public:
}
};
+class Drop : public SpecificExpression<Expression::DropId> {
+public:
+ Drop() {}
+ Drop(MixedArena& allocator) {}
+
+ Expression *value;
+};
+
class Return : public SpecificExpression<Expression::ReturnId> {
public:
Return() : value(nullptr) {
@@ -1418,16 +1430,33 @@ public:
class Import {
public:
- Import() : type(nullptr) {}
+ enum Kind {
+ Function = 0,
+ Table = 1,
+ Memory = 2,
+ Global = 3,
+ };
+
+ Import() : functionType(nullptr), globalType(none) {}
Name name, module, base; // name = module.base
- FunctionType* type;
+ Kind kind;
+ FunctionType* functionType; // for Function imports
+ WasmType globalType; // for Global imports
};
class Export {
public:
- Name name; // exported name
+ enum Kind {
+ Function = 0,
+ Table = 1,
+ Memory = 2,
+ Global = 3,
+ };
+
+ Name name; // exported name - note that this is the key, as the internal name is non-unique (can have multiple exports for an internal, also over kinds)
Name value; // internal name
+ Kind kind;
};
class Table {
@@ -1445,10 +1474,13 @@ public:
}
};
+ Name name;
Address initial, max;
std::vector<Segment> segments;
- Table() : initial(0), max(kMaxSize) {}
+ Table() : initial(0), max(kMaxSize) {
+ name = Name::fromInt(0);
+ }
};
class Memory {
@@ -1470,11 +1502,13 @@ public:
}
};
+ Name name;
Address initial, max; // sizes are in pages
std::vector<Segment> segments;
- Name exportName;
- Memory() : initial(0), max(kMaxSize) {}
+ Memory() : initial(0), max(kMaxSize) {
+ name = Name::fromInt(0);
+ }
};
class Global {
@@ -1503,7 +1537,7 @@ private:
// TODO: add a build option where Names are just indices, and then these methods are not needed
std::map<Name, FunctionType*> functionTypesMap;
std::map<Name, Import*> importsMap;
- std::map<Name, Export*> exportsMap;
+ std::map<Name, Export*> exportsMap; // exports map is by the *exported* name, which is unique
std::map<Name, Function*> functionsMap;
std::map<Name, Global*> globalsMap;