diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-29 14:08:44 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-29 14:08:44 +0200 |
commit | fa06f5f5f9a05c8d0c246e761e94a73680c510a6 (patch) | |
tree | 73618e34b74f9c1876922dff0e27d1f04f0231a5 /candle-core/src/cuda_backend/mod.rs | |
parent | 09d4845aa842dc5d9da650fd7865c4f0855dcf97 (diff) | |
download | candle-fa06f5f5f9a05c8d0c246e761e94a73680c510a6.tar.gz candle-fa06f5f5f9a05c8d0c246e761e94a73680c510a6.tar.bz2 candle-fa06f5f5f9a05c8d0c246e761e94a73680c510a6.zip |
F16/BF16 bugfix (bis). (#2143)
* F16/BF16 bugfix (bis).
* Another fix.
* Yet another fix.
Diffstat (limited to 'candle-core/src/cuda_backend/mod.rs')
-rw-r--r-- | candle-core/src/cuda_backend/mod.rs | 50 |
1 files changed, 36 insertions, 14 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index dbb89eaf..39b41d2e 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1890,14 +1890,24 @@ unsafe fn gemm_strided_batched_f16( use cudarc::cublas::sys; use cudarc::driver::DevicePtrMut; - let compute_type = if gemm_reduced_precision_f16() { - sys::cublasComputeType_t::CUBLAS_COMPUTE_16F + let alpha = cfg.gemm.alpha; + let beta = cfg.gemm.beta; + let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); + let beta_f32: f32 = cfg.gemm.beta.to_f32(); + let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() { + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, + (&alpha) as *const f16 as *const _, + (&beta) as *const f16 as *const _, + ) } else { - sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, + ) }; - let alpha: f32 = cfg.gemm.alpha.to_f32(); - let beta: f32 = cfg.gemm.beta.to_f32(); cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -1905,7 +1915,7 @@ unsafe fn gemm_strided_batched_f16( cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, - (&alpha) as *const f32 as *const _, + alpha, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.lda, @@ -1914,7 +1924,7 @@ unsafe fn gemm_strided_batched_f16( sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldb, cfg.stride_b, - (&beta) as *const f32 as *const _, + beta, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldc, @@ -1935,14 +1945,26 @@ unsafe fn gemm_strided_batched_bf16( use cudarc::cublas::sys; use cudarc::driver::DevicePtrMut; - let compute_type = if gemm_reduced_precision_bf16() { - sys::cublasComputeType_t::CUBLAS_COMPUTE_16F + let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); + let beta_f32: f32 = cfg.gemm.beta.to_f32(); + let alpha = f16::from_f32(alpha_f32); + let beta = f16::from_f32(beta_f32); + // The type for alpha and beta depends on the computeType. + // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex + let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() { + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, + (&alpha) as *const f16 as *const _, + (&beta) as *const f16 as *const _, + ) } else { - sys::cublasComputeType_t::CUBLAS_COMPUTE_32F + ( + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, + ) }; - let alpha: f32 = cfg.gemm.alpha.to_f32(); - let beta: f32 = cfg.gemm.beta.to_f32(); cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, @@ -1950,7 +1972,7 @@ unsafe fn gemm_strided_batched_bf16( cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, - (&alpha) as *const f32 as *const _, + alpha, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.lda, @@ -1959,7 +1981,7 @@ unsafe fn gemm_strided_batched_bf16( sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldb, cfg.stride_b, - (&beta) as *const f32 as *const _, + beta, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldc, |