summaryrefslogtreecommitdiff
path: root/candle-core/src/device.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-27 07:40:36 +0100
committerGitHub <noreply@github.com>2023-07-27 07:40:36 +0100
commit6475bfadfebcd02dd58adc60452f492f8dc11a39 (patch)
tree60e7ba3b02b1ebd26854f4e6adfec76df0617f48 /candle-core/src/device.rs
parent89ba005962495f2bfbda286e185e9c3c7f5300a3 (diff)
downloadcandle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.gz
candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.bz2
candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.zip
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn. * Also switch Tensor::rand to use a generic dtype. * Support sampling for f16. * Cleanup.
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r--candle-core/src/device.rs29
1 files changed, 15 insertions, 14 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 53e2de43..89df8f84 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -71,8 +71,7 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
fn to_cpu_storage(&self) -> CpuStorage {
- let mut vec = Vec::new();
- vec.reserve(N1 * N2 * N3);
+ let mut vec = Vec::with_capacity(N1 * N2 * N3);
for i1 in 0..N1 {
for i2 in 0..N2 {
vec.extend(self[i1][i2])
@@ -117,39 +116,41 @@ impl Device {
}
}
- pub(crate) fn rand_uniform(
+ pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
+ lo: T,
+ up: T,
shape: &Shape,
- dtype: DType,
- lo: f64,
- up: f64,
) -> Result<Storage> {
+ let lo = lo.to_f64();
+ let up = up.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
+ let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_uniform(shape, dtype, lo, up)?;
+ let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
- pub(crate) fn rand_normal(
+ pub(crate) fn rand_normal<T: crate::FloatDType>(
&self,
+ mean: T,
+ std: T,
shape: &Shape,
- dtype: DType,
- mean: f64,
- std: f64,
) -> Result<Storage> {
+ let mean = mean.to_f64();
+ let std = std.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
+ let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_normal(shape, dtype, mean, std)?;
+ let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cuda(storage))
}
}