summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs94
1 files changed, 65 insertions, 29 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 2d704a65..9ba412f9 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -86,9 +86,9 @@ impl Tensor {
dtype: DType,
device: Device,
is_variable: bool,
- ) -> Self {
+ ) -> Result<Self> {
let shape = shape.into();
- let storage = device.ones(&shape, dtype);
+ let storage = device.ones(&shape, dtype)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
@@ -98,18 +98,18 @@ impl Tensor {
op: None,
is_variable,
};
- Self(Arc::new(tensor_))
+ Ok(Self(Arc::new(tensor_)))
}
- pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
+ pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, false)
}
- pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
+ pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, true)
}
- pub fn ones_like(&self) -> Self {
+ pub fn ones_like(&self) -> Result<Self> {
Tensor::ones(self.shape(), self.dtype(), self.device())
}
@@ -118,9 +118,9 @@ impl Tensor {
dtype: DType,
device: Device,
is_variable: bool,
- ) -> Self {
+ ) -> Result<Self> {
let shape = shape.into();
- let storage = device.zeros(&shape, dtype);
+ let storage = device.zeros(&shape, dtype)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
@@ -130,18 +130,18 @@ impl Tensor {
op: None,
is_variable,
};
- Self(Arc::new(tensor_))
+ Ok(Self(Arc::new(tensor_)))
}
- pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
+ pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, false)
}
- pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self {
+ pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, true)
}
- pub fn zeros_like(&self) -> Self {
+ pub fn zeros_like(&self) -> Result<Self> {
Tensor::zeros(self.shape(), self.dtype(), self.device())
}
@@ -151,7 +151,7 @@ impl Tensor {
is_variable: bool,
) -> Result<Self> {
let shape = array.shape()?;
- let storage = device.tensor(array);
+ let storage = device.tensor(array)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
@@ -376,16 +376,16 @@ impl Tensor {
nodes
}
- pub fn backward(&self) -> Result<HashMap<TensorId, Tensor>> {
+ pub fn backward(&self) -> Result<GradStore> {
let sorted_nodes = self.sorted_nodes();
println!("{}", sorted_nodes.len());
- let mut grads = HashMap::new();
- grads.insert(self.id, self.ones_like());
+ let mut grads = GradStore::new();
+ grads.insert(self, self.ones_like()?);
for node in sorted_nodes.iter() {
if node.is_variable {
continue;
}
- let grad = grads.remove(&node.id).unwrap();
+ let grad = grads.remove(node).unwrap();
// TODO: We should perform all these operations in place (or at least not track the
// whole graph).
// The only drawback would be if we wanted to support grad of grad but this is out of
@@ -393,51 +393,51 @@ impl Tensor {
if let Some(op) = &node.op {
match op {
Op::Add(lhs, rhs) => {
- let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
+ let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
- let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
+ let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
}
Op::Sub(lhs, rhs) => {
- let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
+ let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
- let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
+ let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
}
Op::Mul(lhs, rhs) => {
let lhs_grad = grad.mul(rhs)?;
- let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
+ let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?;
- let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
+ let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Div(lhs, rhs) => {
let lhs_grad = grad.div(rhs)?;
- let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
+ let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
- let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
+ let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
- let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
+ let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Neg(arg) => {
let arg_grad = grad.neg()?;
- let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
+ let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqr(arg) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
- let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
+ let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqrt(arg) => {
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
- let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
+ let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
};
@@ -503,3 +503,39 @@ bin_trait!(Add, add, |_| 1., |v| v);
bin_trait!(Sub, sub, |_| 1., |v: f64| -v);
bin_trait!(Mul, mul, |v| v, |_| 0.);
bin_trait!(Div, div, |v| 1. / v, |_| 0.);
+
+pub struct GradStore(HashMap<TensorId, Tensor>);
+
+impl GradStore {
+ fn new() -> Self {
+ GradStore(HashMap::new())
+ }
+
+ pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
+ self.0.get(&id)
+ }
+
+ pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
+ self.0.get(&tensor.id)
+ }
+
+ pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
+ self.0.remove(&tensor.id)
+ }
+
+ pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
+ self.0.insert(tensor.id, grad)
+ }
+
+ fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
+ use std::collections::hash_map::Entry;
+ let grad = match self.0.entry(tensor.id) {
+ Entry::Occupied(entry) => entry.into_mut(),
+ Entry::Vacant(entry) => {
+ let grad = tensor.zeros_like()?;
+ entry.insert(grad)
+ }
+ };
+ Ok(grad)
+ }
+}