summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend/mod.rs50
1 files changed, 36 insertions, 14 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index dbb89eaf..39b41d2e 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -1890,14 +1890,24 @@ unsafe fn gemm_strided_batched_f16(
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
- let compute_type = if gemm_reduced_precision_f16() {
- sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
+ let alpha = cfg.gemm.alpha;
+ let beta = cfg.gemm.beta;
+ let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
+ let beta_f32: f32 = cfg.gemm.beta.to_f32();
+ let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() {
+ (
+ sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
+ (&alpha) as *const f16 as *const _,
+ (&beta) as *const f16 as *const _,
+ )
} else {
- sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
+ (
+ sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
+ (&alpha_f32) as *const f32 as *const _,
+ (&beta_f32) as *const f32 as *const _,
+ )
};
- let alpha: f32 = cfg.gemm.alpha.to_f32();
- let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@@ -1905,7 +1915,7 @@ unsafe fn gemm_strided_batched_f16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
- (&alpha) as *const f32 as *const _,
+ alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
@@ -1914,7 +1924,7 @@ unsafe fn gemm_strided_batched_f16(
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
- (&beta) as *const f32 as *const _,
+ beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
@@ -1935,14 +1945,26 @@ unsafe fn gemm_strided_batched_bf16(
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
- let compute_type = if gemm_reduced_precision_bf16() {
- sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
+ let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
+ let beta_f32: f32 = cfg.gemm.beta.to_f32();
+ let alpha = f16::from_f32(alpha_f32);
+ let beta = f16::from_f32(beta_f32);
+ // The type for alpha and beta depends on the computeType.
+ // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
+ let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
+ (
+ sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
+ (&alpha) as *const f16 as *const _,
+ (&beta) as *const f16 as *const _,
+ )
} else {
- sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
+ (
+ sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
+ (&alpha_f32) as *const f32 as *const _,
+ (&beta_f32) as *const f32 as *const _,
+ )
};
- let alpha: f32 = cfg.gemm.alpha.to_f32();
- let beta: f32 = cfg.gemm.beta.to_f32();
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
@@ -1950,7 +1972,7 @@ unsafe fn gemm_strided_batched_bf16(
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
- (&alpha) as *const f32 as *const _,
+ alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.lda,
@@ -1959,7 +1981,7 @@ unsafe fn gemm_strided_batched_bf16(
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldb,
cfg.stride_b,
- (&beta) as *const f32 as *const _,
+ beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldc,