diff options
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r-- | candle-core/src/device.rs | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 53e2de43..89df8f84 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -71,8 +71,7 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray } fn to_cpu_storage(&self) -> CpuStorage { - let mut vec = Vec::new(); - vec.reserve(N1 * N2 * N3); + let mut vec = Vec::with_capacity(N1 * N2 * N3); for i1 in 0..N1 { for i2 in 0..N2 { vec.extend(self[i1][i2]) @@ -117,39 +116,41 @@ impl Device { } } - pub(crate) fn rand_uniform( + pub(crate) fn rand_uniform<T: crate::FloatDType>( &self, + lo: T, + up: T, shape: &Shape, - dtype: DType, - lo: f64, - up: f64, ) -> Result<Storage> { + let lo = lo.to_f64(); + let up = up.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; + let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, dtype, lo, up)?; + let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?; Ok(Storage::Cuda(storage)) } } } - pub(crate) fn rand_normal( + pub(crate) fn rand_normal<T: crate::FloatDType>( &self, + mean: T, + std: T, shape: &Shape, - dtype: DType, - mean: f64, - std: f64, ) -> Result<Storage> { + let mean = mean.to_f64(); + let std = std.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; + let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_normal(shape, dtype, mean, std)?; + let storage = device.rand_normal(shape, T::DTYPE, mean, std)?; Ok(Storage::Cuda(storage)) } } |