summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 9cc454f1..b3d542b9 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
+ // TODO: Add support for F16 and BF16 though this is likely to require some upstream
+ // cudarc changes.
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
@@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice {
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
+ // TODO: Add support for F16 and BF16 though this is likely to require some upstream
+ // cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {