diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-13 11:04:40 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 11:04:40 +0100 |
commit | 50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb (patch) | |
tree | c48c4ecc686748e10b678d347af8d46cb0955a6c /candle-core/src/tensor.rs | |
parent | a3663ce2f2b03263075099baed677340974b7f4c (diff) | |
download | candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.gz candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.bz2 candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.zip |
Tensor mutability (#154)
* Working towards tensor mutability.
* Use a ref-cell to provide tensor mutability.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 123 |
1 files changed, 72 insertions, 51 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a174edd0..e06f1d37 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,6 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::shape::Dim; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; +use std::cell::RefCell; use std::sync::Arc; /// Unique identifier for tensors. @@ -18,10 +19,23 @@ impl TensorId { pub struct Tensor_ { id: TensorId, - storage: Arc<Storage>, + // Storage uses a mutex here so inner mutability is available and borrow rules are checked + // dynamically. The alternatives would be: + // - Using a mutex, this would have the highest cost when retrieving the storage but would + // prevent errors when concurrent access takes place. Mutex would also be subject to + // deadlocks for example using the current code if the same tensor is used twice by a single + // binary op. + // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent + // accesses. + // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data + // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but + // that's tricky to encode in the current setup. + storage: Arc<RefCell<Storage>>, layout: Layout, op: Option<Op>, is_variable: bool, + dtype: DType, + device: Device, } impl AsRef<Tensor> for Tensor { @@ -62,7 +76,7 @@ macro_rules! unary_op { pub fn $fn_name(&self) -> Result<Self> { let shape = self.shape(); let storage = self - .storage + .storage()? .unary_impl::<crate::op::$op_name>(self.layout())?; let op = if self.track_op() { Some(Op::$op_name(self.clone())) @@ -78,8 +92,8 @@ macro_rules! binary_op { ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result<Self> { let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; - let storage = self.storage.binary_impl::<crate::op::$op_name>( - &rhs.storage, + let storage = self.storage()?.binary_impl::<crate::op::$op_name>( + &*rhs.storage()?, self.layout(), rhs.layout(), )?; @@ -119,12 +133,16 @@ fn from_storage<S: Into<Shape>>( op: Option<Op>, is_variable: bool, ) -> Tensor { + let dtype = storage.dtype(); + let device = storage.device(); let tensor_ = Tensor_ { id: TensorId::new(), - storage: Arc::new(storage), + storage: Arc::new(RefCell::new(storage)), layout: Layout::contiguous(shape), op, is_variable, + dtype, + device, }; Tensor(Arc::new(tensor_)) } @@ -169,7 +187,7 @@ impl Tensor { /// # Ok::<(), candle::Error>(()) /// ``` pub fn ones_like(&self) -> Result<Self> { - Tensor::ones(self.shape(), self.dtype(), &self.device()) + Tensor::ones(self.shape(), self.dtype(), self.device()) } /// Creates a new tensor filled with zeros. @@ -219,7 +237,7 @@ impl Tensor { /// # Ok::<(), candle::Error>(()) /// ``` pub fn zeros_like(&self) -> Result<Self> { - Tensor::zeros(self.shape(), self.dtype(), &self.device()) + Tensor::zeros(self.shape(), self.dtype(), self.device()) } fn rand_impl<S: Into<Shape>>( @@ -502,7 +520,7 @@ impl Tensor { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok::<_, Error>(data[self.layout().start_offset()]) }; - match self.storage.as_ref() { + match &*self.storage()? { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -520,7 +538,7 @@ impl Tensor { /// # Ok::<(), candle::Error>(()) /// ``` pub fn affine(&self, mul: f64, add: f64) -> Result<Self> { - let storage = self.storage.affine(self.layout(), mul, add)?; + let storage = self.storage()?.affine(self.layout(), mul, add)?; let op = if self.track_op() { Some(Op::Affine { arg: self.clone(), @@ -535,7 +553,7 @@ impl Tensor { /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor. pub fn elu(&self, alpha: f64) -> Result<Self> { - let storage = self.storage.elu(self.layout(), alpha)?; + let storage = self.storage()?.elu(self.layout(), alpha)?; let op = if self.track_op() { Some(Op::Elu(self.clone(), alpha)) } else { @@ -585,6 +603,8 @@ impl Tensor { layout, op, is_variable: false, + dtype: self.dtype, + device: self.device.clone(), }; Ok(Tensor(Arc::new(tensor_))) } @@ -616,7 +636,9 @@ impl Tensor { exp.broadcast_div(&sum_exp) } else { let shape = self.shape(); - let mut storage = self.storage.unary_impl::<crate::op::Exp>(self.layout())?; + let mut storage = self + .storage()? + .unary_impl::<crate::op::Exp>(self.layout())?; // The resulting storage is contiguous. storage.divide_by_sum_over_dim(shape, dim)?; let op = if self.track_op() { @@ -649,7 +671,7 @@ impl Tensor { for &dim in sum_dims { self.check_dim(dim, "sum")?; } - let storage = self.storage.sum(self.layout(), sum_dims)?; + let storage = self.storage()?.sum(self.layout(), sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { @@ -695,8 +717,8 @@ impl Tensor { stride, }; let storage = - self.storage - .conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?; + self.storage()? + .conv1d(self.layout(), &*kernel.storage()?, kernel.layout(), ¶ms)?; let op = if self.track_op() || kernel.track_op() { Some(Op::Conv1D { arg: self.clone(), @@ -749,8 +771,8 @@ impl Tensor { })? } - let storage = self.storage.matmul( - &rhs.storage, + let storage = self.storage()?.matmul( + &*rhs.storage()?, (batching, m, n, k), self.layout(), rhs.layout(), @@ -769,11 +791,11 @@ impl Tensor { pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> { let _shap = self.same_shape_binary_op(on_true, "where_cond")?; let shape = self.same_shape_binary_op(on_false, "where_cond")?; - let storage = self.storage.where_cond( + let storage = self.storage()?.where_cond( self.layout(), - &on_true.storage, + &*on_true.storage()?, on_true.layout(), - &on_false.storage, + &*on_false.storage()?, on_false.layout(), )?; let op = if self.track_op() || on_true.track_op() || on_false.track_op() { @@ -821,8 +843,8 @@ impl Tensor { let seq_len = ids_shape.r1()?; let (_, hidden_size) = rhs.shape().r2()?; let storage = ids - .storage - .embedding(ids.layout(), &rhs.storage, rhs.layout())?; + .storage()? + .embedding(ids.layout(), &*rhs.storage()?, rhs.layout())?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) @@ -836,23 +858,6 @@ impl Tensor { self.layout.strided_index() } - /// Returns data from the underlying storage, this does not take the strides - /// into account so the size of the resulting buffer might be larger than the - /// tensor number of elements. - pub fn storage_data<S: crate::WithDType>(&self) -> Result<std::borrow::Cow<[S]>> { - match self.storage.as_ref() { - Storage::Cpu(cpu_storage) => { - let slice = S::cpu_storage_as_slice(cpu_storage)?; - Ok(std::borrow::Cow::Borrowed(slice)) - } - Storage::Cuda(slice) => { - let cpu_storage = slice.to_cpu_storage()?; - let storage_data = S::cpu_storage_data(cpu_storage)?; - Ok(std::borrow::Cow::Owned(storage_data)) - } - } - } - /// Returns the data contained in a 1D tensor as a vector of scalar values. pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> { if self.rank() != 1 { @@ -862,7 +867,7 @@ impl Tensor { shape: self.shape().clone(), }); } - match self.storage.as_ref() { + match &*self.storage()? { Storage::Cpu(cpu_storage) => { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) @@ -890,7 +895,7 @@ impl Tensor { assert!(src_index.next().is_none()); Ok(rows) }; - match self.storage.as_ref() { + match &*self.storage()? { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -914,7 +919,7 @@ impl Tensor { assert!(src_index.next().is_none()); Ok(top_rows) }; - match self.storage.as_ref() { + match &*self.storage()? { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -922,12 +927,12 @@ impl Tensor { /// The dtype for the elements stored in the input tensor. pub fn dtype(&self) -> DType { - self.storage.dtype() + self.dtype } /// The device on which the input tensor is located. - pub fn device(&self) -> Device { - self.storage.device() + pub fn device(&self) -> &Device { + &self.device } /// The tensor shape, i.e. dimension sizes on each axis. @@ -1114,6 +1119,8 @@ impl Tensor { layout: self.layout.transpose(dim1, dim2)?, op, is_variable: false, + dtype: self.dtype, + device: self.device.clone(), }; Ok(Tensor(Arc::new(tensor_))) } @@ -1133,10 +1140,12 @@ impl Tensor { pub fn copy(&self) -> Result<Tensor> { let tensor_ = Tensor_ { id: TensorId::new(), - storage: Arc::new(self.storage.try_clone(self.layout())?), + storage: Arc::new(RefCell::new(self.storage()?.try_clone(self.layout())?)), layout: self.layout.clone(), op: None, // TODO is_variable: false, + dtype: self.dtype, + device: self.device.clone(), }; Ok(Tensor(Arc::new(tensor_))) } @@ -1150,6 +1159,8 @@ impl Tensor { layout: self.layout.clone(), op: None, is_variable: false, + dtype: self.dtype, + device: self.device.clone(), }; Ok(Tensor(Arc::new(tensor_))) } @@ -1159,7 +1170,7 @@ impl Tensor { if self.device().same_device(device) { Ok(self.clone()) } else { - let storage = match (self.storage.as_ref(), device) { + let storage = match (&*self.storage()?, device) { (Storage::Cpu(storage), Device::Cuda(cuda)) => { Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) } @@ -1179,10 +1190,12 @@ impl Tensor { }; let tensor_ = Tensor_ { id: TensorId::new(), - storage: Arc::new(storage), + storage: Arc::new(RefCell::new(storage)), layout: self.layout.clone(), op, is_variable: false, + dtype: self.dtype, + device: device.clone(), }; Ok(Tensor(Arc::new(tensor_))) } @@ -1216,6 +1229,8 @@ impl Tensor { layout: self.layout.broadcast_as(shape)?, op, is_variable: false, + dtype: self.dtype, + device: self.device.clone(), }; Ok(Tensor(Arc::new(tensor_))) } @@ -1240,7 +1255,7 @@ impl Tensor { Ok(self.clone()) } else { let shape = self.shape(); - let storage = self.storage.to_dtype(self.layout(), dtype)?; + let storage = self.storage()?.to_dtype(self.layout(), dtype)?; let op = if self.track_op() { Some(Op::ToDType(self.clone())) } else { @@ -1258,7 +1273,7 @@ impl Tensor { } else { let shape = self.shape(); let mut storage = self.device().zeros(shape, self.dtype())?; - self.storage + self.storage()? .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage( storage, @@ -1307,11 +1322,13 @@ impl Tensor { layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()), op, is_variable: false, + dtype: self.dtype, + device: self.device.clone(), }; Ok(Tensor(Arc::new(tensor_))) } else { let mut storage = self.device().zeros(&shape, self.dtype())?; - self.storage + self.storage()? .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, op, false)) } @@ -1507,11 +1524,15 @@ impl Tensor { let mut storage = device.zeros(&shape, dtype)?; for (arg, &offset) in args.iter().zip(offsets.iter()) { let arg = arg.as_ref(); - arg.storage + arg.storage()? .copy_strided_src(&mut storage, offset, arg.layout())?; } Ok(from_storage(storage, shape, op, false)) } + + fn storage(&self) -> Result<std::cell::Ref<'_, Storage>> { + Ok(self.storage.try_borrow()?) + } } macro_rules! bin_trait { |