summaryrefslogtreecommitdiff
path: root/candle-core/src/npy.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/npy.rs')
-rw-r--r--candle-core/src/npy.rs39
1 files changed, 2 insertions, 37 deletions
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index e17ba02a..2e394b06 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -26,7 +26,7 @@
//! values = np.loadz("test.npz")
//! ```
use crate::{DType, Device, Error, Result, Shape, Tensor};
-use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
+use byteorder::{LittleEndian, ReadBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap;
use std::fs::File;
@@ -307,42 +307,7 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
- let vs = self.flatten_all()?;
- match self.dtype() {
- DType::BF16 => {
- let vs = vs.to_vec1::<bf16>()?;
- for &v in vs.reinterpret_cast() {
- f.write_u16::<LittleEndian>(v)?
- }
- }
- DType::F16 => {
- let vs = vs.to_vec1::<f16>()?;
- for &v in vs.reinterpret_cast() {
- f.write_u16::<LittleEndian>(v)?
- }
- }
- DType::F32 => {
- // TODO: Avoid using a buffer when data is already on the CPU.
- for v in vs.to_vec1::<f32>()? {
- f.write_f32::<LittleEndian>(v)?
- }
- }
- DType::F64 => {
- for v in vs.to_vec1::<f64>()? {
- f.write_f64::<LittleEndian>(v)?
- }
- }
- DType::U32 => {
- for v in vs.to_vec1::<u32>()? {
- f.write_u32::<LittleEndian>(v)?
- }
- }
- DType::U8 => {
- let vs = vs.to_vec1::<u8>()?;
- f.write_all(&vs)?;
- }
- }
- Ok(())
+ self.write_bytes(f)
}
/// Writes a multi-dimensional array in the npy format.