From e74f66e92affd637cb19af2ad5f3d015ba86aa1c Mon Sep 17 00:00:00 2001 From: Thomas Lively <7121787+tlively@users.noreply.github.com> Date: Wed, 1 Jun 2022 13:00:54 -0700 Subject: [Parser] Token classification (#4699) Add methods to `Token` for determining whether the token can be interpreted as a particular token type, returning the interpreted value as appropriate. These methods perform additional bounds checks for integers and NaN payloads that could not be done during the initial lexing because the lexer did not know what the intended token type was. The float methods also reinterpret integer tokens as floating point tokens since the float grammar is a superset of the integer grammar and inject the NaN payloads into parsed NaN values. Move all bounds checking to these new classifier functions to have it in one place. --- src/wasm/wat-lexer.cpp | 173 ++++++++++++++++++++++++++++++++++++++++++------- src/wat-lexer.h | 32 +++++++++ 2 files changed, 183 insertions(+), 22 deletions(-) (limited to 'src') diff --git a/src/wasm/wat-lexer.cpp b/src/wasm/wat-lexer.cpp index 4bea32b59..0d1dc2794 100644 --- a/src/wasm/wat-lexer.cpp +++ b/src/wasm/wat-lexer.cpp @@ -151,23 +151,10 @@ public: if (overflow) { return {}; } - auto basic = LexCtx::lexed(); - if (!basic) { - return {}; - } - // Check most significant bit for overflow of signed numbers. - if (sign == Neg) { - if (n > (1ull << 63)) { - // TODO: Add error production for signed underflow. - return {}; - } - } else if (sign == Pos) { - if (n > (1ull << 63) - 1) { - // TODO: Add error production for signed overflow. - return {}; - } + if (auto basic = LexCtx::lexed()) { + return LexIntResult{*basic, sign == Neg ? -n : n, sign}; } - return LexIntResult{*basic, sign == Neg ? -n : n, sign}; + return {}; } void takeSign() { @@ -592,12 +579,7 @@ std::optional float_(std::string_view in) { if (ctx.takePrefix(":0x"sv)) { if (auto lexed = hexnum(ctx.next())) { ctx.take(*lexed); - if (1 <= lexed->n && lexed->n < (1ull << 52)) { - ctx.nanPayload = lexed->n; - } else { - // TODO: Add error production for invalid NaN payload. - return {}; - } + ctx.nanPayload = lexed->n; } else { // TODO: Add error production for malformed NaN payload. return {}; @@ -781,6 +763,153 @@ std::optional keyword(std::string_view in) { } // anonymous namespace +std::optional Token::getU64() const { + if (auto* tok = std::get_if(&data)) { + if (tok->sign == NoSign) { + return tok->n; + } + } + return {}; +} + +std::optional Token::getS64() const { + if (auto* tok = std::get_if(&data)) { + if (tok->sign == Neg) { + if (uint64_t(INT64_MIN) <= tok->n || tok->n == 0) { + return int64_t(tok->n); + } + // TODO: Add error production for signed underflow. + } else { + if (tok->n <= uint64_t(INT64_MAX)) { + return int64_t(tok->n); + } + // TODO: Add error production for signed overflow. + } + } + return {}; +} + +std::optional Token::getI64() const { + if (auto n = getU64()) { + return *n; + } + if (auto n = getS64()) { + return *n; + } + return {}; +} + +std::optional Token::getU32() const { + if (auto* tok = std::get_if(&data)) { + if (tok->sign == NoSign && tok->n <= UINT32_MAX) { + return int32_t(tok->n); + } + // TODO: Add error production for unsigned overflow. + } + return {}; +} + +std::optional Token::getS32() const { + if (auto* tok = std::get_if(&data)) { + if (tok->sign == Neg) { + if (uint64_t(INT32_MIN) <= tok->n || tok->n == 0) { + return int32_t(tok->n); + } + } else { + if (tok->n <= uint64_t(INT32_MAX)) { + return int32_t(tok->n); + } + } + } + return {}; +} + +std::optional Token::getI32() const { + if (auto n = getU32()) { + return *n; + } + if (auto n = getS32()) { + return uint32_t(*n); + } + return {}; +} + +std::optional Token::getF64() const { + constexpr int signif = 52; + constexpr uint64_t payloadMask = (1ull << signif) - 1; + constexpr uint64_t nanDefault = 1ull << (signif - 1); + if (auto* tok = std::get_if(&data)) { + double d = tok->d; + if (std::isnan(d)) { + // Inject payload. + uint64_t payload = tok->nanPayload ? *tok->nanPayload : nanDefault; + if (payload == 0 || payload > payloadMask) { + // TODO: Add error production for out-of-bounds payload. + return {}; + } + uint64_t bits; + static_assert(sizeof(bits) == sizeof(d)); + memcpy(&bits, &d, sizeof(bits)); + bits = (bits & ~payloadMask) | payload; + memcpy(&d, &bits, sizeof(bits)); + } + return d; + } + if (auto* tok = std::get_if(&data)) { + if (tok->sign == Neg) { + if (tok->n == 0) { + return -0.0; + } + return double(int64_t(tok->n)); + } + return double(tok->n); + } + return {}; +} + +std::optional Token::getF32() const { + constexpr int signif = 23; + constexpr uint32_t payloadMask = (1u << signif) - 1; + constexpr uint64_t nanDefault = 1ull << (signif - 1); + if (auto* tok = std::get_if(&data)) { + float f = tok->d; + if (std::isnan(f)) { + // Validate and inject payload. + uint64_t payload = tok->nanPayload ? *tok->nanPayload : nanDefault; + if (payload == 0 || payload > payloadMask) { + // TODO: Add error production for out-of-bounds payload. + return {}; + } + uint32_t bits; + static_assert(sizeof(bits) == sizeof(f)); + memcpy(&bits, &f, sizeof(bits)); + bits = (bits & ~payloadMask) | payload; + memcpy(&f, &bits, sizeof(bits)); + } + return f; + } + if (auto* tok = std::get_if(&data)) { + if (tok->sign == Neg) { + if (tok->n == 0) { + return -0.0f; + } + return float(int64_t(tok->n)); + } + return float(tok->n); + } + return {}; +} + +std::optional Token::getString() const { + if (auto* tok = std::get_if(&data)) { + if (tok->str) { + return std::string_view(*tok->str); + } + return span.substr(1, span.size() - 2); + } + return {}; +} + void Lexer::skipSpace() { if (auto ctx = space(next())) { index += ctx->span.size(); diff --git a/src/wat-lexer.h b/src/wat-lexer.h index 5a955f5c0..e4ba2efa8 100644 --- a/src/wat-lexer.h +++ b/src/wat-lexer.h @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -101,6 +102,37 @@ struct Token { std::string_view span; Data data; + // ==================== + // Token classification + // ==================== + + bool isLParen() const { return std::get_if(&data); } + + bool isRParen() const { return std::get_if(&data); } + + std::optional getID() const { + if (std::get_if(&data)) { + return span; + } + return {}; + } + + std::optional getKeyword() const { + if (std::get_if(&data)) { + return span; + } + return {}; + } + std::optional getU64() const; + std::optional getS64() const; + std::optional getI64() const; + std::optional getU32() const; + std::optional getS32() const; + std::optional getI32() const; + std::optional getF64() const; + std::optional getF32() const; + std::optional getString() const; + bool operator==(const Token&) const; friend std::ostream& operator<<(std::ostream& os, const Token&); }; -- cgit v1.2.3