summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs36
1 files changed, 16 insertions, 20 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 28ecc357..09f61340 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -232,55 +232,51 @@ impl Tensor {
Tensor::zeros(self.shape(), self.dtype(), self.device())
}
- pub(crate) fn rand_impl<S: Into<Shape>>(
+ pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>(
+ lo: T,
+ up: T,
s: S,
- dtype: DType,
device: &Device,
- lo: f64,
- up: f64,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
- let storage = device.rand_uniform(&s, dtype, lo, up)?;
+ let storage = device.rand_uniform(lo, up, &s)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
- pub fn rand<S: Into<Shape>>(
+ pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
+ lo: T,
+ up: T,
s: S,
- dtype: DType,
device: &Device,
- lo: f64,
- up: f64,
) -> Result<Self> {
- Self::rand_impl(s, dtype, device, lo, up, false)
+ Self::rand_impl(lo, up, s, device, false)
}
- pub(crate) fn randn_impl<S: Into<Shape>>(
+ pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>(
+ mean: T,
+ std: T,
s: S,
- dtype: DType,
device: &Device,
- mean: f64,
- std: f64,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
- let storage = device.rand_normal(&s, dtype, mean, std)?;
+ let storage = device.rand_normal(mean, std, &s)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
- pub fn randn<S: Into<Shape>>(
+ pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
+ mean: T,
+ std: T,
s: S,
- dtype: DType,
device: &Device,
- mean: f64,
- std: f64,
) -> Result<Self> {
- Self::randn_impl(s, dtype, device, mean, std, false)
+ Self::randn_impl(mean, std, s, device, false)
}
pub(crate) fn new_impl<A: crate::device::NdArray>(