summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ir/find_all.h6
-rw-r--r--src/passes/Inlining.cpp38
-rw-r--r--src/passes/pass.cpp2
-rw-r--r--src/passes/passes.h1
-rw-r--r--test/passes/inline-main.txt69
-rw-r--r--test/passes/inline-main.wast53
6 files changed, 166 insertions, 3 deletions
diff --git a/src/ir/find_all.h b/src/ir/find_all.h
index 1abaab772..bd3d265e2 100644
--- a/src/ir/find_all.h
+++ b/src/ir/find_all.h
@@ -58,9 +58,11 @@ struct PointerFinder
template<typename T> struct FindAllPointers {
std::vector<Expression**> list;
- FindAllPointers(Expression* ast) {
+ // Note that a pointer may be to the function->body itself, so we must
+ // take \ast by reference.
+ FindAllPointers(Expression*& ast) {
PointerFinder finder;
- finder.id = T()._id;
+ finder.id = (Expression::Id)T::SpecificId;
finder.list = &list;
finder.walk(ast);
}
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp
index 9ba01c4c1..53a57cfeb 100644
--- a/src/passes/Inlining.cpp
+++ b/src/passes/Inlining.cpp
@@ -217,7 +217,7 @@ struct Updater : public PostWalker<Updater> {
// Core inlining logic. Modifies the outside function (adding locals as
// needed), and returns the inlined code.
static Expression*
-doInlining(Module* module, Function* into, InliningAction& action) {
+doInlining(Module* module, Function* into, const InliningAction& action) {
Function* from = action.contents;
auto* call = (*action.callSite)->cast<Call>();
// Works for return_call, too
@@ -415,4 +415,40 @@ Pass* createInliningOptimizingPass() {
return ret;
}
+static const char* MAIN = "main";
+static const char* ORIGINAL_MAIN = "__original_main";
+
+// Inline __original_main into main, if they exist. This works around the odd
+// thing that clang/llvm currently do, where __original_main contains the user's
+// actual main (this is done as a workaround for main having two different
+// possible signatures).
+struct InlineMainPass : public Pass {
+ void run(PassRunner* runner, Module* module) override {
+ auto* main = module->getFunctionOrNull(MAIN);
+ auto* originalMain = module->getFunctionOrNull(ORIGINAL_MAIN);
+ if (!main || main->imported() || !originalMain ||
+ originalMain->imported()) {
+ return;
+ }
+ FindAllPointers<Call> calls(main->body);
+ Expression** callSite = nullptr;
+ for (auto* call : calls.list) {
+ if ((*call)->cast<Call>()->target == ORIGINAL_MAIN) {
+ if (callSite) {
+ // More than one call site.
+ return;
+ }
+ callSite = call;
+ }
+ }
+ if (!callSite) {
+ // No call at all.
+ return;
+ }
+ doInlining(module, main, InliningAction(callSite, originalMain));
+ }
+};
+
+Pass* createInlineMainPass() { return new InlineMainPass(); }
+
} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 9afea38e6..906fb85e2 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -129,6 +129,8 @@ void PassRegistry::registerPasses() {
"func-metrics", "reports function metrics", createFunctionMetricsPass);
registerPass(
"generate-stack-ir", "generate Stack IR", createGenerateStackIRPass);
+ registerPass(
+ "inline-main", "inline __original_main into main", createInlineMainPass);
registerPass("inlining",
"inline functions (you probably want inlining-optimizing)",
createInliningPass);
diff --git a/src/passes/passes.h b/src/passes/passes.h
index cc33e4300..98c7d374d 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -45,6 +45,7 @@ Pass* createFullPrinterPass();
Pass* createFunctionMetricsPass();
Pass* createGenerateStackIRPass();
Pass* createI64ToI32LoweringPass();
+Pass* createInlineMainPass();
Pass* createInliningPass();
Pass* createInliningOptimizingPass();
Pass* createLegalizeJSInterfacePass();
diff --git a/test/passes/inline-main.txt b/test/passes/inline-main.txt
new file mode 100644
index 000000000..f8704b7be
--- /dev/null
+++ b/test/passes/inline-main.txt
@@ -0,0 +1,69 @@
+(module
+ (type $FUNCSIG$i (func (result i32)))
+ (type $FUNCSIG$iii (func (param i32 i32) (result i32)))
+ (export "main" (func $main))
+ (func $__original_main (; 0 ;) (type $FUNCSIG$i) (result i32)
+ (i32.const 0)
+ )
+ (func $main (; 1 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32)
+ (block $__inlined_func$__original_main (result i32)
+ (i32.const 0)
+ )
+ )
+)
+(module
+ (type $FUNCSIG$i (func (result i32)))
+ (type $FUNCSIG$iii (func (param i32 i32) (result i32)))
+ (export "main" (func $main))
+ (func $__original_main (; 0 ;) (type $FUNCSIG$i) (result i32)
+ (i32.const 0)
+ )
+ (func $main (; 1 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32)
+ (i32.const 0)
+ )
+)
+(module
+ (type $FUNCSIG$iii (func (param i32 i32) (result i32)))
+ (export "main" (func $main))
+ (func $main (; 0 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32)
+ (i32.const 0)
+ )
+)
+(module
+ (type $FUNCSIG$i (func (result i32)))
+ (type $FUNCSIG$iii (func (param i32 i32) (result i32)))
+ (export "main" (func $main))
+ (func $__original_main (; 0 ;) (type $FUNCSIG$i) (result i32)
+ (i32.const 0)
+ )
+ (func $main (; 1 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32)
+ (drop
+ (call $__original_main)
+ )
+ (call $__original_main)
+ )
+)
+(module
+ (type $FUNCSIG$i (func (result i32)))
+ (func $__original_main (; 0 ;) (type $FUNCSIG$i) (result i32)
+ (i32.const 0)
+ )
+)
+(module
+ (type $FUNCSIG$iii (func (param i32 i32) (result i32)))
+ (type $FUNCSIG$i (func (result i32)))
+ (import "env" "main" (func $main (param i32 i32) (result i32)))
+ (export "main" (func $main))
+ (func $__original_main (; 1 ;) (type $FUNCSIG$i) (result i32)
+ (i32.const 0)
+ )
+)
+(module
+ (type $FUNCSIG$i (func (result i32)))
+ (type $FUNCSIG$iii (func (param i32 i32) (result i32)))
+ (import "env" "original_main" (func $__original_main (result i32)))
+ (export "main" (func $main))
+ (func $main (; 1 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32)
+ (call $__original_main)
+ )
+)
diff --git a/test/passes/inline-main.wast b/test/passes/inline-main.wast
new file mode 100644
index 000000000..d86776187
--- /dev/null
+++ b/test/passes/inline-main.wast
@@ -0,0 +1,53 @@
+(module
+ (export "main" (func $main))
+ (func $__original_main (result i32)
+ (i32.const 0)
+ )
+ (func $main (param $0 i32) (param $1 i32) (result i32)
+ (call $__original_main)
+ )
+)
+(module
+ (export "main" (func $main))
+ (func $__original_main (result i32)
+ (i32.const 0)
+ )
+ (func $main (param $0 i32) (param $1 i32) (result i32)
+ (i32.const 0)
+ )
+)
+(module
+ (export "main" (func $main))
+ (func $main (param $0 i32) (param $1 i32) (result i32)
+ (i32.const 0)
+ )
+)
+(module
+ (export "main" (func $main))
+ (func $__original_main (result i32)
+ (i32.const 0)
+ )
+ (func $main (param $0 i32) (param $1 i32) (result i32)
+ (drop (call $__original_main))
+ (call $__original_main)
+ )
+)
+(module
+ (func $__original_main (result i32)
+ (i32.const 0)
+ )
+)
+(module
+ (import "env" "main" (func $main (param i32 i32) (result i32)))
+ (export "main" (func $main))
+ (func $__original_main (result i32)
+ (i32.const 0)
+ )
+)
+(module
+ (import "env" "original_main" (func $__original_main (result i32)))
+ (export "main" (func $main))
+ (func $main (param $0 i32) (param $1 i32) (result i32)
+ (call $__original_main)
+ )
+)