summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/wasm-validator.h57
-rw-r--r--src/wasm/wasm-validator.cpp52
2 files changed, 76 insertions, 33 deletions
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 403a057fe..24affdb44 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -38,12 +38,26 @@
#define wasm_wasm_validator_h
#include <set>
+#include <sstream>
#include "wasm.h"
#include "wasm-printing.h"
namespace wasm {
+// Print anything that can be streamed to an ostream
+template <typename T>
+inline std::ostream& printModuleComponent(T curr, std::ostream& stream) {
+ stream << curr << std::endl;
+ return stream;
+}
+// Specialization for Expressions to print type info too
+template <>
+inline std::ostream& printModuleComponent(Expression* curr, std::ostream& stream) {
+ WasmPrinter::printExpression(curr, stream, false, true) << std::endl;
+ return stream;
+}
+
struct WasmValidator : public PostWalker<WasmValidator> {
bool valid = true;
@@ -123,6 +137,8 @@ public:
void visitSetLocal(SetLocal *curr);
void visitLoad(Load *curr);
void visitStore(Store *curr);
+ void visitAtomicRMW(AtomicRMW *curr);
+ void visitAtomicCmpxchg(AtomicCmpxchg *curr);
void visitBinary(Binary *curr);
void visitUnary(Unary *curr);
void visitSelect(Select* curr);
@@ -144,12 +160,14 @@ public:
// helpers
private:
- std::ostream& fail();
+ template <typename T, typename S>
+ std::ostream& fail(T curr, S text);
+ std::ostream& printFailureHeader();
+
template<typename T>
bool shouldBeTrue(bool result, T curr, const char* text) {
if (!result) {
- fail() << "unexpected false: " << text << ", on \n" << curr << std::endl;
- valid = false;
+ fail(curr, "unexpected false: " + std::string(text));
return false;
}
return result;
@@ -157,8 +175,7 @@ public:
template<typename T>
bool shouldBeFalse(bool result, T curr, const char* text) {
if (result) {
- fail() << "unexpected true: " << text << ", on \n" << curr << std::endl;
- valid = false;
+ fail(curr, "unexpected true: " + std::string(text));
return false;
}
return result;
@@ -167,18 +184,9 @@ public:
template<typename T, typename S>
bool shouldBeEqual(S left, S right, T curr, const char* text) {
if (left != right) {
- fail() << "" << left << " != " << right << ": " << text << ", on \n";
- WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
- valid = false;
- return false;
- }
- return true;
- }
- template<typename T, typename S, typename U>
- bool shouldBeEqual(S left, S right, T curr, U other, const char* text) {
- if (left != right) {
- fail() << "" << left << " != " << right << ": " << text << ", on \n" << curr << " / " << other << std::endl;
- valid = false;
+ std::ostringstream ss;
+ ss << left << " != " << right << ": " << text;
+ fail(curr, ss.str());
return false;
}
return true;
@@ -187,9 +195,9 @@ public:
template<typename T, typename S>
bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) {
if (left != unreachable && left != right) {
- fail() << "" << left << " != " << right << ": " << text << ", on \n";
- WasmPrinter::printExpression(curr, std::cerr, false, true) << std::endl;
- valid = false;
+ std::ostringstream ss;
+ ss << left << " != " << right << ": " << text;
+ fail(curr, ss.str());
return false;
}
return true;
@@ -198,14 +206,17 @@ public:
template<typename T, typename S>
bool shouldBeUnequal(S left, S right, T curr, const char* text) {
if (left == right) {
- fail() << "" << left << " == " << right << ": " << text << ", on \n" << curr << std::endl;
- valid = false;
+ std::ostringstream ss;
+ ss << left << " == " << right << ": " << text;
+ fail(curr, ss.str());
return false;
}
return true;
}
- void validateAlignment(size_t align, WasmType type, Index bytes);
+ void validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic,
+ Expression* curr);
+ void validateMemBytes(uint8_t bytes, WasmType ty, Expression* curr);
void validateBinaryenIR(Module& wasm);
};
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 21bf68b62..4c02fe6d3 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -219,15 +219,36 @@ void WasmValidator::visitSetLocal(SetLocal *curr) {
}
}
void WasmValidator::visitLoad(Load *curr) {
- validateAlignment(curr->align, curr->type, curr->bytes);
+ validateMemBytes(curr->bytes, curr->type, curr);
+ validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
}
void WasmValidator::visitStore(Store *curr) {
- validateAlignment(curr->align, curr->type, curr->bytes);
+ validateMemBytes(curr->bytes, curr->valueType, curr);
+ validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match");
}
+void WasmValidator::visitAtomicRMW(AtomicRMW* curr) {
+ validateMemBytes(curr->bytes, curr->type, curr);
+}
+void WasmValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) {
+ validateMemBytes(curr->bytes, curr->type, curr);
+}
+void WasmValidator::validateMemBytes(uint8_t bytes, WasmType ty, Expression* curr) {
+ switch (bytes) {
+ case 1:
+ case 2:
+ case 4:
+ break;
+ case 8: {
+ shouldBeEqual(getWasmTypeSize(ty), 8U, curr, "8-byte mem operations are only allowed with 8-byte wasm types");
+ break;
+ }
+ default: fail("Memory operations must be 1,2,4, or 8 bytes", curr);
+ }
+}
void WasmValidator::visitBinary(Binary *curr) {
if (curr->left->type != unreachable && curr->right->type != unreachable) {
shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal");
@@ -566,28 +587,32 @@ void WasmValidator::visitModule(Module *curr) {
}
}
-void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes) {
+void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes,
+ bool isAtomic, Expression* curr) {
+ if (isAtomic) {
+ shouldBeEqual(align, (size_t)bytes, curr, "atomic accesses must have natural alignment");
+ return;
+ }
switch (align) {
case 1:
case 2:
case 4:
case 8: break;
default:{
- fail() << "bad alignment: " << align << std::endl;
- valid = false;
+ fail("bad alignment: " + std::to_string(align), curr);
break;
}
}
- shouldBeTrue(align <= bytes, align, "alignment must not exceed natural");
+ shouldBeTrue(align <= bytes, curr, "alignment must not exceed natural");
switch (type) {
case i32:
case f32: {
- shouldBeTrue(align <= 4, align, "alignment must not exceed natural");
+ shouldBeTrue(align <= 4, curr, "alignment must not exceed natural");
break;
}
case i64:
case f64: {
- shouldBeTrue(align <= 8, align, "alignment must not exceed natural");
+ shouldBeTrue(align <= 8, curr, "alignment must not exceed natural");
break;
}
default: {}
@@ -614,7 +639,7 @@ void WasmValidator::validateBinaryenIR(Module& wasm) {
// The block has an added type, not derived from the ast itself, so it is
// ok for it to be either i32 or unreachable.
if (!(isConcreteWasmType(oldType) && newType == unreachable)) {
- parent.fail() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
+ parent.printFailureHeader() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
parent.valid = false;
}
curr->type = oldType;
@@ -625,7 +650,14 @@ void WasmValidator::validateBinaryenIR(Module& wasm) {
binaryenIRValidator.walkModule(&wasm);
}
-std::ostream& WasmValidator::fail() {
+template <typename T, typename S>
+std::ostream& WasmValidator::fail(T curr, S text) {
+ valid = false;
+ auto& ret = printFailureHeader() << text << ", on \n";
+ return printModuleComponent(curr, ret);
+}
+
+std::ostream& WasmValidator::printFailureHeader() {
Colors::red(std::cerr);
if (getFunction()) {
std::cerr << "[wasm-validator error in function ";