summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-06-21 23:59:25 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-06-22 12:37:02 +0200
commit04cf14f35ae9773d9600ed98c39ada56c726338f (patch)
tree89ad4d7d7120fcccca6b828ca8e28606a6de0e0a /src/tensor.rs
parent9ea220fc6e9541e9bceea0b2fbc9587c5f1a96e8 (diff)
downloadcandle-04cf14f35ae9773d9600ed98c39ada56c726338f.tar.gz
candle-04cf14f35ae9773d9600ed98c39ada56c726338f.tar.bz2
candle-04cf14f35ae9773d9600ed98c39ada56c726338f.zip
Moving to `gemm` and adding matmul backprop.
- Tentative `T` operator.
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs110
1 files changed, 61 insertions, 49 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 7607171c..7274c557 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -147,10 +147,11 @@ impl Tensor {
pub fn new_impl<A: crate::device::NdArray>(
array: A,
+ shape: Shape,
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let shape = array.shape()?;
+ // let shape = array.shape()?;
let storage = device.storage(array)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
@@ -165,31 +166,29 @@ impl Tensor {
}
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- Self::new_impl(array, device, false)
+ let shape = array.shape()?.clone();
+ Self::new_impl(array, shape, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- Self::new_impl(array, device, true)
+ let shape = array.shape()?.clone();
+ Self::new_impl(array, shape, device, true)
}
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
- a: &[D],
+ array: &[D],
shape: S,
- device: Device,
+ device: &Device,
) -> Result<Self> {
- let shape = shape.into();
- let storage = device.storage(a)?;
- let stride = shape.stride_contiguous();
- let is_variable = false;
- let tensor_ = Tensor_ {
- id: TensorId::new(),
- storage,
- shape,
- stride,
- op: None,
- is_variable,
- };
- Ok(Self(Arc::new(tensor_)))
+ Self::new_impl(array, shape.into(), device, false)
+ }
+
+ pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>(
+ array: &[D],
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::new_impl(array, shape.into(), device, true)
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
@@ -260,6 +259,7 @@ impl Tensor {
let dim = a_dims.len();
+ // TODO
// if dim < 2 {
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
// }
@@ -309,6 +309,13 @@ impl Tensor {
crate::StridedIndex::new(self.dims(), self.stride())
}
+ pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> {
+ match &self.storage {
+ Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage),
+ Storage::Cuda { .. } => todo!(),
+ }
+ }
+
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 {
return Err(Error::UnexpectedNumberOfDims {
@@ -404,6 +411,31 @@ impl Tensor {
self.id
}
+ pub fn t(&self) -> Result<Tensor> {
+ let mut stride = self.stride().to_vec();
+ let mut shape = self.shape().clone();
+ let n = stride.len();
+ if n < 2 {
+ return Err(Error::UnexpectedNumberOfDims {
+ expected: 2,
+ got: n,
+ shape: self.shape().clone(),
+ });
+ }
+ (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
+ (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage: self.storage.clone(),
+ shape,
+ stride,
+ // TODO The op should have a backward
+ op: None,
+ is_variable: false,
+ };
+ Ok(Tensor(Arc::new(tensor_)))
+ }
+
pub fn is_contiguous(&self) -> bool {
self.shape.is_contiguous(&self.stride)
}
@@ -514,37 +546,17 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
- Op::Matmul(_lhs, _rhs) => {
- // let (m, k) = lhs.shape;
- // let n = rhs.shape.1;
- // let strides = (m, n).strides();
- // Self::matmul(
- // (m, n, k),
- // true,
- // grad_out.as_ptr(),
- // strides,
- // rhs.data.as_ptr(),
- // [rhs.strides[1], rhs.strides[0]],
- // grad_lhs.as_mut_ptr(),
- // lhs.strides,
- // );
- // Self::matmul(
- // (k, m, n),
- // true,
- // lhs.data.as_ptr(),
- // [lhs.strides[1], lhs.strides[0]],
- // grad_out.as_ptr(),
- // strides,
- // grad_rhs.as_mut_ptr(),
- // rhs.strides,
- // );
-
- // let lhs_grad = grad.matmul(rhs)?;
- // let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
- // *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
- // let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
- // let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
- // *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ Op::Matmul(lhs, rhs) => {
+ // Skipping checks, the op went ok, we can skip
+ // the matmul size checks for now.
+
+ let lhs_grad = grad.matmul(&rhs.t()?)?;
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+
+ let rhs_grad = lhs.t()?.matmul(&grad)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;