diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-23 12:31:17 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-23 11:31:17 +0100 |
commit | e449ce53a2f3c85f23ca0f2e7d557a0d0003e0ca (patch) | |
tree | b908de30c669bbbf20edb1604caec3450f76a01d /candle-examples/examples/custom-ops/main.rs | |
parent | b8a10425ad550b04ccf3b5ff2493714615d7df4b (diff) | |
download | candle-e449ce53a2f3c85f23ca0f2e7d557a0d0003e0ca.tar.gz candle-e449ce53a2f3c85f23ca0f2e7d557a0d0003e0ca.tar.bz2 candle-e449ce53a2f3c85f23ca0f2e7d557a0d0003e0ca.zip |
Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op.
* Get the rms example to work.
* Get around rustfmt failing in the CI.
* Fix the rms computation.
Diffstat (limited to 'candle-examples/examples/custom-ops/main.rs')
-rw-r--r-- | candle-examples/examples/custom-ops/main.rs | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index adc7abd7..9c917cca 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -4,6 +4,8 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; +mod cuda_kernels; + use clap::Parser; use candle::backend::BackendStorage; @@ -40,17 +42,30 @@ impl CustomOp1 for LayerNorm { s: &candle::CudaStorage, l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { - let device = s.device().clone(); + use candle::cuda_backend::{cudarc, WrapErr}; + use cudarc::driver::{LaunchAsync, LaunchConfig}; + let (d1, d2) = l.shape().dims2()?; + let d1 = d1 as u32; + let d2 = d2 as u32; + let dev = s.device().clone(); let s = s.as_cuda_slice::<f32>()?; let s = match l.contiguous_offsets() { None => Err(Error::Wrapped("input has to be contiguous".into()))?, - Some((o1, o2)) => s, // TODO: slice with o1 and o2 + Some((o1, o2)) => s.slice(o1..o2), + }; + let elem_count = l.shape().elem_count(); + let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?; + let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?; + let params = (&dst, &s, 1e-5f32, d1, d2); + let cfg = LaunchConfig { + grid_dim: (d1, 1, 1), + block_dim: (d2, 1, 1), + shared_mem_bytes: 0, }; - let s: std::result::Result<_, candle::cuda_backend::CudaError> = - s.try_clone().map_err(|v| v.into()); - let s = s?; - let s = candle::CudaStorage::wrap_cuda_slice(s, device); - Ok((s, l.shape().clone())) + unsafe { func.launch(cfg, params) }.w()?; + + let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev); + Ok((dst, l.shape().clone())) } } |