summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-31 21:27:59 +0100
committerGitHub <noreply@github.com>2023-10-31 20:27:59 +0000
commit36fb84f03802073e137ead8738b01cfd3e470256 (patch)
tree372870897aca9d9df1fecff7aa0c6cc1d81cd02e /candle-core
parentc12ad45562778ffff0cda6c623b4838a2ed1c57c (diff)
downloadcandle-36fb84f03802073e137ead8738b01cfd3e470256.tar.gz
candle-36fb84f03802073e137ead8738b01cfd3e470256.tar.bz2
candle-36fb84f03802073e137ead8738b01cfd3e470256.zip
Add a hack for generating random uniform/normal for f16/bf16. (#1228)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/device.rs20
1 files changed, 16 insertions, 4 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index d566ba42..9dfcd7d5 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -185,8 +185,14 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_uniform(shape, dtype, lo, up)?;
- Ok(Storage::Cuda(storage))
+ // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
+ if dtype == DType::F16 || dtype == DType::BF16 {
+ let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
+ Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
+ } else {
+ let storage = device.rand_uniform(shape, dtype, lo, up)?;
+ Ok(Storage::Cuda(storage))
+ }
}
}
}
@@ -213,8 +219,14 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- let storage = device.rand_normal(shape, dtype, mean, std)?;
- Ok(Storage::Cuda(storage))
+ // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
+ if dtype == DType::F16 || dtype == DType::BF16 {
+ let storage = device.rand_normal(shape, DType::F32, mean, std)?;
+ Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
+ } else {
+ let storage = device.rand_normal(shape, dtype, mean, std)?;
+ Ok(Storage::Cuda(storage))
+ }
}
}
}