diff options
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | candle-core/src/error.rs | 6 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 67 |
3 files changed, 33 insertions, 42 deletions
@@ -5,7 +5,7 @@ members = [ "candle-kernels", "candle-hub", "candle-nn", -# "candle-pyo3", + "candle-pyo3", "candle-transformers", ] diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 36b56aee..2ab2ec1d 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -155,12 +155,6 @@ pub enum Error { #[error(transparent)] SafeTensor(#[from] safetensors::SafeTensorError), - // Maybe we could have a more detailed error here, including the line of the function that - // triggered this or some backtrace. - /// Borrow error. - #[error(transparent)] - BorrowError(#[from] std::cell::BorrowError), - #[error("unsupported safetensor dtype {0:?}")] UnsupportedSafeTensorDtype(safetensors::Dtype), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b9edfedc..4aa2ff91 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,8 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::shape::Dim; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; -use std::cell::RefCell; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; /// Unique identifier for tensors. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -31,7 +30,7 @@ pub struct Tensor_ { // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but // that's tricky to encode in the current setup. - storage: Arc<RefCell<Storage>>, + storage: Arc<RwLock<Storage>>, layout: Layout, op: Option<Op>, is_variable: bool, @@ -77,7 +76,7 @@ macro_rules! unary_op { pub fn $fn_name(&self) -> Result<Self> { let shape = self.shape(); let storage = self - .storage()? + .storage() .unary_impl::<crate::op::$op_name>(self.layout())?; let op = if self.track_op() { Some(Op::$op_name(self.clone())) @@ -93,8 +92,8 @@ macro_rules! binary_op { ($fn_name:ident, $op_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result<Self> { let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; - let storage = self.storage()?.binary_impl::<crate::op::$op_name>( - &*rhs.storage()?, + let storage = self.storage().binary_impl::<crate::op::$op_name>( + &*rhs.storage(), self.layout(), rhs.layout(), )?; @@ -138,7 +137,7 @@ fn from_storage<S: Into<Shape>>( let device = storage.device(); let tensor_ = Tensor_ { id: TensorId::new(), - storage: Arc::new(RefCell::new(storage)), + storage: Arc::new(RwLock::new(storage)), layout: Layout::contiguous(shape), op, is_variable, @@ -521,7 +520,7 @@ impl Tensor { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok::<_, Error>(data[self.layout().start_offset()]) }; - match &*self.storage()? { + match &*self.storage() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -539,7 +538,7 @@ impl Tensor { /// # Ok::<(), candle::Error>(()) /// ``` pub fn affine(&self, mul: f64, add: f64) -> Result<Self> { - let storage = self.storage()?.affine(self.layout(), mul, add)?; + let storage = self.storage().affine(self.layout(), mul, add)?; let op = if self.track_op() { Some(Op::Affine { arg: self.clone(), @@ -554,7 +553,7 @@ impl Tensor { /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor. pub fn elu(&self, alpha: f64) -> Result<Self> { - let storage = self.storage()?.elu(self.layout(), alpha)?; + let storage = self.storage().elu(self.layout(), alpha)?; let op = if self.track_op() { Some(Op::Elu(self.clone(), alpha)) } else { @@ -637,9 +636,7 @@ impl Tensor { exp.broadcast_div(&sum_exp) } else { let shape = self.shape(); - let mut storage = self - .storage()? - .unary_impl::<crate::op::Exp>(self.layout())?; + let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?; // The resulting storage is contiguous. storage.divide_by_sum_over_dim(shape, dim)?; let op = if self.track_op() { @@ -672,7 +669,7 @@ impl Tensor { for &dim in sum_dims { self.check_dim(dim, "sum")?; } - let storage = self.storage()?.sum(self.layout(), sum_dims)?; + let storage = self.storage().sum(self.layout(), sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { @@ -718,8 +715,8 @@ impl Tensor { stride, }; let storage = - self.storage()? - .conv1d(self.layout(), &*kernel.storage()?, kernel.layout(), ¶ms)?; + self.storage() + .conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; let op = if self.track_op() || kernel.track_op() { Some(Op::Conv1D { arg: self.clone(), @@ -772,8 +769,8 @@ impl Tensor { })? } - let storage = self.storage()?.matmul( - &*rhs.storage()?, + let storage = self.storage().matmul( + &rhs.storage(), (batching, m, n, k), self.layout(), rhs.layout(), @@ -792,11 +789,11 @@ impl Tensor { pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> { let _shap = self.same_shape_binary_op(on_true, "where_cond")?; let shape = self.same_shape_binary_op(on_false, "where_cond")?; - let storage = self.storage()?.where_cond( + let storage = self.storage().where_cond( self.layout(), - &*on_true.storage()?, + &on_true.storage(), on_true.layout(), - &*on_false.storage()?, + &on_false.storage(), on_false.layout(), )?; let op = if self.track_op() || on_true.track_op() || on_false.track_op() { @@ -844,8 +841,8 @@ impl Tensor { let seq_len = ids_shape.r1()?; let (_, hidden_size) = rhs.shape().r2()?; let storage = ids - .storage()? - .embedding(ids.layout(), &*rhs.storage()?, rhs.layout())?; + .storage() + .embedding(ids.layout(), &rhs.storage(), rhs.layout())?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) @@ -868,7 +865,7 @@ impl Tensor { shape: self.shape().clone(), }); } - match &*self.storage()? { + match &*self.storage() { Storage::Cpu(cpu_storage) => { let data = S::cpu_storage_as_slice(cpu_storage)?; Ok(self.strided_index().map(|i| data[i]).collect()) @@ -896,7 +893,7 @@ impl Tensor { assert!(src_index.next().is_none()); Ok(rows) }; - match &*self.storage()? { + match &*self.storage() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -920,7 +917,7 @@ impl Tensor { assert!(src_index.next().is_none()); Ok(top_rows) }; - match &*self.storage()? { + match &*self.storage() { Storage::Cpu(storage) => from_cpu_storage(storage), Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?), } @@ -1141,7 +1138,7 @@ impl Tensor { pub fn copy(&self) -> Result<Tensor> { let tensor_ = Tensor_ { id: TensorId::new(), - storage: Arc::new(RefCell::new(self.storage()?.try_clone(self.layout())?)), + storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)), layout: self.layout.clone(), op: None, // TODO is_variable: false, @@ -1171,7 +1168,7 @@ impl Tensor { if self.device().same_device(device) { Ok(self.clone()) } else { - let storage = match (&*self.storage()?, device) { + let storage = match (&*self.storage(), device) { (Storage::Cpu(storage), Device::Cuda(cuda)) => { Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) } @@ -1191,7 +1188,7 @@ impl Tensor { }; let tensor_ = Tensor_ { id: TensorId::new(), - storage: Arc::new(RefCell::new(storage)), + storage: Arc::new(RwLock::new(storage)), layout: self.layout.clone(), op, is_variable: false, @@ -1256,7 +1253,7 @@ impl Tensor { Ok(self.clone()) } else { let shape = self.shape(); - let storage = self.storage()?.to_dtype(self.layout(), dtype)?; + let storage = self.storage().to_dtype(self.layout(), dtype)?; let op = if self.track_op() { Some(Op::ToDType(self.clone())) } else { @@ -1274,7 +1271,7 @@ impl Tensor { } else { let shape = self.shape(); let mut storage = self.device().zeros(shape, self.dtype())?; - self.storage()? + self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage( storage, @@ -1329,7 +1326,7 @@ impl Tensor { Ok(Tensor(Arc::new(tensor_))) } else { let mut storage = self.device().zeros(&shape, self.dtype())?; - self.storage()? + self.storage() .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, op, false)) } @@ -1525,14 +1522,14 @@ impl Tensor { let mut storage = device.zeros(&shape, dtype)?; for (arg, &offset) in args.iter().zip(offsets.iter()) { let arg = arg.as_ref(); - arg.storage()? + arg.storage() .copy_strided_src(&mut storage, offset, arg.layout())?; } Ok(from_storage(storage, shape, op, false)) } - fn storage(&self) -> Result<std::cell::Ref<'_, Storage>> { - Ok(self.storage.try_borrow()?) + fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { + self.storage.read().unwrap() } } |