diff options
author | Alon Zakai <alonzakai@gmail.com> | 2016-09-07 10:55:02 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-07 10:55:02 -0700 |
commit | 135a20cd110d356d5d098a08a7b447205adaed7a (patch) | |
tree | f5200a6b35f19d1bf95dea1fa7e339f40391413b /src | |
parent | fbe77b167002e8a49225b607ca8c37dc7e4b41fe (diff) | |
parent | dd197d3212ac28e778d372df9d03e58b21386648 (diff) | |
download | binaryen-135a20cd110d356d5d098a08a7b447205adaed7a.tar.gz binaryen-135a20cd110d356d5d098a08a7b447205adaed7a.tar.bz2 binaryen-135a20cd110d356d5d098a08a7b447205adaed7a.zip |
Merge pull request #678 from WebAssembly/stack
Stack machine + 0xc update
Diffstat (limited to 'src')
40 files changed, 1649 insertions, 893 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h index ce3c0749d..53fa78db4 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -143,15 +143,13 @@ class Asm2WasmBuilder { // globals - unsigned nextGlobal; // next place to put a global - unsigned maxGlobal; // highest address we can put a global struct MappedGlobal { - unsigned address; WasmType type; bool import; // if true, this is an import - we should read the value, not just set a zero IString module, base; - MappedGlobal() : address(0), type(none), import(false) {} - MappedGlobal(unsigned address, WasmType type, bool import, IString module, IString base) : address(address), type(type), import(import), module(module), base(base) {} + MappedGlobal() : type(none), import(false) {} + MappedGlobal(WasmType type) : type(type), import(false) {} + MappedGlobal(WasmType type, bool import, IString module, IString base) : type(type), import(import), module(module), base(base) {} }; // function table @@ -165,31 +163,20 @@ class Asm2WasmBuilder { public: std::map<IString, MappedGlobal> mappedGlobals; - // the global mapping info is not present in the output wasm. We need to save it on the side - // if we intend to load and run this module's wasm. - void serializeMappedGlobals(const char *filename) { - FILE *f = fopen(filename, "w"); - assert(f); - fprintf(f, "{\n"); - bool first = true; - for (auto& pair : mappedGlobals) { - auto name = pair.first; - auto& global = pair.second; - if (first) first = false; - else fprintf(f, ","); - fprintf(f, "\"%s\": { \"address\": %d, \"type\": %d, \"import\": %d, \"module\": \"%s\", \"base\": \"%s\" }\n", - name.str, global.address, global.type, global.import, global.module.str, global.base.str); - } - fprintf(f, "}"); - fclose(f); - } - private: - void allocateGlobal(IString name, WasmType type, bool import, IString module = IString(), IString base = IString()) { + void allocateGlobal(IString name, WasmType type) { assert(mappedGlobals.find(name) == mappedGlobals.end()); - mappedGlobals.emplace(name, MappedGlobal(nextGlobal, type, import, module, base)); - nextGlobal += 8; - assert(nextGlobal < maxGlobal); + mappedGlobals.emplace(name, MappedGlobal(type)); + auto global = new Global(); + global->name = name; + global->type = type; + Literal value; + if (type == i32) value = Literal(uint32_t(0)); + else if (type == f32) value = Literal(float(0)); + else if (type == f64) value = Literal(double(0)); + else WASM_UNREACHABLE(); + global->init = wasm.allocator.alloc<Const>()->set(value); + wasm.addGlobal(global); } struct View { @@ -237,12 +224,6 @@ private: // if we already saw this signature, verify it's the same (or else handle that) if (importedFunctionTypes.find(importName) != importedFunctionTypes.end()) { FunctionType* previous = importedFunctionTypes[importName].get(); -#if 0 - std::cout << "compare " << importName.str << "\nfirst: "; - type.print(std::cout, 0); - std::cout << "\nsecond: "; - previous.print(std::cout, 0) << ".\n"; -#endif if (*type != *previous) { // merge it in. we'll add on extra 0 parameters for ones not actually used, and upgrade types to // double where there is a conflict (which is ok since in JS, double can contain everything @@ -260,6 +241,8 @@ private: } if (previous->result == none) { previous->result = type->result; // use a more concrete type + } else if (previous->result != type->result) { + previous->result = f64; // overloaded return type, make it a double } } } else { @@ -278,8 +261,6 @@ public: : wasm(wasm), allocator(wasm.allocator), builder(wasm), - nextGlobal(8), - maxGlobal(1000), memoryGrowth(memoryGrowth), debug(debug), imprecise(imprecise), @@ -416,9 +397,9 @@ private: } void fixCallType(Expression* call, WasmType type) { - if (call->is<Call>()) call->type = type; - if (call->is<CallImport>()) call->type = type; - else if (call->is<CallIndirect>()) call->type = type; + if (call->is<Call>()) call->cast<Call>()->type = type; + if (call->is<CallImport>()) call->cast<CallImport>()->type = type; + else if (call->is<CallIndirect>()) call->cast<CallIndirect>()->type = type; } FunctionType* getBuiltinFunctionType(Name module, Name base, ExpressionList* operands = nullptr) { @@ -520,12 +501,13 @@ void Asm2WasmBuilder::processAsm(Ref ast) { type = WasmType::f64; } if (type != WasmType::none) { - // wasm has no imported constants, so allocate a global, and we need to write the value into that - allocateGlobal(name, type, true, import->module, import->base); - delete import; + import->kind = Import::Global; + import->globalType = type; + mappedGlobals.emplace(name, type); } else { - wasm.addImport(import); + import->kind = Import::Function; } + wasm.addImport(import); }; IString Int8Array, Int16Array, Int32Array, UInt8Array, UInt16Array, UInt32Array, Float32Array, Float64Array; @@ -537,7 +519,10 @@ void Asm2WasmBuilder::processAsm(Ref ast) { for (unsigned i = 1; i < body->size(); i++) { if (body[i][0] == DEFUN) numFunctions++; } - optimizingBuilder = make_unique<OptimizingIncrementalModuleBuilder>(&wasm, numFunctions); + optimizingBuilder = make_unique<OptimizingIncrementalModuleBuilder>(&wasm, numFunctions, [&](PassRunner& passRunner) { + // run autodrop first, before optimizations + passRunner.add<AutoDrop>(); + }); } // first pass - do almost everything, but function imports and indirect calls @@ -553,7 +538,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { if (value[0] == NUM) { // global int assert(value[1]->getNumber() == 0); - allocateGlobal(name, WasmType::i32, false); + allocateGlobal(name, WasmType::i32); } else if (value[0] == BINARY) { // int import assert(value[1] == OR && value[3][0] == NUM && value[3][1]->getNumber() == 0); @@ -566,14 +551,14 @@ void Asm2WasmBuilder::processAsm(Ref ast) { if (import[0] == NUM) { // global assert(import[1]->getNumber() == 0); - allocateGlobal(name, WasmType::f64, false); + allocateGlobal(name, WasmType::f64); } else { // import addImport(name, import, WasmType::f64); } } else if (value[0] == CALL) { assert(value[1][0] == NAME && value[1][1] == Math_fround && value[2][0][0] == NUM && value[2][0][1]->getNumber() == 0); - allocateGlobal(name, WasmType::f32, false); + allocateGlobal(name, WasmType::f32); } else if (value[0] == DOT) { // simple module.base import. can be a view, or a function. if (value[1][0] == NAME) { @@ -710,6 +695,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { auto* export_ = new Export; export_->name = key; export_->value = value; + export_->kind = Export::Function; wasm.addExport(export_); exported[key] = export_; } @@ -719,11 +705,9 @@ void Asm2WasmBuilder::processAsm(Ref ast) { if (optimize) { optimizingBuilder->finish(); - if (maxGlobal < 1024) { - PassRunner passRunner(&wasm); - passRunner.add("post-emscripten"); - passRunner.run(); - } + PassRunner passRunner(&wasm); + passRunner.add("post-emscripten"); + passRunner.run(); } // second pass. first, function imports @@ -731,15 +715,16 @@ void Asm2WasmBuilder::processAsm(Ref ast) { std::vector<IString> toErase; for (auto& import : wasm.imports) { + if (import->kind != Import::Function) continue; IString name = import->name; if (importedFunctionTypes.find(name) != importedFunctionTypes.end()) { // special math builtins FunctionType* builtin = getBuiltinFunctionType(import->module, import->base); if (builtin) { - import->type = builtin; + import->functionType = builtin; continue; } - import->type = ensureFunctionType(getSig(importedFunctionTypes[name].get()), &wasm); + import->functionType = ensureFunctionType(getSig(importedFunctionTypes[name].get()), &wasm); } else if (import->module != ASM2WASM) { // special-case the special module // never actually used toErase.push_back(name); @@ -750,7 +735,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { wasm.removeImport(curr); } - // Finalize indirect calls and import calls + // Finalize calls now that everything is known and generated struct FinalizeCalls : public WalkerPass<PostWalker<FinalizeCalls, Visitor<FinalizeCalls>>> { bool isFunctionParallel() override { return true; } @@ -761,6 +746,14 @@ void Asm2WasmBuilder::processAsm(Ref ast) { FinalizeCalls(Asm2WasmBuilder* parent) : parent(parent) {} + void visitCall(Call* curr) { + assert(getModule()->checkFunction(curr->target) ? true : (std::cerr << curr->target << '\n', false)); + auto result = getModule()->getFunction(curr->target)->result; + if (curr->type != result) { + curr->type = result; + } + } + void visitCallImport(CallImport* curr) { // fill out call_import - add extra params as needed, etc. asm tolerates ffi overloading, wasm does not auto iter = parent->importedFunctionTypes.find(curr->target); @@ -782,7 +775,25 @@ void Asm2WasmBuilder::processAsm(Ref ast) { } } } + auto importResult = getModule()->getImport(curr->target)->functionType->result; + if (curr->type != importResult) { + if (importResult == f64) { + // we use a JS f64 value which is the most general, and convert to it + switch (curr->type) { + case i32: replaceCurrent(parent->builder.makeUnary(TruncSFloat64ToInt32, curr)); break; + case f32: replaceCurrent(parent->builder.makeUnary(DemoteFloat64, curr)); break; + case none: replaceCurrent(parent->builder.makeDrop(curr)); break; + default: WASM_UNREACHABLE(); + } + } else { + assert(curr->type == none); + // we don't want a return value here, but the import does provide one + replaceCurrent(parent->builder.makeDrop(curr)); + } + curr->type = importResult; + } } + void visitCallIndirect(CallIndirect* curr) { // we already call into target = something + offset, where offset is a callImport with the name of the table. replace that with the table offset auto add = curr->target->cast<Binary>(); @@ -793,6 +804,12 @@ void Asm2WasmBuilder::processAsm(Ref ast) { }; PassRunner passRunner(&wasm); passRunner.add<FinalizeCalls>(this); + passRunner.add<ReFinalize>(); // FinalizeCalls changes call types, need to percolate + passRunner.add<AutoDrop>(); // FinalizeCalls may cause us to require additional drops + if (optimize) { + passRunner.add("vacuum"); // autodrop can add some garbage + passRunner.add("remove-unused-brs"); // vacuum may open up more opportunities + } passRunner.run(); // apply memory growth, if relevant @@ -802,7 +819,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { wasm.addFunction(builder.makeFunction( GROW_WASM_MEMORY, { { NEW_SIZE, i32 } }, - none, + i32, {}, builder.makeHost( GrowMemory, @@ -812,10 +829,57 @@ void Asm2WasmBuilder::processAsm(Ref ast) { )); auto export_ = new Export; export_->name = export_->value = GROW_WASM_MEMORY; + export_->kind = Export::Function; wasm.addExport(export_); } - wasm.memory.exportName = MEMORY; +#if 0 + // export memory + auto memoryExport = make_unique<Export>(); + memoryExport->name = MEMORY; + memoryExport->value = Name::fromInt(0); + memoryExport->kind = Export::Memory; + wasm.addExport(memoryExport.release()); +#else + // import memory + auto memoryImport = make_unique<Import>(); + memoryImport->name = MEMORY; + memoryImport->module = ENV; + memoryImport->base = MEMORY; + memoryImport->kind = Import::Memory; + wasm.addImport(memoryImport.release()); + + // import table + auto tableImport = make_unique<Import>(); + tableImport->name = TABLE; + tableImport->module = ENV; + tableImport->base = TABLE; + tableImport->kind = Import::Table; + wasm.addImport(tableImport.release()); + + // Import memory offset + { + auto* import = new Import; + import->name = Name("memoryBase"); + import->module = Name("env"); + import->base = Name("memoryBase"); + import->kind = Import::Global; + import->globalType = i32; + wasm.addImport(import); + } + + // Import table offset + { + auto* import = new Import; + import->name = Name("tableBase"); + import->module = Name("env"); + import->base = Name("tableBase"); + import->kind = Import::Global; + import->globalType = i32; + wasm.addImport(import); + } + +#endif #if 0 // enable asm2wasm i64 optimizations when browsers have consistent i64 support in wasm if (udivmoddi4.is() && getTempRet0.is()) { @@ -1009,20 +1073,16 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { auto ret = allocator.alloc<SetLocal>(); ret->index = function->getLocalIndex(ast[2][1]->getIString()); ret->value = process(ast[3]); - ret->type = ret->value->type; + ret->setTee(false); + ret->finalize(); return ret; } - // global var, do a store to memory + // global var assert(mappedGlobals.find(name) != mappedGlobals.end()); - MappedGlobal global = mappedGlobals[name]; - auto ret = allocator.alloc<Store>(); - ret->bytes = getWasmTypeSize(global.type); - ret->offset = 0; - ret->align = ret->bytes; - ret->ptr = builder.makeConst(Literal(int32_t(global.address))); - ret->value = process(ast[3]); - ret->type = global.type; - return ret; + auto* ret = builder.makeSetGlobal(name, process(ast[3])); + // set_global does not return; if our value is trivially not used, don't emit a load (if nontrivially not used, opts get it later) + if (astStackHelper.getParent()[0] == STAT) return ret; + return builder.makeSequence(ret, builder.makeGetGlobal(name, ret->value->type)); } else if (ast[2][0] == SUB) { Ref target = ast[2]; assert(target[1][0] == NAME); @@ -1035,10 +1095,11 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->align = view.bytes; ret->ptr = processUnshifted(target[2], view.bytes); ret->value = process(ast[3]); - ret->type = asmToWasmType(view.type); - if (ret->type != ret->value->type) { + ret->valueType = asmToWasmType(view.type); + ret->finalize(); + if (ret->valueType != ret->value->type) { // in asm.js we have some implicit coercions that we must do explicitly here - if (ret->type == f32 && ret->value->type == f64) { + if (ret->valueType == f32 && ret->value->type == f64) { auto conv = allocator.alloc<Unary>(); conv->op = DemoteFloat64; conv->value = ret->value; @@ -1076,7 +1137,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { import->name = F64_REM; import->module = ASM2WASM; import->base = F64_REM; - import->type = ensureFunctionType("ddd", &wasm); + import->functionType = ensureFunctionType("ddd", &wasm); + import->kind = Import::Function; wasm.addImport(import); } return call; @@ -1101,7 +1163,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { import->name = call->target; import->module = ASM2WASM; import->base = call->target; - import->type = ensureFunctionType("iii", &wasm); + import->functionType = ensureFunctionType("iii", &wasm); + import->kind = Import::Function; wasm.addImport(import); } return call; @@ -1139,22 +1202,16 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { import->name = DEBUGGER; import->module = ASM2WASM; import->base = DEBUGGER; - import->type = ensureFunctionType("v", &wasm); + import->functionType = ensureFunctionType("v", &wasm); + import->kind = Import::Function; wasm.addImport(import); } return call; } - // global var, do a load from memory - assert(mappedGlobals.find(name) != mappedGlobals.end()); - MappedGlobal global = mappedGlobals[name]; - auto ret = allocator.alloc<Load>(); - ret->bytes = getWasmTypeSize(global.type); - ret->signed_ = true; // but doesn't matter - ret->offset = 0; - ret->align = ret->bytes; - ret->ptr = builder.makeConst(Literal(int32_t(global.address))); - ret->type = global.type; - return ret; + // global var + assert(mappedGlobals.find(name) != mappedGlobals.end() ? true : (std::cerr << name.str << '\n', false)); + MappedGlobal& global = mappedGlobals[name]; + return builder.makeGetGlobal(name, global.type); } else if (what == SUB) { Ref target = ast[1]; assert(target[0] == NAME); @@ -1251,7 +1308,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { import->name = F64_TO_INT; import->module = ASM2WASM; import->base = F64_TO_INT; - import->type = ensureFunctionType("id", &wasm); + import->functionType = ensureFunctionType("id", &wasm); + import->kind = Import::Function; wasm.addImport(import); } return ret; @@ -1273,11 +1331,9 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { } abort_on("bad unary", ast); } else if (what == IF) { - auto ret = allocator.alloc<If>(); - ret->condition = process(ast[1]); - ret->ifTrue = process(ast[2]); - ret->ifFalse = !!ast[3] ? process(ast[3]) : nullptr; - return ret; + auto* condition = process(ast[1]); + auto* ifTrue = process(ast[2]); + return builder.makeIf(condition, ifTrue, !!ast[3] ? process(ast[3]) : nullptr); } else if (what == CALL) { if (ast[1][0] == NAME) { IString name = ast[1][1]->getIString(); @@ -1330,9 +1386,10 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { // No wasm support, so use a temp local ensureI32Temp(); auto set = allocator.alloc<SetLocal>(); + set->setTee(false); set->index = function->getLocalIndex(I32_TEMP); set->value = value; - set->type = i32; + set->finalize(); auto get = [&]() { auto ret = allocator.alloc<GetLocal>(); ret->index = function->getLocalIndex(I32_TEMP); @@ -1477,8 +1534,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { out = getNextId("while-out"); in = getNextId("while-in"); } - ret->out = out; - ret->in = in; + ret->name = in; breakStack.push_back(out); continueStack.push_back(in); if (forever) { @@ -1489,6 +1545,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { If *condition = allocator.alloc<If>(); condition->condition = builder.makeUnary(EqZInt32, process(ast[1])); condition->ifTrue = breakOut; + condition->finalize(); auto body = allocator.alloc<Block>(); body->list.push_back(condition); body->list.push_back(process(ast[2])); @@ -1496,11 +1553,13 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->body = body; } // loops do not automatically loop, add a branch back - Block* block = blockify(ret->body); + Block* block = builder.blockifyWithName(ret->body, out); auto continuer = allocator.alloc<Break>(); - continuer->name = ret->in; + continuer->name = ret->name; block->list.push_back(continuer); + block->finalize(); ret->body = block; + ret->finalize(); continueStack.pop_back(); breakStack.pop_back(); return ret; @@ -1526,19 +1585,22 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { if (breakSeeker.found == 0) { auto block = allocator.alloc<Block>(); block->list.push_back(child); + if (isConcreteWasmType(child->type)) { + block->list.push_back(builder.makeNop()); // ensure a nop at the end, so the block has guaranteed none type and no values fall through + } block->name = stop; block->finalize(); return block; } else { auto loop = allocator.alloc<Loop>(); loop->body = child; - loop->out = stop; - loop->in = more; - return loop; + loop->name = more; + loop->finalize(); + return builder.blockifyWithName(loop, stop); } } // general do-while loop - auto ret = allocator.alloc<Loop>(); + auto loop = allocator.alloc<Loop>(); IString out, in; if (!parentLabel.isNull()) { out = getBreakLabelName(parentLabel); @@ -1548,20 +1610,19 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { out = getNextId("do-out"); in = getNextId("do-in"); } - ret->out = out; - ret->in = in; + loop->name = in; breakStack.push_back(out); continueStack.push_back(in); - ret->body = process(ast[2]); + loop->body = process(ast[2]); continueStack.pop_back(); breakStack.pop_back(); Break *continuer = allocator.alloc<Break>(); continuer->name = in; continuer->condition = process(ast[1]); - Block *block = blockify(ret->body); - block->list.push_back(continuer); - ret->body = block; - return ret; + Block *block = builder.blockifyWithName(loop->body, out, continuer); + loop->body = block; + loop->finalize(); + return loop; } else if (what == FOR) { Ref finit = ast[1], fcond = ast[2], @@ -1577,8 +1638,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { out = getNextId("for-out"); in = getNextId("for-in"); } - ret->out = out; - ret->in = in; + ret->name = in; breakStack.push_back(out); continueStack.push_back(in); Break *breakOut = allocator.alloc<Break>(); @@ -1586,6 +1646,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { If *condition = allocator.alloc<If>(); condition->condition = builder.makeUnary(EqZInt32, process(fcond)); condition->ifTrue = breakOut; + condition->finalize(); auto body = allocator.alloc<Block>(); body->list.push_back(condition); body->list.push_back(process(fbody)); @@ -1593,11 +1654,11 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { body->finalize(); ret->body = body; // loops do not automatically loop, add a branch back - Block* block = blockify(ret->body); auto continuer = allocator.alloc<Break>(); - continuer->name = ret->in; - block->list.push_back(continuer); + continuer->name = ret->name; + Block* block = builder.blockifyWithName(ret->body, out, continuer); ret->body = block; + ret->finalize(); continueStack.pop_back(); breakStack.pop_back(); Block *outer = allocator.alloc<Block>(); @@ -1615,7 +1676,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->condition = process(ast[1]); ret->ifTrue = process(ast[2]); ret->ifFalse = process(ast[3]); - ret->type = ret->ifTrue->type; + ret->finalize(); return ret; } else if (what == SEQ) { // Some (x, y) patterns can be optimized, like bitcasts, @@ -1774,7 +1835,9 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { // if there is a shift, we can just look through it, etc. processUnshifted = [&](Ref ptr, unsigned bytes) { auto shifts = bytesToShift(bytes); - if (ptr[0] == BINARY && ptr[1] == RSHIFT && ptr[3][0] == NUM && ptr[3][1]->getInteger() == shifts) { + // HEAP?[addr >> ?], or HEAP8[x | 0] + if ((ptr[0] == BINARY && ptr[1] == RSHIFT && ptr[3][0] == NUM && ptr[3][1]->getInteger() == shifts) || + (bytes == 1 && ptr[0] == BINARY && ptr[1] == OR && ptr[3][0] == NUM && ptr[3][1]->getInteger() == 0)) { return process(ptr[2]); // look through it } else if (ptr[0] == NUM) { // constant, apply a shift (e.g. HEAP32[1] is address 4) diff --git a/src/ast_utils.h b/src/ast_utils.h index 3e45d0e33..9b2ff10cd 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -21,6 +21,7 @@ #include "wasm.h" #include "wasm-traversal.h" #include "wasm-builder.h" +#include "pass.h" namespace wasm { @@ -129,6 +130,9 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer bool checkPost(Expression* curr) { visit(curr); + if (curr->is<Loop>()) { + branches = true; + } return hasAnything(); } @@ -147,8 +151,7 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks } void visitLoop(Loop* curr) { - if (curr->in.is()) breakNames.erase(curr->in); // these were internal breaks - if (curr->out.is()) breakNames.erase(curr->out); // these were internal breaks + if (curr->name.is()) breakNames.erase(curr->name); // these were internal breaks } void visitCall(Call *curr) { calls = true; } @@ -244,7 +247,7 @@ struct ExpressionManipulator { return builder.makeIf(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); } Expression* visitLoop(Loop *curr) { - return builder.makeLoop(curr->out, curr->in, copy(curr->body)); + return builder.makeLoop(curr->name, copy(curr->body)); } Expression* visitBreak(Break *curr) { return builder.makeBreak(curr->name, copy(curr->value), copy(curr->condition)); @@ -277,19 +280,23 @@ struct ExpressionManipulator { return builder.makeGetLocal(curr->index, curr->type); } Expression* visitSetLocal(SetLocal *curr) { - return builder.makeSetLocal(curr->index, copy(curr->value)); + if (curr->isTee()) { + return builder.makeTeeLocal(curr->index, copy(curr->value)); + } else { + return builder.makeSetLocal(curr->index, copy(curr->value)); + } } Expression* visitGetGlobal(GetGlobal *curr) { - return builder.makeGetGlobal(curr->index, curr->type); + return builder.makeGetGlobal(curr->name, curr->type); } Expression* visitSetGlobal(SetGlobal *curr) { - return builder.makeSetGlobal(curr->index, copy(curr->value)); + return builder.makeSetGlobal(curr->name, copy(curr->value)); } Expression* visitLoad(Load *curr) { return builder.makeLoad(curr->bytes, curr->signed_, curr->offset, curr->align, copy(curr->ptr), curr->type); } Expression* visitStore(Store *curr) { - return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value)); + return builder.makeStore(curr->bytes, curr->offset, curr->align, copy(curr->ptr), copy(curr->value), curr->valueType); } Expression* visitConst(Const *curr) { return builder.makeConst(curr->value); @@ -303,6 +310,9 @@ struct ExpressionManipulator { Expression* visitSelect(Select *curr) { return builder.makeSelect(copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); } + Expression* visitDrop(Drop *curr) { + return builder.makeDrop(copy(curr->value)); + } Expression* visitReturn(Return *curr) { return builder.makeReturn(copy(curr->value)); } @@ -340,7 +350,7 @@ struct ExpressionAnalyzer { for (int i = int(stack.size()) - 2; i >= 0; i--) { auto* curr = stack[i]; auto* above = stack[i + 1]; - // only if and block can drop values + // only if and block can drop values (pre-drop expression was added) FIXME if (curr->is<Block>()) { auto* block = curr->cast<Block>(); for (size_t j = 0; j < block->list.size() - 1; j++) { @@ -355,6 +365,7 @@ struct ExpressionAnalyzer { assert(above == iff->ifTrue || above == iff->ifFalse); // continue down } else { + if (curr->is<Drop>()) return false; return true; // all other node types use the result } } @@ -429,8 +440,7 @@ struct ExpressionAnalyzer { break; } case Expression::Id::LoopId: { - if (!noteNames(left->cast<Loop>()->out, right->cast<Loop>()->out)) return false; - if (!noteNames(left->cast<Loop>()->in, right->cast<Loop>()->in)) return false; + if (!noteNames(left->cast<Loop>()->name, right->cast<Loop>()->name)) return false; PUSH(Loop, body); break; } @@ -481,15 +491,16 @@ struct ExpressionAnalyzer { } case Expression::Id::SetLocalId: { CHECK(SetLocal, index); + CHECK(SetLocal, type); // for tee/set PUSH(SetLocal, value); break; } case Expression::Id::GetGlobalId: { - CHECK(GetGlobal, index); + CHECK(GetGlobal, name); break; } case Expression::Id::SetGlobalId: { - CHECK(SetGlobal, index); + CHECK(SetGlobal, name); PUSH(SetGlobal, value); break; } @@ -505,6 +516,7 @@ struct ExpressionAnalyzer { CHECK(Store, bytes); CHECK(Store, offset); CHECK(Store, align); + CHECK(Store, valueType); PUSH(Store, ptr); PUSH(Store, value); break; @@ -530,6 +542,10 @@ struct ExpressionAnalyzer { PUSH(Select, condition); break; } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } case Expression::Id::ReturnId: { PUSH(Return, value); break; @@ -640,8 +656,7 @@ struct ExpressionAnalyzer { break; } case Expression::Id::LoopId: { - noteName(curr->cast<Loop>()->out); - noteName(curr->cast<Loop>()->in); + noteName(curr->cast<Loop>()->name); PUSH(Loop, body); break; } @@ -696,11 +711,11 @@ struct ExpressionAnalyzer { break; } case Expression::Id::GetGlobalId: { - HASH(GetGlobal, index); + HASH_NAME(GetGlobal, name); break; } case Expression::Id::SetGlobalId: { - HASH(SetGlobal, index); + HASH_NAME(SetGlobal, name); PUSH(SetGlobal, value); break; } @@ -716,6 +731,7 @@ struct ExpressionAnalyzer { HASH(Store, bytes); HASH(Store, offset); HASH(Store, align); + HASH(Store, valueType); PUSH(Store, ptr); PUSH(Store, value); break; @@ -742,6 +758,10 @@ struct ExpressionAnalyzer { PUSH(Select, condition); break; } + case Expression::Id::DropId: { + PUSH(Drop, value); + break; + } case Expression::Id::ReturnId: { PUSH(Return, value); break; @@ -770,6 +790,65 @@ struct ExpressionAnalyzer { } }; +// Adds drop() operations where necessary. This lets you not worry about adding drop when +// generating code. +struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop, Visitor<AutoDrop>>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new AutoDrop; } + + void visitBlock(Block* curr) { + if (curr->list.size() == 0) return; + for (Index i = 0; i < curr->list.size() - 1; i++) { + auto* child = curr->list[i]; + if (isConcreteWasmType(child->type)) { + curr->list[i] = Builder(*getModule()).makeDrop(child); + } + } + auto* last = curr->list.back(); + expressionStack.push_back(last); + if (isConcreteWasmType(last->type) && !ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) { + curr->list.back() = Builder(*getModule()).makeDrop(last); + } + expressionStack.pop_back(); + curr->finalize(); // we may have changed our type + } + + void visitFunction(Function* curr) { + if (curr->result == none && isConcreteWasmType(curr->body->type)) { + curr->body = Builder(*getModule()).makeDrop(curr->body); + } + } +}; + +// Finalizes a node + +struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, Visitor<ReFinalize>>> { + void visitBlock(Block *curr) { curr->finalize(); } + void visitIf(If *curr) { curr->finalize(); } + void visitLoop(Loop *curr) { curr->finalize(); } + void visitBreak(Break *curr) { curr->finalize(); } + void visitSwitch(Switch *curr) { curr->finalize(); } + void visitCall(Call *curr) { curr->finalize(); } + void visitCallImport(CallImport *curr) { curr->finalize(); } + void visitCallIndirect(CallIndirect *curr) { curr->finalize(); } + void visitGetLocal(GetLocal *curr) { curr->finalize(); } + void visitSetLocal(SetLocal *curr) { curr->finalize(); } + void visitGetGlobal(GetGlobal *curr) { curr->finalize(); } + void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } + void visitLoad(Load *curr) { curr->finalize(); } + void visitStore(Store *curr) { curr->finalize(); } + void visitConst(Const *curr) { curr->finalize(); } + void visitUnary(Unary *curr) { curr->finalize(); } + void visitBinary(Binary *curr) { curr->finalize(); } + void visitSelect(Select *curr) { curr->finalize(); } + void visitDrop(Drop *curr) { curr->finalize(); } + void visitReturn(Return *curr) { curr->finalize(); } + void visitHost(Host *curr) { curr->finalize(); } + void visitNop(Nop *curr) { curr->finalize(); } + void visitUnreachable(Unreachable *curr) { curr->finalize(); } +}; + } // namespace wasm #endif // wasm_ast_utils_h diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 83c626eb1..3ce24bdbd 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -338,16 +338,13 @@ BinaryenExpressionRef BinaryenIf(BinaryenModuleRef module, BinaryenExpressionRef return static_cast<Expression*>(ret); } -BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* out, const char* in, BinaryenExpressionRef body) { - if (out && !in) abort(); - auto* ret = Builder(*((Module*)module)).makeLoop(out ? Name(out) : Name(), in ? Name(in) : Name(), (Expression*)body); +BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* name, BinaryenExpressionRef body) { + auto* ret = Builder(*((Module*)module)).makeLoop(name ? Name(name) : Name(), (Expression*)body); if (tracing) { auto id = noteExpression(ret); std::cout << " expressions[" << id << "] = BinaryenLoop(the_module, "; - traceNameOrNULL(out); - std::cout << ", "; - traceNameOrNULL(in); + traceNameOrNULL(name); std::cout << ", expressions[" << expressions[body] << "]);\n"; } @@ -488,6 +485,21 @@ BinaryenExpressionRef BinaryenSetLocal(BinaryenModuleRef module, BinaryenIndex i ret->index = index; ret->value = (Expression*)value; + ret->setTee(false); + ret->finalize(); + return static_cast<Expression*>(ret); +} +BinaryenExpressionRef BinaryenTeeLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value) { + auto* ret = ((Module*)module)->allocator.alloc<SetLocal>(); + + if (tracing) { + auto id = noteExpression(ret); + std::cout << " expressions[" << id << "] = BinaryenTeeLocal(the_module, " << index << ", expressions[" << expressions[value] << "]);\n"; + } + + ret->index = index; + ret->value = (Expression*)value; + ret->setTee(true); ret->finalize(); return static_cast<Expression*>(ret); } @@ -508,12 +520,12 @@ BinaryenExpressionRef BinaryenLoad(BinaryenModuleRef module, uint32_t bytes, int ret->finalize(); return static_cast<Expression*>(ret); } -BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value) { +BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value, BinaryenType type) { auto* ret = ((Module*)module)->allocator.alloc<Store>(); if (tracing) { auto id = noteExpression(ret); - std::cout << " expressions[" << id << "] = BinaryenStore(the_module, " << bytes << ", " << offset << ", " << align << ", expressions[" << expressions[ptr] << "], expressions[" << expressions[value] << "]);\n"; + std::cout << " expressions[" << id << "] = BinaryenStore(the_module, " << bytes << ", " << offset << ", " << align << ", expressions[" << expressions[ptr] << "], expressions[" << expressions[value] << "], " << type << ");\n"; } ret->bytes = bytes; @@ -521,6 +533,7 @@ BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, ui ret->align = align ? align : bytes; ret->ptr = (Expression*)ptr; ret->value = (Expression*)value; + ret->valueType = WasmType(type); ret->finalize(); return static_cast<Expression*>(ret); } @@ -584,6 +597,18 @@ BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressio ret->finalize(); return static_cast<Expression*>(ret); } +BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value) { + auto* ret = ((Module*)module)->allocator.alloc<Drop>(); + + if (tracing) { + auto id = noteExpression(ret); + std::cout << " expressions[" << id << "] = BinaryenDrop(the_module, expressions[" << expressions[value] << "]);\n"; + } + + ret->value = (Expression*)value; + ret->finalize(); + return static_cast<Expression*>(ret); +} BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value) { auto* ret = Builder(*((Module*)module)).makeReturn((Expression*)value); @@ -692,7 +717,8 @@ BinaryenImportRef BinaryenAddImport(BinaryenModuleRef module, const char* intern ret->name = internalName; ret->module = externalModuleName; ret->base = externalBaseName; - ret->type = (FunctionType*)type; + ret->functionType = (FunctionType*)type; + ret->kind = Import::Function; wasm->addImport(ret); return ret; } @@ -780,7 +806,13 @@ void BinaryenSetMemory(BinaryenModuleRef module, BinaryenIndex initial, Binaryen auto* wasm = (Module*)module; wasm->memory.initial = initial; wasm->memory.max = maximum; - if (exportName) wasm->memory.exportName = exportName; + if (exportName) { + auto memoryExport = make_unique<Export>(); + memoryExport->name = exportName; + memoryExport->value = Name::fromInt(0); + memoryExport->kind = Export::Memory; + wasm->addExport(memoryExport.release()); + } for (BinaryenIndex i = 0; i < numSegments; i++) { wasm->memory.segments.emplace_back((Expression*)segmentOffsets[i], segments[i], segmentSizes[i]); } @@ -829,6 +861,17 @@ void BinaryenModuleOptimize(BinaryenModuleRef module) { passRunner.run(); } +void BinaryenModuleAutoDrop(BinaryenModuleRef module) { + if (tracing) { + std::cout << " BinaryenModuleAutoDrop(the_module);\n"; + } + + Module* wasm = (Module*)module; + PassRunner passRunner(wasm); + passRunner.add<AutoDrop>(); + passRunner.run(); +} + size_t BinaryenModuleWrite(BinaryenModuleRef module, char* output, size_t outputSize) { if (tracing) { std::cout << " // BinaryenModuleWrite\n"; diff --git a/src/binaryen-c.h b/src/binaryen-c.h index e2086072b..9ecb3f409 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -265,8 +265,7 @@ typedef void* BinaryenExpressionRef; BinaryenExpressionRef BinaryenBlock(BinaryenModuleRef module, const char* name, BinaryenExpressionRef* children, BinaryenIndex numChildren); // If: ifFalse can be NULL BinaryenExpressionRef BinaryenIf(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, BinaryenExpressionRef ifFalse); -// Loop: both out and in can be NULL, or just out can be NULL -BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* out, const char* in, BinaryenExpressionRef body); +BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* in, BinaryenExpressionRef body); // Break: value and condition can be NULL BinaryenExpressionRef BinaryenBreak(BinaryenModuleRef module, const char* name, BinaryenExpressionRef condition, BinaryenExpressionRef value); // Switch: value can be NULL @@ -294,14 +293,16 @@ BinaryenExpressionRef BinaryenCallIndirect(BinaryenModuleRef module, BinaryenExp // for more details. BinaryenExpressionRef BinaryenGetLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenType type); BinaryenExpressionRef BinaryenSetLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value); +BinaryenExpressionRef BinaryenTeeLocal(BinaryenModuleRef module, BinaryenIndex index, BinaryenExpressionRef value); // Load: align can be 0, in which case it will be the natural alignment (equal to bytes) BinaryenExpressionRef BinaryenLoad(BinaryenModuleRef module, uint32_t bytes, int8_t signed_, uint32_t offset, uint32_t align, BinaryenType type, BinaryenExpressionRef ptr); // Store: align can be 0, in which case it will be the natural alignment (equal to bytes) -BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value); +BinaryenExpressionRef BinaryenStore(BinaryenModuleRef module, uint32_t bytes, uint32_t offset, uint32_t align, BinaryenExpressionRef ptr, BinaryenExpressionRef value, BinaryenType type); BinaryenExpressionRef BinaryenConst(BinaryenModuleRef module, struct BinaryenLiteral value); BinaryenExpressionRef BinaryenUnary(BinaryenModuleRef module, BinaryenOp op, BinaryenExpressionRef value); BinaryenExpressionRef BinaryenBinary(BinaryenModuleRef module, BinaryenOp op, BinaryenExpressionRef left, BinaryenExpressionRef right); BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, BinaryenExpressionRef ifFalse); +BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value); // Return: value can be NULL BinaryenExpressionRef BinaryenReturn(BinaryenModuleRef module, BinaryenExpressionRef value); // Host: name may be NULL @@ -366,8 +367,13 @@ int BinaryenModuleValidate(BinaryenModuleRef module); // Run the standard optimization passes on the module. void BinaryenModuleOptimize(BinaryenModuleRef module); +// Auto-generate drop() operations where needed. This lets you generate code without +// worrying about where they are needed. (It is more efficient to do it yourself, +// but simpler to use autodrop). +void BinaryenModuleAutoDrop(BinaryenModuleRef module); + // Serialize a module into binary form. -// @return how many bytes were written. This will be less than or equal to bufferSize +// @return how many bytes were written. This will be less than or equal to outputSize size_t BinaryenModuleWrite(BinaryenModuleRef module, char* output, size_t outputSize); // Deserialize a module from binary form. diff --git a/src/cfg/Relooper.cpp b/src/cfg/Relooper.cpp index e75adbce5..8a5337ed0 100644 --- a/src/cfg/Relooper.cpp +++ b/src/cfg/Relooper.cpp @@ -383,7 +383,7 @@ wasm::Expression* MultipleShape::Render(RelooperBuilder& Builder, bool InLoop) { // LoopShape wasm::Expression* LoopShape::Render(RelooperBuilder& Builder, bool InLoop) { - wasm::Expression* Ret = Builder.makeLoop(wasm::Name(), Builder.getShapeContinueName(Id), Inner->Render(Builder, true)); + wasm::Expression* Ret = Builder.makeLoop(Builder.getShapeContinueName(Id), Inner->Render(Builder, true)); Ret = HandleFollowupMultiples(Ret, this, Builder, InLoop); if (Next) { Ret = Builder.makeSequence(Ret, Next->Render(Builder, InLoop)); diff --git a/src/cfg/cfg-traversal.h b/src/cfg/cfg-traversal.h index d690de4aa..5bc593691 100644 --- a/src/cfg/cfg-traversal.h +++ b/src/cfg/cfg-traversal.h @@ -36,7 +36,7 @@ namespace wasm { template<typename SubType, typename VisitorType, typename Contents> -struct CFGWalker : public PostWalker<SubType, VisitorType> { +struct CFGWalker : public ControlFlowWalker<SubType, VisitorType> { // public interface @@ -57,7 +57,7 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> { // traversal state BasicBlock* currBasicBlock; // the current block in play during traversal - std::map<Name, std::vector<BasicBlock*>> branches; + std::map<Expression*, std::vector<BasicBlock*>> branches; // a block or loop => its branches std::vector<BasicBlock*> ifStack; std::vector<BasicBlock*> loopStack; @@ -74,7 +74,7 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> { static void doEndBlock(SubType* self, Expression** currp) { auto* curr = (*currp)->cast<Block>(); if (!curr->name.is()) return; - auto iter = self->branches.find(curr->name); + auto iter = self->branches.find(curr); if (iter == self->branches.end()) return; auto& origins = iter->second; if (origins.size() == 0) return; @@ -83,12 +83,10 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> { doStartBasicBlock(self, currp); self->link(last, self->currBasicBlock); // fallthrough // branches to the new one - if (curr->name.is()) { - for (auto* origin : origins) { - self->link(origin, self->currBasicBlock); - } - self->branches.erase(curr->name); + for (auto* origin : origins) { + self->link(origin, self->currBasicBlock); } + self->branches.erase(curr); } static void doStartIfTrue(SubType* self, Expression** currp) { @@ -130,30 +128,22 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> { auto* last = self->currBasicBlock; doStartBasicBlock(self, currp); self->link(last, self->currBasicBlock); // fallthrough - // branches to the new one auto* curr = (*currp)->cast<Loop>(); - if (curr->out.is()) { - auto& origins = self->branches[curr->out]; - for (auto* origin : origins) { - self->link(origin, self->currBasicBlock); - } - self->branches.erase(curr->out); - } // branches to the top of the loop - if (curr->in.is()) { + if (curr->name.is()) { auto* loopStart = self->loopStack.back(); - auto& origins = self->branches[curr->in]; + auto& origins = self->branches[curr]; for (auto* origin : origins) { self->link(origin, loopStart); } - self->branches.erase(curr->in); + self->branches.erase(curr); } self->loopStack.pop_back(); } static void doEndBreak(SubType* self, Expression** currp) { auto* curr = (*currp)->cast<Break>(); - self->branches[curr->name].push_back(self->currBasicBlock); // branch to the target + self->branches[self->findBreakTarget(curr->name)].push_back(self->currBasicBlock); // branch to the target auto* last = self->currBasicBlock; doStartBasicBlock(self, currp); if (curr->condition) { @@ -166,12 +156,12 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> { std::set<Name> seen; // we might see the same label more than once; do not spam branches for (Name target : curr->targets) { if (!seen.count(target)) { - self->branches[target].push_back(self->currBasicBlock); // branch to the target + self->branches[self->findBreakTarget(target)].push_back(self->currBasicBlock); // branch to the target seen.insert(target); } } if (!seen.count(curr->default_)) { - self->branches[curr->default_].push_back(self->currBasicBlock); // branch to the target + self->branches[self->findBreakTarget(curr->default_)].push_back(self->currBasicBlock); // branch to the target } doStartBasicBlock(self, currp); } @@ -182,14 +172,10 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> { switch (curr->_id) { case Expression::Id::BlockId: { self->pushTask(SubType::doEndBlock, currp); - self->pushTask(SubType::doVisitBlock, currp); - auto& list = curr->cast<Block>()->list; - for (int i = int(list.size()) - 1; i >= 0; i--) { - self->pushTask(SubType::scan, &list[i]); - } break; } case Expression::Id::IfId: { + self->pushTask(SubType::doPostVisitControlFlow, currp); self->pushTask(SubType::doEndIf, currp); self->pushTask(SubType::doVisitIf, currp); auto* ifFalse = curr->cast<If>()->ifFalse; @@ -200,44 +186,40 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> { self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue); self->pushTask(SubType::doStartIfTrue, currp); self->pushTask(SubType::scan, &curr->cast<If>()->condition); - break; + self->pushTask(SubType::doPreVisitControlFlow, currp); + return; // don't do anything else } case Expression::Id::LoopId: { self->pushTask(SubType::doEndLoop, currp); - self->pushTask(SubType::doVisitLoop, currp); - self->pushTask(SubType::scan, &curr->cast<Loop>()->body); - self->pushTask(SubType::doStartLoop, currp); break; } case Expression::Id::BreakId: { self->pushTask(SubType::doEndBreak, currp); - self->pushTask(SubType::doVisitBreak, currp); - self->maybePushTask(SubType::scan, &curr->cast<Break>()->condition); - self->maybePushTask(SubType::scan, &curr->cast<Break>()->value); break; } case Expression::Id::SwitchId: { self->pushTask(SubType::doEndSwitch, currp); - self->pushTask(SubType::doVisitSwitch, currp); - self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value); - self->pushTask(SubType::scan, &curr->cast<Switch>()->condition); break; } case Expression::Id::ReturnId: { self->pushTask(SubType::doStartBasicBlock, currp); - self->pushTask(SubType::doVisitReturn, currp); - self->maybePushTask(SubType::scan, &curr->cast<Return>()->value); break; } case Expression::Id::UnreachableId: { self->pushTask(SubType::doStartBasicBlock, currp); - self->pushTask(SubType::doVisitUnreachable, currp); break; } - default: { - // other node types do not have control flow, use regular post-order - PostWalker<SubType, VisitorType>::scan(self, currp); + default: {} + } + + ControlFlowWalker<SubType, VisitorType>::scan(self, currp); + + switch (curr->_id) { + case Expression::Id::LoopId: { + self->pushTask(SubType::doStartLoop, currp); + break; } + default: {} } } @@ -246,7 +228,7 @@ struct CFGWalker : public PostWalker<SubType, VisitorType> { doStartBasicBlock(static_cast<SubType*>(this), nullptr); entry = currBasicBlock; - PostWalker<SubType, VisitorType>::doWalkFunction(func); + ControlFlowWalker<SubType, VisitorType>::doWalkFunction(func); assert(branches.size() == 0); assert(ifStack.size() == 0); diff --git a/src/js/wasm.js-post.js b/src/js/wasm.js-post.js index ae42ef175..1c5e6b0b1 100644 --- a/src/js/wasm.js-post.js +++ b/src/js/wasm.js-post.js @@ -29,6 +29,7 @@ function integrateWasmJS(Module) { // inputs var method = Module['wasmJSMethod'] || {{{ wasmJSMethod }}} || 'native-wasm,interpret-s-expr'; // by default, try native and then .wast + Module['wasmJSMethod'] = method; var wasmTextFile = Module['wasmTextFile'] || {{{ wasmTextFile }}}; var wasmBinaryFile = Module['wasmBinaryFile'] || {{{ wasmBinaryFile }}}; @@ -100,19 +101,15 @@ function integrateWasmJS(Module) { } var oldView = new Int8Array(oldBuffer); var newView = new Int8Array(newBuffer); - if ({{{ WASM_BACKEND }}}) { - // memory segments arrived in the wast, do not trample them + + // If we have a mem init file, do not trample it + if (!memoryInitializer) { oldView.set(newView.subarray(STATIC_BASE, STATIC_BASE + STATIC_BUMP), STATIC_BASE); } + newView.set(oldView); updateGlobalBuffer(newBuffer); updateGlobalBufferViews(); - Module['reallocBuffer'] = function(size) { - size = Math.ceil(size / wasmPageSize) * wasmPageSize; // round up to wasm page size - var old = Module['buffer']; - exports['__growWasmMemory'](size / wasmPageSize); // tiny wasm method that just does grow_memory - return Module['buffer'] !== old ? Module['buffer'] : null; // if it was reallocated, it changed - }; } var WasmTypes = { @@ -123,26 +120,6 @@ function integrateWasmJS(Module) { f64: 4 }; - // wasm lacks globals, so asm2wasm maps them into locations in memory. that information cannot - // be present in the wasm output of asm2wasm, so we store it in a side file. If we load asm2wasm - // output, either generated ahead of time or on the client, we need to apply those mapped - // globals after loading the module. - function applyMappedGlobals(globalsFileBase) { - var mappedGlobals = JSON.parse(Module['read'](globalsFileBase + '.mappedGlobals')); - for (var name in mappedGlobals) { - var global = mappedGlobals[name]; - if (!global.import) continue; // non-imports are initialized to zero in the typed array anyhow, so nothing to do here - var value = lookupImport(global.module, global.base); - var address = global.address; - switch (global.type) { - case WasmTypes.i32: Module['HEAP32'][address >> 2] = value; break; - case WasmTypes.f32: Module['HEAPF32'][address >> 2] = value; break; - case WasmTypes.f64: Module['HEAPF64'][address >> 3] = value; break; - default: abort(); - } - } - } - function fixImports(imports) { if (!{{{ WASM_BACKEND }}}) return imports; var ret = {}; @@ -206,9 +183,7 @@ function integrateWasmJS(Module) { return false; } exports = instance.exports; - mergeMemory(exports.memory); - - applyMappedGlobals(wasmBinaryFile); + if (exports.memory) mergeMemory(exports.memory); Module["usingWasm"] = true; @@ -237,6 +212,13 @@ function integrateWasmJS(Module) { info.global = global; info.env = env; + if (!('memoryBase' in env)) { + env['memoryBase'] = STATIC_BASE; // tell the memory segments where to place themselves + } + if (!('tableBase' in env)) { + env['tableBase'] = 0; // tell the memory segments where to place themselves + } + wasmJS['providedTotalMemory'] = Module['buffer'].byteLength; // Prepare to generate wasm, using either asm2wasm or s-exprs @@ -271,12 +253,6 @@ function integrateWasmJS(Module) { Module['newBuffer'] = null; } - if (method == 'interpret-s-expr') { - applyMappedGlobals(wasmTextFile); - } else if (method == 'interpret-binary') { - applyMappedGlobals(wasmBinaryFile); - } - exports = wasmJS['asmExports']; return exports; @@ -285,6 +261,14 @@ function integrateWasmJS(Module) { // We may have a preloaded value in Module.asm, save it Module['asmPreload'] = Module['asm']; + // Memory growth integration code + Module['reallocBuffer'] = function(size) { + size = Math.ceil(size / wasmPageSize) * wasmPageSize; // round up to wasm page size + var old = Module['buffer']; + exports['__growWasmMemory'](size / wasmPageSize); // tiny wasm method that just does grow_memory + return Module['buffer'] !== old ? Module['buffer'] : null; // if it was reallocated, it changed + }; + // Provide an "asm.js function" for the application, called to "link" the asm.js module. We instantiate // the wasm module at that time, and it receives imports and provides exports and so forth, the app // doesn't need to care that it is wasm or olyfilled wasm or asm.js. @@ -293,6 +277,14 @@ function integrateWasmJS(Module) { global = fixImports(global); env = fixImports(env); + // import memory and table + if (!env['memory']) { + env['memory'] = providedBuffer; + } + if (!env['table']) { + env['table'] = new Array(1024); + } + // try the methods. each should return the exports if it succeeded var exports; diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 2ccf5f040..1b4c65562 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -2,9 +2,7 @@ SET(passes_SOURCES pass.cpp CoalesceLocals.cpp DeadCodeElimination.cpp - DropReturnValues.cpp DuplicateFunctionElimination.cpp - LowerIfElse.cpp MergeBlocks.cpp Metrics.cpp NameManager.cpp diff --git a/src/passes/CoalesceLocals.cpp b/src/passes/CoalesceLocals.cpp index 2c707b614..2063a20db 100644 --- a/src/passes/CoalesceLocals.cpp +++ b/src/passes/CoalesceLocals.cpp @@ -485,9 +485,19 @@ void CoalesceLocals::applyIndices(std::vector<Index>& indices, Expression* root) // in addition, we can optimize out redundant copies and ineffective sets GetLocal* get; if ((get = set->value->dynCast<GetLocal>()) && get->index == set->index) { - *action.origin = get; // further optimizations may get rid of the get, if this is in a place where the output does not matter + if (set->isTee()) { + *action.origin = get; + } else { + ExpressionManipulator::nop(set); + } } else if (!action.effective) { *action.origin = set->value; // value may have no side effects, further optimizations can eliminate it + if (!set->isTee()) { + // we need to drop it + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(set); + drop->value = *action.origin; + *action.origin = drop; + } } } } diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index de1191131..b30b8ffbd 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -31,6 +31,7 @@ #include <wasm.h> #include <pass.h> #include <ast_utils.h> +#include <wasm-builder.h> namespace wasm { @@ -131,12 +132,8 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V } void visitLoop(Loop* curr) { - if (curr->in.is()) { - reachableBreaks.erase(curr->in); - } - if (curr->out.is()) { - reachable = reachable || reachableBreaks.count(curr->out); - reachableBreaks.erase(curr->out); + if (curr->name.is()) { + reachableBreaks.erase(curr->name); } if (isDead(curr->body)) { replaceCurrent(curr->body); @@ -191,6 +188,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V case Expression::Id::UnaryId: DELEGATE(Unary); case Expression::Id::BinaryId: DELEGATE(Binary); case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::DropId: DELEGATE(Drop); case Expression::Id::ReturnId: DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); case Expression::Id::NopId: DELEGATE(Nop); @@ -226,46 +224,52 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V // other things + Expression* drop(Expression* toDrop) { + if (toDrop->is<Unreachable>()) return toDrop; + return Builder(*getModule()).makeDrop(toDrop); + } + template<typename T> - void handleCall(T* curr, Expression* initial) { + Expression* handleCall(T* curr) { for (Index i = 0; i < curr->operands.size(); i++) { if (isDead(curr->operands[i])) { - if (i > 0 || initial != nullptr) { + if (i > 0) { auto* block = getModule()->allocator.alloc<Block>(); - Index newSize = i + 1 + (initial ? 1 : 0); + Index newSize = i + 1; block->list.resize(newSize); Index j = 0; - if (initial) { - block->list[j] = initial; - j++; - } for (; j < newSize; j++) { - block->list[j] = curr->operands[j - (initial ? 1 : 0)]; + block->list[j] = drop(curr->operands[j]); } block->finalize(); - replaceCurrent(block); + return replaceCurrent(block); } else { - replaceCurrent(curr->operands[i]); + return replaceCurrent(curr->operands[i]); } - return; } } + return curr; } void visitCall(Call* curr) { - handleCall(curr, nullptr); + handleCall(curr); } void visitCallImport(CallImport* curr) { - handleCall(curr, nullptr); + handleCall(curr); } void visitCallIndirect(CallIndirect* curr) { + if (handleCall(curr) != curr) return; if (isDead(curr->target)) { - replaceCurrent(curr->target); - return; + auto* block = getModule()->allocator.alloc<Block>(); + for (auto* operand : curr->operands) { + block->list.push_back(drop(operand)); + } + block->list.push_back(curr->target); + block->finalize(); + replaceCurrent(block); } - handleCall(curr, curr->target); } void visitSetLocal(SetLocal* curr) { @@ -288,7 +292,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->value)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->ptr; + block->list[0] = drop(curr->ptr); block->list[1] = curr->value; block->finalize(); replaceCurrent(block); @@ -309,7 +313,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->right)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->left; + block->list[0] = drop(curr->left); block->list[1] = curr->right; block->finalize(); replaceCurrent(block); @@ -324,7 +328,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->ifFalse)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->ifTrue; + block->list[0] = drop(curr->ifTrue); block->list[1] = curr->ifFalse; block->finalize(); replaceCurrent(block); @@ -333,8 +337,8 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->condition)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(3); - block->list[0] = curr->ifTrue; - block->list[1] = curr->ifFalse; + block->list[0] = drop(curr->ifTrue); + block->list[1] = drop(curr->ifFalse); block->list[2] = curr->condition; block->finalize(); replaceCurrent(block); diff --git a/src/passes/DropReturnValues.cpp b/src/passes/DropReturnValues.cpp deleted file mode 100644 index 8715f3f61..000000000 --- a/src/passes/DropReturnValues.cpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2016 WebAssembly Community Group participants - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -// Stops using return values from set_local and store nodes. -// - -#include <wasm.h> -#include <pass.h> -#include <ast_utils.h> -#include <wasm-builder.h> - -namespace wasm { - -struct DropReturnValues : public WalkerPass<PostWalker<DropReturnValues, Visitor<DropReturnValues>>> { - bool isFunctionParallel() override { return true; } - - Pass* create() override { return new DropReturnValues; } - - std::vector<Expression*> expressionStack; - - void visitSetLocal(SetLocal* curr) { - if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) { - Builder builder(*getModule()); - replaceCurrent(builder.makeSequence( - curr, - builder.makeGetLocal(curr->index, curr->type) - )); - } - } - - void visitStore(Store* curr) { - if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) { - Index index = getFunction()->getNumLocals(); - getFunction()->vars.emplace_back(curr->type); - Builder builder(*getModule()); - replaceCurrent(builder.makeSequence( - builder.makeSequence( - builder.makeSetLocal(index, curr->value), - curr - ), - builder.makeGetLocal(index, curr->type) - )); - curr->value = builder.makeGetLocal(index, curr->type); - } - } - - static void visitPre(DropReturnValues* self, Expression** currp) { - self->expressionStack.push_back(*currp); - } - - static void visitPost(DropReturnValues* self, Expression** currp) { - self->expressionStack.pop_back(); - } - - static void scan(DropReturnValues* self, Expression** currp) { - self->pushTask(visitPost, currp); - - WalkerPass<PostWalker<DropReturnValues, Visitor<DropReturnValues>>>::scan(self, currp); - - self->pushTask(visitPre, currp); - } -}; - -Pass *createDropReturnValuesPass() { - return new DropReturnValues(); -} - -} // namespace wasm - diff --git a/src/passes/LowerIfElse.cpp b/src/passes/LowerIfElse.cpp deleted file mode 100644 index b566e8207..000000000 --- a/src/passes/LowerIfElse.cpp +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2015 WebAssembly Community Group participants - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -// Lowers if (x) y else z into -// -// L: { -// if (x) break (y) L -// z -// } -// -// This is useful for investigating how beneficial if_else is. -// - -#include <memory> - -#include <wasm.h> -#include <pass.h> - -namespace wasm { - -struct LowerIfElse : public WalkerPass<PostWalker<LowerIfElse, Visitor<LowerIfElse>>> { - MixedArena* allocator; - std::unique_ptr<NameManager> namer; - - void prepare(PassRunner* runner, Module *module) override { - allocator = runner->allocator; - namer = make_unique<NameManager>(); - namer->run(runner, module); - } - - void visitIf(If *curr) { - if (curr->ifFalse) { - auto block = allocator->alloc<Block>(); - auto name = namer->getUnique("L"); // TODO: getUniqueInFunction - block->name = name; - block->list.push_back(curr); - block->list.push_back(curr->ifFalse); - block->finalize(); - curr->ifFalse = nullptr; - auto break_ = allocator->alloc<Break>(); - break_->name = name; - break_->value = curr->ifTrue; - curr->ifTrue = break_; - replaceCurrent(block); - } - } -}; - -Pass *createLowerIfElsePass() { - return new LowerIfElse(); -} - -} // namespace wasm diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index 686bb5d75..bde9397a8 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -64,9 +64,40 @@ #include <wasm.h> #include <pass.h> #include <ast_utils.h> +#include <wasm-builder.h> namespace wasm { +struct SwitchFinder : public ControlFlowWalker<SwitchFinder, Visitor<SwitchFinder>> { + Expression* origin; + bool found = false; + + void visitSwitch(Switch* curr) { + if (findBreakTarget(curr->default_) == origin) { + found = true; + return; + } + for (auto& target : curr->targets) { + if (findBreakTarget(target) == origin) { + found = true; + return; + } + } + } +}; + +struct BreakValueDropper : public ControlFlowWalker<BreakValueDropper, Visitor<BreakValueDropper>> { + Expression* origin; + + void visitBreak(Break* curr) { + if (curr->value && findBreakTarget(curr->name) == origin) { + Builder builder(*getModule()); + replaceCurrent(builder.makeSequence(builder.makeDrop(curr->value), curr)); + curr->value = nullptr; + } + } +}; + struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBlocks>>> { bool isFunctionParallel() override { return true; } @@ -74,10 +105,46 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc void visitBlock(Block *curr) { bool more = true; + bool changed = false; while (more) { more = false; for (size_t i = 0; i < curr->list.size(); i++) { Block* child = curr->list[i]->dynCast<Block>(); + if (!child) { + // if we have a child that is (drop (block ..)) then we can move the drop into the block, and remove br values. this allows more merging, + auto* drop = curr->list[i]->dynCast<Drop>(); + if (drop) { + child = drop->value->dynCast<Block>(); + if (child) { + if (child->name.is()) { + Expression* expression = child; + // if there is a switch targeting us, we can't do it - we can't remove the value from other targets too + SwitchFinder finder; + finder.origin = child; + finder.walk(expression); + if (finder.found) { + child = nullptr; + } else { + // fix up breaks + BreakValueDropper fixer; + fixer.origin = child; + fixer.setModule(getModule()); + fixer.walk(expression); + } + } + if (child) { + // we can do it! + // reuse the drop + drop->value = child->list.back(); + child->list.back() = drop; + child->finalize(); + curr->list[i] = child; + more = true; + changed = true; + } + } + } + } if (!child) continue; if (child->name.is()) continue; // named blocks can have breaks to them (and certainly do, if we ran RemoveUnusedNames and RemoveUnusedBrs) ExpressionList merged(getModule()->allocator); @@ -92,9 +159,11 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc } curr->list = merged; more = true; + changed = true; break; } } + if (changed) curr->finalize(); } Block* optimize(Expression* curr, Expression*& child, Block* outer = nullptr, Expression** dependency1 = nullptr, Expression** dependency2 = nullptr) { diff --git a/src/passes/NameManager.cpp b/src/passes/NameManager.cpp index df8b34557..9f0198c2f 100644 --- a/src/passes/NameManager.cpp +++ b/src/passes/NameManager.cpp @@ -37,8 +37,7 @@ void NameManager::visitBlock(Block* curr) { names.insert(curr->name); } void NameManager::visitLoop(Loop* curr) { - names.insert(curr->out); - names.insert(curr->in); + names.insert(curr->name); } void NameManager::visitBreak(Break* curr) { names.insert(curr->name); diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 5eea38bdc..43ddc954c 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -32,14 +32,17 @@ struct PrintSExpression : public Visitor<PrintSExpression> { const char *maybeSpace; const char *maybeNewLine; - bool fullAST = false; // whether to not elide nodes in output when possible - // (like implicit blocks) + bool full = false; // whether to not elide nodes in output when possible + // (like implicit blocks) and to emit types Module* currModule = nullptr; Function* currFunction = nullptr; PrintSExpression(std::ostream& o) : o(o) { setMinify(false); + if (getenv("BINARYEN_PRINT_FULL")) { + full = std::stoi(getenv("BINARYEN_PRINT_FULL")); + } } void setMinify(bool minify_) { @@ -48,7 +51,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { maybeNewLine = minify ? "" : "\n"; } - void setFullAST(bool fullAST_) { fullAST = fullAST_; } + void setFull(bool full_) { full = full_; } void incIndent() { if (minify) return; @@ -64,6 +67,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void printFullLine(Expression *expression) { !minify && doIndent(o, indent); + if (full) { + o << "[" << printWasmType(expression->type) << "] "; + } visit(expression); o << maybeNewLine; } @@ -79,10 +85,6 @@ struct PrintSExpression : public Visitor<PrintSExpression> { return name; } - Name printableGlobal(Index index) { - return currModule->getGlobal(index)->name; - } - std::ostream& printName(Name name) { // we need to quote names if they have tricky chars if (strpbrk(name.str, "()")) { @@ -99,6 +101,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> { while (1) { if (stack.size() > 0) doIndent(o, indent); stack.push_back(curr); + if (full) { + o << "[" << printWasmType(curr->type) << "] "; + } printOpening(o, "block"); if (curr->name.is()) { o << ' '; @@ -135,13 +140,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> { incIndent(); printFullLine(curr->condition); // ifTrue and False have implict blocks, avoid printing them if possible - if (!fullAST && curr->ifTrue->is<Block>() && curr->ifTrue->dynCast<Block>()->name.isNull() && curr->ifTrue->dynCast<Block>()->list.size() == 1) { + if (!full && curr->ifTrue->is<Block>() && curr->ifTrue->dynCast<Block>()->name.isNull() && curr->ifTrue->dynCast<Block>()->list.size() == 1) { printFullLine(curr->ifTrue->dynCast<Block>()->list.back()); } else { printFullLine(curr->ifTrue); } if (curr->ifFalse) { - if (!fullAST && curr->ifFalse->is<Block>() && curr->ifFalse->dynCast<Block>()->name.isNull() && curr->ifFalse->dynCast<Block>()->list.size() == 1) { + if (!full && curr->ifFalse->is<Block>() && curr->ifFalse->dynCast<Block>()->name.isNull() && curr->ifFalse->dynCast<Block>()->list.size() == 1) { printFullLine(curr->ifFalse->dynCast<Block>()->list.back()); } else { printFullLine(curr->ifFalse); @@ -151,16 +156,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void visitLoop(Loop *curr) { printOpening(o, "loop"); - if (curr->out.is()) { - o << ' ' << curr->out; - assert(curr->in.is()); // if just one is printed, it must be the in - } - if (curr->in.is()) { - o << ' ' << curr->in; + if (curr->name.is()) { + o << ' ' << curr->name; } incIndent(); auto block = curr->body->dynCast<Block>(); - if (!fullAST && block && block->name.isNull()) { + if (!full && block && block->name.isNull()) { // wasm spec has loops containing children directly, while our ast // has a single child for simplicity. print out the optimal form. for (auto expression : block->list) { @@ -229,26 +230,33 @@ struct PrintSExpression : public Visitor<PrintSExpression> { void visitCallIndirect(CallIndirect *curr) { printOpening(o, "call_indirect ") << curr->fullType; incIndent(); - printFullLine(curr->target); for (auto operand : curr->operands) { printFullLine(operand); } + printFullLine(curr->target); decIndent(); } void visitGetLocal(GetLocal *curr) { printOpening(o, "get_local ") << printableLocal(curr->index) << ')'; } void visitSetLocal(SetLocal *curr) { - printOpening(o, "set_local ") << printableLocal(curr->index); + if (curr->isTee()) { + printOpening(o, "tee_local "); + } else { + printOpening(o, "set_local "); + } + o << printableLocal(curr->index); incIndent(); printFullLine(curr->value); decIndent(); } void visitGetGlobal(GetGlobal *curr) { - printOpening(o, "get_global ") << printableGlobal(curr->index) << ')'; + printOpening(o, "get_global "); + printName(curr->name) << ')'; } void visitSetGlobal(SetGlobal *curr) { - printOpening(o, "set_global ") << printableGlobal(curr->index); + printOpening(o, "set_global "); + printName(curr->name); incIndent(); printFullLine(curr->value); decIndent(); @@ -281,7 +289,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void visitStore(Store *curr) { o << '('; - prepareColor(o) << printWasmType(curr->type) << ".store"; + prepareColor(o) << printWasmType(curr->valueType) << ".store"; if (curr->bytes < 4 || (curr->type == i64 && curr->bytes < 8)) { if (curr->bytes == 1) { o << '8'; @@ -466,9 +474,16 @@ struct PrintSExpression : public Visitor<PrintSExpression> { printFullLine(curr->condition); decIndent(); } + void visitDrop(Drop *curr) { + o << '('; + prepareColor(o) << "drop"; + incIndent(); + printFullLine(curr->value); + decIndent(); + } void visitReturn(Return *curr) { printOpening(o, "return"); - if (!curr->value || curr->value->is<Nop>()) { + if (!curr->value) { // avoid a new line just for the parens o << ')'; return; @@ -499,11 +514,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> { printMinorOpening(o, "unreachable") << ')'; } // Module-level visitors - void visitFunctionType(FunctionType *curr, bool full=false) { - if (full) { - printOpening(o, "type") << ' '; - printName(curr->name) << " (func"; - } + void visitFunctionType(FunctionType *curr, Name* internalName = nullptr) { + o << "(func"; + if (internalName) o << ' ' << *internalName; if (curr->params.size() > 0) { o << maybeSpace; printMinorOpening(o, "param"); @@ -516,27 +529,39 @@ struct PrintSExpression : public Visitor<PrintSExpression> { o << maybeSpace; printMinorOpening(o, "result ") << printWasmType(curr->result) << ')'; } - if (full) { - o << "))"; - } + o << ")"; } void visitImport(Import *curr) { printOpening(o, "import "); - printName(curr->name) << ' '; printText(o, curr->module.str) << ' '; - printText(o, curr->base.str); - if (curr->type) visitFunctionType(curr->type); + printText(o, curr->base.str) << ' '; + switch (curr->kind) { + case Export::Function: if (curr->functionType) visitFunctionType(curr->functionType, &curr->name); break; + case Export::Table: o << "(table " << curr->name << ")"; break; + case Export::Memory: o << "(memory " << curr->name << ")"; break; + case Export::Global: o << "(global " << curr->name << ' ' << printWasmType(curr->globalType) << ")"; break; + default: WASM_UNREACHABLE(); + } o << ')'; } void visitExport(Export *curr) { printOpening(o, "export "); - printText(o, curr->name.str) << ' '; - printName(curr->value) << ')'; + printText(o, curr->name.str) << " ("; + switch (curr->kind) { + case Export::Function: o << "func"; break; + case Export::Table: o << "table"; break; + case Export::Memory: o << "memory"; break; + case Export::Global: o << "global"; break; + default: WASM_UNREACHABLE(); + } + o << ' '; + printName(curr->value) << "))"; } void visitGlobal(Global *curr) { printOpening(o, "global "); - printName(curr->name) << ' ' << printWasmType(curr->type); - printFullLine(curr->init); + printName(curr->name) << ' '; + o << printWasmType(curr->type) << ' '; + visit(curr->init); o << ')'; } void visitFunction(Function *curr) { @@ -564,7 +589,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } // It is ok to emit a block here, as a function can directly contain a list, even if our // ast avoids that for simplicity. We can just do that optimization here.. - if (!fullAST && curr->body->is<Block>() && curr->body->cast<Block>()->name.isNull()) { + if (!full && curr->body->is<Block>() && curr->body->cast<Block>()->name.isNull()) { Block* block = curr->body->cast<Block>(); for (auto item : block->list) { printFullLine(item); @@ -575,7 +600,8 @@ struct PrintSExpression : public Visitor<PrintSExpression> { decIndent(); } void visitTable(Table *curr) { - printOpening(o, "table") << ' ' << curr->initial; + printOpening(o, "table") << ' '; + o << curr->initial; if (curr->max && curr->max != Table::kMaxSize) o << ' ' << curr->max; o << " anyfunc)\n"; doIndent(o, indent); @@ -589,15 +615,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> { o << ')'; } } - void visitModule(Module *curr) { - currModule = curr; - printOpening(o, "module", true); - incIndent(); - doIndent(o, indent); - printOpening(o, "memory") << ' ' << curr->memory.initial; - if (curr->memory.max && curr->memory.max != Memory::kMaxSize) o << ' ' << curr->memory.max; + void visitMemory(Memory* curr) { + printOpening(o, "memory") << ' '; + o << curr->initial; + if (curr->max && curr->max != Memory::kMaxSize) o << ' ' << curr->max; o << ")\n"; - for (auto segment : curr->memory.segments) { + for (auto segment : curr->segments) { doIndent(o, indent); printOpening(o, "data ", true); visit(segment.offset); @@ -624,12 +647,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } o << "\")\n"; } - if (curr->memory.exportName.is()) { - doIndent(o, indent); - printOpening(o, "export "); - printText(o, curr->memory.exportName.str) << " memory)"; - o << maybeNewLine; - } + } + void visitModule(Module *curr) { + currModule = curr; + printOpening(o, "module", true); + incIndent(); + doIndent(o, indent); + visitMemory(&curr->memory); if (curr->start.is()) { doIndent(o, indent); printOpening(o, "start") << ' ' << curr->start << ')'; @@ -637,8 +661,10 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } for (auto& child : curr->functionTypes) { doIndent(o, indent); - visitFunctionType(child.get(), true); - o << maybeNewLine; + printOpening(o, "type") << ' '; + printName(child->name) << ' '; + visitFunctionType(child.get()); + o << ")" << maybeNewLine; } for (auto& child : curr->imports) { doIndent(o, indent); @@ -707,7 +733,7 @@ public: void run(PassRunner* runner, Module* module) override { PrintSExpression print(o); - print.setFullAST(true); + print.setFull(true); print.visitModule(module); } }; @@ -718,9 +744,17 @@ Pass *createFullPrinterPass() { // Print individual expressions -std::ostream& WasmPrinter::printExpression(Expression* expression, std::ostream& o, bool minify) { +std::ostream& WasmPrinter::printExpression(Expression* expression, std::ostream& o, bool minify, bool full) { + if (!expression) { + o << "(null expression)"; + return o; + } PrintSExpression print(o); print.setMinify(minify); + if (full) { + print.setFull(true); + o << "[" << printWasmType(expression->type) << "] "; + } print.visit(expression); return o; } diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp index 0b3f50049..19d6c3eb1 100644 --- a/src/passes/RemoveImports.cpp +++ b/src/passes/RemoveImports.cpp @@ -37,7 +37,7 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports, Visitor<Remov } void visitCallImport(CallImport *curr) { - WasmType type = module->getImport(curr->target)->type->result; + WasmType type = module->getImport(curr->target)->functionType->result; if (type == none) { replaceCurrent(allocator->alloc<Nop>()); } else { diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp index 263cea655..59d3af6fc 100644 --- a/src/passes/RemoveUnusedBrs.cpp +++ b/src/passes/RemoveUnusedBrs.cpp @@ -189,7 +189,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R // finally, we may have simplified ifs enough to turn them into selects struct Selectifier : public WalkerPass<PostWalker<Selectifier, Visitor<Selectifier>>> { void visitIf(If* curr) { - if (curr->ifFalse) { + if (curr->ifFalse && isConcreteWasmType(curr->ifTrue->type) && isConcreteWasmType(curr->ifFalse->type)) { // if with else, consider turning it into a select if there is no control flow // TODO: estimate cost EffectAnalyzer condition(curr->condition); diff --git a/src/passes/RemoveUnusedNames.cpp b/src/passes/RemoveUnusedNames.cpp index 9c6743479..8e24f9549 100644 --- a/src/passes/RemoveUnusedNames.cpp +++ b/src/passes/RemoveUnusedNames.cpp @@ -76,34 +76,12 @@ struct RemoveUnusedNames : public WalkerPass<PostWalker<RemoveUnusedNames, Visit } } handleBreakTarget(curr->name); - if (curr->name.is() && curr->list.size() == 1) { - auto* child = curr->list[0]->dynCast<Loop>(); - if (child && !child->out.is()) { - // we have just one child, this loop, and it lacks an out label. So this block's name is doing just that! - child->out = curr->name; - replaceCurrent(child); - } - } } void visitLoop(Loop *curr) { - handleBreakTarget(curr->in); - // Loops can have just 'in', but cannot have just 'out' - auto out = curr->out; - handleBreakTarget(curr->out); - if (curr->out.is() && !curr->in.is()) { - auto* block = getModule()->allocator.alloc<Block>(); - block->name = out; - block->list.push_back(curr->body); - replaceCurrent(block); - } - if (curr->in.is() && !curr->out.is()) { - auto* child = curr->body->dynCast<Block>(); - if (child && child->name.is()) { - // we have just one child, this block, and we lack an out label. So we can take the block's! - curr->out = child->name; - child->name = Name(); - } + handleBreakTarget(curr->name); + if (!curr->name.is()) { + replaceCurrent(curr->body); } } diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index d9786e62c..5315edac4 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -55,7 +55,13 @@ struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemo void visitSetLocal(SetLocal *curr) { if ((*numGetLocals)[curr->index] == 0) { - replaceCurrent(curr->value); + auto* value = curr->value; + if (curr->isTee()) { + replaceCurrent(value); + } else { + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(curr); + drop->value = value; + } } } }; @@ -180,7 +186,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, auto found = sinkables.find(curr->index); if (found != sinkables.end()) { // sink it, and nop the origin - replaceCurrent(*found->second.item); + auto* set = (*found->second.item)->cast<SetLocal>(); + replaceCurrent(set); + assert(!set->isTee()); + set->setTee(true); // reuse the getlocal that is dying *found->second.item = curr; ExpressionManipulator::nop(curr); @@ -189,6 +198,16 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, } } + void visitDrop(Drop* curr) { + // collapse drop-tee into set, which can occur if a get was sunk into a tee + auto* set = curr->value->dynCast<SetLocal>(); + if (set) { + assert(set->isTee()); + set->setTee(false); + replaceCurrent(set); + } + } + void checkInvalidations(EffectAnalyzer& effects) { // TODO: this is O(bad) std::vector<Index> invalidated; @@ -225,7 +244,11 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, // store is dead, leave just the value auto found = self->sinkables.find(set->index); if (found != self->sinkables.end()) { - *found->second.item = (*found->second.item)->cast<SetLocal>()->value; + auto* previous = (*found->second.item)->cast<SetLocal>(); + assert(!previous->isTee()); + auto* previousValue = previous->value; + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(previous); + drop->value = previousValue; self->sinkables.erase(found); self->anotherCycle = true; } @@ -236,15 +259,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, self->checkInvalidations(effects); } - if (set) { - // we may be a replacement for the current node, update the stack - self->expressionStack.pop_back(); - self->expressionStack.push_back(set); - if (!ExpressionAnalyzer::isResultUsed(self->expressionStack, self->getFunction())) { - Index index = set->index; - assert(self->sinkables.count(index) == 0); - self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); - } + if (set && !set->isTee()) { + Index index = set->index; + assert(self->sinkables.count(index) == 0); + self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); } self->expressionStack.pop_back(); diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index 0195427d7..db42a994a 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -25,13 +25,11 @@ namespace wasm { -struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { +struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>>> { bool isFunctionParallel() override { return true; } Pass* create() override { return new Vacuum; } - std::vector<Expression*> expressionStack; - // returns nullptr if curr is dead, curr if it must stay as is, or another node if it can be replaced Expression* optimize(Expression* curr, bool resultUsed) { while (1) { @@ -41,6 +39,7 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { case Expression::Id::BlockId: return curr; // not always needed, but handled in visitBlock() case Expression::Id::IfId: return curr; // not always needed, but handled in visitIf() case Expression::Id::LoopId: return curr; // not always needed, but handled in visitLoop() + case Expression::Id::DropId: return curr; // not always needed, but handled in visitDrop() case Expression::Id::BreakId: case Expression::Id::SwitchId: @@ -51,6 +50,8 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { case Expression::Id::LoadId: case Expression::Id::StoreId: case Expression::Id::ReturnId: + case Expression::Id::GetGlobalId: + case Expression::Id::SetGlobalId: case Expression::Id::HostId: case Expression::Id::UnreachableId: return curr; // always needed @@ -189,7 +190,7 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { // no else if (curr->ifTrue->is<Nop>()) { // no nothing - replaceCurrent(curr->condition); + replaceCurrent(Builder(*getModule()).makeDrop(curr->condition)); } } } @@ -198,21 +199,30 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { if (curr->body->is<Nop>()) ExpressionManipulator::nop(curr); } - static void visitPre(Vacuum* self, Expression** currp) { - self->expressionStack.push_back(*currp); - } - - static void visitPost(Vacuum* self, Expression** currp) { - self->expressionStack.pop_back(); - } - - // override scan to add a pre and a post check task to all nodes - static void scan(Vacuum* self, Expression** currp) { - self->pushTask(visitPost, currp); - - WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>>::scan(self, currp); - - self->pushTask(visitPre, currp); + void visitDrop(Drop* curr) { + // if the drop input has no side effects, it can be wiped out + if (!EffectAnalyzer(curr->value).hasSideEffects()) { + ExpressionManipulator::nop(curr); + return; + } + // sink a drop into an arm of an if-else if the other arm ends in an unreachable, as it if is a branch, this can make that branch optimizable and more vaccuming possible + auto* iff = curr->value->dynCast<If>(); + if (iff && iff->ifFalse && isConcreteWasmType(iff->type)) { + // reuse the drop in both cases + if (iff->ifTrue->type == unreachable) { + assert(isConcreteWasmType(iff->ifFalse->type)); + curr->value = iff->ifFalse; + iff->ifFalse = curr; + iff->type = none; + replaceCurrent(iff); + } else if (iff->ifFalse->type == unreachable) { + assert(isConcreteWasmType(iff->ifTrue->type)); + curr->value = iff->ifTrue; + iff->ifTrue = curr; + iff->type = none; + replaceCurrent(iff); + } + } } void visitFunction(Function* curr) { diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index f27eccf6a..437b023bc 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -64,9 +64,7 @@ void PassRegistry::registerPasses() { registerPass("coalesce-locals", "reduce # of locals by coalescing", createCoalesceLocalsPass); registerPass("coalesce-locals-learning", "reduce # of locals by coalescing and learning", createCoalesceLocalsWithLearningPass); registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass); - registerPass("drop-return-values", "stops relying on return values from set_local and store", createDropReturnValuesPass); registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass); - registerPass("lower-if-else", "lowers if-elses into ifs, blocks and branches", createLowerIfElsePass); registerPass("merge-blocks", "merges blocks to their parents", createMergeBlocksPass); registerPass("metrics", "reports metrics", createMetricsPass); registerPass("nm", "name list", createNameListPass); @@ -211,6 +209,9 @@ void PassRunner::run() { } void PassRunner::runFunction(Function* func) { + if (debug) { + std::cerr << "[PassRunner] running passes on function " << func->name << std::endl; + } for (auto* pass : passes) { runPassOnFunction(pass, func); } diff --git a/src/passes/passes.h b/src/passes/passes.h index 731536d7d..4bb76edad 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -25,7 +25,6 @@ class Pass; Pass *createCoalesceLocalsPass(); Pass *createCoalesceLocalsWithLearningPass(); Pass *createDeadCodeEliminationPass(); -Pass *createDropReturnValuesPass(); Pass *createDuplicateFunctionEliminationPass(); Pass *createLowerIfElsePass(); Pass *createMergeBlocksPass(); diff --git a/src/s2wasm.h b/src/s2wasm.h index a23cbb549..8c04f7891 100644 --- a/src/s2wasm.h +++ b/src/s2wasm.h @@ -753,6 +753,7 @@ class S2WasmBuilder { set->index = func->getLocalIndex(assign); set->value = curr; set->type = curr->type; + set->setTee(false); addToBlock(set); } }; @@ -834,7 +835,7 @@ class S2WasmBuilder { auto makeStore = [&](WasmType type) { skipComma(); auto curr = allocator->alloc<Store>(); - curr->type = type; + curr->valueType = type; int32_t bytes = getInt() / CHAR_BIT; curr->bytes = bytes > 0 ? bytes : getWasmTypeSize(type); Name assign = getAssign(); @@ -849,6 +850,7 @@ class S2WasmBuilder { curr->align = 1U << getInt(attributes[0] + 8); } curr->value = inputs[1]; + curr->finalize(); setOutput(curr, assign); }; auto makeSelect = [&](WasmType type) { @@ -1062,7 +1064,7 @@ class S2WasmBuilder { if (target->is<Block>()) { return target->cast<Block>()->name; } else { - return target->cast<Loop>()->in; + return target->cast<Loop>()->name; } }; // fixups @@ -1097,10 +1099,9 @@ class S2WasmBuilder { } else if (match("loop")) { auto curr = allocator->alloc<Loop>(); addToBlock(curr); - curr->in = getNextLabel(); - curr->out = getNextLabel(); + curr->name = getNextLabel(); auto block = allocator->alloc<Block>(); - block->name = curr->out; // temporary, fake - this way, on bstack we have the right label at the right offset for a br + block->name = getNextLabel(); curr->body = block; loopBlocks.push_back(block); bstack.push_back(block); @@ -1146,6 +1147,7 @@ class S2WasmBuilder { Name assign = getAssign(); skipComma(); auto curr = allocator->alloc<SetLocal>(); + curr->setTee(true); curr->index = func->getLocalIndex(getAssign()); skipComma(); curr->value = getInput(); diff --git a/src/shell-interface.h b/src/shell-interface.h index f332307ad..ee9ff166a 100644 --- a/src/shell-interface.h +++ b/src/shell-interface.h @@ -90,11 +90,11 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { ShellExternalInterface() : memory() {} - void init(Module& wasm) override { + void init(Module& wasm, ModuleInstance& instance) override { memory.resize(wasm.memory.initial * wasm::Memory::kPageSize); // apply memory segments for (auto& segment : wasm.memory.segments) { - Address offset = ConstantExpressionRunner().visit(segment.offset).value.geti32(); + Address offset = ConstantExpressionRunner(instance.globals).visit(segment.offset).value.geti32(); assert(offset + segment.data.size() <= wasm.memory.initial * wasm::Memory::kPageSize); for (size_t i = 0; i != segment.data.size(); ++i) { memory.set(offset + i, segment.data[i]); @@ -103,7 +103,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { table.resize(wasm.table.initial); for (auto& segment : wasm.table.segments) { - Address offset = ConstantExpressionRunner().visit(segment.offset).value.geti32(); + Address offset = ConstantExpressionRunner(instance.globals).visit(segment.offset).value.geti32(); assert(offset + segment.data.size() <= wasm.table.initial); for (size_t i = 0; i != segment.data.size(); ++i) { table[offset + i] = segment.data[i]; @@ -111,6 +111,8 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { } } + void importGlobals(std::map<Name, Literal>& globals, Module& wasm) override {} + Literal callImport(Import *import, LiteralList& arguments) override { if (import->module == SPECTEST && import->base == PRINT) { for (auto argument : arguments) { @@ -126,10 +128,9 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { abort(); } - Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) override { + Literal callTable(Index index, LiteralList& arguments, WasmType result, ModuleInstance& instance) override { if (index >= table.size()) trap("callTable overflow"); auto* func = instance.wasm.getFunction(table[index]); - if (func->type.is() && func->type != type) trap("callIndirect: bad type"); if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments"); for (size_t i = 0; i < func->params.size(); i++) { if (func->params[i] != arguments[i].type) { @@ -167,7 +168,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { } void store(Store* store, Address addr, Literal value) override { - switch (store->type) { + switch (store->valueType) { case i32: { switch (store->bytes) { case 1: memory.set<int8_t>(addr, value.geti32()); break; diff --git a/src/support/colors.h b/src/support/colors.h index d81ecbc16..fb5267ce1 100644 --- a/src/support/colors.h +++ b/src/support/colors.h @@ -42,6 +42,15 @@ inline void grey(std::ostream& stream) { outputColorCode(stream, 0x08); } inline void green(std::ostream& stream) { outputColorCode(stream, 0x02); } inline void blue(std::ostream& stream) { outputColorCode(stream, 0x09); } inline void bold(std::ostream& stream) { /* Do nothing */ } +#else +inline void normal(std::ostream& stream) {} +inline void red(std::ostream& stream) {} +inline void magenta(std::ostream& stream) {} +inline void orange(std::ostream& stream) {} +inline void grey(std::ostream& stream) {} +inline void green(std::ostream& stream) {} +inline void blue(std::ostream& stream) {} +inline void bold(std::ostream& stream) {} #endif }; diff --git a/src/tools/asm2wasm.cpp b/src/tools/asm2wasm.cpp index 7f1a07d1d..031b1fd23 100644 --- a/src/tools/asm2wasm.cpp +++ b/src/tools/asm2wasm.cpp @@ -21,6 +21,7 @@ #include "support/colors.h" #include "support/command-line.h" #include "support/file.h" +#include "wasm-builder.h" #include "wasm-printing.h" #include "asm2wasm.h" @@ -40,9 +41,13 @@ int main(int argc, const char *argv[]) { o->extra["output"] = argument; Colors::disable(); }) - .add("--mapped-globals", "-m", "Mapped globals", Options::Arguments::One, + .add("--mapped-globals", "-n", "Mapped globals", Options::Arguments::One, [](Options *o, const std::string &argument) { - o->extra["mapped globals"] = argument; + std::cerr << "warning: the --mapped-globals/-m option is deprecated (a mapped globals file is no longer needed as we use wasm globals)" << std::endl; + }) + .add("--mem-init", "-t", "Import a memory initialization file into the output module", Options::Arguments::One, + [](Options *o, const std::string &argument) { + o->extra["mem init"] = argument; }) .add("--total-memory", "-m", "Total memory size", Options::Arguments::One, [](Options *o, const std::string &argument) { @@ -62,10 +67,6 @@ int main(int argc, const char *argv[]) { }); options.parse(argc, argv); - const auto &mg_it = options.extra.find("mapped globals"); - const char *mappedGlobals = - mg_it == options.extra.end() ? nullptr : mg_it->second.c_str(); - const auto &tm_it = options.extra.find("total memory"); size_t totalMemory = tm_it == options.extra.end() ? 16 * 1024 * 1024 : atoi(tm_it->second.c_str()); @@ -95,15 +96,18 @@ int main(int argc, const char *argv[]) { Asm2WasmBuilder asm2wasm(wasm, pre.memoryGrowth, options.debug, imprecise, opts); asm2wasm.processAsm(asmjs); + // import mem init file, if provided + const auto &memInit = options.extra.find("mem init"); + if (memInit != options.extra.end()) { + auto filename = memInit->second.c_str(); + auto data(read_file<std::vector<char>>(filename, Flags::Binary, options.debug ? Flags::Debug : Flags::Release)); + // create the memory segment + wasm.memory.segments.emplace_back(Builder(wasm).makeGetGlobal(Name("memoryBase"), i32), data); + } + if (options.debug) std::cerr << "printing..." << std::endl; Output output(options.extra["output"], Flags::Text, options.debug ? Flags::Debug : Flags::Release); WasmPrinter::printModule(&wasm, output.getStream()); - if (mappedGlobals) { - if (options.debug) - std::cerr << "serializing mapped globals..." << std::endl; - asm2wasm.serializeMappedGlobals(mappedGlobals); - } - if (options.debug) std::cerr << "done." << std::endl; } diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 72e8e13ba..a95ced87f 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -125,14 +125,19 @@ static void run_asserts(size_t* i, bool* checked, Module* wasm, // maybe parsed ok, but otherwise incorrect invalid = !WasmValidator().validate(wasm); } - assert(invalid); + if (!invalid) { + Colors::red(std::cerr); + std::cerr << "[should have been invalid]\n"; + Colors::normal(std::cerr); + std::cerr << &wasm << '\n'; + abort(); + } } else if (id == INVOKE) { assert(wasm); Invocation invocation(curr, instance.get(), *builder->get()); invocation.invoke(); - } else { + } else if (wasm) { // if no wasm, we skipped the module // an invoke test - assert(wasm); bool trapped = false; WASM_UNUSED(trapped); Literal result; @@ -169,6 +174,7 @@ static void run_asserts(size_t* i, bool* checked, Module* wasm, int main(int argc, const char* argv[]) { Name entry; + std::set<size_t> skipped; Options options("wasm-shell", "Execute .wast files"); options @@ -176,6 +182,21 @@ int main(int argc, const char* argv[]) { "--entry", "-e", "call the entry point after parsing the module", Options::Arguments::One, [&entry](Options*, const std::string& argument) { entry = argument; }) + .add( + "--skip", "-s", "skip input on certain lines (comma-separated-list)", + Options::Arguments::One, + [&skipped](Options*, const std::string& argument) { + size_t i = 0; + while (i < argument.size()) { + auto ending = argument.find(',', i); + if (ending == std::string::npos) { + ending = argument.size(); + } + auto sub = argument.substr(i, ending - i); + skipped.insert(atoi(sub.c_str())); + i = ending + 1; + } + }) .add_positional("INFILE", Options::Arguments::One, [](Options* o, const std::string& argument) { o->extra["infile"] = argument; @@ -195,9 +216,19 @@ int main(int argc, const char* argv[]) { size_t i = 0; while (i < root.size()) { Element& curr = *root[i]; + if (skipped.count(curr.line) > 0) { + Colors::green(std::cerr); + std::cerr << "SKIPPING [line: " << curr.line << "]\n"; + Colors::normal(std::cerr); + i++; + continue; + } IString id = curr[0]->str(); if (id == MODULE) { if (options.debug) std::cerr << "parsing s-expressions to wasm...\n"; + Colors::green(std::cerr); + std::cerr << "BUILDING MODULE [line: " << curr.line << "]\n"; + Colors::normal(std::cerr); Module wasm; std::unique_ptr<SExpressionWasmBuilder> builder; builder = wasm::make_unique<SExpressionWasmBuilder>(wasm, *root[i]); diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 04bebddc5..40671b082 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -234,7 +234,7 @@ namespace BinaryConsts { enum Meta { Magic = 0x6d736100, - Version = 11 + Version = 0x0c }; namespace Section { @@ -412,6 +412,7 @@ enum ASTNodes { CallFunction = 0x16, CallIndirect = 0x17, CallImport = 0x18, + TeeLocal = 0x19, GetGlobal = 0x1a, SetGlobal = 0x1b, @@ -426,6 +427,7 @@ enum ASTNodes { TableSwitch = 0x08, Return = 0x09, Unreachable = 0x0a, + Drop = 0x0b, End = 0x0f }; @@ -529,8 +531,7 @@ public: if (debug) std::cerr << "== writeMemory" << std::endl; auto start = startSection(BinaryConsts::Section::Memory); o << U32LEB(wasm->memory.initial) - << U32LEB(wasm->memory.max) - << int8_t(wasm->memory.exportName.is()); // export memory + << U32LEB(wasm->memory.max); finishSection(start); } @@ -571,7 +572,14 @@ public: o << U32LEB(wasm->imports.size()); for (auto& import : wasm->imports) { if (debug) std::cerr << "write one" << std::endl; - o << U32LEB(getFunctionTypeIndex(import->type->name)); + o << U32LEB(import->kind); + switch (import->kind) { + case Export::Function: o << U32LEB(getFunctionTypeIndex(import->functionType->name)); + case Export::Table: break; + case Export::Memory: break; + case Export::Global: o << binaryWasmType(import->globalType);break; + default: WASM_UNREACHABLE(); + } writeInlineString(import->module.str); writeInlineString(import->base.str); } @@ -690,7 +698,14 @@ public: o << U32LEB(wasm->exports.size()); for (auto& curr : wasm->exports) { if (debug) std::cerr << "write one" << std::endl; - o << U32LEB(getFunctionIndex(curr->value)); + o << U32LEB(curr->kind); + switch (curr->kind) { + case Export::Function: o << U32LEB(getFunctionIndex(curr->value)); break; + case Export::Table: o << U32LEB(0); break; + case Export::Memory: o << U32LEB(0); break; + case Export::Global: o << U32LEB(getGlobalIndex(curr->value)); break; + default: WASM_UNREACHABLE(); + } writeInlineString(curr->name.str); } finishSection(start); @@ -726,7 +741,7 @@ public: return mappedImports[name]; } - std::map<Name, uint32_t> mappedFunctions; // name of the Function => index + std::map<Name, uint32_t> mappedFunctions; // name of the Function => index uint32_t getFunctionIndex(Name name) { if (!mappedFunctions.size()) { // Create name => index mapping. @@ -739,6 +754,26 @@ public: return mappedFunctions[name]; } + std::map<Name, uint32_t> mappedGlobals; // name of the Global => index. first imported globals, then internal globals + uint32_t getGlobalIndex(Name name) { + if (!mappedGlobals.size()) { + // Create name => index mapping. + for (auto& import : wasm->imports) { + if (import->kind != Import::Global) continue; + assert(mappedGlobals.count(import->name) == 0); + auto index = mappedGlobals.size(); + mappedGlobals[import->name] = index; + } + for (size_t i = 0; i < wasm->globals.size(); i++) { + assert(mappedGlobals.count(wasm->globals[i]->name) == 0); + auto index = mappedGlobals.size(); + mappedGlobals[wasm->globals[i]->name] = index; + } + } + assert(mappedGlobals.count(name)); + return mappedGlobals[name]; + } + void writeFunctionTable() { if (wasm->table.segments.size() == 0) return; if (debug) std::cerr << "== writeFunctionTable" << std::endl; @@ -873,11 +908,9 @@ public: void visitLoop(Loop *curr) { if (debug) std::cerr << "zz node: Loop" << std::endl; o << int8_t(BinaryConsts::Loop); - breakStack.push_back(curr->out); - breakStack.push_back(curr->in); + breakStack.push_back(curr->name); recursePossibleBlockContents(curr->body); breakStack.pop_back(); - breakStack.pop_back(); o << int8_t(BinaryConsts::End); } @@ -939,18 +972,18 @@ public: o << int8_t(BinaryConsts::GetLocal) << U32LEB(mappedLocals[curr->index]); } void visitSetLocal(SetLocal *curr) { - if (debug) std::cerr << "zz node: SetLocal" << std::endl; + if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl; recurse(curr->value); - o << int8_t(BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]); + o << int8_t(curr->isTee() ? BinaryConsts::TeeLocal : BinaryConsts::SetLocal) << U32LEB(mappedLocals[curr->index]); } void visitGetGlobal(GetGlobal *curr) { if (debug) std::cerr << "zz node: GetGlobal " << (o.size() + 1) << std::endl; - o << int8_t(BinaryConsts::GetGlobal) << U32LEB(curr->index); + o << int8_t(BinaryConsts::GetGlobal) << U32LEB(getGlobalIndex(curr->name)); } void visitSetGlobal(SetGlobal *curr) { if (debug) std::cerr << "zz node: SetGlobal" << std::endl; recurse(curr->value); - o << int8_t(BinaryConsts::SetGlobal) << U32LEB(curr->index); + o << int8_t(BinaryConsts::SetGlobal) << U32LEB(getGlobalIndex(curr->name)); } void emitMemoryAccess(size_t alignment, size_t bytes, uint32_t offset) { @@ -991,7 +1024,7 @@ public: if (debug) std::cerr << "zz node: Store" << std::endl; recurse(curr->ptr); recurse(curr->value); - switch (curr->type) { + switch (curr->valueType) { case i32: { switch (curr->bytes) { case 1: o << int8_t(BinaryConsts::I32StoreMem8); break; @@ -1219,6 +1252,11 @@ public: if (debug) std::cerr << "zz node: Unreachable" << std::endl; o << int8_t(BinaryConsts::Unreachable); } + void visitDrop(Drop *curr) { + if (debug) std::cerr << "zz node: Drop" << std::endl; + recurse(curr->value); + o << int8_t(BinaryConsts::Drop); + } }; class WasmBinaryBuilder { @@ -1432,10 +1470,6 @@ public: if (debug) std::cerr << "== readMemory" << std::endl; wasm.memory.initial = getU32LEB(); wasm.memory.max = getU32LEB(); - auto exportMemory = getInt8(); - if (exportMemory) { - wasm.memory.exportName = Name("memory"); - } } void readSignatures() { @@ -1472,10 +1506,20 @@ public: if (debug) std::cerr << "read one" << std::endl; auto curr = new Import; curr->name = Name(std::string("import$") + std::to_string(i)); - auto index = getU32LEB(); - assert(index < wasm.functionTypes.size()); - curr->type = wasm.getFunctionType(index); - assert(curr->type->name.is()); + curr->kind = (Import::Kind)getU32LEB(); + switch (curr->kind) { + case Export::Function: { + auto index = getU32LEB(); + assert(index < wasm.functionTypes.size()); + curr->functionType = wasm.getFunctionType(index); + assert(curr->functionType->name.is()); + break; + } + case Export::Table: break; + case Export::Memory: break; + case Export::Global: curr->globalType = getWasmType(); break; + default: WASM_UNREACHABLE(); + } curr->module = getInlineString(); curr->base = getInlineString(); wasm.addImport(curr); @@ -1573,8 +1617,8 @@ public: for (size_t i = 0; i < num; i++) { if (debug) std::cerr << "read one" << std::endl; auto curr = new Export; + curr->kind = (Export::Kind)getU32LEB(); auto index = getU32LEB(); - assert(index < functionTypes.size()); curr->name = getInlineString(); exportIndexes[curr] = index; } @@ -1627,6 +1671,24 @@ public: return ret; } + std::map<Index, Name> mappedGlobals; // index of the Global => name. first imported globals, then internal globals + Name getGlobalName(Index index) { + if (!mappedGlobals.size()) { + // Create name => index mapping. + for (auto& import : wasm.imports) { + if (import->kind != Import::Global) continue; + auto index = mappedGlobals.size(); + mappedGlobals[index] = import->name; + } + for (size_t i = 0; i < wasm.globals.size(); i++) { + auto index = mappedGlobals.size(); + mappedGlobals[index] = wasm.globals[i]->name; + } + } + assert(mappedGlobals.count(index)); + return mappedGlobals[index]; + } + void processFunctions() { for (auto& func : functions) { wasm.addFunction(func); @@ -1639,7 +1701,13 @@ public: for (auto& iter : exportIndexes) { Export* curr = iter.first; - curr->value = wasm.functions[iter.second]->name; + switch (curr->kind) { + case Export::Function: curr->value = wasm.functions[iter.second]->name; break; + case Export::Table: curr->value = Name::fromInt(0); break; + case Export::Memory: curr->value = Name::fromInt(0); break; + case Export::Global: curr->value = getGlobalName(iter.second); break; + default: WASM_UNREACHABLE(); + } wasm.addExport(curr); } @@ -1728,13 +1796,15 @@ public: case BinaryConsts::CallImport: visitCallImport((curr = allocator.alloc<CallImport>())->cast<CallImport>()); break; case BinaryConsts::CallIndirect: visitCallIndirect((curr = allocator.alloc<CallIndirect>())->cast<CallIndirect>()); break; case BinaryConsts::GetLocal: visitGetLocal((curr = allocator.alloc<GetLocal>())->cast<GetLocal>()); break; - case BinaryConsts::SetLocal: visitSetLocal((curr = allocator.alloc<SetLocal>())->cast<SetLocal>()); break; + case BinaryConsts::TeeLocal: + case BinaryConsts::SetLocal: visitSetLocal((curr = allocator.alloc<SetLocal>())->cast<SetLocal>(), code); break; case BinaryConsts::GetGlobal: visitGetGlobal((curr = allocator.alloc<GetGlobal>())->cast<GetGlobal>()); break; case BinaryConsts::SetGlobal: visitSetGlobal((curr = allocator.alloc<SetGlobal>())->cast<SetGlobal>()); break; case BinaryConsts::Select: visitSelect((curr = allocator.alloc<Select>())->cast<Select>()); break; case BinaryConsts::Return: visitReturn((curr = allocator.alloc<Return>())->cast<Return>()); break; case BinaryConsts::Nop: visitNop((curr = allocator.alloc<Nop>())->cast<Nop>()); break; case BinaryConsts::Unreachable: visitUnreachable((curr = allocator.alloc<Unreachable>())->cast<Unreachable>()); break; + case BinaryConsts::Drop: visitDrop((curr = allocator.alloc<Drop>())->cast<Drop>()); break; case BinaryConsts::End: case BinaryConsts::Else: curr = nullptr; break; default: { @@ -1832,13 +1902,10 @@ public: } void visitLoop(Loop *curr) { if (debug) std::cerr << "zz node: Loop" << std::endl; - curr->out = getNextLabel(); - curr->in = getNextLabel(); - breakStack.push_back(curr->out); - breakStack.push_back(curr->in); + curr->name = getNextLabel(); + breakStack.push_back(curr->name); curr->body = getMaybeBlock(); breakStack.pop_back(); - breakStack.pop_back(); curr->finalize(); } @@ -1890,7 +1957,7 @@ public: WASM_UNUSED(arity); auto import = wasm.getImport(getU32LEB()); curr->target = import->name; - auto type = import->type; + auto type = import->functionType; assert(type); auto num = type->params.size(); assert(num == arity); @@ -1922,25 +1989,35 @@ public: assert(curr->index < currFunction->getNumLocals()); curr->type = currFunction->getLocalType(curr->index); } - void visitSetLocal(SetLocal *curr) { - if (debug) std::cerr << "zz node: SetLocal" << std::endl; + void visitSetLocal(SetLocal *curr, uint8_t code) { + if (debug) std::cerr << "zz node: Set|TeeLocal" << std::endl; curr->index = getU32LEB(); assert(curr->index < currFunction->getNumLocals()); curr->value = popExpression(); curr->type = curr->value->type; + curr->setTee(code == BinaryConsts::TeeLocal); } void visitGetGlobal(GetGlobal *curr) { if (debug) std::cerr << "zz node: GetGlobal " << pos << std::endl; - curr->index = getU32LEB(); - assert(curr->index < wasm.globals.size()); - curr->type = wasm.globals[curr->index]->type; + auto index = getU32LEB(); + curr->name = getGlobalName(index); + auto* global = wasm.checkGlobal(curr->name); + if (global) { + curr->type = global->type; + return; + } + auto* import = wasm.checkImport(curr->name); + if (import && import->kind == Import::Global) { + curr->type = import->globalType; + return; + } + throw ParseException("bad get_global"); } void visitSetGlobal(SetGlobal *curr) { if (debug) std::cerr << "zz node: SetGlobal" << std::endl; - curr->index = getU32LEB(); - assert(curr->index < wasm.globals.size()); + auto index = getU32LEB(); + curr->name = getGlobalName(index); curr->value = popExpression(); - curr->type = curr->value->type; } void readMemoryAccess(Address& alignment, size_t bytes, Address& offset) { @@ -1976,21 +2053,22 @@ public: bool maybeVisitStore(Expression*& out, uint8_t code) { Store* curr; switch (code) { - case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->type = i32; break; - case BinaryConsts::I32StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->type = i32; break; - case BinaryConsts::I32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = i32; break; - case BinaryConsts::I64StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->type = i64; break; - case BinaryConsts::I64StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->type = i64; break; - case BinaryConsts::I64StoreMem32: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = i64; break; - case BinaryConsts::I64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->type = i64; break; - case BinaryConsts::F32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->type = f32; break; - case BinaryConsts::F64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->type = f64; break; + case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->valueType = i32; break; + case BinaryConsts::I32StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->valueType = i32; break; + case BinaryConsts::I32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = i32; break; + case BinaryConsts::I64StoreMem8: curr = allocator.alloc<Store>(); curr->bytes = 1; curr->valueType = i64; break; + case BinaryConsts::I64StoreMem16: curr = allocator.alloc<Store>(); curr->bytes = 2; curr->valueType = i64; break; + case BinaryConsts::I64StoreMem32: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = i64; break; + case BinaryConsts::I64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->valueType = i64; break; + case BinaryConsts::F32StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 4; curr->valueType = f32; break; + case BinaryConsts::F64StoreMem: curr = allocator.alloc<Store>(); curr->bytes = 8; curr->valueType = f64; break; default: return false; } if (debug) std::cerr << "zz node: Store" << std::endl; readMemoryAccess(curr->align, curr->bytes, curr->offset); curr->value = popExpression(); curr->ptr = popExpression(); + curr->finalize(); out = curr; return true; } @@ -2175,6 +2253,10 @@ public: void visitUnreachable(Unreachable *curr) { if (debug) std::cerr << "zz node: Unreachable" << std::endl; } + void visitDrop(Drop *curr) { + if (debug) std::cerr << "zz node: Drop" << std::endl; + curr->value = popExpression(); + } }; } // namespace wasm diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 22e1f9a00..2841585e6 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -82,9 +82,9 @@ public: ret->finalize(); return ret; } - Loop* makeLoop(Name out, Name in, Expression* body) { + Loop* makeLoop(Name name, Expression* body) { auto* ret = allocator.alloc<Loop>(); - ret->out = out; ret->in = in; ret->body = body; + ret->name = name; ret->body = body; ret->finalize(); return ret; } @@ -142,20 +142,26 @@ public: auto* ret = allocator.alloc<SetLocal>(); ret->index = index; ret->value = value; + ret->type = none; + return ret; + } + SetLocal* makeTeeLocal(Index index, Expression* value) { + auto* ret = allocator.alloc<SetLocal>(); + ret->index = index; + ret->value = value; ret->type = value->type; return ret; } - GetGlobal* makeGetGlobal(Index index, WasmType type) { + GetGlobal* makeGetGlobal(Name name, WasmType type) { auto* ret = allocator.alloc<GetGlobal>(); - ret->index = index; + ret->name = name; ret->type = type; return ret; } - SetGlobal* makeSetGlobal(Index index, Expression* value) { + SetGlobal* makeSetGlobal(Name name, Expression* value) { auto* ret = allocator.alloc<SetGlobal>(); - ret->index = index; + ret->name = name; ret->value = value; - ret->type = value->type; return ret; } Load* makeLoad(unsigned bytes, bool signed_, uint32_t offset, unsigned align, Expression *ptr, WasmType type) { @@ -164,10 +170,11 @@ public: ret->type = type; return ret; } - Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value) { + Store* makeStore(unsigned bytes, uint32_t offset, unsigned align, Expression *ptr, Expression *value, WasmType type) { auto* ret = allocator.alloc<Store>(); - ret->bytes = bytes; ret->offset = offset; ret->align = align; ret->ptr = ptr; ret->value = value; - ret->type = value->type; + ret->bytes = bytes; ret->offset = offset; ret->align = align; ret->ptr = ptr; ret->value = value; ret->valueType = type; + ret->finalize(); + assert(isConcreteWasmType(ret->value->type) ? ret->value->type == type : true); return ret; } Const* makeConst(Literal value) { @@ -205,12 +212,22 @@ public: ret->op = op; ret->nameOperand = nameOperand; ret->operands.set(operands); + ret->finalize(); return ret; } Unreachable* makeUnreachable() { return allocator.alloc<Unreachable>(); } + // Additional helpers + + Drop* makeDrop(Expression *value) { + auto* ret = allocator.alloc<Drop>(); + ret->value = value; + ret->finalize(); + return ret; + } + // Additional utility functions for building on top of nodes static Index addParam(Function* func, Name name, WasmType type) { @@ -247,7 +264,21 @@ public: if (!block) block = makeBlock(any); if (append) { block->list.push_back(append); - block->finalize(); + block->finalize(); // TODO: move out of if + } + return block; + } + + // ensure a node is a block, if it isn't already, and optionally append to the block + // this variant sets a name for the block, so it will not reuse a block already named + Block* blockifyWithName(Expression* any, Name name, Expression* append = nullptr) { + Block* block = nullptr; + if (any) block = any->dynCast<Block>(); + if (!block || block->name.is()) block = makeBlock(any); + block->name = name; + if (append) { + block->list.push_back(append); + block->finalize(); // TODO: move out of if } return block; } diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 28a3e8e6d..292d6a521 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -31,8 +31,6 @@ #ifdef WASM_INTERPRETER_DEBUG #include "wasm-printing.h" - -int indent = 0; #endif @@ -77,32 +75,11 @@ typedef std::vector<Literal> LiteralList; // Debugging helpers #ifdef WASM_INTERPRETER_DEBUG -struct IndentHandler { - const char *name; - IndentHandler(const char *name, Expression *expression) : name(name) { - doIndent(std::cout, indent); - std::cout << "visit " << name << " :\n"; - indent++; -#if WASM_INTERPRETER_DEBUG == 2 - doIndent(std::cout, indent); - if (expression) std::cout << "\n" << expression << '\n'; - indent++; -#endif - } - ~IndentHandler() { -#if WASM_INTERPRETER_DEBUG == 2 - indent--; -#endif - indent--; - doIndent(std::cout, indent); - std::cout << "exit " << name << '\n'; - } -}; -#define NOTE_ENTER(x) IndentHandler indentHandler(x, curr); -#define NOTE_ENTER_(x) IndentHandler indentHandler(x, nullptr); -#define NOTE_NAME(p0) { doIndent(std::cout, indent); std::cout << "name in " << indentHandler.name << '(' << Name(p0) << ")\n"; } -#define NOTE_EVAL1(p0) { doIndent(std::cout, indent); std::cout << "eval in " << indentHandler.name << '(' << p0 << ")\n"; } -#define NOTE_EVAL2(p0, p1) { doIndent(std::cout, indent); std::cout << "eval in " << indentHandler.name << '(' << p0 << ", " << p1 << ")\n"; } +#define NOTE_ENTER(x) { std::cout << "visit " << x << " : " << curr << "\n"; } +#define NOTE_ENTER_(x) { std::cout << "visit " << x << "\n"; } +#define NOTE_NAME(p0) { std::cout << "name " << '(' << Name(p0) << ")\n"; } +#define NOTE_EVAL1(p0) { std::cout << "eval " << '(' << p0 << ")\n"; } +#define NOTE_EVAL2(p0, p1) { std::cout << "eval " << '(' << p0 << ", " << p1 << ")\n"; } #else // WASM_INTERPRETER_DEBUG #define NOTE_ENTER(x) #define NOTE_ENTER_(x) @@ -170,8 +147,7 @@ public: while (1) { Flow flow = visit(curr->body); if (flow.breaking()) { - if (flow.breakTo == curr->in) continue; // lol - flow.clearIf(curr->out); + if (flow.breakTo == curr->name) continue; // lol } return flow; // loop does not loop automatically, only continue achieves that } @@ -435,6 +411,12 @@ public: NOTE_EVAL1(condition.value); return condition.value.geti32() ? ifTrue : ifFalse; // ;-) } + Flow visitDrop(Drop *curr) { + NOTE_ENTER("Drop"); + Flow value = visit(curr->value); + if (value.breaking()) return value; + return Flow(); + } Flow visitReturn(Return *curr) { NOTE_ENTER("Return"); Flow flow; @@ -503,14 +485,17 @@ public: // Execute an constant expression in a global init or memory offset class ConstantExpressionRunner : public ExpressionRunner<ConstantExpressionRunner> { + std::map<Name, Literal>& globals; public: + ConstantExpressionRunner(std::map<Name, Literal>& globals) : globals(globals) {} + Flow visitLoop(Loop* curr) { WASM_UNREACHABLE(); } Flow visitCall(Call* curr) { WASM_UNREACHABLE(); } Flow visitCallImport(CallImport* curr) { WASM_UNREACHABLE(); } Flow visitCallIndirect(CallIndirect* curr) { WASM_UNREACHABLE(); } Flow visitGetLocal(GetLocal *curr) { WASM_UNREACHABLE(); } Flow visitSetLocal(SetLocal *curr) { WASM_UNREACHABLE(); } - Flow visitGetGlobal(GetGlobal *curr) { WASM_UNREACHABLE(); } + Flow visitGetGlobal(GetGlobal *curr) { return Flow(globals[curr->name]); } Flow visitSetGlobal(SetGlobal *curr) { WASM_UNREACHABLE(); } Flow visitLoad(Load *curr) { WASM_UNREACHABLE(); } Flow visitStore(Store *curr) { WASM_UNREACHABLE(); } @@ -535,9 +520,10 @@ public: // an imported function or accessing memory. // struct ExternalInterface { - virtual void init(Module& wasm) {} + virtual void init(Module& wasm, ModuleInstance& instance) {} + virtual void importGlobals(std::map<Name, Literal>& globals, Module& wasm) = 0; virtual Literal callImport(Import* import, LiteralList& arguments) = 0; - virtual Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) = 0; + virtual Literal callTable(Index index, LiteralList& arguments, WasmType result, ModuleInstance& instance) = 0; virtual Literal load(Load* load, Address addr) = 0; virtual void store(Store* store, Address addr, Literal value) = 0; virtual void growMemory(Address oldSize, Address newSize) = 0; @@ -547,14 +533,20 @@ public: Module& wasm; // Values of globals - std::vector<Literal> globals; + std::map<Name, Literal> globals; ModuleInstance(Module& wasm, ExternalInterface* externalInterface) : wasm(wasm), externalInterface(externalInterface) { + // import globals from the outside + externalInterface->importGlobals(globals, wasm); + // prepare memory memorySize = wasm.memory.initial; - for (Index i = 0; i < wasm.globals.size(); i++) { - globals.push_back(ConstantExpressionRunner().visit(wasm.globals[i]->init).value); + // generate internal (non-imported) globals + for (auto& global : wasm.globals) { + globals[global->name] = ConstantExpressionRunner(globals).visit(global->init).value; } - externalInterface->init(wasm); + // initialize the rest of the external interface + externalInterface->init(wasm, *this); + // run start, if present if (wasm.start.is()) { LiteralList arguments; callFunction(wasm.start, arguments); @@ -670,13 +662,13 @@ public: } Flow visitCallIndirect(CallIndirect *curr) { NOTE_ENTER("CallIndirect"); - Flow target = visit(curr->target); - if (target.breaking()) return target; LiteralList arguments; Flow flow = generateArguments(curr->operands, arguments); if (flow.breaking()) return flow; + Flow target = visit(curr->target); + if (target.breaking()) return target; Index index = target.value.geti32(); - return instance.externalInterface->callTable(index, curr->fullType, arguments, instance); + return instance.externalInterface->callTable(index, arguments, curr->type, instance); } Flow visitGetLocal(GetLocal *curr) { @@ -693,28 +685,27 @@ public: if (flow.breaking()) return flow; NOTE_EVAL1(index); NOTE_EVAL1(flow.value); - assert(flow.value.type == curr->type); + assert(curr->isTee() ? flow.value.type == curr->type : true); scope.locals[index] = flow.value; - return flow; + return curr->isTee() ? flow : Flow(); } Flow visitGetGlobal(GetGlobal *curr) { NOTE_ENTER("GetGlobal"); - auto index = curr->index; - NOTE_EVAL1(index); - NOTE_EVAL1(instance.globals[index]); - return instance.globals[index]; + auto name = curr->name; + NOTE_EVAL1(name); + NOTE_EVAL1(instance.globals[name]); + return instance.globals[name]; } Flow visitSetGlobal(SetGlobal *curr) { NOTE_ENTER("SetGlobal"); - auto index = curr->index; + auto name = curr->name; Flow flow = visit(curr->value); if (flow.breaking()) return flow; - NOTE_EVAL1(index); + NOTE_EVAL1(name); NOTE_EVAL1(flow.value); - assert(flow.value.type == curr->type); - instance.globals[index] = flow.value; - return flow; + instance.globals[name] = flow.value; + return Flow(); } Flow visitLoad(Load *curr) { @@ -730,7 +721,7 @@ public: Flow value = visit(curr->value); if (value.breaking()) return value; instance.externalInterface->store(curr, instance.getFinalAddress(curr, ptr.value), value.value); - return value; + return Flow(); } Flow visitHost(Host *curr) { @@ -739,14 +730,15 @@ public: case PageSize: return Literal((int32_t)Memory::kPageSize); case CurrentMemory: return Literal(int32_t(instance.memorySize)); case GrowMemory: { + auto fail = Literal(int32_t(-1)); Flow flow = visit(curr->operands[0]); if (flow.breaking()) return flow; int32_t ret = instance.memorySize; uint32_t delta = flow.value.geti32(); - if (delta > uint32_t(-1) /Memory::kPageSize) trap("growMemory: delta relatively too big"); - if (instance.memorySize >= uint32_t(-1) - delta) trap("growMemory: delta objectively too big"); + if (delta > uint32_t(-1) /Memory::kPageSize) return fail; + if (instance.memorySize >= uint32_t(-1) - delta) return fail; uint32_t newSize = instance.memorySize + delta; - if (newSize > instance.wasm.memory.max) trap("growMemory: exceeds max"); + if (newSize > instance.wasm.memory.max) return fail; instance.externalInterface->growMemory(instance.memorySize * Memory::kPageSize, newSize * Memory::kPageSize); instance.memorySize = newSize; return Literal(int32_t(ret)); diff --git a/src/wasm-js.cpp b/src/wasm-js.cpp index 83956e47e..7d71cec3c 100644 --- a/src/wasm-js.cpp +++ b/src/wasm-js.cpp @@ -81,21 +81,6 @@ extern "C" void EMSCRIPTEN_KEEPALIVE load_asm2wasm(char *input) { if (wasmJSDebug) std::cerr << "wasming...\n"; asm2wasm = new Asm2WasmBuilder(*module, pre.memoryGrowth, debug, false /* TODO: support imprecise? */, false /* TODO: support optimizing? */); asm2wasm->processAsm(asmjs); - - if (wasmJSDebug) std::cerr << "mapping globals...\n"; - for (auto& pair : asm2wasm->mappedGlobals) { - auto name = pair.first; - auto& global = pair.second; - if (!global.import) continue; // non-imports are initialized to zero in the typed array anyhow, so nothing to do here - double value = EM_ASM_DOUBLE({ return Module['lookupImport'](Pointer_stringify($0), Pointer_stringify($1)) }, global.module.str, global.base.str); - uint32_t address = global.address; - switch (global.type) { - case i32: EM_ASM_({ Module['info'].parent['HEAP32'][$0 >> 2] = $1 }, address, value); break; - case f32: EM_ASM_({ Module['info'].parent['HEAPF32'][$0 >> 2] = $1 }, address, value); break; - case f64: EM_ASM_({ Module['info'].parent['HEAPF64'][$0 >> 3] = $1 }, address, value); break; - default: abort(); - } - } } void finalizeModule() { @@ -160,14 +145,16 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { Module['asmExports'] = {}; }); for (auto& curr : module->exports) { - EM_ASM_({ - var name = Pointer_stringify($0); - Module['asmExports'][name] = function() { - Module['tempArguments'] = Array.prototype.slice.call(arguments); - Module['_call_from_js']($0); - return Module['tempReturn']; - }; - }, curr->name.str); + if (curr->kind == Export::Function) { + EM_ASM_({ + var name = Pointer_stringify($0); + Module['asmExports'][name] = function() { + Module['tempArguments'] = Array.prototype.slice.call(arguments); + Module['_call_from_js']($0); + return Module['tempReturn']; + }; + }, curr->name.str); + } } // verify imports are provided @@ -176,32 +163,68 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { var mod = Pointer_stringify($0); var base = Pointer_stringify($1); var name = Pointer_stringify($2); - assert(Module['lookupImport'](mod, base), 'checking import ' + name + ' = ' + mod + '.' + base); + assert(Module['lookupImport'](mod, base) !== undefined, 'checking import ' + name + ' = ' + mod + '.' + base); }, import->module.str, import->base.str, import->name.str); } if (wasmJSDebug) std::cerr << "creating instance...\n"; struct JSExternalInterface : ModuleInstance::ExternalInterface { - void init(Module& wasm) override { - // create a new buffer here, just like native wasm support would. - EM_ASM_({ - Module['outside']['newBuffer'] = new ArrayBuffer($0); - }, wasm.memory.initial * Memory::kPageSize); + void init(Module& wasm, ModuleInstance& instance) override { + // look for imported memory + { + bool found = false; + for (auto& import : wasm.imports) { + if (import->module == ENV && import->base == MEMORY) { + assert(import->kind == Import::Memory); + // memory is imported + EM_ASM({ + Module['asmExports']['memory'] = Module['lookupImport']('env', 'memory'); + }); + found = true; + } + } + if (!found) { + // no memory import; create a new buffer here, just like native wasm support would. + EM_ASM_({ + Module['asmExports']['memory'] = Module['outside']['newBuffer'] = new ArrayBuffer($0); + }, wasm.memory.initial * Memory::kPageSize); + } + } for (auto segment : wasm.memory.segments) { EM_ASM_({ var source = Module['HEAP8'].subarray($1, $1 + $2); - var target = new Int8Array(Module['outside']['newBuffer']); + var target = new Int8Array(Module['asmExports']['memory']); target.set(source, $0); - }, ConstantExpressionRunner().visit(segment.offset).value.geti32(), &segment.data[0], segment.data.size()); + }, ConstantExpressionRunner(instance.globals).visit(segment.offset).value.geti32(), &segment.data[0], segment.data.size()); + } + // look for imported table + { + bool found = false; + for (auto& import : wasm.imports) { + if (import->module == ENV && import->base == TABLE) { + assert(import->kind == Import::Table); + // table is imported + EM_ASM({ + Module['outside']['wasmTable'] = Module['lookupImport']('env', 'table'); + }); + found = true; + } + } + if (!found) { + // no table import; create a new one here, just like native wasm support would. + EM_ASM_({ + Module['outside']['wasmTable'] = new Array($0); + }, wasm.table.initial); + } } - // Table support is in a JS array. If the entry is a number, it's a function pointer. If not, it's a JS method to be called directly + EM_ASM({ + Module['asmExports']['table'] = Module['outside']['wasmTable']; + }); + // Emulated table support is in a JS array. If the entry is a number, it's a function pointer. If not, it's a JS method to be called directly // TODO: make them all JS methods, wrapping a dynCall where necessary? - EM_ASM_({ - Module['outside']['wasmTable'] = new Array($0); - }, wasm.table.initial); for (auto segment : wasm.table.segments) { - Address offset = ConstantExpressionRunner().visit(segment.offset).value.geti32(); + Address offset = ConstantExpressionRunner(instance.globals).visit(segment.offset).value.geti32(); assert(offset + segment.data.size() <= wasm.table.initial); for (size_t i = 0; i != segment.data.size(); ++i) { EM_ASM_({ @@ -238,6 +261,23 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { } } + void importGlobals(std::map<Name, Literal>& globals, Module& wasm) override { + for (auto& import : wasm.imports) { + if (import->kind == Import::Global) { + double ret = EM_ASM_DOUBLE({ + var mod = Pointer_stringify($0); + var base = Pointer_stringify($1); + var lookup = Module['lookupImport'](mod, base); + return lookup; + }, import->module.str, import->base.str); + + if (wasmJSDebug) std::cout << "calling importGlobal for " << import->name << " returning " << ret << '\n'; + + globals[import->name] = getResultFromJS(ret, import->globalType); + } + } + } + Literal callImport(Import *import, LiteralList& arguments) override { if (wasmJSDebug) std::cout << "calling import " << import->name.str << '\n'; prepareTempArgments(arguments); @@ -252,10 +292,10 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { if (wasmJSDebug) std::cout << "calling import returning " << ret << '\n'; - return getResultFromJS(ret, import->type->result); + return getResultFromJS(ret, import->functionType->result); } - Literal callTable(Index index, Name type, LiteralList& arguments, ModuleInstance& instance) override { + Literal callTable(Index index, LiteralList& arguments, WasmType result, ModuleInstance& instance) override { void* ptr = (void*)EM_ASM_INT({ var value = Module['outside']['wasmTable'][$0]; return typeof value === "number" ? value : -1; @@ -264,7 +304,6 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { if (ptr != (void*)-1) { // a Function we can call Function* func = (Function*)ptr; - if (func->type.is() && func->type != type) trap("callIndirect: bad type"); if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments"); for (size_t i = 0; i < func->params.size(); i++) { if (func->params[i] != arguments[i].type) { @@ -281,7 +320,7 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { Module['tempArguments'] = null; return func.apply(null, tempArguments); }, index); - return getResultFromJS(ret, instance.wasm.getFunctionType(type)->result); + return getResultFromJS(ret, result); } } @@ -411,11 +450,11 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { Module["info"].parent["HEAPU8"][addr + i] = HEAPU8[i]; } HEAP32[0] = save0; HEAP32[1] = save1; - }, (uint32_t)addr, store_->bytes, isWasmTypeFloat(store_->type), isWasmTypeFloat(store_->type) ? value.getFloat() : (double)value.getInteger()); + }, (uint32_t)addr, store_->bytes, isWasmTypeFloat(store_->valueType), isWasmTypeFloat(store_->valueType) ? value.getFloat() : (double)value.getInteger()); return; } // nicely aligned - if (!isWasmTypeFloat(store_->type)) { + if (!isWasmTypeFloat(store_->valueType)) { if (store_->bytes == 1) { EM_ASM_INT({ Module['info'].parent['HEAP8'][$0] = $1 }, addr, value.geti32()); } else if (store_->bytes == 2) { diff --git a/src/wasm-linker.cpp b/src/wasm-linker.cpp index b1160e94a..04deddf4a 100644 --- a/src/wasm-linker.cpp +++ b/src/wasm-linker.cpp @@ -49,7 +49,8 @@ void Linker::ensureImport(Name target, std::string signature) { auto import = new Import; import->name = import->base = target; import->module = ENV; - import->type = ensureFunctionType(signature, &out.wasm); + import->functionType = ensureFunctionType(signature, &out.wasm); + import->kind = Import::Function; out.wasm.addImport(import); } } @@ -106,7 +107,12 @@ void Linker::layout() { } if (userMaxMemory) out.wasm.memory.max = userMaxMemory / Memory::kPageSize; - out.wasm.memory.exportName = MEMORY; + + auto memoryExport = make_unique<Export>(); + memoryExport->name = MEMORY; + memoryExport->value = Name::fromInt(0); + memoryExport->kind = Export::Memory; + out.wasm.addExport(memoryExport.release()); // XXX For now, export all functions marked .globl. for (Name name : out.globls) exportFunction(name, false); @@ -333,7 +339,8 @@ void Linker::emscriptenGlue(std::ostream& o) { auto import = new Import; import->name = import->base = curr->target; import->module = ENV; - import->type = ensureFunctionType(getSig(curr), &parent->out.wasm); + import->functionType = ensureFunctionType(getSig(curr), &parent->out.wasm); + import->kind = Import::Function; parent->out.wasm.addImport(import); } } diff --git a/src/wasm-linker.h b/src/wasm-linker.h index a6f5d319a..21d336273 100644 --- a/src/wasm-linker.h +++ b/src/wasm-linker.h @@ -310,6 +310,7 @@ class Linker { if (out.wasm.checkExport(name)) return; // Already exported auto exp = new Export; exp->name = exp->value = name; + exp->kind = Export::Function; out.wasm.addExport(exp); } diff --git a/src/wasm-module-building.h b/src/wasm-module-building.h index ead074991..43cc493d1 100644 --- a/src/wasm-module-building.h +++ b/src/wasm-module-building.h @@ -73,6 +73,7 @@ static std::mutex debug; class OptimizingIncrementalModuleBuilder { Module* wasm; uint32_t numFunctions; + std::function<void (PassRunner&)> addPrePasses; Function* endMarker; std::atomic<Function*>* list; uint32_t nextFunction; // only used on main thread @@ -86,8 +87,8 @@ class OptimizingIncrementalModuleBuilder { public: // numFunctions must be equal to the number of functions allocated, or higher. Knowing // this bounds helps avoid locking. - OptimizingIncrementalModuleBuilder(Module* wasm, Index numFunctions) - : wasm(wasm), numFunctions(numFunctions), endMarker(nullptr), list(nullptr), nextFunction(0), + OptimizingIncrementalModuleBuilder(Module* wasm, Index numFunctions, std::function<void (PassRunner&)> addPrePasses) + : wasm(wasm), numFunctions(numFunctions), addPrePasses(addPrePasses), endMarker(nullptr), list(nullptr), nextFunction(0), numWorkers(0), liveWorkers(0), activeWorkers(0), availableFuncs(0), finishedFuncs(0), finishing(false) { if (numFunctions == 0) { @@ -201,6 +202,7 @@ private: void optimizeFunction(Function* func) { PassRunner passRunner(wasm); + addPrePasses(passRunner); passRunner.addDefaultFunctionOptimizationPasses(); passRunner.runFunction(func); } diff --git a/src/wasm-printing.h b/src/wasm-printing.h index 2f1c97831..830f2dda2 100644 --- a/src/wasm-printing.h +++ b/src/wasm-printing.h @@ -36,7 +36,7 @@ struct WasmPrinter { return printModule(module, std::cout); } - static std::ostream& printExpression(Expression* expression, std::ostream& o, bool minify = false); + static std::ostream& printExpression(Expression* expression, std::ostream& o, bool minify = false, bool full = false); }; } diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index b70f70b4e..9abdea744 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -60,13 +60,15 @@ class Element { List list_; IString str_; bool dollared_; + bool quoted_; public: Element(MixedArena& allocator) : isList_(true), list_(allocator), line(-1), col(-1) {} bool isList() { return isList_; } bool isStr() { return !isList_; } - bool dollared() { return dollared_; } + bool dollared() { return isStr() && dollared_; } + bool quoted() { return isStr() && quoted_; } size_t line, col; @@ -98,10 +100,11 @@ public: return str_.str; } - Element* setString(IString str__, bool dollared__) { + Element* setString(IString str__, bool dollared__, bool quoted__) { isList_ = false; str_ = str__; dollared_ = dollared__; + quoted_ = quoted__; return this; } @@ -242,13 +245,13 @@ private: input++; } input++; - return allocator.alloc<Element>()->setString(IString(str.c_str(), false), dollared)->setMetadata(line, start - lineStart); + return allocator.alloc<Element>()->setString(IString(str.c_str(), false), dollared, true)->setMetadata(line, start - lineStart); } while (input[0] && !isspace(input[0]) && input[0] != ')' && input[0] != '(' && input[0] != ';') input++; if (start == input) throw ParseException("expected string", line, input - lineStart); char temp = input[0]; input[0] = 0; - auto ret = allocator.alloc<Element>()->setString(IString(start, false), dollared)->setMetadata(line, start - lineStart); + auto ret = allocator.alloc<Element>()->setString(IString(start, false), dollared, false)->setMetadata(line, start - lineStart); input[0] = temp; return ret; } @@ -319,10 +322,11 @@ private: Element& curr = *s[i]; IString id = curr[0]->str(); if (id == RESULT) { + if (curr.size() > 2) throw ParseException("invalid result arity", curr.line, curr.col); functionTypes[name] = stringToWasmType(curr[1]->str()); } else if (id == TYPE) { Name typeName = curr[1]->str(); - if (!wasm.checkFunctionType(typeName)) throw ParseException("unknown function"); + if (!wasm.checkFunctionType(typeName)) throw ParseException("unknown function", curr.line, curr.col); type = wasm.getFunctionType(typeName); functionTypes[name] = type->result; } else if (id == PARAM && curr.size() > 1) { @@ -387,7 +391,11 @@ private: bool brokeToAutoBlock; Name getPrefixedName(std::string prefix) { - return IString((prefix + std::to_string(otherIndex++)).c_str(), false); + // make sure to return a unique name not already on the stack + while (1) { + Name ret = IString((prefix + std::to_string(otherIndex++)).c_str(), false); + if (std::find(labelStack.begin(), labelStack.end(), ret) == labelStack.end()) return ret; + } } Name getFunctionName(Element& s) { @@ -408,14 +416,23 @@ private: // returns the next index in s size_t parseFunctionNames(Element& s, Name& name, Name& exportName) { size_t i = 1; - while (i < s.size() && s[i]->isStr()) { - if (!s[i]->dollared()) { + while (i < s.size() && i < 3 && s[i]->isStr()) { + if (s[i]->quoted()) { // an export name exportName = s[i]->str(); i++; - } else { + } else if (s[i]->dollared()) { name = s[i]->str(); i++; + } else { + break; + } + } + if (i < s.size() && s[i]->isList()) { + auto& inner = *s[i]; + if (inner.size() > 0 && inner[0]->str() == EXPORT) { + exportName = inner[1]->str(); + i++; } } return i; @@ -433,6 +450,7 @@ private: auto ex = make_unique<Export>(); ex->name = exportName; ex->value = name; + ex->kind = Export::Function; wasm.addExport(ex.release()); } functionCounter++; @@ -490,6 +508,7 @@ private: currLocalTypes[name] = type; } } else if (id == RESULT) { + if (curr.size() > 2) throw ParseException("invalid result arity", curr.line, curr.col); result = stringToWasmType(curr[1]->str()); } else if (id == TYPE) { Name name = curr[1]->str(); @@ -564,7 +583,6 @@ private: if (str[1] == '6' && str[2] == '4' && (prefix || str[3] == 0)) return f64; } if (allowError) return none; - throw ParseException("unknown type"); abort(); } @@ -576,6 +594,7 @@ public: #define abort_on(str) { throw ParseException(std::string("abort_on ") + str); } Expression* parseExpression(Element& s) { + if (!s.isList()) throw ParseException("invalid node for parseExpression, needed list", s.line, s.col); IString id = s[0]->str(); const char *str = id.str; const char *dot = strchr(str, '.'); @@ -617,7 +636,7 @@ public: if (op[3] == '_') return makeBinary(s, op[4] == 'u' ? BINARY_INT(DivU) : BINARY_INT(DivS), type); if (op[3] == 0) return makeBinary(s, BINARY_FLOAT(Div), type); } - if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type); + if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type); abort_on(op); } case 'e': { @@ -735,6 +754,10 @@ public: } else if (str[1] == 'u') return makeHost(s, HostOp::CurrentMemory); abort_on(str); } + case 'd': { + if (str[1] == 'r') return makeDrop(s); + abort_on(str); + } case 'e': { if (str[1] == 'l') return makeThenOrElse(s); abort_on(str); @@ -781,6 +804,7 @@ public: } case 't': { if (str[1] == 'h') return makeThenOrElse(s); + if (str[1] == 'e' && str[2] == 'e') return makeTeeLocal(s); abort_on(str); } case 'u': { @@ -874,13 +898,20 @@ private: return ret; } + Expression* makeDrop(Element& s) { + auto ret = allocator.alloc<Drop>(); + ret->value = parseExpression(s[1]); + ret->finalize(); + return ret; + } + Expression* makeHost(Element& s, HostOp op) { auto ret = allocator.alloc<Host>(); ret->op = op; if (op == HostOp::HasFeature) { ret->nameOperand = s[1]->str(); } else { - parseCallOperands(s, 1, ret); + parseCallOperands(s, 1, s.size(), ret); } ret->finalize(); return ret; @@ -906,40 +937,41 @@ private: return ret; } - Expression* makeSetLocal(Element& s) { + Expression* makeTeeLocal(Element& s) { auto ret = allocator.alloc<SetLocal>(); ret->index = getLocalIndex(*s[1]); ret->value = parseExpression(s[2]); - ret->type = currFunction->getLocalType(ret->index); + ret->setTee(true); return ret; } - - Index getGlobalIndex(Element& s) { - if (s.dollared()) { - auto name = s.str(); - for (Index i = 0; i < wasm.globals.size(); i++) { - if (wasm.globals[i]->name == name) return i; - } - throw ParseException("bad global name", s.line, s.col); - } - // this is a numeric index - Index ret = atoi(s.c_str()); - if (!wasm.checkGlobal(ret)) throw ParseException("bad global index", s.line, s.col); + Expression* makeSetLocal(Element& s) { + auto ret = allocator.alloc<SetLocal>(); + ret->index = getLocalIndex(*s[1]); + ret->value = parseExpression(s[2]); + ret->setTee(false); return ret; } Expression* makeGetGlobal(Element& s) { auto ret = allocator.alloc<GetGlobal>(); - ret->index = getGlobalIndex(*s[1]); - ret->type = wasm.getGlobal(ret->index)->type; - return ret; + ret->name = s[1]->str(); + auto* global = wasm.checkGlobal(ret->name); + if (global) { + ret->type = global->type; + return ret; + } + auto* import = wasm.checkImport(ret->name); + if (import && import->kind == Import::Global) { + ret->type = import->globalType; + return ret; + } + throw ParseException("bad get_global name", s.line, s.col); } Expression* makeSetGlobal(Element& s) { auto ret = allocator.alloc<SetGlobal>(); - ret->index = getGlobalIndex(*s[1]); + ret->name = s[1]->str(); ret->value = parseExpression(s[2]); - ret->type = wasm.getGlobal(ret->index)->type; return ret; } @@ -1057,7 +1089,7 @@ private: Expression* makeStore(Element& s, WasmType type) { const char *extra = strchr(s[0]->c_str(), '.') + 6; // after "type.store" auto ret = allocator.alloc<Store>(); - ret->type = type; + ret->valueType = type; ret->bytes = getWasmTypeSize(type); if (extra[0] == '8') { ret->bytes = 1; @@ -1088,6 +1120,7 @@ private: } ret->ptr = parseExpression(s[i]); ret->value = parseExpression(s[i+1]); + ret->finalize(); return ret; } @@ -1148,24 +1181,28 @@ private: Expression* makeLoop(Element& s) { auto ret = allocator.alloc<Loop>(); size_t i = 1; + Name out; if (s.size() > i + 1 && s[i]->isStr() && s[i + 1]->isStr()) { // out can only be named if both are - ret->out = s[i]->str(); + out = s[i]->str(); i++; - } else { - ret->out = getPrefixedName("loop-out"); } if (s.size() > i && s[i]->isStr()) { - ret->in = s[i]->str(); + ret->name = s[i]->str(); i++; } else { - ret->in = getPrefixedName("loop-in"); + ret->name = getPrefixedName("loop-in"); } - labelStack.push_back(ret->out); - labelStack.push_back(ret->in); + labelStack.push_back(ret->name); ret->body = makeMaybeBlock(s, i); labelStack.pop_back(); - labelStack.pop_back(); ret->finalize(); + if (out.is()) { + auto* block = allocator.alloc<Block>(); + block->name = out; + block->list.push_back(ret); + block->finalize(); + return block; + } return ret; } @@ -1173,7 +1210,7 @@ private: auto ret = allocator.alloc<Call>(); ret->target = s[1]->str(); ret->type = functionTypes[ret->target]; - parseCallOperands(s, 2, ret); + parseCallOperands(s, 2, s.size(), ret); return ret; } @@ -1181,8 +1218,8 @@ private: auto ret = allocator.alloc<CallImport>(); ret->target = s[1]->str(); Import* import = wasm.getImport(ret->target); - ret->type = import->type->result; - parseCallOperands(s, 2, ret); + ret->type = import->functionType->result; + parseCallOperands(s, 2, s.size(), ret); return ret; } @@ -1193,14 +1230,14 @@ private: if (!fullType) throw ParseException("invalid call_indirect type", s.line, s.col); ret->fullType = fullType->name; ret->type = fullType->result; - ret->target = parseExpression(s[2]); - parseCallOperands(s, 3, ret); + parseCallOperands(s, 2, s.size() - 1, ret); + ret->target = parseExpression(s[s.size() - 1]); return ret; } template<class T> - void parseCallOperands(Element& s, size_t i, T* call) { - while (i < s.size()) { + void parseCallOperands(Element& s, Index i, Index j, T* call) { + while (i < j) { call->operands.push_back(parseExpression(s[i])); i++; } @@ -1311,10 +1348,29 @@ private: void parseMemory(Element& s) { hasMemory = true; - - wasm.memory.initial = atoi(s[1]->c_str()); - if (s.size() == 2) return; - size_t i = 2; + Index i = 1; + if (s[i]->dollared()) { + wasm.memory.name = s[i++]->str(); + } + if (s[i]->isList()) { + auto& inner = *s[i]; + if (inner[0]->str() == EXPORT) { + auto ex = make_unique<Export>(); + ex->name = inner[1]->str(); + ex->value = wasm.memory.name; + ex->kind = Export::Memory; + wasm.addExport(ex.release()); + i++; + } else { + assert(inner.size() > 0 ? inner[0]->str() != IMPORT : true); + // (memory (data ..)) format + parseData(*s[i]); + wasm.memory.initial = wasm.memory.segments[0].data.size(); + return; + } + } + wasm.memory.initial = atoi(s[i++]->c_str()); + if (i == s.size()) return; if (s[i]->isStr()) { uint64_t max = atoll(s[i]->c_str()); if (max > Memory::kMaxSize) throw ParseException("total memory must be <= 4GB"); @@ -1346,108 +1402,220 @@ private: } void parseData(Element& s) { - auto* offset = parseExpression(s[1]); - const char *input = s[2]->c_str(); - if (auto size = strlen(input)) { - std::vector<char> data; - stringToBinary(input, size, data); - wasm.memory.segments.emplace_back(offset, data.data(), data.size()); + if (!hasMemory) throw ParseException("data but no memory"); + Index i = 1; + Expression* offset; + if (i < s.size() && s[i]->isList()) { + // there is an init expression + offset = parseExpression(s[i++]); } else { - wasm.memory.segments.emplace_back(offset, "", 0); + offset = allocator.alloc<Const>()->set(Literal(int32_t(0))); + } + std::vector<char> data; + while (i < s.size()) { + const char *input = s[i++]->c_str(); + if (auto size = strlen(input)) { + stringToBinary(input, size, data); + } } + wasm.memory.segments.emplace_back(offset, data.data(), data.size()); } void parseExport(Element& s) { - if (!s[2]->dollared() && !std::isdigit(s[2]->str()[0])) { - assert(s[2]->str() == MEMORY); - if (!hasMemory) throw ParseException("memory exported but no memory"); - wasm.memory.exportName = s[1]->str(); - return; - } std::unique_ptr<Export> ex = make_unique<Export>(); ex->name = s[1]->str(); - ex->value = s[2]->str(); + if (s[2]->isList()) { + auto& inner = *s[2]; + if (inner[0]->str() == FUNC) { + ex->value = inner[1]->str(); + ex->kind = Export::Function; + } else if (inner[0]->str() == MEMORY) { + if (!hasMemory) throw ParseException("memory exported but no memory"); + ex->value = Name::fromInt(0); + ex->kind = Export::Memory; + } else if (inner[0]->str() == TABLE) { + ex->value = Name::fromInt(0); + ex->kind = Export::Table; + } else if (inner[0]->str() == GLOBAL) { + ex->value = inner[1]->str(); + ex->kind = Export::Table; + } else { + WASM_UNREACHABLE(); + } + } else if (!s[2]->dollared() && !std::isdigit(s[2]->str()[0])) { + if (s[2]->str() == MEMORY) { + if (!hasMemory) throw ParseException("memory exported but no memory"); + ex->value = Name::fromInt(0); + ex->kind = Export::Memory; + } else if (s[2]->str() == TABLE) { + ex->value = Name::fromInt(0); + ex->kind = Export::Table; + } else if (s[2]->str() == GLOBAL) { + ex->value = s[3]->str(); + ex->kind = Export::Table; + } else { + WASM_UNREACHABLE(); + } + } else { + // function + ex->value = s[2]->str(); + ex->kind = Export::Function; + } wasm.addExport(ex.release()); } void parseImport(Element& s) { std::unique_ptr<Import> im = make_unique<Import>(); size_t i = 1; + bool newStyle = s.size() == 4 && s[3]->isList(); // (import "env" "STACKTOP" (global $stackTop i32)) + if (newStyle) { + if ((*s[3])[0]->str() == FUNC) { + im->kind = Import::Function; + } else if ((*s[3])[0]->str() == MEMORY) { + im->kind = Import::Memory; + } else if ((*s[3])[0]->str() == TABLE) { + im->kind = Import::Table; + } else if ((*s[3])[0]->str() == GLOBAL) { + im->kind = Import::Global; + } else { + newStyle = false; // either (param..) or (result..) + } + } if (s.size() > 3 && s[3]->isStr()) { im->name = s[i++]->str(); + } else if (newStyle) { + im->name = (*s[3])[1]->str(); } else { im->name = Name::fromInt(importCounter); } importCounter++; + if (!s[i]->quoted()) { + if (s[i]->str() == MEMORY) { + im->kind = Import::Memory; + } else if (s[i]->str() == TABLE) { + im->kind = Import::Table; + } else if (s[i]->str() == GLOBAL) { + im->kind = Import::Global; + } else { + WASM_UNREACHABLE(); + } + i++; + } else if (!newStyle) { + im->kind = Import::Function; + } im->module = s[i++]->str(); if (!s[i]->isStr()) throw ParseException("no name for import"); im->base = s[i++]->str(); - std::unique_ptr<FunctionType> type = make_unique<FunctionType>(); - if (s.size() > i) { - Element& params = *s[i]; - IString id = params[0]->str(); - if (id == PARAM) { - for (size_t i = 1; i < params.size(); i++) { - type->params.push_back(stringToWasmType(params[i]->str())); + // parse internals + Element& inner = newStyle ? *s[3] : s; + Index j = newStyle ? 2 : i; + if (im->kind == Import::Function) { + std::unique_ptr<FunctionType> type = make_unique<FunctionType>(); + if (inner.size() > j) { + Element& params = *inner[j]; + IString id = params[0]->str(); + if (id == PARAM) { + for (size_t i = 1; i < params.size(); i++) { + type->params.push_back(stringToWasmType(params[i]->str())); + } + } else if (id == RESULT) { + type->result = stringToWasmType(params[1]->str()); + } else if (id == TYPE) { + IString name = params[1]->str(); + if (!wasm.checkFunctionType(name)) throw ParseException("bad function type for import"); + *type = *wasm.getFunctionType(name); + } else { + throw ParseException("bad import element"); + } + if (inner.size() > j+1) { + Element& result = *inner[j+1]; + assert(result[0]->str() == RESULT); + type->result = stringToWasmType(result[1]->str()); } - } else if (id == RESULT) { - type->result = stringToWasmType(params[1]->str()); - } else if (id == TYPE) { - IString name = params[1]->str(); - if (!wasm.checkFunctionType(name)) throw ParseException("bad function type for import"); - *type = *wasm.getFunctionType(name); - } else { - throw ParseException("bad import element"); - } - if (s.size() > i+1) { - Element& result = *s[i+1]; - assert(result[0]->str() == RESULT); - type->result = stringToWasmType(result[1]->str()); } + im->functionType = ensureFunctionType(getSig(type.get()), &wasm); + } else if (im->kind == Import::Global) { + im->globalType = stringToWasmType(inner[j]->str()); } - im->type = ensureFunctionType(getSig(type.get()), &wasm); wasm.addImport(im.release()); } void parseGlobal(Element& s) { std::unique_ptr<Global> global = make_unique<Global>(); size_t i = 1; - if (s.size() == 4) { + if (s[i]->dollared()) { global->name = s[i++]->str(); } else { global->name = Name::fromInt(globalCounter); } globalCounter++; + if (s[i]->isList()) { + auto& inner = *s[i]; + if (inner[0]->str() == EXPORT) { + auto ex = make_unique<Export>(); + ex->name = inner[1]->str(); + ex->value = global->name; + ex->kind = Export::Global; + wasm.addExport(ex.release()); + i++; + } else { + WASM_UNREACHABLE(); + } + } global->type = stringToWasmType(s[i++]->str()); global->init = parseExpression(s[i++]); assert(i == s.size()); wasm.addGlobal(global.release()); } + bool seenTable = false; + void parseTable(Element& s) { - if (s.size() == 1) return; // empty table in old notation - if (!s[1]->dollared()) { - if (s[1]->str() == ANYFUNC) { + seenTable = true; + Index i = 1; + if (i == s.size()) return; // empty table in old notation +#if 0 // TODO: new table notation + if (s[i]->dollared()) { + wasm.table.name = s[i++]->str(); + } +#endif + if (i == s.size()) return; + if (s[i]->isList()) { + auto& inner = *s[i]; + if (inner[0]->str() == EXPORT) { + auto ex = make_unique<Export>(); + ex->name = inner[1]->str(); + ex->value = wasm.table.name; + ex->kind = Export::Table; + wasm.addExport(ex.release()); + i++; + } else { + WASM_UNREACHABLE(); + } + } + if (i == s.size()) return; + if (!s[i]->dollared()) { + if (s[i]->str() == ANYFUNC) { // (table type (elem ..)) - parseElem(*s[2]); + parseElem(*s[i + 1]); wasm.table.initial = wasm.table.max = wasm.table.segments[0].data.size(); return; } // first element isn't dollared, and isn't anyfunc. this could be old syntax for (table 0 1) which means function 0 and 1, or it could be (table initial max? type), look for type if (s[s.size() - 1]->str() == ANYFUNC) { // (table initial max? type) - wasm.table.initial = atoi(s[1]->c_str()); - wasm.table.max = atoi(s[2]->c_str()); + wasm.table.initial = atoi(s[i]->c_str()); + wasm.table.max = atoi(s[i + 1]->c_str()); return; } } // old notation (table func1 func2 ..) - parseElem(s); + parseElem(s, i); wasm.table.initial = wasm.table.max = wasm.table.segments[0].data.size(); } - void parseElem(Element& s) { - Index i = 1; + void parseElem(Element& s, Index i = 1) { + if (!seenTable) throw ParseException("elem without table", s.line, s.col); Expression* offset; if (s[i]->isList()) { // there is an init expression @@ -1478,6 +1646,7 @@ private: type->params.push_back(stringToWasmType(curr[j]->str())); } } else if (curr[0]->str() == RESULT) { + if (curr.size() > 2) throw ParseException("invalid result arity", curr.line, curr.col); type->result = stringToWasmType(curr[1]->str()); } } diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index b50ca0fb2..4725237dd 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -53,6 +53,7 @@ struct Visitor { ReturnType visitUnary(Unary *curr) {} ReturnType visitBinary(Binary *curr) {} ReturnType visitSelect(Select *curr) {} + ReturnType visitDrop(Drop *curr) {} ReturnType visitReturn(Return *curr) {} ReturnType visitHost(Host *curr) {} ReturnType visitNop(Nop *curr) {} @@ -93,6 +94,7 @@ struct Visitor { case Expression::Id::UnaryId: DELEGATE(Unary); case Expression::Id::BinaryId: DELEGATE(Binary); case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::DropId: DELEGATE(Drop); case Expression::Id::ReturnId: DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); case Expression::Id::NopId: DELEGATE(Nop); @@ -132,6 +134,7 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> { ReturnType visitUnary(Unary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitBinary(Binary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitSelect(Select *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitDrop(Drop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitReturn(Return *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitHost(Host *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitNop(Nop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } @@ -153,8 +156,8 @@ struct Walker : public VisitorType { // Note that the visit*() for the result node is not called for you (i.e., // just one visit*() method is called by the traversal; if you replace a node, // and you want to process the output, you must do that explicitly). - void replaceCurrent(Expression *expression) { - replace = expression; + Expression* replaceCurrent(Expression *expression) { + return replace = expression; } // Get the current module @@ -264,14 +267,15 @@ struct Walker : public VisitorType { static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast<CallIndirect>()); } static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast<GetLocal>()); } static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast<SetLocal>()); } - static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); } - static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); } + static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); } + static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); } static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); } static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); } static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); } static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); } static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); } static void doVisitSelect(SubType* self, Expression** currp) { self->visitSelect((*currp)->cast<Select>()); } + static void doVisitDrop(SubType* self, Expression** currp) { self->visitDrop((*currp)->cast<Drop>()); } static void doVisitReturn(SubType* self, Expression** currp) { self->visitReturn((*currp)->cast<Return>()); } static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast<Host>()); } static void doVisitNop(SubType* self, Expression** currp) { self->visitNop((*currp)->cast<Nop>()); } @@ -354,10 +358,10 @@ struct PostWalker : public Walker<SubType, VisitorType> { case Expression::Id::CallIndirectId: { self->pushTask(SubType::doVisitCallIndirect, currp); auto& list = curr->cast<CallIndirect>()->operands; + self->pushTask(SubType::scan, &curr->cast<CallIndirect>()->target); for (int i = int(list.size()) - 1; i >= 0; i--) { self->pushTask(SubType::scan, &list[i]); } - self->pushTask(SubType::scan, &curr->cast<CallIndirect>()->target); break; } case Expression::Id::GetLocalId: { @@ -411,6 +415,11 @@ struct PostWalker : public Walker<SubType, VisitorType> { self->pushTask(SubType::scan, &curr->cast<Select>()->ifTrue); break; } + case Expression::Id::DropId: { + self->pushTask(SubType::doVisitDrop, currp); + self->pushTask(SubType::scan, &curr->cast<Drop>()->value); + break; + } case Expression::Id::ReturnId: { self->pushTask(SubType::doVisitReturn, currp); self->maybePushTask(SubType::scan, &curr->cast<Return>()->value); @@ -437,6 +446,112 @@ struct PostWalker : public Walker<SubType, VisitorType> { } }; +// Traversal with a control-flow stack. + +template<typename SubType, typename VisitorType> +struct ControlFlowWalker : public PostWalker<SubType, VisitorType> { + ControlFlowWalker() {} + + std::vector<Expression*> controlFlowStack; // contains blocks, loops, and ifs + + // Uses the control flow stack to find the target of a break to a name + Expression* findBreakTarget(Name name) { + assert(!controlFlowStack.empty()); + Index i = controlFlowStack.size() - 1; + while (1) { + auto* curr = controlFlowStack[i]; + if (Block* block = curr->template dynCast<Block>()) { + if (name == block->name) return curr; + } else if (Loop* loop = curr->template dynCast<Loop>()) { + if (name == loop->name) return curr; + } else { + // an if, ignorable + assert(curr->template is<If>()); + } + if (i == 0) return nullptr; + i--; + } + } + + static void doPreVisitControlFlow(SubType* self, Expression** currp) { + self->controlFlowStack.push_back(*currp); + } + + static void doPostVisitControlFlow(SubType* self, Expression** currp) { + assert(self->controlFlowStack.back() == *currp); + self->controlFlowStack.pop_back(); + } + + static void scan(SubType* self, Expression** currp) { + auto* curr = *currp; + + switch (curr->_id) { + case Expression::Id::BlockId: + case Expression::Id::IfId: + case Expression::Id::LoopId: { + self->pushTask(SubType::doPostVisitControlFlow, currp); + break; + } + default: {} + } + + PostWalker<SubType, VisitorType>::scan(self, currp); + + switch (curr->_id) { + case Expression::Id::BlockId: + case Expression::Id::IfId: + case Expression::Id::LoopId: { + self->pushTask(SubType::doPreVisitControlFlow, currp); + break; + } + default: {} + } + } +}; + +// Traversal with an expression stack. + +template<typename SubType, typename VisitorType> +struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> { + ExpressionStackWalker() {} + + std::vector<Expression*> expressionStack; + + // Uses the control flow stack to find the target of a break to a name + Expression* findBreakTarget(Name name) { + assert(!expressionStack.empty()); + Index i = expressionStack.size() - 1; + while (1) { + auto* curr = expressionStack[i]; + if (Block* block = curr->template dynCast<Block>()) { + if (name == block->name) return curr; + } else if (Loop* loop = curr->template dynCast<Loop>()) { + if (name == loop->name) return curr; + } else { + WASM_UNREACHABLE(); + } + if (i == 0) return nullptr; + i--; + } + } + + static void doPreVisit(SubType* self, Expression** currp) { + self->expressionStack.push_back(*currp); + } + + static void doPostVisit(SubType* self, Expression** currp) { + self->expressionStack.pop_back(); + } + + static void scan(SubType* self, Expression** currp) { + self->pushTask(SubType::doPostVisit, currp); + + PostWalker<SubType, VisitorType>::scan(self, currp); + + self->pushTask(SubType::doPreVisit, currp); + } +}; + // Traversal in the order of execution. This is quick and simple, but // does not provide the same comprehensive information that a full // conversion to basic blocks would. What it does give is a quick diff --git a/src/wasm-validator.h b/src/wasm-validator.h index 3d9a0e1c3..58a30f9a3 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -30,7 +30,16 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>> bool valid = true; bool validateWebConstraints = false; - std::map<Name, WasmType> breakTypes; // breaks to a label must all have the same type, and the right type + struct BreakInfo { + WasmType type; + Index arity; + BreakInfo() {} + BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {} + }; + + std::map<Name, std::vector<Expression*>> breakTargets; // more than one block/loop may use a label name, so stack them + std::map<Expression*, BreakInfo> breakInfos; + WasmType returnType = unreachable; // type used in returns public: @@ -42,55 +51,107 @@ public: // visitors + static void visitPreBlock(WasmValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Block>(); + if (curr->name.is()) self->breakTargets[curr->name].push_back(curr); + } + void visitBlock(Block *curr) { // if we are break'ed to, then the value must be right for us if (curr->name.is()) { - // none or unreachable means a poison value that we should ignore - if consumed, it will error - if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) { - shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value"); + if (breakInfos.count(curr) > 0) { + auto& info = breakInfos[curr]; + // none or unreachable means a poison value that we should ignore - if consumed, it will error + if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) { + shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value"); + } + shouldBeTrue(info.arity != Index(-1), curr, "break arities must match"); + if (curr->list.size() > 0) { + auto last = curr->list.back()->type; + if (isConcreteWasmType(last) && info.type != unreachable) { + shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value"); + } + if (last == none) { + shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type"); + } + } + } + breakTargets[curr->name].pop_back(); + } + if (curr->list.size() > 1) { + for (Index i = 0; i < curr->list.size() - 1; i++) { + if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) { + std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n"; + } } - breakTypes.erase(curr->name); } } - void visitIf(If *curr) { - shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + + static void visitPreLoop(WasmValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Loop>(); + if (curr->name.is()) self->breakTargets[curr->name].push_back(curr); } + void visitLoop(Loop *curr) { - if (curr->in.is()) { - breakTypes.erase(curr->in); - } - if (curr->out.is()) { - breakTypes.erase(curr->out); + if (curr->name.is()) { + breakTargets[curr->name].pop_back(); + if (breakInfos.count(curr) > 0) { + auto& info = breakInfos[curr]; + shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value"); + } } } - void noteBreak(Name name, Expression* value) { + + void visitIf(If *curr) { + shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + } + + // override scan to add a pre and a post check task to all nodes + static void scan(WasmValidator* self, Expression** currp) { + PostWalker<WasmValidator, Visitor<WasmValidator>>::scan(self, currp); + + auto* curr = *currp; + if (curr->is<Block>()) self->pushTask(visitPreBlock, currp); + if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp); + } + + void noteBreak(Name name, Expression* value, Expression* curr) { WasmType valueType = none; + Index arity = 0; if (value) { valueType = value->type; + shouldBeUnequal(valueType, none, curr, "breaks must have a valid value"); + arity = 1; } - if (breakTypes.count(name) == 0) { - breakTypes[name] = valueType; + if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return; + auto* target = breakTargets[name].back(); + if (breakInfos.count(target) == 0) { + breakInfos[target] = BreakInfo(valueType, arity); } else { - if (breakTypes[name] == unreachable) { - breakTypes[name] = valueType; + auto& info = breakInfos[target]; + if (info.type == unreachable) { + info.type = valueType; } else if (valueType != unreachable) { - if (valueType != breakTypes[name]) { - breakTypes[name] = none; // a poison value that must not be consumed + if (valueType != info.type) { + info.type = none; // a poison value that must not be consumed } } + if (arity != info.arity) { + info.arity = Index(-1); // a poison value + } } } void visitBreak(Break *curr) { - noteBreak(curr->name, curr->value); + noteBreak(curr->name, curr->value, curr); if (curr->condition) { shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32"); } } void visitSwitch(Switch *curr) { for (auto& target : curr->targets) { - noteBreak(target, curr->value); + noteBreak(target, curr->value, curr); } - noteBreak(curr->default_, curr->value); + noteBreak(curr->default_, curr->value, curr); shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32"); } void visitCall(Call *curr) { @@ -106,7 +167,7 @@ public: void visitCallImport(CallImport *curr) { auto* import = getModule()->checkImport(curr->target); if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return; - auto* type = import->type; + auto* type = import->functionType; if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; for (size_t i = 0; i < curr->operands.size(); i++) { if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match")) { @@ -115,6 +176,7 @@ public: } } void visitCallIndirect(CallIndirect *curr) { + shouldBeTrue(getModule()->table.segments.size() > 0, curr, "no table"); auto* type = getModule()->checkFunctionType(curr->fullType); if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return; shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32"); @@ -128,18 +190,21 @@ public: void visitSetLocal(SetLocal *curr) { shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough"); if (curr->value->type != unreachable) { - shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); + if (curr->type != none) { // tee is ok anyhow + shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct"); + } shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function"); } } void visitLoad(Load *curr) { - validateAlignment(curr->align); + validateAlignment(curr->align, curr->type, curr->bytes); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32"); } void visitStore(Store *curr) { - validateAlignment(curr->align); + validateAlignment(curr->align, curr->type, curr->bytes); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32"); - shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "store value type must match"); + shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none"); + shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match"); } void visitBinary(Binary *curr) { if (curr->left->type != unreachable && curr->right->type != unreachable) { @@ -216,6 +281,10 @@ public: default: abort(); } } + void visitSelect(Select* curr) { + shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid"); + shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid"); + } void visitReturn(Return* curr) { if (curr->value) { @@ -245,9 +314,11 @@ public: void visitImport(Import* curr) { if (!validateWebConstraints) return; - shouldBeUnequal(curr->type->result, i64, curr->name, "Imported function must not have i64 return type"); - for (WasmType param : curr->type->params) { - shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); + if (curr->kind == Import::Function) { + shouldBeUnequal(curr->functionType->result, i64, curr->name, "Imported function must not have i64 return type"); + for (WasmType param : curr->functionType->params) { + shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); + } } } @@ -262,39 +333,66 @@ public: void visitGlobal(Global* curr) { shouldBeTrue(curr->init->is<Const>(), curr->name, "global init must be valid"); - shouldBeEqual(curr->type, curr->init->type, curr, "global init must have correct type"); + shouldBeEqual(curr->type, curr->init->type, nullptr, "global init must have correct type"); } void visitFunction(Function *curr) { // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return - if (curr->result != none) { - if (curr->body->type != unreachable) { - shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); - } + if (curr->body->type != unreachable) { + shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns"); + } + if (curr->result != none) { // TODO: over previous too? if (returnType != unreachable) { shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function returns"); } } returnType = unreachable; } + + bool isConstant(Expression* curr) { + return curr->is<Const>() || curr->is<GetGlobal>(); + } + void visitMemory(Memory *curr) { shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial"); shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB"); + Index mustBeGreaterOrEqual = 0; + for (auto& segment : curr->segments) { + if (!shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue; + shouldBeTrue(isConstant(segment.offset), segment.offset, "segment offset should be constant"); + Index size = segment.data.size(); + shouldBeTrue(size <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); + if (segment.offset->is<Const>()) { + Index start = segment.offset->cast<Const>()->value.geti32(); + Index end = start + size; + shouldBeTrue(end <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); + shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory"); + mustBeGreaterOrEqual = end; + } + } + } + void visitTable(Table* curr) { + for (auto& segment : curr->segments) { + shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32"); + shouldBeTrue(isConstant(segment.offset), segment.offset, "segment offset should be constant"); + } } void visitModule(Module *curr) { // exports std::set<Name> exportNames; for (auto& exp : curr->exports) { Name name = exp->value; - bool found = false; - for (auto& func : curr->functions) { - if (func->name == name) { - found = true; - break; + if (exp->kind == Export::Function) { + bool found = false; + for (auto& func : curr->functions) { + if (func->name == name) { + found = true; + break; + } } + shouldBeTrue(found, name, "module exports must be found"); } - shouldBeTrue(found, name, "module exports must be found"); Name exportName = exp->name; shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique"); exportNames.insert(exportName); @@ -311,12 +409,6 @@ public: void doWalkFunction(Function* func) { PostWalker<WasmValidator, Visitor<WasmValidator>>::doWalkFunction(func); - if (!shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid")) { - for (auto& target : breakTypes) { - std::cerr << " - " << target.first << '\n'; - } - breakTypes.clear(); - } } private: @@ -360,7 +452,8 @@ private: template<typename T, typename S> bool shouldBeEqual(S left, S right, T curr, const char* text) { if (left != right) { - fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl; + fail() << "" << left << " != " << right << ": " << text << ", on \n"; + WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl; valid = false; return false; } @@ -379,7 +472,8 @@ private: template<typename T, typename S> bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) { if (left != unreachable && left != right) { - fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << std::endl; + fail() << "" << left << " != " << right << ": " << text << ", on \n"; + WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl; valid = false; return false; } @@ -396,7 +490,7 @@ private: return true; } - void validateAlignment(size_t align) { + void validateAlignment(size_t align, WasmType type, Index bytes) { switch (align) { case 1: case 2: @@ -408,6 +502,20 @@ private: break; } } + shouldBeTrue(align <= bytes, align, "alignment must not exceed natural"); + switch (type) { + case i32: + case f32: { + shouldBeTrue(align <= 4, align, "alignment must not exceed natural"); + break; + } + case i64: + case f64: { + shouldBeTrue(align <= 8, align, "alignment must not exceed natural"); + break; + } + default: {} + } } }; diff --git a/src/wasm.cpp b/src/wasm.cpp index 8cfdee759..f6486eb50 100644 --- a/src/wasm.cpp +++ b/src/wasm.cpp @@ -120,7 +120,7 @@ struct TypeSeeker : public PostWalker<TypeSeeker, Visitor<TypeSeeker>> { void visitLoop(Loop* curr) { if (curr == target) { types.push_back(curr->body->type); - } else if (curr->in == targetName || curr->out == targetName) { + } else if (curr->name == targetName) { types.clear(); // ignore all breaks til now, they were captured by someone with the same name } } @@ -162,13 +162,7 @@ void Block::finalize() { } void Loop::finalize() { - if (!out.is()) { - type = body->type; - return; - } - - TypeSeeker seeker(this, this->out); - type = mergeTypes(seeker.types); + type = body->type; } } // namespace wasm diff --git a/src/wasm.h b/src/wasm.h index 68558033d..d6bdfe91f 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -802,7 +802,7 @@ enum UnaryOp { ConvertSInt32ToFloat32, ConvertSInt32ToFloat64, ConvertUInt32ToFloat32, ConvertUInt32ToFloat64, ConvertSInt64ToFloat32, ConvertSInt64ToFloat64, ConvertUInt64ToFloat32, ConvertUInt64ToFloat64, // int to float PromoteFloat32, // f32 to f64 DemoteFloat64, // f64 to f32 - ReinterpretInt32, ReinterpretInt64 // reinterpret bits to float + ReinterpretInt32, ReinterpretInt64, // reinterpret bits to float }; enum BinaryOp { @@ -877,6 +877,7 @@ public: UnaryId, BinaryId, SelectId, + DropId, ReturnId, HostId, NopId, @@ -929,6 +930,7 @@ inline const char *getExpressionName(Expression *curr) { case Expression::Id::UnaryId: return "unary"; case Expression::Id::BinaryId: return "binary"; case Expression::Id::SelectId: return "select"; + case Expression::Id::DropId: return "drop"; case Expression::Id::ReturnId: return "return"; case Expression::Id::HostId: return "host"; case Expression::Id::NopId: return "nop"; @@ -999,7 +1001,7 @@ public: Loop() {} Loop(MixedArena& allocator) {} - Name out, in; + Name name; Expression *body; // set the type of a loop if you already know it @@ -1108,8 +1110,13 @@ public: Index index; Expression *value; - void finalize() { - type = value->type; + bool isTee() { + return type != none; + } + + void setTee(bool is) { + if (is) type = value->type; + else type = none; } }; @@ -1118,7 +1125,7 @@ public: GetGlobal() {} GetGlobal(MixedArena& allocator) {} - Index index; + Name name; }; class SetGlobal : public SpecificExpression<Expression::SetGlobalId> { @@ -1126,12 +1133,8 @@ public: SetGlobal() {} SetGlobal(MixedArena& allocator) {} - Index index; + Name name; Expression *value; - - void finalize() { - type = value->type; - } }; class Load : public SpecificExpression<Expression::LoadId> { @@ -1150,16 +1153,17 @@ public: class Store : public SpecificExpression<Expression::StoreId> { public: - Store() {} - Store(MixedArena& allocator) {} + Store() : valueType(none) {} + Store(MixedArena& allocator) : Store() {} uint8_t bytes; Address offset; Address align; Expression *ptr, *value; + WasmType valueType; // the store never returns a value void finalize() { - type = value->type; + assert(valueType != none); // must be set } }; @@ -1312,6 +1316,14 @@ public: } }; +class Drop : public SpecificExpression<Expression::DropId> { +public: + Drop() {} + Drop(MixedArena& allocator) {} + + Expression *value; +}; + class Return : public SpecificExpression<Expression::ReturnId> { public: Return() : value(nullptr) { @@ -1418,16 +1430,33 @@ public: class Import { public: - Import() : type(nullptr) {} + enum Kind { + Function = 0, + Table = 1, + Memory = 2, + Global = 3, + }; + + Import() : functionType(nullptr), globalType(none) {} Name name, module, base; // name = module.base - FunctionType* type; + Kind kind; + FunctionType* functionType; // for Function imports + WasmType globalType; // for Global imports }; class Export { public: - Name name; // exported name + enum Kind { + Function = 0, + Table = 1, + Memory = 2, + Global = 3, + }; + + Name name; // exported name - note that this is the key, as the internal name is non-unique (can have multiple exports for an internal, also over kinds) Name value; // internal name + Kind kind; }; class Table { @@ -1445,10 +1474,13 @@ public: } }; + Name name; Address initial, max; std::vector<Segment> segments; - Table() : initial(0), max(kMaxSize) {} + Table() : initial(0), max(kMaxSize) { + name = Name::fromInt(0); + } }; class Memory { @@ -1470,11 +1502,13 @@ public: } }; + Name name; Address initial, max; // sizes are in pages std::vector<Segment> segments; - Name exportName; - Memory() : initial(0), max(kMaxSize) {} + Memory() : initial(0), max(kMaxSize) { + name = Name::fromInt(0); + } }; class Global { @@ -1503,7 +1537,7 @@ private: // TODO: add a build option where Names are just indices, and then these methods are not needed std::map<Name, FunctionType*> functionTypesMap; std::map<Name, Import*> importsMap; - std::map<Name, Export*> exportsMap; + std::map<Name, Export*> exportsMap; // exports map is by the *exported* name, which is unique std::map<Name, Function*> functionsMap; std::map<Name, Global*> globalsMap; |