summaryrefslogtreecommitdiff
path: root/candle-core/src/scalar.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/scalar.rs')
-rw-r--r--candle-core/src/scalar.rs23
1 files changed, 23 insertions, 0 deletions
diff --git a/candle-core/src/scalar.rs b/candle-core/src/scalar.rs
new file mode 100644
index 00000000..43e1f4c8
--- /dev/null
+++ b/candle-core/src/scalar.rs
@@ -0,0 +1,23 @@
+use crate::{Result, Tensor, WithDType};
+
+pub enum TensorScalar {
+ Tensor(Tensor),
+ Scalar(Tensor),
+}
+
+pub trait TensorOrScalar {
+ fn to_tensor_scalar(self) -> Result<TensorScalar>;
+}
+
+impl TensorOrScalar for &Tensor {
+ fn to_tensor_scalar(self) -> Result<TensorScalar> {
+ Ok(TensorScalar::Tensor(self.clone()))
+ }
+}
+
+impl<T: WithDType> TensorOrScalar for T {
+ fn to_tensor_scalar(self) -> Result<TensorScalar> {
+ let scalar = Tensor::new(self, &crate::Device::Cpu)?;
+ Ok(TensorScalar::Scalar(scalar))
+ }
+}