//! Implement conversion traits for tensors use crate::{DType, Device, Error, Tensor, WithDType}; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::convert::TryFrom; impl TryFrom<&Tensor> for Vec { type Error = Error; fn try_from(tensor: &Tensor) -> Result { tensor.to_vec1::() } } impl TryFrom<&Tensor> for Vec> { type Error = Error; fn try_from(tensor: &Tensor) -> Result { tensor.to_vec2::() } } impl TryFrom<&Tensor> for Vec>> { type Error = Error; fn try_from(tensor: &Tensor) -> Result { tensor.to_vec3::() } } impl TryFrom for Vec { type Error = Error; fn try_from(tensor: Tensor) -> Result { Vec::::try_from(&tensor) } } impl TryFrom for Vec> { type Error = Error; fn try_from(tensor: Tensor) -> Result { Vec::>::try_from(&tensor) } } impl TryFrom for Vec>> { type Error = Error; fn try_from(tensor: Tensor) -> Result { Vec::>>::try_from(&tensor) } } impl TryFrom<&[T]> for Tensor { type Error = Error; fn try_from(v: &[T]) -> Result { Tensor::from_slice(v, v.len(), &Device::Cpu) } } impl TryFrom> for Tensor { type Error = Error; fn try_from(v: Vec) -> Result { let len = v.len(); Tensor::from_vec(v, len, &Device::Cpu) } } macro_rules! from_tensor { ($typ:ident) => { impl TryFrom<&Tensor> for $typ { type Error = Error; fn try_from(tensor: &Tensor) -> Result { tensor.to_scalar::<$typ>() } } impl TryFrom for $typ { type Error = Error; fn try_from(tensor: Tensor) -> Result { $typ::try_from(&tensor) } } impl TryFrom<$typ> for Tensor { type Error = Error; fn try_from(v: $typ) -> Result { Tensor::new(v, &Device::Cpu) } } }; } from_tensor!(f64); from_tensor!(f32); from_tensor!(f16); from_tensor!(bf16); from_tensor!(i64); from_tensor!(u32); from_tensor!(u8); impl Tensor { pub fn write_bytes(&self, f: &mut W) -> crate::Result<()> { use byteorder::{LittleEndian, WriteBytesExt}; let vs = self.flatten_all()?; match self.dtype() { DType::BF16 => { let vs = vs.to_vec1::()?; for &v in vs.reinterpret_cast() { f.write_u16::(v)? } } DType::F16 => { let vs = vs.to_vec1::()?; for &v in vs.reinterpret_cast() { f.write_u16::(v)? } } DType::F32 => { // TODO: Avoid using a buffer when data is already on the CPU. for v in vs.to_vec1::()? { f.write_f32::(v)? } } DType::F64 => { for v in vs.to_vec1::()? { f.write_f64::(v)? } } DType::U32 => { for v in vs.to_vec1::()? { f.write_u32::(v)? } } DType::I64 => { for v in vs.to_vec1::()? { f.write_i64::(v)? } } DType::U8 => { let vs = vs.to_vec1::()?; f.write_all(&vs)?; } } Ok(()) } }