diff options
Diffstat (limited to 'candle-core/src/npy.rs')
-rw-r--r-- | candle-core/src/npy.rs | 39 |
1 files changed, 2 insertions, 37 deletions
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index e17ba02a..2e394b06 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -26,7 +26,7 @@ //! values = np.loadz("test.npz") //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use byteorder::{LittleEndian, ReadBytesExt}; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; @@ -307,42 +307,7 @@ impl Tensor { header.push('\n'); f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; f.write_all(header.as_bytes())?; - let vs = self.flatten_all()?; - match self.dtype() { - DType::BF16 => { - let vs = vs.to_vec1::<bf16>()?; - for &v in vs.reinterpret_cast() { - f.write_u16::<LittleEndian>(v)? - } - } - DType::F16 => { - let vs = vs.to_vec1::<f16>()?; - for &v in vs.reinterpret_cast() { - f.write_u16::<LittleEndian>(v)? - } - } - DType::F32 => { - // TODO: Avoid using a buffer when data is already on the CPU. - for v in vs.to_vec1::<f32>()? { - f.write_f32::<LittleEndian>(v)? - } - } - DType::F64 => { - for v in vs.to_vec1::<f64>()? { - f.write_f64::<LittleEndian>(v)? - } - } - DType::U32 => { - for v in vs.to_vec1::<u32>()? { - f.write_u32::<LittleEndian>(v)? - } - } - DType::U8 => { - let vs = vs.to_vec1::<u8>()?; - f.write_all(&vs)?; - } - } - Ok(()) + self.write_bytes(f) } /// Writes a multi-dimensional array in the npy format. |