diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-27 07:40:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-27 07:40:36 +0100 |
commit | 6475bfadfebcd02dd58adc60452f492f8dc11a39 (patch) | |
tree | 60e7ba3b02b1ebd26854f4e6adfec76df0617f48 /candle-core/src/device.rs | |
parent | 89ba005962495f2bfbda286e185e9c3c7f5300a3 (diff) | |
download | candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.gz candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.bz2 candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.zip |
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn.
* Also switch Tensor::rand to use a generic dtype.
* Support sampling for f16.
* Cleanup.
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r-- | candle-core/src/device.rs | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 53e2de43..89df8f84 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -71,8 +71,7 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray } fn to_cpu_storage(&self) -> CpuStorage { - let mut vec = Vec::new(); - vec.reserve(N1 * N2 * N3); + let mut vec = Vec::with_capacity(N1 * N2 * N3); for i1 in 0..N1 { for i2 in 0..N2 { vec.extend(self[i1][i2]) @@ -117,39 +116,41 @@ impl Device { } } - pub(crate) fn rand_uniform( + pub(crate) fn rand_uniform<T: crate::FloatDType>( &self, + lo: T, + up: T, shape: &Shape, - dtype: DType, - lo: f64, - up: f64, ) -> Result<Storage> { + let lo = lo.to_f64(); + let up = up.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; + let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, dtype, lo, up)?; + let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?; Ok(Storage::Cuda(storage)) } } } - pub(crate) fn rand_normal( + pub(crate) fn rand_normal<T: crate::FloatDType>( &self, + mean: T, + std: T, shape: &Shape, - dtype: DType, - mean: f64, - std: f64, ) -> Result<Storage> { + let mean = mean.to_f64(); + let std = std.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; + let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_normal(shape, dtype, mean, std)?; + let storage = device.rand_normal(shape, T::DTYPE, mean, std)?; Ok(Storage::Cuda(storage)) } } |