summaryrefslogtreecommitdiff
path: root/candle-core/src/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r--candle-core/src/device.rs44
1 files changed, 30 insertions, 14 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 89df8f84..563d892b 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -116,46 +116,62 @@ impl Device {
}
}
- pub(crate) fn rand_uniform<T: crate::FloatDType>(
+ pub(crate) fn rand_uniform_f64(
&self,
- lo: T,
- up: T,
+ lo: f64,
+ up: f64,
shape: &Shape,
+ dtype: DType,
) -> Result<Storage> {
- let lo = lo.to_f64();
- let up = up.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
+ let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
+ let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
- pub(crate) fn rand_normal<T: crate::FloatDType>(
+ pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
- mean: T,
- std: T,
+ lo: T,
+ up: T,
+ shape: &Shape,
+ ) -> Result<Storage> {
+ self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
+ }
+
+ pub(crate) fn rand_normal_f64(
+ &self,
+ mean: f64,
+ std: f64,
shape: &Shape,
+ dtype: DType,
) -> Result<Storage> {
- let mean = mean.to_f64();
- let std = std.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
+ let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
+ let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
}
+ pub(crate) fn rand_normal<T: crate::FloatDType>(
+ &self,
+ mean: T,
+ std: T,
+ shape: &Shape,
+ ) -> Result<Storage> {
+ self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
+ }
+
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {