diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 55 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/device.rs | 29 | ||||
-rw-r--r-- | candle-core/src/dtype.rs | 9 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 36 | ||||
-rw-r--r-- | candle-core/src/variable.rs | 18 |
7 files changed, 93 insertions, 60 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8d38b158..83c7080f 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -369,8 +369,7 @@ pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>( block_start_index, block_len, } => { - let mut result = vec![]; - result.reserve(layout.shape().elem_count()); + let mut result = Vec::with_capacity(layout.shape().elem_count()); // Specialize the case where block_len is one to avoid the second loop. if block_len == 1 { for index in block_start_index { @@ -1843,12 +1842,27 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()), + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::<bf16, _>(uniform)) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::<f16, _>(uniform)) + } + Ok(CpuStorage::F16(data)) } DType::F32 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min as f32, max as f32); for _i in 0..elem_count { data.push(rng.sample::<f32, _>(uniform)) @@ -1856,8 +1870,7 @@ impl BackendDevice for CpuDevice { Ok(CpuStorage::F32(data)) } DType::F64 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min, max); for _i in 0..elem_count { data.push(rng.sample::<f64, _>(uniform)) @@ -1873,12 +1886,27 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let std = bf16::from_f64(std); + let mean = bf16::from_f64(mean); + for _i in 0..elem_count { + data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let std = f16::from_f64(std); + let mean = f16::from_f64(mean); + for _i in 0..elem_count { + data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean) + } + Ok(CpuStorage::F16(data)) } DType::F32 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let std = std as f32; let mean = mean as f32; for _i in 0..elem_count { @@ -1887,8 +1915,7 @@ impl BackendDevice for CpuDevice { Ok(CpuStorage::F32(data)) } DType::F64 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); for _i in 0..elem_count { data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9cc454f1..b3d542b9 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", @@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice { } fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 53e2de43..89df8f84 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -71,8 +71,7 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray } fn to_cpu_storage(&self) -> CpuStorage { - let mut vec = Vec::new(); - vec.reserve(N1 * N2 * N3); + let mut vec = Vec::with_capacity(N1 * N2 * N3); for i1 in 0..N1 { for i2 in 0..N2 { vec.extend(self[i1][i2]) @@ -117,39 +116,41 @@ impl Device { } } - pub(crate) fn rand_uniform( + pub(crate) fn rand_uniform<T: crate::FloatDType>( &self, + lo: T, + up: T, shape: &Shape, - dtype: DType, - lo: f64, - up: f64, ) -> Result<Storage> { + let lo = lo.to_f64(); + let up = up.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?; + let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, dtype, lo, up)?; + let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?; Ok(Storage::Cuda(storage)) } } } - pub(crate) fn rand_normal( + pub(crate) fn rand_normal<T: crate::FloatDType>( &self, + mean: T, + std: T, shape: &Shape, - dtype: DType, - mean: f64, - std: f64, ) -> Result<Storage> { + let mean = mean.to_f64(); + let std = std.to_f64(); match self { Device::Cpu => { - let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?; + let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?; Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_normal(shape, dtype, mean, std)?; + let storage = device.rand_normal(shape, T::DTYPE, mean, std)?; Ok(Storage::Cuda(storage)) } } diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index c6befbb8..0e906119 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -120,7 +120,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); -pub trait IntDType { +pub trait IntDType: WithDType { fn is_true(&self) -> bool; fn as_usize(&self) -> usize; } @@ -142,3 +142,10 @@ impl IntDType for u8 { *self as usize } } + +pub trait FloatDType: WithDType {} + +impl FloatDType for f16 {} +impl FloatDType for bf16 {} +impl FloatDType for f32 {} +impl FloatDType for f64 {} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 3dbae7fc..95cc189c 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -61,7 +61,7 @@ mod variable; pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; -pub use dtype::{DType, IntDType, WithDType}; +pub use dtype::{DType, FloatDType, IntDType, WithDType}; pub use error::{Error, Result}; pub use indexer::IndexOp; pub use layout::Layout; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 28ecc357..09f61340 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -232,55 +232,51 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), self.device()) } - pub(crate) fn rand_impl<S: Into<Shape>>( + pub(crate) fn rand_impl<S: Into<Shape>, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, is_variable: bool, ) -> Result<Self> { let s = s.into(); - let storage = device.rand_uniform(&s, dtype, lo, up)?; + let storage = device.rand_uniform(lo, up, &s)?; let none = BackpropOp::none(); Ok(from_storage(storage, s, none, is_variable)) } /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. - pub fn rand<S: Into<Shape>>( + pub fn rand<S: Into<Shape>, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, ) -> Result<Self> { - Self::rand_impl(s, dtype, device, lo, up, false) + Self::rand_impl(lo, up, s, device, false) } - pub(crate) fn randn_impl<S: Into<Shape>>( + pub(crate) fn randn_impl<S: Into<Shape>, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, is_variable: bool, ) -> Result<Self> { let s = s.into(); - let storage = device.rand_normal(&s, dtype, mean, std)?; + let storage = device.rand_normal(mean, std, &s)?; let none = BackpropOp::none(); Ok(from_storage(storage, s, none, is_variable)) } /// Creates a new tensor initialized with values sampled from a normal distribution with the /// specified `mean` and standard deviation `std`. - pub fn randn<S: Into<Shape>>( + pub fn randn<S: Into<Shape>, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, ) -> Result<Self> { - Self::randn_impl(s, dtype, device, mean, std, false) + Self::randn_impl(mean, std, s, device, false) } pub(crate) fn new_impl<A: crate::device::NdArray>( diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index e26f1420..0cefee11 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -34,25 +34,23 @@ impl Var { Ok(Self(inner)) } - pub fn rand<S: Into<Shape>>( + pub fn rand<S: Into<Shape>, T: crate::FloatDType>( + lo: T, + up: T, s: S, - dtype: DType, device: &Device, - lo: f64, - up: f64, ) -> Result<Self> { - let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?; + let inner = Tensor::rand_impl(lo, up, s, device, true)?; Ok(Self(inner)) } - pub fn randn<S: Into<Shape>>( + pub fn randn<S: Into<Shape>, T: crate::FloatDType>( + mean: T, + std: T, s: S, - dtype: DType, device: &Device, - mean: f64, - std: f64, ) -> Result<Self> { - let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?; + let inner = Tensor::randn_impl(mean, std, s, device, true)?; Ok(Self(inner)) } |