summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-01-08 09:20:48 +0100
committerGitHub <noreply@github.com>2024-01-08 09:20:48 +0100
commit12b2a337f30f023af157b9ae560b53c3c5bd416c (patch)
tree2852386070682ac95dda28bcc372070f282959cc
parent0eb90ed7831d451e2e420ecd158151b44dc5b2ba (diff)
downloadcandle-12b2a337f30f023af157b9ae560b53c3c5bd416c.tar.gz
candle-12b2a337f30f023af157b9ae560b53c3c5bd416c.tar.bz2
candle-12b2a337f30f023af157b9ae560b53c3c5bd416c.zip
Handle start-offset when loading a tensor from a pickle file. (#1546)
-rw-r--r--candle-core/src/pickle.rs14
1 files changed, 11 insertions, 3 deletions
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs
index 25640d1a..276b30e3 100644
--- a/candle-core/src/pickle.rs
+++ b/candle-core/src/pickle.rs
@@ -703,6 +703,7 @@ impl PthTensors {
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
+ use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
@@ -712,14 +713,21 @@ impl PthTensors {
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
- // Reading the data is a bit tricky as it can be strided, use an offset, etc.
- // For now only support the basic case.
- if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
+ // Reading the data is a bit tricky as it can be strided, for now only support the basic
+ // case.
+ if !tensor_info.layout.is_contiguous() {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
+ let start_offset = tensor_info.layout.start_offset();
+ if start_offset > 0 {
+ std::io::copy(
+ &mut reader.by_ref().take(start_offset as u64),
+ &mut std::io::sink(),
+ )?;
+ }
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,