summaryrefslogtreecommitdiff
path: root/candle-core/src/variable.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-27 07:40:36 +0100
committerGitHub <noreply@github.com>2023-07-27 07:40:36 +0100
commit6475bfadfebcd02dd58adc60452f492f8dc11a39 (patch)
tree60e7ba3b02b1ebd26854f4e6adfec76df0617f48 /candle-core/src/variable.rs
parent89ba005962495f2bfbda286e185e9c3c7f5300a3 (diff)
downloadcandle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.gz
candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.bz2
candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.zip
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn. * Also switch Tensor::rand to use a generic dtype. * Support sampling for f16. * Cleanup.
Diffstat (limited to 'candle-core/src/variable.rs')
-rw-r--r--candle-core/src/variable.rs18
1 files changed, 8 insertions, 10 deletions
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs
index e26f1420..0cefee11 100644
--- a/candle-core/src/variable.rs
+++ b/candle-core/src/variable.rs
@@ -34,25 +34,23 @@ impl Var {
Ok(Self(inner))
}
- pub fn rand<S: Into<Shape>>(
+ pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
+ lo: T,
+ up: T,
s: S,
- dtype: DType,
device: &Device,
- lo: f64,
- up: f64,
) -> Result<Self> {
- let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
+ let inner = Tensor::rand_impl(lo, up, s, device, true)?;
Ok(Self(inner))
}
- pub fn randn<S: Into<Shape>>(
+ pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
+ mean: T,
+ std: T,
s: S,
- dtype: DType,
device: &Device,
- mean: f64,
- std: f64,
) -> Result<Self> {
- let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
+ let inner = Tensor::randn_impl(mean, std, s, device, true)?;
Ok(Self(inner))
}