summaryrefslogtreecommitdiff
path: root/src/wasm/wasm-stack.cpp
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2024-05-15 17:00:54 -0700
committerGitHub <noreply@github.com>2024-05-16 00:00:54 +0000
commite5f2edf4bedb1ab842c2f7ac0dfd58d73e26df7d (patch)
tree258946ed4f722057af16ae640784b7029a1ed62c /src/wasm/wasm-stack.cpp
parent268feb9d77b66d60649682d62a93baa3daadd143 (diff)
downloadbinaryen-e5f2edf4bedb1ab842c2f7ac0dfd58d73e26df7d.tar.gz
binaryen-e5f2edf4bedb1ab842c2f7ac0dfd58d73e26df7d.tar.bz2
binaryen-e5f2edf4bedb1ab842c2f7ac0dfd58d73e26df7d.zip
Fix binary emitting of br_if with a refined value by emitting a cast (#6510)
This makes us compliant with the wasm spec by adding a cast: we use the refined type for br_if fallthrough values, and the wasm spec uses the branch target. If the two differ, we add a cast after the br_if to make things match. Alternatively we could match the wasm spec's typing in our IR, but we hope the wasm spec will improve here, and so this is will only be temporary in that case. Even if not, this is useful because by using the most refined type in the IR we optimize in the best way possible, and only suffer when we emit fixups in the binary, but in practice those cases are very rare: br_if is almost always dropped rather than used, in real-world code (except for fuzz cases and exploits). We check carefully when a br_if value is actually used (and not dropped) and its type actually differs, and it does not already have a cast. The last condition ensures that we do not keep adding casts over repeated roundtripping.
Diffstat (limited to 'src/wasm/wasm-stack.cpp')
-rw-r--r--src/wasm/wasm-stack.cpp159
1 files changed, 157 insertions, 2 deletions
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp
index 1d7fe5255..50c13da65 100644
--- a/src/wasm/wasm-stack.cpp
+++ b/src/wasm/wasm-stack.cpp
@@ -65,6 +65,56 @@ void BinaryInstWriter::visitLoop(Loop* curr) {
void BinaryInstWriter::visitBreak(Break* curr) {
o << int8_t(curr->condition ? BinaryConsts::BrIf : BinaryConsts::Br)
<< U32LEB(getBreakIndex(curr->name));
+
+ // See comment on |brIfsNeedingHandling| for the extra casts we need to emit
+ // here for certain br_ifs.
+ auto iter = brIfsNeedingHandling.find(curr);
+ if (iter != brIfsNeedingHandling.end()) {
+ auto unrefinedType = iter->second;
+ auto type = curr->type;
+ assert(type.size() == unrefinedType.size());
+
+ assert(curr->type.hasRef());
+
+ auto emitCast = [&](Type to) {
+ // Shim a tiny bit of IR, just enough to get visitRefCast to see what we
+ // are casting, and to emit the proper thing.
+ RefCast cast;
+ cast.type = to;
+ cast.ref = nullptr;
+ visitRefCast(&cast);
+ };
+
+ if (!type.isTuple()) {
+ // Simple: Just emit a cast, and then the type matches Binaryen IR's.
+ emitCast(type);
+ } else {
+ // Tuples are trickier to handle, and we need to use scratch locals. Stash
+ // all the values on the stack to those locals, then reload them, casting
+ // as we go.
+ //
+ // We must track how many scratch locals we've used from each type as we
+ // go, as a type might appear multiple times in the tuple. We allocated
+ // enough for each, in a contiguous range, so we just increment as we go.
+ std::unordered_map<Type, Index> scratchTypeUses;
+ for (Index i = 0; i < unrefinedType.size(); i++) {
+ auto t = unrefinedType[unrefinedType.size() - i - 1];
+ assert(scratchLocals.find(t) != scratchLocals.end());
+ auto localIndex = scratchLocals[t] + scratchTypeUses[t]++;
+ o << int8_t(BinaryConsts::LocalSet) << U32LEB(localIndex);
+ }
+ for (Index i = 0; i < unrefinedType.size(); i++) {
+ auto t = unrefinedType[i];
+ auto localIndex = scratchLocals[t] + --scratchTypeUses[t];
+ o << int8_t(BinaryConsts::LocalGet) << U32LEB(localIndex);
+ if (t.isRef()) {
+ // Note that we cast all types here, when perhaps only some of the
+ // tuple's lanes need that. This is simpler.
+ emitCast(type[i]);
+ }
+ }
+ }
+ }
}
void BinaryInstWriter::visitSwitch(Switch* curr) {
@@ -2664,11 +2714,116 @@ InsertOrderedMap<Type, Index> BinaryInstWriter::countScratchLocals() {
auto& count = scratches[Type::i32];
count = std::max(count, numScratches);
}
- };
- ScratchLocalFinder finder(*this);
+ // As mentioned in BinaryInstWriter::visitBreak, the type of br_if with a
+ // value may be more refined in Binaryen IR compared to the wasm spec, as we
+ // give it the type of the value, while the spec gives it the type of the
+ // block it targets. To avoid problems we must handle the case where a br_if
+ // has a value, the value is more refined then the target, and the value is
+ // not dropped (the last condition is very rare in real-world wasm, making
+ // all of this a quite unusual situation). First, detect such situations by
+ // seeing if we have br_ifs that return reference types at all. We do so by
+ // counting them, and as we go we ignore ones that are dropped, since a
+ // dropped value is not a problem for us.
+ //
+ // Note that we do not check all the conditions here, such as if the type
+ // matches the break target, or if the parent is a cast, which we leave for
+ // a more expensive analysis later, which we only run if we see something
+ // suspicious here.
+ Index numDangerousBrIfs = 0;
+
+ void visitBreak(Break* curr) {
+ if (curr->type.hasRef()) {
+ numDangerousBrIfs++;
+ }
+ }
+
+ void visitDrop(Drop* curr) {
+ if (curr->value->is<Break>() && curr->value->type.hasRef()) {
+ // The value is exactly a br_if of a ref, that we just visited before
+ // us. Undo the ++ from there as it can be ignored.
+ assert(numDangerousBrIfs > 0);
+ numDangerousBrIfs--;
+ }
+ }
+ } finder(*this);
finder.walk(func->body);
+ if (!finder.numDangerousBrIfs || !parent.getModule()->features.hasGC()) {
+ // Nothing more to do: either no such br_ifs, or GC is not enabled.
+ //
+ // The explicit check for GC is here because if only reference types are
+ // enabled then we still may seem to need a fixup here, e.g. if a ref.func
+ // is br_if'd to a block of type funcref. But that only appears that way
+ // because in Binaryen IR we allow non-nullable types even without GC (and
+ // if GC is not enabled then we always emit nullable types in the binary).
+ // That is, even if we see a type difference without GC, it will vanish in
+ // the binary format; there is never a need to add any ref.casts without GC
+ // being enabled.
+ return std::move(finder.scratches);
+ }
+
+ // There are dangerous-looking br_ifs, so we must do the harder work to
+ // actually investigate them in detail, including tracking block types. By
+ // being fully precise here, we'll only emit casts when absolutely necessary,
+ // which avoids repeated roundtrips adding more and more code.
+ struct RefinementScanner : public ExpressionStackWalker<RefinementScanner> {
+ BinaryInstWriter& writer;
+ ScratchLocalFinder& finder;
+
+ RefinementScanner(BinaryInstWriter& writer, ScratchLocalFinder& finder)
+ : writer(writer), finder(finder) {}
+
+ void visitBreak(Break* curr) {
+ // See if this is one of the dangerous br_ifs we must handle.
+ if (!curr->type.hasRef()) {
+ // Not even a reference.
+ return;
+ }
+ auto* parent = getParent();
+ if (parent) {
+ if (parent->is<Drop>()) {
+ // It is dropped anyhow.
+ return;
+ }
+ if (auto* cast = parent->dynCast<RefCast>()) {
+ if (Type::isSubType(cast->type, curr->type)) {
+ // It is cast to the same type or a better one. In particular this
+ // handles the case of repeated roundtripping: After the first
+ // roundtrip we emit a cast that we'll identify here, and not emit
+ // an additional one.
+ return;
+ }
+ }
+ }
+ auto* breakTarget = findBreakTarget(curr->name);
+ auto unrefinedType = breakTarget->type;
+ if (unrefinedType == curr->type) {
+ // It has the proper type anyhow.
+ return;
+ }
+
+ // Mark the br_if as needing handling, and add the type to the set of
+ // types we need scratch tuple locals for (if relevant).
+ writer.brIfsNeedingHandling[curr] = unrefinedType;
+
+ if (unrefinedType.isTuple()) {
+ // We must allocate enough scratch locals for this tuple. Note that we
+ // may need more than one per type in the tuple, if a type appears more
+ // than once, so we count their appearances.
+ InsertOrderedMap<Type, Index> scratchTypeUses;
+ for (auto t : unrefinedType) {
+ scratchTypeUses[t]++;
+ }
+ for (auto& [type, uses] : scratchTypeUses) {
+ auto& count = finder.scratches[type];
+ count = std::max(count, uses);
+ }
+ }
+ }
+ } refinementScanner(*this, finder);
+ refinementScanner.walk(func->body);
+
return std::move(finder.scratches);
}