/*
 * Copyright 2023 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef wasm_wasm_ir_builder_h
#define wasm_wasm_ir_builder_h

#include <vector>

#include "ir/names.h"
#include "support/result.h"
#include "wasm-builder.h"
#include "wasm-traversal.h"
#include "wasm-type.h"
#include "wasm.h"

namespace wasm {

// A utility for constructing valid Binaryen IR from arbitrary valid sequences
// of WebAssembly instructions. The user is responsible for providing Expression
// nodes with all of their non-child fields already filled out, and IRBuilder is
// responsible for setting child fields and finalizing nodes.
//
// To use, call CHECK_ERR(visit(...)) or CHECK_ERR(makeXYZ(...)) on each
// expression in the sequence, then call build().
//
// Unlike `Builder`, `IRBuilder` requires referenced module-level items (e.g.
// globals, tables, functions, etc.) to already exist in the module.
class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
public:
  IRBuilder(Module& wasm, Function* func = nullptr)
    : wasm(wasm), func(func), builder(wasm) {}

  // Get the valid Binaryen IR expression representing the sequence of visited
  // instructions. The IRBuilder is reset and can be used with a fresh sequence
  // of instructions after this is called.
  [[nodiscard]] Result<Expression*> build();

  // Call visit() on an existing Expression with its non-child fields
  // initialized to initialize the child fields and refinalize it.
  [[nodiscard]] Result<> visit(Expression*);

  // Like visit, but pushes the expression onto the stack as-is without popping
  // any children or refinalization.
  void push(Expression*);

  // Set the debug location to be attached to the next visited, created, or
  // pushed instruction.
  void setDebugLocation(const std::optional<Function::DebugLocation>&);

  // Handle the boundaries of control flow structures. Users may choose to use
  // the corresponding `makeXYZ` function below instead of `visitXYZStart`, but
  // either way must call `visitEnd` and friends at the appropriate times.
  [[nodiscard]] Result<> visitFunctionStart(Function* func);
  [[nodiscard]] Result<> visitBlockStart(Block* block);
  [[nodiscard]] Result<> visitIfStart(If* iff, Name label = {});
  [[nodiscard]] Result<> visitElse();
  [[nodiscard]] Result<> visitLoopStart(Loop* iff);
  [[nodiscard]] Result<> visitTryStart(Try* tryy, Name label = {});
  [[nodiscard]] Result<> visitCatch(Name tag);
  [[nodiscard]] Result<> visitCatchAll();
  [[nodiscard]] Result<> visitDelegate(Index label);
  [[nodiscard]] Result<> visitTryTableStart(TryTable* trytable,
                                            Name label = {});
  [[nodiscard]] Result<> visitEnd();

  // Used to visit break nodes when traversing a single block without its
  // context. The type indicates how many values the break carries to its
  // destination.
  [[nodiscard]] Result<> visitBreakWithType(Break*, Type);
  // Used to visit switch nodes when traversing a single block without its
  // context. The type indicates how many values the switch carries to its
  // destination.
  [[nodiscard]] Result<> visitSwitchWithType(Switch*, Type);

  // Binaryen IR uses names to refer to branch targets, but in general there may
  // be branches to constructs that do not yet have names, so in IRBuilder we
  // use indices to refer to branch targets instead, just as the binary format
  // does. This function converts a branch target name to the correct index.
  //
  // Labels in delegates need special handling because the indexing needs to be
  // relative to the try's enclosing scope rather than the try itself.
  [[nodiscard]] Result<Index> getLabelIndex(Name label,
                                            bool inDelegate = false);

  // Instead of calling visit, call makeXYZ to have the IRBuilder allocate the
  // nodes. This is generally safer than calling `visit` because the function
  // signatures ensure that there are no missing fields.
  [[nodiscard]] Result<> makeNop();
  [[nodiscard]] Result<> makeBlock(Name label, Type type);
  [[nodiscard]] Result<> makeIf(Name label, Type type);
  [[nodiscard]] Result<> makeLoop(Name label, Type type);
  [[nodiscard]] Result<> makeBreak(Index label, bool isConditional);
  [[nodiscard]] Result<> makeSwitch(const std::vector<Index>& labels,
                                    Index defaultLabel);
  // Unlike Builder::makeCall, this assumes the function already exists.
  [[nodiscard]] Result<> makeCall(Name func, bool isReturn);
  [[nodiscard]] Result<>
  makeCallIndirect(Name table, HeapType type, bool isReturn);
  [[nodiscard]] Result<> makeLocalGet(Index local);
  [[nodiscard]] Result<> makeLocalSet(Index local);
  [[nodiscard]] Result<> makeLocalTee(Index local);
  [[nodiscard]] Result<> makeGlobalGet(Name global);
  [[nodiscard]] Result<> makeGlobalSet(Name global);
  [[nodiscard]] Result<> makeLoad(unsigned bytes,
                                  bool signed_,
                                  Address offset,
                                  unsigned align,
                                  Type type,
                                  Name mem);
  [[nodiscard]] Result<> makeStore(
    unsigned bytes, Address offset, unsigned align, Type type, Name mem);
  [[nodiscard]] Result<>
  makeAtomicLoad(unsigned bytes, Address offset, Type type, Name mem);
  [[nodiscard]] Result<>
  makeAtomicStore(unsigned bytes, Address offset, Type type, Name mem);
  [[nodiscard]] Result<> makeAtomicRMW(
    AtomicRMWOp op, unsigned bytes, Address offset, Type type, Name mem);
  [[nodiscard]] Result<>
  makeAtomicCmpxchg(unsigned bytes, Address offset, Type type, Name mem);
  [[nodiscard]] Result<> makeAtomicWait(Type type, Address offset, Name mem);
  [[nodiscard]] Result<> makeAtomicNotify(Address offset, Name mem);
  [[nodiscard]] Result<> makeAtomicFence();
  [[nodiscard]] Result<> makeSIMDExtract(SIMDExtractOp op, uint8_t lane);
  [[nodiscard]] Result<> makeSIMDReplace(SIMDReplaceOp op, uint8_t lane);
  [[nodiscard]] Result<> makeSIMDShuffle(const std::array<uint8_t, 16>& lanes);
  [[nodiscard]] Result<> makeSIMDTernary(SIMDTernaryOp op);
  [[nodiscard]] Result<> makeSIMDShift(SIMDShiftOp op);
  [[nodiscard]] Result<>
  makeSIMDLoad(SIMDLoadOp op, Address offset, unsigned align, Name mem);
  [[nodiscard]] Result<> makeSIMDLoadStoreLane(SIMDLoadStoreLaneOp op,
                                               Address offset,
                                               unsigned align,
                                               uint8_t lane,
                                               Name mem);
  [[nodiscard]] Result<> makeMemoryInit(Name data, Name mem);
  [[nodiscard]] Result<> makeDataDrop(Name data);
  [[nodiscard]] Result<> makeMemoryCopy(Name destMem, Name srcMem);
  [[nodiscard]] Result<> makeMemoryFill(Name mem);
  [[nodiscard]] Result<> makeConst(Literal val);
  [[nodiscard]] Result<> makeUnary(UnaryOp op);
  [[nodiscard]] Result<> makeBinary(BinaryOp op);
  [[nodiscard]] Result<> makeSelect(std::optional<Type> type = std::nullopt);
  [[nodiscard]] Result<> makeDrop();
  [[nodiscard]] Result<> makeReturn();
  [[nodiscard]] Result<> makeMemorySize(Name mem);
  [[nodiscard]] Result<> makeMemoryGrow(Name mem);
  [[nodiscard]] Result<> makeUnreachable();
  [[nodiscard]] Result<> makePop(Type type);
  [[nodiscard]] Result<> makeRefNull(HeapType type);
  [[nodiscard]] Result<> makeRefIsNull();
  [[nodiscard]] Result<> makeRefFunc(Name func);
  [[nodiscard]] Result<> makeRefEq();
  [[nodiscard]] Result<> makeTableGet(Name table);
  [[nodiscard]] Result<> makeTableSet(Name table);
  [[nodiscard]] Result<> makeTableSize(Name table);
  [[nodiscard]] Result<> makeTableGrow(Name table);
  [[nodiscard]] Result<> makeTableFill(Name table);
  [[nodiscard]] Result<> makeTableCopy(Name destTable, Name srcTable);
  [[nodiscard]] Result<> makeTableInit(Name elem, Name table);
  [[nodiscard]] Result<> makeTry(Name label, Type type);
  [[nodiscard]] Result<> makeTryTable(Name label,
                                      Type type,
                                      const std::vector<Name>& tags,
                                      const std::vector<Index>& labels,
                                      const std::vector<bool>& isRefs);
  [[nodiscard]] Result<> makeThrow(Name tag);
  [[nodiscard]] Result<> makeRethrow(Index label);
  [[nodiscard]] Result<> makeThrowRef();
  [[nodiscard]] Result<> makeTupleMake(uint32_t arity);
  [[nodiscard]] Result<> makeTupleExtract(uint32_t arity, uint32_t index);
  [[nodiscard]] Result<> makeTupleDrop(uint32_t arity);
  [[nodiscard]] Result<> makeRefI31(Shareability share);
  [[nodiscard]] Result<> makeI31Get(bool signed_);
  [[nodiscard]] Result<> makeCallRef(HeapType type, bool isReturn);
  [[nodiscard]] Result<> makeRefTest(Type type);
  [[nodiscard]] Result<> makeRefCast(Type type);
  [[nodiscard]] Result<>
  makeBrOn(Index label, BrOnOp op, Type in = Type::none, Type out = Type::none);
  [[nodiscard]] Result<> makeStructNew(HeapType type);
  [[nodiscard]] Result<> makeStructNewDefault(HeapType type);
  [[nodiscard]] Result<>
  makeStructGet(HeapType type, Index field, bool signed_);
  [[nodiscard]] Result<> makeStructSet(HeapType type, Index field);
  [[nodiscard]] Result<> makeArrayNew(HeapType type);
  [[nodiscard]] Result<> makeArrayNewDefault(HeapType type);
  [[nodiscard]] Result<> makeArrayNewData(HeapType type, Name data);
  [[nodiscard]] Result<> makeArrayNewElem(HeapType type, Name elem);
  [[nodiscard]] Result<> makeArrayNewFixed(HeapType type, uint32_t arity);
  [[nodiscard]] Result<> makeArrayGet(HeapType type, bool signed_);
  [[nodiscard]] Result<> makeArraySet(HeapType type);
  [[nodiscard]] Result<> makeArrayLen();
  [[nodiscard]] Result<> makeArrayCopy(HeapType destType, HeapType srcType);
  [[nodiscard]] Result<> makeArrayFill(HeapType type);
  [[nodiscard]] Result<> makeArrayInitData(HeapType type, Name data);
  [[nodiscard]] Result<> makeArrayInitElem(HeapType type, Name elem);
  [[nodiscard]] Result<> makeRefAs(RefAsOp op);
  [[nodiscard]] Result<> makeStringNew(StringNewOp op);
  [[nodiscard]] Result<> makeStringConst(Name string);
  [[nodiscard]] Result<> makeStringMeasure(StringMeasureOp op);
  [[nodiscard]] Result<> makeStringEncode(StringEncodeOp op);
  [[nodiscard]] Result<> makeStringConcat();
  [[nodiscard]] Result<> makeStringEq(StringEqOp op);
  [[nodiscard]] Result<> makeStringWTF8Advance();
  [[nodiscard]] Result<> makeStringWTF16Get();
  [[nodiscard]] Result<> makeStringIterNext();
  [[nodiscard]] Result<> makeStringSliceWTF();
  [[nodiscard]] Result<> makeContBind(HeapType contTypeBefore,
                                      HeapType contTypeAfter);
  [[nodiscard]] Result<> makeContNew(HeapType ct);
  [[nodiscard]] Result<> makeResume(HeapType ct,
                                    const std::vector<Name>& tags,
                                    const std::vector<Index>& labels);
  [[nodiscard]] Result<> makeSuspend(Name tag);

  // Private functions that must be public for technical reasons.
  [[nodiscard]] Result<> visitExpression(Expression*);

  // Do not push pops onto the stack since we generate our own pops as necessary
  // when visiting the beginnings of try blocks.
  [[nodiscard]] Result<> visitPop(Pop*) { return Ok{}; }

private:
  Module& wasm;
  Function* func;
  Builder builder;

  // The location lacks debug info as it was marked as not having it.
  struct NoDebug : public std::monostate {};
  // The location lacks debug info, but was not marked as not having
  // it, and it can receive it from the parent or its previous sibling
  // (if it has one).
  struct CanReceiveDebug : public std::monostate {};
  using DebugVariant =
    std::variant<NoDebug, CanReceiveDebug, Function::DebugLocation>;

  DebugVariant debugLoc;

  struct ChildPopper;

  void applyDebugLoc(Expression* expr);

  // The context for a single block scope, including the instructions parsed
  // inside that scope so far and the ultimate result type we expect this block
  // to have.
  struct ScopeCtx {
    struct NoScope {};
    struct FuncScope {
      Function* func;
    };
    struct BlockScope {
      Block* block;
    };
    struct IfScope {
      If* iff;
      Name originalLabel;
    };
    struct ElseScope {
      If* iff;
      Name originalLabel;
    };
    struct LoopScope {
      Loop* loop;
    };
    struct TryScope {
      Try* tryy;
      Name originalLabel;
    };
    struct CatchScope {
      Try* tryy;
      Name originalLabel;
    };
    struct CatchAllScope {
      Try* tryy;
      Name originalLabel;
    };
    struct TryTableScope {
      TryTable* trytable;
      Name originalLabel;
    };
    using Scope = std::variant<NoScope,
                               FuncScope,
                               BlockScope,
                               IfScope,
                               ElseScope,
                               LoopScope,
                               TryScope,
                               CatchScope,
                               CatchAllScope,
                               TryTableScope>;

    // The control flow structure we are building expressions for.
    Scope scope;

    // The branch label name for this scope. Always fresh, never shadowed.
    Name label;
    // For Try/Catch/CatchAll scopes, we need to separately track a label used
    // for branches, since the normal label is only used for delegates.
    Name branchLabel;

    bool labelUsed = false;

    std::vector<Expression*> exprStack;
    // Whether we have seen an unreachable instruction and are in
    // stack-polymorphic unreachable mode.
    bool unreachable = false;

    ScopeCtx() : scope(NoScope{}) {}
    ScopeCtx(Scope scope) : scope(scope) {}
    ScopeCtx(Scope scope, Name label, bool labelUsed)
      : scope(scope), label(label), labelUsed(labelUsed) {}
    ScopeCtx(Scope scope, Name label, bool labelUsed, Name branchLabel)
      : scope(scope), label(label), branchLabel(branchLabel),
        labelUsed(labelUsed) {}

    static ScopeCtx makeFunc(Function* func) {
      return ScopeCtx(FuncScope{func});
    }
    static ScopeCtx makeBlock(Block* block) {
      return ScopeCtx(BlockScope{block});
    }
    static ScopeCtx makeIf(If* iff, Name originalLabel = {}) {
      return ScopeCtx(IfScope{iff, originalLabel});
    }
    static ScopeCtx
    makeElse(If* iff, Name originalLabel, Name label, bool labelUsed) {
      return ScopeCtx(ElseScope{iff, originalLabel}, label, labelUsed);
    }
    static ScopeCtx makeLoop(Loop* loop) { return ScopeCtx(LoopScope{loop}); }
    static ScopeCtx makeTry(Try* tryy, Name originalLabel = {}) {
      return ScopeCtx(TryScope{tryy, originalLabel});
    }
    static ScopeCtx makeCatch(Try* tryy,
                              Name originalLabel,
                              Name label,
                              bool labelUsed,
                              Name branchLabel) {
      return ScopeCtx(
        CatchScope{tryy, originalLabel}, label, labelUsed, branchLabel);
    }
    static ScopeCtx makeCatchAll(Try* tryy,
                                 Name originalLabel,
                                 Name label,
                                 bool labelUsed,
                                 Name branchLabel) {
      return ScopeCtx(
        CatchAllScope{tryy, originalLabel}, label, labelUsed, branchLabel);
    }
    static ScopeCtx makeTryTable(TryTable* trytable, Name originalLabel = {}) {
      return ScopeCtx(TryTableScope{trytable, originalLabel});
    }

    bool isNone() { return std::get_if<NoScope>(&scope); }
    Function* getFunction() {
      if (auto* funcScope = std::get_if<FuncScope>(&scope)) {
        return funcScope->func;
      }
      return nullptr;
    }
    Block* getBlock() {
      if (auto* blockScope = std::get_if<BlockScope>(&scope)) {
        return blockScope->block;
      }
      return nullptr;
    }
    If* getIf() {
      if (auto* ifScope = std::get_if<IfScope>(&scope)) {
        return ifScope->iff;
      }
      return nullptr;
    }
    If* getElse() {
      if (auto* elseScope = std::get_if<ElseScope>(&scope)) {
        return elseScope->iff;
      }
      return nullptr;
    }
    Loop* getLoop() {
      if (auto* loopScope = std::get_if<LoopScope>(&scope)) {
        return loopScope->loop;
      }
      return nullptr;
    }
    Try* getTry() {
      if (auto* tryScope = std::get_if<TryScope>(&scope)) {
        return tryScope->tryy;
      }
      return nullptr;
    }
    Try* getCatch() {
      if (auto* catchScope = std::get_if<CatchScope>(&scope)) {
        return catchScope->tryy;
      }
      return nullptr;
    }
    Try* getCatchAll() {
      if (auto* catchAllScope = std::get_if<CatchAllScope>(&scope)) {
        return catchAllScope->tryy;
      }
      return nullptr;
    }
    TryTable* getTryTable() {
      if (auto* tryTableScope = std::get_if<TryTableScope>(&scope)) {
        return tryTableScope->trytable;
      }
      return nullptr;
    }
    Type getResultType() {
      if (auto* func = getFunction()) {
        return func->type.getSignature().results;
      }
      if (auto* block = getBlock()) {
        return block->type;
      }
      if (auto* iff = getIf()) {
        return iff->type;
      }
      if (auto* iff = getElse()) {
        return iff->type;
      }
      if (auto* loop = getLoop()) {
        return loop->type;
      }
      if (auto* tryy = getTry()) {
        return tryy->type;
      }
      if (auto* tryy = getCatch()) {
        return tryy->type;
      }
      if (auto* tryy = getCatchAll()) {
        return tryy->type;
      }
      if (auto* trytable = getTryTable()) {
        return trytable->type;
      }
      WASM_UNREACHABLE("unexpected scope kind");
    }
    Name getOriginalLabel() {
      if (std::get_if<NoScope>(&scope) || getFunction()) {
        return Name{};
      }
      if (auto* block = getBlock()) {
        return block->name;
      }
      if (auto* ifScope = std::get_if<IfScope>(&scope)) {
        return ifScope->originalLabel;
      }
      if (auto* elseScope = std::get_if<ElseScope>(&scope)) {
        return elseScope->originalLabel;
      }
      if (auto* loop = getLoop()) {
        return loop->name;
      }
      if (auto* tryScope = std::get_if<TryScope>(&scope)) {
        return tryScope->originalLabel;
      }
      if (auto* catchScope = std::get_if<CatchScope>(&scope)) {
        return catchScope->originalLabel;
      }
      if (auto* catchAllScope = std::get_if<CatchAllScope>(&scope)) {
        return catchAllScope->originalLabel;
      }
      if (auto* tryTableScope = std::get_if<TryTableScope>(&scope)) {
        return tryTableScope->originalLabel;
      }
      WASM_UNREACHABLE("unexpected scope kind");
    }
  };

  // The stack of block contexts currently being parsed.
  std::vector<ScopeCtx> scopeStack;

  // Map label names to stacks of label depths at which they appear. The
  // relative index of a label name is the current depth minus the top depth on
  // its stack.
  std::unordered_map<Name, std::vector<Index>> labelDepths;

  Name makeFresh(Name label) {
    return Names::getValidName(
      label,
      [&](Name candidate) {
        return labelDepths.insert({candidate, {}}).second;
      },
      0,
      "");
  }

  void pushScope(ScopeCtx scope) {
    if (auto label = scope.getOriginalLabel()) {
      // Assign a fresh label to the scope, if necessary.
      if (!scope.label) {
        scope.label = makeFresh(label);
      }
      // Record the original label to handle references to it correctly.
      labelDepths[label].push_back(scopeStack.size() + 1);
    }
    scopeStack.push_back(scope);
  }

  ScopeCtx& getScope() {
    if (scopeStack.empty()) {
      // We are not in a block context, so push a dummy scope.
      scopeStack.push_back({});
    }
    return scopeStack.back();
  }

  Result<ScopeCtx*> getScope(Index label) {
    Index numLabels = scopeStack.size();
    if (!scopeStack.empty() && scopeStack[0].isNone()) {
      --numLabels;
    }
    if (label >= numLabels) {
      return Err{"label index out of bounds"};
    }
    return &scopeStack[scopeStack.size() - 1 - label];
  }

  // Collect the current scope into a single expression. If it has multiple
  // top-level expressions, this requires collecting them into a block. If we
  // are in a block context, we can collect them directly into the destination
  // `block`, but otherwise we will have to allocate a new block.
  Result<Expression*> finishScope(Block* block = nullptr);

  [[nodiscard]] Result<Name> getLabelName(Index label,
                                          bool forDelegate = false);
  [[nodiscard]] Result<Name> getDelegateLabelName(Index label) {
    return getLabelName(label, true);
  }
  [[nodiscard]] Result<Index> addScratchLocal(Type);

  struct HoistedVal {
    // The index in the stack of the original value-producing expression.
    Index valIndex;
    // The local.get placed on the stack, if any.
    LocalGet* get;
  };

  // Find the last value-producing expression, if any, and hoist its value to
  // the top of the stack using a scratch local if necessary.
  [[nodiscard]] MaybeResult<HoistedVal> hoistLastValue();
  // Transform the stack as necessary such that the original producer of the
  // hoisted value will be popped along with the final expression that produces
  // the value, if they are different. May only be called directly after
  // hoistLastValue(). `sizeHint` is the size of the type we ultimately want to
  // consume, so if the hoisted value has `sizeHint` elements, it is left intact
  // even if it is a tuple. Otherwise, hoisted tuple values will be broken into
  // pieces.
  [[nodiscard]] Result<> packageHoistedValue(const HoistedVal&,
                                             size_t sizeHint = 1);

  [[nodiscard]] Result<Type> getLabelType(Index label);
  [[nodiscard]] Result<Type> getLabelType(Name labelName);

  void dump();
};

} // namespace wasm

#endif // wasm_wasm_ir_builder_h