diff options
-rw-r--r-- | candle-core/src/convert.rs | 47 | ||||
-rw-r--r-- | candle-core/src/npy.rs | 39 |
2 files changed, 47 insertions, 39 deletions
diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 41a9c4ee..744982fc 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -1,6 +1,6 @@ //! Implement conversion traits for tensors -use crate::{Device, Error, Tensor, WithDType}; -use half::{bf16, f16}; +use crate::{DType, Device, Error, Tensor, WithDType}; +use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::convert::TryFrom; impl<T: WithDType> TryFrom<&Tensor> for Vec<T> { @@ -94,3 +94,46 @@ from_tensor!(f16); from_tensor!(bf16); from_tensor!(u32); from_tensor!(u8); + +impl Tensor { + pub fn write_bytes<W: std::io::Write>(&self, f: &mut W) -> crate::Result<()> { + use byteorder::{LittleEndian, WriteBytesExt}; + + let vs = self.flatten_all()?; + match self.dtype() { + DType::BF16 => { + let vs = vs.to_vec1::<bf16>()?; + for &v in vs.reinterpret_cast() { + f.write_u16::<LittleEndian>(v)? + } + } + DType::F16 => { + let vs = vs.to_vec1::<f16>()?; + for &v in vs.reinterpret_cast() { + f.write_u16::<LittleEndian>(v)? + } + } + DType::F32 => { + // TODO: Avoid using a buffer when data is already on the CPU. + for v in vs.to_vec1::<f32>()? { + f.write_f32::<LittleEndian>(v)? + } + } + DType::F64 => { + for v in vs.to_vec1::<f64>()? { + f.write_f64::<LittleEndian>(v)? + } + } + DType::U32 => { + for v in vs.to_vec1::<u32>()? { + f.write_u32::<LittleEndian>(v)? + } + } + DType::U8 => { + let vs = vs.to_vec1::<u8>()?; + f.write_all(&vs)?; + } + } + Ok(()) + } +} diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index e17ba02a..2e394b06 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -26,7 +26,7 @@ //! values = np.loadz("test.npz") //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use byteorder::{LittleEndian, ReadBytesExt}; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; @@ -307,42 +307,7 @@ impl Tensor { header.push('\n'); f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; f.write_all(header.as_bytes())?; - let vs = self.flatten_all()?; - match self.dtype() { - DType::BF16 => { - let vs = vs.to_vec1::<bf16>()?; - for &v in vs.reinterpret_cast() { - f.write_u16::<LittleEndian>(v)? - } - } - DType::F16 => { - let vs = vs.to_vec1::<f16>()?; - for &v in vs.reinterpret_cast() { - f.write_u16::<LittleEndian>(v)? - } - } - DType::F32 => { - // TODO: Avoid using a buffer when data is already on the CPU. - for v in vs.to_vec1::<f32>()? { - f.write_f32::<LittleEndian>(v)? - } - } - DType::F64 => { - for v in vs.to_vec1::<f64>()? { - f.write_f64::<LittleEndian>(v)? - } - } - DType::U32 => { - for v in vs.to_vec1::<u32>()? { - f.write_u32::<LittleEndian>(v)? - } - } - DType::U8 => { - let vs = vs.to_vec1::<u8>()?; - f.write_all(&vs)?; - } - } - Ok(()) + self.write_bytes(f) } /// Writes a multi-dimensional array in the npy format. |