summaryrefslogtreecommitdiff
path: root/candle-core/src/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r--candle-core/src/device.rs29
1 files changed, 15 insertions, 14 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 53e2de43..89df8f84 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -71,8 +71,7 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
}
fn to_cpu_storage(&self) -> CpuStorage {
- let mut vec = Vec::new();
- vec.reserve(N1 * N2 * N3);
+ let mut vec = Vec::with_capacity(N1 * N2 * N3);
for i1 in 0..N1 {
for i2 in 0..N2 {
vec.extend(self[i1][i2])
@@ -117,39 +116,41 @@ impl Device {
}
}
- pub(crate) fn rand_uniform(
+ pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
+ lo: T,
+ up: T,
shape: &Shape,
- dtype: DType,
- lo: f64,
- up: f64,
) -> Result<Storage> {
+ let lo = lo.to_f64();
+ let up = up.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
+ let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_uniform(shape, dtype, lo, up)?;
+ let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
- pub(crate) fn rand_normal(
+ pub(crate) fn rand_normal<T: crate::FloatDType>(
&self,
+ mean: T,
+ std: T,
shape: &Shape,
- dtype: DType,
- mean: f64,
- std: f64,
) -> Result<Storage> {
+ let mean = mean.to_f64();
+ let std = std.to_f64();
match self {
Device::Cpu => {
- let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
+ let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_normal(shape, dtype, mean, std)?;
+ let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
Ok(Storage::Cuda(storage))
}
}