diff options
author | zachcp <zachcp@users.noreply.github.com> | 2024-11-26 16:52:53 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-26 22:52:53 +0100 |
commit | b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd (patch) | |
tree | 6ddb6323edf91ce11a3653ce57e887ab9fc83595 | |
parent | c12db594e389610c2b0d20fc90ecffd32c2f8d40 (diff) | |
download | candle-b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd.tar.gz candle-b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd.tar.bz2 candle-b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd.zip |
Provide a method to allow PTH files with state maps to be loaded. (#2639)
* Provide a method to allow PTH files iwth state maps to be loaded.
* add a line to the doc
* String-. &str
-rw-r--r-- | candle-nn/src/var_builder.rs | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 0d836c7f..2731456d 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -544,7 +544,17 @@ impl<'a> VarBuilder<'a> { let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } - + /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. + /// similar to [`from_pth`] but requires a `state_key`. + pub fn from_pth_with_state<P: AsRef<std::path::Path>>( + p: P, + dtype: DType, + state_key: &str, + dev: &Device, + ) -> Result<Self> { + let pth = candle::pickle::PthTensors::new(p, Some(state_key))?; + Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) + } /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before /// passing the new names to the inner VarBuilder. /// |