summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 14:09:51 +0100
committerGitHub <noreply@github.com>2023-07-13 14:09:51 +0100
commit6991036bc50bdfb5312b7145443d413dbd69d77e (patch)
treec90d0dab8f9e829868dfedbaad6965cf21133959
parent7adc8c903a2963fd35a8a2d2e353ed086387396c (diff)
downloadcandle-6991036bc50bdfb5312b7145443d413dbd69d77e.tar.gz
candle-6991036bc50bdfb5312b7145443d413dbd69d77e.tar.bz2
candle-6991036bc50bdfb5312b7145443d413dbd69d77e.zip
Introduce the variables api used for adjusting parameters during the training loop. (#158)
* Add the variable api. * And add a comment.
-rw-r--r--candle-core/src/lib.rs2
-rw-r--r--candle-core/src/variable.rs30
2 files changed, 32 insertions, 0 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 0108c198..af1eb215 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -56,6 +56,7 @@ mod storage;
mod strided_index;
mod tensor;
pub mod utils;
+mod variable;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};
@@ -67,6 +68,7 @@ pub use shape::{Shape, D};
pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
+pub use variable::Variable;
#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaStorage};
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs
new file mode 100644
index 00000000..67675765
--- /dev/null
+++ b/candle-core/src/variable.rs
@@ -0,0 +1,30 @@
+// Variables are wrappers around tensors that can be modified, they are typically used for holding
+// weights and being modified by gradient descent.
+// They are not cloneable by default to avoid having too many potential writers on the data.
+// We also do not expose a public way to create variables as this would break the invariant that
+// the tensor within a variable is actually with `is_variable` set to `true`.
+use crate::Tensor;
+
+/// A variable is a wrapper around a tensor, however variables can have their content modified
+/// whereas tensors are immutable.
+#[derive(Debug)]
+pub struct Variable(Tensor);
+
+impl std::ops::Deref for Variable {
+ type Target = Tensor;
+
+ fn deref(&self) -> &Self::Target {
+ self.0.as_ref()
+ }
+}
+
+impl Variable {
+ pub fn as_tensor(&self) -> &Tensor {
+ &self.0
+ }
+
+ /// Consumes this `Variable` and return the underlying tensor.
+ pub fn into_inner(self) -> Tensor {
+ self.0
+ }
+}