summaryrefslogtreecommitdiff
path: root/candle-core/examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-29 16:12:11 +0100
committerGitHub <noreply@github.com>2023-08-29 16:12:11 +0100
commita044907ffce553a0394db3a1204f21e3691e54af (patch)
tree8ce11fae8ee11e4eb181f7240344994356625791 /candle-core/examples
parentee8bb1bde1a44738c314dfaacba743f4eabf917c (diff)
downloadcandle-a044907ffce553a0394db3a1204f21e3691e54af.tar.gz
candle-a044907ffce553a0394db3a1204f21e3691e54af.tar.bz2
candle-a044907ffce553a0394db3a1204f21e3691e54af.zip
Dilated convolutions (#657)
* Add the dilation parameter. * Restore the basic optimizer example. * Dilation support in cudnn. * Use the dilation parameter in the cpu backend. * More dilation support. * No support for dilation in transposed convolutions. * Add dilation to a test. * Remove a print. * Helper function.
Diffstat (limited to 'candle-core/examples')
-rw-r--r--candle-core/examples/basics.rs2
-rw-r--r--candle-core/examples/cpu_benchmarks.rs4
-rw-r--r--candle-core/examples/cuda_basics.rs6
3 files changed, 6 insertions, 6 deletions
diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs
index 9d4734de..ad008177 100644
--- a/candle-core/examples/basics.rs
+++ b/candle-core/examples/basics.rs
@@ -11,7 +11,7 @@ fn main() -> Result<()> {
let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
let start = std::time::Instant::now();
- let res = inp.conv2d(&w, 0, 1, 1)?;
+ let res = inp.conv2d(&w, 0, 1, 1, 1)?;
println!("{:?}", start.elapsed());
println!("{res:?}");
Ok(())
diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs
index 1ebd9b75..13175ac1 100644
--- a/candle-core/examples/cpu_benchmarks.rs
+++ b/candle-core/examples/cpu_benchmarks.rs
@@ -40,7 +40,7 @@ impl Benchmark for Conv1d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv1d(&d.1, 0, 1, 1)
+ d.0.conv1d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 5;
@@ -59,7 +59,7 @@ impl Benchmark for Conv2d {
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
- d.0.conv2d(&d.1, 0, 1, 1)
+ d.0.conv2d(&d.1, 0, 1, 1, 1)
}
const ITERS: usize = 1;
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs
index cbdafd64..ad207461 100644
--- a/candle-core/examples/cuda_basics.rs
+++ b/candle-core/examples/cuda_basics.rs
@@ -11,11 +11,11 @@ fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
- let out_t = in_t.conv2d(&k_t, 0, 1, 1)?;
+ let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
println!("{out_t}");
let in_t = in_t.to_device(&Device::Cpu)?;
let k_t = k_t.to_device(&Device::Cpu)?;
- let out_t2 = in_t.conv2d(&k_t, 0, 1, 1)?;
+ let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
.sqr()?
.sum_all()?;
@@ -23,7 +23,7 @@ fn main() -> Result<()> {
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
- let res = t.conv2d(&w, 1, 1, 1)?;
+ let res = t.conv2d(&w, 1, 1, 1, 1)?;
println!("{res:?}");
Ok(())
}