diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-13 21:32:32 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 21:32:32 +0100 |
commit | 2bfa791336b320b96d392aba83cbd4cee87173e3 (patch) | |
tree | a3127719a64cf5cfbf38f5f8be859afd2dc6118e /candle-core/examples/cuda_sum_benchmark.rs | |
parent | 57be3638d8c10304629f6859d183fb192858f3a3 (diff) | |
download | candle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.gz candle-2bfa791336b320b96d392aba83cbd4cee87173e3.tar.bz2 candle-2bfa791336b320b96d392aba83cbd4cee87173e3.zip |
Use the same default as pytorch for sum. (#164)
Diffstat (limited to 'candle-core/examples/cuda_sum_benchmark.rs')
-rw-r--r-- | candle-core/examples/cuda_sum_benchmark.rs | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/candle-core/examples/cuda_sum_benchmark.rs b/candle-core/examples/cuda_sum_benchmark.rs index 09d0099d..86a1691d 100644 --- a/candle-core/examples/cuda_sum_benchmark.rs +++ b/candle-core/examples/cuda_sum_benchmark.rs @@ -27,18 +27,18 @@ fn main() -> Result<()> { let xys_cpu = cos_sin(n, &Device::Cpu)?; let xys = cos_sin(n, &device)?; println!("{xys_cpu:?} {xys:?}"); - let sum_cpu = xys_cpu.sum(&[1])?; - println!("{sum_cpu}"); - let sum = xys.sum(&[1])?; - println!("{sum}"); + let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?; + println!("{sum_keepdim_cpu}"); + let sum_keepdim = xys.sum_keepdim(&[1])?; + println!("{sum_keepdim}"); let start = std::time::Instant::now(); let n_iters = 100; let mut v = 0f32; for _i in 0..n_iters { - let sum = xys.sum(&[1])?; - let sum = sum.sum(&[0])?; - let sum: f32 = sum.reshape(&[])?.to_scalar()?; - v += sum; + let sum_keepdim = xys.sum_keepdim(&[1])?; + let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?; + let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?; + v += sum_keepdim; } let elapsed = start.elapsed(); if v > 0. { |