diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-23 08:15:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 07:15:37 +0100 |
commit | b8a10425ad550b04ccf3b5ff2493714615d7df4b (patch) | |
tree | eab9ad609e34e4bad23cedc81ee338fe00961c3f /candle-examples/examples/custom-ops/main.rs | |
parent | 5f20acf0804a624d6c274e488c897fb88d698f1a (diff) | |
download | candle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.tar.gz candle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.tar.bz2 candle-b8a10425ad550b04ccf3b5ff2493714615d7df4b.zip |
Kernel build example (#224)
* Build example kernels.
* Add some sample custom kernel.
* Get the example kernel to compile.
* Add some cuda code.
* More cuda custom op.
* More cuda custom ops.
Diffstat (limited to 'candle-examples/examples/custom-ops/main.rs')
-rw-r--r-- | candle-examples/examples/custom-ops/main.rs | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs new file mode 100644 index 00000000..adc7abd7 --- /dev/null +++ b/candle-examples/examples/custom-ops/main.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] +#![allow(unused)] + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use clap::Parser; + +use candle::backend::BackendStorage; +use candle::cpu_backend; +use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +struct LayerNorm; + +impl CustomOp1 for LayerNorm { + fn name(&self) -> &'static str { + "layer-norm" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + let s = s.as_slice::<f32>()?; + let _s = match l.contiguous_offsets() { + None => Err(Error::Wrapped("input has to be contiguous".into()))?, + Some((o1, o2)) => &s[o1..o2], + }; + todo!() + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s: &candle::CudaStorage, + l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + let device = s.device().clone(); + let s = s.as_cuda_slice::<f32>()?; + let s = match l.contiguous_offsets() { + None => Err(Error::Wrapped("input has to be contiguous".into()))?, + Some((o1, o2)) => s, // TODO: slice with o1 and o2 + }; + let s: std::result::Result<_, candle::cuda_backend::CudaError> = + s.try_clone().map_err(|v| v.into()); + let s = s?; + let s = candle::CudaStorage::wrap_cuda_slice(s, device); + Ok((s, l.shape().clone())) + } +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?; + println!("{t}"); + let t = t.custom_op1(LayerNorm)?; + println!("{t}"); + Ok(()) +} |