diff options
author | Thomas Lively <tlively@google.com> | 2024-04-15 14:02:24 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-15 14:02:24 -0700 |
commit | b1245577ba92b77a97e266cf4c7f7cd15e6e7f28 (patch) | |
tree | 333e17f651e6ed9d24fa13aa86f38fcc907541cf /src | |
parent | 8c834e8257b03ea87b639ddac9adefec64fcad00 (diff) | |
download | binaryen-b1245577ba92b77a97e266cf4c7f7cd15e6e7f28.tar.gz binaryen-b1245577ba92b77a97e266cf4c7f7cd15e6e7f28.tar.bz2 binaryen-b1245577ba92b77a97e266cf4c7f7cd15e6e7f28.zip |
[Strings] Add a string lowering pass using magic imports (#6497)
The latest idea for efficient string constants is to encode the constants in
the import names of their globals and implement fast paths in the engines for
materializing those constants at instantiation time without needing to parse
anything in JS. This strategy only works for valid strings (i.e. strings without
unpaired surrogates) because only valid strings can be used as import names in
the WebAssembly syntax.
Add a new configuration of the StringLowering pass that encodes valid string
contents in import names, falling back to the JSON custom section approach for
invalid strings.
To test this chang, update the printer to escape import and export names
properly and update the legacy parser to parse escapes in import and export
names properly. As a drive-by, remove the incorrect check in the parser that the
import module and base names are non-empty.
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/Print.cpp | 13 | ||||
-rw-r--r-- | src/passes/StringLowering.cpp | 35 | ||||
-rw-r--r-- | src/passes/pass.cpp | 4 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | src/pretty_printing.h | 18 | ||||
-rw-r--r-- | src/support/string.cpp | 41 | ||||
-rw-r--r-- | src/support/string.h | 5 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 19 |
8 files changed, 96 insertions, 40 deletions
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 68a2e4cb6..a90ef4669 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2838,8 +2838,9 @@ void PrintSExpression::handleSignature(HeapType curr, Name name) { void PrintSExpression::visitExport(Export* curr) { o << '('; printMedium(o, "export "); - // TODO: Escape the string properly. - printText(o, curr->name.str.data()) << " ("; + std::stringstream escaped; + String::printEscaped(escaped, curr->name.str); + printText(o, escaped.str(), false) << " ("; switch (curr->kind) { case ExternalKind::Function: o << "func"; @@ -2865,9 +2866,11 @@ void PrintSExpression::visitExport(Export* curr) { void PrintSExpression::emitImportHeader(Importable* curr) { printMedium(o, "import "); - // TODO: Escape the strings properly and use std::string_view. - printText(o, curr->module.str.data()) << ' '; - printText(o, curr->base.str.data()) << ' '; + std::stringstream escapedModule, escapedBase; + String::printEscaped(escapedModule, curr->module.str); + String::printEscaped(escapedBase, curr->base.str); + printText(o, escapedModule.str(), false) << ' '; + printText(o, escapedBase.str(), false) << ' '; } void PrintSExpression::visitGlobal(Global* curr) { diff --git a/src/passes/StringLowering.cpp b/src/passes/StringLowering.cpp index df2d66860..dd7428546 100644 --- a/src/passes/StringLowering.cpp +++ b/src/passes/StringLowering.cpp @@ -189,6 +189,13 @@ struct StringGathering : public Pass { }; struct StringLowering : public StringGathering { + // If true, then encode well-formed strings as (import "'" "string...") + // instead of emitting them into the JSON custom section. + bool useMagicImports; + + StringLowering(bool useMagicImports = false) + : useMagicImports(useMagicImports) {} + void run(Module* module) override { if (!module->features.has(FeatureSet::Strings)) { return; @@ -217,25 +224,30 @@ struct StringLowering : public StringGathering { } void makeImports(Module* module) { - Index importIndex = 0; + Index jsonImportIndex = 0; std::stringstream json; json << '['; bool first = true; - std::vector<Name> importedStrings; for (auto& global : module->globals) { if (global->init) { if (auto* c = global->init->dynCast<StringConst>()) { - global->module = "string.const"; - global->base = std::to_string(importIndex); - importIndex++; - global->init = nullptr; - - if (first) { - first = false; + std::stringstream utf8; + if (useMagicImports && + String::convertUTF16ToUTF8(utf8, c->string.str)) { + global->module = "'"; + global->base = Name(utf8.str()); } else { - json << ','; + global->module = "string.const"; + global->base = std::to_string(jsonImportIndex); + if (first) { + first = false; + } else { + json << ','; + } + String::printEscapedJSON(json, c->string.str); + jsonImportIndex++; } - String::printEscapedJSON(json, c->string.str); + global->init = nullptr; } } } @@ -516,5 +528,6 @@ struct StringLowering : public StringGathering { Pass* createStringGatheringPass() { return new StringGathering(); } Pass* createStringLoweringPass() { return new StringLowering(); } +Pass* createStringLoweringMagicImportPass() { return new StringLowering(true); } } // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 0955082ac..19ddaf2d4 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -484,6 +484,10 @@ void PassRegistry::registerPasses() { "lowers wasm strings and operations to imports", createStringLoweringPass); registerPass( + "string-lowering-magic-imports", + "same as string-lowering, but encodes well-formed strings as magic imports", + createStringLoweringMagicImportPass); + registerPass( "strip", "deprecated; same as strip-debug", createStripDebugPass); registerPass("stack-check", "enforce limits on llvm's __stack_pointer global", diff --git a/src/passes/passes.h b/src/passes/passes.h index 1b1ca99c6..23a9ea70b 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -156,6 +156,7 @@ Pass* createSimplifyLocalsNoTeeNoStructurePass(); Pass* createStackCheckPass(); Pass* createStringGatheringPass(); Pass* createStringLoweringPass(); +Pass* createStringLoweringMagicImportPass(); Pass* createStripDebugPass(); Pass* createStripDWARFPass(); Pass* createStripProducersPass(); diff --git a/src/pretty_printing.h b/src/pretty_printing.h index f693c4d51..0f1a0ed87 100644 --- a/src/pretty_printing.h +++ b/src/pretty_printing.h @@ -51,29 +51,35 @@ inline std::ostream& restoreNormalColor(std::ostream& o) { return o; } -inline std::ostream& printText(std::ostream& o, const char* str) { - o << '"'; +inline std::ostream& +printText(std::ostream& o, std::string_view str, bool needQuotes = true) { + if (needQuotes) { + o << '"'; + } Colors::green(o); o << str; Colors::normal(o); - return o << '"'; + if (needQuotes) { + o << '"'; + } + return o; } -inline std::ostream& printMajor(std::ostream& o, const char* str) { +inline std::ostream& printMajor(std::ostream& o, std::string_view str) { prepareMajorColor(o); o << str; restoreNormalColor(o); return o; } -inline std::ostream& printMedium(std::ostream& o, const char* str) { +inline std::ostream& printMedium(std::ostream& o, std::string_view str) { prepareColor(o); o << str; restoreNormalColor(o); return o; } -inline std::ostream& printMinor(std::ostream& o, const char* str) { +inline std::ostream& printMinor(std::ostream& o, std::string_view str) { prepareMinorColor(o); o << str; restoreNormalColor(o); diff --git a/src/support/string.cpp b/src/support/string.cpp index 68249f51e..31d0e9170 100644 --- a/src/support/string.cpp +++ b/src/support/string.cpp @@ -213,7 +213,8 @@ std::optional<uint16_t> takeWTF16CodeUnit(std::string_view& str) { return u; } -std::optional<uint32_t> takeWTF16CodePoint(std::string_view& str) { +std::optional<uint32_t> takeWTF16CodePoint(std::string_view& str, + bool allowWTF = true) { auto u = takeWTF16CodeUnit(str); if (!u) { return std::nullopt; @@ -228,7 +229,13 @@ std::optional<uint32_t> takeWTF16CodePoint(std::string_view& str) { uint16_t highBits = *u - 0xD800; uint16_t lowBits = *low - 0xDC00; return 0x10000 + ((highBits << 10) | lowBits); + } else if (!allowWTF) { + // Unpaired high surrogate. + return std::nullopt; } + } else if (!allowWTF && 0xDC00 <= *u && *u < 0xE000) { + // Unpaired low surrogate. + return std::nullopt; } return *u; @@ -242,6 +249,23 @@ void writeWTF16CodeUnit(std::ostream& os, uint16_t u) { constexpr uint32_t replacementCharacter = 0xFFFD; +bool doConvertWTF16ToWTF8(std::ostream& os, + std::string_view str, + bool allowWTF) { + bool valid = true; + + while (str.size()) { + auto u = takeWTF16CodePoint(str, allowWTF); + if (!u) { + valid = false; + u = replacementCharacter; + } + writeWTF8CodePoint(os, *u); + } + + return valid; +} + } // anonymous namespace std::ostream& writeWTF8CodePoint(std::ostream& os, uint32_t u) { @@ -308,18 +332,11 @@ bool convertWTF8ToWTF16(std::ostream& os, std::string_view str) { } bool convertWTF16ToWTF8(std::ostream& os, std::string_view str) { - bool valid = true; - - while (str.size()) { - auto u = takeWTF16CodePoint(str); - if (!u) { - valid = false; - u = replacementCharacter; - } - writeWTF8CodePoint(os, *u); - } + return doConvertWTF16ToWTF8(os, str, true); +} - return valid; +bool convertUTF16ToUTF8(std::ostream& os, std::string_view str) { + return doConvertWTF16ToWTF8(os, str, false); } std::ostream& printEscapedJSON(std::ostream& os, std::string_view str) { diff --git a/src/support/string.h b/src/support/string.h index be2c3c6a3..af120ab4e 100644 --- a/src/support/string.h +++ b/src/support/string.h @@ -94,6 +94,11 @@ bool convertWTF8ToWTF16(std::ostream& os, std::string_view str); // Returns `true` iff the input was valid WTF-16. bool convertWTF16ToWTF8(std::ostream& os, std::string_view str); +// Writes the UTF-8 encoding of the given UTF-16LE string to `os`, inserting a +// replacement character in place of any unpaired surrogate or incomplete code +// unit. Returns `true` if the input was valid UTF-16. +bool convertUTF16ToUTF8(std::ostream& os, std::string_view str); + } // namespace wasm::String #endif // wasm_support_string_h diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index eb31355c8..bca94a768 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -3622,7 +3622,9 @@ void SExpressionWasmBuilder::parseInnerData(Element& s, void SExpressionWasmBuilder::parseExport(Element& s) { std::unique_ptr<Export> ex = std::make_unique<Export>(); - ex->name = s[1]->str(); + std::vector<char> nameBytes; + stringToBinary(*s[1], s[1]->str().str, nameBytes); + ex->name = std::string(nameBytes.data(), nameBytes.size()); if (s[2]->isList()) { auto& inner = *s[2]; if (elementStartsWith(inner, FUNC)) { @@ -3703,15 +3705,20 @@ void SExpressionWasmBuilder::parseImport(Element& s) { if (!newStyle) { kind = ExternalKind::Function; } - auto module = s[i++]->str(); + std::vector<char> moduleBytes; + stringToBinary(*s[i], s[i]->str().str, moduleBytes); + Name module = std::string(moduleBytes.data(), moduleBytes.size()); + i++; + if (!s[i]->isStr()) { throw SParseException("no name for import", s, *s[i]); } - auto base = s[i]->str(); - if (!module.size() || !base.size()) { - throw SParseException("imports must have module and base", s, *s[i]); - } + + std::vector<char> baseBytes; + stringToBinary(*s[i], s[i]->str().str, baseBytes); + Name base = std::string(baseBytes.data(), baseBytes.size()); i++; + // parse internals Element& inner = newStyle ? *s[3] : s; Index j = newStyle ? newStyleInner : i; |