summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast/import-utils.h41
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/SafeHeap.cpp318
-rw-r--r--src/passes/pass.cpp1
-rw-r--r--src/passes/passes.h1
5 files changed, 362 insertions, 0 deletions
diff --git a/src/ast/import-utils.h b/src/ast/import-utils.h
new file mode 100644
index 000000000..d12c23182
--- /dev/null
+++ b/src/ast/import-utils.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_ast_import_h
+#define wasm_ast_import_h
+
+#include "literal.h"
+#include "wasm.h"
+
+namespace wasm {
+
+namespace ImportUtils {
+ // find an import by the module.base that is being imported.
+ // return the internal name
+ inline Name getImport(Module& wasm, Name module, Name base) {
+ for (auto& import : wasm.imports) {
+ if (import->module == module && import->base == base) {
+ return import->name;
+ }
+ }
+ return Name();
+ }
+};
+
+} // namespace wasm
+
+#endif // wasm_ast_import_h
+
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index 7c0166786..a575d8b27 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -32,6 +32,7 @@ SET(passes_SOURCES
RemoveUnusedModuleElements.cpp
ReorderLocals.cpp
ReorderFunctions.cpp
+ SafeHeap.cpp
SimplifyLocals.cpp
SSAify.cpp
Untee.cpp
diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp
new file mode 100644
index 000000000..ebaf42358
--- /dev/null
+++ b/src/passes/SafeHeap.cpp
@@ -0,0 +1,318 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Instruments code to check for incorrect heap access. This checks
+// for dereferencing 0 (null pointer access), reading past the valid
+// top of sbrk()-addressible memory, and incorrect alignment notation.
+//
+
+#include "wasm.h"
+#include "pass.h"
+#include "asm_v_wasm.h"
+#include "asmjs/shared-constants.h"
+#include "wasm-builder.h"
+#include "ast/import-utils.h"
+
+namespace wasm {
+
+const Name DYNAMICTOP_PTR_IMPORT("DYNAMICTOP_PTR"),
+ SEGFAULT_IMPORT("segfault"),
+ ALIGNFAULT_IMPORT("alignfault");
+
+static Name getLoadName(Load* curr) {
+ std::string ret = "SAFE_HEAP_LOAD_";
+ ret += printWasmType(curr->type);
+ ret += "_" + std::to_string(curr->bytes) + "_";
+ if (!isWasmTypeFloat(curr->type) && !curr->signed_) {
+ ret += "U_";
+ }
+ if (curr->isAtomic) {
+ ret += "A";
+ } else {
+ ret += std::to_string(curr->align);
+ }
+ return ret;
+}
+
+static Name getStoreName(Store* curr) {
+ std::string ret = "SAFE_HEAP_STORE_";
+ ret += printWasmType(curr->valueType);
+ ret += "_" + std::to_string(curr->bytes) + "_";
+ if (curr->isAtomic) {
+ ret += "A";
+ } else {
+ ret += std::to_string(curr->align);
+ }
+ return ret;
+}
+
+struct AccessInstrumenter : public WalkerPass<PostWalker<AccessInstrumenter>> {
+ bool isFunctionParallel() override { return true; }
+
+ AccessInstrumenter* create() override { return new AccessInstrumenter; }
+
+ void visitLoad(Load* curr) {
+ if (curr->type == unreachable) return;
+ Builder builder(*getModule());
+ replaceCurrent(
+ builder.makeCall(
+ getLoadName(curr),
+ {
+ curr->ptr,
+ builder.makeConst(Literal(int32_t(curr->offset))),
+ },
+ curr->type
+ )
+ );
+ }
+
+ void visitStore(Store* curr) {
+ if (curr->type == unreachable) return;
+ Builder builder(*getModule());
+ replaceCurrent(
+ builder.makeCall(
+ getStoreName(curr),
+ {
+ curr->ptr,
+ builder.makeConst(Literal(int32_t(curr->offset))),
+ curr->value,
+ },
+ none
+ )
+ );
+ }
+};
+
+struct SafeHeap : public Pass {
+ void run(PassRunner* runner, Module* module) override {
+ // add imports
+ addImports(module);
+ // instrument loads and stores
+ PassRunner instrumenter(module);
+ instrumenter.setIsNested(true);
+ instrumenter.add<AccessInstrumenter>();
+ instrumenter.run();
+ // add helper checking funcs and imports
+ addGlobals(module);
+ }
+
+ Name dynamicTopPtr, segfault, alignfault;
+
+ void addImports(Module* module) {
+ // imports
+ dynamicTopPtr = ImportUtils::getImport(*module, ENV, DYNAMICTOP_PTR_IMPORT);
+ if (!dynamicTopPtr.is()) {
+ auto* import = new Import;
+ import->name = dynamicTopPtr = DYNAMICTOP_PTR_IMPORT;
+ import->module = ENV;
+ import->base = DYNAMICTOP_PTR_IMPORT;
+ import->kind = ExternalKind::Global;
+ import->globalType = i32;
+ module->addImport(import);
+ }
+ segfault = ImportUtils::getImport(*module, ENV, SEGFAULT_IMPORT);
+ if (!segfault.is()) {
+ auto* import = new Import;
+ import->name = segfault = SEGFAULT_IMPORT;
+ import->module = ENV;
+ import->base = SEGFAULT_IMPORT;
+ import->kind = ExternalKind::Function;
+ import->functionType = ensureFunctionType("v", module)->name;
+ module->addImport(import);
+ }
+ alignfault = ImportUtils::getImport(*module, ENV, ALIGNFAULT_IMPORT);
+ if (!alignfault.is()) {
+ auto* import = new Import;
+ import->name = alignfault = ALIGNFAULT_IMPORT;
+ import->module = ENV;
+ import->base = ALIGNFAULT_IMPORT;
+ import->kind = ExternalKind::Function;
+ import->functionType = ensureFunctionType("v", module)->name;
+ module->addImport(import);
+ }
+ }
+
+ void addGlobals(Module* module) {
+ // load funcs
+ Load load;
+ for (auto type : { i32, i64, f32, f64 }) {
+ load.type = type;
+ for (Index bytes : { 1, 2, 4, 8 }) {
+ load.bytes = bytes;
+ if (bytes > getWasmTypeSize(type)) continue;
+ for (auto signed_ : { true, false }) {
+ load.signed_ = signed_;
+ if (isWasmTypeFloat(type) && signed_) continue;
+ for (Index align : { 1, 2, 4, 8 }) {
+ load.align = align;
+ if (align > bytes) continue;
+ for (auto isAtomic : { true, false }) {
+ load.isAtomic = isAtomic;
+ if (isAtomic && align != bytes) continue;
+ if (isAtomic && !module->memory.shared) continue;
+ addLoadFunc(load, module);
+ }
+ }
+ }
+ }
+ }
+ // store funcs
+ Store store;
+ for (auto valueType : { i32, i64, f32, f64 }) {
+ store.valueType = valueType;
+ store.type = none;
+ for (Index bytes : { 1, 2, 4, 8 }) {
+ store.bytes = bytes;
+ if (bytes > getWasmTypeSize(valueType)) continue;
+ for (Index align : { 1, 2, 4, 8 }) {
+ store.align = align;
+ if (align > bytes) continue;
+ for (auto isAtomic : { true, false }) {
+ store.isAtomic = isAtomic;
+ if (isAtomic && align != bytes) continue;
+ if (isAtomic && !module->memory.shared) continue;
+ addStoreFunc(store, module);
+ }
+ }
+ }
+ }
+ }
+
+ // creates a function for a particular style of load
+ void addLoadFunc(Load style, Module* module) {
+ auto* func = new Function;
+ func->name = getLoadName(&style);
+ func->params.push_back(i32); // pointer
+ func->params.push_back(i32); // offset
+ func->vars.push_back(i32); // pointer + offset
+ func->result = style.type;
+ Builder builder(*module);
+ auto* block = builder.makeBlock();
+ block->list.push_back(
+ builder.makeSetLocal(
+ 2,
+ builder.makeBinary(
+ AddInt32,
+ builder.makeGetLocal(0, i32),
+ builder.makeGetLocal(1, i32)
+ )
+ )
+ );
+ // check for reading past valid memory: if pointer + offset + bytes
+ block->list.push_back(
+ makeBoundsCheck(style.type, builder, 2)
+ );
+ // check proper alignment
+ if (style.align > 1) {
+ block->list.push_back(
+ makeAlignCheck(style.align, builder, 2)
+ );
+ }
+ // do the load
+ auto* load = module->allocator.alloc<Load>();
+ *load = style; // basically the same as the template we are given!
+ load->ptr = builder.makeGetLocal(2, i32);
+ block->list.push_back(load);
+ block->finalize(style.type);
+ func->body = block;
+ module->addFunction(func);
+ }
+
+ // creates a function for a particular type of store
+ void addStoreFunc(Store style, Module* module) {
+ auto* func = new Function;
+ func->name = getStoreName(&style);
+ func->params.push_back(i32); // pointer
+ func->params.push_back(i32); // offset
+ func->params.push_back(style.valueType); // value
+ func->vars.push_back(i32); // pointer + offset
+ func->result = none;
+ Builder builder(*module);
+ auto* block = builder.makeBlock();
+ block->list.push_back(
+ builder.makeSetLocal(
+ 3,
+ builder.makeBinary(
+ AddInt32,
+ builder.makeGetLocal(0, i32),
+ builder.makeGetLocal(1, i32)
+ )
+ )
+ );
+ // check for reading past valid memory: if pointer + offset + bytes
+ block->list.push_back(
+ makeBoundsCheck(style.valueType, builder, 3)
+ );
+ // check proper alignment
+ if (style.align > 1) {
+ block->list.push_back(
+ makeAlignCheck(style.align, builder, 3)
+ );
+ }
+ // do the store
+ auto* store = module->allocator.alloc<Store>();
+ *store = style; // basically the same as the template we are given!
+ store->ptr = builder.makeGetLocal(3, i32);
+ store->value = builder.makeGetLocal(2, style.valueType);
+ block->list.push_back(store);
+ block->finalize(none);
+ func->body = block;
+ module->addFunction(func);
+ }
+
+ Expression* makeAlignCheck(Address align, Builder& builder, Index local) {
+ return builder.makeIf(
+ builder.makeBinary(
+ AndInt32,
+ builder.makeGetLocal(local, i32),
+ builder.makeConst(Literal(int32_t(align - 1)))
+ ),
+ builder.makeCallImport(alignfault, {}, none)
+ );
+ }
+
+ Expression* makeBoundsCheck(WasmType type, Builder& builder, Index local) {
+ return builder.makeIf(
+ builder.makeBinary(
+ OrInt32,
+ builder.makeBinary(
+ EqInt32,
+ builder.makeGetLocal(local, i32),
+ builder.makeConst(Literal(int32_t(0)))
+ ),
+ builder.makeBinary(
+ GtUInt32,
+ builder.makeBinary(
+ AddInt32,
+ builder.makeGetLocal(local, i32),
+ builder.makeConst(Literal(int32_t(getWasmTypeSize(type))))
+ ),
+ builder.makeLoad(4, false, 0, 4,
+ builder.makeGetGlobal(dynamicTopPtr, i32), i32
+ )
+ )
+ ),
+ builder.makeCallImport(segfault, {}, none)
+ );
+ }
+};
+
+Pass *createSafeHeapPass() {
+ return new SafeHeap();
+}
+
+} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 1e56169df..4ffd74047 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -101,6 +101,7 @@ void PassRegistry::registerPasses() {
registerPass("reorder-locals", "sorts locals by access frequency", createReorderLocalsPass);
registerPass("rereloop", "re-optimize control flow using the relooper algorithm", createReReloopPass);
registerPass("simplify-locals", "miscellaneous locals-related optimizations", createSimplifyLocalsPass);
+ registerPass("safe-heap", "instrument loads and stores to check for invalid behavior", createSafeHeapPass);
registerPass("simplify-locals-notee", "miscellaneous locals-related optimizations", createSimplifyLocalsNoTeePass);
registerPass("simplify-locals-nostructure", "miscellaneous locals-related optimizations", createSimplifyLocalsNoStructurePass);
registerPass("simplify-locals-notee-nostructure", "miscellaneous locals-related optimizations", createSimplifyLocalsNoTeeNoStructurePass);
diff --git a/src/passes/passes.h b/src/passes/passes.h
index 5e7eba540..18c92b2cb 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -58,6 +58,7 @@ Pass *createRemoveUnusedNamesPass();
Pass *createReorderFunctionsPass();
Pass *createReorderLocalsPass();
Pass *createReReloopPass();
+Pass *createSafeHeapPass();
Pass *createSimplifyLocalsPass();
Pass *createSimplifyLocalsNoTeePass();
Pass *createSimplifyLocalsNoStructurePass();