diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-27 07:40:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-27 07:40:36 +0100 |
commit | 6475bfadfebcd02dd58adc60452f492f8dc11a39 (patch) | |
tree | 60e7ba3b02b1ebd26854f4e6adfec76df0617f48 /candle-core/src/variable.rs | |
parent | 89ba005962495f2bfbda286e185e9c3c7f5300a3 (diff) | |
download | candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.gz candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.bz2 candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.zip |
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn.
* Also switch Tensor::rand to use a generic dtype.
* Support sampling for f16.
* Cleanup.
Diffstat (limited to 'candle-core/src/variable.rs')
-rw-r--r-- | candle-core/src/variable.rs | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index e26f1420..0cefee11 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -34,25 +34,23 @@ impl Var { Ok(Self(inner)) } - pub fn rand<S: Into<Shape>>( + pub fn rand<S: Into<Shape>, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, ) -> Result<Self> { - let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?; + let inner = Tensor::rand_impl(lo, up, s, device, true)?; Ok(Self(inner)) } - pub fn randn<S: Into<Shape>>( + pub fn randn<S: Into<Shape>, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, ) -> Result<Self> { - let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?; + let inner = Tensor::randn_impl(mean, std, s, device, true)?; Ok(Self(inner)) } |