diff options
-rw-r--r-- | candle-core/examples/tensor-tools.rs | 2 | ||||
-rw-r--r-- | candle-core/src/pickle.rs | 54 | ||||
-rw-r--r-- | candle-core/tests/pth.py | 2 | ||||
-rw-r--r-- | candle-core/tests/pth_tests.rs | 10 | ||||
-rw-r--r-- | candle-core/tests/test_with_key.pt | bin | 0 -> 1338 bytes | |||
-rw-r--r-- | candle-nn/src/var_builder.rs | 2 |
6 files changed, 61 insertions, 9 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index eb6ceb1c..1801ac58 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -196,7 +196,7 @@ fn run_ls( } } Format::Pth => { - let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?; + let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { println!( diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 4c76c416..2c189131 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -625,9 +625,16 @@ pub struct TensorInfo { pub storage_size: usize, } +/// Read the tensor info from a .pth file. +/// +/// # Arguments +/// * `file` - The path to the .pth file. +/// * `verbose` - Whether to print debug information. +/// * `key` - Optional key to retrieve `state_dict` from the pth file. pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>( file: P, verbose: bool, + key: Option<&str>, ) -> Result<Vec<TensorInfo>> { let file = std::fs::File::open(file)?; let zip_reader = std::io::BufReader::new(file); @@ -649,8 +656,9 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>( stack.read_loop(&mut reader)?; let obj = stack.finalize()?; if VERBOSE || verbose { - println!("{obj:?}"); + println!("{obj:#?}"); } + let obj = match obj { Object::Build { callable, args } => match *callable { Object::Reduce { callable, args: _ } => match *callable { @@ -664,6 +672,24 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>( }, obj => obj, }; + + // If key is provided, then we need to extract the state_dict from the object. + let obj = if let Some(key) = key { + if let Object::Dict(key_values) = obj { + key_values + .into_iter() + .find(|(k, _)| *k == Object::Unicode(key.to_owned())) + .map(|(_, v)| v) + .ok_or_else(|| E::Msg(format!("key {key} not found")))? + } else { + obj + } + } else { + obj + }; + + // If the object is a dict, then we can extract the tensor info from it. + // NOTE: We are assuming that the `obj` is state_dict by this stage. if let Object::Dict(key_values) = obj { for (name, value) in key_values.into_iter() { match value.into_tensor_info(name, &dir_name) { @@ -686,8 +712,8 @@ pub struct PthTensors { } impl PthTensors { - pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> { - let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?; + pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> { + let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?; let tensor_infos = tensor_infos .into_iter() .map(|ti| (ti.name.to_string(), ti)) @@ -735,9 +761,17 @@ impl PthTensors { } } -/// Read all the tensors from a PyTorch pth file. -pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> { - let pth = PthTensors::new(path)?; +/// Read all the tensors from a PyTorch pth file with a given key. +/// +/// # Arguments +/// * `path` - Path to the pth file. +/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file +/// contains multiple objects and the state_dict is the one we are interested in. +pub fn read_all_with_key<P: AsRef<std::path::Path>>( + path: P, + key: Option<&str>, +) -> Result<Vec<(String, Tensor)>> { + let pth = PthTensors::new(path, key)?; let tensor_names = pth.tensor_infos.keys(); let mut tensors = Vec::with_capacity(tensor_names.len()); for name in tensor_names { @@ -747,3 +781,11 @@ pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tenso } Ok(tensors) } + +/// Read all the tensors from a PyTorch pth file. +/// +/// # Arguments +/// * `path` - Path to the pth file. +pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> { + read_all_with_key(path, None) +} diff --git a/candle-core/tests/pth.py b/candle-core/tests/pth.py index 97724712..cab94f2c 100644 --- a/candle-core/tests/pth.py +++ b/candle-core/tests/pth.py @@ -6,3 +6,5 @@ a= torch.tensor([[1,2,3,4], [5,6,7,8]]) o = OrderedDict() o["test"] = a torch.save(o, "test.pt") + +torch.save({"model_state_dict": o}, "test_with_key.pt") diff --git a/candle-core/tests/pth_tests.rs b/candle-core/tests/pth_tests.rs index b09d1026..ad788ed9 100644 --- a/candle-core/tests/pth_tests.rs +++ b/candle-core/tests/pth_tests.rs @@ -1,6 +1,14 @@ /// Regression test for pth files not loading on Windows. #[test] fn test_pth() { - let tensors = candle_core::pickle::PthTensors::new("tests/test.pt").unwrap(); + let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap(); + tensors.get("test").unwrap().unwrap(); +} + +#[test] +fn test_pth_with_key() { + let tensors = + candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict")) + .unwrap(); tensors.get("test").unwrap().unwrap(); } diff --git a/candle-core/tests/test_with_key.pt b/candle-core/tests/test_with_key.pt Binary files differnew file mode 100644 index 00000000..a598e02c --- /dev/null +++ b/candle-core/tests/test_with_key.pt diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 33d94c83..bf090219 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -484,7 +484,7 @@ impl<'a> VarBuilder<'a> { /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> { - let pth = candle::pickle::PthTensors::new(p)?; + let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } } |