diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9cc454f1..b3d542b9 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", @@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice { } fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { |