diff options
-rw-r--r-- | src/ir/find_all.h | 6 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 38 | ||||
-rw-r--r-- | src/passes/pass.cpp | 2 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | test/passes/inline-main.txt | 69 | ||||
-rw-r--r-- | test/passes/inline-main.wast | 53 |
6 files changed, 166 insertions, 3 deletions
diff --git a/src/ir/find_all.h b/src/ir/find_all.h index 1abaab772..bd3d265e2 100644 --- a/src/ir/find_all.h +++ b/src/ir/find_all.h @@ -58,9 +58,11 @@ struct PointerFinder template<typename T> struct FindAllPointers { std::vector<Expression**> list; - FindAllPointers(Expression* ast) { + // Note that a pointer may be to the function->body itself, so we must + // take \ast by reference. + FindAllPointers(Expression*& ast) { PointerFinder finder; - finder.id = T()._id; + finder.id = (Expression::Id)T::SpecificId; finder.list = &list; finder.walk(ast); } diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 9ba01c4c1..53a57cfeb 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -217,7 +217,7 @@ struct Updater : public PostWalker<Updater> { // Core inlining logic. Modifies the outside function (adding locals as // needed), and returns the inlined code. static Expression* -doInlining(Module* module, Function* into, InliningAction& action) { +doInlining(Module* module, Function* into, const InliningAction& action) { Function* from = action.contents; auto* call = (*action.callSite)->cast<Call>(); // Works for return_call, too @@ -415,4 +415,40 @@ Pass* createInliningOptimizingPass() { return ret; } +static const char* MAIN = "main"; +static const char* ORIGINAL_MAIN = "__original_main"; + +// Inline __original_main into main, if they exist. This works around the odd +// thing that clang/llvm currently do, where __original_main contains the user's +// actual main (this is done as a workaround for main having two different +// possible signatures). +struct InlineMainPass : public Pass { + void run(PassRunner* runner, Module* module) override { + auto* main = module->getFunctionOrNull(MAIN); + auto* originalMain = module->getFunctionOrNull(ORIGINAL_MAIN); + if (!main || main->imported() || !originalMain || + originalMain->imported()) { + return; + } + FindAllPointers<Call> calls(main->body); + Expression** callSite = nullptr; + for (auto* call : calls.list) { + if ((*call)->cast<Call>()->target == ORIGINAL_MAIN) { + if (callSite) { + // More than one call site. + return; + } + callSite = call; + } + } + if (!callSite) { + // No call at all. + return; + } + doInlining(module, main, InliningAction(callSite, originalMain)); + } +}; + +Pass* createInlineMainPass() { return new InlineMainPass(); } + } // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 9afea38e6..906fb85e2 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -129,6 +129,8 @@ void PassRegistry::registerPasses() { "func-metrics", "reports function metrics", createFunctionMetricsPass); registerPass( "generate-stack-ir", "generate Stack IR", createGenerateStackIRPass); + registerPass( + "inline-main", "inline __original_main into main", createInlineMainPass); registerPass("inlining", "inline functions (you probably want inlining-optimizing)", createInliningPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index cc33e4300..98c7d374d 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -45,6 +45,7 @@ Pass* createFullPrinterPass(); Pass* createFunctionMetricsPass(); Pass* createGenerateStackIRPass(); Pass* createI64ToI32LoweringPass(); +Pass* createInlineMainPass(); Pass* createInliningPass(); Pass* createInliningOptimizingPass(); Pass* createLegalizeJSInterfacePass(); diff --git a/test/passes/inline-main.txt b/test/passes/inline-main.txt new file mode 100644 index 000000000..f8704b7be --- /dev/null +++ b/test/passes/inline-main.txt @@ -0,0 +1,69 @@ +(module + (type $FUNCSIG$i (func (result i32))) + (type $FUNCSIG$iii (func (param i32 i32) (result i32))) + (export "main" (func $main)) + (func $__original_main (; 0 ;) (type $FUNCSIG$i) (result i32) + (i32.const 0) + ) + (func $main (; 1 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32) + (block $__inlined_func$__original_main (result i32) + (i32.const 0) + ) + ) +) +(module + (type $FUNCSIG$i (func (result i32))) + (type $FUNCSIG$iii (func (param i32 i32) (result i32))) + (export "main" (func $main)) + (func $__original_main (; 0 ;) (type $FUNCSIG$i) (result i32) + (i32.const 0) + ) + (func $main (; 1 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32) + (i32.const 0) + ) +) +(module + (type $FUNCSIG$iii (func (param i32 i32) (result i32))) + (export "main" (func $main)) + (func $main (; 0 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32) + (i32.const 0) + ) +) +(module + (type $FUNCSIG$i (func (result i32))) + (type $FUNCSIG$iii (func (param i32 i32) (result i32))) + (export "main" (func $main)) + (func $__original_main (; 0 ;) (type $FUNCSIG$i) (result i32) + (i32.const 0) + ) + (func $main (; 1 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32) + (drop + (call $__original_main) + ) + (call $__original_main) + ) +) +(module + (type $FUNCSIG$i (func (result i32))) + (func $__original_main (; 0 ;) (type $FUNCSIG$i) (result i32) + (i32.const 0) + ) +) +(module + (type $FUNCSIG$iii (func (param i32 i32) (result i32))) + (type $FUNCSIG$i (func (result i32))) + (import "env" "main" (func $main (param i32 i32) (result i32))) + (export "main" (func $main)) + (func $__original_main (; 1 ;) (type $FUNCSIG$i) (result i32) + (i32.const 0) + ) +) +(module + (type $FUNCSIG$i (func (result i32))) + (type $FUNCSIG$iii (func (param i32 i32) (result i32))) + (import "env" "original_main" (func $__original_main (result i32))) + (export "main" (func $main)) + (func $main (; 1 ;) (type $FUNCSIG$iii) (param $0 i32) (param $1 i32) (result i32) + (call $__original_main) + ) +) diff --git a/test/passes/inline-main.wast b/test/passes/inline-main.wast new file mode 100644 index 000000000..d86776187 --- /dev/null +++ b/test/passes/inline-main.wast @@ -0,0 +1,53 @@ +(module + (export "main" (func $main)) + (func $__original_main (result i32) + (i32.const 0) + ) + (func $main (param $0 i32) (param $1 i32) (result i32) + (call $__original_main) + ) +) +(module + (export "main" (func $main)) + (func $__original_main (result i32) + (i32.const 0) + ) + (func $main (param $0 i32) (param $1 i32) (result i32) + (i32.const 0) + ) +) +(module + (export "main" (func $main)) + (func $main (param $0 i32) (param $1 i32) (result i32) + (i32.const 0) + ) +) +(module + (export "main" (func $main)) + (func $__original_main (result i32) + (i32.const 0) + ) + (func $main (param $0 i32) (param $1 i32) (result i32) + (drop (call $__original_main)) + (call $__original_main) + ) +) +(module + (func $__original_main (result i32) + (i32.const 0) + ) +) +(module + (import "env" "main" (func $main (param i32 i32) (result i32))) + (export "main" (func $main)) + (func $__original_main (result i32) + (i32.const 0) + ) +) +(module + (import "env" "original_main" (func $__original_main (result i32))) + (export "main" (func $main)) + (func $main (param $0 i32) (param $1 i32) (result i32) + (call $__original_main) + ) +) |