summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/var_builder.rs12
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 0d836c7f..2731456d 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> {
let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
-
+ /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
+ /// similar to [`from_pth`] but requires a `state_key`.
+ pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
+ p: P,
+ dtype: DType,
+ state_key: &str,
+ dev: &Device,
+ ) -> Result<Self> {
+ let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
+ Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
+ }
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
/// passing the new names to the inner VarBuilder.
///