diff options
Diffstat (limited to 'candle-core/src/npy.rs')
-rw-r--r-- | candle-core/src/npy.rs | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 6302cf71..e17ba02a 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -307,39 +307,39 @@ impl Tensor { header.push('\n'); f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?; f.write_all(header.as_bytes())?; - let elem_count = self.elem_count(); + let vs = self.flatten_all()?; match self.dtype() { DType::BF16 => { - let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?; + let vs = vs.to_vec1::<bf16>()?; for &v in vs.reinterpret_cast() { f.write_u16::<LittleEndian>(v)? } } DType::F16 => { - let vs = self.reshape(elem_count)?.to_vec1::<f16>()?; + let vs = vs.to_vec1::<f16>()?; for &v in vs.reinterpret_cast() { f.write_u16::<LittleEndian>(v)? } } DType::F32 => { // TODO: Avoid using a buffer when data is already on the CPU. - for v in self.reshape(elem_count)?.to_vec1::<f32>()? { + for v in vs.to_vec1::<f32>()? { f.write_f32::<LittleEndian>(v)? } } DType::F64 => { - for v in self.reshape(elem_count)?.to_vec1::<f64>()? { + for v in vs.to_vec1::<f64>()? { f.write_f64::<LittleEndian>(v)? } } DType::U32 => { - for v in self.reshape(elem_count)?.to_vec1::<u32>()? { + for v in vs.to_vec1::<u32>()? { f.write_u32::<LittleEndian>(v)? } } DType::U8 => { - let data = self.reshape(elem_count)?.to_vec1::<u8>()?; - f.write_all(&data)?; + let vs = vs.to_vec1::<u8>()?; + f.write_all(&vs)?; } } Ok(()) @@ -373,7 +373,7 @@ pub struct NpzTensors { index_per_name: HashMap<String, usize>, path: std::path::PathBuf, // We do not store a zip reader as it needs mutable access to extract data. Instead we - // re-create a zip reader each time. + // re-create a zip reader for each tensor. } impl NpzTensors { |