diff options
-rw-r--r-- | CHANGELOG.md | 1 | ||||
-rwxr-xr-x | scripts/fuzz_opt.py | 3 | ||||
-rw-r--r-- | src/binaryen-c.cpp | 8 | ||||
-rw-r--r-- | src/literal.h | 2 | ||||
-rw-r--r-- | src/parser/contexts.h | 9 | ||||
-rw-r--r-- | src/parser/lexer.cpp | 21 | ||||
-rw-r--r-- | src/passes/Print.cpp | 8 | ||||
-rw-r--r-- | src/passes/StringLowering.cpp | 8 | ||||
-rw-r--r-- | src/support/json.cpp | 7 | ||||
-rw-r--r-- | src/support/string.cpp | 231 | ||||
-rw-r--r-- | src/support/string.h | 19 | ||||
-rw-r--r-- | src/tools/fuzzing/fuzzing.cpp | 10 | ||||
-rw-r--r-- | src/wasm-builder.h | 13 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 53 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 23 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 16 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 8 | ||||
-rw-r--r-- | test/lit/passes/precompute-strings.wast | 84 | ||||
-rw-r--r-- | test/lit/passes/string-lowering.wast | 18 | ||||
-rw-r--r-- | test/lit/strings.wast | 4 |
20 files changed, 339 insertions, 207 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index ba9adda87..c4b81c710 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Current Trunk - Add a new `BinaryenModuleReadWithFeatures` function to the C API that allows to configure which features to enable in the parser. - The build-time option to use legacy WasmGC opcodes is removed. + - The strings in `string.const` instructions must now be valid WTF-8. v117 ---- diff --git a/scripts/fuzz_opt.py b/scripts/fuzz_opt.py index 9831eb467..686895790 100755 --- a/scripts/fuzz_opt.py +++ b/scripts/fuzz_opt.py @@ -333,9 +333,6 @@ INITIAL_CONTENTS_IGNORE = [ 'exception-handling.wast', 'translate-to-new-eh.wast', 'rse-eh.wast', - # Non-UTF8 strings trap in V8, and have limitations in our interpreter - 'string-lowering.wast', - 'precompute-strings.wast', ] diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 29e0597d7..402dce553 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -26,6 +26,7 @@ #include "pass.h" #include "shell-interface.h" #include "support/colors.h" +#include "support/string.h" #include "wasm-binary.h" #include "wasm-builder.h" #include "wasm-interpreter.h" @@ -1895,8 +1896,13 @@ BinaryenExpressionRef BinaryenStringNew(BinaryenModuleRef module, } BinaryenExpressionRef BinaryenStringConst(BinaryenModuleRef module, const char* name) { + // Re-encode from WTF-8 to WTF-16. + std::stringstream wtf16; + [[maybe_unused]] bool valid = String::convertWTF8ToWTF16(wtf16, name); + assert(valid); + // TODO: Use wtf16.view() once we have C++20. return static_cast<Expression*>( - Builder(*(Module*)module).makeStringConst(name)); + Builder(*(Module*)module).makeStringConst(wtf16.str())); } BinaryenExpressionRef BinaryenStringMeasure(BinaryenModuleRef module, BinaryenOp op, diff --git a/src/literal.h b/src/literal.h index 971cc8b3e..a4017f6ec 100644 --- a/src/literal.h +++ b/src/literal.h @@ -85,7 +85,7 @@ public: assert(type.isSignature()); } explicit Literal(std::shared_ptr<GCData> gcData, HeapType type); - explicit Literal(std::string string); + explicit Literal(std::string_view string); Literal(const Literal& other); Literal& operator=(const Literal& other); ~Literal(); diff --git a/src/parser/contexts.h b/src/parser/contexts.h index 8b59ab40b..0979461a0 100644 --- a/src/parser/contexts.h +++ b/src/parser/contexts.h @@ -22,6 +22,7 @@ #include "lexer.h" #include "support/name.h" #include "support/result.h" +#include "support/string.h" #include "wasm-builder.h" #include "wasm-ir-builder.h" #include "wasm.h" @@ -2491,7 +2492,13 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> { Result<> makeStringConst(Index pos, const std::vector<Annotation>& annotations, std::string_view str) { - return withLoc(pos, irBuilder.makeStringConst(Name(str))); + // Re-encode from WTF-8 to WTF-16. + std::stringstream wtf16; + if (!String::convertWTF8ToWTF16(wtf16, str)) { + return in.err(pos, "invalid string constant"); + } + // TODO: Use wtf16.view() once we have C++20. + return withLoc(pos, irBuilder.makeStringConst(wtf16.str())); } Result<> makeStringMeasure(Index pos, diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp index 8c7542dd7..48da163e1 100644 --- a/src/parser/lexer.cpp +++ b/src/parser/lexer.cpp @@ -23,6 +23,7 @@ #include <variant> #include "lexer.h" +#include "support/string.h" using namespace std::string_view_literals; @@ -308,25 +309,7 @@ public: if ((0xd800 <= u && u < 0xe000) || 0x110000 <= u) { return false; } - if (u < 0x80) { - // 0xxxxxxx - *escapeBuilder << uint8_t(u); - } else if (u < 0x800) { - // 110xxxxx 10xxxxxx - *escapeBuilder << uint8_t(0b11000000 | ((u >> 6) & 0b00011111)); - *escapeBuilder << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); - } else if (u < 0x10000) { - // 1110xxxx 10xxxxxx 10xxxxxx - *escapeBuilder << uint8_t(0b11100000 | ((u >> 12) & 0b00001111)); - *escapeBuilder << uint8_t(0b10000000 | ((u >> 6) & 0b00111111)); - *escapeBuilder << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); - } else { - // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - *escapeBuilder << uint8_t(0b11110000 | ((u >> 18) & 0b00000111)); - *escapeBuilder << uint8_t(0b10000000 | ((u >> 12) & 0b00111111)); - *escapeBuilder << uint8_t(0b10000000 | ((u >> 6) & 0b00111111)); - *escapeBuilder << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); - } + String::writeWTF8CodePoint(*escapeBuilder, u); return true; } }; diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 643f1cc3f..80047a281 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2232,7 +2232,13 @@ struct PrintExpressionContents } void visitStringConst(StringConst* curr) { printMedium(o, "string.const "); - String::printEscaped(o, curr->string.str); + // Re-encode from WTF-16 to WTF-8. + std::stringstream wtf8; + [[maybe_unused]] bool valid = + String::convertWTF16ToWTF8(wtf8, curr->string.str); + assert(valid); + // TODO: Use wtf8.view() once we have C++20. + String::printEscaped(o, wtf8.str()); } void visitStringMeasure(StringMeasure* curr) { switch (curr->op) { diff --git a/src/passes/StringLowering.cpp b/src/passes/StringLowering.cpp index e0d3fbad0..322f0deb2 100644 --- a/src/passes/StringLowering.cpp +++ b/src/passes/StringLowering.cpp @@ -147,8 +147,14 @@ struct StringGathering : public Pass { } auto& string = strings[i]; + // Re-encode from WTF-16 to WTF-8 to make the name easier to read. + std::stringstream wtf8; + [[maybe_unused]] bool valid = + String::convertWTF16ToWTF8(wtf8, string.str); + assert(valid); + // TODO: Use wtf8.view() once we have C++20. auto name = Names::getValidGlobalName( - *module, std::string("string.const_") + std::string(string.str)); + *module, std::string("string.const_") + std::string(wtf8.str())); globalName = name; newNames.insert(name); auto* stringConst = builder.makeStringConst(string); diff --git a/src/support/json.cpp b/src/support/json.cpp index ab55cc75f..dd94719d4 100644 --- a/src/support/json.cpp +++ b/src/support/json.cpp @@ -21,7 +21,12 @@ namespace json { void Value::stringify(std::ostream& os, bool pretty) { if (isString()) { - wasm::String::printEscapedJSON(os, getCString()); + std::stringstream wtf16; + [[maybe_unused]] bool valid = + wasm::String::convertWTF8ToWTF16(wtf16, getIString().str); + assert(valid); + // TODO: Use wtf16.view() once we have C++20. + wasm::String::printEscapedJSON(os, wtf16.str()); } else if (isArray()) { os << '['; auto first = true; diff --git a/src/support/string.cpp b/src/support/string.cpp index c3a9ce4e4..68249f51e 100644 --- a/src/support/string.cpp +++ b/src/support/string.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include <optional> #include <ostream> #include "support/string.h" @@ -106,7 +107,7 @@ std::string trim(const std::string& input) { return input.substr(0, size); } -std::ostream& printEscaped(std::ostream& os, const std::string_view str) { +std::ostream& printEscaped(std::ostream& os, std::string_view str) { os << '"'; for (unsigned char c : str) { switch (c) { @@ -140,67 +141,193 @@ std::ostream& printEscaped(std::ostream& os, const std::string_view str) { return os << '"'; } -std::ostream& printEscapedJSON(std::ostream& os, const std::string_view str) { - os << '"'; - constexpr uint32_t replacementCharacter = 0xFFFD; - bool lastWasLeadingSurrogate = false; - for (size_t i = 0; i < str.size();) { - // Decode from WTF-8 into a unicode code point. - uint8_t leading = str[i]; - size_t trailingBytes; - uint32_t u; - if ((leading & 0b10000000) == 0b00000000) { - // 0xxxxxxx - trailingBytes = 0; - u = leading; - } else if ((leading & 0b11100000) == 0b11000000) { - // 110xxxxx 10xxxxxx - trailingBytes = 1; - u = (leading & 0b00011111) << 6; - } else if ((leading & 0b11110000) == 0b11100000) { - // 1110xxxx 10xxxxxx 10xxxxxx - trailingBytes = 2; - u = (leading & 0b00001111) << 12; - } else if ((leading & 0b11111000) == 0b11110000) { - // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - trailingBytes = 3; - u = (leading & 0b00000111) << 18; - } else { - std::cerr << "warning: Bad WTF-8 leading byte (" << std::hex - << int(leading) << std::dec << "). Replacing.\n"; - trailingBytes = 0; - u = replacementCharacter; - } +namespace { - ++i; +std::optional<uint32_t> takeWTF8CodePoint(std::string_view& str) { + bool valid = true; - if (i + trailingBytes > str.size()) { - std::cerr << "warning: Unexpected end of string. Replacing.\n"; - u = replacementCharacter; - } else { - for (size_t j = 0; j < trailingBytes; ++j) { - uint8_t trailing = str[i + j]; - if ((trailing & 0b11000000) != 0b10000000) { - std::cerr << "warning: Bad WTF-8 trailing byte (" << std::hex - << int(trailing) << std::dec << "). Replacing.\n"; - u = replacementCharacter; - break; - } - // Shift 6 bits for every remaining trailing byte after this one. - u |= (trailing & 0b00111111) << (6 * (trailingBytes - j - 1)); + if (str.size() == 0) { + return std::nullopt; + } + + uint8_t leading = str[0]; + size_t trailingBytes; + uint32_t u; + if ((leading & 0b10000000) == 0b00000000) { + // 0xxxxxxx + trailingBytes = 0; + u = leading; + } else if ((leading & 0b11100000) == 0b11000000) { + // 110xxxxx 10xxxxxx + trailingBytes = 1; + u = (leading & 0b00011111) << 6; + } else if ((leading & 0b11110000) == 0b11100000) { + // 1110xxxx 10xxxxxx 10xxxxxx + trailingBytes = 2; + u = (leading & 0b00001111) << 12; + } else if ((leading & 0b11111000) == 0b11110000) { + // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + trailingBytes = 3; + u = (leading & 0b00000111) << 18; + } else { + // Bad WTF-8 leading byte. + trailingBytes = 0; + valid = false; + } + + if (str.size() <= trailingBytes) { + // Unexpected end of string. + str = str.substr(str.size()); + return std::nullopt; + } + + if (valid) { + for (size_t j = 0; j < trailingBytes; ++j) { + uint8_t trailing = str[1 + j]; + if ((trailing & 0b11000000) != 0b10000000) { + // Bad WTF-8 trailing byte. + valid = false; + break; } + // Shift 6 bits for every remaining trailing byte after this one. + u |= (trailing & 0b00111111) << (6 * (trailingBytes - j - 1)); } + } + + str = str.substr(1 + trailingBytes); + if (!valid) { + return std::nullopt; + } + return u; +} + +std::optional<uint16_t> takeWTF16CodeUnit(std::string_view& str) { + if (str.size() < 2) { + str = str.substr(str.size()); + return std::nullopt; + } + + // Use a little-endian encoding. + uint16_t u = uint8_t(str[0]) | (uint8_t(str[1]) << 8); + str = str.substr(2); + return u; +} + +std::optional<uint32_t> takeWTF16CodePoint(std::string_view& str) { + auto u = takeWTF16CodeUnit(str); + if (!u) { + return std::nullopt; + } + + if (0xD800 <= *u && *u < 0xDC00) { + // High surrogate; take the next low surrogate if it exists. + auto next = str; + auto low = takeWTF16CodeUnit(next); + if (low && 0xDC00 <= *low && *low < 0xE000) { + str = next; + uint16_t highBits = *u - 0xD800; + uint16_t lowBits = *low - 0xDC00; + return 0x10000 + ((highBits << 10) | lowBits); + } + } + + return *u; +} + +void writeWTF16CodeUnit(std::ostream& os, uint16_t u) { + // Little-endian encoding. + os << uint8_t(u & 0xFF); + os << uint8_t(u >> 8); +} + +constexpr uint32_t replacementCharacter = 0xFFFD; + +} // anonymous namespace + +std::ostream& writeWTF8CodePoint(std::ostream& os, uint32_t u) { + assert(u < 0x110000); + if (u < 0x80) { + // 0xxxxxxx + os << uint8_t(u); + } else if (u < 0x800) { + // 110xxxxx 10xxxxxx + os << uint8_t(0b11000000 | ((u >> 6) & 0b00011111)); + os << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); + } else if (u < 0x10000) { + // 1110xxxx 10xxxxxx 10xxxxxx + os << uint8_t(0b11100000 | ((u >> 12) & 0b00001111)); + os << uint8_t(0b10000000 | ((u >> 6) & 0b00111111)); + os << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); + } else { + // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + os << uint8_t(0b11110000 | ((u >> 18) & 0b00000111)); + os << uint8_t(0b10000000 | ((u >> 12) & 0b00111111)); + os << uint8_t(0b10000000 | ((u >> 6) & 0b00111111)); + os << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); + } + return os; +} + +std::ostream& writeWTF16CodePoint(std::ostream& os, uint32_t u) { + assert(u < 0x110000); + if (u < 0x10000) { + writeWTF16CodeUnit(os, u); + } else { + // Encode with a surrogate pair. + uint16_t high = 0xD800 + ((u - 0x10000) >> 10); + uint16_t low = 0xDC00 + ((u - 0x10000) & 0x3FF); + writeWTF16CodeUnit(os, high); + writeWTF16CodeUnit(os, low); + } + return os; +} + +bool convertWTF8ToWTF16(std::ostream& os, std::string_view str) { + bool valid = true; + bool lastWasLeadingSurrogate = false; - i += trailingBytes; + while (str.size()) { + auto u = takeWTF8CodePoint(str); + if (!u) { + valid = false; + u = replacementCharacter; + } - bool isLeadingSurrogate = 0xD800 <= u && u <= 0xDBFF; - bool isTrailingSurrogate = 0xDC00 <= u && u <= 0xDFFF; + bool isLeadingSurrogate = 0xD800 <= *u && *u < 0xDC00; + bool isTrailingSurrogate = 0xDC00 <= *u && *u < 0xE000; if (lastWasLeadingSurrogate && isTrailingSurrogate) { - std::cerr << "warning: Invalid surrogate sequence in WTF-8.\n"; + // Invalid surrogate sequence. + valid = false; } lastWasLeadingSurrogate = isLeadingSurrogate; - // Encode unicode code point into JSON. + writeWTF16CodePoint(os, *u); + } + + return valid; +} + +bool convertWTF16ToWTF8(std::ostream& os, std::string_view str) { + bool valid = true; + + while (str.size()) { + auto u = takeWTF16CodePoint(str); + if (!u) { + valid = false; + u = replacementCharacter; + } + writeWTF8CodePoint(os, *u); + } + + return valid; +} + +std::ostream& printEscapedJSON(std::ostream& os, std::string_view str) { + os << '"'; + while (str.size()) { + auto u = *takeWTF16CodePoint(str); + + // Use escape sequences mandated by the JSON spec. switch (u) { case '"': os << "\\\""; diff --git a/src/support/string.h b/src/support/string.h index 6fb3f693b..be2c3c6a3 100644 --- a/src/support/string.h +++ b/src/support/string.h @@ -75,9 +75,24 @@ inline bool isNumber(const std::string& str) { return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); } -std::ostream& printEscaped(std::ostream& os, const std::string_view str); +std::ostream& printEscaped(std::ostream& os, std::string_view str); -std::ostream& printEscapedJSON(std::ostream& os, const std::string_view str); +// `str` must be a valid WTF-16 string. +std::ostream& printEscapedJSON(std::ostream& os, std::string_view str); + +std::ostream& writeWTF8CodePoint(std::ostream& os, uint32_t u); + +std::ostream& writeWTF16CodePoint(std::ostream& os, uint32_t u); + +// Writes the WTF-16LE encoding of the given WTF-8 string to `os`, inserting +// replacement characters as necessary when encountering invalid WTF-8. Returns +// `true` iff the input was valid WTF-8. +bool convertWTF8ToWTF16(std::ostream& os, std::string_view str); + +// Writes the WTF-8 encoding of the given WTF-16LE string to `os`, inserting a +// replacement character at the end if the string is an odd number of bytes. +// Returns `true` iff the input was valid WTF-16. +bool convertWTF16ToWTF8(std::ostream& os, std::string_view str); } // namespace wasm::String diff --git a/src/tools/fuzzing/fuzzing.cpp b/src/tools/fuzzing/fuzzing.cpp index 1c4ee4cc5..c62114c3f 100644 --- a/src/tools/fuzzing/fuzzing.cpp +++ b/src/tools/fuzzing/fuzzing.cpp @@ -20,6 +20,7 @@ #include "ir/module-utils.h" #include "ir/subtypes.h" #include "ir/type-updating.h" +#include "support/string.h" #include "tools/fuzzing/heap-types.h" #include "tools/fuzzing/parameters.h" @@ -2465,8 +2466,13 @@ Expression* TranslateToFuzzReader::makeBasicRef(Type type) { } return null; } - case HeapType::string: - return builder.makeStringConst(std::to_string(upTo(1024))); + case HeapType::string: { + auto wtf8 = std::to_string(upTo(1024)); + std::stringstream wtf16; + String::convertWTF8ToWTF16(wtf16, wtf8); + // TODO: Use wtf16.view() once we have C++20. + return builder.makeStringConst(wtf16.str()); + } case HeapType::stringview_wtf8: return builder.makeStringAs( StringAsWTF8, makeBasicRef(Type(HeapType::string, NonNullable))); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 7d4991d9a..cc90a8abe 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -1265,12 +1265,17 @@ public: return makeRefI31(makeConst(value.geti31())); } if (type.isString()) { - // TODO: more than ascii support - std::string string; + // The string is already WTF-16, but we need to convert from `Literals` to + // actual string. + std::stringstream wtf16; for (auto c : value.getGCData()->values) { - string.push_back(c.getInteger()); + auto u = c.getInteger(); + assert(u < 0x10000); + wtf16 << uint8_t(u & 0xFF); + wtf16 << uint8_t(u >> 8); } - return makeStringConst(string); + // TODO: Use wtf16.view() once we have C++20. + return makeStringConst(wtf16.str()); } if (type.isRef() && type.getHeapType() == HeapType::ext) { return makeRefAs(ExternExternalize, diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index c8031f617..f34cb83be 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1900,23 +1900,7 @@ public: return Flow(NONCONSTANT_FLOW); } } - Flow visitStringConst(StringConst* curr) { - return Literal(curr->string.toString()); - } - - // Returns if there is a non-ascii character in a list of values, looking only - // up to an index that is provided (not inclusive). If the index is not - // provided we look in the entire list. - bool hasNonAsciiUpTo(const Literals& values, - std::optional<Index> maybeEnd = std::nullopt) { - Index end = maybeEnd ? *maybeEnd : values.size(); - for (Index i = 0; i < end; ++i) { - if (uint32_t(values[i].geti32()) > 127) { - return true; - } - } - return false; - } + Flow visitStringConst(StringConst* curr) { return Literal(curr->string.str); } Flow visitStringMeasure(StringMeasure* curr) { // For now we only support JS-style strings. @@ -1934,12 +1918,6 @@ public: trap("null ref"); } - // This is only correct if all the bytes stored in `values` correspond to - // single unicode code points. See `visitStringWTF16Get` for details. - if (hasNonAsciiUpTo(data->values)) { - return Flow(NONCONSTANT_FLOW); - } - return Literal(int32_t(data->values.size())); } Flow visitStringConcat(StringConcat* curr) { @@ -1960,18 +1938,13 @@ public: if (!leftData || !rightData) { trap("null ref"); } - // This is only correct if all the bytes in the left operand correspond - // to single unicode code points. - if (hasNonAsciiUpTo(leftData->values)) { - return Flow(NONCONSTANT_FLOW); - } Literals contents; contents.reserve(leftData->values.size() + rightData->values.size()); - for (Literal l : leftData->values) { + for (Literal& l : leftData->values) { contents.push_back(l); } - for (Literal l : rightData->values) { + for (Literal& l : rightData->values) { contents.push_back(l); } @@ -2011,11 +1984,6 @@ public: trap("oob"); } - // We don't handle non-ascii code points correctly yet. - if (hasNonAsciiUpTo(refValues)) { - return Flow(NONCONSTANT_FLOW); - } - for (Index i = 0; i < refValues.size(); i++) { ptrValues[startVal + i] = refValues[i]; } @@ -2132,17 +2100,6 @@ public: trap("string oob"); } - // This naive indexing approach is only correct if the first `i` bytes - // stored in `values` each corresponds to a single unicode code point. To - // implement this correctly in general, we would have to reinterpret the - // bytes as WTF-8, then count up to the `i`th code point, accounting - // properly for code points that would be represented by surrogate pairs in - // WTF-16. Alternatively, we could represent string contents as WTF-16 to - // begin with. - if (hasNonAsciiUpTo(values, i + 1)) { - return Flow(NONCONSTANT_FLOW); - } - return Literal(values[i].geti32()); } Flow visitStringIterNext(StringIterNext* curr) { @@ -2178,9 +2135,7 @@ public: auto startVal = start.getSingleValue().getUnsigned(); auto endVal = end.getSingleValue().getUnsigned(); endVal = std::min<size_t>(endVal, refValues.size()); - if (hasNonAsciiUpTo(refValues, endVal)) { - return Flow(NONCONSTANT_FLOW); - } + Literals contents; if (endVal > startVal) { contents.reserve(endVal - startVal); diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 7c674ffc5..afdc14c72 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -23,6 +23,7 @@ #include "ir/bits.h" #include "pretty_printing.h" #include "support/bits.h" +#include "support/string.h" #include "support/utilities.h" namespace wasm { @@ -77,12 +78,15 @@ Literal::Literal(std::shared_ptr<GCData> gcData, HeapType type) (type.isBottom() && !gcData)); } -Literal::Literal(std::string string) +Literal::Literal(std::string_view string) : gcData(nullptr), type(Type(HeapType::string, NonNullable)) { // TODO: we could in theory internalize strings + // Extract individual WTF-16LE code units. Literals contents; - for (auto c : string) { - contents.push_back(Literal(int32_t(c))); + assert(string.size() % 2 == 0); + for (size_t i = 0; i < string.size(); i += 2) { + int32_t u = uint8_t(string[i]) | (uint8_t(string[i + 1]) << 8); + contents.push_back(Literal(u)); } gcData = std::make_shared<GCData>(HeapType::string, contents); } @@ -636,10 +640,19 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { o << "nullstring"; } else { o << "string(\""; + // Convert WTF-16 literals to WTF-16 string. + std::stringstream wtf16; for (auto c : data->values) { - // TODO: more than ascii - o << char(c.getInteger()); + auto u = c.getInteger(); + assert(u < 0x10000); + wtf16 << uint8_t(u & 0xFF); + wtf16 << uint8_t(u >> 8); } + // Convert to WTF-8 for printing. + // TODO: Use wtf16.view() once we have C++20. + [[maybe_unused]] bool valid = + String::convertWTF16ToWTF8(o, wtf16.str()); + assert(valid); o << "\")"; } break; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 55cdd726c..f9a643d66 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -23,6 +23,7 @@ #include "ir/type-updating.h" #include "support/bits.h" #include "support/debug.h" +#include "support/string.h" #include "wasm-binary.h" #include "wasm-debug.h" #include "wasm-stack.h" @@ -527,7 +528,12 @@ void WasmBinaryWriter::writeStrings() { // The number of strings and then their contents. o << U32LEB(num); for (auto& string : sorted) { - writeInlineString(string.str); + // Re-encode from WTF-16 to WTF-8. + std::stringstream wtf8; + [[maybe_unused]] bool valid = String::convertWTF16ToWTF8(wtf8, string.str); + assert(valid); + // TODO: Use wtf8.view() once we have C++20. + writeInlineString(wtf8.str()); } finishSection(start); @@ -2960,7 +2966,13 @@ void WasmBinaryReader::readStrings() { size_t num = getU32LEB(); for (size_t i = 0; i < num; i++) { auto string = getInlineString(); - strings.push_back(string); + // Re-encode from WTF-8 to WTF-16. + std::stringstream wtf16; + if (!String::convertWTF8ToWTF16(wtf16, string.str)) { + throwError("invalid string constant"); + } + // TODO: Use wtf16.view() once we have C++20. + strings.push_back(wtf16.str()); } } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 2d4f3fffe..d6df7ab0f 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -3335,7 +3335,13 @@ Expression* SExpressionWasmBuilder::makeStringConst(Element& s) { std::vector<char> data; stringToBinary(*s[1], s[1]->str().str, data); Name str = std::string_view(data.data(), data.size()); - return Builder(wasm).makeStringConst(str); + // Re-encode from WTF-8 to WTF-16. + std::stringstream wtf16; + if (!String::convertWTF8ToWTF16(wtf16, str.str)) { + throw SParseException("invalid string constant", s); + } + // TODO: Use wtf16.view() once we have C++20. + return Builder(wasm).makeStringConst(wtf16.str()); } Expression* SExpressionWasmBuilder::makeStringMeasure(Element& s, diff --git a/test/lit/passes/precompute-strings.wast b/test/lit/passes/precompute-strings.wast index 58c7e52d0..6046ee6b1 100644 --- a/test/lit/passes/precompute-strings.wast +++ b/test/lit/passes/precompute-strings.wast @@ -12,13 +12,15 @@ ;; CHECK: (type $3 (func (result (ref any)))) - ;; CHECK: (export "get_codepoint-bad" (func $get_codepoint-bad)) + ;; CHECK: (export "get_codepoint-unicode" (func $get_codepoint-unicode)) + + ;; CHECK: (export "get_codepoint-surrogate" (func $get_codepoint-surrogate)) ;; CHECK: (export "test" (func $encode-stashed)) ;; CHECK: (export "slice" (func $slice)) - ;; CHECK: (export "slice-bad" (func $slice-bad)) + ;; CHECK: (export "slice-unicode" (func $slice-unicode)) ;; CHECK: (func $eq-no (type $0) (result i32) ;; CHECK-NEXT: (i32.const 0) @@ -50,19 +52,14 @@ ) ) - ;; CHECK: (func $concat-bad (type $0) (result i32) - ;; CHECK-NEXT: (string.eq - ;; CHECK-NEXT: (string.concat - ;; CHECK-NEXT: (string.const "a\f0") - ;; CHECK-NEXT: (string.const "b") - ;; CHECK-NEXT: ) - ;; CHECK-NEXT: (string.const "a\f0b") - ;; CHECK-NEXT: ) + ;; CHECK: (func $concat-surrogates (type $0) (result i32) + ;; CHECK-NEXT: (i32.const 1) ;; CHECK-NEXT: ) - (func $concat-bad (result i32) + (func $concat-surrogates (result i32) (string.eq - (string.concat (string.const "a\F0") (string.const "b")) - (string.const "a\F0b") + ;; Concatenating these surrogates creates '𐍈', which has a different UTF-8 encoding. + (string.concat (string.const "\ED\A0\80") (string.const "\ED\BD\88")) + (string.const "\F0\90\8D\88") ) ) @@ -77,18 +74,13 @@ ) ) - ;; CHECK: (func $length-bad (type $0) (result i32) - ;; CHECK-NEXT: (stringview_wtf16.length - ;; CHECK-NEXT: (string.as_wtf16 - ;; CHECK-NEXT: (string.const "$_\c2\a3_\e2\82\ac_\f0\90\8d\88") - ;; CHECK-NEXT: ) - ;; CHECK-NEXT: ) + ;; CHECK: (func $length-unicode (type $0) (result i32) + ;; CHECK-NEXT: (i32.const 8) ;; CHECK-NEXT: ) - (func $length-bad (result i32) - ;; Not precomputable because we don't handle unicode yet. + (func $length-unicode (result i32) (stringview_wtf16.length (string.as_wtf16 - ;; $_£_€_𐍈 + ;; $_£_€_𐍈 (the last character is encoded as a surrogate pair) (string.const "$_\C2\A3_\E2\82\AC_\F0\90\8D\88") ) ) @@ -98,7 +90,7 @@ ;; CHECK-NEXT: (i32.const 95) ;; CHECK-NEXT: ) (func $get_codepoint (result i32) - ;; This is computable because everything up to the requested index is ascii. Returns 95 ('_'). + ;; Returns 95 ('_'). (stringview_wtf16.get_codeunit (string.as_wtf16 ;; $_£_€_𐍈 @@ -108,22 +100,31 @@ ) ) - ;; CHECK: (func $get_codepoint-bad (type $0) (result i32) - ;; CHECK-NEXT: (stringview_wtf16.get_codeunit - ;; CHECK-NEXT: (string.as_wtf16 - ;; CHECK-NEXT: (string.const "$_\c2\a3_\e2\82\ac_\f0\90\8d\88") - ;; CHECK-NEXT: ) - ;; CHECK-NEXT: (i32.const 2) - ;; CHECK-NEXT: ) + ;; CHECK: (func $get_codepoint-unicode (type $0) (result i32) + ;; CHECK-NEXT: (i32.const 8364) ;; CHECK-NEXT: ) - (func $get_codepoint-bad (export "get_codepoint-bad") (result i32) - ;; This is not computable because the requested code unit is not ascii. + (func $get_codepoint-unicode (export "get_codepoint-unicode") (result i32) + ;; Returns 8364 ('€') (stringview_wtf16.get_codeunit (string.as_wtf16 ;; $_£_€_𐍈 (string.const "$_\C2\A3_\E2\82\AC_\F0\90\8D\88") ) - (i32.const 2) + (i32.const 4) + ) + ) + + ;; CHECK: (func $get_codepoint-surrogate (type $0) (result i32) + ;; CHECK-NEXT: (i32.const 55296) + ;; CHECK-NEXT: ) + (func $get_codepoint-surrogate (export "get_codepoint-surrogate") (result i32) + ;; Returns 0xd800 (the high surrogate in '𐍈') + (stringview_wtf16.get_codeunit + (string.as_wtf16 + ;; $_£_€_𐍈 + (string.const "$_\C2\A3_\E2\82\AC_\F0\90\8D\88") + ) + (i32.const 6) ) ) @@ -148,7 +149,7 @@ ) ) - ;; CHECK: (func $encode-bad (type $0) (result i32) + ;; CHECK: (func $encode-unicode (type $0) (result i32) ;; CHECK-NEXT: (string.encode_wtf16_array ;; CHECK-NEXT: (string.const "$_\c2\a3_\e2\82\ac_\f0\90\8d\88") ;; CHECK-NEXT: (array.new_default $array16 @@ -157,7 +158,7 @@ ;; CHECK-NEXT: (i32.const 0) ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) - (func $encode-bad (result i32) + (func $encode-unicode (result i32) (string.encode_wtf16_array ;; $_£_€_𐍈 (string.const "$_\C2\A3_\E2\82\AC_\F0\90\8D\88") @@ -220,17 +221,10 @@ ) ) - ;; CHECK: (func $slice-bad (type $2) (result (ref string)) - ;; CHECK-NEXT: (stringview_wtf16.slice - ;; CHECK-NEXT: (string.as_wtf16 - ;; CHECK-NEXT: (string.const "abcd\c2\a3fgh") - ;; CHECK-NEXT: ) - ;; CHECK-NEXT: (i32.const 3) - ;; CHECK-NEXT: (i32.const 6) - ;; CHECK-NEXT: ) + ;; CHECK: (func $slice-unicode (type $2) (result (ref string)) + ;; CHECK-NEXT: (string.const "d\c2\a3f") ;; CHECK-NEXT: ) - (func $slice-bad (export "slice-bad") (result (ref string)) - ;; This slice contains non-ascii, so we do not optimize. + (func $slice-unicode (export "slice-unicode") (result (ref string)) (stringview_wtf16.slice ;; abcd£fgh (string.as_wtf16 diff --git a/test/lit/passes/string-lowering.wast b/test/lit/passes/string-lowering.wast index f7f47871b..c060bc8bd 100644 --- a/test/lit/passes/string-lowering.wast +++ b/test/lit/passes/string-lowering.wast @@ -16,18 +16,6 @@ (drop (string.const "needs\tescaping\00.'#%\"- .\r\n\\08\0C\0A\0D\09.ꙮ") ) - (drop - (string.const "invalid WTF-8 leading byte \FF") - ) - (drop - (string.const "invalid trailing byte \C0\00") - ) - (drop - (string.const "unexpected end \C0") - ) - (drop - (string.const "invalid surrogate sequence \ED\A0\81\ED\B0\B7") - ) ) ) @@ -36,7 +24,7 @@ ;; ;; RUN: wasm-opt %s --string-lowering -all -S -o - | filecheck %s ;; -;; CHECK: custom section "string.consts", size 202, contents: "[\"bar\",\"foo\",\"invalid WTF-8 leading byte \\ufffd\",\"invalid surrogate sequence \\ud801\\udc37\",\"invalid trailing byte \\ufffd\",\"needs\\tescaping\\u0000.'#%\\\"- .\\r\\n\\\\08\\f\\n\\r\\t.\\ua66e\",\"unexpected end \\ufffd\"]" +;; CHECK: custom section "string.consts", size 69, contents: "[\"bar\",\"foo\",\"needs\\tescaping\\u0000.'#%\\\"- .\\r\\n\\\\08\\f\\n\\r\\t.\\ua66e\"]" ;; The custom section should parse OK using JSON.parse from node. ;; (Note we run --remove-unused-module-elements to remove externref-using @@ -45,5 +33,5 @@ ;; RUN: wasm-opt %s --string-lowering --remove-unused-module-elements -all -o %t.wasm ;; RUN: node %S/string-lowering.js %t.wasm | filecheck %s --check-prefix=CHECK-JS ;; -;; CHECK-JS: string: ["bar","foo","invalid WTF-8 leading byte \ufffd","invalid surrogate sequence \ud801\udc37","invalid trailing byte \ufffd","needs\tescaping\x00.'#%\"- .\r\n\\08\f\n\r\t.\ua66e","unexpected end \ufffd"] -;; CHECK-JS: JSON: ["bar","foo","invalid WTF-8 leading byte �","invalid surrogate sequence 𐐷","invalid trailing byte �","needs\tescaping\x00.'#%\"- .\r\n\\08\f\n\r\t.ꙮ","unexpected end �"] +;; CHECK-JS: string: ["bar","foo","needs\tescaping\x00.'#%\"- .\r\n\\08\f\n\r\t.\ua66e"] +;; CHECK-JS: JSON: ["bar","foo","needs\tescaping\x00.'#%\"- .\r\n\\08\f\n\r\t.ꙮ"] diff --git a/test/lit/strings.wast b/test/lit/strings.wast index b7cf0bf63..1166f13d0 100644 --- a/test/lit/strings.wast +++ b/test/lit/strings.wast @@ -41,8 +41,8 @@ ;; CHECK: (type $12 (func (param stringref) (result i32))) - ;; CHECK: (global $string-const stringref (string.const "string in a global \01\ff\00\t\t\n\n\r\r\"\"\'\'\\\\")) - (global $string-const stringref (string.const "string in a global \01\ff\00\t\09\n\0a\r\0d\"\22\'\27\\\5c")) + ;; CHECK: (global $string-const stringref (string.const "string in a global \c2\a3_\e2\82\ac_\f0\90\8d\88 \01\00\t\t\n\n\r\r\"\"\'\'\\\\ ")) + (global $string-const stringref (string.const "string in a global \C2\A3_\E2\82\AC_\F0\90\8D\88 \01\00\t\t\n\n\r\r\"\"\'\'\\\\ ")) ;; CHECK: (memory $0 10 10) |