diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-21 08:59:08 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-21 08:59:08 +0100 |
commit | f319583530745dfab125bd2d16c2dfa4aa75646d (patch) | |
tree | 5948cbdb464585f4d7b22b7b4f8b3c058293968f /src/tensor.rs | |
parent | 08399547703b2a6cc802729464aacfb4e8ebdd43 (diff) | |
download | candle-f319583530745dfab125bd2d16c2dfa4aa75646d.tar.gz candle-f319583530745dfab125bd2d16c2dfa4aa75646d.tar.bz2 candle-f319583530745dfab125bd2d16c2dfa4aa75646d.zip |
More QOL changes, binary op for constants.
Diffstat (limited to 'src/tensor.rs')
-rw-r--r-- | src/tensor.rs | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/src/tensor.rs b/src/tensor.rs index 816308e0..b8fa738a 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -445,7 +445,7 @@ impl Tensor { } macro_rules! bin_trait { - ($trait:ident, $fn1:ident) => { + ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => { impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor { type Output = Result<Tensor>; @@ -477,10 +477,26 @@ macro_rules! bin_trait { Tensor::$fn1(&self, rhs?.borrow()) } } + + impl std::ops::$trait<f64> for Tensor { + type Output = Result<Tensor>; + + fn $fn1(self, rhs: f64) -> Self::Output { + self.affine($mul(rhs), $add(rhs)) + } + } + + impl std::ops::$trait<f64> for &Tensor { + type Output = Result<Tensor>; + + fn $fn1(self, rhs: f64) -> Self::Output { + self.affine($mul(rhs), $add(rhs)) + } + } }; } -bin_trait!(Add, add); -bin_trait!(Sub, sub); -bin_trait!(Mul, mul); -bin_trait!(Div, div); +bin_trait!(Add, add, |_| 1., |v| v); +bin_trait!(Sub, sub, |_| 1., |v: f64| -v); +bin_trait!(Mul, mul, |v| v, |_| 0.); +bin_trait!(Div, div, |v| 1. / v, |_| 0.); |