diff options
author | Brendan Dahl <brendan.dahl@gmail.com> | 2024-08-06 14:15:58 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-06 14:15:58 -0700 |
commit | 0c269482097ae9da62a690b0ace406e2d2109c48 (patch) | |
tree | d7fa2f2f5b9e7e703dce9805f9dbd1bb16f65cfb | |
parent | d5a5425c0c76cfc08711b81d6ec70c3a2879e405 (diff) | |
download | binaryen-0c269482097ae9da62a690b0ace406e2d2109c48.tar.gz binaryen-0c269482097ae9da62a690b0ace406e2d2109c48.tar.bz2 binaryen-0c269482097ae9da62a690b0ace406e2d2109c48.zip |
[FP16] Implement load and store instructions. (#6796)
Specified at
https://github.com/WebAssembly/half-precision/blob/main/proposals/half-precision/Overview.md
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | LICENSE | 5 | ||||
-rwxr-xr-x | scripts/gen-s-parser.py | 2 | ||||
-rw-r--r-- | src/gen-s-parser.inc | 42 | ||||
-rw-r--r-- | src/passes/Print.cpp | 16 | ||||
-rw-r--r-- | src/wasm-binary.h | 12 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 37 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 65 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 30 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 5 | ||||
-rw-r--r-- | test/lit/basic/f16.wast | 79 | ||||
-rw-r--r-- | test/spec/f16.wast | 15 | ||||
-rw-r--r-- | third_party/FP16/LICENSE | 11 | ||||
-rw-r--r-- | third_party/FP16/include/fp16.h | 7 | ||||
-rw-r--r-- | third_party/FP16/include/fp16/bitcasts.h | 92 | ||||
-rw-r--r-- | third_party/FP16/include/fp16/fp16.h | 515 | ||||
-rw-r--r-- | third_party/FP16/include/fp16/macros.h | 46 | ||||
-rw-r--r-- | third_party/FP16/readme.txt | 13 |
18 files changed, 952 insertions, 41 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b38f0baa..94224299c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -175,6 +175,7 @@ endif() # Compiler setup. include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/FP16/include) if(BUILD_LLVM_DWARF) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/include) endif() @@ -199,3 +199,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +============================================================================== + +The FP16 project is used in this repo, and it has the MIT license, see +third_party/FP16/LICENSE. diff --git a/scripts/gen-s-parser.py b/scripts/gen-s-parser.py index e1bfe3f3d..f35109644 100755 --- a/scripts/gen-s-parser.py +++ b/scripts/gen-s-parser.py @@ -46,6 +46,7 @@ instructions = [ ("i32.load", "makeLoad(Type::i32, /*signed=*/false, 4, /*isAtomic=*/false)"), ("i64.load", "makeLoad(Type::i64, /*signed=*/false, 8, /*isAtomic=*/false)"), ("f32.load", "makeLoad(Type::f32, /*signed=*/false, 4, /*isAtomic=*/false)"), + ("f32.load_f16", "makeLoad(Type::f32, /*signed=*/false, 2, /*isAtomic=*/false)"), ("f64.load", "makeLoad(Type::f64, /*signed=*/false, 8, /*isAtomic=*/false)"), ("i32.load8_s", "makeLoad(Type::i32, /*signed=*/true, 1, /*isAtomic=*/false)"), ("i32.load8_u", "makeLoad(Type::i32, /*signed=*/false, 1, /*isAtomic=*/false)"), @@ -60,6 +61,7 @@ instructions = [ ("i32.store", "makeStore(Type::i32, 4, /*isAtomic=*/false)"), ("i64.store", "makeStore(Type::i64, 8, /*isAtomic=*/false)"), ("f32.store", "makeStore(Type::f32, 4, /*isAtomic=*/false)"), + ("f32.store_f16", "makeStore(Type::f32, 2, /*isAtomic=*/false)"), ("f64.store", "makeStore(Type::f64, 8, /*isAtomic=*/false)"), ("i32.store8", "makeStore(Type::i32, 1, /*isAtomic=*/false)"), ("i32.store16", "makeStore(Type::i32, 2, /*isAtomic=*/false)"), diff --git a/src/gen-s-parser.inc b/src/gen-s-parser.inc index 98f5b7831..c54a0de54 100644 --- a/src/gen-s-parser.inc +++ b/src/gen-s-parser.inc @@ -454,12 +454,23 @@ switch (buf[0]) { return Ok{}; } goto parse_error; - case 'o': - if (op == "f32.load"sv) { - CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 4, /*isAtomic=*/false)); - return Ok{}; + case 'o': { + switch (buf[8]) { + case '\0': + if (op == "f32.load"sv) { + CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 4, /*isAtomic=*/false)); + return Ok{}; + } + goto parse_error; + case '_': + if (op == "f32.load_f16"sv) { + CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 2, /*isAtomic=*/false)); + return Ok{}; + } + goto parse_error; + default: goto parse_error; } - goto parse_error; + } case 't': if (op == "f32.lt"sv) { CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::LtFloat32)); @@ -529,12 +540,23 @@ switch (buf[0]) { return Ok{}; } goto parse_error; - case 't': - if (op == "f32.store"sv) { - CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 4, /*isAtomic=*/false)); - return Ok{}; + case 't': { + switch (buf[9]) { + case '\0': + if (op == "f32.store"sv) { + CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 4, /*isAtomic=*/false)); + return Ok{}; + } + goto parse_error; + case '_': + if (op == "f32.store_f16"sv) { + CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 2, /*isAtomic=*/false)); + return Ok{}; + } + goto parse_error; + default: goto parse_error; } - goto parse_error; + } case 'u': if (op == "f32.sub"sv) { CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::SubFloat32)); diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index ede49ab38..aca43924d 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -548,13 +548,19 @@ struct PrintExpressionContents if (curr->bytes == 1) { o << '8'; } else if (curr->bytes == 2) { - o << "16"; + if (curr->type == Type::f32) { + o << "_f16"; + } else { + o << "16"; + } } else if (curr->bytes == 4) { o << "32"; } else { abort(); } - o << (curr->signed_ ? "_s" : "_u"); + if (curr->type != Type::f32) { + o << (curr->signed_ ? "_s" : "_u"); + } } restoreNormalColor(o); printMemoryName(curr->memory, o, wasm); @@ -575,7 +581,11 @@ struct PrintExpressionContents if (curr->bytes == 1) { o << '8'; } else if (curr->bytes == 2) { - o << "16"; + if (curr->valueType == Type::f32) { + o << "_f16"; + } else { + o << "16"; + } } else if (curr->bytes == 4) { o << "32"; } else { diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 8a6f825ff..5fae1b64d 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1051,6 +1051,10 @@ enum ASTNodes { I16x8DotI8x16I7x16S = 0x112, I32x4DotI8x16I7x16AddS = 0x113, + // half precision opcodes + F32_F16LoadMem = 0x30, + F32_F16StoreMem = 0x31, + // bulk memory opcodes MemoryInit = 0x08, @@ -1703,8 +1707,12 @@ public: void visitLocalSet(LocalSet* curr, uint8_t code); void visitGlobalGet(GlobalGet* curr); void visitGlobalSet(GlobalSet* curr); - bool maybeVisitLoad(Expression*& out, uint8_t code, bool isAtomic); - bool maybeVisitStore(Expression*& out, uint8_t code, bool isAtomic); + bool maybeVisitLoad(Expression*& out, + uint8_t code, + std::optional<BinaryConsts::ASTNodes> prefix); + bool maybeVisitStore(Expression*& out, + uint8_t code, + std::optional<BinaryConsts::ASTNodes> prefix); bool maybeVisitNontrappingTrunc(Expression*& out, uint32_t code); bool maybeVisitAtomicRMW(Expression*& out, uint8_t code); bool maybeVisitAtomicCmpxchg(Expression*& out, uint8_t code); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 9bdf0e72c..3e62d5335 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -28,6 +28,7 @@ #include <sstream> #include <variant> +#include "fp16.h" #include "ir/intrinsics.h" #include "ir/module-utils.h" #include "support/bits.h" @@ -2540,8 +2541,22 @@ public: } break; } - case Type::f32: - return Literal(load32u(addr, memory)).castToF32(); + case Type::f32: { + switch (load->bytes) { + case 2: { + // Convert the float16 to float32 and store the binary + // representation. + return Literal(bit_cast<int32_t>( + fp16_ieee_to_fp32_value(load16u(addr, memory)))) + .castToF32(); + } + case 4: + return Literal(load32u(addr, memory)).castToF32(); + default: + WASM_UNREACHABLE("invalid size"); + } + break; + } case Type::f64: return Literal(load64u(addr, memory)).castToF64(); case Type::v128: @@ -2590,9 +2605,23 @@ public: break; } // write floats carefully, ensuring all bits reach memory - case Type::f32: - store32(addr, value.reinterpreti32(), memory); + case Type::f32: { + switch (store->bytes) { + case 2: { + float f32 = bit_cast<float>(value.reinterpreti32()); + // Convert the float32 to float16 and store the binary + // representation. + store16(addr, fp16_ieee_from_fp32_value(f32), memory); + break; + } + case 4: + store32(addr, value.reinterpreti32(), memory); + break; + default: + WASM_UNREACHABLE("invalid store size"); + } break; + } case Type::f64: store64(addr, value.reinterpreti64(), memory); break; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 574e13aa2..b9645ab8f 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -4145,10 +4145,10 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) { } case BinaryConsts::AtomicPrefix: { code = static_cast<uint8_t>(getU32LEB()); - if (maybeVisitLoad(curr, code, /*isAtomic=*/true)) { + if (maybeVisitLoad(curr, code, BinaryConsts::AtomicPrefix)) { break; } - if (maybeVisitStore(curr, code, /*isAtomic=*/true)) { + if (maybeVisitStore(curr, code, BinaryConsts::AtomicPrefix)) { break; } if (maybeVisitAtomicRMW(curr, code)) { @@ -4198,6 +4198,12 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) { if (maybeVisitTableCopy(curr, opcode)) { break; } + if (maybeVisitLoad(curr, opcode, BinaryConsts::MiscPrefix)) { + break; + } + if (maybeVisitStore(curr, opcode, BinaryConsts::MiscPrefix)) { + break; + } throwError("invalid code after misc prefix: " + std::to_string(opcode)); break; } @@ -4338,10 +4344,10 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) { if (maybeVisitConst(curr, code)) { break; } - if (maybeVisitLoad(curr, code, /*isAtomic=*/false)) { + if (maybeVisitLoad(curr, code, /*prefix=*/std::nullopt)) { break; } - if (maybeVisitStore(curr, code, /*isAtomic=*/false)) { + if (maybeVisitStore(curr, code, /*prefix=*/std::nullopt)) { break; } throwError("bad node code " + std::to_string(code)); @@ -4717,14 +4723,15 @@ Index WasmBinaryReader::readMemoryAccess(Address& alignment, Address& offset) { return memIdx; } -bool WasmBinaryReader::maybeVisitLoad(Expression*& out, - uint8_t code, - bool isAtomic) { +bool WasmBinaryReader::maybeVisitLoad( + Expression*& out, + uint8_t code, + std::optional<BinaryConsts::ASTNodes> prefix) { Load* curr; auto allocate = [&]() { curr = allocator.alloc<Load>(); }; - if (!isAtomic) { + if (!prefix) { switch (code) { case BinaryConsts::I32LoadMem8S: allocate(); @@ -4805,7 +4812,7 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out, return false; } BYN_TRACE("zz node: Load\n"); - } else { + } else if (prefix == BinaryConsts::AtomicPrefix) { switch (code) { case BinaryConsts::I32AtomicLoad8U: allocate(); @@ -4846,9 +4853,22 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out, return false; } BYN_TRACE("zz node: AtomicLoad\n"); + } else if (prefix == BinaryConsts::MiscPrefix) { + switch (code) { + case BinaryConsts::F32_F16LoadMem: + allocate(); + curr->bytes = 2; + curr->type = Type::f32; + break; + default: + return false; + } + BYN_TRACE("zz node: Load\n"); + } else { + return false; } - curr->isAtomic = isAtomic; + curr->isAtomic = prefix == BinaryConsts::AtomicPrefix; Index memIdx = readMemoryAccess(curr->align, curr->offset); memoryRefs[memIdx].push_back(&curr->memory); curr->ptr = popNonVoidExpression(); @@ -4857,11 +4877,12 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out, return true; } -bool WasmBinaryReader::maybeVisitStore(Expression*& out, - uint8_t code, - bool isAtomic) { +bool WasmBinaryReader::maybeVisitStore( + Expression*& out, + uint8_t code, + std::optional<BinaryConsts::ASTNodes> prefix) { Store* curr; - if (!isAtomic) { + if (!prefix) { switch (code) { case BinaryConsts::I32StoreMem8: curr = allocator.alloc<Store>(); @@ -4911,7 +4932,7 @@ bool WasmBinaryReader::maybeVisitStore(Expression*& out, default: return false; } - } else { + } else if (prefix == BinaryConsts::AtomicPrefix) { switch (code) { case BinaryConsts::I32AtomicStore8: curr = allocator.alloc<Store>(); @@ -4951,9 +4972,21 @@ bool WasmBinaryReader::maybeVisitStore(Expression*& out, default: return false; } + } else if (prefix == BinaryConsts::MiscPrefix) { + switch (code) { + case BinaryConsts::F32_F16StoreMem: + curr = allocator.alloc<Store>(); + curr->bytes = 2; + curr->valueType = Type::f32; + break; + default: + return false; + } + } else { + return false; } - curr->isAtomic = isAtomic; + curr->isAtomic = prefix == BinaryConsts::AtomicPrefix; BYN_TRACE("zz node: Store\n"); Index memIdx = readMemoryAccess(curr->align, curr->offset); memoryRefs[memIdx].push_back(&curr->memory); diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index cd0a9928e..35db3b322 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -258,9 +258,20 @@ void BinaryInstWriter::visitLoad(Load* curr) { } break; } - case Type::f32: - o << int8_t(BinaryConsts::F32LoadMem); + case Type::f32: { + switch (curr->bytes) { + case 2: + o << int8_t(BinaryConsts::MiscPrefix) + << U32LEB(BinaryConsts::F32_F16LoadMem); + break; + case 4: + o << int8_t(BinaryConsts::F32LoadMem); + break; + default: + WASM_UNREACHABLE("invalid load size"); + } break; + } case Type::f64: o << int8_t(BinaryConsts::F64LoadMem); break; @@ -359,9 +370,20 @@ void BinaryInstWriter::visitStore(Store* curr) { } break; } - case Type::f32: - o << int8_t(BinaryConsts::F32StoreMem); + case Type::f32: { + switch (curr->bytes) { + case 2: + o << int8_t(BinaryConsts::MiscPrefix) + << U32LEB(BinaryConsts::F32_F16StoreMem); + break; + case 4: + o << int8_t(BinaryConsts::F32StoreMem); + break; + default: + WASM_UNREACHABLE("invalid store size"); + } break; + } case Type::f64: o << int8_t(BinaryConsts::F64StoreMem); break; diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index ce7d0df3c..b32917432 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -1579,8 +1579,9 @@ void FunctionValidator::validateMemBytes(uint8_t bytes, "expected i64 operation to touch 1, 2, 4, or 8 bytes"); break; case Type::f32: - shouldBeEqual( - bytes, uint8_t(4), curr, "expected f32 operation to touch 4 bytes"); + shouldBeTrue(bytes == 2 || bytes == 4, + curr, + "expected f32 operation to touch 2 or 4 bytes"); break; case Type::f64: shouldBeEqual( diff --git a/test/lit/basic/f16.wast b/test/lit/basic/f16.wast new file mode 100644 index 000000000..c68b0306f --- /dev/null +++ b/test/lit/basic/f16.wast @@ -0,0 +1,79 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. + +;; RUN: wasm-opt %s -all -o %t.text.wast -g -S +;; RUN: wasm-as %s -all -g -o %t.wasm +;; RUN: wasm-dis %t.wasm -all -o %t.bin.wast +;; RUN: wasm-as %s -all -o %t.nodebug.wasm +;; RUN: wasm-dis %t.nodebug.wasm -all -o %t.bin.nodebug.wast +;; RUN: cat %t.text.wast | filecheck %s --check-prefix=CHECK-TEXT +;; RUN: cat %t.bin.wast | filecheck %s --check-prefix=CHECK-BIN +;; RUN: cat %t.bin.nodebug.wast | filecheck %s --check-prefix=CHECK-BIN-NODEBUG + +(module + (memory 1 1) + + + ;; CHECK-TEXT: (type $0 (func (param i32) (result f32))) + + ;; CHECK-TEXT: (type $1 (func (param i32 f32))) + + ;; CHECK-TEXT: (memory $0 1 1) + + ;; CHECK-TEXT: (func $f32.load_f16 (type $0) (param $0 i32) (result f32) + ;; CHECK-TEXT-NEXT: (f32.load_f16 + ;; CHECK-TEXT-NEXT: (local.get $0) + ;; CHECK-TEXT-NEXT: ) + ;; CHECK-TEXT-NEXT: ) + ;; CHECK-BIN: (type $0 (func (param i32) (result f32))) + + ;; CHECK-BIN: (type $1 (func (param i32 f32))) + + ;; CHECK-BIN: (memory $0 1 1) + + ;; CHECK-BIN: (func $f32.load_f16 (type $0) (param $0 i32) (result f32) + ;; CHECK-BIN-NEXT: (f32.load_f16 + ;; CHECK-BIN-NEXT: (local.get $0) + ;; CHECK-BIN-NEXT: ) + ;; CHECK-BIN-NEXT: ) + (func $f32.load_f16 (param $0 i32) (result f32) + (f32.load_f16 + (local.get $0) + ) + ) + ;; CHECK-TEXT: (func $f32.store_f16 (type $1) (param $0 i32) (param $1 f32) + ;; CHECK-TEXT-NEXT: (f32.store_f16 + ;; CHECK-TEXT-NEXT: (local.get $0) + ;; CHECK-TEXT-NEXT: (local.get $1) + ;; CHECK-TEXT-NEXT: ) + ;; CHECK-TEXT-NEXT: ) + ;; CHECK-BIN: (func $f32.store_f16 (type $1) (param $0 i32) (param $1 f32) + ;; CHECK-BIN-NEXT: (f32.store_f16 + ;; CHECK-BIN-NEXT: (local.get $0) + ;; CHECK-BIN-NEXT: (local.get $1) + ;; CHECK-BIN-NEXT: ) + ;; CHECK-BIN-NEXT: ) + (func $f32.store_f16 (param $0 i32) (param $1 f32) + (f32.store_f16 + (local.get $0) + (local.get $1) + ) + ) +) +;; CHECK-BIN-NODEBUG: (type $0 (func (param i32) (result f32))) + +;; CHECK-BIN-NODEBUG: (type $1 (func (param i32 f32))) + +;; CHECK-BIN-NODEBUG: (memory $0 1 1) + +;; CHECK-BIN-NODEBUG: (func $0 (type $0) (param $0 i32) (result f32) +;; CHECK-BIN-NODEBUG-NEXT: (f32.load_f16 +;; CHECK-BIN-NODEBUG-NEXT: (local.get $0) +;; CHECK-BIN-NODEBUG-NEXT: ) +;; CHECK-BIN-NODEBUG-NEXT: ) + +;; CHECK-BIN-NODEBUG: (func $1 (type $1) (param $0 i32) (param $1 f32) +;; CHECK-BIN-NODEBUG-NEXT: (f32.store_f16 +;; CHECK-BIN-NODEBUG-NEXT: (local.get $0) +;; CHECK-BIN-NODEBUG-NEXT: (local.get $1) +;; CHECK-BIN-NODEBUG-NEXT: ) +;; CHECK-BIN-NODEBUG-NEXT: ) diff --git a/test/spec/f16.wast b/test/spec/f16.wast new file mode 100644 index 000000000..19bad1756 --- /dev/null +++ b/test/spec/f16.wast @@ -0,0 +1,15 @@ +;; Test float 16 operations. + +(module + (memory (data "\40\51\AD\DE")) + + (func (export "f32.load_f16") (result f32) (f32.load_f16 (i32.const 0))) + (func (export "f32.store_f16") (f32.store_f16 (i32.const 0) (f32.const 100.5))) + (func (export "i32.load16_u") (result i32) (i32.load16_u (i32.const 2))) +) + +(assert_return (invoke "f32.load_f16") (f32.const 42.0)) +(invoke "f32.store_f16") +(assert_return (invoke "f32.load_f16") (f32.const 100.5)) +;; Ensure that the above operations didn't write to memory they shouldn't have. +(assert_return (invoke "i32.load16_u") (i32.const 0xDEAD)) diff --git a/third_party/FP16/LICENSE b/third_party/FP16/LICENSE new file mode 100644 index 000000000..eabec6c86 --- /dev/null +++ b/third_party/FP16/LICENSE @@ -0,0 +1,11 @@ +The MIT License (MIT) + +Copyright (c) 2017 Facebook Inc. +Copyright (c) 2017 Georgia Institute of Technology +Copyright 2019 Google LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/third_party/FP16/include/fp16.h b/third_party/FP16/include/fp16.h new file mode 100644 index 000000000..adbcb961c --- /dev/null +++ b/third_party/FP16/include/fp16.h @@ -0,0 +1,7 @@ +#pragma once +#ifndef FP16_H +#define FP16_H + +#include <fp16/fp16.h> + +#endif /* FP16_H */ diff --git a/third_party/FP16/include/fp16/bitcasts.h b/third_party/FP16/include/fp16/bitcasts.h new file mode 100644 index 000000000..ae6884325 --- /dev/null +++ b/third_party/FP16/include/fp16/bitcasts.h @@ -0,0 +1,92 @@ +#pragma once +#ifndef FP16_BITCASTS_H +#define FP16_BITCASTS_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include <cstdint> +#elif !defined(__OPENCL_VERSION__) + #include <stdint.h> +#endif + +#if defined(__INTEL_COMPILER) || defined(_MSC_VER) && (_MSC_VER >= 1932) && (defined(_M_IX86) || defined(_M_X64)) + #include <immintrin.h> +#endif + +#if defined(_MSC_VER) && !defined(__clang__) && (defined(_M_ARM) || defined(_M_ARM64)) + #include <intrin.h> +#endif + + +static inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) + return __uint_as_float((unsigned int) w); +#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) && (_MSC_VER >= 1932) && (defined(_M_IX86) || defined(_M_X64)) + return _castu32_f32(w); +#elif defined(_MSC_VER) && !defined(__clang__) && (defined(_M_ARM) || defined(_M_ARM64)) + return _CopyFloatFromInt32((__int32) w); +#else + union { + uint32_t as_bits; + float as_value; + } fp32 = { w }; + return fp32.as_value; +#endif +} + +static inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) + return (uint32_t) __float_as_uint(f); +#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) && (_MSC_VER >= 1932) && (defined(_M_IX86) || defined(_M_X64)) + return _castf32_u32(f); +#elif defined(_MSC_VER) && !defined(__clang__) && (defined(_M_ARM) || defined(_M_ARM64)) + return (uint32_t) _CopyInt32FromFloat(f); +#else + union { + float as_value; + uint32_t as_bits; + } fp32 = { f }; + return fp32.as_bits; +#endif +} + +static inline double fp64_from_bits(uint64_t w) { +#if defined(__OPENCL_VERSION__) + return as_double(w); +#elif defined(__CUDA_ARCH__) + return __longlong_as_double((long long) w); +#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) && (_MSC_VER >= 1932) && (defined(_M_IX86) || defined(_M_X64)) + return _castu64_f64(w); +#elif defined(_MSC_VER) && !defined(__clang__) && (defined(_M_ARM) || defined(_M_ARM64)) + return _CopyDoubleFromInt64((__int64) w); +#else + union { + uint64_t as_bits; + double as_value; + } fp64 = { w }; + return fp64.as_value; +#endif +} + +static inline uint64_t fp64_to_bits(double f) { +#if defined(__OPENCL_VERSION__) + return as_ulong(f); +#elif defined(__CUDA_ARCH__) + return (uint64_t) __double_as_longlong(f); +#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) && (_MSC_VER >= 1932) && (defined(_M_IX86) || defined(_M_X64)) + return _castf64_u64(f); +#elif defined(_MSC_VER) && !defined(__clang__) && (defined(_M_ARM) || defined(_M_ARM64)) + return (uint64_t) _CopyInt64FromDouble(f); +#else + union { + double as_value; + uint64_t as_bits; + } fp64 = { f }; + return fp64.as_bits; +#endif +} + +#endif /* FP16_BITCASTS_H */ diff --git a/third_party/FP16/include/fp16/fp16.h b/third_party/FP16/include/fp16/fp16.h new file mode 100644 index 000000000..95fa0c2de --- /dev/null +++ b/third_party/FP16/include/fp16/fp16.h @@ -0,0 +1,515 @@ +#pragma once +#ifndef FP16_FP16_H +#define FP16_FP16_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include <cstdint> + #include <cmath> +#elif !defined(__OPENCL_VERSION__) + #include <stdint.h> + #include <math.h> +#endif + +#include <fp16/bitcasts.h> +#include <fp16/macros.h> + +#if defined(_MSC_VER) + #include <intrin.h> +#endif +#if defined(__F16C__) && FP16_USE_NATIVE_CONVERSION && !FP16_USE_FLOAT16_TYPE && !FP16_USE_FP16_TYPE + #include <immintrin.h> +#endif +#if (defined(__aarch64__) || defined(_M_ARM64)) && FP16_USE_NATIVE_CONVERSION && !FP16_USE_FLOAT16_TYPE && !FP16_USE_FP16_TYPE + #include <arm_neon.h> +#endif + + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +static inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. + * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. + * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift + * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the + * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long) nonsign); + uint32_t renorm_shift = (uint32_t) nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows it into bit 31, + * and the subsequent shift turns the high 9 bits into 1. Thus + * inf_nan_mask == + * 0x7F800000 if the half-precision number had exponent of 15 (i.e. was NaN or infinity) + * 0x00000000 otherwise + */ + const int32_t inf_nan_mask = ((int32_t) (nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 into 1. Otherwise, bit 31 remains 0. + * The signed shift right by 31 broadcasts bit 31 into all bits of the zero_mask. Thus + * zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t) (nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) becomes an 8-bit field and 10-bit mantissa + * shifts into the 10 high bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the different in exponent bias + * (0x7F for single-precision number less 0xF for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to account for renormalization. As renorm_shift + * is less than 0x70, this can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_ieee_to_fp32_value(uint16_t h) { +#if FP16_USE_NATIVE_CONVERSION + #if FP16_USE_FLOAT16_TYPE + union { + uint16_t as_bits; + _Float16 as_value; + } fp16 = { h }; + return (float) fp16.as_value; + #elif FP16_USE_FP16_TYPE + union { + uint16_t as_bits; + __fp16 as_value; + } fp16 = { h }; + return (float) fp16.as_value; + #else + #if (defined(__INTEL_COMPILER) || defined(__GNUC__)) && defined(__F16C__) + return _cvtsh_ss((unsigned short) h); + #elif defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64)) && defined(__AVX2__) + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128((int) (unsigned int) h))); + #elif defined(_M_ARM64) || defined(__aarch64__) + return vgetq_lane_f32(vcvt_f32_f16(vreinterpret_f16_u16(vdup_n_u16(h))), 0); + #else + #error "Archtecture- or compiler-specific implementation required" + #endif + #endif +#else + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +#endif +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * IEEE half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_ieee_from_fp32_value(float f) { +#if FP16_USE_NATIVE_CONVERSION + #if FP16_USE_FLOAT16_TYPE + union { + _Float16 as_value; + uint16_t as_bits; + } fp16 = { (_Float16) f }; + return fp16.as_bits; + #elif FP16_USE_FP16_TYPE + union { + __fp16 as_value; + uint16_t as_bits; + } fp16 = { (__fp16) f }; + return fp16.as_bits; + #else + #if (defined(__INTEL_COMPILER) || defined(__GNUC__)) && defined(__F16C__) + return _cvtss_sh(f, _MM_FROUND_CUR_DIRECTION); + #elif defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64)) && defined(__AVX2__) + return (uint16_t) _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(f), _MM_FROUND_CUR_DIRECTION)); + #elif defined(_M_ARM64) || defined(__aarch64__) + return vget_lane_u16(vcvt_f16_f32(vdupq_n_f32(f)), 0); + #else + #error "Archtecture- or compiler-specific implementation required" + #endif + #endif +#else +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif +#if defined(_MSC_VER) && defined(_M_IX86_FP) && (_M_IX86_FP == 0) || defined(__GNUC__) && defined(__FLT_EVAL_METHOD__) && (__FLT_EVAL_METHOD__ != 0) + const volatile float saturated_f = fabsf(f) * scale_to_inf; +#else + const float saturated_f = fabsf(f) * scale_to_inf; +#endif + float base = saturated_f * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +#endif +} + +/* + * Convert a 16-bit floating-point number in ARM alternative half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +static inline uint32_t fp16_alt_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. + * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. + * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift + * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the + * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long) nonsign); + uint32_t renorm_shift = (uint32_t) nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 into 1. Otherwise, bit 31 remains 0. + * The signed shift right by 31 broadcasts bit 31 into all bits of the zero_mask. Thus + * zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t) (nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) becomes an 8-bit field and 10-bit mantissa + * shifts into the 10 high bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the different in exponent bias + * (0x7F for single-precision number less 0xF for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to account for renormalization. As renorm_shift + * is less than 0x70, this can be combined with step 3. + * 5. Binary ANDNOT with zero_mask to turn the mantissa and exponent into zero if the input was zero. + * 6. Combine with the sign of the input number. + */ + return sign | (((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) & ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in ARM alternative half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_alt_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, the exponent is adjusted for the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70). This operation never overflows or generates non-finite values, as the largest + * half-precision exponent is 0x1F and after the adjustment is can not exceed 0x8F < 0xFE (largest single-precision + * exponent for non-finite values). + * + * Note that this operation does not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0x70) << 23; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset); + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * ARM alternative half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_alt_from_fp32_value(float f) { + const uint32_t w = fp32_to_bits(f); + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t shl1_w = w + w; + + const uint32_t shl1_max_fp16_fp32 = UINT32_C(0x8FFFC000); + const uint32_t shl1_base = shl1_w > shl1_max_fp16_fp32 ? shl1_max_fp16_fp32 : shl1_w; + uint32_t shl1_bias = shl1_base & UINT32_C(0xFF000000); + const uint32_t exp_difference = 23 - 10; + const uint32_t shl1_bias_min = (127 - 1 - exp_difference) << 24; + if (shl1_bias < shl1_bias_min) { + shl1_bias = shl1_bias_min; + } + + const float bias = fp32_from_bits((shl1_bias >> 1) + ((exp_difference + 2) << 23)); + const float base = fp32_from_bits((shl1_base >> 1) + (2 << 23)) + bias; + + const uint32_t exp_f = fp32_to_bits(base) >> 13; + return (sign >> 16) | ((exp_f & UINT32_C(0x00007C00)) + (fp32_to_bits(base) & UINT32_C(0x00000FFF))); +} + +#endif /* FP16_FP16_H */ diff --git a/third_party/FP16/include/fp16/macros.h b/third_party/FP16/include/fp16/macros.h new file mode 100644 index 000000000..4018b0c9d --- /dev/null +++ b/third_party/FP16/include/fp16/macros.h @@ -0,0 +1,46 @@ +#pragma once +#ifndef FP16_MACROS_H +#define FP16_MACROS_H + +#ifndef FP16_USE_NATIVE_CONVERSION + #if (defined(__INTEL_COMPILER) || defined(__GNUC__)) && defined(__F16C__) + #define FP16_USE_NATIVE_CONVERSION 1 + #elif defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64)) && defined(__AVX2__) + #define FP16_USE_NATIVE_CONVERSION 1 + #elif defined(_MSC_VER) && defined(_M_ARM64) + #define FP16_USE_NATIVE_CONVERSION 1 + #elif defined(__GNUC__) && defined(__aarch64__) + #define FP16_USE_NATIVE_CONVERSION 1 + #endif + #if !defined(FP16_USE_NATIVE_CONVERSION) + #define FP16_USE_NATIVE_CONVERSION 0 + #endif // !defined(FP16_USE_NATIVE_CONVERSION) +#endif // !define(FP16_USE_NATIVE_CONVERSION) + +#ifndef FP16_USE_FLOAT16_TYPE + #if !defined(__clang__) && !defined(__INTEL_COMPILER) && defined(__GNUC__) && (__GNUC__ >= 12) + #if defined(__F16C__) + #define FP16_USE_FLOAT16_TYPE 1 + #endif + #endif + #if !defined(FP16_USE_FLOAT16_TYPE) + #define FP16_USE_FLOAT16_TYPE 0 + #endif // !defined(FP16_USE_FLOAT16_TYPE) +#endif // !defined(FP16_USE_FLOAT16_TYPE) + +#ifndef FP16_USE_FP16_TYPE + #if defined(__clang__) + #if defined(__F16C__) || defined(__aarch64__) + #define FP16_USE_FP16_TYPE 1 + #endif + #elif defined(__GNUC__) + #if defined(__aarch64__) + #define FP16_USE_FP16_TYPE 1 + #endif + #endif + #if !defined(FP16_USE_FP16_TYPE) + #define FP16_USE_FP16_TYPE 0 + #endif // !defined(FP16_USE_FP16_TYPE) +#endif // !defined(FP16_USE_FP16_TYPE) + +#endif /* FP16_MACROS_H */ diff --git a/third_party/FP16/readme.txt b/third_party/FP16/readme.txt new file mode 100644 index 000000000..9a09e623d --- /dev/null +++ b/third_party/FP16/readme.txt @@ -0,0 +1,13 @@ +This folder contains files from FP16. See LICENSE.TXT for their license. + +These files were synced from + +commit 98b0a46bce017382a6351a19577ec43a715b6835 (HEAD -> master, origin/master, origin/HEAD) +Author: Marat Dukhan <maratek@gmail.com> +Date: Wed Jun 19 23:11:08 2024 -0700 + + Support native conversions without __fp16/_Float16 types + +and also contain the patch from + +https://github.com/Maratyszcza/FP16/pull/31 616ad91f449a03d0b48a8a124f4d1baa94f869b2 |