summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorDilshod Tadjibaev <939125+antimora@users.noreply.github.com>2024-02-06 14:17:33 -0600
committerGitHub <noreply@github.com>2024-02-06 21:17:33 +0100
commitb75e8945bc7c67106be6288b9f357efa8068e62e (patch)
treee9918bcacc07281358f0ecfa66fa700278980cb1 /candle-nn
parenta90fc5ca5a486e988d39ea69ee3d3bb40a39c017 (diff)
downloadcandle-b75e8945bc7c67106be6288b9f357efa8068e62e.tar.gz
candle-b75e8945bc7c67106be6288b9f357efa8068e62e.tar.bz2
candle-b75e8945bc7c67106be6288b9f357efa8068e62e.zip
Enhance pickle to retrieve state_dict with a given key (#1671)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/var_builder.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 33d94c83..bf090219 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -484,7 +484,7 @@ impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
- let pth = candle::pickle::PthTensors::new(p)?;
+ let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
}