diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-23 08:15:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 07:15:37 +0100 |
commit | b8a10425ad550b04ccf3b5ff2493714615d7df4b (patch) | |
tree | eab9ad609e34e4bad23cedc81ee338fe00961c3f /candle-core/src/cuda_backend.rs | |
parent | 5f20acf0804a624d6c274e488c897fb88d698f1a (diff) | |
download | candle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.tar.gz candle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.tar.bz2 candle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.zip |
Kernel build example (#224)
* Build example kernels.
* Add some sample custom kernel.
* Get the example kernel to compile.
* Add some cuda code.
* More cuda custom op.
* More cuda custom ops.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index f9fefe17..d2cc3e41 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -771,6 +771,50 @@ pub struct CudaStorage { device: CudaDevice, } +pub trait CudaDType: Sized { + fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>; + fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage; +} + +macro_rules! cuda_dtype { + ($ty:ty, $dtype:ident) => { + impl CudaDType for $ty { + fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> { + match &s.slice { + CudaStorageSlice::$dtype(data) => Ok(&data), + _ => Err(crate::Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + msg: "unexpected dtype", + } + .bt()), + } + } + + fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage { + let slice = CudaStorageSlice::$dtype(slice); + CudaStorage { slice, device } + } + } + }; +} +cuda_dtype!(u8, U8); +cuda_dtype!(u32, U32); +cuda_dtype!(f16, F16); +cuda_dtype!(bf16, BF16); +cuda_dtype!(f32, F32); +cuda_dtype!(f64, F64); + +impl CudaStorage { + pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage { + T::wrap_cuda_slice(slice, device) + } + + pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> { + T::as_cuda_slice(self) + } +} + fn gemm_config<T>( alpha: T, beta: T, |