summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-23 08:15:37 +0200
committerGitHub <noreply@github.com>2023-07-23 07:15:37 +0100
commitb8a10425ad550b04ccf3b5ff2493714615d7df4b (patch)
treeeab9ad609e34e4bad23cedc81ee338fe00961c3f /candle-core/src/cuda_backend.rs
parent5f20acf0804a624d6c274e488c897fb88d698f1a (diff)
downloadcandle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.tar.gz
candle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.tar.bz2
candle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.zip
Kernel build example (#224)
* Build example kernels. * Add some sample custom kernel. * Get the example kernel to compile. * Add some cuda code. * More cuda custom op. * More cuda custom ops.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs44
1 files changed, 44 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index f9fefe17..d2cc3e41 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -771,6 +771,50 @@ pub struct CudaStorage {
device: CudaDevice,
}
+pub trait CudaDType: Sized {
+ fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
+ fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
+}
+
+macro_rules! cuda_dtype {
+ ($ty:ty, $dtype:ident) => {
+ impl CudaDType for $ty {
+ fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> {
+ match &s.slice {
+ CudaStorageSlice::$dtype(data) => Ok(&data),
+ _ => Err(crate::Error::UnexpectedDType {
+ expected: DType::$dtype,
+ got: s.dtype(),
+ msg: "unexpected dtype",
+ }
+ .bt()),
+ }
+ }
+
+ fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
+ let slice = CudaStorageSlice::$dtype(slice);
+ CudaStorage { slice, device }
+ }
+ }
+ };
+}
+cuda_dtype!(u8, U8);
+cuda_dtype!(u32, U32);
+cuda_dtype!(f16, F16);
+cuda_dtype!(bf16, BF16);
+cuda_dtype!(f32, F32);
+cuda_dtype!(f64, F64);
+
+impl CudaStorage {
+ pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage {
+ T::wrap_cuda_slice(slice, device)
+ }
+
+ pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
+ T::as_cuda_slice(self)
+ }
+}
+
fn gemm_config<T>(
alpha: T,
beta: T,