diff options
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/var_builder.rs | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 0d836c7f..2731456d 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> { let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } - + /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. + /// similar to [`from_pth`] but requires a `state_key`. + pub fn from_pth_with_state<P: AsRef<std::path::Path>>( + p: P, + dtype: DType, + state_key: &str, + dev: &Device, + ) -> Result<Self> { + let pth = candle::pickle::PthTensors::new(p, Some(state_key))?; + Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) + } /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before /// passing the new names to the inner VarBuilder. /// |