diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 36 |
1 files changed, 16 insertions, 20 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 28ecc357..09f61340 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -232,55 +232,51 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), self.device()) } - pub(crate) fn rand_impl<S: Into<Shape>>( + pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, is_variable: bool, ) -> Result<Self> { let s = s.into(); - let storage = device.rand_uniform(&s, dtype, lo, up)?; + let storage = device.rand_uniform(lo, up, &s)?; let none = BackpropOp::none(); Ok(from_storage(storage, s, none, is_variable)) } /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. - pub fn rand<S: Into<Shape>>( + pub fn rand<S: Into<Shape>, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, ) -> Result<Self> { - Self::rand_impl(s, dtype, device, lo, up, false) + Self::rand_impl(lo, up, s, device, false) } - pub(crate) fn randn_impl<S: Into<Shape>>( + pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, is_variable: bool, ) -> Result<Self> { let s = s.into(); - let storage = device.rand_normal(&s, dtype, mean, std)?; + let storage = device.rand_normal(mean, std, &s)?; let none = BackpropOp::none(); Ok(from_storage(storage, s, none, is_variable)) } /// Creates a new tensor initialized with values sampled from a normal distribution with the /// specified `mean` and standard deviation `std`. - pub fn randn<S: Into<Shape>>( + pub fn randn<S: Into<Shape>, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, ) -> Result<Self> { - Self::randn_impl(s, dtype, device, mean, std, false) + Self::randn_impl(mean, std, s, device, false) } pub(crate) fn new_impl<A: crate::device::NdArray>( |