summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/wasm-interpreter.h4
-rw-r--r--src/wasm-s-parser.h15
-rw-r--r--src/wasm-validator.h17
3 files changed, 25 insertions, 11 deletions
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index 0ee2aff78..82ead1d8f 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -359,6 +359,10 @@ private:
LiteralList arguments;
Flow flow = generateArguments(curr->operands, arguments);
if (flow.breaking()) return flow;
+ if (func->params.size() != arguments.size()) trap("callIndirect: bad # of arguments");
+ for (size_t i = 0; i < func->getNumLocals(); i++) {
+ if (func->params[i] != arguments[i].type) trap("callIndirect: bad argument type");
+ }
return instance.callFunctionInternal(name, arguments);
}
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index e6262c645..2f1a1bc1b 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -822,9 +822,15 @@ private:
}
Index getLocalIndex(Element& s) {
- if (s.dollared()) return currFunction->getLocalIndex(s.str());
+ if (s.dollared()) {
+ auto ret = s.str();
+ if (currFunction->localIndices.count(ret) == 0) throw ParseException("bad local name", s.line, s.col);
+ return currFunction->getLocalIndex(ret);
+ }
// this is a numeric index
- return atoi(s.c_str());
+ Index ret = atoi(s.c_str());
+ if (ret >= currFunction->getNumLocals()) throw ParseException("bad local index", s.line, s.col);
+ return ret;
}
Expression* makeGetLocal(Element& s) {
@@ -1088,7 +1094,8 @@ private:
Expression* makeCallIndirect(Element& s) {
auto ret = allocator.alloc<CallIndirect>();
IString type = s[1]->str();
- ret->fullType = wasm.getFunctionType(type);
+ ret->fullType = wasm.checkFunctionType(type);
+ if (!ret->fullType) throw ParseException("invalid call_indirect type", s.line, s.col);
assert(ret->fullType);
ret->type = ret->fullType->result;
ret->target = parseExpression(s[2]);
@@ -1109,7 +1116,7 @@ private:
} else {
// offset, break to nth outside label
uint64_t offset = std::stoll(s.c_str(), nullptr, 0);
- if (offset >= labelStack.size()) throw ParseException("total memory must be <= 4GB", s.line, s.col);
+ if (offset >= labelStack.size()) throw ParseException("invalid label", s.line, s.col);
return labelStack[labelStack.size() - 1 - offset];
}
}
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 39eaca572..2a11bf64f 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -93,23 +93,26 @@ public:
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
}
void visitCall(Call *curr) {
- auto* target = getModule()->getFunction(curr->target);
- shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match");
+ auto* target = getModule()->checkFunction(curr->target);
+ if (!shouldBeTrue(!!target, curr, "call target must exist")) return;
+ if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match");
}
}
void visitCallImport(CallImport *curr) {
- auto* target = getModule()->getImport(curr->target)->type;
- shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match");
+ auto* import = getModule()->checkImport(curr->target);
+ if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return;
+ auto* type = import->type;
+ if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
- shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match");
+ shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match");
}
}
void visitCallIndirect(CallIndirect *curr) {
auto* type = curr->fullType;
shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32");
- shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match");
+ if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match");
}
@@ -245,7 +248,7 @@ public:
}
void visitMemory(Memory *curr) {
shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
- shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "total memory must be <= 4GB");
+ shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB");
size_t top = 0;
for (auto& segment : curr->segments) {
shouldBeFalse(segment.offset < top, "memory", "segment offset is small enough");