diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-27 07:40:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-27 07:40:36 +0100 |
commit | 6475bfadfebcd02dd58adc60452f492f8dc11a39 (patch) | |
tree | 60e7ba3b02b1ebd26854f4e6adfec76df0617f48 /candle-core/src/cuda_backend.rs | |
parent | 89ba005962495f2bfbda286e185e9c3c7f5300a3 (diff) | |
download | candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.gz candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.bz2 candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.zip |
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn.
* Also switch Tensor::rand to use a generic dtype.
* Support sampling for f16.
* Cleanup.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9cc454f1..b3d542b9 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", @@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice { } fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { |