summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/error.rs14
-rw-r--r--candle-core/src/npy.rs18
-rw-r--r--candle-core/src/safetensors.rs20
3 files changed, 38 insertions, 14 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 35a33032..c18b43c6 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -185,6 +185,13 @@ pub enum Error {
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
+ /// Adding path information to an error.
+ #[error("path: {path:?} {inner}")]
+ WithPath {
+ inner: Box<Self>,
+ path: std::path::PathBuf,
+ },
+
#[error("{inner}\n{backtrace}")]
WithBacktrace {
inner: Box<Self>,
@@ -214,6 +221,13 @@ impl Error {
},
}
}
+
+ pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
+ Self::WithPath {
+ inner: Box::new(self),
+ path: p.as_ref().to_path_buf(),
+ }
+ }
}
#[macro_export]
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index 6302cf71..e17ba02a 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -307,39 +307,39 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
- let elem_count = self.elem_count();
+ let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
+ let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
+ let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
- for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
+ for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
- for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
+ for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
- for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
+ for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
- let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
- f.write_all(&data)?;
+ let vs = vs.to_vec1::<u8>()?;
+ f.write_all(&vs)?;
}
}
Ok(())
@@ -373,7 +373,7 @@ pub struct NpzTensors {
index_per_name: HashMap<String, usize>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
- // re-create a zip reader each time.
+ // re-create a zip reader for each tensor.
}
impl NpzTensors {
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 132fb914..914e5101 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -257,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
-pub struct MmapedFile(memmap2::Mmap);
+pub struct MmapedFile {
+ path: std::path::PathBuf,
+ inner: memmap2::Mmap,
+}
impl MmapedFile {
/// Creates a wrapper around a memory mapped file from which you can retrieve
@@ -267,13 +270,20 @@ impl MmapedFile {
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
- let file = std::fs::File::open(p)?;
- let mmap = memmap2::MmapOptions::new().map(&file)?;
- Ok(Self(mmap))
+ let p = p.as_ref();
+ let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
+ let inner = memmap2::MmapOptions::new()
+ .map(&file)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ Ok(Self {
+ inner,
+ path: p.to_path_buf(),
+ })
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
- let st = safetensors::SafeTensors::deserialize(&self.0)?;
+ let st = safetensors::SafeTensors::deserialize(&self.inner)
+ .map_err(|e| Error::from(e).with_path(&self.path))?;
Ok(st)
}
}