/* * Copyright 2016 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "src/validator.h" #include #include #include #include #include "config.h" #include "src/binary-reader.h" #include "src/cast.h" #include "src/error-handler.h" #include "src/ir.h" #include "src/type-checker.h" #include "src/wast-parser-lexer-shared.h" namespace wabt { namespace { class Validator { public: WABT_DISALLOW_COPY_AND_ASSIGN(Validator); Validator(ErrorHandler*, WastLexer*, const Script*); Result CheckModule(const Module* module); Result CheckScript(const Script* script); Result CheckAllFuncSignatures(const Module* module); private: struct ActionResult { enum class Kind { Error, Types, Type, } kind; union { const TypeVector* types; Type type; }; }; void WABT_PRINTF_FORMAT(3, 4) PrintError(const Location* loc, const char* fmt, ...); void OnTypecheckerError(const char* msg); Result CheckVar(Index max_index, const Var* var, const char* desc, Index* out_index); Result CheckFuncVar(const Var* var, const Func** out_func); Result CheckGlobalVar(const Var* var, const Global** out_global, Index* out_global_index); Type GetGlobalVarTypeOrAny(const Var* var); Result CheckFuncTypeVar(const Var* var, const FuncType** out_func_type); Result CheckTableVar(const Var* var, const Table** out_table); Result CheckMemoryVar(const Var* var, const Memory** out_memory); Result CheckLocalVar(const Var* var, Type* out_type); Type GetLocalVarTypeOrAny(const Var* var); void CheckAlign(const Location* loc, Address alignment, Address natural_alignment); void CheckAtomicAlign(const Location* loc, Address alignment, Address natural_alignment); void CheckType(const Location* loc, Type actual, Type expected, const char* desc); void CheckTypeIndex(const Location* loc, Type actual, Type expected, const char* desc, Index index, const char* index_kind); void CheckTypes(const Location* loc, const TypeVector& actual, const TypeVector& expected, const char* desc, const char* index_kind); void CheckConstTypes(const Location* loc, const TypeVector& actual, const ConstVector& expected, const char* desc); void CheckConstType(const Location* loc, Type actual, const ConstVector& expected, const char* desc); void CheckAssertReturnNanType(const Location* loc, Type actual, const char* desc); void CheckExprList(const Location* loc, const ExprList& exprs); bool CheckHasMemory(const Location* loc, Opcode opcode); void CheckHasSharedMemory(const Location* loc, Opcode opcode); void CheckBlockSig(const Location* loc, Opcode opcode, const BlockSignature* sig); template void CheckAtomicExpr(const T* expr, Result (TypeChecker::*func)(Opcode)); void CheckExpr(const Expr* expr); void CheckFuncSignature(const Location* loc, const Func* func); void CheckFunc(const Location* loc, const Func* func); void PrintConstExprError(const Location* loc, const char* desc); void CheckConstInitExpr(const Location* loc, const ExprList& expr, Type expected_type, const char* desc); void CheckGlobal(const Location* loc, const Global* global); void CheckLimits(const Location* loc, const Limits* limits, uint64_t absolute_max, const char* desc, LimitsShareable sharing); void CheckTable(const Location* loc, const Table* table); void CheckElemSegments(const Module* module); void CheckMemory(const Location* loc, const Memory* memory); void CheckDataSegments(const Module* module); void CheckImport(const Location* loc, const Import* import); void CheckExport(const Location* loc, const Export* export_); void CheckDuplicateExportBindings(const Module* module); const TypeVector* CheckInvoke(const InvokeAction* action); Result CheckGet(const GetAction* action, Type* out_type); ActionResult CheckAction(const Action* action); void CheckAssertReturnNanCommand(const Action* action); void CheckCommand(const Command* command); void CheckExcept(const Location* loc, const Exception* Except); Result CheckExceptVar(const Var* var, const Exception** out_except); ErrorHandler* error_handler_ = nullptr; WastLexer* lexer_ = nullptr; const Script* script_ = nullptr; const Module* current_module_ = nullptr; const Func* current_func_ = nullptr; Index current_table_index_ = 0; Index current_memory_index_ = 0; Index current_global_index_ = 0; Index num_imported_globals_ = 0; Index current_except_index_ = 0; TypeChecker typechecker_; // Cached for access by OnTypecheckerError. const Location* expr_loc_ = nullptr; Result result_ = Result::Ok; }; Validator::Validator(ErrorHandler* error_handler, WastLexer* lexer, const Script* script) : error_handler_(error_handler), lexer_(lexer), script_(script) { typechecker_.set_error_callback( [this](const char* msg) { OnTypecheckerError(msg); }); } void Validator::PrintError(const Location* loc, const char* fmt, ...) { result_ = Result::Error; va_list args; va_start(args, fmt); WastFormatError(error_handler_, loc, lexer_, fmt, args); va_end(args); } void Validator::OnTypecheckerError(const char* msg) { PrintError(expr_loc_, "%s", msg); } static bool is_power_of_two(uint32_t x) { return x && ((x & (x - 1)) == 0); } static Address get_opcode_natural_alignment(Opcode opcode) { Address memory_size = opcode.GetMemorySize(); assert(memory_size != 0); return memory_size; } Result Validator::CheckVar(Index max_index, const Var* var, const char* desc, Index* out_index) { if (var->index() < max_index) { if (out_index) *out_index = var->index(); return Result::Ok; } PrintError(&var->loc, "%s variable out of range (max %" PRIindex ")", desc, max_index); return Result::Error; } Result Validator::CheckFuncVar(const Var* var, const Func** out_func) { Index index; CHECK_RESULT(CheckVar(current_module_->funcs.size(), var, "function", &index)); if (out_func) *out_func = current_module_->funcs[index]; return Result::Ok; } Result Validator::CheckGlobalVar(const Var* var, const Global** out_global, Index* out_global_index) { Index index; CHECK_RESULT( CheckVar(current_module_->globals.size(), var, "global", &index)); if (out_global) *out_global = current_module_->globals[index]; if (out_global_index) *out_global_index = index; return Result::Ok; } Type Validator::GetGlobalVarTypeOrAny(const Var* var) { const Global* global; if (Succeeded(CheckGlobalVar(var, &global, nullptr))) return global->type; return Type::Any; } Result Validator::CheckFuncTypeVar(const Var* var, const FuncType** out_func_type) { Index index; CHECK_RESULT(CheckVar(current_module_->func_types.size(), var, "function type", &index)); if (out_func_type) *out_func_type = current_module_->func_types[index]; return Result::Ok; } Result Validator::CheckTableVar(const Var* var, const Table** out_table) { Index index; CHECK_RESULT(CheckVar(current_module_->tables.size(), var, "table", &index)); if (out_table) *out_table = current_module_->tables[index]; return Result::Ok; } Result Validator::CheckMemoryVar(const Var* var, const Memory** out_memory) { Index index; CHECK_RESULT( CheckVar(current_module_->memories.size(), var, "memory", &index)); if (out_memory) *out_memory = current_module_->memories[index]; return Result::Ok; } Result Validator::CheckLocalVar(const Var* var, Type* out_type) { const Func* func = current_func_; Index max_index = func->GetNumParamsAndLocals(); Index index = func->GetLocalIndex(*var); if (index < max_index) { if (out_type) { Index num_params = func->GetNumParams(); if (index < num_params) { *out_type = func->GetParamType(index); } else { *out_type = current_func_->local_types[index - num_params]; } } return Result::Ok; } if (var->is_name()) { PrintError(&var->loc, "undefined local variable \"%s\"", var->name().c_str()); } else { PrintError(&var->loc, "local variable out of range (max %" PRIindex ")", max_index); } return Result::Error; } Type Validator::GetLocalVarTypeOrAny(const Var* var) { Type type = Type::Any; CheckLocalVar(var, &type); return type; } void Validator::CheckAlign(const Location* loc, Address alignment, Address natural_alignment) { if (alignment != WABT_USE_NATURAL_ALIGNMENT) { if (!is_power_of_two(alignment)) PrintError(loc, "alignment must be power-of-two"); if (alignment > natural_alignment) { PrintError(loc, "alignment must not be larger than natural alignment (%u)", natural_alignment); } } } void Validator::CheckAtomicAlign(const Location* loc, Address alignment, Address natural_alignment) { if (alignment != WABT_USE_NATURAL_ALIGNMENT) { if (!is_power_of_two(alignment)) PrintError(loc, "alignment must be power-of-two"); if (alignment != natural_alignment) { PrintError(loc, "alignment must be equal to natural alignment (%u)", natural_alignment); } } } void Validator::CheckType(const Location* loc, Type actual, Type expected, const char* desc) { if (expected != actual) { PrintError(loc, "type mismatch at %s. got %s, expected %s", desc, GetTypeName(actual), GetTypeName(expected)); } } void Validator::CheckTypeIndex(const Location* loc, Type actual, Type expected, const char* desc, Index index, const char* index_kind) { if (expected != actual && expected != Type::Any && actual != Type::Any) { PrintError( loc, "type mismatch for %s %" PRIindex " of %s. got %s, expected %s", index_kind, index, desc, GetTypeName(actual), GetTypeName(expected)); } } void Validator::CheckTypes(const Location* loc, const TypeVector& actual, const TypeVector& expected, const char* desc, const char* index_kind) { if (actual.size() == expected.size()) { for (size_t i = 0; i < actual.size(); ++i) { CheckTypeIndex(loc, actual[i], expected[i], desc, i, index_kind); } } else { PrintError(loc, "expected %" PRIzd " %ss, got %" PRIzd, expected.size(), index_kind, actual.size()); } } void Validator::CheckConstTypes(const Location* loc, const TypeVector& actual, const ConstVector& expected, const char* desc) { if (actual.size() == expected.size()) { for (size_t i = 0; i < actual.size(); ++i) { CheckTypeIndex(loc, actual[i], expected[i].type, desc, i, "result"); } } else { PrintError(loc, "expected %" PRIzd " results, got %" PRIzd, expected.size(), actual.size()); } } void Validator::CheckConstType(const Location* loc, Type actual, const ConstVector& expected, const char* desc) { TypeVector actual_types; if (actual != Type::Void) actual_types.push_back(actual); CheckConstTypes(loc, actual_types, expected, desc); } void Validator::CheckAssertReturnNanType(const Location* loc, Type actual, const char* desc) { // When using assert_return_nan, the result can be either a f32 or f64 type // so we special case it here. if (actual != Type::F32 && actual != Type::F64) { PrintError(loc, "type mismatch at %s. got %s, expected f32 or f64", desc, GetTypeName(actual)); } } void Validator::CheckExprList(const Location* loc, const ExprList& exprs) { for (const Expr& expr : exprs) CheckExpr(&expr); } bool Validator::CheckHasMemory(const Location* loc, Opcode opcode) { if (current_module_->memories.size() == 0) { PrintError(loc, "%s requires an imported or defined memory.", opcode.GetName()); return false; } return true; } void Validator::CheckHasSharedMemory(const Location* loc, Opcode opcode) { if (CheckHasMemory(loc, opcode)) { Memory* memory = current_module_->memories[0]; if (!memory->page_limits.is_shared) { PrintError(loc, "%s requires memory to be shared.", opcode.GetName()); } } } void Validator::CheckBlockSig(const Location* loc, Opcode opcode, const BlockSignature* sig) { if (sig->size() > 1) { PrintError(loc, "multiple %s signature result types not currently supported.", opcode.GetName()); } } template void Validator::CheckAtomicExpr(const T* expr, Result (TypeChecker::*func)(Opcode)) { CheckHasSharedMemory(&expr->loc, expr->opcode); CheckAtomicAlign(&expr->loc, expr->align, get_opcode_natural_alignment(expr->opcode)); (typechecker_.*func)(expr->opcode); } void Validator::CheckExpr(const Expr* expr) { expr_loc_ = &expr->loc; switch (expr->type()) { case ExprType::AtomicLoad: CheckAtomicExpr(cast(expr), &TypeChecker::OnAtomicLoad); break; case ExprType::AtomicRmw: CheckAtomicExpr(cast(expr), &TypeChecker::OnAtomicRmw); break; case ExprType::AtomicRmwCmpxchg: CheckAtomicExpr(cast(expr), &TypeChecker::OnAtomicRmwCmpxchg); break; case ExprType::AtomicStore: CheckAtomicExpr(cast(expr), &TypeChecker::OnAtomicStore); break; case ExprType::Binary: typechecker_.OnBinary(cast(expr)->opcode); break; case ExprType::Block: { auto block_expr = cast(expr); CheckBlockSig(&block_expr->loc, Opcode::Block, &block_expr->block.sig); typechecker_.OnBlock(&block_expr->block.sig); CheckExprList(&block_expr->loc, block_expr->block.exprs); typechecker_.OnEnd(); break; } case ExprType::Br: typechecker_.OnBr(cast(expr)->var.index()); break; case ExprType::BrIf: typechecker_.OnBrIf(cast(expr)->var.index()); break; case ExprType::BrTable: { auto br_table_expr = cast(expr); typechecker_.BeginBrTable(); for (const Var& var : br_table_expr->targets) { typechecker_.OnBrTableTarget(var.index()); } typechecker_.OnBrTableTarget(br_table_expr->default_target.index()); typechecker_.EndBrTable(); break; } case ExprType::Call: { const Func* callee; if (Succeeded(CheckFuncVar(&cast(expr)->var, &callee))) { typechecker_.OnCall(&callee->decl.sig.param_types, &callee->decl.sig.result_types); } break; } case ExprType::CallIndirect: { const FuncType* func_type; if (current_module_->tables.size() == 0) { PrintError(&expr->loc, "found call_indirect operator, but no table"); } if (Succeeded(CheckFuncTypeVar(&cast(expr)->var, &func_type))) { typechecker_.OnCallIndirect(&func_type->sig.param_types, &func_type->sig.result_types); } break; } case ExprType::Compare: typechecker_.OnCompare(cast(expr)->opcode); break; case ExprType::Const: typechecker_.OnConst(cast(expr)->const_.type); break; case ExprType::Convert: typechecker_.OnConvert(cast(expr)->opcode); break; case ExprType::Drop: typechecker_.OnDrop(); break; case ExprType::GetGlobal: typechecker_.OnGetGlobal( GetGlobalVarTypeOrAny(&cast(expr)->var)); break; case ExprType::GetLocal: typechecker_.OnGetLocal( GetLocalVarTypeOrAny(&cast(expr)->var)); break; case ExprType::GrowMemory: CheckHasMemory(&expr->loc, Opcode::GrowMemory); typechecker_.OnGrowMemory(); break; case ExprType::If: { auto if_expr = cast(expr); CheckBlockSig(&if_expr->loc, Opcode::If, &if_expr->true_.sig); typechecker_.OnIf(&if_expr->true_.sig); CheckExprList(&if_expr->loc, if_expr->true_.exprs); if (!if_expr->false_.empty()) { typechecker_.OnElse(); CheckExprList(&if_expr->loc, if_expr->false_); } typechecker_.OnEnd(); break; } case ExprType::Load: { auto load_expr = cast(expr); CheckHasMemory(&load_expr->loc, load_expr->opcode); CheckAlign(&load_expr->loc, load_expr->align, get_opcode_natural_alignment(load_expr->opcode)); typechecker_.OnLoad(load_expr->opcode); break; } case ExprType::Loop: { auto loop_expr = cast(expr); CheckBlockSig(&loop_expr->loc, Opcode::Loop, &loop_expr->block.sig); typechecker_.OnLoop(&loop_expr->block.sig); CheckExprList(&loop_expr->loc, loop_expr->block.exprs); typechecker_.OnEnd(); break; } case ExprType::CurrentMemory: CheckHasMemory(&expr->loc, Opcode::CurrentMemory); typechecker_.OnCurrentMemory(); break; case ExprType::Nop: break; case ExprType::Rethrow: typechecker_.OnRethrow(cast(expr)->var.index()); break; case ExprType::Return: typechecker_.OnReturn(); break; case ExprType::Select: typechecker_.OnSelect(); break; case ExprType::SetGlobal: typechecker_.OnSetGlobal( GetGlobalVarTypeOrAny(&cast(expr)->var)); break; case ExprType::SetLocal: typechecker_.OnSetLocal( GetLocalVarTypeOrAny(&cast(expr)->var)); break; case ExprType::Store: { auto store_expr = cast(expr); CheckHasMemory(&store_expr->loc, store_expr->opcode); CheckAlign(&store_expr->loc, store_expr->align, get_opcode_natural_alignment(store_expr->opcode)); typechecker_.OnStore(store_expr->opcode); break; } case ExprType::TeeLocal: typechecker_.OnTeeLocal( GetLocalVarTypeOrAny(&cast(expr)->var)); break; case ExprType::Throw: const Exception* except; if (Succeeded(CheckExceptVar(&cast(expr)->var, &except))) { typechecker_.OnThrow(&except->sig); } break; case ExprType::TryBlock: { auto try_expr = cast(expr); CheckBlockSig(&try_expr->loc, Opcode::Try, &try_expr->block.sig); typechecker_.OnTryBlock(&try_expr->block.sig); CheckExprList(&try_expr->loc, try_expr->block.exprs); if (try_expr->catches.empty()) PrintError(&try_expr->loc, "TryBlock: doesn't have any catch clauses"); bool found_catch_all = false; for (const Catch& catch_ : try_expr->catches) { typechecker_.OnCatchBlock(&try_expr->block.sig); if (catch_.IsCatchAll()) { found_catch_all = true; } else { if (found_catch_all) PrintError(&catch_.loc, "Appears after catch all block"); const Exception* except = nullptr; if (Succeeded(CheckExceptVar(&catch_.var, &except))) { typechecker_.OnCatch(&except->sig); } } CheckExprList(&catch_.loc, catch_.exprs); } typechecker_.OnEnd(); break; } case ExprType::Unary: typechecker_.OnUnary(cast(expr)->opcode); break; case ExprType::Unreachable: typechecker_.OnUnreachable(); break; case ExprType::Wait: CheckAtomicExpr(cast(expr), &TypeChecker::OnWait); break; case ExprType::Wake: CheckAtomicExpr(cast(expr), &TypeChecker::OnWake); break; } } void Validator::CheckFuncSignature(const Location* loc, const Func* func) { if (func->decl.has_func_type) { const FuncType* func_type; if (Succeeded(CheckFuncTypeVar(&func->decl.type_var, &func_type))) { CheckTypes(loc, func->decl.sig.result_types, func_type->sig.result_types, "function", "result"); CheckTypes(loc, func->decl.sig.param_types, func_type->sig.param_types, "function", "argument"); } } } void Validator::CheckFunc(const Location* loc, const Func* func) { current_func_ = func; CheckFuncSignature(loc, func); if (func->GetNumResults() > 1) { PrintError(loc, "multiple result values not currently supported."); // Don't run any other checks, the won't test the result_type properly. return; } expr_loc_ = loc; typechecker_.BeginFunction(&func->decl.sig.result_types); CheckExprList(loc, func->exprs); typechecker_.EndFunction(); current_func_ = nullptr; } void Validator::PrintConstExprError(const Location* loc, const char* desc) { PrintError(loc, "invalid %s, must be a constant expression; either *.const or " "get_global.", desc); } void Validator::CheckConstInitExpr(const Location* loc, const ExprList& exprs, Type expected_type, const char* desc) { Type type = Type::Void; if (!exprs.empty()) { if (exprs.size() > 1) { PrintConstExprError(loc, desc); return; } const Expr* expr = &exprs.front(); loc = &expr->loc; switch (expr->type()) { case ExprType::Const: type = cast(expr)->const_.type; break; case ExprType::GetGlobal: { const Global* ref_global = nullptr; Index ref_global_index; if (Failed(CheckGlobalVar(&cast(expr)->var, &ref_global, &ref_global_index))) { return; } type = ref_global->type; if (ref_global_index >= num_imported_globals_) { PrintError( loc, "initializer expression can only reference an imported global"); } if (ref_global->mutable_) { PrintError( loc, "initializer expression cannot reference a mutable global"); } break; } default: PrintConstExprError(loc, desc); return; } } CheckType(loc, type, expected_type, desc); } void Validator::CheckGlobal(const Location* loc, const Global* global) { CheckConstInitExpr(loc, global->init_expr, global->type, "global initializer expression"); } void Validator::CheckLimits(const Location* loc, const Limits* limits, uint64_t absolute_max, const char* desc, LimitsShareable sharing) { if (limits->initial > absolute_max) { PrintError(loc, "initial %s (%" PRIu64 ") must be <= (%" PRIu64 ")", desc, limits->initial, absolute_max); } if (limits->has_max) { if (limits->max > absolute_max) { PrintError(loc, "max %s (%" PRIu64 ") must be <= (%" PRIu64 ")", desc, limits->max, absolute_max); } if (limits->max < limits->initial) { PrintError(loc, "max %s (%" PRIu64 ") must be >= initial %s (%" PRIu64 ")", desc, limits->max, desc, limits->initial); } } if (limits->is_shared) { if (sharing == LimitsShareable::NotAllowed) { PrintError(loc, "tables may not be shared"); return; } if (!limits->has_max) { PrintError(loc, "shared memories must have max sizes"); } } } void Validator::CheckTable(const Location* loc, const Table* table) { if (current_table_index_ == 1) PrintError(loc, "only one table allowed"); CheckLimits(loc, &table->elem_limits, UINT32_MAX, "elems", LimitsShareable::NotAllowed); } void Validator::CheckElemSegments(const Module* module) { for (const ModuleField& field : module->fields) { if (auto elem_segment_field = dyn_cast(&field)) { auto&& elem_segment = elem_segment_field->elem_segment; const Table* table; if (Failed(CheckTableVar(&elem_segment.table_var, &table))) continue; for (const Var& var : elem_segment.vars) { CheckFuncVar(&var, nullptr); } CheckConstInitExpr(&field.loc, elem_segment.offset, Type::I32, "elem segment offset"); } } } void Validator::CheckMemory(const Location* loc, const Memory* memory) { if (current_memory_index_ == 1) PrintError(loc, "only one memory block allowed"); CheckLimits(loc, &memory->page_limits, WABT_MAX_PAGES, "pages", LimitsShareable::Allowed); } void Validator::CheckDataSegments(const Module* module) { for (const ModuleField& field : module->fields) { if (auto data_segment_field = dyn_cast(&field)) { auto&& data_segment = data_segment_field->data_segment; const Memory* memory; if (Failed(CheckMemoryVar(&data_segment.memory_var, &memory))) continue; CheckConstInitExpr(&field.loc, data_segment.offset, Type::I32, "data segment offset"); } } } void Validator::CheckImport(const Location* loc, const Import* import) { switch (import->kind()) { case ExternalKind::Except: ++current_except_index_; CheckExcept(loc, &cast(import)->except); break; case ExternalKind::Func: { auto* func_import = cast(import); if (func_import->func.decl.has_func_type) CheckFuncTypeVar(&func_import->func.decl.type_var, nullptr); break; } case ExternalKind::Table: CheckTable(loc, &cast(import)->table); ++current_table_index_; break; case ExternalKind::Memory: CheckMemory(loc, &cast(import)->memory); ++current_memory_index_; break; case ExternalKind::Global: { auto* global_import = cast(import); if (global_import->global.mutable_) { PrintError(loc, "mutable globals cannot be imported"); } ++num_imported_globals_; ++current_global_index_; break; } } } void Validator::CheckExport(const Location* loc, const Export* export_) { switch (export_->kind) { case ExternalKind::Except: CheckExceptVar(&export_->var, nullptr); break; case ExternalKind::Func: CheckFuncVar(&export_->var, nullptr); break; case ExternalKind::Table: CheckTableVar(&export_->var, nullptr); break; case ExternalKind::Memory: CheckMemoryVar(&export_->var, nullptr); break; case ExternalKind::Global: { const Global* global; if (Succeeded(CheckGlobalVar(&export_->var, &global, nullptr))) { if (global->mutable_) { PrintError(&export_->var.loc, "mutable globals cannot be exported"); } } break; } } } void Validator::CheckDuplicateExportBindings(const Module* module) { module->export_bindings.FindDuplicates([this]( const BindingHash::value_type& a, const BindingHash::value_type& b) { // Choose the location that is later in the file. const Location& a_loc = a.second.loc; const Location& b_loc = b.second.loc; const Location& loc = a_loc.line > b_loc.line ? a_loc : b_loc; PrintError(&loc, "redefinition of export \"%s\"", a.first.c_str()); }); } Result Validator::CheckModule(const Module* module) { bool seen_start = false; current_module_ = module; current_table_index_ = 0; current_memory_index_ = 0; current_global_index_ = 0; num_imported_globals_ = 0; current_except_index_ = 0; for (const ModuleField& field : module->fields) { switch (field.type()) { case ModuleFieldType::Except: ++current_except_index_; CheckExcept(&field.loc, &cast(&field)->except); break; case ModuleFieldType::Func: CheckFunc(&field.loc, &cast(&field)->func); break; case ModuleFieldType::Global: CheckGlobal(&field.loc, &cast(&field)->global); current_global_index_++; break; case ModuleFieldType::Import: CheckImport(&field.loc, cast(&field)->import.get()); break; case ModuleFieldType::Export: CheckExport(&field.loc, &cast(&field)->export_); break; case ModuleFieldType::Table: CheckTable(&field.loc, &cast(&field)->table); current_table_index_++; break; case ModuleFieldType::ElemSegment: // Checked below. break; case ModuleFieldType::Memory: CheckMemory(&field.loc, &cast(&field)->memory); current_memory_index_++; break; case ModuleFieldType::DataSegment: // Checked below. break; case ModuleFieldType::FuncType: break; case ModuleFieldType::Start: { if (seen_start) { PrintError(&field.loc, "only one start function allowed"); } const Func* start_func = nullptr; CheckFuncVar(&cast(&field)->start, &start_func); if (start_func) { if (start_func->GetNumParams() != 0) { PrintError(&field.loc, "start function must be nullary"); } if (start_func->GetNumResults() != 0) { PrintError(&field.loc, "start function must not return anything"); } } seen_start = true; break; } } } CheckElemSegments(module); CheckDataSegments(module); CheckDuplicateExportBindings(module); return result_; } // Returns the result type of the invoked function, checked by the caller; // returning nullptr means that another error occured first, so the result type // should be ignored. const TypeVector* Validator::CheckInvoke(const InvokeAction* action) { const Module* module = script_->GetModule(action->module_var); if (!module) { PrintError(&action->loc, "unknown module"); return nullptr; } const Export* export_ = module->GetExport(action->name); if (!export_) { PrintError(&action->loc, "unknown function export \"%s\"", action->name.c_str()); return nullptr; } const Func* func = module->GetFunc(export_->var); if (!func) { // This error will have already been reported, just skip it. return nullptr; } size_t actual_args = action->args.size(); size_t expected_args = func->GetNumParams(); if (expected_args != actual_args) { PrintError(&action->loc, "too %s parameters to function. got %" PRIzd ", expected %" PRIzd, actual_args > expected_args ? "many" : "few", actual_args, expected_args); return nullptr; } for (size_t i = 0; i < actual_args; ++i) { const Const* const_ = &action->args[i]; CheckTypeIndex(&const_->loc, const_->type, func->GetParamType(i), "invoke", i, "argument"); } return &func->decl.sig.result_types; } Result Validator::CheckGet(const GetAction* action, Type* out_type) { const Module* module = script_->GetModule(action->module_var); if (!module) { PrintError(&action->loc, "unknown module"); return Result::Error; } const Export* export_ = module->GetExport(action->name); if (!export_) { PrintError(&action->loc, "unknown global export \"%s\"", action->name.c_str()); return Result::Error; } const Global* global = module->GetGlobal(export_->var); if (!global) { // This error will have already been reported, just skip it. return Result::Error; } *out_type = global->type; return Result::Ok; } Result Validator::CheckExceptVar(const Var* var, const Exception** out_except) { Index index; CHECK_RESULT( CheckVar(current_module_->excepts.size(), var, "except", &index)); if (out_except) *out_except = current_module_->excepts[index]; return Result::Ok; } void Validator::CheckExcept(const Location* loc, const Exception* except) { for (Type ty : except->sig) { switch (ty) { case Type::I32: case Type::I64: case Type::F32: case Type::F64: break; default: PrintError(loc, "Invalid exception type: %s", GetTypeName(ty)); break; } } } Validator::ActionResult Validator::CheckAction(const Action* action) { ActionResult result; ZeroMemory(result); switch (action->type()) { case ActionType::Invoke: result.types = CheckInvoke(cast(action)); result.kind = result.types ? ActionResult::Kind::Types : ActionResult::Kind::Error; break; case ActionType::Get: if (Succeeded(CheckGet(cast(action), &result.type))) result.kind = ActionResult::Kind::Type; else result.kind = ActionResult::Kind::Error; break; } return result; } void Validator::CheckAssertReturnNanCommand(const Action* action) { ActionResult result = CheckAction(action); // A valid result type will either be f32 or f64; convert a Types result into // a Type result, so it is easier to check below. Type::Any is used to // specify a type that should not be checked (because an earlier error // occurred). if (result.kind == ActionResult::Kind::Types) { if (result.types->size() == 1) { result.kind = ActionResult::Kind::Type; result.type = (*result.types)[0]; } else { PrintError(&action->loc, "expected 1 result, got %" PRIzd, result.types->size()); result.type = Type::Any; } } if (result.kind == ActionResult::Kind::Type && result.type != Type::Any) CheckAssertReturnNanType(&action->loc, result.type, "action"); } void Validator::CheckCommand(const Command* command) { switch (command->type) { case CommandType::Module: CheckModule(&cast(command)->module); break; case CommandType::Action: // Ignore result type. CheckAction(cast(command)->action.get()); break; case CommandType::Register: case CommandType::AssertMalformed: case CommandType::AssertInvalid: case CommandType::AssertUnlinkable: case CommandType::AssertUninstantiable: // Ignore. break; case CommandType::AssertReturn: { auto* assert_return_command = cast(command); const Action* action = assert_return_command->action.get(); ActionResult result = CheckAction(action); switch (result.kind) { case ActionResult::Kind::Types: CheckConstTypes(&action->loc, *result.types, assert_return_command->expected, "action"); break; case ActionResult::Kind::Type: CheckConstType(&action->loc, result.type, assert_return_command->expected, "action"); break; case ActionResult::Kind::Error: // Error occurred, don't do any further checks. break; } break; } case CommandType::AssertReturnCanonicalNan: CheckAssertReturnNanCommand( cast(command)->action.get()); break; case CommandType::AssertReturnArithmeticNan: CheckAssertReturnNanCommand( cast(command)->action.get()); break; case CommandType::AssertTrap: // ignore result type. CheckAction(cast(command)->action.get()); break; case CommandType::AssertExhaustion: // ignore result type. CheckAction(cast(command)->action.get()); break; } } Result Validator::CheckScript(const Script* script) { for (const std::unique_ptr& command : script->commands) CheckCommand(command.get()); return result_; } Result Validator::CheckAllFuncSignatures(const Module* module) { current_module_ = module; for (const ModuleField& field : module->fields) { switch (field.type()) { case ModuleFieldType::Func: CheckFuncSignature(&field.loc, &cast(&field)->func); break; default: break; } } return result_; } } // end anonymous namespace Result ValidateScript(WastLexer* lexer, const Script* script, ErrorHandler* error_handler) { Validator validator(error_handler, lexer, script); return validator.CheckScript(script); } Result ValidateModule(WastLexer* lexer, const Module* module, ErrorHandler* error_handler) { Validator validator(error_handler, lexer, nullptr); return validator.CheckModule(module); } Result ValidateFuncSignatures(WastLexer* lexer, const Module* module, ErrorHandler* error_handler) { Validator validator(error_handler, lexer, nullptr); return validator.CheckAllFuncSignatures(module); } } // namespace wabt