diff options
-rw-r--r-- | candle-core/src/backend.rs | 6 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 61 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 56 | ||||
-rw-r--r-- | candle-core/src/device.rs | 17 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/dummy_metal_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 10 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 8 | ||||
-rw-r--r-- | candle-core/src/tensor_cat.rs | 4 |
9 files changed, 154 insertions, 16 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index c63aad54..27ffe934 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -127,6 +127,12 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>; + /// # Safety + /// This function is unsafe as it doesn't initialize the underlying data store. + /// The caller should ensure that the data is properly initialized as early as possible + /// after this call. + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>; + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>; fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index fa48577c..6d2ba361 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -2582,7 +2582,10 @@ impl BackendStorage for CpuStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -2590,7 +2593,7 @@ impl BackendStorage for CpuStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -2681,7 +2684,10 @@ impl BackendStorage for CpuStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -2691,7 +2697,7 @@ impl BackendStorage for CpuStorage { let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) .transpose(1, 2)? .transpose(1, 3)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -2919,6 +2925,53 @@ impl BackendDevice for CpuDevice { } } + #[allow(clippy::uninit_vec)] + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> { + let elem_count = shape.elem_count(); + // The code below is highly unsafe but hopefully not directly unsound as we only consider + // types that are Copy, not Drop, and for which all bit patterns are proper values. + // It's still pretty risky, see the following for more details: + // https://github.com/rust-lang/rust-clippy/issues/4483 + let storage = match dtype { + DType::U8 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::U8(v) + } + DType::U32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::U32(v) + } + DType::I64 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I64(v) + } + DType::BF16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::BF16(v) + } + DType::F16 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F16(v) + } + DType::F32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F32(v) + } + DType::F64 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F64(v) + } + }; + Ok(storage) + } + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> { let elem_count = shape.elem_count(); let storage = match dtype { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index fec37c39..f0f03053 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -384,6 +384,44 @@ impl BackendDevice for CudaDevice { self.const_impl(1., shape, dtype) } + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc::<u8>(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc::<u32>(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + let data = self.alloc::<i64>(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc::<bf16>(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc::<f16>(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc::<f32>(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc::<f64>(elem_count).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> { let slice = match storage { CpuStorage::U8(storage) => { @@ -1916,7 +1954,10 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -1924,7 +1965,7 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -1981,7 +2022,10 @@ impl BackendStorage for CudaStorage { col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. - let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + let mut kernel_c = unsafe { + self.device() + .alloc_uninit(kernel_l.shape(), kernel.dtype())? + }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? @@ -1991,7 +2035,7 @@ impl BackendStorage for CudaStorage { let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? .transpose(1, 3)?; - let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } @@ -2128,7 +2172,7 @@ impl BackendStorage for CudaStorage { dim: usize, ) -> Result<Self> { let device = self.device().clone(); - let mut acc = device.zeros_impl(l.shape(), self.dtype())?; + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; Ok(acc) @@ -2143,7 +2187,7 @@ impl BackendStorage for CudaStorage { dim: usize, ) -> Result<Self> { let device = self.device().clone(); - let mut acc = device.zeros_impl(l.shape(), self.dtype())?; + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; Ok(acc) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 9c39d27a..846c62ce 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -289,6 +289,23 @@ impl Device { } } + pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> { + match self { + Device::Cpu => { + let storage = CpuDevice.alloc_uninit(shape, dtype)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.alloc_uninit(shape, dtype)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = device.alloc_uninit(shape, dtype)?; + Ok(Storage::Metal(storage)) + } + } + } + pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> { match self { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index d4887f19..5348233c 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -210,6 +210,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> { + Err(Error::NotCompiledWithCudaSupport) + } + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 33c6c9fe..322f81d2 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -222,6 +222,10 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> { + Err(Error::NotCompiledWithMetalSupport) + } + fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 4f4162e2..ef044fc8 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1886,6 +1886,16 @@ impl BackendDevice for MetalDevice { self.device.registry_id() == rhs.device.registry_id() } + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { + let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-uninit")?; + Ok(MetalStorage::new( + buffer, + self.clone(), + shape.elem_count(), + dtype, + )) + } + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { let size = shape.elem_count() * dtype.size_in_bytes(); let buffer = self.allocate_zeros(size)?; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d7c2ed66..6b5aed96 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1349,7 +1349,7 @@ impl Tensor { } .bt())? } - let mut storage = self.device().zeros(self.shape(), self.dtype())?; + let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? }; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; let offset = start * src.dims()[1..].iter().product::<usize>(); @@ -1999,7 +1999,7 @@ impl Tensor { Ok(self.clone()) } else { let shape = self.shape(); - let mut storage = self.device().zeros(shape, self.dtype())?; + let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; let op = BackpropOp::new1(self, Op::Copy); @@ -2011,7 +2011,7 @@ impl Tensor { /// copied. pub(crate) fn make_var(&self) -> Result<Tensor> { let shape = self.shape().clone(); - let mut storage = self.device().zeros(&shape, self.dtype())?; + let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? }; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, BackpropOp::none(), true)) @@ -2064,7 +2064,7 @@ impl Tensor { }; Ok(Tensor(Arc::new(tensor_))) } else { - let mut storage = self.device().zeros(&shape, self.dtype())?; + let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? }; self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, op, false)) diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs index 25acc80e..31cc8503 100644 --- a/candle-core/src/tensor_cat.rs +++ b/candle-core/src/tensor_cat.rs @@ -141,7 +141,7 @@ impl Tensor { } let shape = Shape::from(cat_dims); let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0)); - let mut storage = device.zeros(&shape, dtype)?; + let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? }; for (arg, &offset) in args.iter().zip(offsets.iter()) { let arg = arg.as_ref(); arg.storage() @@ -215,7 +215,7 @@ impl Tensor { let block_size: usize = cat_dims.iter().skip(1 + dim).product(); let shape = Shape::from(cat_dims); let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim)); - let mut storage = device.zeros(&shape, dtype)?; + let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? }; let mut dst_o = 0; for arg in args.iter() { let arg = arg.as_ref(); |