summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend/mod.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-29 13:30:11 +0200
committerGitHub <noreply@github.com>2024-04-29 13:30:11 +0200
commit09d4845aa842dc5d9da650fd7865c4f0855dcf97 (patch)
treee19b40da2eaa05c250ad776dcafa1f6789763406 /candle-core/src/cuda_backend/mod.rs
parenta0d03aded1b8c4cfe96f7d6490f5c709c31b76f0 (diff)
downloadcandle-09d4845aa842dc5d9da650fd7865c4f0855dcf97.tar.gz
candle-09d4845aa842dc5d9da650fd7865c4f0855dcf97.tar.bz2
candle-09d4845aa842dc5d9da650fd7865c4f0855dcf97.zip
Bugfix the recent f16/bf16 changes. (#2142)
Diffstat (limited to 'candle-core/src/cuda_backend/mod.rs')
-rw-r--r--candle-core/src/cuda_backend/mod.rs16
1 files changed, 8 insertions, 8 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 67ed56e0..dbb89eaf 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -1896,8 +1896,8 @@ unsafe fn gemm_strided_batched_f16(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
- let alpha = cfg.gemm.alpha;
- let beta = cfg.gemm.beta;
+ let alpha: f32 = cfg.gemm.alpha.to_f32();
+ let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@@ -1905,7 +1905,7 @@ unsafe fn gemm_strided_batched_f16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
- (&alpha) as *const f16 as *const _,
+ (&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
@@ -1914,7 +1914,7 @@ unsafe fn gemm_strided_batched_f16(
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
- (&beta) as *const f16 as *const _,
+ (&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
@@ -1941,8 +1941,8 @@ unsafe fn gemm_strided_batched_bf16(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
- let alpha = cfg.gemm.alpha;
- let beta = cfg.gemm.beta;
+ let alpha: f32 = cfg.gemm.alpha.to_f32();
+ let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@@ -1950,7 +1950,7 @@ unsafe fn gemm_strided_batched_bf16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
- (&alpha) as *const bf16 as *const _,
+ (&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.lda,
@@ -1959,7 +1959,7 @@ unsafe fn gemm_strided_batched_bf16(
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldb,
cfg.stride_b,
- (&beta) as *const bf16 as *const _,
+ (&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldc,