summaryrefslogtreecommitdiff
path: root/candle-core/examples/cuda_basics.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-24 10:16:37 +0100
committerGitHub <noreply@github.com>2023-08-24 10:16:37 +0100
commitdd64465899f4b58628642b406c465d35ddfe8f79 (patch)
tree917c410d8d9ab07786091bed14672d9892ddeba0 /candle-core/examples/cuda_basics.rs
parent79916c2edbab024c37918b5c27c9d675cd444410 (diff)
downloadcandle-dd64465899f4b58628642b406c465d35ddfe8f79.tar.gz
candle-dd64465899f4b58628642b406c465d35ddfe8f79.tar.bz2
candle-dd64465899f4b58628642b406c465d35ddfe8f79.zip
Add a test for conv2d with padding + bugfix the random number generation on cuda. (#578)
* Add a test for conv2d with padding. * Cosmetic changes. * Bugfix the rand function on the cuda backend.
Diffstat (limited to 'candle-core/examples/cuda_basics.rs')
-rw-r--r--candle-core/examples/cuda_basics.rs3
1 files changed, 3 insertions, 0 deletions
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs
index ac435488..6a3aaacc 100644
--- a/candle-core/examples/cuda_basics.rs
+++ b/candle-core/examples/cuda_basics.rs
@@ -9,6 +9,9 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
+ let t = Tensor::rand(-1f32, 1f32, 96, &device)?;
+ println!("{t}");
+
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1, 1)?;