summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/tensor-tools.rs2
-rw-r--r--candle-core/src/pickle.rs54
-rw-r--r--candle-core/tests/pth.py2
-rw-r--r--candle-core/tests/pth_tests.rs10
-rw-r--r--candle-core/tests/test_with_key.ptbin0 -> 1338 bytes
-rw-r--r--candle-nn/src/var_builder.rs2
6 files changed, 61 insertions, 9 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index eb6ceb1c..1801ac58 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -196,7 +196,7 @@ fn run_ls(
}
}
Format::Pth => {
- let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
+ let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs
index 4c76c416..2c189131 100644
--- a/candle-core/src/pickle.rs
+++ b/candle-core/src/pickle.rs
@@ -625,9 +625,16 @@ pub struct TensorInfo {
pub storage_size: usize,
}
+/// Read the tensor info from a .pth file.
+///
+/// # Arguments
+/// * `file` - The path to the .pth file.
+/// * `verbose` - Whether to print debug information.
+/// * `key` - Optional key to retrieve `state_dict` from the pth file.
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P,
verbose: bool,
+ key: Option<&str>,
) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file);
@@ -649,8 +656,9 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
stack.read_loop(&mut reader)?;
let obj = stack.finalize()?;
if VERBOSE || verbose {
- println!("{obj:?}");
+ println!("{obj:#?}");
}
+
let obj = match obj {
Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable {
@@ -664,6 +672,24 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
},
obj => obj,
};
+
+ // If key is provided, then we need to extract the state_dict from the object.
+ let obj = if let Some(key) = key {
+ if let Object::Dict(key_values) = obj {
+ key_values
+ .into_iter()
+ .find(|(k, _)| *k == Object::Unicode(key.to_owned()))
+ .map(|(_, v)| v)
+ .ok_or_else(|| E::Msg(format!("key {key} not found")))?
+ } else {
+ obj
+ }
+ } else {
+ obj
+ };
+
+ // If the object is a dict, then we can extract the tensor info from it.
+ // NOTE: We are assuming that the `obj` is state_dict by this stage.
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) {
@@ -686,8 +712,8 @@ pub struct PthTensors {
}
impl PthTensors {
- pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
- let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
+ pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
+ let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))
@@ -735,9 +761,17 @@ impl PthTensors {
}
}
-/// Read all the tensors from a PyTorch pth file.
-pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
- let pth = PthTensors::new(path)?;
+/// Read all the tensors from a PyTorch pth file with a given key.
+///
+/// # Arguments
+/// * `path` - Path to the pth file.
+/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
+/// contains multiple objects and the state_dict is the one we are interested in.
+pub fn read_all_with_key<P: AsRef<std::path::Path>>(
+ path: P,
+ key: Option<&str>,
+) -> Result<Vec<(String, Tensor)>> {
+ let pth = PthTensors::new(path, key)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
@@ -747,3 +781,11 @@ pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tenso
}
Ok(tensors)
}
+
+/// Read all the tensors from a PyTorch pth file.
+///
+/// # Arguments
+/// * `path` - Path to the pth file.
+pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
+ read_all_with_key(path, None)
+}
diff --git a/candle-core/tests/pth.py b/candle-core/tests/pth.py
index 97724712..cab94f2c 100644
--- a/candle-core/tests/pth.py
+++ b/candle-core/tests/pth.py
@@ -6,3 +6,5 @@ a= torch.tensor([[1,2,3,4], [5,6,7,8]])
o = OrderedDict()
o["test"] = a
torch.save(o, "test.pt")
+
+torch.save({"model_state_dict": o}, "test_with_key.pt")
diff --git a/candle-core/tests/pth_tests.rs b/candle-core/tests/pth_tests.rs
index b09d1026..ad788ed9 100644
--- a/candle-core/tests/pth_tests.rs
+++ b/candle-core/tests/pth_tests.rs
@@ -1,6 +1,14 @@
/// Regression test for pth files not loading on Windows.
#[test]
fn test_pth() {
- let tensors = candle_core::pickle::PthTensors::new("tests/test.pt").unwrap();
+ let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
+ tensors.get("test").unwrap().unwrap();
+}
+
+#[test]
+fn test_pth_with_key() {
+ let tensors =
+ candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
+ .unwrap();
tensors.get("test").unwrap().unwrap();
}
diff --git a/candle-core/tests/test_with_key.pt b/candle-core/tests/test_with_key.pt
new file mode 100644
index 00000000..a598e02c
--- /dev/null
+++ b/candle-core/tests/test_with_key.pt
Binary files differ
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 33d94c83..bf090219 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -484,7 +484,7 @@ impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
- let pth = candle::pickle::PthTensors::new(p)?;
+ let pth = candle::pickle::PthTensors::new(p, None)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
}
}