summaryrefslogtreecommitdiff
path: root/candle-examples/examples/custom-ops
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/custom-ops')
-rw-r--r--candle-examples/examples/custom-ops/cuda_kernels.rs1
-rw-r--r--candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu18
-rw-r--r--candle-examples/examples/custom-ops/main.rs29
3 files changed, 31 insertions, 17 deletions
diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs
index 07d18342..0bee73aa 100644
--- a/candle-examples/examples/custom-ops/cuda_kernels.rs
+++ b/candle-examples/examples/custom-ops/cuda_kernels.rs
@@ -1 +1,2 @@
+#[rustfmt::skip]
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
diff --git a/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu
index 07ab8639..a0836392 100644
--- a/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu
+++ b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu
@@ -1,12 +1,12 @@
+#include <stdint.h>
#include "reduction_utils.cuh"
template <typename scalar_t>
__device__ void
rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
- const scalar_t *__restrict__ weight, // [hidden_size]
- const float epsilon, const int num_tokens,
- const int hidden_size) {
+ const float epsilon, const uint32_t num_tokens,
+ const uint32_t hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
@@ -22,16 +22,14 @@ rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
- out[blockIdx.x * hidden_size + idx] =
- ((scalar_t)(x * s_variance)) * weight[idx];
+ out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance));
}
}
-extern "C" __global__ void rms_norm_kernel_f32(
+extern "C" __global__ void rms_f32(
float *__restrict__ out, // [num_tokens, hidden_size]
const float *__restrict__ input, // [num_tokens, hidden_size]
- const float *__restrict__ weight, // [hidden_size]
- const float epsilon, const int num_tokens,
- const int hidden_size) {
- rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size);
+ const float epsilon, const uint32_t num_tokens,
+ const uint32_t hidden_size) {
+ rms_norm_kernel(out, input, epsilon, num_tokens, hidden_size);
}
diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs
index adc7abd7..9c917cca 100644
--- a/candle-examples/examples/custom-ops/main.rs
+++ b/candle-examples/examples/custom-ops/main.rs
@@ -4,6 +4,8 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
+mod cuda_kernels;
+
use clap::Parser;
use candle::backend::BackendStorage;
@@ -40,17 +42,30 @@ impl CustomOp1 for LayerNorm {
s: &candle::CudaStorage,
l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
- let device = s.device().clone();
+ use candle::cuda_backend::{cudarc, WrapErr};
+ use cudarc::driver::{LaunchAsync, LaunchConfig};
+ let (d1, d2) = l.shape().dims2()?;
+ let d1 = d1 as u32;
+ let d2 = d2 as u32;
+ let dev = s.device().clone();
let s = s.as_cuda_slice::<f32>()?;
let s = match l.contiguous_offsets() {
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
- Some((o1, o2)) => s, // TODO: slice with o1 and o2
+ Some((o1, o2)) => s.slice(o1..o2),
+ };
+ let elem_count = l.shape().elem_count();
+ let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
+ let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
+ let params = (&dst, &s, 1e-5f32, d1, d2);
+ let cfg = LaunchConfig {
+ grid_dim: (d1, 1, 1),
+ block_dim: (d2, 1, 1),
+ shared_mem_bytes: 0,
};
- let s: std::result::Result<_, candle::cuda_backend::CudaError> =
- s.try_clone().map_err(|v| v.into());
- let s = s?;
- let s = candle::CudaStorage::wrap_cuda_slice(s, device);
- Ok((s, l.shape().clone()))
+ unsafe { func.launch(cfg, params) }.w()?;
+
+ let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
+ Ok((dst, l.shape().clone()))
}
}