/*
 * Copyright 2024 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//
// Instruments the build with code to intercept selected function calls.
// This can be e.g. used to trace allocations (malloc, free, calloc, realloc)
// and build tools for memory usage analysis.
// The pass supports SIMD but the multi-value feature is not supported yet.
//
// Instrumenting void free(void*):

// Instrumenting function `void* malloc(int32_t)` with a user-defined
// name of the tracer `trace_alloc` and function `void free(void*)`
// with the default name of the tracer `trace_free` (`trace_` prefix
// is added by default):
// wasm-opt --trace-calls=malloc:trace_alloc,free -o test-opt.wasm test.wasm
//
//  Before:
//   (call $malloc
//     (local.const 32))
//   (call $free (i32.const 64))
//
//  After:
//   (local $0 i32)
//   (local $1 i32)
//   (local $2 i32)
//   (block (result i32)
//     (call $trace_alloc
//       (local.get $0)
//       (local.tee $1
//         (call $malloc
//           (local.tee $0 (i32.const 2))
//         )
//       )
//     )
//   )
//   (block
//     (call $free
//       (local.tee $3
//         (i32.const 64)
//       )
//     )
//     (call $trace_free
//       (local.get $3)
//     )
//   )

#include <map>

#include "asmjs/shared-constants.h"
#include "ir/import-utils.h"
#include "pass.h"
#include "support/string.h"
#include "wasm-builder.h"

namespace wasm {

using TracedFunctions = std::map<Name /* originName */, Name /* tracerName */>;

struct AddTraceWrappers : public WalkerPass<PostWalker<AddTraceWrappers>> {
  AddTraceWrappers(TracedFunctions tracedFunctions)
    : tracedFunctions(std::move(tracedFunctions)) {}
  void visitCall(Call* curr) {
    auto* target = getModule()->getFunction(curr->target);

    auto iter = tracedFunctions.find(target->name);
    if (iter != tracedFunctions.end()) {
      addInstrumentation(curr, target, iter->second);
    }
  }

private:
  void addInstrumentation(Call* curr,
                          const wasm::Function* target,
                          const Name& wrapperName) {
    Builder builder(*getModule());
    std::vector<wasm::Expression*> realCallParams, trackerCallParams;

    for (const auto& op : curr->operands) {
      auto localVar = builder.addVar(getFunction(), op->type);
      realCallParams.push_back(builder.makeLocalTee(localVar, op, op->type));
      trackerCallParams.push_back(builder.makeLocalGet(localVar, op->type));
    }

    auto resultType = target->type.getSignature().results;
    auto realCall = builder.makeCall(target->name, realCallParams, resultType);

    if (resultType.isConcrete()) {
      auto resultLocal = builder.addVar(getFunction(), resultType);
      trackerCallParams.insert(
        trackerCallParams.begin(),
        builder.makeLocalTee(resultLocal, realCall, resultType));

      replaceCurrent(builder.makeBlock(
        {builder.makeCall(
           wrapperName, trackerCallParams, Type::BasicType::none),
         builder.makeLocalGet(resultLocal, resultType)}));
    } else {
      replaceCurrent(builder.makeBlock(
        {realCall,
         builder.makeCall(
           wrapperName, trackerCallParams, Type::BasicType::none)}));
    }
  }

  TracedFunctions tracedFunctions;
};

struct TraceCalls : public Pass {
  // Adds calls to new imports.
  bool addsEffects() override { return true; }

  void run(Module* module) override {
    auto functionsDefinitions =
      getArgument("trace-calls",
                  "TraceCalls usage: wasm-opt "
                  "--trace-calls=FUNCTION_TO_TRACE[:TRACER_NAME][,...]");

    auto tracedFunctions = parseArgument(functionsDefinitions);

    for (const auto& tracedFunction : tracedFunctions) {
      auto func = module->getFunctionOrNull(tracedFunction.first);
      if (!func) {
        std::cerr << "[TraceCalls] Function '" << tracedFunction.first
                  << "' not found" << std::endl;
      } else {
        addImport(module, *func, tracedFunction.second);
      }
    }

    AddTraceWrappers(std::move(tracedFunctions)).run(getPassRunner(), module);
  }

private:
  Type getTracerParamsType(ImportInfo& info, const Function& func) {
    auto resultsType = func.type.getSignature().results;
    if (resultsType.isTuple()) {
      Fatal() << "Failed to instrument function '" << func.name
              << "': Multi-value result type is not supported";
    }

    std::vector<Type> tracerParamTypes;
    if (resultsType.isConcrete()) {
      tracerParamTypes.push_back(resultsType);
    }
    for (auto& op : func.type.getSignature().params) {
      tracerParamTypes.push_back(op);
    }

    return Type(tracerParamTypes);
  }

  TracedFunctions parseArgument(const std::string& arg) {
    TracedFunctions tracedFunctions;

    for (const auto& definition : String::Split(arg, ",")) {
      if (definition.empty()) {
        // Empty definition, ignore.
        continue;
      }

      std::string originName, traceName;
      parseFunctionName(definition, originName, traceName);

      tracedFunctions[Name(originName)] = Name(traceName);
    }

    return tracedFunctions;
  }

  void parseFunctionName(const std::string& str,
                         std::string& originName,
                         std::string& traceName) {
    auto parts = String::Split(str, ":");
    switch (parts.size()) {
      case 1:
        originName = parts[0];
        traceName = "trace_" + originName;
        break;
      case 2:
        originName = parts[0];
        traceName = parts[1];
        break;
      default:
        Fatal() << "Failed to parse function name ('" << str
                << "'): expected format FUNCTION_TO_TRACE[:TRACER_NAME]";
    }
  }

  void addImport(Module* wasm, const Function& f, const Name& tracerName) {
    ImportInfo info(*wasm);

    if (!info.getImportedFunction(ENV, tracerName)) {
      auto import = Builder::makeFunction(
        tracerName, Signature(getTracerParamsType(info, f), Type::none), {});
      import->module = ENV;
      import->base = tracerName;
      wasm->addFunction(std::move(import));
    }
  }
};

Pass* createTraceCallsPass() { return new TraceCalls(); }

} // namespace wasm