summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-01 09:28:35 +0200
committerGitHub <noreply@github.com>2023-09-01 08:28:35 +0100
commitf9f482d4e5378d941bc1e1a7dad1a5dfe0efd24e (patch)
tree54dd1cfd3972dd6b59a420adb621de50b49f76ef /candle-nn/src
parent9736236175633de60bbc174cadf23f25c37ce653 (diff)
downloadcandle-f9f482d4e5378d941bc1e1a7dad1a5dfe0efd24e.tar.gz
candle-f9f482d4e5378d941bc1e1a7dad1a5dfe0efd24e.tar.bz2
candle-f9f482d4e5378d941bc1e1a7dad1a5dfe0efd24e.zip
Add some doc to the varbuilder. (#700)
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/var_builder.rs24
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 03929681..bf5d5b43 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -1,3 +1,6 @@
+//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
+//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_safetensors`, or initialized for
+//! training, e.g. using `VarBuilder::from_varmap`.
use crate::VarMap;
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
use safetensors::{slice::IndexOp, tensor::SafeTensors};
@@ -107,6 +110,15 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
self.path.join(".")
}
+ /// Returns a new `VarBuilder` using the root path.
+ pub fn root(&self) -> Self {
+ Self {
+ data: self.data.clone(),
+ path: vec![],
+ _phantom: std::marker::PhantomData,
+ }
+ }
+
/// Returns a new `VarBuilder` with the prefix set to `prefix`.
pub fn set_prefix(&self, prefix: impl ToString) -> Self {
Self {
@@ -327,18 +339,29 @@ impl<'a> VarBuilder<'a> {
}
}
+ /// Initializes a `VarBuilder` that uses zeros for any tensor.
pub fn zeros(dtype: DType, dev: &Device) -> Self {
Self::new(Box::new(Zeros), dtype, dev.clone())
}
+ /// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
+ /// returned if no tensor is available under the requested path or on shape mismatches.
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
Self::new(Box::new(ts), dtype, dev.clone())
}
+ /// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
+ /// initialized on new paths, the same tensor is used if the same path is requested multiple
+ /// times. This is commonly used when initializing a model before training.
+ ///
+ /// Note that it is possible to load the tensor values after model creation using the `load`
+ /// method on `varmap`, this can be used to start model training from an existing checkpoint.
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
Self::new(Box::new(varmap.clone()), dtype, dev.clone())
}
+ /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
+ /// files.
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
@@ -353,6 +376,7 @@ impl<'a> VarBuilder<'a> {
Self::new(Box::new(tensors), dtype, dev.clone())
}
+ /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let npz = candle::npy::NpzTensors::new(p)?;
Ok(Self::new(Box::new(npz), dtype, dev.clone()))