diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-22 07:25:23 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-22 07:25:23 +0100 |
commit | 6708870e633af636660c556c19703c38cbe2af8d (patch) | |
tree | 5b1f7e3eac1e0be5cc25a5d16db43cc717ff6ee0 /candle-core/src/cuda_backend.rs | |
parent | a00e24d752f3f62978c878859a01a4246244d4bc (diff) | |
download | candle-6708870e633af636660c556c19703c38cbe2af8d.tar.gz candle-6708870e633af636660c556c19703c38cbe2af8d.tar.bz2 candle-6708870e633af636660c556c19703c38cbe2af8d.zip |
Add the alloc_uninit function. (#1901)
* Add the alloc_uninit function.
* Dummy metal fix.
* Lazy initialization.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 56 |
1 files changed, 50 insertions, 6 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index fec37c39..f0f03053 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -384,6 +384,44 @@ impl BackendDevice for CudaDevice { self.const_impl(1., shape, dtype) } + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc::<u8>(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc::<u32>(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + let data = self.alloc::<i64>(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc::<bf16>(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc::<f16>(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc::<f32>(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc::<f64>(elem_count).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> { let slice = match storage { CpuStorage::U8(storage) => { @@ -1916,7 +1954,10 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -1924,7 +1965,7 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -1981,7 +2022,10 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -1991,7 +2035,7 @@ impl BackendStorage for CudaStorage { let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? .transpose(1, 3)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -2128,7 +2172,7 @@ impl BackendStorage for CudaStorage { dim: usize, ) -> Result<Self> { let device = self.device().clone(); - let mut acc = device.zeros_impl(l.shape(), self.dtype())?; + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; Ok(acc) @@ -2143,7 +2187,7 @@ impl BackendStorage for CudaStorage { dim: usize, ) -> Result<Self> { let device = self.device().clone(); - let mut acc = device.zeros_impl(l.shape(), self.dtype())?; + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; Ok(acc) |