diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-31 21:27:59 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-31 20:27:59 +0000 |
commit | 36fb84f03802073e137ead8738b01cfd3e470256 (patch) | |
tree | 372870897aca9d9df1fecff7aa0c6cc1d81cd02e /candle-core | |
parent | c12ad45562778ffff0cda6c623b4838a2ed1c57c (diff) | |
download | candle-36fb84f03802073e137ead8738b01cfd3e470256.tar.gz candle-36fb84f03802073e137ead8738b01cfd3e470256.tar.bz2 candle-36fb84f03802073e137ead8738b01cfd3e470256.zip |
Add a hack for generating random uniform/normal for f16/bf16. (#1228)
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/device.rs | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index d566ba42..9dfcd7d5 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -185,8 +185,14 @@ impl Device { Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, dtype, lo, up)?; - Ok(Storage::Cuda(storage)) + // TODO: Remove the special case if we start supporting generating f16/bf16 directly. + if dtype == DType::F16 || dtype == DType::BF16 { + let storage = device.rand_uniform(shape, DType::F32, lo, up)?; + Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype) + } else { + let storage = device.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Cuda(storage)) + } } } } @@ -213,8 +219,14 @@ impl Device { Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_normal(shape, dtype, mean, std)?; - Ok(Storage::Cuda(storage)) + // TODO: Remove the special case if we start supporting generating f16/bf16 directly. + if dtype == DType::F16 || dtype == DType::BF16 { + let storage = device.rand_normal(shape, DType::F32, mean, std)?; + Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype) + } else { + let storage = device.rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Cuda(storage)) + } } } } |