diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-13 14:09:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 14:09:51 +0100 |
commit | 6991036bc50bdfb5312b7145443d413dbd69d77e (patch) | |
tree | c90d0dab8f9e829868dfedbaad6965cf21133959 | |
parent | 7adc8c903a2963fd35a8a2d2e353ed086387396c (diff) | |
download | candle-6991036bc50bdfb5312b7145443d413dbd69d77e.tar.gz candle-6991036bc50bdfb5312b7145443d413dbd69d77e.tar.bz2 candle-6991036bc50bdfb5312b7145443d413dbd69d77e.zip |
Introduce the variables api used for adjusting parameters during the training loop. (#158)
* Add the variable api.
* And add a comment.
-rw-r--r-- | candle-core/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-core/src/variable.rs | 30 |
2 files changed, 32 insertions, 0 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 0108c198..af1eb215 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -56,6 +56,7 @@ mod storage; mod strided_index; mod tensor; pub mod utils; +mod variable; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; @@ -67,6 +68,7 @@ pub use shape::{Shape, D}; pub use storage::Storage; use strided_index::StridedIndex; pub use tensor::{Tensor, TensorId}; +pub use variable::Variable; #[cfg(feature = "cuda")] pub use cuda_backend::{CudaDevice, CudaStorage}; diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs new file mode 100644 index 00000000..67675765 --- /dev/null +++ b/candle-core/src/variable.rs @@ -0,0 +1,30 @@ +// Variables are wrappers around tensors that can be modified, they are typically used for holding +// weights and being modified by gradient descent. +// They are not cloneable by default to avoid having too many potential writers on the data. +// We also do not expose a public way to create variables as this would break the invariant that +// the tensor within a variable is actually with `is_variable` set to `true`. +use crate::Tensor; + +/// A variable is a wrapper around a tensor, however variables can have their content modified +/// whereas tensors are immutable. +#[derive(Debug)] +pub struct Variable(Tensor); + +impl std::ops::Deref for Variable { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl Variable { + pub fn as_tensor(&self) -> &Tensor { + &self.0 + } + + /// Consumes this `Variable` and return the underlying tensor. + pub fn into_inner(self) -> Tensor { + self.0 + } +} |