diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/AvoidReinterprets.cpp | 181 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/pass.cpp | 7 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | src/wasm-type.h | 1 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 19 | ||||
-rw-r--r-- | src/wasm2js.h | 11 |
7 files changed, 219 insertions, 2 deletions
diff --git a/src/passes/AvoidReinterprets.cpp b/src/passes/AvoidReinterprets.cpp new file mode 100644 index 000000000..d79645dc6 --- /dev/null +++ b/src/passes/AvoidReinterprets.cpp @@ -0,0 +1,181 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Avoids reinterprets by using more loads: if we load a value and +// reinterpret it, we could have loaded it with the other type +// anyhow. This uses more locals and loads, so it is not generally +// beneficial, unless reinterprets are very costly. + +#include <ir/local-graph.h> +#include <ir/properties.h> +#include <pass.h> +#include <wasm-builder.h> +#include <wasm.h> + +namespace wasm { + +static Load* getSingleLoad(LocalGraph* localGraph, GetLocal* get) { + while (1) { + auto& sets = localGraph->getSetses[get]; + if (sets.size() != 1) { + return nullptr; + } + auto* set = *sets.begin(); + if (!set) { + return nullptr; + } + auto* value = Properties::getFallthrough(set->value); + if (auto* parentGet = value->dynCast<GetLocal>()) { + get = parentGet; + continue; + } + if (auto* load = value->dynCast<Load>()) { + return load; + } + return nullptr; + } +} + +static bool isReinterpret(Unary* curr) { + return curr->op == ReinterpretInt32 || curr->op == ReinterpretInt64 || + curr->op == ReinterpretFloat32 || curr->op == ReinterpretFloat64; +} + +struct AvoidReinterprets : public WalkerPass<PostWalker<AvoidReinterprets>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new AvoidReinterprets; } + + struct Info { + // Info used when analyzing. + bool reinterpreted; + // Info used when optimizing. + Index ptrLocal; + Index reinterpretedLocal; + }; + std::map<Load*, Info> infos; + + LocalGraph* localGraph; + + void doWalkFunction(Function* func) { + // prepare + LocalGraph localGraph_(func); + localGraph = &localGraph_; + // walk + PostWalker<AvoidReinterprets>::doWalkFunction(func); + // optimize + optimize(func); + } + + void visitUnary(Unary* curr) { + if (isReinterpret(curr)) { + if (auto* get = + Properties::getFallthrough(curr->value)->dynCast<GetLocal>()) { + if (auto* load = getSingleLoad(localGraph, get)) { + auto& info = infos[load]; + info.reinterpreted = true; + } + } + } + } + + void optimize(Function* func) { + std::set<Load*> unoptimizables; + for (auto& pair : infos) { + auto* load = pair.first; + auto& info = pair.second; + if (info.reinterpreted && load->type != unreachable) { + // We should use another load here, to avoid reinterprets. + info.ptrLocal = Builder::addVar(func, i32); + info.reinterpretedLocal = + Builder::addVar(func, reinterpretType(load->type)); + } else { + unoptimizables.insert(load); + } + } + for (auto* load : unoptimizables) { + infos.erase(load); + } + // We now know which we can optimize, and how. + struct FinalOptimizer : public PostWalker<FinalOptimizer> { + std::map<Load*, Info>& infos; + LocalGraph* localGraph; + Module* module; + + FinalOptimizer(std::map<Load*, Info>& infos, + LocalGraph* localGraph, + Module* module) + : infos(infos), localGraph(localGraph), module(module) {} + + void visitUnary(Unary* curr) { + if (isReinterpret(curr)) { + auto* value = Properties::getFallthrough(curr->value); + if (auto* load = value->dynCast<Load>()) { + // A reinterpret of a load - flip it right here. + replaceCurrent(makeReinterpretedLoad(load, load->ptr)); + } else if (auto* get = value->dynCast<GetLocal>()) { + if (auto* load = getSingleLoad(localGraph, get)) { + auto iter = infos.find(load); + if (iter != infos.end()) { + auto& info = iter->second; + // A reinterpret of a get of a load - use the new local. + Builder builder(*module); + replaceCurrent(builder.makeGetLocal( + info.reinterpretedLocal, reinterpretType(load->type))); + } + } + } + } + } + + void visitLoad(Load* curr) { + auto iter = infos.find(curr); + if (iter != infos.end()) { + auto& info = iter->second; + Builder builder(*module); + auto* ptr = curr->ptr; + curr->ptr = builder.makeGetLocal(info.ptrLocal, i32); + // Note that the other load can have its sign set to false - if the + // original were an integer, the other is a float anyhow; and if + // original were a float, we don't know what sign to use. + replaceCurrent(builder.makeBlock( + {builder.makeSetLocal(info.ptrLocal, ptr), + builder.makeSetLocal( + info.reinterpretedLocal, + makeReinterpretedLoad(curr, + builder.makeGetLocal(info.ptrLocal, i32))), + curr})); + } + } + + Load* makeReinterpretedLoad(Load* load, Expression* ptr) { + Builder builder(*module); + return builder.makeLoad(load->bytes, + false, + load->offset, + load->align, + ptr, + reinterpretType(load->type)); + } + } finalOptimizer(infos, localGraph, getModule()); + + finalOptimizer.walk(func->body); + } +}; + +Pass* createAvoidReinterpretsPass() { return new AvoidReinterprets(); } + +} // namespace wasm diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 9d98930ed..6605d7a09 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -6,6 +6,7 @@ add_custom_command( SET(passes_SOURCES pass.cpp AlignmentLowering.cpp + AvoidReinterprets.cpp CoalesceLocals.cpp CodePushing.cpp CodeFolding.cpp diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 287dd1cf9..bb1a062e7 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -71,11 +71,14 @@ std::string PassRegistry::getPassDescription(std::string name) { // PassRunner void PassRegistry::registerPasses() { - registerPass( - "dae", "removes arguments to calls in an lto-like manner", createDAEPass); registerPass("alignment-lowering", "lower unaligned loads and stores to smaller aligned ones", createAlignmentLoweringPass); + registerPass("avoid-reinterprets", + "Tries to avoid reinterpret operations via more loads", + createAvoidReinterpretsPass); + registerPass( + "dae", "removes arguments to calls in an lto-like manner", createDAEPass); registerPass("dae-optimizing", "removes arguments to calls in an lto-like manner, and " "optimizes where we removed", diff --git a/src/passes/passes.h b/src/passes/passes.h index cb2950bd4..f3bec3f04 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -23,6 +23,7 @@ class Pass; // All passes: Pass* createAlignmentLoweringPass(); +Pass* createAvoidReinterpretsPass(); Pass* createCoalesceLocalsPass(); Pass* createCoalesceLocalsWithLearningPass(); Pass* createCodeFoldingPass(); diff --git a/src/wasm-type.h b/src/wasm-type.h index 6c8ea82a6..e685bc4b5 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -45,6 +45,7 @@ bool isFloatType(Type type); bool isIntegerType(Type type); bool isVectorType(Type type); bool isReferenceType(Type type); +Type reinterpretType(Type type); } // namespace wasm diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 091d851f6..ebaba3f24 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -118,4 +118,23 @@ bool isReferenceType(Type type) { return type == except_ref; } +Type reinterpretType(Type type) { + switch (type) { + case Type::i32: + return f32; + case Type::i64: + return f64; + case Type::f32: + return i32; + case Type::f64: + return i64; + case Type::v128: + case Type::except_ref: + case Type::none: + case Type::unreachable: + WASM_UNREACHABLE(); + } + WASM_UNREACHABLE(); +} + } // namespace wasm diff --git a/src/wasm2js.h b/src/wasm2js.h index 24d168355..78bfd30ff 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h @@ -296,7 +296,18 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) { // Next, optimize that as best we can. This should not generate // non-JS-friendly things. if (options.optimizeLevel > 0) { + // It is especially import to propagate constants after the lowering. + // However, this can be a slow operation, especially after flattening; + // some local simplification helps. + if (options.optimizeLevel >= 3 || options.shrinkLevel >= 1) { + runner.add("simplify-locals-nonesting"); + runner.add("precompute-propagate"); + // Avoiding reinterpretation is helped by propagation. We also run + // it later down as default optimizations help as well. + runner.add("avoid-reinterprets"); + } runner.addDefaultOptimizationPasses(); + runner.add("avoid-reinterprets"); } // Finally, get the code into the flat form we need for wasm2js itself, and // optimize that a little in a way that keeps flat property. |