summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/asm2wasm.h27
-rw-r--r--src/ast_utils.h4
-rw-r--r--src/pass.h10
-rw-r--r--src/passes/pass.cpp23
-rw-r--r--src/tools/asm2wasm.cpp5
-rw-r--r--src/wasm-module-building.h23
-rw-r--r--src/wasm-validator.h20
7 files changed, 80 insertions, 32 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h
index bf42a8a0e..fa3a3dd4c 100644
--- a/src/asm2wasm.h
+++ b/src/asm2wasm.h
@@ -633,11 +633,15 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
if (body[i][0] == DEFUN) numFunctions++;
}
optimizingBuilder = make_unique<OptimizingIncrementalModuleBuilder>(&wasm, numFunctions, [&](PassRunner& passRunner) {
+ if (debug) {
+ passRunner.setDebug(true);
+ passRunner.setValidateGlobally(false);
+ }
// run autodrop first, before optimizations
passRunner.add<AutoDrop>();
// optimize relooper label variable usage at the wasm level, where it is easy
passRunner.add("relooper-jump-threading");
- });
+ }, debug, false /* do not validate globally yet */);
}
// first pass - do almost everything, but function imports and indirect calls
@@ -821,6 +825,10 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
if (optimize) {
optimizingBuilder->finish();
PassRunner passRunner(&wasm);
+ if (debug) {
+ passRunner.setDebug(true);
+ passRunner.setValidateGlobally(false);
+ }
passRunner.add("post-emscripten");
passRunner.run();
}
@@ -859,7 +867,9 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
Asm2WasmBuilder* parent;
- FinalizeCalls(Asm2WasmBuilder* parent) : parent(parent) {}
+ FinalizeCalls(Asm2WasmBuilder* parent) : parent(parent) {
+ name = "finalize-calls";
+ }
void visitCall(Call* curr) {
if (!getModule()->checkFunction(curr->target)) {
@@ -930,6 +940,10 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
};
PassRunner passRunner(&wasm);
+ if (debug) {
+ passRunner.setDebug(true);
+ passRunner.setValidateGlobally(false);
+ }
passRunner.add<FinalizeCalls>(this);
passRunner.add<ReFinalize>(); // FinalizeCalls changes call types, need to percolate
passRunner.add<AutoDrop>(); // FinalizeCalls may cause us to require additional drops
@@ -1069,9 +1083,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
auto name = ast[1]->getIString();
if (debug) {
- std::cout << "\nfunc: " << ast[1]->getIString().str << '\n';
- ast->stringify(std::cout);
- std::cout << '\n';
+ std::cout << "asm2wasming func: " << ast[1]->getIString().str << '\n';
}
auto function = new Function;
@@ -1141,11 +1153,6 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
std::function<Expression* (Ref)> process = [&](Ref ast) -> Expression* {
AstStackHelper astStackHelper(ast); // TODO: only create one when we need it?
- if (debug) {
- std::cout << "at: ";
- ast->stringify(std::cout);
- std::cout << '\n';
- }
IString what = ast[0]->getIString();
if (what == STAT) {
return process(ast[1]); // and drop return value, if any
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 8b2ae0d59..8124d3a79 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -885,6 +885,8 @@ struct ExpressionAnalyzer {
// Finalizes a node
struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, Visitor<ReFinalize>>> {
+ ReFinalize() { name = "refinalize"; }
+
void visitBlock(Block *curr) { curr->finalize(); }
void visitIf(If *curr) { curr->finalize(); }
void visitLoop(Loop *curr) { curr->finalize(); }
@@ -917,6 +919,8 @@ struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop, Visitor<Auto
Pass* create() override { return new AutoDrop; }
+ AutoDrop() { name = "autodrop"; }
+
bool maybeDrop(Expression*& child) {
bool acted = false;
if (isConcreteWasmType(child->type)) {
diff --git a/src/pass.h b/src/pass.h
index e237b8a98..35065b0a3 100644
--- a/src/pass.h
+++ b/src/pass.h
@@ -62,12 +62,18 @@ struct PassRunner {
Module* wasm;
MixedArena* allocator;
std::vector<Pass*> passes;
- Pass* currPass;
bool debug = false;
+ bool validateGlobally = false;
PassRunner(Module* wasm) : wasm(wasm), allocator(&wasm->allocator) {}
- void setDebug(bool debug_) { debug = debug_; }
+ void setDebug(bool debug_) {
+ debug = debug_;
+ validateGlobally = debug; // validate everything by default if debugging
+ }
+ void setValidateGlobally(bool validate) {
+ validateGlobally = validate;
+ }
void add(std::string passName) {
auto pass = PassRegistry::get()->createPass(passName);
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index ac5ad04c3..59155c834 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -15,6 +15,7 @@
*/
#include <chrono>
+#include <sstream>
#include <passes/passes.h>
#include <pass.h>
@@ -143,7 +144,10 @@ void PassRunner::run() {
padding = std::max(padding, pass->name.size());
}
for (auto* pass : passes) {
- currPass = pass;
+ // ignoring the time, save a printout of the module before, in case this pass breaks it, so we can print the before and after
+ std::stringstream moduleBefore;
+ WasmPrinter::printModule(wasm, moduleBefore);
+ // prepare to run
std::chrono::high_resolution_clock::time_point before;
std::cerr << "[PassRunner] running pass: " << pass->name << "... ";
for (size_t i = 0; i < padding - pass->name.size(); i++) {
@@ -162,18 +166,20 @@ void PassRunner::run() {
std::chrono::duration<double> diff = after - before;
std::cerr << diff.count() << " seconds." << std::endl;
totalTime += diff;
-#if 0
// validate, ignoring the time
std::cerr << "[PassRunner] (validating)\n";
- if (!WasmValidator().validate(*wasm)) {
- std::cerr << "last pass (" << pass->name << ") broke validation\n";
+ if (!WasmValidator().validate(*wasm, false, validateGlobally)) {
+ std::cerr << "Last pass (" << pass->name << ") broke validation. Here is the module before: \n" << moduleBefore.str() << "\n";
abort();
}
-#endif
}
std::cerr << "[PassRunner] passes took " << totalTime.count() << " seconds." << std::endl;
// validate
- assert(WasmValidator().validate(*wasm));
+ std::cerr << "[PassRunner] (final validation)\n";
+ if (!WasmValidator().validate(*wasm, false, validateGlobally)) {
+ std::cerr << "final module does not validate\n";
+ abort();
+ }
} else {
// non-debug normal mode, run them in an optimal manner - for locality it is better
// to run as many passes as possible on a single function before moving to the next
@@ -241,6 +247,11 @@ void PassRunner::doAdd(Pass* pass) {
}
void PassRunner::runPassOnFunction(Pass* pass, Function* func) {
+#if 0
+ if (debug) {
+ std::cerr << "[PassRunner] runPass " << pass->name << " OnFunction " << func->name << "\n";
+ }
+#endif
// function-parallel passes get a new instance per function
if (pass->isFunctionParallel()) {
auto instance = std::unique_ptr<Pass>(pass->create());
diff --git a/src/tools/asm2wasm.cpp b/src/tools/asm2wasm.cpp
index 49e04470b..cd500b0d9 100644
--- a/src/tools/asm2wasm.cpp
+++ b/src/tools/asm2wasm.cpp
@@ -89,11 +89,6 @@ int main(int argc, const char *argv[]) {
if (options.debug) std::cerr << "parsing..." << std::endl;
cashew::Parser<Ref, DotZeroValueBuilder> builder;
Ref asmjs = builder.parseToplevel(start);
- if (options.debug) {
- std::cerr << "parsed:" << std::endl;
- asmjs->stringify(std::cerr, true);
- std::cerr << std::endl;
- }
if (options.debug) std::cerr << "wasming..." << std::endl;
Module wasm;
diff --git a/src/wasm-module-building.h b/src/wasm-module-building.h
index 92b96d98d..56d41ecb8 100644
--- a/src/wasm-module-building.h
+++ b/src/wasm-module-building.h
@@ -83,16 +83,18 @@ class OptimizingIncrementalModuleBuilder {
std::mutex mutex;
std::condition_variable condition;
bool finishing;
+ bool debug;
+ bool validateGlobally;
public:
// numFunctions must be equal to the number of functions allocated, or higher. Knowing
// this bounds helps avoid locking.
- OptimizingIncrementalModuleBuilder(Module* wasm, Index numFunctions, std::function<void (PassRunner&)> addPrePasses)
+ OptimizingIncrementalModuleBuilder(Module* wasm, Index numFunctions, std::function<void (PassRunner&)> addPrePasses, bool debug, bool validateGlobally)
: wasm(wasm), numFunctions(numFunctions), addPrePasses(addPrePasses), endMarker(nullptr), list(nullptr), nextFunction(0),
numWorkers(0), liveWorkers(0), activeWorkers(0), availableFuncs(0), finishedFuncs(0),
- finishing(false) {
- if (numFunctions == 0) {
- // special case: no functions to be optimized. Don't create any threads.
+ finishing(false), debug(debug), validateGlobally(validateGlobally) {
+ if (numFunctions == 0 || debug) {
+ // if no functions to be optimized, or debug non-parallel mode, don't create any threads.
return;
}
@@ -134,6 +136,7 @@ public:
// Add a function to the module, and to be optimized
void addFunction(Function* func) {
wasm->addFunction(func);
+ if (debug) return; // we optimize at the end if debugging
queueFunction(func);
// wake workers if needed
auto wake = availableFuncs.load();
@@ -145,6 +148,18 @@ public:
// All functions have been added, block until all are optimized, and then do
// global optimizations. When this returns, the module is ready and optimized.
void finish() {
+ if (debug) {
+ // in debug mode, optimize each function now that we are done adding functions,
+ // then optimize globally
+ PassRunner passRunner(wasm);
+ passRunner.setDebug(true);
+ passRunner.setValidateGlobally(validateGlobally);
+ addPrePasses(passRunner);
+ passRunner.addDefaultFunctionOptimizationPasses();
+ passRunner.addDefaultGlobalOptimizationPasses();
+ passRunner.run();
+ return;
+ }
DEBUG_THREAD("finish()ing");
assert(nextFunction == numFunctions);
wakeAllWorkers();
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index bdc075946..d741fafdc 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -28,7 +28,8 @@ namespace wasm {
struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>> {
bool valid = true;
- bool validateWebConstraints = false;
+ bool validateWeb = false;
+ bool validateGlobally = true;
struct BreakInfo {
WasmType type;
@@ -43,8 +44,9 @@ struct WasmValidator : public PostWalker<WasmValidator, Visitor<WasmValidator>>
WasmType returnType = unreachable; // type used in returns
public:
- bool validate(Module& module, bool validateWeb=false) {
- validateWebConstraints = validateWeb;
+ bool validate(Module& module, bool validateWeb_ = false, bool validateGlobally_ = true) {
+ validateWeb = validateWeb_;
+ validateGlobally = validateGlobally_;
walkModule(&module);
if (!valid) {
WasmPrinter::printModule(&module, std::cerr);
@@ -175,6 +177,7 @@ public:
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
}
void visitCall(Call *curr) {
+ if (!validateGlobally) return;
auto* target = getModule()->checkFunction(curr->target);
if (!shouldBeTrue(!!target, curr, "call target must exist")) return;
if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return;
@@ -185,8 +188,10 @@ public:
}
}
void visitCallImport(CallImport *curr) {
+ if (!validateGlobally) return;
auto* import = getModule()->checkImport(curr->target);
if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return;
+ if (!shouldBeTrue(import->functionType, curr, "called import must be function")) return;
auto* type = import->functionType;
if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
@@ -196,6 +201,7 @@ public:
}
}
void visitCallIndirect(CallIndirect *curr) {
+ if (!validateGlobally) return;
auto* type = getModule()->checkFunctionType(curr->fullType);
if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return;
shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32");
@@ -336,7 +342,8 @@ public:
}
void visitImport(Import* curr) {
- if (!validateWebConstraints) return;
+ if (!validateWeb) return;
+ if (!validateGlobally) return;
if (curr->kind == Import::Function) {
shouldBeUnequal(curr->functionType->result, i64, curr->name, "Imported function must not have i64 return type");
for (WasmType param : curr->functionType->params) {
@@ -346,7 +353,8 @@ public:
}
void visitExport(Export* curr) {
- if (!validateWebConstraints) return;
+ if (!validateWeb) return;
+ if (!validateGlobally) return;
Function* f = getModule()->getFunction(curr->value);
shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type");
for (auto param : f->params) {
@@ -355,6 +363,7 @@ public:
}
void visitGlobal(Global* curr) {
+ if (!validateGlobally) return;
shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid");
shouldBeEqual(curr->type, curr->init->type, nullptr, "global init must have correct type");
}
@@ -402,6 +411,7 @@ public:
}
}
void visitModule(Module *curr) {
+ if (!validateGlobally) return;
// exports
std::set<Name> exportNames;
for (auto& exp : curr->exports) {