diff options
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/reduce.cu | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 48bbcd83..2af81c42 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -150,7 +150,7 @@ __device__ void softmax(const T * x, T * dst, const int ncols) { template <typename T> __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (2 * idx > bh * td) return; + if (2 * idx >= bh * td) return; uint32_t rope_idx = idx % (td / 2); T c = cos[rope_idx]; @@ -163,7 +163,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons template <typename T> __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (2 * idx > bh * td) return; + if (2 * idx >= bh * td) return; uint32_t i_bh = idx / (td / 2); uint32_t i_td = idx - (td / 2) * i_bh; |