summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzachcp <zachcp@users.noreply.github.com>2024-11-26 16:52:53 -0500
committerGitHub <noreply@github.com>2024-11-26 22:52:53 +0100
commitb4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd (patch)
tree6ddb6323edf91ce11a3653ce57e887ab9fc83595
parentc12db594e389610c2b0d20fc90ecffd32c2f8d40 (diff)
downloadcandle-b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd.tar.gz
candle-b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd.tar.bz2
candle-b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd.zip
Provide a method to allow PTH files with state maps to be loaded. (#2639)
* Provide a method to allow PTH files iwth state maps to be loaded. * add a line to the doc * String-. &str
-rw-r--r--candle-nn/src/var_builder.rs12
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 0d836c7f..2731456d 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> {
let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
-
+ /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
+ /// similar to [`from_pth`] but requires a `state_key`.
+ pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
+ p: P,
+ dtype: DType,
+ state_key: &str,
+ dev: &Device,
+ ) -> Result<Self> {
+ let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
+ Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
+ }
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
/// passing the new names to the inner VarBuilder.
///