diff options
Diffstat (limited to 'src/passes')
-rw-r--r-- | src/passes/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/passes/CoalesceLocals.cpp | 12 | ||||
-rw-r--r-- | src/passes/DeadCodeElimination.cpp | 58 | ||||
-rw-r--r-- | src/passes/DropReturnValues.cpp | 83 | ||||
-rw-r--r-- | src/passes/LowerIfElse.cpp | 67 | ||||
-rw-r--r-- | src/passes/MergeBlocks.cpp | 69 | ||||
-rw-r--r-- | src/passes/NameManager.cpp | 3 | ||||
-rw-r--r-- | src/passes/Print.cpp | 148 | ||||
-rw-r--r-- | src/passes/RemoveImports.cpp | 2 | ||||
-rw-r--r-- | src/passes/RemoveUnusedBrs.cpp | 2 | ||||
-rw-r--r-- | src/passes/RemoveUnusedNames.cpp | 28 | ||||
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 42 | ||||
-rw-r--r-- | src/passes/Vacuum.cpp | 48 | ||||
-rw-r--r-- | src/passes/pass.cpp | 5 | ||||
-rw-r--r-- | src/passes/passes.h | 1 |
15 files changed, 270 insertions, 300 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 2ccf5f040..1b4c65562 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -2,9 +2,7 @@ SET(passes_SOURCES pass.cpp CoalesceLocals.cpp DeadCodeElimination.cpp - DropReturnValues.cpp DuplicateFunctionElimination.cpp - LowerIfElse.cpp MergeBlocks.cpp Metrics.cpp NameManager.cpp diff --git a/src/passes/CoalesceLocals.cpp b/src/passes/CoalesceLocals.cpp index 2c707b614..2063a20db 100644 --- a/src/passes/CoalesceLocals.cpp +++ b/src/passes/CoalesceLocals.cpp @@ -485,9 +485,19 @@ void CoalesceLocals::applyIndices(std::vector<Index>& indices, Expression* root) // in addition, we can optimize out redundant copies and ineffective sets GetLocal* get; if ((get = set->value->dynCast<GetLocal>()) && get->index == set->index) { - *action.origin = get; // further optimizations may get rid of the get, if this is in a place where the output does not matter + if (set->isTee()) { + *action.origin = get; + } else { + ExpressionManipulator::nop(set); + } } else if (!action.effective) { *action.origin = set->value; // value may have no side effects, further optimizations can eliminate it + if (!set->isTee()) { + // we need to drop it + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(set); + drop->value = *action.origin; + *action.origin = drop; + } } } } diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index de1191131..b30b8ffbd 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -31,6 +31,7 @@ #include <wasm.h> #include <pass.h> #include <ast_utils.h> +#include <wasm-builder.h> namespace wasm { @@ -131,12 +132,8 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V } void visitLoop(Loop* curr) { - if (curr->in.is()) { - reachableBreaks.erase(curr->in); - } - if (curr->out.is()) { - reachable = reachable || reachableBreaks.count(curr->out); - reachableBreaks.erase(curr->out); + if (curr->name.is()) { + reachableBreaks.erase(curr->name); } if (isDead(curr->body)) { replaceCurrent(curr->body); @@ -191,6 +188,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V case Expression::Id::UnaryId: DELEGATE(Unary); case Expression::Id::BinaryId: DELEGATE(Binary); case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::DropId: DELEGATE(Drop); case Expression::Id::ReturnId: DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); case Expression::Id::NopId: DELEGATE(Nop); @@ -226,46 +224,52 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V // other things + Expression* drop(Expression* toDrop) { + if (toDrop->is<Unreachable>()) return toDrop; + return Builder(*getModule()).makeDrop(toDrop); + } + template<typename T> - void handleCall(T* curr, Expression* initial) { + Expression* handleCall(T* curr) { for (Index i = 0; i < curr->operands.size(); i++) { if (isDead(curr->operands[i])) { - if (i > 0 || initial != nullptr) { + if (i > 0) { auto* block = getModule()->allocator.alloc<Block>(); - Index newSize = i + 1 + (initial ? 1 : 0); + Index newSize = i + 1; block->list.resize(newSize); Index j = 0; - if (initial) { - block->list[j] = initial; - j++; - } for (; j < newSize; j++) { - block->list[j] = curr->operands[j - (initial ? 1 : 0)]; + block->list[j] = drop(curr->operands[j]); } block->finalize(); - replaceCurrent(block); + return replaceCurrent(block); } else { - replaceCurrent(curr->operands[i]); + return replaceCurrent(curr->operands[i]); } - return; } } + return curr; } void visitCall(Call* curr) { - handleCall(curr, nullptr); + handleCall(curr); } void visitCallImport(CallImport* curr) { - handleCall(curr, nullptr); + handleCall(curr); } void visitCallIndirect(CallIndirect* curr) { + if (handleCall(curr) != curr) return; if (isDead(curr->target)) { - replaceCurrent(curr->target); - return; + auto* block = getModule()->allocator.alloc<Block>(); + for (auto* operand : curr->operands) { + block->list.push_back(drop(operand)); + } + block->list.push_back(curr->target); + block->finalize(); + replaceCurrent(block); } - handleCall(curr, curr->target); } void visitSetLocal(SetLocal* curr) { @@ -288,7 +292,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->value)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->ptr; + block->list[0] = drop(curr->ptr); block->list[1] = curr->value; block->finalize(); replaceCurrent(block); @@ -309,7 +313,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->right)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->left; + block->list[0] = drop(curr->left); block->list[1] = curr->right; block->finalize(); replaceCurrent(block); @@ -324,7 +328,7 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->ifFalse)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(2); - block->list[0] = curr->ifTrue; + block->list[0] = drop(curr->ifTrue); block->list[1] = curr->ifFalse; block->finalize(); replaceCurrent(block); @@ -333,8 +337,8 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, V if (isDead(curr->condition)) { auto* block = getModule()->allocator.alloc<Block>(); block->list.resize(3); - block->list[0] = curr->ifTrue; - block->list[1] = curr->ifFalse; + block->list[0] = drop(curr->ifTrue); + block->list[1] = drop(curr->ifFalse); block->list[2] = curr->condition; block->finalize(); replaceCurrent(block); diff --git a/src/passes/DropReturnValues.cpp b/src/passes/DropReturnValues.cpp deleted file mode 100644 index 8715f3f61..000000000 --- a/src/passes/DropReturnValues.cpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2016 WebAssembly Community Group participants - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -// Stops using return values from set_local and store nodes. -// - -#include <wasm.h> -#include <pass.h> -#include <ast_utils.h> -#include <wasm-builder.h> - -namespace wasm { - -struct DropReturnValues : public WalkerPass<PostWalker<DropReturnValues, Visitor<DropReturnValues>>> { - bool isFunctionParallel() override { return true; } - - Pass* create() override { return new DropReturnValues; } - - std::vector<Expression*> expressionStack; - - void visitSetLocal(SetLocal* curr) { - if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) { - Builder builder(*getModule()); - replaceCurrent(builder.makeSequence( - curr, - builder.makeGetLocal(curr->index, curr->type) - )); - } - } - - void visitStore(Store* curr) { - if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) { - Index index = getFunction()->getNumLocals(); - getFunction()->vars.emplace_back(curr->type); - Builder builder(*getModule()); - replaceCurrent(builder.makeSequence( - builder.makeSequence( - builder.makeSetLocal(index, curr->value), - curr - ), - builder.makeGetLocal(index, curr->type) - )); - curr->value = builder.makeGetLocal(index, curr->type); - } - } - - static void visitPre(DropReturnValues* self, Expression** currp) { - self->expressionStack.push_back(*currp); - } - - static void visitPost(DropReturnValues* self, Expression** currp) { - self->expressionStack.pop_back(); - } - - static void scan(DropReturnValues* self, Expression** currp) { - self->pushTask(visitPost, currp); - - WalkerPass<PostWalker<DropReturnValues, Visitor<DropReturnValues>>>::scan(self, currp); - - self->pushTask(visitPre, currp); - } -}; - -Pass *createDropReturnValuesPass() { - return new DropReturnValues(); -} - -} // namespace wasm - diff --git a/src/passes/LowerIfElse.cpp b/src/passes/LowerIfElse.cpp deleted file mode 100644 index b566e8207..000000000 --- a/src/passes/LowerIfElse.cpp +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2015 WebAssembly Community Group participants - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -// Lowers if (x) y else z into -// -// L: { -// if (x) break (y) L -// z -// } -// -// This is useful for investigating how beneficial if_else is. -// - -#include <memory> - -#include <wasm.h> -#include <pass.h> - -namespace wasm { - -struct LowerIfElse : public WalkerPass<PostWalker<LowerIfElse, Visitor<LowerIfElse>>> { - MixedArena* allocator; - std::unique_ptr<NameManager> namer; - - void prepare(PassRunner* runner, Module *module) override { - allocator = runner->allocator; - namer = make_unique<NameManager>(); - namer->run(runner, module); - } - - void visitIf(If *curr) { - if (curr->ifFalse) { - auto block = allocator->alloc<Block>(); - auto name = namer->getUnique("L"); // TODO: getUniqueInFunction - block->name = name; - block->list.push_back(curr); - block->list.push_back(curr->ifFalse); - block->finalize(); - curr->ifFalse = nullptr; - auto break_ = allocator->alloc<Break>(); - break_->name = name; - break_->value = curr->ifTrue; - curr->ifTrue = break_; - replaceCurrent(block); - } - } -}; - -Pass *createLowerIfElsePass() { - return new LowerIfElse(); -} - -} // namespace wasm diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index 686bb5d75..bde9397a8 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -64,9 +64,40 @@ #include <wasm.h> #include <pass.h> #include <ast_utils.h> +#include <wasm-builder.h> namespace wasm { +struct SwitchFinder : public ControlFlowWalker<SwitchFinder, Visitor<SwitchFinder>> { + Expression* origin; + bool found = false; + + void visitSwitch(Switch* curr) { + if (findBreakTarget(curr->default_) == origin) { + found = true; + return; + } + for (auto& target : curr->targets) { + if (findBreakTarget(target) == origin) { + found = true; + return; + } + } + } +}; + +struct BreakValueDropper : public ControlFlowWalker<BreakValueDropper, Visitor<BreakValueDropper>> { + Expression* origin; + + void visitBreak(Break* curr) { + if (curr->value && findBreakTarget(curr->name) == origin) { + Builder builder(*getModule()); + replaceCurrent(builder.makeSequence(builder.makeDrop(curr->value), curr)); + curr->value = nullptr; + } + } +}; + struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBlocks>>> { bool isFunctionParallel() override { return true; } @@ -74,10 +105,46 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc void visitBlock(Block *curr) { bool more = true; + bool changed = false; while (more) { more = false; for (size_t i = 0; i < curr->list.size(); i++) { Block* child = curr->list[i]->dynCast<Block>(); + if (!child) { + // if we have a child that is (drop (block ..)) then we can move the drop into the block, and remove br values. this allows more merging, + auto* drop = curr->list[i]->dynCast<Drop>(); + if (drop) { + child = drop->value->dynCast<Block>(); + if (child) { + if (child->name.is()) { + Expression* expression = child; + // if there is a switch targeting us, we can't do it - we can't remove the value from other targets too + SwitchFinder finder; + finder.origin = child; + finder.walk(expression); + if (finder.found) { + child = nullptr; + } else { + // fix up breaks + BreakValueDropper fixer; + fixer.origin = child; + fixer.setModule(getModule()); + fixer.walk(expression); + } + } + if (child) { + // we can do it! + // reuse the drop + drop->value = child->list.back(); + child->list.back() = drop; + child->finalize(); + curr->list[i] = child; + more = true; + changed = true; + } + } + } + } if (!child) continue; if (child->name.is()) continue; // named blocks can have breaks to them (and certainly do, if we ran RemoveUnusedNames and RemoveUnusedBrs) ExpressionList merged(getModule()->allocator); @@ -92,9 +159,11 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBloc } curr->list = merged; more = true; + changed = true; break; } } + if (changed) curr->finalize(); } Block* optimize(Expression* curr, Expression*& child, Block* outer = nullptr, Expression** dependency1 = nullptr, Expression** dependency2 = nullptr) { diff --git a/src/passes/NameManager.cpp b/src/passes/NameManager.cpp index df8b34557..9f0198c2f 100644 --- a/src/passes/NameManager.cpp +++ b/src/passes/NameManager.cpp @@ -37,8 +37,7 @@ void NameManager::visitBlock(Block* curr) { names.insert(curr->name); } void NameManager::visitLoop(Loop* curr) { - names.insert(curr->out); - names.insert(curr->in); + names.insert(curr->name); } void NameManager::visitBreak(Break* curr) { names.insert(curr->name); diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 5eea38bdc..43ddc954c 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -32,14 +32,17 @@ struct PrintSExpression : public Visitor<PrintSExpression> { const char *maybeSpace; const char *maybeNewLine; - bool fullAST = false; // whether to not elide nodes in output when possible - // (like implicit blocks) + bool full = false; // whether to not elide nodes in output when possible + // (like implicit blocks) and to emit types Module* currModule = nullptr; Function* currFunction = nullptr; PrintSExpression(std::ostream& o) : o(o) { setMinify(false); + if (getenv("BINARYEN_PRINT_FULL")) { + full = std::stoi(getenv("BINARYEN_PRINT_FULL")); + } } void setMinify(bool minify_) { @@ -48,7 +51,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { maybeNewLine = minify ? "" : "\n"; } - void setFullAST(bool fullAST_) { fullAST = fullAST_; } + void setFull(bool full_) { full = full_; } void incIndent() { if (minify) return; @@ -64,6 +67,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void printFullLine(Expression *expression) { !minify && doIndent(o, indent); + if (full) { + o << "[" << printWasmType(expression->type) << "] "; + } visit(expression); o << maybeNewLine; } @@ -79,10 +85,6 @@ struct PrintSExpression : public Visitor<PrintSExpression> { return name; } - Name printableGlobal(Index index) { - return currModule->getGlobal(index)->name; - } - std::ostream& printName(Name name) { // we need to quote names if they have tricky chars if (strpbrk(name.str, "()")) { @@ -99,6 +101,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> { while (1) { if (stack.size() > 0) doIndent(o, indent); stack.push_back(curr); + if (full) { + o << "[" << printWasmType(curr->type) << "] "; + } printOpening(o, "block"); if (curr->name.is()) { o << ' '; @@ -135,13 +140,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> { incIndent(); printFullLine(curr->condition); // ifTrue and False have implict blocks, avoid printing them if possible - if (!fullAST && curr->ifTrue->is<Block>() && curr->ifTrue->dynCast<Block>()->name.isNull() && curr->ifTrue->dynCast<Block>()->list.size() == 1) { + if (!full && curr->ifTrue->is<Block>() && curr->ifTrue->dynCast<Block>()->name.isNull() && curr->ifTrue->dynCast<Block>()->list.size() == 1) { printFullLine(curr->ifTrue->dynCast<Block>()->list.back()); } else { printFullLine(curr->ifTrue); } if (curr->ifFalse) { - if (!fullAST && curr->ifFalse->is<Block>() && curr->ifFalse->dynCast<Block>()->name.isNull() && curr->ifFalse->dynCast<Block>()->list.size() == 1) { + if (!full && curr->ifFalse->is<Block>() && curr->ifFalse->dynCast<Block>()->name.isNull() && curr->ifFalse->dynCast<Block>()->list.size() == 1) { printFullLine(curr->ifFalse->dynCast<Block>()->list.back()); } else { printFullLine(curr->ifFalse); @@ -151,16 +156,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void visitLoop(Loop *curr) { printOpening(o, "loop"); - if (curr->out.is()) { - o << ' ' << curr->out; - assert(curr->in.is()); // if just one is printed, it must be the in - } - if (curr->in.is()) { - o << ' ' << curr->in; + if (curr->name.is()) { + o << ' ' << curr->name; } incIndent(); auto block = curr->body->dynCast<Block>(); - if (!fullAST && block && block->name.isNull()) { + if (!full && block && block->name.isNull()) { // wasm spec has loops containing children directly, while our ast // has a single child for simplicity. print out the optimal form. for (auto expression : block->list) { @@ -229,26 +230,33 @@ struct PrintSExpression : public Visitor<PrintSExpression> { void visitCallIndirect(CallIndirect *curr) { printOpening(o, "call_indirect ") << curr->fullType; incIndent(); - printFullLine(curr->target); for (auto operand : curr->operands) { printFullLine(operand); } + printFullLine(curr->target); decIndent(); } void visitGetLocal(GetLocal *curr) { printOpening(o, "get_local ") << printableLocal(curr->index) << ')'; } void visitSetLocal(SetLocal *curr) { - printOpening(o, "set_local ") << printableLocal(curr->index); + if (curr->isTee()) { + printOpening(o, "tee_local "); + } else { + printOpening(o, "set_local "); + } + o << printableLocal(curr->index); incIndent(); printFullLine(curr->value); decIndent(); } void visitGetGlobal(GetGlobal *curr) { - printOpening(o, "get_global ") << printableGlobal(curr->index) << ')'; + printOpening(o, "get_global "); + printName(curr->name) << ')'; } void visitSetGlobal(SetGlobal *curr) { - printOpening(o, "set_global ") << printableGlobal(curr->index); + printOpening(o, "set_global "); + printName(curr->name); incIndent(); printFullLine(curr->value); decIndent(); @@ -281,7 +289,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void visitStore(Store *curr) { o << '('; - prepareColor(o) << printWasmType(curr->type) << ".store"; + prepareColor(o) << printWasmType(curr->valueType) << ".store"; if (curr->bytes < 4 || (curr->type == i64 && curr->bytes < 8)) { if (curr->bytes == 1) { o << '8'; @@ -466,9 +474,16 @@ struct PrintSExpression : public Visitor<PrintSExpression> { printFullLine(curr->condition); decIndent(); } + void visitDrop(Drop *curr) { + o << '('; + prepareColor(o) << "drop"; + incIndent(); + printFullLine(curr->value); + decIndent(); + } void visitReturn(Return *curr) { printOpening(o, "return"); - if (!curr->value || curr->value->is<Nop>()) { + if (!curr->value) { // avoid a new line just for the parens o << ')'; return; @@ -499,11 +514,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> { printMinorOpening(o, "unreachable") << ')'; } // Module-level visitors - void visitFunctionType(FunctionType *curr, bool full=false) { - if (full) { - printOpening(o, "type") << ' '; - printName(curr->name) << " (func"; - } + void visitFunctionType(FunctionType *curr, Name* internalName = nullptr) { + o << "(func"; + if (internalName) o << ' ' << *internalName; if (curr->params.size() > 0) { o << maybeSpace; printMinorOpening(o, "param"); @@ -516,27 +529,39 @@ struct PrintSExpression : public Visitor<PrintSExpression> { o << maybeSpace; printMinorOpening(o, "result ") << printWasmType(curr->result) << ')'; } - if (full) { - o << "))"; - } + o << ")"; } void visitImport(Import *curr) { printOpening(o, "import "); - printName(curr->name) << ' '; printText(o, curr->module.str) << ' '; - printText(o, curr->base.str); - if (curr->type) visitFunctionType(curr->type); + printText(o, curr->base.str) << ' '; + switch (curr->kind) { + case Export::Function: if (curr->functionType) visitFunctionType(curr->functionType, &curr->name); break; + case Export::Table: o << "(table " << curr->name << ")"; break; + case Export::Memory: o << "(memory " << curr->name << ")"; break; + case Export::Global: o << "(global " << curr->name << ' ' << printWasmType(curr->globalType) << ")"; break; + default: WASM_UNREACHABLE(); + } o << ')'; } void visitExport(Export *curr) { printOpening(o, "export "); - printText(o, curr->name.str) << ' '; - printName(curr->value) << ')'; + printText(o, curr->name.str) << " ("; + switch (curr->kind) { + case Export::Function: o << "func"; break; + case Export::Table: o << "table"; break; + case Export::Memory: o << "memory"; break; + case Export::Global: o << "global"; break; + default: WASM_UNREACHABLE(); + } + o << ' '; + printName(curr->value) << "))"; } void visitGlobal(Global *curr) { printOpening(o, "global "); - printName(curr->name) << ' ' << printWasmType(curr->type); - printFullLine(curr->init); + printName(curr->name) << ' '; + o << printWasmType(curr->type) << ' '; + visit(curr->init); o << ')'; } void visitFunction(Function *curr) { @@ -564,7 +589,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } // It is ok to emit a block here, as a function can directly contain a list, even if our // ast avoids that for simplicity. We can just do that optimization here.. - if (!fullAST && curr->body->is<Block>() && curr->body->cast<Block>()->name.isNull()) { + if (!full && curr->body->is<Block>() && curr->body->cast<Block>()->name.isNull()) { Block* block = curr->body->cast<Block>(); for (auto item : block->list) { printFullLine(item); @@ -575,7 +600,8 @@ struct PrintSExpression : public Visitor<PrintSExpression> { decIndent(); } void visitTable(Table *curr) { - printOpening(o, "table") << ' ' << curr->initial; + printOpening(o, "table") << ' '; + o << curr->initial; if (curr->max && curr->max != Table::kMaxSize) o << ' ' << curr->max; o << " anyfunc)\n"; doIndent(o, indent); @@ -589,15 +615,12 @@ struct PrintSExpression : public Visitor<PrintSExpression> { o << ')'; } } - void visitModule(Module *curr) { - currModule = curr; - printOpening(o, "module", true); - incIndent(); - doIndent(o, indent); - printOpening(o, "memory") << ' ' << curr->memory.initial; - if (curr->memory.max && curr->memory.max != Memory::kMaxSize) o << ' ' << curr->memory.max; + void visitMemory(Memory* curr) { + printOpening(o, "memory") << ' '; + o << curr->initial; + if (curr->max && curr->max != Memory::kMaxSize) o << ' ' << curr->max; o << ")\n"; - for (auto segment : curr->memory.segments) { + for (auto segment : curr->segments) { doIndent(o, indent); printOpening(o, "data ", true); visit(segment.offset); @@ -624,12 +647,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } o << "\")\n"; } - if (curr->memory.exportName.is()) { - doIndent(o, indent); - printOpening(o, "export "); - printText(o, curr->memory.exportName.str) << " memory)"; - o << maybeNewLine; - } + } + void visitModule(Module *curr) { + currModule = curr; + printOpening(o, "module", true); + incIndent(); + doIndent(o, indent); + visitMemory(&curr->memory); if (curr->start.is()) { doIndent(o, indent); printOpening(o, "start") << ' ' << curr->start << ')'; @@ -637,8 +661,10 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } for (auto& child : curr->functionTypes) { doIndent(o, indent); - visitFunctionType(child.get(), true); - o << maybeNewLine; + printOpening(o, "type") << ' '; + printName(child->name) << ' '; + visitFunctionType(child.get()); + o << ")" << maybeNewLine; } for (auto& child : curr->imports) { doIndent(o, indent); @@ -707,7 +733,7 @@ public: void run(PassRunner* runner, Module* module) override { PrintSExpression print(o); - print.setFullAST(true); + print.setFull(true); print.visitModule(module); } }; @@ -718,9 +744,17 @@ Pass *createFullPrinterPass() { // Print individual expressions -std::ostream& WasmPrinter::printExpression(Expression* expression, std::ostream& o, bool minify) { +std::ostream& WasmPrinter::printExpression(Expression* expression, std::ostream& o, bool minify, bool full) { + if (!expression) { + o << "(null expression)"; + return o; + } PrintSExpression print(o); print.setMinify(minify); + if (full) { + print.setFull(true); + o << "[" << printWasmType(expression->type) << "] "; + } print.visit(expression); return o; } diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp index 0b3f50049..19d6c3eb1 100644 --- a/src/passes/RemoveImports.cpp +++ b/src/passes/RemoveImports.cpp @@ -37,7 +37,7 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports, Visitor<Remov } void visitCallImport(CallImport *curr) { - WasmType type = module->getImport(curr->target)->type->result; + WasmType type = module->getImport(curr->target)->functionType->result; if (type == none) { replaceCurrent(allocator->alloc<Nop>()); } else { diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp index 263cea655..59d3af6fc 100644 --- a/src/passes/RemoveUnusedBrs.cpp +++ b/src/passes/RemoveUnusedBrs.cpp @@ -189,7 +189,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R // finally, we may have simplified ifs enough to turn them into selects struct Selectifier : public WalkerPass<PostWalker<Selectifier, Visitor<Selectifier>>> { void visitIf(If* curr) { - if (curr->ifFalse) { + if (curr->ifFalse && isConcreteWasmType(curr->ifTrue->type) && isConcreteWasmType(curr->ifFalse->type)) { // if with else, consider turning it into a select if there is no control flow // TODO: estimate cost EffectAnalyzer condition(curr->condition); diff --git a/src/passes/RemoveUnusedNames.cpp b/src/passes/RemoveUnusedNames.cpp index 9c6743479..8e24f9549 100644 --- a/src/passes/RemoveUnusedNames.cpp +++ b/src/passes/RemoveUnusedNames.cpp @@ -76,34 +76,12 @@ struct RemoveUnusedNames : public WalkerPass<PostWalker<RemoveUnusedNames, Visit } } handleBreakTarget(curr->name); - if (curr->name.is() && curr->list.size() == 1) { - auto* child = curr->list[0]->dynCast<Loop>(); - if (child && !child->out.is()) { - // we have just one child, this loop, and it lacks an out label. So this block's name is doing just that! - child->out = curr->name; - replaceCurrent(child); - } - } } void visitLoop(Loop *curr) { - handleBreakTarget(curr->in); - // Loops can have just 'in', but cannot have just 'out' - auto out = curr->out; - handleBreakTarget(curr->out); - if (curr->out.is() && !curr->in.is()) { - auto* block = getModule()->allocator.alloc<Block>(); - block->name = out; - block->list.push_back(curr->body); - replaceCurrent(block); - } - if (curr->in.is() && !curr->out.is()) { - auto* child = curr->body->dynCast<Block>(); - if (child && child->name.is()) { - // we have just one child, this block, and we lack an out label. So we can take the block's! - curr->out = child->name; - child->name = Name(); - } + handleBreakTarget(curr->name); + if (!curr->name.is()) { + replaceCurrent(curr->body); } } diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index d9786e62c..5315edac4 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -55,7 +55,13 @@ struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemo void visitSetLocal(SetLocal *curr) { if ((*numGetLocals)[curr->index] == 0) { - replaceCurrent(curr->value); + auto* value = curr->value; + if (curr->isTee()) { + replaceCurrent(value); + } else { + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(curr); + drop->value = value; + } } } }; @@ -180,7 +186,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, auto found = sinkables.find(curr->index); if (found != sinkables.end()) { // sink it, and nop the origin - replaceCurrent(*found->second.item); + auto* set = (*found->second.item)->cast<SetLocal>(); + replaceCurrent(set); + assert(!set->isTee()); + set->setTee(true); // reuse the getlocal that is dying *found->second.item = curr; ExpressionManipulator::nop(curr); @@ -189,6 +198,16 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, } } + void visitDrop(Drop* curr) { + // collapse drop-tee into set, which can occur if a get was sunk into a tee + auto* set = curr->value->dynCast<SetLocal>(); + if (set) { + assert(set->isTee()); + set->setTee(false); + replaceCurrent(set); + } + } + void checkInvalidations(EffectAnalyzer& effects) { // TODO: this is O(bad) std::vector<Index> invalidated; @@ -225,7 +244,11 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, // store is dead, leave just the value auto found = self->sinkables.find(set->index); if (found != self->sinkables.end()) { - *found->second.item = (*found->second.item)->cast<SetLocal>()->value; + auto* previous = (*found->second.item)->cast<SetLocal>(); + assert(!previous->isTee()); + auto* previousValue = previous->value; + Drop* drop = ExpressionManipulator::convert<SetLocal, Drop>(previous); + drop->value = previousValue; self->sinkables.erase(found); self->anotherCycle = true; } @@ -236,15 +259,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, self->checkInvalidations(effects); } - if (set) { - // we may be a replacement for the current node, update the stack - self->expressionStack.pop_back(); - self->expressionStack.push_back(set); - if (!ExpressionAnalyzer::isResultUsed(self->expressionStack, self->getFunction())) { - Index index = set->index; - assert(self->sinkables.count(index) == 0); - self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); - } + if (set && !set->isTee()) { + Index index = set->index; + assert(self->sinkables.count(index) == 0); + self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); } self->expressionStack.pop_back(); diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index 0195427d7..db42a994a 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -25,13 +25,11 @@ namespace wasm { -struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { +struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>>> { bool isFunctionParallel() override { return true; } Pass* create() override { return new Vacuum; } - std::vector<Expression*> expressionStack; - // returns nullptr if curr is dead, curr if it must stay as is, or another node if it can be replaced Expression* optimize(Expression* curr, bool resultUsed) { while (1) { @@ -41,6 +39,7 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { case Expression::Id::BlockId: return curr; // not always needed, but handled in visitBlock() case Expression::Id::IfId: return curr; // not always needed, but handled in visitIf() case Expression::Id::LoopId: return curr; // not always needed, but handled in visitLoop() + case Expression::Id::DropId: return curr; // not always needed, but handled in visitDrop() case Expression::Id::BreakId: case Expression::Id::SwitchId: @@ -51,6 +50,8 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { case Expression::Id::LoadId: case Expression::Id::StoreId: case Expression::Id::ReturnId: + case Expression::Id::GetGlobalId: + case Expression::Id::SetGlobalId: case Expression::Id::HostId: case Expression::Id::UnreachableId: return curr; // always needed @@ -189,7 +190,7 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { // no else if (curr->ifTrue->is<Nop>()) { // no nothing - replaceCurrent(curr->condition); + replaceCurrent(Builder(*getModule()).makeDrop(curr->condition)); } } } @@ -198,21 +199,30 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> { if (curr->body->is<Nop>()) ExpressionManipulator::nop(curr); } - static void visitPre(Vacuum* self, Expression** currp) { - self->expressionStack.push_back(*currp); - } - - static void visitPost(Vacuum* self, Expression** currp) { - self->expressionStack.pop_back(); - } - - // override scan to add a pre and a post check task to all nodes - static void scan(Vacuum* self, Expression** currp) { - self->pushTask(visitPost, currp); - - WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>>::scan(self, currp); - - self->pushTask(visitPre, currp); + void visitDrop(Drop* curr) { + // if the drop input has no side effects, it can be wiped out + if (!EffectAnalyzer(curr->value).hasSideEffects()) { + ExpressionManipulator::nop(curr); + return; + } + // sink a drop into an arm of an if-else if the other arm ends in an unreachable, as it if is a branch, this can make that branch optimizable and more vaccuming possible + auto* iff = curr->value->dynCast<If>(); + if (iff && iff->ifFalse && isConcreteWasmType(iff->type)) { + // reuse the drop in both cases + if (iff->ifTrue->type == unreachable) { + assert(isConcreteWasmType(iff->ifFalse->type)); + curr->value = iff->ifFalse; + iff->ifFalse = curr; + iff->type = none; + replaceCurrent(iff); + } else if (iff->ifFalse->type == unreachable) { + assert(isConcreteWasmType(iff->ifTrue->type)); + curr->value = iff->ifTrue; + iff->ifTrue = curr; + iff->type = none; + replaceCurrent(iff); + } + } } void visitFunction(Function* curr) { diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index f27eccf6a..437b023bc 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -64,9 +64,7 @@ void PassRegistry::registerPasses() { registerPass("coalesce-locals", "reduce # of locals by coalescing", createCoalesceLocalsPass); registerPass("coalesce-locals-learning", "reduce # of locals by coalescing and learning", createCoalesceLocalsWithLearningPass); registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass); - registerPass("drop-return-values", "stops relying on return values from set_local and store", createDropReturnValuesPass); registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass); - registerPass("lower-if-else", "lowers if-elses into ifs, blocks and branches", createLowerIfElsePass); registerPass("merge-blocks", "merges blocks to their parents", createMergeBlocksPass); registerPass("metrics", "reports metrics", createMetricsPass); registerPass("nm", "name list", createNameListPass); @@ -211,6 +209,9 @@ void PassRunner::run() { } void PassRunner::runFunction(Function* func) { + if (debug) { + std::cerr << "[PassRunner] running passes on function " << func->name << std::endl; + } for (auto* pass : passes) { runPassOnFunction(pass, func); } diff --git a/src/passes/passes.h b/src/passes/passes.h index 731536d7d..4bb76edad 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -25,7 +25,6 @@ class Pass; Pass *createCoalesceLocalsPass(); Pass *createCoalesceLocalsWithLearningPass(); Pass *createDeadCodeEliminationPass(); -Pass *createDropReturnValuesPass(); Pass *createDuplicateFunctionEliminationPass(); Pass *createLowerIfElsePass(); Pass *createMergeBlocksPass(); |