summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/tensor.rs16
1 files changed, 8 insertions, 8 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index e55050c6..7607171c 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -178,7 +178,7 @@ impl Tensor {
device: Device,
) -> Result<Self> {
let shape = shape.into();
- let storage = device.storage(a);
+ let storage = device.storage(a)?;
let stride = shape.stride_contiguous();
let is_variable = false;
let tensor_ = Tensor_ {
@@ -514,7 +514,7 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
- Op::Matmul(lhs, rhs) => {
+ Op::Matmul(_lhs, _rhs) => {
// let (m, k) = lhs.shape;
// let n = rhs.shape.1;
// let strides = (m, n).strides();
@@ -539,12 +539,12 @@ impl Tensor {
// rhs.strides,
// );
- let lhs_grad = grad.matmul(rhs)?;
- let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
- *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
- let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
- let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
- *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ // let lhs_grad = grad.matmul(rhs)?;
+ // let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
+ // *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+ // let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
+ // let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
+ // *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;