diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/apply-names.cc | 3 | ||||
-rw-r--r-- | src/binary-reader-ir.cc | 7 | ||||
-rw-r--r-- | src/binary-writer.cc | 3 | ||||
-rw-r--r-- | src/ir.h | 81 | ||||
-rw-r--r-- | src/resolve-names.cc | 3 | ||||
-rw-r--r-- | src/validator.cc | 55 | ||||
-rw-r--r-- | src/wast-parser.cc | 95 | ||||
-rw-r--r-- | src/wast-parser.h | 2 | ||||
-rw-r--r-- | src/wat-writer.cc | 23 |
9 files changed, 177 insertions, 95 deletions
diff --git a/src/apply-names.cc b/src/apply-names.cc index 2080e958..b7db07db 100644 --- a/src/apply-names.cc +++ b/src/apply-names.cc @@ -272,7 +272,8 @@ Result NameApplier::OnCallExpr(CallExpr* expr) { } Result NameApplier::OnCallIndirectExpr(CallIndirectExpr* expr) { - CHECK_RESULT(UseNameForFuncTypeVar(&expr->var)); + if (expr->decl.has_func_type) + CHECK_RESULT(UseNameForFuncTypeVar(&expr->decl.type_var)); return Result::Ok; } diff --git a/src/binary-reader-ir.cc b/src/binary-reader-ir.cc index 5e4dd9f3..6ee7f9dc 100644 --- a/src/binary-reader-ir.cc +++ b/src/binary-reader-ir.cc @@ -595,8 +595,11 @@ Result BinaryReaderIR::OnCallExpr(Index func_index) { Result BinaryReaderIR::OnCallIndirectExpr(Index sig_index) { assert(sig_index < module_->func_types.size()); - return AppendExpr( - MakeUnique<CallIndirectExpr>(Var(sig_index, GetLocation()))); + auto expr = MakeUnique<CallIndirectExpr>(GetLocation()); + expr->decl.has_func_type = true; + expr->decl.type_var = Var(sig_index, GetLocation()); + expr->decl.sig = module_->func_types[sig_index]->sig; + return AppendExpr(std::move(expr)); } Result BinaryReaderIR::OnCompareExpr(Opcode opcode) { diff --git a/src/binary-writer.cc b/src/binary-writer.cc index 687b117d..b0af2c35 100644 --- a/src/binary-writer.cc +++ b/src/binary-writer.cc @@ -398,7 +398,8 @@ void BinaryWriter::WriteExpr(const Module* module, break; } case ExprType::CallIndirect: { - Index index = module->GetFuncTypeIndex(cast<CallIndirectExpr>(expr)->var); + Index index = + module->GetFuncTypeIndex(cast<CallIndirectExpr>(expr)->decl); WriteOpcode(stream_, Opcode::CallIndirect); WriteU32Leb128WithReloc(index, "signature index", RelocType::TypeIndexLEB); @@ -113,6 +113,42 @@ struct Const { }; typedef std::vector<Const> ConstVector; +struct FuncSignature { + TypeVector param_types; + TypeVector result_types; + + Index GetNumParams() const { return param_types.size(); } + Index GetNumResults() const { return result_types.size(); } + Type GetParamType(Index index) const { return param_types[index]; } + Type GetResultType(Index index) const { return result_types[index]; } + + bool operator==(const FuncSignature&) const; +}; + +struct FuncType { + FuncType() = default; + explicit FuncType(string_view name) : name(name.to_string()) {} + + Index GetNumParams() const { return sig.GetNumParams(); } + Index GetNumResults() const { return sig.GetNumResults(); } + Type GetParamType(Index index) const { return sig.GetParamType(index); } + Type GetResultType(Index index) const { return sig.GetResultType(index); } + + std::string name; + FuncSignature sig; +}; + +struct FuncDeclaration { + Index GetNumParams() const { return sig.GetNumParams(); } + Index GetNumResults() const { return sig.GetNumResults(); } + Type GetParamType(Index index) const { return sig.GetParamType(index); } + Type GetResultType(Index index) const { return sig.GetResultType(index); } + + bool has_func_type = false; + Var type_var; + FuncSignature sig; +}; + enum class ExprType { AtomicLoad, AtomicStore, @@ -246,7 +282,6 @@ class VarExpr : public ExprMixin<TypeEnum> { typedef VarExpr<ExprType::Br> BrExpr; typedef VarExpr<ExprType::BrIf> BrIfExpr; typedef VarExpr<ExprType::Call> CallExpr; -typedef VarExpr<ExprType::CallIndirect> CallIndirectExpr; typedef VarExpr<ExprType::GetGlobal> GetGlobalExpr; typedef VarExpr<ExprType::GetLocal> GetLocalExpr; typedef VarExpr<ExprType::Rethrow> RethrowExpr; @@ -255,6 +290,14 @@ typedef VarExpr<ExprType::SetLocal> SetLocalExpr; typedef VarExpr<ExprType::TeeLocal> TeeLocalExpr; typedef VarExpr<ExprType::Throw> ThrowExpr; +class CallIndirectExpr : public ExprMixin<ExprType::CallIndirect> { + public: + explicit CallIndirectExpr(const Location& loc = Location()) + : ExprMixin<ExprType::CallIndirect>(loc) {} + + FuncDeclaration decl; +}; + template <ExprType TypeEnum> class BlockExprBase : public ExprMixin<TypeEnum> { public: @@ -337,42 +380,6 @@ struct Exception { TypeVector sig; }; -struct FuncSignature { - TypeVector param_types; - TypeVector result_types; - - Index GetNumParams() const { return param_types.size(); } - Index GetNumResults() const { return result_types.size(); } - Type GetParamType(Index index) const { return param_types[index]; } - Type GetResultType(Index index) const { return result_types[index]; } - - bool operator==(const FuncSignature&) const; -}; - -struct FuncType { - FuncType() = default; - explicit FuncType(string_view name) : name(name.to_string()) {} - - Index GetNumParams() const { return sig.GetNumParams(); } - Index GetNumResults() const { return sig.GetNumResults(); } - Type GetParamType(Index index) const { return sig.GetParamType(index); } - Type GetResultType(Index index) const { return sig.GetResultType(index); } - - std::string name; - FuncSignature sig; -}; - -struct FuncDeclaration { - Index GetNumParams() const { return sig.GetNumParams(); } - Index GetNumResults() const { return sig.GetNumResults(); } - Type GetParamType(Index index) const { return sig.GetParamType(index); } - Type GetResultType(Index index) const { return sig.GetResultType(index); } - - bool has_func_type = false; - Var type_var; - FuncSignature sig; -}; - struct Func { Func() = default; explicit Func(string_view name) : name(name.to_string()) {} diff --git a/src/resolve-names.cc b/src/resolve-names.cc index dfdee366..2b6199b8 100644 --- a/src/resolve-names.cc +++ b/src/resolve-names.cc @@ -245,7 +245,8 @@ Result NameResolver::OnCallExpr(CallExpr* expr) { } Result NameResolver::OnCallIndirectExpr(CallIndirectExpr* expr) { - ResolveFuncTypeVar(&expr->var); + if (expr->decl.has_func_type) + ResolveFuncTypeVar(&expr->decl.type_var); return Result::Ok; } diff --git a/src/validator.cc b/src/validator.cc index 9ef679ec..c4ad1f9a 100644 --- a/src/validator.cc +++ b/src/validator.cc @@ -25,6 +25,7 @@ #include "src/binary-reader.h" #include "src/cast.h" +#include "src/expr-visitor.h" #include "src/error-handler.h" #include "src/ir.h" #include "src/type-checker.h" @@ -115,7 +116,9 @@ class Validator { template <typename T> void CheckAtomicExpr(const T* expr, Result (TypeChecker::*func)(Opcode)); void CheckExpr(const Expr* expr); - void CheckFuncSignature(const Location* loc, const Func* func); + void CheckFuncSignature(const Location* loc, const FuncDeclaration& decl); + class CheckFuncSignatureExprVisitorDelegate; + void CheckFunc(const Location* loc, const Func* func); void PrintConstExprError(const Location* loc, const char* desc); void CheckConstInitExpr(const Location* loc, @@ -496,15 +499,16 @@ void Validator::CheckExpr(const Expr* expr) { } case ExprType::CallIndirect: { - const FuncType* func_type; if (current_module_->tables.size() == 0) { PrintError(&expr->loc, "found call_indirect operator, but no table"); } - if (Succeeded(CheckFuncTypeVar(&cast<CallIndirectExpr>(expr)->var, - &func_type))) { - typechecker_.OnCallIndirect(&func_type->sig.param_types, - &func_type->sig.result_types); + auto ci_expr = cast<CallIndirectExpr>(expr); + if (ci_expr->decl.has_func_type) { + const FuncType* func_type; + CheckFuncTypeVar(&ci_expr->decl.type_var, &func_type); } + typechecker_.OnCallIndirect(&ci_expr->decl.sig.param_types, + &ci_expr->decl.sig.result_types); break; } @@ -667,13 +671,14 @@ void Validator::CheckExpr(const Expr* expr) { } } -void Validator::CheckFuncSignature(const Location* loc, const Func* func) { - if (func->decl.has_func_type) { +void Validator::CheckFuncSignature(const Location* loc, + const FuncDeclaration& decl) { + if (decl.has_func_type) { const FuncType* func_type; - if (Succeeded(CheckFuncTypeVar(&func->decl.type_var, &func_type))) { - CheckTypes(loc, func->decl.sig.result_types, func_type->sig.result_types, + if (Succeeded(CheckFuncTypeVar(&decl.type_var, &func_type))) { + CheckTypes(loc, decl.sig.result_types, func_type->sig.result_types, "function", "result"); - CheckTypes(loc, func->decl.sig.param_types, func_type->sig.param_types, + CheckTypes(loc, decl.sig.param_types, func_type->sig.param_types, "function", "argument"); } } @@ -681,7 +686,7 @@ void Validator::CheckFuncSignature(const Location* loc, const Func* func) { void Validator::CheckFunc(const Location* loc, const Func* func) { current_func_ = func; - CheckFuncSignature(loc, func); + CheckFuncSignature(loc, func->decl); if (func->GetNumResults() > 1) { PrintError(loc, "multiple result values not currently supported."); // Don't run any other checks, the won't test the result_type properly. @@ -1192,13 +1197,35 @@ Result Validator::CheckScript(const Script* script) { return result_; } +class Validator::CheckFuncSignatureExprVisitorDelegate + : public ExprVisitor::DelegateNop { + public: + explicit CheckFuncSignatureExprVisitorDelegate(Validator* validator) + : validator_(validator) {} + + Result OnCallIndirectExpr(CallIndirectExpr* expr) override { + validator_->CheckFuncSignature(&expr->loc, expr->decl); + return Result::Ok; + } + + private: + Validator* validator_; +}; + Result Validator::CheckAllFuncSignatures(const Module* module) { current_module_ = module; for (const ModuleField& field : module->fields) { switch (field.type()) { - case ModuleFieldType::Func: - CheckFuncSignature(&field.loc, &cast<FuncModuleField>(&field)->func); + case ModuleFieldType::Func: { + auto func_field = cast<FuncModuleField>(&field); + CheckFuncSignature(&field.loc, func_field->func.decl); + CheckFuncSignatureExprVisitorDelegate delegate(this); + ExprVisitor visitor(&delegate); + // TODO(binji): would rather not do a const_cast here, but the visitor + // is non-const only. + visitor.VisitFunc(const_cast<Func*>(&func_field->func)); break; + } default: break; diff --git a/src/wast-parser.cc b/src/wast-parser.cc index ee3acdba..577b18c7 100644 --- a/src/wast-parser.cc +++ b/src/wast-parser.cc @@ -19,6 +19,7 @@ #include "src/binary-reader.h" #include "src/binary-reader-ir.h" #include "src/cast.h" +#include "src/expr-visitor.h" #include "src/error-handler.h" #include "src/make-unique.h" #include "src/utf8.h" @@ -253,15 +254,56 @@ bool IsEmptySignature(const FuncSignature* sig) { return sig->result_types.empty() && sig->param_types.empty(); } +void ResolveFuncType(const Location& loc, + Module* module, + FuncDeclaration* decl) { + // Resolve func type variables where the signature was not specified + // explicitly, e.g.: (func (type 1) ...) + if (decl->has_func_type && IsEmptySignature(&decl->sig)) { + FuncType* func_type = module->GetFuncType(decl->type_var); + if (func_type) { + decl->sig = func_type->sig; + } + } + + // Resolve implicitly defined function types, e.g.: (func (param i32) ...) + if (!decl->has_func_type) { + Index func_type_index = module->GetFuncTypeIndex(decl->sig); + if (func_type_index == kInvalidIndex) { + auto func_type_field = MakeUnique<FuncTypeModuleField>(loc); + func_type_field->func_type.sig = decl->sig; + module->AppendField(std::move(func_type_field)); + } + } +} + +class ResolveFuncTypesExprVisitorDelegate : public ExprVisitor::DelegateNop { + public: + explicit ResolveFuncTypesExprVisitorDelegate(Module* module) + : module_(module) {} + + Result OnCallIndirectExpr(CallIndirectExpr* expr) override { + ResolveFuncType(expr->loc, module_, &expr->decl); + return Result::Ok; + } + + private: + Module* module_; +}; + void ResolveFuncTypes(Module* module) { for (ModuleField& field : module->fields) { Func* func = nullptr; + FuncDeclaration* decl = nullptr; if (auto* func_field = dyn_cast<FuncModuleField>(&field)) { func = &func_field->func; + decl = &func->decl; } else if (auto* import_field = dyn_cast<ImportModuleField>(&field)) { if (auto* func_import = dyn_cast<FuncImport>(import_field->import.get())) { - func = &func_import->func; + // Only check the declaration, not the function itself, since it is an + // import. + decl = &func_import->func.decl; } else { continue; } @@ -269,23 +311,13 @@ void ResolveFuncTypes(Module* module) { continue; } - // Resolve func type variables where the signature was not specified - // explicitly, e.g.: (func (type 1) ...) - if (func->decl.has_func_type && IsEmptySignature(&func->decl.sig)) { - FuncType* func_type = module->GetFuncType(func->decl.type_var); - if (func_type) { - func->decl.sig = func_type->sig; - } - } + if (decl) + ResolveFuncType(field.loc, module, decl); - // Resolve implicitly defined function types, e.g.: (func (param i32) ...) - if (!func->decl.has_func_type) { - Index func_type_index = module->GetFuncTypeIndex(func->decl.sig); - if (func_type_index == kInvalidIndex) { - auto func_type_field = MakeUnique<FuncTypeModuleField>(field.loc); - func_type_field->func_type.sig = func->decl.sig; - module->AppendField(std::move(func_type_field)); - } + if (func) { + ResolveFuncTypesExprVisitorDelegate delegate(module); + ExprVisitor visitor(&delegate); + visitor.VisitFunc(func); } } } @@ -1143,6 +1175,13 @@ Result WastParser::ParseFuncSignature(FuncSignature* sig, return Result::Ok; } +Result WastParser::ParseUnboundFuncSignature(FuncSignature* sig) { + WABT_TRACE(ParseUnboundFuncSignature); + CHECK_RESULT(ParseUnboundValueTypeList(TokenType::Param, &sig->param_types)); + CHECK_RESULT(ParseResultList(&sig->result_types)); + return Result::Ok; +} + Result WastParser::ParseBoundValueTypeList(TokenType token, TypeVector* types, BindingHash* bindings) { @@ -1164,15 +1203,21 @@ Result WastParser::ParseBoundValueTypeList(TokenType token, return Result::Ok; } -Result WastParser::ParseResultList(TypeVector* result_types) { - WABT_TRACE(ParseResultList); - while (MatchLpar(TokenType::Result)) { - CHECK_RESULT(ParseValueTypeList(result_types)); +Result WastParser::ParseUnboundValueTypeList(TokenType token, + TypeVector* types) { + WABT_TRACE(ParseUnboundValueTypeList); + while (MatchLpar(token)) { + CHECK_RESULT(ParseValueTypeList(types)); EXPECT(Rpar); } return Result::Ok; } +Result WastParser::ParseResultList(TypeVector* result_types) { + WABT_TRACE(ParseResultList); + return ParseUnboundValueTypeList(TokenType::Result, result_types); +} + Result WastParser::ParseInstrList(ExprList* exprs) { WABT_TRACE(ParseInstrList); ExprList new_exprs; @@ -1293,10 +1338,14 @@ Result WastParser::ParsePlainInstr(std::unique_ptr<Expr>* out_expr) { CHECK_RESULT(ParsePlainInstrVar<CallExpr>(loc, out_expr)); break; - case TokenType::CallIndirect: + case TokenType::CallIndirect: { Consume(); - CHECK_RESULT(ParsePlainInstrVar<CallIndirectExpr>(loc, out_expr)); + auto expr = MakeUnique<CallIndirectExpr>(loc); + CHECK_RESULT(ParseTypeUseOpt(&expr->decl)); + CHECK_RESULT(ParseUnboundFuncSignature(&expr->decl.sig)); + *out_expr = std::move(expr); break; + } case TokenType::GetLocal: Consume(); diff --git a/src/wast-parser.h b/src/wast-parser.h index 0f36cb70..19233079 100644 --- a/src/wast-parser.h +++ b/src/wast-parser.h @@ -149,7 +149,9 @@ class WastParser { Result ParseInlineImport(Import*); Result ParseTypeUseOpt(FuncDeclaration*); Result ParseFuncSignature(FuncSignature*, BindingHash* param_bindings); + Result ParseUnboundFuncSignature(FuncSignature*); Result ParseBoundValueTypeList(TokenType, TypeVector*, BindingHash*); + Result ParseUnboundValueTypeList(TokenType, TypeVector*); Result ParseResultList(TypeVector*); Result ParseInstrList(ExprList*); Result ParseTerminatingInstrList(ExprList*); diff --git a/src/wat-writer.cc b/src/wat-writer.cc index 1a019718..718c312b 100644 --- a/src/wat-writer.cc +++ b/src/wat-writer.cc @@ -160,8 +160,6 @@ class WatWriter { Index GetLabelArity(const Var& var); Index GetFuncParamCount(const Var& var); Index GetFuncResultCount(const Var& var); - Index GetFuncSigParamCount(const Var& var); - Index GetFuncSigResultCount(const Var& var); void PushExpr(const Expr* expr, Index operand_count, Index result_count); void FlushExprTree(const ExprTree& expr_tree); void FlushExprTreeVector(const std::vector<ExprTree>&); @@ -524,7 +522,10 @@ void WatWriter::WriteExpr(const Expr* expr) { case ExprType::CallIndirect: WritePutsSpace(Opcode::CallIndirect_Opcode.GetName()); - WriteVar(cast<CallIndirectExpr>(expr)->var, NextChar::Newline); + WriteOpenSpace("type"); + WriteVar(cast<CallIndirectExpr>(expr)->decl.type_var, + NextChar::Space); + WriteCloseNewline(); break; case ExprType::Compare: @@ -707,16 +708,6 @@ Index WatWriter::GetFuncResultCount(const Var& var) { return func ? func->GetNumResults() : 0; } -Index WatWriter::GetFuncSigParamCount(const Var& var) { - const FuncType* func_type = module_->GetFuncType(var); - return func_type ? func_type->GetNumParams() : 0; -} - -Index WatWriter::GetFuncSigResultCount(const Var& var) { - const FuncType* func_type = module_->GetFuncType(var); - return func_type ? func_type->GetNumResults() : 0; -} - void WatWriter::WriteFoldedExpr(const Expr* expr) { WABT_TRACE_ARGS(WriteFoldedExpr, "%s", GetExprTypeName(*expr)); switch (expr->type()) { @@ -758,9 +749,9 @@ void WatWriter::WriteFoldedExpr(const Expr* expr) { } case ExprType::CallIndirect: { - const Var& var = cast<CallIndirectExpr>(expr)->var; - PushExpr(expr, GetFuncSigParamCount(var) + 1, - GetFuncSigResultCount(var)); + const auto* ci_expr = cast<CallIndirectExpr>(expr); + PushExpr(expr, ci_expr->decl.GetNumParams() + 1, + ci_expr->decl.GetNumResults()); break; } |