//! TensorScalar Enum and Trait //! use crate::{Result, Tensor, WithDType}; pub enum TensorScalar { Tensor(Tensor), Scalar(Tensor), } pub trait TensorOrScalar { fn to_tensor_scalar(self) -> Result; } impl TensorOrScalar for &Tensor { fn to_tensor_scalar(self) -> Result { Ok(TensorScalar::Tensor(self.clone())) } } impl TensorOrScalar for T { fn to_tensor_scalar(self) -> Result { let scalar = Tensor::new(self, &crate::Device::Cpu)?; Ok(TensorScalar::Scalar(scalar)) } }