diff options
-rw-r--r-- | src/passes/Outlining.cpp | 14 | ||||
-rw-r--r-- | src/wasm-ir-builder.h | 20 | ||||
-rw-r--r-- | src/wasm/wasm-ir-builder.cpp | 46 | ||||
-rw-r--r-- | test/lit/passes/outlining.wast | 100 |
4 files changed, 174 insertions, 6 deletions
diff --git a/src/passes/Outlining.cpp b/src/passes/Outlining.cpp index 345b3f9a2..68c3d0397 100644 --- a/src/passes/Outlining.cpp +++ b/src/passes/Outlining.cpp @@ -132,7 +132,19 @@ struct ReconstructStringifyWalker : state == NotInSeq ? &existingBuilder : nullptr; if (builder) { - ASSERT_OK(builder->visit(curr)); + if (auto* expr = curr->dynCast<Break>()) { + Type type = expr->value ? expr->value->type : Type::none; + ASSERT_OK(builder->visitBreakWithType(expr, type)); + } else if (auto* expr = curr->dynCast<Switch>()) { + Type type = expr->value ? expr->value->type : Type::none; + ASSERT_OK(builder->visitSwitchWithType(expr, type)); + } else { + // Assert ensures new unhandled branch instructions + // will quickly cause an error. Serves as a reminder to + // implement a new special-case visit*WithType. + assert(curr->is<BrOn>() || !Properties::isBranch(curr)); + ASSERT_OK(builder->visit(curr)); + } } DBG(printVisitExpression(curr)); diff --git a/src/wasm-ir-builder.h b/src/wasm-ir-builder.h index b37b352e3..d31a532a7 100644 --- a/src/wasm-ir-builder.h +++ b/src/wasm-ir-builder.h @@ -224,10 +224,26 @@ public: [[nodiscard]] Result<> visitStructNew(StructNew*); [[nodiscard]] Result<> visitArrayNew(ArrayNew*); [[nodiscard]] Result<> visitArrayNewFixed(ArrayNewFixed*); + // Used to visit break exprs when traversing the module in the fully nested + // format. Break label destinations are assumed to have already been visited, + // with a corresponding push onto the scope stack. As a result, an error will + // return if a corresponding scope is not found for the break. [[nodiscard]] Result<> visitBreak(Break*, std::optional<Index> label = std::nullopt); + // Used to visit break nodes when traversing a single block without its + // context. The type indicates how many values the break carries to its + // destination. + [[nodiscard]] Result<> visitBreakWithType(Break*, Type); [[nodiscard]] Result<> + // Used to visit switch exprs when traversing the module in the fully nested + // format. Switch label destinations are assumed to have already been visited, + // with a corresponding push onto the scope stack. As a result, an error will + // return if a corresponding scope is not found for the switch. visitSwitch(Switch*, std::optional<Index> defaultLabel = std::nullopt); + // Used to visit switch nodes when traversing a single block without its + // context. The type indicates how many values the switch carries to its + // destination. + [[nodiscard]] Result<> visitSwitchWithType(Switch*, Type); [[nodiscard]] Result<> visitCall(Call*); [[nodiscard]] Result<> visitCallIndirect(CallIndirect*); [[nodiscard]] Result<> visitCallRef(CallRef*); @@ -535,8 +551,8 @@ private: [[nodiscard]] Result<> packageHoistedValue(const HoistedVal&, size_t sizeHint = 1); - [[nodiscard]] Result<Expression*> getBranchValue(Name labelName, - std::optional<Index> label); + [[nodiscard]] Result<Expression*> + getBranchValue(Expression* curr, Name labelName, std::optional<Index> label); void dump(); }; diff --git a/src/wasm/wasm-ir-builder.cpp b/src/wasm/wasm-ir-builder.cpp index b1bf8c855..8d87d0f33 100644 --- a/src/wasm/wasm-ir-builder.cpp +++ b/src/wasm/wasm-ir-builder.cpp @@ -419,8 +419,14 @@ Result<> IRBuilder::visitArrayNewFixed(ArrayNewFixed* curr) { return Ok{}; } -Result<Expression*> IRBuilder::getBranchValue(Name labelName, +Result<Expression*> IRBuilder::getBranchValue(Expression* curr, + Name labelName, std::optional<Index> label) { + // As new branch instructions are added, one of the existing branch visit* + // functions is likely to be copied, along with its call to getBranchValue(). + // This assert serves as a reminder to also add an implementation of + // visit*WithType() for new branch instructions. + assert(curr->is<Break>() || curr->is<Switch>()); if (!label) { auto index = getLabelIndex(labelName); CHECK_ERR(index); @@ -440,23 +446,57 @@ Result<> IRBuilder::visitBreak(Break* curr, std::optional<Index> label) { CHECK_ERR(cond); curr->condition = *cond; } - auto value = getBranchValue(curr->name, label); + auto value = getBranchValue(curr, curr->name, label); CHECK_ERR(value); curr->value = *value; return Ok{}; } +Result<> IRBuilder::visitBreakWithType(Break* curr, Type type) { + if (curr->condition) { + auto cond = pop(); + CHECK_ERR(cond); + curr->condition = *cond; + } + if (type == Type::none) { + curr->value = nullptr; + } else { + auto value = pop(type.size()); + CHECK_ERR(value) + curr->value = *value; + } + curr->finalize(); + push(curr); + return Ok{}; +} + Result<> IRBuilder::visitSwitch(Switch* curr, std::optional<Index> defaultLabel) { auto cond = pop(); CHECK_ERR(cond); curr->condition = *cond; - auto value = getBranchValue(curr->default_, defaultLabel); + auto value = getBranchValue(curr, curr->default_, defaultLabel); CHECK_ERR(value); curr->value = *value; return Ok{}; } +Result<> IRBuilder::visitSwitchWithType(Switch* curr, Type type) { + auto cond = pop(); + CHECK_ERR(cond); + curr->condition = *cond; + if (type == Type::none) { + curr->value = nullptr; + } else { + auto value = pop(type.size()); + CHECK_ERR(value) + curr->value = *value; + } + curr->finalize(); + push(curr); + return Ok{}; +} + Result<> IRBuilder::visitCall(Call* curr) { auto numArgs = wasm.getFunction(curr->target)->getNumParams(); curr->operands.resize(numArgs); diff --git a/test/lit/passes/outlining.wast b/test/lit/passes/outlining.wast index befce7513..76f305db7 100644 --- a/test/lit/passes/outlining.wast +++ b/test/lit/passes/outlining.wast @@ -614,6 +614,106 @@ ) ) +;; Tests branch with condition is reconstructed without error. +(module + ;; CHECK: (type $0 (func)) + + ;; CHECK: (func $outline$ (type $0) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + + ;; CHECK: (func $a (type $0) + ;; CHECK-NEXT: (block $label1 + ;; CHECK-NEXT: (call $outline$) + ;; CHECK-NEXT: (loop $loop-in + ;; CHECK-NEXT: (br $label1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (call $outline$) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $a + (block $label1 + (drop + (i32.const 2) + ) + (drop + (i32.const 1) + ) + (loop + (br $label1) + ) + (drop + (i32.const 2) + ) + (drop + (i32.const 1) + ) + ) + ) +) + +;; Tests br_table instruction is reconstructed without error. +(module + ;; CHECK: (type $0 (func)) + + ;; CHECK: (type $1 (func (param i32) (result i32))) + + ;; CHECK: (func $outline$ (type $0) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + + ;; CHECK: (func $a (type $1) (param $0 i32) (result i32) + ;; CHECK-NEXT: (call $outline$) + ;; CHECK-NEXT: (block $block + ;; CHECK-NEXT: (block $block0 + ;; CHECK-NEXT: (br_table $block $block0 + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (i32.const 21) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (i32.const 20) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (call $outline$) + ;; CHECK-NEXT: (i32.const 22) + ;; CHECK-NEXT: ) + (func $a (param i32) (result i32) + (drop + (i32.const 2) + ) + (drop + (i32.const 1) + ) + (block + (block + (br_table 1 0 (local.get $0)) + (return (i32.const 21)) + ) + (return (i32.const 20)) + ) + (drop + (i32.const 2) + ) + (drop + (i32.const 1) + ) + (i32.const 22) + ) +) + ;; Tests return instructions are correctly filtered from being outlined. (module ;; CHECK: (type $0 (func (result i32))) |