diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/FuncCastEmulation.cpp | 235 | ||||
-rw-r--r-- | src/passes/pass.cpp | 1 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | src/wasm-builder.h | 16 | ||||
-rw-r--r-- | src/wasm-linker.cpp | 8 | ||||
-rw-r--r-- | src/wasm/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 10 | ||||
-rw-r--r-- | src/wasm/wasm-emscripten.cpp (renamed from src/wasm-emscripten.cpp) | 0 | ||||
-rw-r--r-- | src/wasm/wasm-interpreter.cpp (renamed from src/wasm-interpreter.cpp) | 0 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 6 |
11 files changed, 272 insertions, 8 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index e5334fff1..d4fff78df 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -8,6 +8,7 @@ SET(passes_SOURCES DuplicateFunctionElimination.cpp ExtractFunction.cpp Flatten.cpp + FuncCastEmulation.cpp Inlining.cpp LegalizeJSInterface.cpp LocalCSE.cpp diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp new file mode 100644 index 000000000..59a2588da --- /dev/null +++ b/src/passes/FuncCastEmulation.cpp @@ -0,0 +1,235 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Instruments all indirect calls so that they work even if a function +// pointer was cast incorrectly. For example, if you cast an int (int, float) +// to an int (int, float, int) and call it natively, on most archs it will +// happen to work, ignoring the extra param, whereas in wasm it will trap. +// When porting code that relies on such casts working (like e.g. Python), +// this pass may be useful. It sets a new "ABI" for indirect calls, in which +// they all return an i64 and they have a fixed number of i64 params, and +// the pass converts everything to go through that. +// +// This should work even with dynamic linking, however, the number of +// params must be identical, i.e., the "ABI" must match. + +#include <wasm.h> +#include <wasm-builder.h> +#include <asm_v_wasm.h> +#include <pass.h> +#include <wasm-emscripten.h> +#include <ir/literal-utils.h> + +namespace wasm { + +// This should be enough for everybody. (As described above, we need this +// to match when dynamically linking, and also dynamic linking is why we +// can't just detect this automatically in the module we see.) +static const int NUM_PARAMS = 15; + +// Converts a value to the ABI type of i64. +static Expression* toABI(Expression* value, Module* module) { + Builder builder(*module); + switch (value->type) { + case i32: { + value = builder.makeUnary(ExtendUInt32, value); + break; + } + case i64: { + // already good + break; + } + case f32: { + value = builder.makeUnary( + ExtendUInt32, + builder.makeUnary(ReinterpretFloat32, value) + ); + break; + } + case f64: { + value = builder.makeUnary(ReinterpretFloat64, value); + break; + } + case none: { + // the value is none, but we need a value here + value = builder.makeSequence( + value, + LiteralUtils::makeZero(i64, *module) + ); + break; + } + case unreachable: { + // can leave it, the call isn't taken anyhow + break; + } + default: { + // SIMD may be interesting some day + WASM_UNREACHABLE(); + } + } + return value; +} + +// Converts a value from the ABI type of i64 to the expected type +static Expression* fromABI(Expression* value, Type type, Module* module) { + Builder builder(*module); + switch (type) { + case i32: { + value = builder.makeUnary(WrapInt64, value); + break; + } + case i64: { + // already good + break; + } + case f32: { + value = builder.makeUnary( + ReinterpretInt32, + builder.makeUnary(WrapInt64, value) + ); + break; + } + case f64: { + value = builder.makeUnary(ReinterpretInt64, value); + break; + } + case none: { + value = builder.makeDrop(value); + } + case unreachable: { + // can leave it, the call isn't taken anyhow + break; + } + default: { + // SIMD may be interesting some day + WASM_UNREACHABLE(); + } + } + return value; +} + +struct ParallelFuncCastEmulation : public WalkerPass<PostWalker<ParallelFuncCastEmulation>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new ParallelFuncCastEmulation(ABIType); } + + ParallelFuncCastEmulation(Name ABIType) : ABIType(ABIType) {} + + void visitCallIndirect(CallIndirect* curr) { + if (curr->operands.size() > NUM_PARAMS) { + Fatal() << "FuncCastEmulation::NUM_PARAMS needs to be at least " << + curr->operands.size(); + } + for (Expression*& operand : curr->operands) { + operand = toABI(operand, getModule()); + } + // Add extra operands as needed. + while (curr->operands.size() < NUM_PARAMS) { + curr->operands.push_back(LiteralUtils::makeZero(i64, *getModule())); + } + // Set the new types + auto oldType = curr->type; + curr->type = i64; + curr->fullType = ABIType; + // Fix up return value + replaceCurrent(fromABI(curr, oldType, getModule())); + } + +private: + // the name of a type for a call with the right params and return + Name ABIType; +}; + +struct FuncCastEmulation : public Pass { + void run(PassRunner* runner, Module* module) override { + // we just need the one ABI function type for all indirect calls + std::string sig = "j"; + for (Index i = 0; i < NUM_PARAMS; i++) { + sig += 'j'; + } + ABIType = ensureFunctionType(sig, module)->name; + // Add a way for JS to call into the table (as our i64 ABI means an i64 + // is returned when there is a return value, which JS engines will fail on), + // using dynCalls + EmscriptenGlueGenerator generator(*module); + generator.generateDynCallThunks(); + // Add a thunk for each function in the table, and do the call through it. + std::unordered_map<Name, Name> funcThunks; + for (auto& segment : module->table.segments) { + for (auto& name : segment.data) { + auto iter = funcThunks.find(name); + if (iter == funcThunks.end()) { + auto thunk = makeThunk(name, module); + funcThunks[name] = thunk; + name = thunk; + } else { + name = iter->second; + } + } + } + // update call_indirects + PassRunner subRunner(module, runner->options); + subRunner.setIsNested(true); + subRunner.add<ParallelFuncCastEmulation>(ABIType); + subRunner.run(); + } + +private: + // the name of a type for a call with the right params and return + Name ABIType; + + // Creates a thunk for a function, casting args and return value as needed. + Name makeThunk(Name name, Module* module) { + Name thunk = std::string("byn$fpcast-emu$") + name.str; + if (module->getFunctionOrNull(thunk)) { + Fatal() << "FuncCastEmulation::makeThunk seems a thunk name already in use. Was the pass already run on this code?"; + } + // The item in the table may be a function or a function import. + auto* func = module->getFunctionOrNull(name); + Import* imp = nullptr; + if (!func) imp = module->getImport(name); + std::vector<Type>& params = func ? func->params : module->getFunctionType(imp->functionType)->params; + Type type = func ? func->result : module->getFunctionType(imp->functionType)->result; + Builder builder(*module); + std::vector<Expression*> callOperands; + for (Index i = 0; i < params.size(); i++) { + callOperands.push_back(fromABI(builder.makeGetLocal(i, i64), params[i], module)); + } + Expression* call = func ? (Expression*)builder.makeCall(name, callOperands, type) + : (Expression*)builder.makeCallImport(name, callOperands, type); + std::vector<Type> thunkParams; + for (Index i = 0; i < NUM_PARAMS; i++) { + thunkParams.push_back(i64); + } + auto* thunkFunc = builder.makeFunction( + thunk, + std::move(thunkParams), + i64, + {}, // no vars + toABI(call, module) + ); + thunkFunc->type = ABIType; + module->addFunction(thunkFunc); + return thunk; + } +}; + +Pass* createFuncCastEmulationPass() { + return new FuncCastEmulation(); +} + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 37c59f570..6518a6f3d 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -74,6 +74,7 @@ void PassRegistry::registerPasses() { registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass); registerPass("extract-function", "leaves just one function (useful for debugging)", createExtractFunctionPass); registerPass("flatten", "flattens out code, removing nesting", createFlattenPass); + registerPass("fpcast-emu", "emulates function pointer casts, allowing incorrect indirect calls to (sometimes) work", createFuncCastEmulationPass); registerPass("func-metrics", "reports function metrics", createFunctionMetricsPass); registerPass("inlining", "inline functions (you probably want inlining-optimizing)", createInliningPass); registerPass("inlining-optimizing", "inline functions and optimizes where we inlined", createInliningOptimizingPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index 230cdfd86..2eaf1049e 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -31,6 +31,7 @@ Pass* createDeadCodeEliminationPass(); Pass* createDuplicateFunctionEliminationPass(); Pass* createExtractFunctionPass(); Pass* createFlattenPass(); +Pass* createFuncCastEmulationPass(); Pass* createFullPrinterPass(); Pass* createFunctionMetricsPass(); Pass* createI64ToI32LoweringPass(); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 94699df7e..02304acec 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -43,6 +43,20 @@ public: // make* functions, create nodes Function* makeFunction(Name name, + std::vector<Type>&& params, + Type resultType, + std::vector<Type>&& vars, + Expression* body = nullptr) { + auto* func = new Function; + func->name = name; + func->result = resultType; + func->body = body; + func->params.swap(params); + func->vars.swap(vars); + return func; + } + + Function* makeFunction(Name name, std::vector<NameType>&& params, Type resultType, std::vector<NameType>&& vars, @@ -51,7 +65,6 @@ public: func->name = name; func->result = resultType; func->body = body; - for (auto& param : params) { func->params.push_back(param.type); Index index = func->localNames.size(); @@ -64,7 +77,6 @@ public: func->localIndices[var.name] = index; func->localNames[index] = var.name; } - return func; } diff --git a/src/wasm-linker.cpp b/src/wasm-linker.cpp index c284de81f..df51d85c6 100644 --- a/src/wasm-linker.cpp +++ b/src/wasm-linker.cpp @@ -382,7 +382,13 @@ void Linker::makeDummyFunction() { if (!create) return; wasm::Builder wasmBuilder(out.wasm); Expression *unreachable = wasmBuilder.makeUnreachable(); - Function *dummy = wasmBuilder.makeFunction(Name(dummyFunction), {}, Type::none, {}, unreachable); + Function *dummy = wasmBuilder.makeFunction( + Name(dummyFunction), + std::vector<Type>{}, + Type::none, + std::vector<Type>{}, + unreachable + ); out.wasm.addFunction(dummy); getFunctionIndex(dummy->name); } diff --git a/src/wasm/CMakeLists.txt b/src/wasm/CMakeLists.txt index 1a8a9b8ba..da876b56f 100644 --- a/src/wasm/CMakeLists.txt +++ b/src/wasm/CMakeLists.txt @@ -2,6 +2,8 @@ SET(wasm_SOURCES literal.cpp wasm.cpp wasm-binary.cpp + wasm-emscripten.cpp + wasm-interpreter.cpp wasm-io.cpp wasm-s-parser.cpp wasm-type.cpp diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 56b35b712..69e939c44 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -1730,11 +1730,11 @@ void WasmBinaryBuilder::readFunctions() { } } auto func = Builder(wasm).makeFunction( - Name::fromInt(i), - std::move(params), - type->result, - std::move(vars) - ); + Name::fromInt(i), + std::move(params), + type->result, + std::move(vars) + ); func->type = type->name; currFunction = func; { diff --git a/src/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp index 9a393db43..9a393db43 100644 --- a/src/wasm-emscripten.cpp +++ b/src/wasm/wasm-emscripten.cpp diff --git a/src/wasm-interpreter.cpp b/src/wasm/wasm-interpreter.cpp index e7df785ac..e7df785ac 100644 --- a/src/wasm-interpreter.cpp +++ b/src/wasm/wasm-interpreter.cpp diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 8ff2fb9b8..d6c6e7bd0 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -806,6 +806,12 @@ void FunctionValidator::visitFunction(Function* curr) { shouldBeTrue(breakTargets.empty(), curr->body, "all named break targets must exist"); returnType = unreachable; labelNames.clear(); + // if function has a named type, it must match up with the function's params and result + if (info.validateGlobally && curr->type.is()) { + auto* ft = getModule()->getFunctionType(curr->type); + shouldBeTrue(ft->params == curr->params, curr->name, "function params must match its declared type"); + shouldBeTrue(ft->result == curr->result, curr->name, "function result must match its declared type"); + } // expressions must not be seen more than once struct Walker : public PostWalker<Walker, UnifiedExpressionVisitor<Walker>> { std::unordered_set<Expression*>& seen; |