summaryrefslogtreecommitdiff
path: root/candle-examples/examples/custom-ops
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/custom-ops')
-rw-r--r--candle-examples/examples/custom-ops/cuda_kernels.rs1
-rw-r--r--candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu37
-rw-r--r--candle-examples/examples/custom-ops/kernels/reduction_utils.cuh46
-rw-r--r--candle-examples/examples/custom-ops/main.rs65
4 files changed, 149 insertions, 0 deletions
diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs
new file mode 100644
index 00000000..07d18342
--- /dev/null
+++ b/candle-examples/examples/custom-ops/cuda_kernels.rs
@@ -0,0 +1 @@
+pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
diff --git a/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu
new file mode 100644
index 00000000..07ab8639
--- /dev/null
+++ b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu
@@ -0,0 +1,37 @@
+#include "reduction_utils.cuh"
+
+template <typename scalar_t>
+__device__ void
+rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
+ const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
+ const scalar_t *__restrict__ weight, // [hidden_size]
+ const float epsilon, const int num_tokens,
+ const int hidden_size) {
+ __shared__ float s_variance;
+ float variance = 0.0f;
+
+ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
+ const float x = (float)input[blockIdx.x * hidden_size + idx];
+ variance += x * x;
+ }
+ variance = blockReduceSum<float>(variance);
+ if (threadIdx.x == 0) {
+ s_variance = rsqrtf(variance / hidden_size + epsilon);
+ }
+ __syncthreads();
+
+ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
+ float x = (float)input[blockIdx.x * hidden_size + idx];
+ out[blockIdx.x * hidden_size + idx] =
+ ((scalar_t)(x * s_variance)) * weight[idx];
+ }
+}
+extern "C" __global__ void rms_norm_kernel_f32(
+ float *__restrict__ out, // [num_tokens, hidden_size]
+ const float *__restrict__ input, // [num_tokens, hidden_size]
+ const float *__restrict__ weight, // [hidden_size]
+ const float epsilon, const int num_tokens,
+ const int hidden_size) {
+ rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size);
+}
+
diff --git a/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh b/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh
new file mode 100644
index 00000000..d5765f4f
--- /dev/null
+++ b/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh
@@ -0,0 +1,46 @@
+/*
+ * Adapted from
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+template <typename T> __inline__ __device__ T warpReduceSum(T val) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1)
+ val += __shfl_xor_sync(0xffffffff, val, mask, 32);
+ return val;
+}
+
+/* Calculate the sum of all elements in a block */
+template <typename T> __inline__ __device__ T blockReduceSum(T val) {
+ static __shared__ T shared[32];
+ int lane = threadIdx.x & 0x1f;
+ int wid = threadIdx.x >> 5;
+
+ val = warpReduceSum<T>(val);
+
+ if (lane == 0)
+ shared[wid] = val;
+
+ __syncthreads();
+
+ // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
+ // blockDim.x is not divided by 32
+ val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
+ val = warpReduceSum<T>(val);
+ return val;
+}
diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs
new file mode 100644
index 00000000..adc7abd7
--- /dev/null
+++ b/candle-examples/examples/custom-ops/main.rs
@@ -0,0 +1,65 @@
+#![allow(dead_code)]
+#![allow(unused)]
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use clap::Parser;
+
+use candle::backend::BackendStorage;
+use candle::cpu_backend;
+use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+}
+
+struct LayerNorm;
+
+impl CustomOp1 for LayerNorm {
+ fn name(&self) -> &'static str {
+ "layer-norm"
+ }
+
+ fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
+ let s = s.as_slice::<f32>()?;
+ let _s = match l.contiguous_offsets() {
+ None => Err(Error::Wrapped("input has to be contiguous".into()))?,
+ Some((o1, o2)) => &s[o1..o2],
+ };
+ todo!()
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ s: &candle::CudaStorage,
+ l: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ let device = s.device().clone();
+ let s = s.as_cuda_slice::<f32>()?;
+ let s = match l.contiguous_offsets() {
+ None => Err(Error::Wrapped("input has to be contiguous".into()))?,
+ Some((o1, o2)) => s, // TODO: slice with o1 and o2
+ };
+ let s: std::result::Result<_, candle::cuda_backend::CudaError> =
+ s.try_clone().map_err(|v| v.into());
+ let s = s?;
+ let s = candle::CudaStorage::wrap_cuda_slice(s, device);
+ Ok((s, l.shape().clone()))
+ }
+}
+
+fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+ let device = candle_examples::device(args.cpu)?;
+ let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
+ println!("{t}");
+ let t = t.custom_op1(LayerNorm)?;
+ println!("{t}");
+ Ok(())
+}