summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-11 08:52:29 +0100
committerGitHub <noreply@github.com>2023-07-11 08:52:29 +0100
commitae79c00e48089d889f900b4c05f90a1201e610c6 (patch)
treea42bc3334791a79b203397c581d4ed03998e4e3d /candle-core/src/tensor.rs
parentb31a3bbdcbf1a75bbb18cdc2aa0fbff2ab931351 (diff)
downloadcandle-ae79c00e48089d889f900b4c05f90a1201e610c6.tar.gz
candle-ae79c00e48089d889f900b4c05f90a1201e610c6.tar.bz2
candle-ae79c00e48089d889f900b4c05f90a1201e610c6.zip
Allow for uniform initialization in a single step. (#136)
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs24
1 files changed, 19 insertions, 5 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index aba7b91a..ecc018f9 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -226,19 +226,33 @@ impl Tensor {
s: S,
dtype: DType,
device: &Device,
+ lo: f64,
+ up: f64,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
- let storage = device.rand_uniform(&s, dtype)?;
+ let storage = device.rand_uniform(&s, dtype, lo, up)?;
Ok(from_storage(storage, s, None, is_variable))
}
- pub fn rand_uniform<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
- Self::rand_uniform_impl(s, dtype, device, false)
+ pub fn rand_uniform<S: Into<Shape>>(
+ s: S,
+ dtype: DType,
+ device: &Device,
+ lo: f64,
+ up: f64,
+ ) -> Result<Self> {
+ Self::rand_uniform_impl(s, dtype, device, lo, up, false)
}
- pub fn rand_uniform_var<S: Into<Shape>>(s: S, dtype: DType, device: &Device) -> Result<Self> {
- Self::rand_uniform_impl(s, dtype, device, true)
+ pub fn rand_uniform_var<S: Into<Shape>>(
+ s: S,
+ dtype: DType,
+ device: &Device,
+ lo: f64,
+ up: f64,
+ ) -> Result<Self> {
+ Self::rand_uniform_impl(s, dtype, device, lo, up, true)
}
fn rand_normal_impl<S: Into<Shape>>(