diff options
-rw-r--r-- | .cargo/config.toml | 5 | ||||
-rw-r--r-- | src/shape.rs | 2 | ||||
-rw-r--r-- | src/tensor.rs | 52 |
3 files changed, 58 insertions, 1 deletions
diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..a6c6276e --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,5 @@ +[target.x86_64-unknown-linux-gnu] +rustflags = ["-C", "target-cpu=native"] + +[target.aarch64-apple-darwin] +rustflags = ["-C", "target-cpu=native"] diff --git a/src/shape.rs b/src/shape.rs index 97f0f567..36a48276 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -131,7 +131,7 @@ mod tests { #[test] fn stride() { let shape = Shape::from(()); - assert_eq!(shape.stride_contiguous(), []); + assert_eq!(shape.stride_contiguous(), Vec::<usize>::new()); let shape = Shape::from(42); assert_eq!(shape.stride_contiguous(), [1]); let shape = Shape::from((42, 1337)); diff --git a/src/tensor.rs b/src/tensor.rs index 2726b326..88f47a15 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -22,6 +22,7 @@ pub struct Tensor_ { // The strides are given in number of elements and not in bytes. stride: Vec<usize>, op: Option<Op>, + is_variable: bool, } #[derive(Clone)] @@ -52,6 +53,7 @@ macro_rules! unary_op { shape: shape.clone(), stride: shape.stride_contiguous(), op: Some(Op::$op_name(self.clone())), + is_variable: false, }; Ok(Self(Arc::new(tensor_))) } @@ -71,6 +73,7 @@ macro_rules! binary_op { shape: shape.clone(), stride: shape.stride_contiguous(), op: Some(Op::$op_name(self.clone(), rhs.clone())), + is_variable: false, }; Ok(Self(Arc::new(tensor_))) } @@ -88,6 +91,7 @@ impl Tensor { shape, stride, op: None, + is_variable: false, }; Self(Arc::new(tensor_)) } @@ -102,6 +106,7 @@ impl Tensor { shape, stride, op: None, + is_variable: false, }; Ok(Self(Arc::new(tensor_))) } @@ -211,4 +216,51 @@ impl Tensor { pub fn id(&self) -> TensorId { self.id } + + /// Return all the nodes that lead to this value in a topologically sorted vec, the first + /// elements having dependencies on the latter ones, e.g. the first element if any is the + /// argument. + /// This assumes that the op graph is a DAG. + pub fn sorted_nodes(&self) -> Vec<&Tensor> { + use std::collections::HashMap; + + // The vec of sorted nodes is passed as an owned value rather than a mutable reference + // to get around some lifetime limitations. + fn walk<'a>( + node: &'a Tensor, + nodes: Vec<&'a Tensor>, + already_seen: &mut HashMap<TensorId, bool>, + ) -> (bool, Vec<&'a Tensor>) { + if let Some(&tg) = already_seen.get(&node.id) { + return (tg, nodes); + } + let mut track_grad = false; + let mut nodes = if let Some(op) = &node.op { + match op { + Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => { + let (tg, nodes) = walk(lhs, nodes, already_seen); + track_grad |= tg; + let (tg, nodes) = walk(rhs, nodes, already_seen); + track_grad |= tg; + nodes + } + Op::Sqr(node) | Op::Sqrt(node) => { + let (tg, nodes) = walk(node, nodes, already_seen); + track_grad |= tg; + nodes + } + } + } else { + nodes + }; + already_seen.insert(node.id, track_grad); + if track_grad { + nodes.push(node); + } + (track_grad, nodes) + } + let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new()); + nodes.reverse(); + nodes + } } |