From 56ce1eaba7f500b572bcfe06e3248372e9672322 Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Thu, 14 Sep 2023 16:15:26 -0700 Subject: Optimize tuple.extract of gets in BinaryInstWriter (#5941) In general, the binary lowering of tuple.extract expects that all the tuple values are on top of the stack, so it inserts drops and possibly uses a scratch local to ensure only the extracted value is left. However, when the extracted tuple expression is a local.get, local.tee, or global.get, it's much more efficient to change the lowering of the get or tee to ensure that only the extracted value is on the stack to begin with. Implement that optimization in the binary writer. --- src/wasm-stack.h | 5 +++++ src/wasm/wasm-stack.cpp | 43 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/wasm-stack.h b/src/wasm-stack.h index 03ec4eef5..1f66212ad 100644 --- a/src/wasm-stack.h +++ b/src/wasm-stack.h @@ -148,6 +148,11 @@ private: InsertOrderedMap scratchLocals; void countScratchLocals(); void setScratchLocals(); + + // local.get, local.tee, and glboal.get expressions that will be followed by + // tuple.extracts. We can optimize these by getting only the local for the + // extracted index. + std::unordered_map extractedGets; }; // Takes binaryen IR and converts it to something else (binary or stack IR) diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 1ddf69d41..c3d53f126 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -87,6 +87,13 @@ void BinaryInstWriter::visitCallIndirect(CallIndirect* curr) { } void BinaryInstWriter::visitLocalGet(LocalGet* curr) { + if (auto it = extractedGets.find(curr); it != extractedGets.end()) { + // We have a tuple of locals to get, but we will only end up using one of + // them, so we can just emit that one. + o << int8_t(BinaryConsts::LocalGet) + << U32LEB(mappedLocals[std::make_pair(curr->index, it->second)]); + return; + } size_t numValues = func->getLocalType(curr->index).size(); for (Index i = 0; i < numValues; ++i) { o << int8_t(BinaryConsts::LocalGet) @@ -96,14 +103,28 @@ void BinaryInstWriter::visitLocalGet(LocalGet* curr) { void BinaryInstWriter::visitLocalSet(LocalSet* curr) { size_t numValues = func->getLocalType(curr->index).size(); + // If this is a tuple, set all the elements with nonzero index. for (Index i = numValues - 1; i >= 1; --i) { o << int8_t(BinaryConsts::LocalSet) << U32LEB(mappedLocals[std::make_pair(curr->index, i)]); } if (!curr->isTee()) { + // This is not a tee, so just finish setting the values. o << int8_t(BinaryConsts::LocalSet) << U32LEB(mappedLocals[std::make_pair(curr->index, 0)]); + } else if (auto it = extractedGets.find(curr); it != extractedGets.end()) { + // We only need to get the single extracted value. + if (it->second == 0) { + o << int8_t(BinaryConsts::LocalTee) + << U32LEB(mappedLocals[std::make_pair(curr->index, 0)]); + } else { + o << int8_t(BinaryConsts::LocalSet) + << U32LEB(mappedLocals[std::make_pair(curr->index, 0)]); + o << int8_t(BinaryConsts::LocalGet) + << U32LEB(mappedLocals[std::make_pair(curr->index, it->second)]); + } } else { + // We need to get all the values. o << int8_t(BinaryConsts::LocalTee) << U32LEB(mappedLocals[std::make_pair(curr->index, 0)]); for (Index i = 1; i < numValues; ++i) { @@ -114,8 +135,14 @@ void BinaryInstWriter::visitLocalSet(LocalSet* curr) { } void BinaryInstWriter::visitGlobalGet(GlobalGet* curr) { - // Emit a global.get for each element if this is a tuple global Index index = parent.getGlobalIndex(curr->name); + if (auto it = extractedGets.find(curr); it != extractedGets.end()) { + // We have a tuple of globals to get, but we will only end up using one of + // them, so we can just emit that one. + o << int8_t(BinaryConsts::GlobalGet) << U32LEB(index + it->second); + return; + } + // Emit a global.get for each element if this is a tuple global size_t numValues = curr->type.size(); for (Index i = 0; i < numValues; ++i) { o << int8_t(BinaryConsts::GlobalGet) << U32LEB(index + i); @@ -1970,6 +1997,10 @@ void BinaryInstWriter::visitTupleMake(TupleMake* curr) { } void BinaryInstWriter::visitTupleExtract(TupleExtract* curr) { + if (extractedGets.count(curr->tuple)) { + // We already have just the extracted value on the stack. + return; + } size_t numVals = curr->tuple->type.size(); // Drop all values after the one we want for (size_t i = curr->index + 1; i < numVals; ++i) { @@ -2506,6 +2537,7 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { } } setScratchLocals(); + o << U32LEB(numLocalsByType.size()); for (auto& localType : localTypes) { o << U32LEB(numLocalsByType.at(localType)); @@ -2532,6 +2564,15 @@ void BinaryInstWriter::countScratchLocals() { for (auto& [type, _] : scratchLocals) { noteLocalType(type); } + // While we have all the tuple.extracts, also find extracts of local.gets, + // local.tees, and global.gets that we can optimize. + for (auto* extract : extracts.list) { + auto* tuple = extract->tuple; + if (tuple->is() || tuple->is() || + tuple->is()) { + extractedGets.insert({tuple, extract->index}); + } + } } void BinaryInstWriter::setScratchLocals() { -- cgit v1.2.3