summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backend.rs6
-rw-r--r--candle-core/src/cpu_backend.rs61
-rw-r--r--candle-core/src/cuda_backend.rs56
-rw-r--r--candle-core/src/device.rs17
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/dummy_metal_backend.rs4
-rw-r--r--candle-core/src/metal_backend.rs10
-rw-r--r--candle-core/src/tensor.rs8
-rw-r--r--candle-core/src/tensor_cat.rs4
9 files changed, 154 insertions, 16 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index c63aad54..27ffe934 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -127,6 +127,12 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
+ /// # Safety
+ /// This function is unsafe as it doesn't initialize the underlying data store.
+ /// The caller should ensure that the data is properly initialized as early as possible
+ /// after this call.
+ unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
+
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index fa48577c..6d2ba361 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2582,7 +2582,10 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
- let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ let mut kernel_c = unsafe {
+ self.device()
+ .alloc_uninit(kernel_l.shape(), kernel.dtype())?
+ };
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@@ -2590,7 +2593,7 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
- let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@@ -2681,7 +2684,10 @@ impl BackendStorage for CpuStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
- let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ let mut kernel_c = unsafe {
+ self.device()
+ .alloc_uninit(kernel_l.shape(), kernel.dtype())?
+ };
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@@ -2691,7 +2697,7 @@ impl BackendStorage for CpuStorage {
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
.transpose(1, 2)?
.transpose(1, 3)?;
- let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@@ -2919,6 +2925,53 @@ impl BackendDevice for CpuDevice {
}
}
+ #[allow(clippy::uninit_vec)]
+ unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
+ let elem_count = shape.elem_count();
+ // The code below is highly unsafe but hopefully not directly unsound as we only consider
+ // types that are Copy, not Drop, and for which all bit patterns are proper values.
+ // It's still pretty risky, see the following for more details:
+ // https://github.com/rust-lang/rust-clippy/issues/4483
+ let storage = match dtype {
+ DType::U8 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::U8(v)
+ }
+ DType::U32 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::U32(v)
+ }
+ DType::I64 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::I64(v)
+ }
+ DType::BF16 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::BF16(v)
+ }
+ DType::F16 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::F16(v)
+ }
+ DType::F32 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::F32(v)
+ }
+ DType::F64 => {
+ let mut v = Vec::with_capacity(elem_count);
+ v.set_len(elem_count);
+ CpuStorage::F64(v)
+ }
+ };
+ Ok(storage)
+ }
+
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
let elem_count = shape.elem_count();
let storage = match dtype {
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index fec37c39..f0f03053 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -384,6 +384,44 @@ impl BackendDevice for CudaDevice {
self.const_impl(1., shape, dtype)
}
+ unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
+ let elem_count = shape.elem_count();
+ let slice = match dtype {
+ DType::U8 => {
+ let data = self.alloc::<u8>(elem_count).w()?;
+ CudaStorageSlice::U8(data)
+ }
+ DType::U32 => {
+ let data = self.alloc::<u32>(elem_count).w()?;
+ CudaStorageSlice::U32(data)
+ }
+ DType::I64 => {
+ let data = self.alloc::<i64>(elem_count).w()?;
+ CudaStorageSlice::I64(data)
+ }
+ DType::BF16 => {
+ let data = self.alloc::<bf16>(elem_count).w()?;
+ CudaStorageSlice::BF16(data)
+ }
+ DType::F16 => {
+ let data = self.alloc::<f16>(elem_count).w()?;
+ CudaStorageSlice::F16(data)
+ }
+ DType::F32 => {
+ let data = self.alloc::<f32>(elem_count).w()?;
+ CudaStorageSlice::F32(data)
+ }
+ DType::F64 => {
+ let data = self.alloc::<f64>(elem_count).w()?;
+ CudaStorageSlice::F64(data)
+ }
+ };
+ Ok(CudaStorage {
+ slice,
+ device: self.clone(),
+ })
+ }
+
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
@@ -1916,7 +1954,10 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
- let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ let mut kernel_c = unsafe {
+ self.device()
+ .alloc_uninit(kernel_l.shape(), kernel.dtype())?
+ };
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@@ -1924,7 +1965,7 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
- let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@@ -1981,7 +2022,10 @@ impl BackendStorage for CudaStorage {
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
- let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ let mut kernel_c = unsafe {
+ self.device()
+ .alloc_uninit(kernel_l.shape(), kernel.dtype())?
+ };
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
@@ -1991,7 +2035,7 @@ impl BackendStorage for CudaStorage {
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
- let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
@@ -2128,7 +2172,7 @@ impl BackendStorage for CudaStorage {
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
- let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
+ let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
@@ -2143,7 +2187,7 @@ impl BackendStorage for CudaStorage {
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
- let mut acc = device.zeros_impl(l.shape(), self.dtype())?;
+ let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 9c39d27a..846c62ce 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -289,6 +289,23 @@ impl Device {
}
}
+ pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
+ match self {
+ Device::Cpu => {
+ let storage = CpuDevice.alloc_uninit(shape, dtype)?;
+ Ok(Storage::Cpu(storage))
+ }
+ Device::Cuda(device) => {
+ let storage = device.alloc_uninit(shape, dtype)?;
+ Ok(Storage::Cuda(storage))
+ }
+ Device::Metal(device) => {
+ let storage = device.alloc_uninit(shape, dtype)?;
+ Ok(Storage::Metal(storage))
+ }
+ }
+ }
+
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index d4887f19..5348233c 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -210,6 +210,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
+ unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs
index 33c6c9fe..322f81d2 100644
--- a/candle-core/src/dummy_metal_backend.rs
+++ b/candle-core/src/dummy_metal_backend.rs
@@ -222,6 +222,10 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}
+ unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 4f4162e2..ef044fc8 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -1886,6 +1886,16 @@ impl BackendDevice for MetalDevice {
self.device.registry_id() == rhs.device.registry_id()
}
+ unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
+ let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-uninit")?;
+ Ok(MetalStorage::new(
+ buffer,
+ self.clone(),
+ shape.elem_count(),
+ dtype,
+ ))
+ }
+
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let size = shape.elem_count() * dtype.size_in_bytes();
let buffer = self.allocate_zeros(size)?;
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index d7c2ed66..6b5aed96 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1349,7 +1349,7 @@ impl Tensor {
}
.bt())?
}
- let mut storage = self.device().zeros(self.shape(), self.dtype())?;
+ let mut storage = unsafe { self.device().alloc_uninit(self.shape(), self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let offset = start * src.dims()[1..].iter().product::<usize>();
@@ -1999,7 +1999,7 @@ impl Tensor {
Ok(self.clone())
} else {
let shape = self.shape();
- let mut storage = self.device().zeros(shape, self.dtype())?;
+ let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = BackpropOp::new1(self, Op::Copy);
@@ -2011,7 +2011,7 @@ impl Tensor {
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {
let shape = self.shape().clone();
- let mut storage = self.device().zeros(&shape, self.dtype())?;
+ let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, BackpropOp::none(), true))
@@ -2064,7 +2064,7 @@ impl Tensor {
};
Ok(Tensor(Arc::new(tensor_)))
} else {
- let mut storage = self.device().zeros(&shape, self.dtype())?;
+ let mut storage = unsafe { self.device().alloc_uninit(&shape, self.dtype())? };
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, op, false))
diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs
index 25acc80e..31cc8503 100644
--- a/candle-core/src/tensor_cat.rs
+++ b/candle-core/src/tensor_cat.rs
@@ -141,7 +141,7 @@ impl Tensor {
}
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
- let mut storage = device.zeros(&shape, dtype)?;
+ let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
arg.storage()
@@ -215,7 +215,7 @@ impl Tensor {
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
let shape = Shape::from(cat_dims);
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
- let mut storage = device.zeros(&shape, dtype)?;
+ let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
let mut dst_o = 0;
for arg in args.iter() {
let arg = arg.as_ref();