summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs123
1 files changed, 72 insertions, 51 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index a174edd0..e06f1d37 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,6 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::shape::Dim;
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
+use std::cell::RefCell;
use std::sync::Arc;
/// Unique identifier for tensors.
@@ -18,10 +19,23 @@ impl TensorId {
pub struct Tensor_ {
id: TensorId,
- storage: Arc<Storage>,
+ // Storage uses a mutex here so inner mutability is available and borrow rules are checked
+ // dynamically. The alternatives would be:
+ // - Using a mutex, this would have the highest cost when retrieving the storage but would
+ // prevent errors when concurrent access takes place. Mutex would also be subject to
+ // deadlocks for example using the current code if the same tensor is used twice by a single
+ // binary op.
+ // - Using an unsafe cell would have the lowest cost but undefined behavior on concurrent
+ // accesses.
+ // Ideally, we would use Arc<Storage> for tensors on which we don't plan on modifying the data
+ // and Arc<Mutex<Storage>> for tensors where the data could be modified, e.g. variables but
+ // that's tricky to encode in the current setup.
+ storage: Arc<RefCell<Storage>>,
layout: Layout,
op: Option<Op>,
is_variable: bool,
+ dtype: DType,
+ device: Device,
}
impl AsRef<Tensor> for Tensor {
@@ -62,7 +76,7 @@ macro_rules! unary_op {
pub fn $fn_name(&self) -> Result<Self> {
let shape = self.shape();
let storage = self
- .storage
+ .storage()?
.unary_impl::<crate::op::$op_name>(self.layout())?;
let op = if self.track_op() {
Some(Op::$op_name(self.clone()))
@@ -78,8 +92,8 @@ macro_rules! binary_op {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
- let storage = self.storage.binary_impl::<crate::op::$op_name>(
- &rhs.storage,
+ let storage = self.storage()?.binary_impl::<crate::op::$op_name>(
+ &*rhs.storage()?,
self.layout(),
rhs.layout(),
)?;
@@ -119,12 +133,16 @@ fn from_storage<S: Into<Shape>>(
op: Option<Op>,
is_variable: bool,
) -> Tensor {
+ let dtype = storage.dtype();
+ let device = storage.device();
let tensor_ = Tensor_ {
id: TensorId::new(),
- storage: Arc::new(storage),
+ storage: Arc::new(RefCell::new(storage)),
layout: Layout::contiguous(shape),
op,
is_variable,
+ dtype,
+ device,
};
Tensor(Arc::new(tensor_))
}
@@ -169,7 +187,7 @@ impl Tensor {
/// # Ok::<(), candle::Error>(())
/// ```
pub fn ones_like(&self) -> Result<Self> {
- Tensor::ones(self.shape(), self.dtype(), &self.device())
+ Tensor::ones(self.shape(), self.dtype(), self.device())
}
/// Creates a new tensor filled with zeros.
@@ -219,7 +237,7 @@ impl Tensor {
/// # Ok::<(), candle::Error>(())
/// ```
pub fn zeros_like(&self) -> Result<Self> {
- Tensor::zeros(self.shape(), self.dtype(), &self.device())
+ Tensor::zeros(self.shape(), self.dtype(), self.device())
}
fn rand_impl<S: Into<Shape>>(
@@ -502,7 +520,7 @@ impl Tensor {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok::<_, Error>(data[self.layout().start_offset()])
};
- match self.storage.as_ref() {
+ match &*self.storage()? {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
@@ -520,7 +538,7 @@ impl Tensor {
/// # Ok::<(), candle::Error>(())
/// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
- let storage = self.storage.affine(self.layout(), mul, add)?;
+ let storage = self.storage()?.affine(self.layout(), mul, add)?;
let op = if self.track_op() {
Some(Op::Affine {
arg: self.clone(),
@@ -535,7 +553,7 @@ impl Tensor {
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> {
- let storage = self.storage.elu(self.layout(), alpha)?;
+ let storage = self.storage()?.elu(self.layout(), alpha)?;
let op = if self.track_op() {
Some(Op::Elu(self.clone(), alpha))
} else {
@@ -585,6 +603,8 @@ impl Tensor {
layout,
op,
is_variable: false,
+ dtype: self.dtype,
+ device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@@ -616,7 +636,9 @@ impl Tensor {
exp.broadcast_div(&sum_exp)
} else {
let shape = self.shape();
- let mut storage = self.storage.unary_impl::<crate::op::Exp>(self.layout())?;
+ let mut storage = self
+ .storage()?
+ .unary_impl::<crate::op::Exp>(self.layout())?;
// The resulting storage is contiguous.
storage.divide_by_sum_over_dim(shape, dim)?;
let op = if self.track_op() {
@@ -649,7 +671,7 @@ impl Tensor {
for &dim in sum_dims {
self.check_dim(dim, "sum")?;
}
- let storage = self.storage.sum(self.layout(), sum_dims)?;
+ let storage = self.storage()?.sum(self.layout(), sum_dims)?;
let op = if self.track_op() {
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
} else {
@@ -695,8 +717,8 @@ impl Tensor {
stride,
};
let storage =
- self.storage
- .conv1d(self.layout(), &kernel.storage, kernel.layout(), &params)?;
+ self.storage()?
+ .conv1d(self.layout(), &*kernel.storage()?, kernel.layout(), &params)?;
let op = if self.track_op() || kernel.track_op() {
Some(Op::Conv1D {
arg: self.clone(),
@@ -749,8 +771,8 @@ impl Tensor {
})?
}
- let storage = self.storage.matmul(
- &rhs.storage,
+ let storage = self.storage()?.matmul(
+ &*rhs.storage()?,
(batching, m, n, k),
self.layout(),
rhs.layout(),
@@ -769,11 +791,11 @@ impl Tensor {
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
- let storage = self.storage.where_cond(
+ let storage = self.storage()?.where_cond(
self.layout(),
- &on_true.storage,
+ &*on_true.storage()?,
on_true.layout(),
- &on_false.storage,
+ &*on_false.storage()?,
on_false.layout(),
)?;
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
@@ -821,8 +843,8 @@ impl Tensor {
let seq_len = ids_shape.r1()?;
let (_, hidden_size) = rhs.shape().r2()?;
let storage = ids
- .storage
- .embedding(ids.layout(), &rhs.storage, rhs.layout())?;
+ .storage()?
+ .embedding(ids.layout(), &*rhs.storage()?, rhs.layout())?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))
@@ -836,23 +858,6 @@ impl Tensor {
self.layout.strided_index()
}
- /// Returns data from the underlying storage, this does not take the strides
- /// into account so the size of the resulting buffer might be larger than the
- /// tensor number of elements.
- pub fn storage_data<S: crate::WithDType>(&self) -> Result<std::borrow::Cow<[S]>> {
- match self.storage.as_ref() {
- Storage::Cpu(cpu_storage) => {
- let slice = S::cpu_storage_as_slice(cpu_storage)?;
- Ok(std::borrow::Cow::Borrowed(slice))
- }
- Storage::Cuda(slice) => {
- let cpu_storage = slice.to_cpu_storage()?;
- let storage_data = S::cpu_storage_data(cpu_storage)?;
- Ok(std::borrow::Cow::Owned(storage_data))
- }
- }
- }
-
/// Returns the data contained in a 1D tensor as a vector of scalar values.
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 {
@@ -862,7 +867,7 @@ impl Tensor {
shape: self.shape().clone(),
});
}
- match self.storage.as_ref() {
+ match &*self.storage()? {
Storage::Cpu(cpu_storage) => {
let data = S::cpu_storage_as_slice(cpu_storage)?;
Ok(self.strided_index().map(|i| data[i]).collect())
@@ -890,7 +895,7 @@ impl Tensor {
assert!(src_index.next().is_none());
Ok(rows)
};
- match self.storage.as_ref() {
+ match &*self.storage()? {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
@@ -914,7 +919,7 @@ impl Tensor {
assert!(src_index.next().is_none());
Ok(top_rows)
};
- match self.storage.as_ref() {
+ match &*self.storage()? {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
@@ -922,12 +927,12 @@ impl Tensor {
/// The dtype for the elements stored in the input tensor.
pub fn dtype(&self) -> DType {
- self.storage.dtype()
+ self.dtype
}
/// The device on which the input tensor is located.
- pub fn device(&self) -> Device {
- self.storage.device()
+ pub fn device(&self) -> &Device {
+ &self.device
}
/// The tensor shape, i.e. dimension sizes on each axis.
@@ -1114,6 +1119,8 @@ impl Tensor {
layout: self.layout.transpose(dim1, dim2)?,
op,
is_variable: false,
+ dtype: self.dtype,
+ device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@@ -1133,10 +1140,12 @@ impl Tensor {
pub fn copy(&self) -> Result<Tensor> {
let tensor_ = Tensor_ {
id: TensorId::new(),
- storage: Arc::new(self.storage.try_clone(self.layout())?),
+ storage: Arc::new(RefCell::new(self.storage()?.try_clone(self.layout())?)),
layout: self.layout.clone(),
op: None, // TODO
is_variable: false,
+ dtype: self.dtype,
+ device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@@ -1150,6 +1159,8 @@ impl Tensor {
layout: self.layout.clone(),
op: None,
is_variable: false,
+ dtype: self.dtype,
+ device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@@ -1159,7 +1170,7 @@ impl Tensor {
if self.device().same_device(device) {
Ok(self.clone())
} else {
- let storage = match (self.storage.as_ref(), device) {
+ let storage = match (&*self.storage()?, device) {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
@@ -1179,10 +1190,12 @@ impl Tensor {
};
let tensor_ = Tensor_ {
id: TensorId::new(),
- storage: Arc::new(storage),
+ storage: Arc::new(RefCell::new(storage)),
layout: self.layout.clone(),
op,
is_variable: false,
+ dtype: self.dtype,
+ device: device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@@ -1216,6 +1229,8 @@ impl Tensor {
layout: self.layout.broadcast_as(shape)?,
op,
is_variable: false,
+ dtype: self.dtype,
+ device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
@@ -1240,7 +1255,7 @@ impl Tensor {
Ok(self.clone())
} else {
let shape = self.shape();
- let storage = self.storage.to_dtype(self.layout(), dtype)?;
+ let storage = self.storage()?.to_dtype(self.layout(), dtype)?;
let op = if self.track_op() {
Some(Op::ToDType(self.clone()))
} else {
@@ -1258,7 +1273,7 @@ impl Tensor {
} else {
let shape = self.shape();
let mut storage = self.device().zeros(shape, self.dtype())?;
- self.storage
+ self.storage()?
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(
storage,
@@ -1307,11 +1322,13 @@ impl Tensor {
layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
op,
is_variable: false,
+ dtype: self.dtype,
+ device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
} else {
let mut storage = self.device().zeros(&shape, self.dtype())?;
- self.storage
+ self.storage()?
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, op, false))
}
@@ -1507,11 +1524,15 @@ impl Tensor {
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
- arg.storage
+ arg.storage()?
.copy_strided_src(&mut storage, offset, arg.layout())?;
}
Ok(from_storage(storage, shape, op, false))
}
+
+ fn storage(&self) -> Result<std::cell::Ref<'_, Storage>> {
+ Ok(self.storage.try_borrow()?)
+ }
}
macro_rules! bin_trait {