summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-27 07:40:36 +0100
committerGitHub <noreply@github.com>2023-07-27 07:40:36 +0100
commit6475bfadfebcd02dd58adc60452f492f8dc11a39 (patch)
tree60e7ba3b02b1ebd26854f4e6adfec76df0617f48 /candle-core/src/cuda_backend.rs
parent89ba005962495f2bfbda286e185e9c3c7f5300a3 (diff)
downloadcandle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.gz
candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.bz2
candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.zip
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn. * Also switch Tensor::rand to use a generic dtype. * Support sampling for f16. * Cleanup.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 9cc454f1..b3d542b9 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -255,6 +255,8 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
+ // TODO: Add support for F16 and BF16 though this is likely to require some upstream
+ // cudarc changes.
DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_uniform",
@@ -282,6 +284,8 @@ impl BackendDevice for CudaDevice {
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
+ // TODO: Add support for F16 and BF16 though this is likely to require some upstream
+ // cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {