diff options
Diffstat (limited to 'candle-core/src/scalar.rs')
-rw-r--r-- | candle-core/src/scalar.rs | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs new file mode 100644 index 00000000..43e1f4c8 --- /dev/null +++ b/candle-core/src/scalar.rs @@ -0,0 +1,23 @@ +use crate::{Result, Tensor, WithDType}; + +pub enum TensorScalar { + Tensor(Tensor), + Scalar(Tensor), +} + +pub trait TensorOrScalar { + fn to_tensor_scalar(self) -> Result<TensorScalar>; +} + +impl TensorOrScalar for &Tensor { + fn to_tensor_scalar(self) -> Result<TensorScalar> { + Ok(TensorScalar::Tensor(self.clone())) + } +} + +impl<T: WithDType> TensorOrScalar for T { + fn to_tensor_scalar(self) -> Result<TensorScalar> { + let scalar = Tensor::new(self, &crate::Device::Cpu)?; + Ok(TensorScalar::Scalar(scalar)) + } +} |