summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/pickle.rs92
1 files changed, 48 insertions, 44 deletions
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs
index 0013113a..4a2c65fd 100644
--- a/candle-core/src/pickle.rs
+++ b/candle-core/src/pickle.rs
@@ -193,6 +193,50 @@ impl Object {
_ => Err(self),
}
}
+
+ pub fn into_tensor_info(
+ self,
+ name: Self,
+ dir_name: &std::path::Path,
+ ) -> Result<Option<TensorInfo>> {
+ let name = match name.unicode() {
+ Ok(name) => name,
+ Err(_) => return Ok(None),
+ };
+ let (callable, args) = match self.reduce() {
+ Ok(callable_args) => callable_args,
+ _ => return Ok(None),
+ };
+ let (callable, args) = match callable {
+ Object::Class {
+ module_name,
+ class_name,
+ } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
+ let mut args = args.tuple()?;
+ let callable = args.remove(0);
+ let args = args.remove(1);
+ (callable, args)
+ }
+ _ => (callable, args),
+ };
+ match callable {
+ Object::Class {
+ module_name,
+ class_name,
+ } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
+ _ => return Ok(None),
+ };
+ let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
+ let mut path = dir_name.to_path_buf();
+ path.push(file_path);
+ Ok(Some(TensorInfo {
+ name,
+ dtype,
+ layout,
+ path: path.to_string_lossy().into_owned(),
+ storage_size,
+ }))
+ }
}
impl TryFrom<Object> for String {
@@ -623,50 +667,10 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
};
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
- let name = match name.unicode() {
- Ok(name) => name,
- Err(_) => continue,
- };
- let (callable, args) = match value.reduce() {
- Ok(callable_args) => callable_args,
- _ => continue,
- };
- let (callable, args) = match callable {
- Object::Class {
- module_name,
- class_name,
- } if module_name == "torch._tensor"
- && class_name == "_rebuild_from_type_v2" =>
- {
- let mut args = args.tuple()?;
- let callable = args.remove(0);
- let args = args.remove(1);
- (callable, args)
- }
- _ => (callable, args),
- };
- match callable {
- Object::Class {
- module_name,
- class_name,
- } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
- _ => continue,
- };
- match rebuild_args(args) {
- Ok((layout, dtype, file_path, storage_size)) => {
- let mut path = dir_name.clone();
- path.push(file_path);
- tensor_infos.push(TensorInfo {
- name,
- dtype,
- layout,
- path: path.to_string_lossy().into_owned(),
- storage_size,
- })
- }
- Err(err) => {
- eprintln!("skipping {name}: {err:?}")
- }
+ match value.into_tensor_info(name, &dir_name) {
+ Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
+ Ok(None) => {}
+ Err(err) => eprintln!("skipping: {err:?}"),
}
}
}