diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ast/import-utils.h | 41 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/SafeHeap.cpp | 318 | ||||
-rw-r--r-- | src/passes/pass.cpp | 1 | ||||
-rw-r--r-- | src/passes/passes.h | 1 |
5 files changed, 362 insertions, 0 deletions
diff --git a/src/ast/import-utils.h b/src/ast/import-utils.h new file mode 100644 index 000000000..d12c23182 --- /dev/null +++ b/src/ast/import-utils.h @@ -0,0 +1,41 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ast_import_h +#define wasm_ast_import_h + +#include "literal.h" +#include "wasm.h" + +namespace wasm { + +namespace ImportUtils { + // find an import by the module.base that is being imported. + // return the internal name + inline Name getImport(Module& wasm, Name module, Name base) { + for (auto& import : wasm.imports) { + if (import->module == module && import->base == base) { + return import->name; + } + } + return Name(); + } +}; + +} // namespace wasm + +#endif // wasm_ast_import_h + diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 7c0166786..a575d8b27 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -32,6 +32,7 @@ SET(passes_SOURCES RemoveUnusedModuleElements.cpp ReorderLocals.cpp ReorderFunctions.cpp + SafeHeap.cpp SimplifyLocals.cpp SSAify.cpp Untee.cpp diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp new file mode 100644 index 000000000..ebaf42358 --- /dev/null +++ b/src/passes/SafeHeap.cpp @@ -0,0 +1,318 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Instruments code to check for incorrect heap access. This checks +// for dereferencing 0 (null pointer access), reading past the valid +// top of sbrk()-addressible memory, and incorrect alignment notation. +// + +#include "wasm.h" +#include "pass.h" +#include "asm_v_wasm.h" +#include "asmjs/shared-constants.h" +#include "wasm-builder.h" +#include "ast/import-utils.h" + +namespace wasm { + +const Name DYNAMICTOP_PTR_IMPORT("DYNAMICTOP_PTR"), + SEGFAULT_IMPORT("segfault"), + ALIGNFAULT_IMPORT("alignfault"); + +static Name getLoadName(Load* curr) { + std::string ret = "SAFE_HEAP_LOAD_"; + ret += printWasmType(curr->type); + ret += "_" + std::to_string(curr->bytes) + "_"; + if (!isWasmTypeFloat(curr->type) && !curr->signed_) { + ret += "U_"; + } + if (curr->isAtomic) { + ret += "A"; + } else { + ret += std::to_string(curr->align); + } + return ret; +} + +static Name getStoreName(Store* curr) { + std::string ret = "SAFE_HEAP_STORE_"; + ret += printWasmType(curr->valueType); + ret += "_" + std::to_string(curr->bytes) + "_"; + if (curr->isAtomic) { + ret += "A"; + } else { + ret += std::to_string(curr->align); + } + return ret; +} + +struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> { + bool isFunctionParallel() override { return true; } + + AccessInstrumenter* create() override { return new AccessInstrumenter; } + + void visitLoad(Load* curr) { + if (curr->type == unreachable) return; + Builder builder(*getModule()); + replaceCurrent( + builder.makeCall( + getLoadName(curr), + { + curr->ptr, + builder.makeConst(Literal(int32_t(curr->offset))), + }, + curr->type + ) + ); + } + + void visitStore(Store* curr) { + if (curr->type == unreachable) return; + Builder builder(*getModule()); + replaceCurrent( + builder.makeCall( + getStoreName(curr), + { + curr->ptr, + builder.makeConst(Literal(int32_t(curr->offset))), + curr->value, + }, + none + ) + ); + } +}; + +struct SafeHeap : public Pass { + void run(PassRunner* runner, Module* module) override { + // add imports + addImports(module); + // instrument loads and stores + PassRunner instrumenter(module); + instrumenter.setIsNested(true); + instrumenter.add<AccessInstrumenter>(); + instrumenter.run(); + // add helper checking funcs and imports + addGlobals(module); + } + + Name dynamicTopPtr, segfault, alignfault; + + void addImports(Module* module) { + // imports + dynamicTopPtr = ImportUtils::getImport(*module, ENV, DYNAMICTOP_PTR_IMPORT); + if (!dynamicTopPtr.is()) { + auto* import = new Import; + import->name = dynamicTopPtr = DYNAMICTOP_PTR_IMPORT; + import->module = ENV; + import->base = DYNAMICTOP_PTR_IMPORT; + import->kind = ExternalKind::Global; + import->globalType = i32; + module->addImport(import); + } + segfault = ImportUtils::getImport(*module, ENV, SEGFAULT_IMPORT); + if (!segfault.is()) { + auto* import = new Import; + import->name = segfault = SEGFAULT_IMPORT; + import->module = ENV; + import->base = SEGFAULT_IMPORT; + import->kind = ExternalKind::Function; + import->functionType = ensureFunctionType("v", module)->name; + module->addImport(import); + } + alignfault = ImportUtils::getImport(*module, ENV, ALIGNFAULT_IMPORT); + if (!alignfault.is()) { + auto* import = new Import; + import->name = alignfault = ALIGNFAULT_IMPORT; + import->module = ENV; + import->base = ALIGNFAULT_IMPORT; + import->kind = ExternalKind::Function; + import->functionType = ensureFunctionType("v", module)->name; + module->addImport(import); + } + } + + void addGlobals(Module* module) { + // load funcs + Load load; + for (auto type : { i32, i64, f32, f64 }) { + load.type = type; + for (Index bytes : { 1, 2, 4, 8 }) { + load.bytes = bytes; + if (bytes > getWasmTypeSize(type)) continue; + for (auto signed_ : { true, false }) { + load.signed_ = signed_; + if (isWasmTypeFloat(type) && signed_) continue; + for (Index align : { 1, 2, 4, 8 }) { + load.align = align; + if (align > bytes) continue; + for (auto isAtomic : { true, false }) { + load.isAtomic = isAtomic; + if (isAtomic && align != bytes) continue; + if (isAtomic && !module->memory.shared) continue; + addLoadFunc(load, module); + } + } + } + } + } + // store funcs + Store store; + for (auto valueType : { i32, i64, f32, f64 }) { + store.valueType = valueType; + store.type = none; + for (Index bytes : { 1, 2, 4, 8 }) { + store.bytes = bytes; + if (bytes > getWasmTypeSize(valueType)) continue; + for (Index align : { 1, 2, 4, 8 }) { + store.align = align; + if (align > bytes) continue; + for (auto isAtomic : { true, false }) { + store.isAtomic = isAtomic; + if (isAtomic && align != bytes) continue; + if (isAtomic && !module->memory.shared) continue; + addStoreFunc(store, module); + } + } + } + } + } + + // creates a function for a particular style of load + void addLoadFunc(Load style, Module* module) { + auto* func = new Function; + func->name = getLoadName(&style); + func->params.push_back(i32); // pointer + func->params.push_back(i32); // offset + func->vars.push_back(i32); // pointer + offset + func->result = style.type; + Builder builder(*module); + auto* block = builder.makeBlock(); + block->list.push_back( + builder.makeSetLocal( + 2, + builder.makeBinary( + AddInt32, + builder.makeGetLocal(0, i32), + builder.makeGetLocal(1, i32) + ) + ) + ); + // check for reading past valid memory: if pointer + offset + bytes + block->list.push_back( + makeBoundsCheck(style.type, builder, 2) + ); + // check proper alignment + if (style.align > 1) { + block->list.push_back( + makeAlignCheck(style.align, builder, 2) + ); + } + // do the load + auto* load = module->allocator.alloc<Load>(); + *load = style; // basically the same as the template we are given! + load->ptr = builder.makeGetLocal(2, i32); + block->list.push_back(load); + block->finalize(style.type); + func->body = block; + module->addFunction(func); + } + + // creates a function for a particular type of store + void addStoreFunc(Store style, Module* module) { + auto* func = new Function; + func->name = getStoreName(&style); + func->params.push_back(i32); // pointer + func->params.push_back(i32); // offset + func->params.push_back(style.valueType); // value + func->vars.push_back(i32); // pointer + offset + func->result = none; + Builder builder(*module); + auto* block = builder.makeBlock(); + block->list.push_back( + builder.makeSetLocal( + 3, + builder.makeBinary( + AddInt32, + builder.makeGetLocal(0, i32), + builder.makeGetLocal(1, i32) + ) + ) + ); + // check for reading past valid memory: if pointer + offset + bytes + block->list.push_back( + makeBoundsCheck(style.valueType, builder, 3) + ); + // check proper alignment + if (style.align > 1) { + block->list.push_back( + makeAlignCheck(style.align, builder, 3) + ); + } + // do the store + auto* store = module->allocator.alloc<Store>(); + *store = style; // basically the same as the template we are given! + store->ptr = builder.makeGetLocal(3, i32); + store->value = builder.makeGetLocal(2, style.valueType); + block->list.push_back(store); + block->finalize(none); + func->body = block; + module->addFunction(func); + } + + Expression* makeAlignCheck(Address align, Builder& builder, Index local) { + return builder.makeIf( + builder.makeBinary( + AndInt32, + builder.makeGetLocal(local, i32), + builder.makeConst(Literal(int32_t(align - 1))) + ), + builder.makeCallImport(alignfault, {}, none) + ); + } + + Expression* makeBoundsCheck(WasmType type, Builder& builder, Index local) { + return builder.makeIf( + builder.makeBinary( + OrInt32, + builder.makeBinary( + EqInt32, + builder.makeGetLocal(local, i32), + builder.makeConst(Literal(int32_t(0))) + ), + builder.makeBinary( + GtUInt32, + builder.makeBinary( + AddInt32, + builder.makeGetLocal(local, i32), + builder.makeConst(Literal(int32_t(getWasmTypeSize(type)))) + ), + builder.makeLoad(4, false, 0, 4, + builder.makeGetGlobal(dynamicTopPtr, i32), i32 + ) + ) + ), + builder.makeCallImport(segfault, {}, none) + ); + } +}; + +Pass *createSafeHeapPass() { + return new SafeHeap(); +} + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 1e56169df..4ffd74047 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -101,6 +101,7 @@ void PassRegistry::registerPasses() { registerPass("reorder-locals", "sorts locals by access frequency", createReorderLocalsPass); registerPass("rereloop", "re-optimize control flow using the relooper algorithm", createReReloopPass); registerPass("simplify-locals", "miscellaneous locals-related optimizations", createSimplifyLocalsPass); + registerPass("safe-heap", "instrument loads and stores to check for invalid behavior", createSafeHeapPass); registerPass("simplify-locals-notee", "miscellaneous locals-related optimizations", createSimplifyLocalsNoTeePass); registerPass("simplify-locals-nostructure", "miscellaneous locals-related optimizations", createSimplifyLocalsNoStructurePass); registerPass("simplify-locals-notee-nostructure", "miscellaneous locals-related optimizations", createSimplifyLocalsNoTeeNoStructurePass); diff --git a/src/passes/passes.h b/src/passes/passes.h index 5e7eba540..18c92b2cb 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -58,6 +58,7 @@ Pass *createRemoveUnusedNamesPass(); Pass *createReorderFunctionsPass(); Pass *createReorderLocalsPass(); Pass *createReReloopPass(); +Pass *createSafeHeapPass(); Pass *createSimplifyLocalsPass(); Pass *createSimplifyLocalsNoTeePass(); Pass *createSimplifyLocalsNoStructurePass(); |