summaryrefslogtreecommitdiff
path: root/src/passes/AvoidReinterprets.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/AvoidReinterprets.cpp')
-rw-r--r--src/passes/AvoidReinterprets.cpp181
1 files changed, 181 insertions, 0 deletions
diff --git a/src/passes/AvoidReinterprets.cpp b/src/passes/AvoidReinterprets.cpp
new file mode 100644
index 000000000..d79645dc6
--- /dev/null
+++ b/src/passes/AvoidReinterprets.cpp
@@ -0,0 +1,181 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Avoids reinterprets by using more loads: if we load a value and
+// reinterpret it, we could have loaded it with the other type
+// anyhow. This uses more locals and loads, so it is not generally
+// beneficial, unless reinterprets are very costly.
+
+#include <ir/local-graph.h>
+#include <ir/properties.h>
+#include <pass.h>
+#include <wasm-builder.h>
+#include <wasm.h>
+
+namespace wasm {
+
+static Load* getSingleLoad(LocalGraph* localGraph, GetLocal* get) {
+ while (1) {
+ auto& sets = localGraph->getSetses[get];
+ if (sets.size() != 1) {
+ return nullptr;
+ }
+ auto* set = *sets.begin();
+ if (!set) {
+ return nullptr;
+ }
+ auto* value = Properties::getFallthrough(set->value);
+ if (auto* parentGet = value->dynCast<GetLocal>()) {
+ get = parentGet;
+ continue;
+ }
+ if (auto* load = value->dynCast<Load>()) {
+ return load;
+ }
+ return nullptr;
+ }
+}
+
+static bool isReinterpret(Unary* curr) {
+ return curr->op == ReinterpretInt32 || curr->op == ReinterpretInt64 ||
+ curr->op == ReinterpretFloat32 || curr->op == ReinterpretFloat64;
+}
+
+struct AvoidReinterprets : public WalkerPass<PostWalker<AvoidReinterprets>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new AvoidReinterprets; }
+
+ struct Info {
+ // Info used when analyzing.
+ bool reinterpreted;
+ // Info used when optimizing.
+ Index ptrLocal;
+ Index reinterpretedLocal;
+ };
+ std::map<Load*, Info> infos;
+
+ LocalGraph* localGraph;
+
+ void doWalkFunction(Function* func) {
+ // prepare
+ LocalGraph localGraph_(func);
+ localGraph = &localGraph_;
+ // walk
+ PostWalker<AvoidReinterprets>::doWalkFunction(func);
+ // optimize
+ optimize(func);
+ }
+
+ void visitUnary(Unary* curr) {
+ if (isReinterpret(curr)) {
+ if (auto* get =
+ Properties::getFallthrough(curr->value)->dynCast<GetLocal>()) {
+ if (auto* load = getSingleLoad(localGraph, get)) {
+ auto& info = infos[load];
+ info.reinterpreted = true;
+ }
+ }
+ }
+ }
+
+ void optimize(Function* func) {
+ std::set<Load*> unoptimizables;
+ for (auto& pair : infos) {
+ auto* load = pair.first;
+ auto& info = pair.second;
+ if (info.reinterpreted && load->type != unreachable) {
+ // We should use another load here, to avoid reinterprets.
+ info.ptrLocal = Builder::addVar(func, i32);
+ info.reinterpretedLocal =
+ Builder::addVar(func, reinterpretType(load->type));
+ } else {
+ unoptimizables.insert(load);
+ }
+ }
+ for (auto* load : unoptimizables) {
+ infos.erase(load);
+ }
+ // We now know which we can optimize, and how.
+ struct FinalOptimizer : public PostWalker<FinalOptimizer> {
+ std::map<Load*, Info>& infos;
+ LocalGraph* localGraph;
+ Module* module;
+
+ FinalOptimizer(std::map<Load*, Info>& infos,
+ LocalGraph* localGraph,
+ Module* module)
+ : infos(infos), localGraph(localGraph), module(module) {}
+
+ void visitUnary(Unary* curr) {
+ if (isReinterpret(curr)) {
+ auto* value = Properties::getFallthrough(curr->value);
+ if (auto* load = value->dynCast<Load>()) {
+ // A reinterpret of a load - flip it right here.
+ replaceCurrent(makeReinterpretedLoad(load, load->ptr));
+ } else if (auto* get = value->dynCast<GetLocal>()) {
+ if (auto* load = getSingleLoad(localGraph, get)) {
+ auto iter = infos.find(load);
+ if (iter != infos.end()) {
+ auto& info = iter->second;
+ // A reinterpret of a get of a load - use the new local.
+ Builder builder(*module);
+ replaceCurrent(builder.makeGetLocal(
+ info.reinterpretedLocal, reinterpretType(load->type)));
+ }
+ }
+ }
+ }
+ }
+
+ void visitLoad(Load* curr) {
+ auto iter = infos.find(curr);
+ if (iter != infos.end()) {
+ auto& info = iter->second;
+ Builder builder(*module);
+ auto* ptr = curr->ptr;
+ curr->ptr = builder.makeGetLocal(info.ptrLocal, i32);
+ // Note that the other load can have its sign set to false - if the
+ // original were an integer, the other is a float anyhow; and if
+ // original were a float, we don't know what sign to use.
+ replaceCurrent(builder.makeBlock(
+ {builder.makeSetLocal(info.ptrLocal, ptr),
+ builder.makeSetLocal(
+ info.reinterpretedLocal,
+ makeReinterpretedLoad(curr,
+ builder.makeGetLocal(info.ptrLocal, i32))),
+ curr}));
+ }
+ }
+
+ Load* makeReinterpretedLoad(Load* load, Expression* ptr) {
+ Builder builder(*module);
+ return builder.makeLoad(load->bytes,
+ false,
+ load->offset,
+ load->align,
+ ptr,
+ reinterpretType(load->type));
+ }
+ } finalOptimizer(infos, localGraph, getModule());
+
+ finalOptimizer.walk(func->body);
+ }
+};
+
+Pass* createAvoidReinterpretsPass() { return new AvoidReinterprets(); }
+
+} // namespace wasm