diff options
Diffstat (limited to 'src/tensor.rs')
-rw-r--r-- | src/tensor.rs | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/src/tensor.rs b/src/tensor.rs index 816308e0..b8fa738a 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -445,7 +445,7 @@ impl Tensor { } macro_rules! bin_trait { - ($trait:ident, $fn1:ident) => { + ($trait:ident, $fn1:ident, $mul:expr, $add:expr) => { impl<B: std::borrow::Borrow<Tensor>> std::ops::$trait<B> for Tensor { type Output = Result<Tensor>; @@ -477,10 +477,26 @@ macro_rules! bin_trait { Tensor::$fn1(&self, rhs?.borrow()) } } + + impl std::ops::$trait<f64> for Tensor { + type Output = Result<Tensor>; + + fn $fn1(self, rhs: f64) -> Self::Output { + self.affine($mul(rhs), $add(rhs)) + } + } + + impl std::ops::$trait<f64> for &Tensor { + type Output = Result<Tensor>; + + fn $fn1(self, rhs: f64) -> Self::Output { + self.affine($mul(rhs), $add(rhs)) + } + } }; } -bin_trait!(Add, add); -bin_trait!(Sub, sub); -bin_trait!(Mul, mul); -bin_trait!(Div, div); +bin_trait!(Add, add, |_| 1., |v| v); +bin_trait!(Sub, sub, |_| 1., |v: f64| -v); +bin_trait!(Mul, mul, |v| v, |_| 0.); +bin_trait!(Div, div, |v| 1. / v, |_| 0.); |