diff options
Diffstat (limited to 'candle-flash-attn/kernels/philox.cuh')
-rw-r--r-- | candle-flash-attn/kernels/philox.cuh | 120 |
1 files changed, 3 insertions, 117 deletions
diff --git a/candle-flash-attn/kernels/philox.cuh b/candle-flash-attn/kernels/philox.cuh index 6ce1440f..cd7e4d2f 100644 --- a/candle-flash-attn/kernels/philox.cuh +++ b/candle-flash-attn/kernels/philox.cuh @@ -9,7 +9,7 @@ struct ull2 { unsigned long long y; }; -inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { +__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { uint2 *res; unsigned long long tmp; asm ("mul.wide.u32 %0, %1, %2;\n\t" @@ -19,7 +19,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { return *res; } -inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { +__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { constexpr unsigned long kPhiloxSA = 0xD2511F53; constexpr unsigned long kPhiloxSB = 0xCD9E8D57; uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); @@ -28,7 +28,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { return ret; } -inline __device__ uint4 philox(unsigned long long seed, +__forceinline__ __device__ uint4 philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) { constexpr unsigned long kPhilox10A = 0x9E3779B9; @@ -49,117 +49,3 @@ inline __device__ uint4 philox(unsigned long long seed, } } // namespace flash - -namespace { - -class Philox { -public: - __device__ inline Philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) - : STATE(0) - , seed_(seed) - , offset_(offset) - , key(reinterpret_cast<const uint2&>(seed)) { - //key.x = (unsigned int)seed; - //key.y = (unsigned int)(seed >> 32); - //counter = make_uint4(0, 0, 0, 0); - //counter.z = (unsigned int)(subsequence); - //counter.w = (unsigned int)(subsequence >> 32); - //STATE = 0; - //incr_n(offset / 4); - - // key = reinterpret_cast<const uint2&>(seed); - ull2 * tmp = reinterpret_cast<ull2*>(&counter); - tmp->x = offset / 4; - tmp->y = subsequence; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); - // } - } - __device__ inline uint4 operator()() { - // // if (STATE == 0) { - // uint4 counter_ = counter; - // uint2 key_ = key; - // // 7-round philox - // #pragma unroll - // for (int i = 0; i < 6; i++) { - // counter_ = flash::philox_single_round(counter_, key_); - // key_.x += (kPhilox10A); - // key_.y += (kPhilox10B); - // } - // // output = philox_single_round(counter_, key_); - // uint4 output = flash::philox_single_round(counter_, key_); - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); - // // } - // incr(); - // // } - // // return a float4 directly - // // unsigned long ret; - // // switch(STATE) { - // // case 0: ret = output.x; break; - // // case 1: ret = output.y; break; - // // case 2: ret = output.z; break; - // // case 3: ret = output.w; break; - // //} - // // STATE = (STATE + 1) % 4; - // return output; - return flash::philox(seed_, offset_, offset_); - } - -private: - unsigned long long offset_, seed_; - struct ull2 { - uint64_t x; - uint64_t y; - }; - uint4 counter; - // uint4 output; - const uint2 key; - unsigned int STATE; - __device__ inline void incr_n(unsigned long long n) { - unsigned int nlo = (unsigned int)(n); - unsigned int nhi = (unsigned int)(n >> 32); - counter.x += nlo; - if (counter.x < nlo) - nhi++; - counter.y += nhi; - if (nhi <= counter.y) - return; - if (++counter.z) - return; - ++counter.w; - } - - __device__ uint4 incr128 (uint4 ctr) - { - uint4 res; - asm ("add.cc.u32 %0, %4, %8;\n\t" - "addc.cc.u32 %1, %5, %9;\n\t" - "addc.cc.u32 %2, %6, %10;\n\t" - "addc.u32 %3, %7, %11;\n\t" - : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) - : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), - "n"(1), "n"(0), "n"(0), "n"(0)); - return res; - } - - __device__ inline void incr() { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // } - counter = incr128(counter); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // } - } - - static const unsigned long kPhilox10A = 0x9E3779B9; - static const unsigned long kPhilox10B = 0xBB67AE85; - // static const unsigned long kPhiloxSA = 0xD2511F53; - // static const unsigned long kPhiloxSB = 0xCD9E8D57; -}; - -} // namespace |