summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs56
1 files changed, 50 insertions, 6 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index fec37c39..f0f03053 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -384,6 +384,44 @@ impl BackendDevice for CudaDevice {
self.const_impl(1., shape, dtype)
}
+ unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
+ let elem_count = shape.elem_count();
+ let slice = match dtype {
+ DType::U8 => {
+ let data = self.alloc::<u8>(elem_count).w()?;
+ CudaStorageSlice::U8(data)
+ }
+ DType::U32 => {
+ let data = self.alloc::<u32>(elem_count).w()?;
+ CudaStorageSlice::U32(data)
+ }
+ DType::I64 => {
+ let data = self.alloc::<i64>(elem_count).w()?;
+ CudaStorageSlice::I64(data)
+ }
+ DType::BF16 => {
+ let data = self.alloc::<bf16>(elem_count).w()?;
+ CudaStorageSlice::BF16(data)
+ }
+ DType::F16 => {
+ let data = self.alloc::<f16>(elem_count).w()?;
+ CudaStorageSlice::F16(data)
+ }
+ DType::F32 => {
+ let data = self.alloc::<f32>(elem_count).w()?;
+ CudaStorageSlice::F32(data)
+ }
+ DType::F64 => {
+ let data = self.alloc::<f64>(elem_count).w()?;
+ CudaStorageSlice::F64(data)
+ }
+ };
+ Ok(CudaStorage {
+ slice,
+ device: self.clone(),
+ })
+ }
+
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
@@ -1916,7 +1954,10 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
- let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ let mut kernel_c = unsafe {
+ self.device()
+ .alloc_uninit(kernel_l.shape(), kernel.dtype())?
+ };
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@@ -1924,7 +1965,7 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
- let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@@ -1981,7 +2022,10 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
- let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ let mut kernel_c = unsafe {
+ self.device()
+ .alloc_uninit(kernel_l.shape(), kernel.dtype())?
+ };
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@@ -1991,7 +2035,7 @@ impl BackendStorage for CudaStorage {
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
- let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@@ -2128,7 +2172,7 @@ impl BackendStorage for CudaStorage {
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
- let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
+ let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
@@ -2143,7 +2187,7 @@ impl BackendStorage for CudaStorage {
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
- let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
+ let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)