/*
* Copyright 2017 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef wasm_ir_module_h
#define wasm_ir_module_h
#include "ir/find_all.h"
#include "ir/manipulation.h"
#include "ir/properties.h"
#include "pass.h"
#include "support/unique_deferring_queue.h"
#include "wasm.h"
namespace wasm {
namespace ModuleUtils {
inline Function* copyFunction(Function* func, Module& out) {
auto* ret = new Function();
ret->name = func->name;
ret->sig = func->sig;
ret->vars = func->vars;
ret->localNames = func->localNames;
ret->localIndices = func->localIndices;
ret->debugLocations = func->debugLocations;
ret->body = ExpressionManipulator::copy(func->body, out);
ret->module = func->module;
ret->base = func->base;
// TODO: copy Stack IR
assert(!func->stackIR);
out.addFunction(ret);
return ret;
}
inline Global* copyGlobal(Global* global, Module& out) {
auto* ret = new Global();
ret->name = global->name;
ret->type = global->type;
ret->mutable_ = global->mutable_;
ret->module = global->module;
ret->base = global->base;
if (global->imported()) {
ret->init = nullptr;
} else {
ret->init = ExpressionManipulator::copy(global->init, out);
}
out.addGlobal(ret);
return ret;
}
inline Event* copyEvent(Event* event, Module& out) {
auto* ret = new Event();
ret->name = event->name;
ret->attribute = event->attribute;
ret->sig = event->sig;
out.addEvent(ret);
return ret;
}
inline Table* copyTableWithoutSegments(Table* table, Module& out) {
auto ret = std::make_unique
();
ret->name = table->name;
ret->module = table->module;
ret->base = table->base;
ret->initial = table->initial;
ret->max = table->max;
return out.addTable(std::move(ret));
}
inline Table* copyTable(Table* table, Module& out) {
auto ret = copyTableWithoutSegments(table, out);
for (auto segment : table->segments) {
segment.offset = ExpressionManipulator::copy(segment.offset, out);
ret->segments.push_back(segment);
}
return ret;
}
inline void copyModule(const Module& in, Module& out) {
// we use names throughout, not raw pointers, so simple copying is fine
// for everything *but* expressions
for (auto& curr : in.exports) {
out.addExport(new Export(*curr));
}
for (auto& curr : in.functions) {
copyFunction(curr.get(), out);
}
for (auto& curr : in.globals) {
copyGlobal(curr.get(), out);
}
for (auto& curr : in.events) {
copyEvent(curr.get(), out);
}
for (auto& curr : in.tables) {
copyTable(curr.get(), out);
}
out.memory = in.memory;
for (auto& segment : out.memory.segments) {
segment.offset = ExpressionManipulator::copy(segment.offset, out);
}
out.start = in.start;
out.userSections = in.userSections;
out.debugInfoFileNames = in.debugInfoFileNames;
}
inline void clearModule(Module& wasm) {
wasm.~Module();
new (&wasm) Module;
}
// Renaming
// Rename functions along with all their uses.
// Note that for this to work the functions themselves don't necessarily need
// to exist. For example, it is possible to remove a given function and then
// call this redirect all of its uses.
template inline void renameFunctions(Module& wasm, T& map) {
// Update the function itself.
for (auto& pair : map) {
if (Function* F = wasm.getFunctionOrNull(pair.first)) {
assert(!wasm.getFunctionOrNull(pair.second) || F->name == pair.second);
F->name = pair.second;
}
}
wasm.updateMaps();
// Update other global things.
auto maybeUpdate = [&](Name& name) {
auto iter = map.find(name);
if (iter != map.end()) {
name = iter->second;
}
};
maybeUpdate(wasm.start);
for (auto& table : wasm.tables) {
for (auto& segment : table->segments) {
for (auto& name : segment.data) {
maybeUpdate(name);
}
}
}
for (auto& exp : wasm.exports) {
if (exp->kind == ExternalKind::Function) {
maybeUpdate(exp->value);
}
}
// Update call instructions.
for (auto& func : wasm.functions) {
// TODO: parallelize
if (!func->imported()) {
FindAll calls(func->body);
for (auto* call : calls.list) {
maybeUpdate(call->target);
}
}
}
}
inline void renameFunction(Module& wasm, Name oldName, Name newName) {
std::map map;
map[oldName] = newName;
renameFunctions(wasm, map);
}
// Convenient iteration over imported/non-imported module elements
template inline void iterImportedMemories(Module& wasm, T visitor) {
if (wasm.memory.exists && wasm.memory.imported()) {
visitor(&wasm.memory);
}
}
template inline void iterDefinedMemories(Module& wasm, T visitor) {
if (wasm.memory.exists && !wasm.memory.imported()) {
visitor(&wasm.memory);
}
}
template inline void iterImportedTables(Module& wasm, T visitor) {
for (auto& import : wasm.tables) {
if (import->imported()) {
visitor(import.get());
}
}
}
template inline void iterDefinedTables(Module& wasm, T visitor) {
for (auto& import : wasm.tables) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template inline void iterImportedGlobals(Module& wasm, T visitor) {
for (auto& import : wasm.globals) {
if (import->imported()) {
visitor(import.get());
}
}
}
template inline void iterDefinedGlobals(Module& wasm, T visitor) {
for (auto& import : wasm.globals) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template
inline void iterImportedFunctions(Module& wasm, T visitor) {
for (auto& import : wasm.functions) {
if (import->imported()) {
visitor(import.get());
}
}
}
template inline void iterDefinedFunctions(Module& wasm, T visitor) {
for (auto& import : wasm.functions) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template inline void iterImportedEvents(Module& wasm, T visitor) {
for (auto& import : wasm.events) {
if (import->imported()) {
visitor(import.get());
}
}
}
template inline void iterDefinedEvents(Module& wasm, T visitor) {
for (auto& import : wasm.events) {
if (!import->imported()) {
visitor(import.get());
}
}
}
template inline void iterImports(Module& wasm, T visitor) {
iterImportedMemories(wasm, visitor);
iterImportedTables(wasm, visitor);
iterImportedGlobals(wasm, visitor);
iterImportedFunctions(wasm, visitor);
iterImportedEvents(wasm, visitor);
}
// Helper class for performing an operation on all the functions in the module,
// in parallel, with an Info object for each one that can contain results of
// some computation that the operation performs.
// The operation performend should not modify the wasm module in any way.
// TODO: enforce this
template struct ParallelFunctionAnalysis {
Module& wasm;
typedef std::map Map;
Map map;
typedef std::function Func;
ParallelFunctionAnalysis(Module& wasm, Func work) : wasm(wasm) {
// Fill in map, as we operate on it in parallel (each function to its own
// entry).
for (auto& func : wasm.functions) {
map[func.get()];
}
// Run on the imports first. TODO: parallelize this too
for (auto& func : wasm.functions) {
if (func->imported()) {
work(func.get(), map[func.get()]);
}
}
struct Mapper : public WalkerPass> {
bool isFunctionParallel() override { return true; }
bool modifiesBinaryenIR() override { return false; }
Mapper(Module& module, Map& map, Func work)
: module(module), map(map), work(work) {}
Mapper* create() override { return new Mapper(module, map, work); }
void doWalkFunction(Function* curr) {
assert(map.count(curr));
work(curr, map[curr]);
}
private:
Module& module;
Map& map;
Func work;
};
PassRunner runner(&wasm);
Mapper(wasm, map, work).run(&runner, &wasm);
}
};
// Helper class for analyzing the call graph.
//
// Provides hooks for running some initial calculation on each function (which
// is done in parallel), writing to a FunctionInfo structure for each function.
// Then you can call propagateBack() to propagate a property of interest to the
// calling functions, transitively.
//
// For example, if some functions are known to call an import "foo", then you
// can use this to find which functions call something that might eventually
// reach foo, by initially marking the direct callers as "calling foo" and
// propagating that backwards.
template struct CallGraphPropertyAnalysis {
Module& wasm;
// The basic information for each function about whom it calls and who is
// called by it.
struct FunctionInfo {
std::set callsTo;
std::set calledBy;
// A non-direct call is any call that is not direct. That includes
// CallIndirect and CallRef.
bool hasNonDirectCall = false;
};
typedef std::map Map;
Map map;
typedef std::function Func;
CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) {
ParallelFunctionAnalysis analysis(wasm, [&](Function* func, T& info) {
work(func, info);
if (func->imported()) {
return;
}
struct Mapper : public PostWalker {
Mapper(Module* module, T& info, Func work)
: module(module), info(info), work(work) {}
void visitCall(Call* curr) {
info.callsTo.insert(module->getFunction(curr->target));
}
void visitCallIndirect(CallIndirect* curr) {
info.hasNonDirectCall = true;
}
void visitCallRef(CallRef* curr) { info.hasNonDirectCall = true; }
private:
Module* module;
T& info;
Func work;
} mapper(&wasm, info, work);
mapper.walk(func->body);
});
map.swap(analysis.map);
// Find what is called by what.
for (auto& pair : map) {
auto* func = pair.first;
auto& info = pair.second;
for (auto* target : info.callsTo) {
map[target].calledBy.insert(func);
}
}
}
enum NonDirectCalls { IgnoreNonDirectCalls, NonDirectCallsHaveProperty };
// Propagate a property from a function to those that call it.
//
// hasProperty() - Check if the property is present.
// canHaveProperty() - Check if the property could be present.
// addProperty() - Adds the property. This receives a second parameter which
// is the function due to which we are adding the property.
void propagateBack(std::function hasProperty,
std::function canHaveProperty,
std::function addProperty,
NonDirectCalls nonDirectCalls) {
// The work queue contains items we just learned can change the state.
UniqueDeferredQueue work;
for (auto& func : wasm.functions) {
if (hasProperty(map[func.get()]) ||
(nonDirectCalls == NonDirectCallsHaveProperty &&
map[func.get()].hasNonDirectCall)) {
addProperty(map[func.get()], func.get());
work.push(func.get());
}
}
while (!work.empty()) {
auto* func = work.pop();
for (auto* caller : map[func].calledBy) {
// If we don't already have the property, and we are not forbidden
// from getting it, then it propagates back to us now.
if (!hasProperty(map[caller]) && canHaveProperty(map[caller])) {
addProperty(map[caller], func);
work.push(caller);
}
}
}
}
};
// Helper function for collecting all the types that are declared in a module,
// which means the HeapTypes (that are non-basic, that is, not eqref etc., which
// do not need to be defined).
//
// Used when emitting or printing a module to give HeapTypes canonical
// indices. HeapTypes are sorted in order of decreasing frequency to minize the
// size of their collective encoding. Both a vector mapping indices to
// HeapTypes and a map mapping HeapTypes to indices are produced.
inline void collectHeapTypes(Module& wasm,
std::vector& types,
std::unordered_map& typeIndices) {
struct Counts : public std::unordered_map {
bool isRelevant(Type type) {
if (type.isRef()) {
return !type.getHeapType().isBasic();
}
return type.isRtt();
}
void note(HeapType type) { (*this)[type]++; }
void maybeNote(Type type) {
if (isRelevant(type)) {
note(type.getHeapType());
}
}
};
// Collect the type use counts for a single function
auto updateCounts = [&](Function* func, Counts& counts) {
if (func->imported()) {
return;
}
struct TypeCounter
: PostWalker> {
Counts& counts;
TypeCounter(Counts& counts) : counts(counts) {}
void visitExpression(Expression* curr) {
if (auto* call = curr->dynCast()) {
counts.note(call->sig);
} else if (curr->is()) {
counts.maybeNote(curr->type);
} else if (curr->is() || curr->is()) {
counts.note(curr->type.getRtt().heapType);
} else if (auto* get = curr->dynCast()) {
counts.maybeNote(get->ref->type);
} else if (auto* set = curr->dynCast()) {
counts.maybeNote(set->ref->type);
} else if (Properties::isControlFlowStructure(curr)) {
counts.maybeNote(curr->type);
if (curr->type.isTuple()) {
// TODO: Allow control flow to have input types as well
counts.note(Signature(Type::none, curr->type));
}
}
}
};
TypeCounter(counts).walk(func->body);
};
ModuleUtils::ParallelFunctionAnalysis analysis(wasm, updateCounts);
// Collect all the counts.
Counts counts;
for (auto& curr : wasm.functions) {
counts.note(curr->sig);
for (auto type : curr->vars) {
counts.maybeNote(type);
if (type.isTuple()) {
for (auto t : type) {
counts.maybeNote(t);
}
}
}
}
for (auto& curr : wasm.events) {
counts.note(curr->sig);
}
for (auto& curr : wasm.globals) {
counts.maybeNote(curr->type);
}
for (auto& pair : analysis.map) {
Counts& functionCounts = pair.second;
for (auto& innerPair : functionCounts) {
counts[innerPair.first] += innerPair.second;
}
}
// A generic utility to traverse the child types of a type.
// TODO: work with tlively to refactor this to a shared place
auto walkRelevantChildren = [&](HeapType type,
std::function callback) {
auto callIfRelevant = [&](Type type) {
if (counts.isRelevant(type)) {
callback(type.getHeapType());
}
};
if (type.isSignature()) {
auto sig = type.getSignature();
for (Type type : {sig.params, sig.results}) {
for (auto element : type) {
callIfRelevant(element);
}
}
} else if (type.isArray()) {
callIfRelevant(type.getArray().element.type);
} else if (type.isStruct()) {
auto fields = type.getStruct().fields;
for (auto field : fields) {
callIfRelevant(field.type);
}
}
};
// Recursively traverse each reference type, which may have a child type that
// is itself a reference type. This reflects an appearance in the binary
// format that is in the type section itself.
// As we do this we may find more and more types, as nested children of
// previous ones. Each such type will appear in the type section once, so
// we just need to visit it once.
// TODO: handle struct and array fields
std::unordered_set newTypes;
for (auto& pair : counts) {
newTypes.insert(pair.first);
}
while (!newTypes.empty()) {
auto iter = newTypes.begin();
auto type = *iter;
newTypes.erase(iter);
walkRelevantChildren(type, [&](HeapType type) {
if (!counts.count(type)) {
newTypes.insert(type);
}
counts.note(type);
});
}
// We must sort all the dependencies of a type before it. For example,
// (func (param (ref (func)))) must appear after (func). To do that, find the
// depth of dependencies of each type. For example, if A depends on B
// which depends on C, then A's depth is 2, B's is 1, and C's is 0 (assuming
// no other dependencies).
Counts depthOfDependencies;
std::unordered_map> isDependencyOf;
// To calculate the depth of dependencies, we'll do a flow analysis, visiting
// each type as we find out new things about it.
std::set toVisit;
for (auto& pair : counts) {
auto type = pair.first;
depthOfDependencies[type] = 0;
toVisit.insert(type);
walkRelevantChildren(type, [&](HeapType childType) {
isDependencyOf[childType].insert(type); // XXX flip?
});
}
while (!toVisit.empty()) {
auto iter = toVisit.begin();
auto type = *iter;
toVisit.erase(iter);
// Anything that depends on this has a depth of dependencies equal to this
// type's, plus this type itself.
auto newDepth = depthOfDependencies[type] + 1;
if (newDepth > counts.size()) {
Fatal() << "Cyclic types detected, cannot sort them.";
}
for (auto& other : isDependencyOf[type]) {
if (depthOfDependencies[other] < newDepth) {
// We found something new to propagate.
depthOfDependencies[other] = newDepth;
toVisit.insert(other);
}
}
}
// Sort by frequency and then simplicity, and also keeping every type
// before things that depend on it.
std::vector> sorted(counts.begin(), counts.end());
std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
if (depthOfDependencies[a.first] != depthOfDependencies[b.first]) {
return depthOfDependencies[a.first] < depthOfDependencies[b.first];
}
if (a.second != b.second) {
return a.second > b.second;
}
return a.first < b.first;
});
for (Index i = 0; i < sorted.size(); ++i) {
typeIndices[sorted[i].first] = i;
types.push_back(sorted[i].first);
}
}
} // namespace ModuleUtils
} // namespace wasm
#endif // wasm_ir_module_h