summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/apply-names.cc3
-rw-r--r--src/binary-reader-ir.cc7
-rw-r--r--src/binary-writer.cc3
-rw-r--r--src/ir.h81
-rw-r--r--src/resolve-names.cc3
-rw-r--r--src/validator.cc55
-rw-r--r--src/wast-parser.cc95
-rw-r--r--src/wast-parser.h2
-rw-r--r--src/wat-writer.cc23
9 files changed, 177 insertions, 95 deletions
diff --git a/src/apply-names.cc b/src/apply-names.cc
index 2080e958..b7db07db 100644
--- a/src/apply-names.cc
+++ b/src/apply-names.cc
@@ -272,7 +272,8 @@ Result NameApplier::OnCallExpr(CallExpr* expr) {
}
Result NameApplier::OnCallIndirectExpr(CallIndirectExpr* expr) {
- CHECK_RESULT(UseNameForFuncTypeVar(&expr->var));
+ if (expr->decl.has_func_type)
+ CHECK_RESULT(UseNameForFuncTypeVar(&expr->decl.type_var));
return Result::Ok;
}
diff --git a/src/binary-reader-ir.cc b/src/binary-reader-ir.cc
index 5e4dd9f3..6ee7f9dc 100644
--- a/src/binary-reader-ir.cc
+++ b/src/binary-reader-ir.cc
@@ -595,8 +595,11 @@ Result BinaryReaderIR::OnCallExpr(Index func_index) {
Result BinaryReaderIR::OnCallIndirectExpr(Index sig_index) {
assert(sig_index < module_->func_types.size());
- return AppendExpr(
- MakeUnique<CallIndirectExpr>(Var(sig_index, GetLocation())));
+ auto expr = MakeUnique<CallIndirectExpr>(GetLocation());
+ expr->decl.has_func_type = true;
+ expr->decl.type_var = Var(sig_index, GetLocation());
+ expr->decl.sig = module_->func_types[sig_index]->sig;
+ return AppendExpr(std::move(expr));
}
Result BinaryReaderIR::OnCompareExpr(Opcode opcode) {
diff --git a/src/binary-writer.cc b/src/binary-writer.cc
index 687b117d..b0af2c35 100644
--- a/src/binary-writer.cc
+++ b/src/binary-writer.cc
@@ -398,7 +398,8 @@ void BinaryWriter::WriteExpr(const Module* module,
break;
}
case ExprType::CallIndirect: {
- Index index = module->GetFuncTypeIndex(cast<CallIndirectExpr>(expr)->var);
+ Index index =
+ module->GetFuncTypeIndex(cast<CallIndirectExpr>(expr)->decl);
WriteOpcode(stream_, Opcode::CallIndirect);
WriteU32Leb128WithReloc(index, "signature index",
RelocType::TypeIndexLEB);
diff --git a/src/ir.h b/src/ir.h
index 99950421..40455536 100644
--- a/src/ir.h
+++ b/src/ir.h
@@ -113,6 +113,42 @@ struct Const {
};
typedef std::vector<Const> ConstVector;
+struct FuncSignature {
+ TypeVector param_types;
+ TypeVector result_types;
+
+ Index GetNumParams() const { return param_types.size(); }
+ Index GetNumResults() const { return result_types.size(); }
+ Type GetParamType(Index index) const { return param_types[index]; }
+ Type GetResultType(Index index) const { return result_types[index]; }
+
+ bool operator==(const FuncSignature&) const;
+};
+
+struct FuncType {
+ FuncType() = default;
+ explicit FuncType(string_view name) : name(name.to_string()) {}
+
+ Index GetNumParams() const { return sig.GetNumParams(); }
+ Index GetNumResults() const { return sig.GetNumResults(); }
+ Type GetParamType(Index index) const { return sig.GetParamType(index); }
+ Type GetResultType(Index index) const { return sig.GetResultType(index); }
+
+ std::string name;
+ FuncSignature sig;
+};
+
+struct FuncDeclaration {
+ Index GetNumParams() const { return sig.GetNumParams(); }
+ Index GetNumResults() const { return sig.GetNumResults(); }
+ Type GetParamType(Index index) const { return sig.GetParamType(index); }
+ Type GetResultType(Index index) const { return sig.GetResultType(index); }
+
+ bool has_func_type = false;
+ Var type_var;
+ FuncSignature sig;
+};
+
enum class ExprType {
AtomicLoad,
AtomicStore,
@@ -246,7 +282,6 @@ class VarExpr : public ExprMixin<TypeEnum> {
typedef VarExpr<ExprType::Br> BrExpr;
typedef VarExpr<ExprType::BrIf> BrIfExpr;
typedef VarExpr<ExprType::Call> CallExpr;
-typedef VarExpr<ExprType::CallIndirect> CallIndirectExpr;
typedef VarExpr<ExprType::GetGlobal> GetGlobalExpr;
typedef VarExpr<ExprType::GetLocal> GetLocalExpr;
typedef VarExpr<ExprType::Rethrow> RethrowExpr;
@@ -255,6 +290,14 @@ typedef VarExpr<ExprType::SetLocal> SetLocalExpr;
typedef VarExpr<ExprType::TeeLocal> TeeLocalExpr;
typedef VarExpr<ExprType::Throw> ThrowExpr;
+class CallIndirectExpr : public ExprMixin<ExprType::CallIndirect> {
+ public:
+ explicit CallIndirectExpr(const Location& loc = Location())
+ : ExprMixin<ExprType::CallIndirect>(loc) {}
+
+ FuncDeclaration decl;
+};
+
template <ExprType TypeEnum>
class BlockExprBase : public ExprMixin<TypeEnum> {
public:
@@ -337,42 +380,6 @@ struct Exception {
TypeVector sig;
};
-struct FuncSignature {
- TypeVector param_types;
- TypeVector result_types;
-
- Index GetNumParams() const { return param_types.size(); }
- Index GetNumResults() const { return result_types.size(); }
- Type GetParamType(Index index) const { return param_types[index]; }
- Type GetResultType(Index index) const { return result_types[index]; }
-
- bool operator==(const FuncSignature&) const;
-};
-
-struct FuncType {
- FuncType() = default;
- explicit FuncType(string_view name) : name(name.to_string()) {}
-
- Index GetNumParams() const { return sig.GetNumParams(); }
- Index GetNumResults() const { return sig.GetNumResults(); }
- Type GetParamType(Index index) const { return sig.GetParamType(index); }
- Type GetResultType(Index index) const { return sig.GetResultType(index); }
-
- std::string name;
- FuncSignature sig;
-};
-
-struct FuncDeclaration {
- Index GetNumParams() const { return sig.GetNumParams(); }
- Index GetNumResults() const { return sig.GetNumResults(); }
- Type GetParamType(Index index) const { return sig.GetParamType(index); }
- Type GetResultType(Index index) const { return sig.GetResultType(index); }
-
- bool has_func_type = false;
- Var type_var;
- FuncSignature sig;
-};
-
struct Func {
Func() = default;
explicit Func(string_view name) : name(name.to_string()) {}
diff --git a/src/resolve-names.cc b/src/resolve-names.cc
index dfdee366..2b6199b8 100644
--- a/src/resolve-names.cc
+++ b/src/resolve-names.cc
@@ -245,7 +245,8 @@ Result NameResolver::OnCallExpr(CallExpr* expr) {
}
Result NameResolver::OnCallIndirectExpr(CallIndirectExpr* expr) {
- ResolveFuncTypeVar(&expr->var);
+ if (expr->decl.has_func_type)
+ ResolveFuncTypeVar(&expr->decl.type_var);
return Result::Ok;
}
diff --git a/src/validator.cc b/src/validator.cc
index 9ef679ec..c4ad1f9a 100644
--- a/src/validator.cc
+++ b/src/validator.cc
@@ -25,6 +25,7 @@
#include "src/binary-reader.h"
#include "src/cast.h"
+#include "src/expr-visitor.h"
#include "src/error-handler.h"
#include "src/ir.h"
#include "src/type-checker.h"
@@ -115,7 +116,9 @@ class Validator {
template <typename T>
void CheckAtomicExpr(const T* expr, Result (TypeChecker::*func)(Opcode));
void CheckExpr(const Expr* expr);
- void CheckFuncSignature(const Location* loc, const Func* func);
+ void CheckFuncSignature(const Location* loc, const FuncDeclaration& decl);
+ class CheckFuncSignatureExprVisitorDelegate;
+
void CheckFunc(const Location* loc, const Func* func);
void PrintConstExprError(const Location* loc, const char* desc);
void CheckConstInitExpr(const Location* loc,
@@ -496,15 +499,16 @@ void Validator::CheckExpr(const Expr* expr) {
}
case ExprType::CallIndirect: {
- const FuncType* func_type;
if (current_module_->tables.size() == 0) {
PrintError(&expr->loc, "found call_indirect operator, but no table");
}
- if (Succeeded(CheckFuncTypeVar(&cast<CallIndirectExpr>(expr)->var,
- &func_type))) {
- typechecker_.OnCallIndirect(&func_type->sig.param_types,
- &func_type->sig.result_types);
+ auto ci_expr = cast<CallIndirectExpr>(expr);
+ if (ci_expr->decl.has_func_type) {
+ const FuncType* func_type;
+ CheckFuncTypeVar(&ci_expr->decl.type_var, &func_type);
}
+ typechecker_.OnCallIndirect(&ci_expr->decl.sig.param_types,
+ &ci_expr->decl.sig.result_types);
break;
}
@@ -667,13 +671,14 @@ void Validator::CheckExpr(const Expr* expr) {
}
}
-void Validator::CheckFuncSignature(const Location* loc, const Func* func) {
- if (func->decl.has_func_type) {
+void Validator::CheckFuncSignature(const Location* loc,
+ const FuncDeclaration& decl) {
+ if (decl.has_func_type) {
const FuncType* func_type;
- if (Succeeded(CheckFuncTypeVar(&func->decl.type_var, &func_type))) {
- CheckTypes(loc, func->decl.sig.result_types, func_type->sig.result_types,
+ if (Succeeded(CheckFuncTypeVar(&decl.type_var, &func_type))) {
+ CheckTypes(loc, decl.sig.result_types, func_type->sig.result_types,
"function", "result");
- CheckTypes(loc, func->decl.sig.param_types, func_type->sig.param_types,
+ CheckTypes(loc, decl.sig.param_types, func_type->sig.param_types,
"function", "argument");
}
}
@@ -681,7 +686,7 @@ void Validator::CheckFuncSignature(const Location* loc, const Func* func) {
void Validator::CheckFunc(const Location* loc, const Func* func) {
current_func_ = func;
- CheckFuncSignature(loc, func);
+ CheckFuncSignature(loc, func->decl);
if (func->GetNumResults() > 1) {
PrintError(loc, "multiple result values not currently supported.");
// Don't run any other checks, the won't test the result_type properly.
@@ -1192,13 +1197,35 @@ Result Validator::CheckScript(const Script* script) {
return result_;
}
+class Validator::CheckFuncSignatureExprVisitorDelegate
+ : public ExprVisitor::DelegateNop {
+ public:
+ explicit CheckFuncSignatureExprVisitorDelegate(Validator* validator)
+ : validator_(validator) {}
+
+ Result OnCallIndirectExpr(CallIndirectExpr* expr) override {
+ validator_->CheckFuncSignature(&expr->loc, expr->decl);
+ return Result::Ok;
+ }
+
+ private:
+ Validator* validator_;
+};
+
Result Validator::CheckAllFuncSignatures(const Module* module) {
current_module_ = module;
for (const ModuleField& field : module->fields) {
switch (field.type()) {
- case ModuleFieldType::Func:
- CheckFuncSignature(&field.loc, &cast<FuncModuleField>(&field)->func);
+ case ModuleFieldType::Func: {
+ auto func_field = cast<FuncModuleField>(&field);
+ CheckFuncSignature(&field.loc, func_field->func.decl);
+ CheckFuncSignatureExprVisitorDelegate delegate(this);
+ ExprVisitor visitor(&delegate);
+ // TODO(binji): would rather not do a const_cast here, but the visitor
+ // is non-const only.
+ visitor.VisitFunc(const_cast<Func*>(&func_field->func));
break;
+ }
default:
break;
diff --git a/src/wast-parser.cc b/src/wast-parser.cc
index ee3acdba..577b18c7 100644
--- a/src/wast-parser.cc
+++ b/src/wast-parser.cc
@@ -19,6 +19,7 @@
#include "src/binary-reader.h"
#include "src/binary-reader-ir.h"
#include "src/cast.h"
+#include "src/expr-visitor.h"
#include "src/error-handler.h"
#include "src/make-unique.h"
#include "src/utf8.h"
@@ -253,15 +254,56 @@ bool IsEmptySignature(const FuncSignature* sig) {
return sig->result_types.empty() && sig->param_types.empty();
}
+void ResolveFuncType(const Location& loc,
+ Module* module,
+ FuncDeclaration* decl) {
+ // Resolve func type variables where the signature was not specified
+ // explicitly, e.g.: (func (type 1) ...)
+ if (decl->has_func_type && IsEmptySignature(&decl->sig)) {
+ FuncType* func_type = module->GetFuncType(decl->type_var);
+ if (func_type) {
+ decl->sig = func_type->sig;
+ }
+ }
+
+ // Resolve implicitly defined function types, e.g.: (func (param i32) ...)
+ if (!decl->has_func_type) {
+ Index func_type_index = module->GetFuncTypeIndex(decl->sig);
+ if (func_type_index == kInvalidIndex) {
+ auto func_type_field = MakeUnique<FuncTypeModuleField>(loc);
+ func_type_field->func_type.sig = decl->sig;
+ module->AppendField(std::move(func_type_field));
+ }
+ }
+}
+
+class ResolveFuncTypesExprVisitorDelegate : public ExprVisitor::DelegateNop {
+ public:
+ explicit ResolveFuncTypesExprVisitorDelegate(Module* module)
+ : module_(module) {}
+
+ Result OnCallIndirectExpr(CallIndirectExpr* expr) override {
+ ResolveFuncType(expr->loc, module_, &expr->decl);
+ return Result::Ok;
+ }
+
+ private:
+ Module* module_;
+};
+
void ResolveFuncTypes(Module* module) {
for (ModuleField& field : module->fields) {
Func* func = nullptr;
+ FuncDeclaration* decl = nullptr;
if (auto* func_field = dyn_cast<FuncModuleField>(&field)) {
func = &func_field->func;
+ decl = &func->decl;
} else if (auto* import_field = dyn_cast<ImportModuleField>(&field)) {
if (auto* func_import =
dyn_cast<FuncImport>(import_field->import.get())) {
- func = &func_import->func;
+ // Only check the declaration, not the function itself, since it is an
+ // import.
+ decl = &func_import->func.decl;
} else {
continue;
}
@@ -269,23 +311,13 @@ void ResolveFuncTypes(Module* module) {
continue;
}
- // Resolve func type variables where the signature was not specified
- // explicitly, e.g.: (func (type 1) ...)
- if (func->decl.has_func_type && IsEmptySignature(&func->decl.sig)) {
- FuncType* func_type = module->GetFuncType(func->decl.type_var);
- if (func_type) {
- func->decl.sig = func_type->sig;
- }
- }
+ if (decl)
+ ResolveFuncType(field.loc, module, decl);
- // Resolve implicitly defined function types, e.g.: (func (param i32) ...)
- if (!func->decl.has_func_type) {
- Index func_type_index = module->GetFuncTypeIndex(func->decl.sig);
- if (func_type_index == kInvalidIndex) {
- auto func_type_field = MakeUnique<FuncTypeModuleField>(field.loc);
- func_type_field->func_type.sig = func->decl.sig;
- module->AppendField(std::move(func_type_field));
- }
+ if (func) {
+ ResolveFuncTypesExprVisitorDelegate delegate(module);
+ ExprVisitor visitor(&delegate);
+ visitor.VisitFunc(func);
}
}
}
@@ -1143,6 +1175,13 @@ Result WastParser::ParseFuncSignature(FuncSignature* sig,
return Result::Ok;
}
+Result WastParser::ParseUnboundFuncSignature(FuncSignature* sig) {
+ WABT_TRACE(ParseUnboundFuncSignature);
+ CHECK_RESULT(ParseUnboundValueTypeList(TokenType::Param, &sig->param_types));
+ CHECK_RESULT(ParseResultList(&sig->result_types));
+ return Result::Ok;
+}
+
Result WastParser::ParseBoundValueTypeList(TokenType token,
TypeVector* types,
BindingHash* bindings) {
@@ -1164,15 +1203,21 @@ Result WastParser::ParseBoundValueTypeList(TokenType token,
return Result::Ok;
}
-Result WastParser::ParseResultList(TypeVector* result_types) {
- WABT_TRACE(ParseResultList);
- while (MatchLpar(TokenType::Result)) {
- CHECK_RESULT(ParseValueTypeList(result_types));
+Result WastParser::ParseUnboundValueTypeList(TokenType token,
+ TypeVector* types) {
+ WABT_TRACE(ParseUnboundValueTypeList);
+ while (MatchLpar(token)) {
+ CHECK_RESULT(ParseValueTypeList(types));
EXPECT(Rpar);
}
return Result::Ok;
}
+Result WastParser::ParseResultList(TypeVector* result_types) {
+ WABT_TRACE(ParseResultList);
+ return ParseUnboundValueTypeList(TokenType::Result, result_types);
+}
+
Result WastParser::ParseInstrList(ExprList* exprs) {
WABT_TRACE(ParseInstrList);
ExprList new_exprs;
@@ -1293,10 +1338,14 @@ Result WastParser::ParsePlainInstr(std::unique_ptr<Expr>* out_expr) {
CHECK_RESULT(ParsePlainInstrVar<CallExpr>(loc, out_expr));
break;
- case TokenType::CallIndirect:
+ case TokenType::CallIndirect: {
Consume();
- CHECK_RESULT(ParsePlainInstrVar<CallIndirectExpr>(loc, out_expr));
+ auto expr = MakeUnique<CallIndirectExpr>(loc);
+ CHECK_RESULT(ParseTypeUseOpt(&expr->decl));
+ CHECK_RESULT(ParseUnboundFuncSignature(&expr->decl.sig));
+ *out_expr = std::move(expr);
break;
+ }
case TokenType::GetLocal:
Consume();
diff --git a/src/wast-parser.h b/src/wast-parser.h
index 0f36cb70..19233079 100644
--- a/src/wast-parser.h
+++ b/src/wast-parser.h
@@ -149,7 +149,9 @@ class WastParser {
Result ParseInlineImport(Import*);
Result ParseTypeUseOpt(FuncDeclaration*);
Result ParseFuncSignature(FuncSignature*, BindingHash* param_bindings);
+ Result ParseUnboundFuncSignature(FuncSignature*);
Result ParseBoundValueTypeList(TokenType, TypeVector*, BindingHash*);
+ Result ParseUnboundValueTypeList(TokenType, TypeVector*);
Result ParseResultList(TypeVector*);
Result ParseInstrList(ExprList*);
Result ParseTerminatingInstrList(ExprList*);
diff --git a/src/wat-writer.cc b/src/wat-writer.cc
index 1a019718..718c312b 100644
--- a/src/wat-writer.cc
+++ b/src/wat-writer.cc
@@ -160,8 +160,6 @@ class WatWriter {
Index GetLabelArity(const Var& var);
Index GetFuncParamCount(const Var& var);
Index GetFuncResultCount(const Var& var);
- Index GetFuncSigParamCount(const Var& var);
- Index GetFuncSigResultCount(const Var& var);
void PushExpr(const Expr* expr, Index operand_count, Index result_count);
void FlushExprTree(const ExprTree& expr_tree);
void FlushExprTreeVector(const std::vector<ExprTree>&);
@@ -524,7 +522,10 @@ void WatWriter::WriteExpr(const Expr* expr) {
case ExprType::CallIndirect:
WritePutsSpace(Opcode::CallIndirect_Opcode.GetName());
- WriteVar(cast<CallIndirectExpr>(expr)->var, NextChar::Newline);
+ WriteOpenSpace("type");
+ WriteVar(cast<CallIndirectExpr>(expr)->decl.type_var,
+ NextChar::Space);
+ WriteCloseNewline();
break;
case ExprType::Compare:
@@ -707,16 +708,6 @@ Index WatWriter::GetFuncResultCount(const Var& var) {
return func ? func->GetNumResults() : 0;
}
-Index WatWriter::GetFuncSigParamCount(const Var& var) {
- const FuncType* func_type = module_->GetFuncType(var);
- return func_type ? func_type->GetNumParams() : 0;
-}
-
-Index WatWriter::GetFuncSigResultCount(const Var& var) {
- const FuncType* func_type = module_->GetFuncType(var);
- return func_type ? func_type->GetNumResults() : 0;
-}
-
void WatWriter::WriteFoldedExpr(const Expr* expr) {
WABT_TRACE_ARGS(WriteFoldedExpr, "%s", GetExprTypeName(*expr));
switch (expr->type()) {
@@ -758,9 +749,9 @@ void WatWriter::WriteFoldedExpr(const Expr* expr) {
}
case ExprType::CallIndirect: {
- const Var& var = cast<CallIndirectExpr>(expr)->var;
- PushExpr(expr, GetFuncSigParamCount(var) + 1,
- GetFuncSigResultCount(var));
+ const auto* ci_expr = cast<CallIndirectExpr>(expr);
+ PushExpr(expr, ci_expr->decl.GetNumParams() + 1,
+ ci_expr->decl.GetNumResults());
break;
}