diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-01-08 09:20:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-08 09:20:48 +0100 |
commit | 12b2a337f30f023af157b9ae560b53c3c5bd416c (patch) | |
tree | 2852386070682ac95dda28bcc372070f282959cc | |
parent | 0eb90ed7831d451e2e420ecd158151b44dc5b2ba (diff) | |
download | candle-12b2a337f30f023af157b9ae560b53c3c5bd416c.tar.gz candle-12b2a337f30f023af157b9ae560b53c3c5bd416c.tar.bz2 candle-12b2a337f30f023af157b9ae560b53c3c5bd416c.zip |
Handle start-offset when loading a tensor from a pickle file. (#1546)
-rw-r--r-- | candle-core/src/pickle.rs | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 25640d1a..276b30e3 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -703,6 +703,7 @@ impl PthTensors { } pub fn get(&self, name: &str) -> Result<Option<Tensor>> { + use std::io::Read; let tensor_info = match self.tensor_infos.get(name) { None => return Ok(None), Some(tensor_info) => tensor_info, @@ -712,14 +713,21 @@ impl PthTensors { let mut zip = zip::ZipArchive::new(zip_reader)?; let mut reader = zip.by_name(&tensor_info.path)?; - // Reading the data is a bit tricky as it can be strided, use an offset, etc. - // For now only support the basic case. - if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() { + // Reading the data is a bit tricky as it can be strided, for now only support the basic + // case. + if !tensor_info.layout.is_contiguous() { crate::bail!( "cannot retrieve non-contiguous tensors {:?}", tensor_info.layout ) } + let start_offset = tensor_info.layout.start_offset(); + if start_offset > 0 { + std::io::copy( + &mut reader.by_ref().take(start_offset as u64), + &mut std::io::sink(), + )?; + } let tensor = Tensor::from_reader( tensor_info.layout.shape().clone(), tensor_info.dtype, |