diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/module-utils.h | 67 | ||||
-rw-r--r-- | src/ir/properties.h | 2 | ||||
-rw-r--r-- | src/literal.h | 3 | ||||
-rw-r--r-- | src/passes/Flatten.cpp | 2 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 3 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 4 | ||||
-rw-r--r-- | src/passes/Print.cpp | 10 | ||||
-rw-r--r-- | src/tools/fuzzing.h | 44 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 3 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 10 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 46 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 8 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 2 |
13 files changed, 153 insertions, 51 deletions
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index 2b1c6812c..008ec0714 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -402,7 +402,17 @@ inline void collectSignatures(Module& wasm, std::vector<Signature>& signatures, std::unordered_map<Signature, Index>& sigIndices) { - using Counts = std::unordered_map<Signature, size_t>; + struct Counts : public std::unordered_map<Signature, size_t> { + void note(Signature sig) { (*this)[sig]++; } + void maybeNote(Type type) { + if (type.isRef()) { + auto heapType = type.getHeapType(); + if (heapType.isSignature()) { + note(heapType.getSignature()); + } + } + } + }; // Collect the signature use counts for a single function auto updateCounts = [&](Function* func, Counts& counts) { @@ -417,23 +427,14 @@ collectSignatures(Module& wasm, void visitExpression(Expression* curr) { if (curr->is<RefNull>()) { - maybeNote(curr->type); + counts.maybeNote(curr->type); } else if (auto* call = curr->dynCast<CallIndirect>()) { counts[call->sig]++; } else if (Properties::isControlFlowStructure(curr)) { - maybeNote(curr->type); + counts.maybeNote(curr->type); if (curr->type.isTuple()) { // TODO: Allow control flow to have input types as well - counts[Signature(Type::none, curr->type)]++; - } - } - } - - void maybeNote(Type type) { - if (type.isRef()) { - auto heapType = type.getHeapType(); - if (heapType.isSignature()) { - counts[heapType.getSignature()]++; + counts.note(Signature(Type::none, curr->type)); } } } @@ -448,10 +449,10 @@ collectSignatures(Module& wasm, for (auto& curr : wasm.functions) { counts[curr->sig]++; for (auto type : curr->vars) { - if (type.isRef()) { - auto heapType = type.getHeapType(); - if (heapType.isSignature()) { - counts[heapType.getSignature()]++; + counts.maybeNote(type); + if (type.isTuple()) { + for (auto t : type) { + counts.maybeNote(t); } } } @@ -466,8 +467,36 @@ collectSignatures(Module& wasm, } } - // TODO: recursively traverse each reference type, which may have a child type - // this is itself a reference type. + // Recursively traverse each reference type, which may have a child type that + // is itself a reference type. This reflects an appearance in the binary + // format that is in the type section itself. + // As we do this we may find more and more signatures, as nested children of + // previous ones. Each such signature will appear in the type section once, so + // we just need to visit it once. + // TODO: handle struct and array fields + std::unordered_set<Signature> newSigs; + for (auto& pair : counts) { + newSigs.insert(pair.first); + } + while (!newSigs.empty()) { + auto iter = newSigs.begin(); + auto sig = *iter; + newSigs.erase(iter); + for (Type type : {sig.params, sig.results}) { + for (auto element : type) { + if (element.isRef()) { + auto heapType = element.getHeapType(); + if (heapType.isSignature()) { + auto sig = heapType.getSignature(); + if (!counts.count(sig)) { + newSigs.insert(sig); + } + counts[sig]++; + } + } + } + } + } // We must sort all the dependencies of a signature before it. For example, // (func (param (ref (func)))) must appear after (func). To do that, find the diff --git a/src/ir/properties.h b/src/ir/properties.h index a2c70a205..7ec1701d6 100644 --- a/src/ir/properties.h +++ b/src/ir/properties.h @@ -105,7 +105,7 @@ inline Literal getLiteral(const Expression* curr) { } else if (auto* n = curr->dynCast<RefNull>()) { return Literal(n->type); } else if (auto* r = curr->dynCast<RefFunc>()) { - return Literal(r->func); + return Literal(r->func, r->type); } else if (auto* i = curr->dynCast<I31New>()) { if (auto* c = i->value->dynCast<Const>()) { return Literal::makeI31(c->value.geti32()); diff --git a/src/literal.h b/src/literal.h index f96d6c7f5..2099049f1 100644 --- a/src/literal.h +++ b/src/literal.h @@ -71,8 +71,7 @@ public: explicit Literal(const std::array<Literal, 8>&); explicit Literal(const std::array<Literal, 4>&); explicit Literal(const std::array<Literal, 2>&); - explicit Literal(Name func, Type type = Type::funcref) - : func(func), type(type) {} + explicit Literal(Name func, Type type) : func(func), type(type) {} explicit Literal(std::unique_ptr<ExceptionPackage>&& exn) : exn(std::move(exn)), type(Type::exnref) {} Literal(const Literal& other); diff --git a/src/passes/Flatten.cpp b/src/passes/Flatten.cpp index c7b4acbd9..ebd0f3ba8 100644 --- a/src/passes/Flatten.cpp +++ b/src/passes/Flatten.cpp @@ -17,6 +17,8 @@ // // Flattens code into "Flat IR" form. See ir/flat.h. // +// TODO: handle non-nullability +// #include <ir/branch-utils.h> #include <ir/effects.h> diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index a44f02426..9eade8690 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -213,7 +213,8 @@ struct Updater : public PostWalker<Updater> { } void visitCallRef(CallRef* curr) { if (curr->isReturn) { - handleReturnCall(curr, curr->target->type); + handleReturnCall(curr, + curr->target->type.getHeapType().getSignature().results); } } void visitLocalGet(LocalGet* curr) { diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 56c99984a..dd088b03b 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -121,7 +121,9 @@ struct LocalScanner : PostWalker<LocalScanner> { Index getMaxBitsForLocal(LocalGet* get) { return getBitsForType(get->type); } Index getBitsForType(Type type) { - TODO_SINGLE_COMPOUND(type); + if (!type.isBasic()) { + return -1; + } switch (type.getBasic()) { case Type::i32: return 32; diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 864a46362..33b34904e 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -103,15 +103,21 @@ std::ostream& operator<<(std::ostream& os, SigName sigName) { if (t.isNullable()) { os << "_null"; } - os << "<"; + os << "["; + auto subsep = ""; for (auto s : sig.params) { + os << subsep; + subsep = "_"; printType(s); } os << "_->_"; + subsep = ""; for (auto s : sig.results) { + os << subsep; + subsep = "_"; printType(s); } - os << ">"; + os << "]"; continue; } } diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index 39298ce21..63f10cac4 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -1089,7 +1089,8 @@ private: WeightedOption{&Self::makeBreak, Important}, &Self::makeCall, &Self::makeCallIndirect) - .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef); + .add(FeatureSet::TypedFunctionReferences | FeatureSet::ReferenceTypes, + &Self::makeCallRef); } if (type.isSingle()) { options @@ -1149,7 +1150,8 @@ private: &Self::makeGlobalSet) .add(FeatureSet::BulkMemory, &Self::makeBulkMemory) .add(FeatureSet::Atomics, &Self::makeAtomic) - .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef); + .add(FeatureSet::TypedFunctionReferences | FeatureSet::ReferenceTypes, + &Self::makeCallRef); return (this->*pick(options))(Type::none); } @@ -1174,7 +1176,8 @@ private: &Self::makeSwitch, &Self::makeDrop, &Self::makeReturn) - .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef); + .add(FeatureSet::TypedFunctionReferences | FeatureSet::ReferenceTypes, + &Self::makeCallRef); return (this->*pick(options))(Type::unreachable); } @@ -1449,7 +1452,32 @@ private: } Expression* makeCallRef(Type type) { - return makeTrivial(type); // FIXME + // look for a call target with the right type + Function* target; + bool isReturn; + size_t i = 0; + while (1) { + if (i == TRIES || wasm.functions.empty()) { + // We can't find a proper target, give up. + return makeTrivial(type); + } + // TODO: handle unreachable + target = wasm.functions[upTo(wasm.functions.size())].get(); + isReturn = type == Type::unreachable && wasm.features.hasTailCall() && + funcContext->func->sig.results == target->sig.results; + if (target->sig.results == type || isReturn) { + break; + } + i++; + } + std::vector<Expression*> args; + for (const auto& type : target->sig.params) { + args.push_back(make(type)); + } + auto targetType = Type(HeapType(target->sig), /* nullable = */ true); + // TODO: half the time make a completely random item with that type. + return builder.makeCallRef( + builder.makeRefFunc(target->name, targetType), args, type, isReturn); } Expression* makeLocalGet(Type type) { @@ -2055,7 +2083,13 @@ private: if (type.isNullable()) { return builder.makeRefNull(type); } - WASM_UNREACHABLE("un-handleable non-nullable type"); + // Last resort: create a function. + auto* func = wasm.addFunction(builder.makeFunction( + Names::getValidFunctionName(wasm, "ref_func_target"), + type.getHeapType().getSignature(), + {}, + builder.makeUnreachable())); + return builder.makeRefFunc(func->name, type); } if (type.isTuple()) { std::vector<Expression*> operands; diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 37719d4d9..93c409797 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -2110,6 +2110,9 @@ private: if (target.breaking()) { return target; } + if (target.getSingleValue().isNull()) { + trap("null target in call_ref"); + } Name funcName = target.getSingleValue().getFunc(); auto* func = instance.wasm.getFunction(funcName); Flow ret; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 20b0899a5..170aa17bb 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -1357,7 +1357,8 @@ Type WasmBinaryBuilder::getType() { case BinaryConsts::EncodedType::nullable: return Type(getHeapType(), /* nullable = */ true); case BinaryConsts::EncodedType::nonnullable: - return Type(getHeapType(), /* nullable = */ false); + // FIXME: for now, force all inputs to be nullable + return Type(getHeapType(), /* nullable = */ true); case BinaryConsts::EncodedType::i31ref: return Type::i31ref; default: @@ -5291,6 +5292,7 @@ void WasmBinaryBuilder::visitRefFunc(RefFunc* curr) { functionRefs[index].push_back(curr); // we don't know function names yet // To support typed function refs, we give the reference not just a general // funcref, but a specific subtype with the actual signature. + // FIXME: for now, emit a nullable type here curr->finalize( Type(HeapType(getFunctionSignatureByIndex(index)), /* nullable = */ true)); } @@ -5440,6 +5442,12 @@ void WasmBinaryBuilder::visitCallRef(CallRef* curr) { BYN_TRACE("zz node: CallRef\n"); curr->target = popNonVoidExpression(); auto type = curr->target->type; + if (type == Type::unreachable) { + // If our input is unreachable, then we cannot even find out how many inputs + // we have, and just set ourselves to unreachable as well. + curr->finalize(type); + return; + } if (!type.isRef()) { throwError("Non-ref type for a call_ref: " + type.toString()); } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index d8d9fa779..434dec9d7 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -536,14 +536,10 @@ SExpressionWasmBuilder::parseParamOrLocal(Element& s, size_t& localIndex) { } localIndex++; Type type; - if (s[i]->isStr()) { - type = stringToType(s[i]->str()); - } else { - type = elementToType(*s[i]); - if (elementStartsWith(s, PARAM) && type.isTuple()) { - throw ParseException( - "params may not have tuple types", s[i]->line, s[i]->col); - } + type = elementToType(*s[i]); + if (elementStartsWith(s, PARAM) && type.isTuple()) { + throw ParseException( + "params may not have tuple types", s[i]->line, s[i]->col); } namedParams.emplace_back(name, type); } @@ -555,7 +551,7 @@ std::vector<Type> SExpressionWasmBuilder::parseResults(Element& s) { assert(elementStartsWith(s, RESULT)); std::vector<Type> types; for (size_t i = 1; i < s.size(); i++) { - types.push_back(stringToType(s[i]->str())); + types.push_back(elementToType(*s[i])); } return types; } @@ -923,7 +919,7 @@ HeapType SExpressionWasmBuilder::stringToHeapType(const char* str, Type SExpressionWasmBuilder::elementToType(Element& s) { if (s.isStr()) { - return stringToType(s.str(), false, false); + return stringToType(s.str()); } auto& list = s.list(); auto size = list.size(); @@ -942,7 +938,8 @@ Type SExpressionWasmBuilder::elementToType(Element& s) { throw ParseException( std::string("invalid reference type qualifier"), s.line, s.col); } - bool nullable = false; + // FIXME: for now, force all inputs to be nullable + bool nullable = true; size_t i = 1; if (size == 3) { nullable = true; @@ -966,7 +963,7 @@ Type SExpressionWasmBuilder::elementToType(Element& s) { // It's a tuple. std::vector<Type> types; for (size_t i = 0; i < s.size(); ++i) { - types.push_back(stringToType(list[i]->str())); + types.push_back(elementToType(*list[i])); } return Type(types); } @@ -1911,9 +1908,30 @@ Expression* SExpressionWasmBuilder::makeRefNull(Element& s) { if (s.size() != 2) { throw ParseException("invalid heap type reference", s.line, s.col); } - auto heapType = stringToHeapType(s[1]->str()); auto ret = allocator.alloc<RefNull>(); - ret->finalize(heapType); + if (s[1]->isStr()) { + // For example, this parses + // (ref.null func) + ret->finalize(stringToHeapType(s[1]->str())); + } else { + // To parse a heap type, create an element around it, and call that method. + // That is, given (func) we wrap to (ref (func)). + // For example, this parses + // (ref.null (func (param i32))) + // TODO add a helper method, but this is the only user atm, and we are + // waiting on https://github.com/WebAssembly/function-references/issues/42 + Element wrapper(wasm.allocator); + auto& list = wrapper.list(); + list.resize(3); + Element ref(wasm.allocator); + ref.setString(REF, false, false); + Element null(wasm.allocator); + null.setString(NULL_, false, false); + list[0] = &ref; + list[1] = &null; + list[2] = s[1]; + ret->finalize(elementToType(wrapper)); + } return ret; } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 5faa8b2f5..fb55417d5 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -1963,6 +1963,8 @@ void FunctionValidator::visitRefNull(RefNull* curr) { shouldBeTrue(getModule()->features.hasReferenceTypes(), curr, "ref.null requires reference-types to be enabled"); + shouldBeTrue( + curr->type.isNullable(), curr, "ref.null types must be nullable"); } void FunctionValidator::visitRefIsNull(RefIsNull* curr) { @@ -2158,10 +2160,10 @@ void FunctionValidator::visitCallRef(CallRef* curr) { shouldBeTrue(getModule()->features.hasTypedFunctionReferences(), curr, "call_ref requires typed-function-references to be enabled"); - shouldBeTrue(curr->target->type.isFunction(), - curr, - "call_ref target must be a function reference"); if (curr->target->type != Type::unreachable) { + shouldBeTrue(curr->target->type.isFunction(), + curr, + "call_ref target must be a function reference"); validateCallParamsAndResult( curr, curr->target->type.getHeapType().getSignature()); } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index ac76a63ac..cbf573950 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -973,12 +973,10 @@ void MemoryGrow::finalize() { void RefNull::finalize(HeapType heapType) { type = Type(heapType, true); } void RefNull::finalize(Type type_) { - assert(type_ == Type::unreachable || type_.isNullable()); type = type_; } void RefNull::finalize() { - assert(type == Type::unreachable || type.isNullable()); } void RefIsNull::finalize() { |