summaryrefslogtreecommitdiff
path: root/candle-core/examples/cuda_sum_benchmark.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 21:32:32 +0100
committerGitHub <noreply@github.com>2023-07-13 21:32:32 +0100
commit2bfa791336b320b96d392aba83cbd4cee87173e3 (patch)
treea3127719a64cf5cfbf38f5f8be859afd2dc6118e /candle-core/examples/cuda_sum_benchmark.rs
parent57be3638d8c10304629f6859d183fb192858f3a3 (diff)
downloadcandle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.gz
candle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.bz2
candle-2bfa791336b320b96d392aba83cbd4cee87173e3.zip
Use the same default as pytorch for sum. (#164)
Diffstat (limited to 'candle-core/examples/cuda_sum_benchmark.rs')
-rw-r--r--candle-core/examples/cuda_sum_benchmark.rs16
1 files changed, 8 insertions, 8 deletions
diff --git a/candle-core/examples/cuda_sum_benchmark.rs b/candle-core/examples/cuda_sum_benchmark.rs
index 09d0099d..86a1691d 100644
--- a/candle-core/examples/cuda_sum_benchmark.rs
+++ b/candle-core/examples/cuda_sum_benchmark.rs
@@ -27,18 +27,18 @@ fn main() -> Result<()> {
let xys_cpu = cos_sin(n, &Device::Cpu)?;
let xys = cos_sin(n, &device)?;
println!("{xys_cpu:?} {xys:?}");
- let sum_cpu = xys_cpu.sum(&[1])?;
- println!("{sum_cpu}");
- let sum = xys.sum(&[1])?;
- println!("{sum}");
+ let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?;
+ println!("{sum_keepdim_cpu}");
+ let sum_keepdim = xys.sum_keepdim(&[1])?;
+ println!("{sum_keepdim}");
let start = std::time::Instant::now();
let n_iters = 100;
let mut v = 0f32;
for _i in 0..n_iters {
- let sum = xys.sum(&[1])?;
- let sum = sum.sum(&[0])?;
- let sum: f32 = sum.reshape(&[])?.to_scalar()?;
- v += sum;
+ let sum_keepdim = xys.sum_keepdim(&[1])?;
+ let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?;
+ let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
+ v += sum_keepdim;
}
let elapsed = start.elapsed();
if v > 0. {