From 6991036bc50bdfb5312b7145443d413dbd69d77e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 13 Jul 2023 14:09:51 +0100 Subject: Introduce the variables api used for adjusting parameters during the training loop. (#158) * Add the variable api. * And add a comment. --- candle-core/src/lib.rs | 2 ++ candle-core/src/variable.rs | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 candle-core/src/variable.rs (limited to 'candle-core/src') diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 0108c198..af1eb215 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -56,6 +56,7 @@ mod storage; mod strided_index; mod tensor; pub mod utils; +mod variable; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; @@ -67,6 +68,7 @@ pub use shape::{Shape, D}; pub use storage::Storage; use strided_index::StridedIndex; pub use tensor::{Tensor, TensorId}; +pub use variable::Variable; #[cfg(feature = "cuda")] pub use cuda_backend::{CudaDevice, CudaStorage}; diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs new file mode 100644 index 00000000..67675765 --- /dev/null +++ b/candle-core/src/variable.rs @@ -0,0 +1,30 @@ +// Variables are wrappers around tensors that can be modified, they are typically used for holding +// weights and being modified by gradient descent. +// They are not cloneable by default to avoid having too many potential writers on the data. +// We also do not expose a public way to create variables as this would break the invariant that +// the tensor within a variable is actually with `is_variable` set to `true`. +use crate::Tensor; + +/// A variable is a wrapper around a tensor, however variables can have their content modified +/// whereas tensors are immutable. +#[derive(Debug)] +pub struct Variable(Tensor); + +impl std::ops::Deref for Variable { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl Variable { + pub fn as_tensor(&self) -> &Tensor { + &self.0 + } + + /// Consumes this `Variable` and return the underlying tensor. + pub fn into_inner(self) -> Tensor { + self.0 + } +} -- cgit v1.2.3