diff options
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r-- | candle-core/src/device.rs | 44 |
1 files changed, 30 insertions, 14 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 89df8f84..563d892b 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -116,46 +116,62 @@ impl Device { } } - pub(crate) fn rand_uniform<T: crate::FloatDType>( + pub(crate) fn rand_uniform_f64( &self, - lo: T, - up: T, + lo: f64, + up: f64, shape: &Shape, + dtype: DType, ) -> Result<Storage> { - let lo = lo.to_f64(); - let up = up.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?; + let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?; + let storage = device.rand_uniform(shape, dtype, lo, up)?; Ok(Storage::Cuda(storage)) } } } - pub(crate) fn rand_normal<T: crate::FloatDType>( + pub(crate) fn rand_uniform<T: crate::FloatDType>( &self, - mean: T, - std: T, + lo: T, + up: T, + shape: &Shape, + ) -> Result<Storage> { + self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE) + } + + pub(crate) fn rand_normal_f64( + &self, + mean: f64, + std: f64, shape: &Shape, + dtype: DType, ) -> Result<Storage> { - let mean = mean.to_f64(); - let std = std.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?; + let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_normal(shape, T::DTYPE, mean, std)?; + let storage = device.rand_normal(shape, dtype, mean, std)?; Ok(Storage::Cuda(storage)) } } } + pub(crate) fn rand_normal<T: crate::FloatDType>( + &self, + mean: T, + std: T, + shape: &Shape, + ) -> Result<Storage> { + self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE) + } + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> { match self { Device::Cpu => { |