summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend/mod.rs16
1 files changed, 8 insertions, 8 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 67ed56e0..dbb89eaf 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -1896,8 +1896,8 @@ unsafe fn gemm_strided_batched_f16(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
- let alpha = cfg.gemm.alpha;
- let beta = cfg.gemm.beta;
+ let alpha: f32 = cfg.gemm.alpha.to_f32();
+ let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@@ -1905,7 +1905,7 @@ unsafe fn gemm_strided_batched_f16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
- (&alpha) as *const f16 as *const _,
+ (&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
@@ -1914,7 +1914,7 @@ unsafe fn gemm_strided_batched_f16(
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
- (&beta) as *const f16 as *const _,
+ (&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
@@ -1941,8 +1941,8 @@ unsafe fn gemm_strided_batched_bf16(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
- let alpha = cfg.gemm.alpha;
- let beta = cfg.gemm.beta;
+ let alpha: f32 = cfg.gemm.alpha.to_f32();
+ let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@@ -1950,7 +1950,7 @@ unsafe fn gemm_strided_batched_bf16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
- (&alpha) as *const bf16 as *const _,
+ (&alpha) as *const f32 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.lda,
@@ -1959,7 +1959,7 @@ unsafe fn gemm_strided_batched_bf16(
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldb,
cfg.stride_b,
- (&beta) as *const bf16 as *const _,
+ (&beta) as *const f32 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldc,