diff options
-rw-r--r-- | candle-core/src/device.rs | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index d566ba42..9dfcd7d5 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -185,8 +185,14 @@ impl Device { Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_uniform(shape, dtype, lo, up)?; - Ok(Storage::Cuda(storage)) + // TODO: Remove the special case if we start supporting generating f16/bf16 directly. + if dtype == DType::F16 || dtype == DType::BF16 { + let storage = device.rand_uniform(shape, DType::F32, lo, up)?; + Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype) + } else { + let storage = device.rand_uniform(shape, dtype, lo, up)?; + Ok(Storage::Cuda(storage)) + } } } } @@ -213,8 +219,14 @@ impl Device { Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - let storage = device.rand_normal(shape, dtype, mean, std)?; - Ok(Storage::Cuda(storage)) + // TODO: Remove the special case if we start supporting generating f16/bf16 directly. + if dtype == DType::F16 || dtype == DType::BF16 { + let storage = device.rand_normal(shape, DType::F32, mean, std)?; + Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype) + } else { + let storage = device.rand_normal(shape, dtype, mean, std)?; + Ok(Storage::Cuda(storage)) + } } } } |