/*
 * Copyright 2020 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//
// Enforce stack pointer limits.  This pass will add checks around all
// assignments to the __stack_pointer global that LLVM uses for its
// shadow stack.
//

#include "abi/js.h"
#include "ir/abstract.h"
#include "ir/import-utils.h"
#include "ir/names.h"
#include "pass.h"
#include "shared-constants.h"
#include "support/debug.h"
#include "wasm-emscripten.h"

#define DEBUG_TYPE "stack-check"

namespace wasm {

// Exported function to set the base and the limit.
static Name SET_STACK_LIMITS("__set_stack_limits");

static void
importStackOverflowHandler(Module& module, Name name, Signature sig) {
  ImportInfo info(module);

  if (!info.getImportedFunction(ENV, name)) {
    auto import = Builder::makeFunction(name, sig, {});
    import->module = ENV;
    import->base = name;
    module.addFunction(std::move(import));
  }
}

static void addExportedFunction(Module& module,
                                std::unique_ptr<Function> function) {
  auto export_ =
    Builder::makeExport(function->name, function->name, ExternalKind::Function);
  module.addFunction(std::move(function));
  module.addExport(std::move(export_));
}

struct EnforceStackLimits : public WalkerPass<PostWalker<EnforceStackLimits>> {
  EnforceStackLimits(const Global* stackPointer,
                     const Global* stackBase,
                     const Global* stackLimit,
                     Builder& builder,
                     Name handler)
    : stackPointer(stackPointer), stackBase(stackBase), stackLimit(stackLimit),
      builder(builder), handler(handler) {}

  bool isFunctionParallel() override { return true; }

  // Only affects linear memory operations.
  bool requiresNonNullableLocalFixups() override { return false; }

  std::unique_ptr<Pass> create() override {
    return std::make_unique<EnforceStackLimits>(
      stackPointer, stackBase, stackLimit, builder, handler);
  }

  Expression* stackBoundsCheck(Function* func, Expression* value) {
    // Add a local to store the value of the expression. We need the value
    // twice: once to check if it has overflowed, and again to assign to store
    // it.
    auto newSP = Builder::addVar(func, stackPointer->type);
    // If we imported a handler, call it. That can show a nice error in JS.
    // Otherwise, just trap.
    Expression* handlerExpr;
    if (handler.is()) {
      handlerExpr =
        builder.makeCall(handler,
                         {builder.makeLocalGet(newSP, stackPointer->type)},
                         stackPointer->type);
    } else {
      handlerExpr = builder.makeUnreachable();
    }

    // If it is >= the base or <= the limit, then error.
    auto check = builder.makeIf(
      builder.makeBinary(
        BinaryOp::OrInt32,
        builder.makeBinary(
          Abstract::getBinary(stackPointer->type, Abstract::GtU),
          builder.makeLocalTee(newSP, value, stackPointer->type),
          builder.makeGlobalGet(stackBase->name, stackBase->type)),
        builder.makeBinary(
          Abstract::getBinary(stackPointer->type, Abstract::LtU),
          builder.makeLocalGet(newSP, stackPointer->type),
          builder.makeGlobalGet(stackLimit->name, stackLimit->type))),
      handlerExpr);
    // (global.set $__stack_pointer (local.get $newSP))
    auto newSet = builder.makeGlobalSet(
      stackPointer->name, builder.makeLocalGet(newSP, stackPointer->type));
    return builder.blockify(check, newSet);
  }

  void visitGlobalSet(GlobalSet* curr) {
    if (getModule()->getGlobalOrNull(curr->name) == stackPointer) {
      replaceCurrent(stackBoundsCheck(getFunction(), curr->value));
    }
  }

private:
  const Global* stackPointer;
  const Global* stackBase;
  const Global* stackLimit;
  Builder& builder;
  Name handler;
};

struct StackCheck : public Pass {
  // Adds calls to new imports.
  bool addsEffects() override { return true; }

  void run(Module* module) override {
    Global* stackPointer = getStackPointerGlobal(*module);
    if (!stackPointer) {
      BYN_DEBUG(std::cerr << "no stack pointer found\n");
      return;
    }

    // Pick appropriate names.
    auto stackBaseName = Names::getValidGlobalName(*module, "__stack_base");
    auto stackLimitName = Names::getValidGlobalName(*module, "__stack_limit");

    Name handler;
    auto handlerName = getArgumentOrDefault("stack-check-handler", "");
    if (handlerName != "") {
      handler = handlerName;
      importStackOverflowHandler(
        *module, handler, Signature({stackPointer->type}, Type::none));
    }

    Builder builder(*module);

    // Add the globals.
    Type addressType =
      module->memories.empty() ? Type::i32 : module->memories[0]->addressType;
    auto stackBase =
      module->addGlobal(builder.makeGlobal(stackBaseName,
                                           stackPointer->type,
                                           builder.makeConstPtr(0, addressType),
                                           Builder::Mutable));
    auto stackLimit =
      module->addGlobal(builder.makeGlobal(stackLimitName,
                                           stackPointer->type,
                                           builder.makeConstPtr(0, addressType),
                                           Builder::Mutable));

    // Instrument all the code.
    EnforceStackLimits(stackPointer, stackBase, stackLimit, builder, handler)
      .run(getPassRunner(), module);

    // Generate the exported function.
    auto limitsFunc = builder.makeFunction(
      SET_STACK_LIMITS,
      Signature({stackPointer->type, stackPointer->type}, Type::none),
      {});
    auto* getBase = builder.makeLocalGet(0, stackPointer->type);
    auto* storeBase = builder.makeGlobalSet(stackBaseName, getBase);
    auto* getLimit = builder.makeLocalGet(1, stackPointer->type);
    auto* storeLimit = builder.makeGlobalSet(stackLimitName, getLimit);
    limitsFunc->body = builder.makeBlock({storeBase, storeLimit});
    addExportedFunction(*module, std::move(limitsFunc));
  }
};

Pass* createStackCheckPass() { return new StackCheck; }

} // namespace wasm