diff options
-rw-r--r-- | Cargo.toml | 5 | ||||
-rw-r--r-- | src/device.rs | 21 | ||||
-rw-r--r-- | src/tensor.rs | 94 | ||||
-rw-r--r-- | tests/grad_tests.rs | 2 | ||||
-rw-r--r-- | tests/tensor_tests.rs | 2 |
5 files changed, 87 insertions, 37 deletions
@@ -13,9 +13,14 @@ readme = "README.md" [dependencies] safetensors = "0.3.1" thiserror = "1" +cudarc = { version = "0.9.9", optional = true } [dev-dependencies] anyhow = "1" clap = { version = "4.2.4", features = ["derive"] } rand = "0.8.5" tokenizers = "0.13.3" + +[features] +default = [] +cuda = ["dep:cudarc"] diff --git a/src/device.rs b/src/device.rs index af538c6c..c76cc301 100644 --- a/src/device.rs +++ b/src/device.rs @@ -54,27 +54,36 @@ impl<S: crate::WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; } impl Device { - pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Storage { + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { match self { - Device::Cpu => Storage::Cpu(CpuStorage::ones_impl(shape, dtype)), + Device::Cpu => { + let storage = Storage::Cpu(CpuStorage::ones_impl(shape, dtype)); + Ok(storage) + } Device::Cuda { gpu_id: _ } => { todo!() } } } - pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage { + pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> { match self { - Device::Cpu => Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)), + Device::Cpu => { + let storage = Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)); + Ok(storage) + } Device::Cuda { gpu_id: _ } => { todo!() } } } - pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Storage { + pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> { match self { - Device::Cpu => Storage::Cpu(array.to_cpu_storage()), + Device::Cpu => { + let storage = Storage::Cpu(array.to_cpu_storage()); + Ok(storage) + } Device::Cuda { gpu_id: _ } => { todo!() } diff --git a/src/tensor.rs b/src/tensor.rs index 2d704a65..9ba412f9 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -86,9 +86,9 @@ impl Tensor { dtype: DType, device: Device, is_variable: bool, - ) -> Self { + ) -> Result<Self> { let shape = shape.into(); - let storage = device.ones(&shape, dtype); + let storage = device.ones(&shape, dtype)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), @@ -98,18 +98,18 @@ impl Tensor { op: None, is_variable, }; - Self(Arc::new(tensor_)) + Ok(Self(Arc::new(tensor_))) } - pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self { + pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { Self::ones_impl(shape, dtype, device, false) } - pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self { + pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { Self::ones_impl(shape, dtype, device, true) } - pub fn ones_like(&self) -> Self { + pub fn ones_like(&self) -> Result<Self> { Tensor::ones(self.shape(), self.dtype(), self.device()) } @@ -118,9 +118,9 @@ impl Tensor { dtype: DType, device: Device, is_variable: bool, - ) -> Self { + ) -> Result<Self> { let shape = shape.into(); - let storage = device.zeros(&shape, dtype); + let storage = device.zeros(&shape, dtype)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), @@ -130,18 +130,18 @@ impl Tensor { op: None, is_variable, }; - Self(Arc::new(tensor_)) + Ok(Self(Arc::new(tensor_))) } - pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self { + pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { Self::zeros_impl(shape, dtype, device, false) } - pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Self { + pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: Device) -> Result<Self> { Self::zeros_impl(shape, dtype, device, true) } - pub fn zeros_like(&self) -> Self { + pub fn zeros_like(&self) -> Result<Self> { Tensor::zeros(self.shape(), self.dtype(), self.device()) } @@ -151,7 +151,7 @@ impl Tensor { is_variable: bool, ) -> Result<Self> { let shape = array.shape()?; - let storage = device.tensor(array); + let storage = device.tensor(array)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), @@ -376,16 +376,16 @@ impl Tensor { nodes } - pub fn backward(&self) -> Result<HashMap<TensorId, Tensor>> { + pub fn backward(&self) -> Result<GradStore> { let sorted_nodes = self.sorted_nodes(); println!("{}", sorted_nodes.len()); - let mut grads = HashMap::new(); - grads.insert(self.id, self.ones_like()); + let mut grads = GradStore::new(); + grads.insert(self, self.ones_like()?); for node in sorted_nodes.iter() { if node.is_variable { continue; } - let grad = grads.remove(&node.id).unwrap(); + let grad = grads.remove(node).unwrap(); // TODO: We should perform all these operations in place (or at least not track the // whole graph). // The only drawback would be if we wanted to support grad of grad but this is out of @@ -393,51 +393,51 @@ impl Tensor { if let Some(op) = &node.op { match op { Op::Add(lhs, rhs) => { - let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&grad)?; - let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&grad)?; } Op::Sub(lhs, rhs) => { - let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&grad)?; - let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?; } Op::Mul(lhs, rhs) => { let lhs_grad = grad.mul(rhs)?; - let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; let rhs_grad = grad.mul(lhs)?; - let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } Op::Div(lhs, rhs) => { let lhs_grad = grad.div(rhs)?; - let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + let lhs_sum_grad = grads.or_insert(lhs)?; *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; - let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; - let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } Op::Neg(arg) => { let arg_grad = grad.neg()?; - let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } Op::Sqr(arg) => { let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; - let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } Op::Sqrt(arg) => { let arg_grad = grad.div(arg)?.affine(0.5, 0.)?; - let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } }; @@ -503,3 +503,39 @@ bin_trait!(Add, add, |_| 1., |v| v); bin_trait!(Sub, sub, |_| 1., |v: f64| -v); bin_trait!(Mul, mul, |v| v, |_| 0.); bin_trait!(Div, div, |v| 1. / v, |_| 0.); + +pub struct GradStore(HashMap<TensorId, Tensor>); + +impl GradStore { + fn new() -> Self { + GradStore(HashMap::new()) + } + + pub fn get_id(&self, id: TensorId) -> Option<&Tensor> { + self.0.get(&id) + } + + pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> { + self.0.get(&tensor.id) + } + + pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> { + self.0.remove(&tensor.id) + } + + pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> { + self.0.insert(tensor.id, grad) + } + + fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> { + use std::collections::hash_map::Entry; + let grad = match self.0.entry(tensor.id) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + let grad = tensor.zeros_like()?; + entry.insert(grad) + } + }; + Ok(grad) + } +} diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs index e5ba68e8..432b1520 100644 --- a/tests/grad_tests.rs +++ b/tests/grad_tests.rs @@ -6,7 +6,7 @@ fn simple_grad() -> Result<()> { let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?; let y = (((&x * &x)? + &x * 5f64)? + 4f64)?; let grads = y.backward()?; - let grad_x = grads.get(&x.id()).context("no grad for x")?; + let grad_x = grads.get(&x).context("no grad for x")?; assert_eq!(x.to_vec1::<f32>()?, [3., 1., 4.]); // y = x^2 + 5.x + 4 assert_eq!(y.to_vec1::<f32>()?, [28., 10., 40.]); diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 01f6f66c..fb2d84d9 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -2,7 +2,7 @@ use candle::{DType, Device, Result, Tensor}; #[test] fn zeros() -> Result<()> { - let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu); + let tensor = Tensor::zeros((5, 2), DType::F32, Device::Cpu)?; let (dim1, dim2) = tensor.shape().r2()?; assert_eq!(dim1, 5); assert_eq!(dim2, 2); |