summaryrefslogtreecommitdiff
path: root/candle-core/examples/cuda_basics.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-14 21:30:41 +0100
committerGitHub <noreply@github.com>2023-08-14 21:30:41 +0100
commit90374097dc99b14dfc935318a18c21fc5909291f (patch)
tree61a45bfa840f4551a5581a4ab169dce23b13db4b /candle-core/examples/cuda_basics.rs
parentc84883ecf2c240792392353175b634f6ec92a011 (diff)
downloadcandle-90374097dc99b14dfc935318a18c21fc5909291f.tar.gz
candle-90374097dc99b14dfc935318a18c21fc5909291f.tar.bz2
candle-90374097dc99b14dfc935318a18c21fc5909291f.zip
Cudnn support (#445)
* Add a cudnn feature to be used for conv2d. * Allocate the proper workspace. * Only create a single cudnn handle per cuda device. * Proper cudnn usage. * Bugfix.
Diffstat (limited to 'candle-core/examples/cuda_basics.rs')
-rw-r--r--candle-core/examples/cuda_basics.rs9
1 files changed, 4 insertions, 5 deletions
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs
index d902b9d5..12febb60 100644
--- a/candle-core/examples/cuda_basics.rs
+++ b/candle-core/examples/cuda_basics.rs
@@ -9,10 +9,9 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
- let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
- let sum = t.sum_keepdim(0)?;
- println!("{sum}");
- let sum = t.sum_keepdim(1)?;
- println!("{sum}");
+ let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
+ let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
+ let res = t.conv2d(&w, 1, 1)?;
+ println!("{res:?}");
Ok(())
}