summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.cargo/config.toml5
-rw-r--r--src/shape.rs2
-rw-r--r--src/tensor.rs52
3 files changed, 58 insertions, 1 deletions
diff --git a/.cargo/config.toml b/.cargo/config.toml
new file mode 100644
index 00000000..a6c6276e
--- /dev/null
+++ b/.cargo/config.toml
@@ -0,0 +1,5 @@
+[target.x86_64-unknown-linux-gnu]
+rustflags = ["-C", "target-cpu=native"]
+
+[target.aarch64-apple-darwin]
+rustflags = ["-C", "target-cpu=native"]
diff --git a/src/shape.rs b/src/shape.rs
index 97f0f567..36a48276 100644
--- a/src/shape.rs
+++ b/src/shape.rs
@@ -131,7 +131,7 @@ mod tests {
#[test]
fn stride() {
let shape = Shape::from(());
- assert_eq!(shape.stride_contiguous(), []);
+ assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
diff --git a/src/tensor.rs b/src/tensor.rs
index 2726b326..88f47a15 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -22,6 +22,7 @@ pub struct Tensor_ {
// The strides are given in number of elements and not in bytes.
stride: Vec<usize>,
op: Option<Op>,
+ is_variable: bool,
}
#[derive(Clone)]
@@ -52,6 +53,7 @@ macro_rules! unary_op {
shape: shape.clone(),
stride: shape.stride_contiguous(),
op: Some(Op::$op_name(self.clone())),
+ is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
}
@@ -71,6 +73,7 @@ macro_rules! binary_op {
shape: shape.clone(),
stride: shape.stride_contiguous(),
op: Some(Op::$op_name(self.clone(), rhs.clone())),
+ is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
}
@@ -88,6 +91,7 @@ impl Tensor {
shape,
stride,
op: None,
+ is_variable: false,
};
Self(Arc::new(tensor_))
}
@@ -102,6 +106,7 @@ impl Tensor {
shape,
stride,
op: None,
+ is_variable: false,
};
Ok(Self(Arc::new(tensor_)))
}
@@ -211,4 +216,51 @@ impl Tensor {
pub fn id(&self) -> TensorId {
self.id
}
+
+ /// Return all the nodes that lead to this value in a topologically sorted vec, the first
+ /// elements having dependencies on the latter ones, e.g. the first element if any is the
+ /// argument.
+ /// This assumes that the op graph is a DAG.
+ pub fn sorted_nodes(&self) -> Vec<&Tensor> {
+ use std::collections::HashMap;
+
+ // The vec of sorted nodes is passed as an owned value rather than a mutable reference
+ // to get around some lifetime limitations.
+ fn walk<'a>(
+ node: &'a Tensor,
+ nodes: Vec<&'a Tensor>,
+ already_seen: &mut HashMap<TensorId, bool>,
+ ) -> (bool, Vec<&'a Tensor>) {
+ if let Some(&tg) = already_seen.get(&node.id) {
+ return (tg, nodes);
+ }
+ let mut track_grad = false;
+ let mut nodes = if let Some(op) = &node.op {
+ match op {
+ Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
+ let (tg, nodes) = walk(lhs, nodes, already_seen);
+ track_grad |= tg;
+ let (tg, nodes) = walk(rhs, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ }
+ Op::Sqr(node) | Op::Sqrt(node) => {
+ let (tg, nodes) = walk(node, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ }
+ }
+ } else {
+ nodes
+ };
+ already_seen.insert(node.id, track_grad);
+ if track_grad {
+ nodes.push(node);
+ }
+ (track_grad, nodes)
+ }
+ let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
+ nodes.reverse();
+ nodes
+ }
}