summaryrefslogtreecommitdiff
path: root/candle-kernels/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-25 23:26:05 +0100
committerGitHub <noreply@github.com>2024-03-25 23:26:05 +0100
commit196765e995f7f4bd3b9610a22f8ef5b009437a4e (patch)
tree707ef5ce23ded99c14f0c9e29e8115fa5386ff5c /candle-kernels/src
parent60676780a9436fd0de43b1e8ff99445ab863c066 (diff)
downloadcandle-196765e995f7f4bd3b9610a22f8ef5b009437a4e.tar.gz
candle-196765e995f7f4bd3b9610a22f8ef5b009437a4e.tar.bz2
candle-196765e995f7f4bd3b9610a22f8ef5b009437a4e.zip
Use the new rope kernel in mistral. (#1937)
* Use the new rope kernel in mistral. * Compute the cos and sin with full precision. * Bugfix.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r--candle-kernels/src/reduce.cu4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu
index 48bbcd83..2af81c42 100644
--- a/candle-kernels/src/reduce.cu
+++ b/candle-kernels/src/reduce.cu
@@ -150,7 +150,7 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (2 * idx > bh * td) return;
+ if (2 * idx >= bh * td) return;
uint32_t rope_idx = idx % (td / 2);
T c = cos[rope_idx];
@@ -163,7 +163,7 @@ __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, cons
template <typename T>
__device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td, const uint32_t d) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (2 * idx > bh * td) return;
+ if (2 * idx >= bh * td) return;
uint32_t i_bh = idx / (td / 2);
uint32_t i_td = idx - (td / 2) * i_bh;