diff options
Diffstat (limited to 'candle-examples/examples')
4 files changed, 149 insertions, 0 deletions
diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs new file mode 100644 index 00000000..07d18342 --- /dev/null +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -0,0 +1 @@ +pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx")); diff --git a/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu new file mode 100644 index 00000000..07ab8639 --- /dev/null +++ b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu @@ -0,0 +1,37 @@ +#include "reduction_utils.cuh" + +template <typename scalar_t> +__device__ void +rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size] + const scalar_t *__restrict__ input, // [num_tokens, hidden_size] + const scalar_t *__restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + variance = blockReduceSum<float>(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} +extern "C" __global__ void rms_norm_kernel_f32( + float *__restrict__ out, // [num_tokens, hidden_size] + const float *__restrict__ input, // [num_tokens, hidden_size] + const float *__restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, + const int hidden_size) { + rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size); +} + diff --git a/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh b/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh new file mode 100644 index 00000000..d5765f4f --- /dev/null +++ b/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh @@ -0,0 +1,46 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +template <typename T> __inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template <typename T> __inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum<T>(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum<T>(val); + return val; +} diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs new file mode 100644 index 00000000..adc7abd7 --- /dev/null +++ b/candle-examples/examples/custom-ops/main.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] +#![allow(unused)] + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use clap::Parser; + +use candle::backend::BackendStorage; +use candle::cpu_backend; +use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +struct LayerNorm; + +impl CustomOp1 for LayerNorm { + fn name(&self) -> &'static str { + "layer-norm" + } + + fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> { + let s = s.as_slice::<f32>()?; + let _s = match l.contiguous_offsets() { + None => Err(Error::Wrapped("input has to be contiguous".into()))?, + Some((o1, o2)) => &s[o1..o2], + }; + todo!() + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s: &candle::CudaStorage, + l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + let device = s.device().clone(); + let s = s.as_cuda_slice::<f32>()?; + let s = match l.contiguous_offsets() { + None => Err(Error::Wrapped("input has to be contiguous".into()))?, + Some((o1, o2)) => s, // TODO: slice with o1 and o2 + }; + let s: std::result::Result<_, candle::cuda_backend::CudaError> = + s.try_clone().map_err(|v| v.into()); + let s = s?; + let s = candle::CudaStorage::wrap_cuda_slice(s, device); + Ok((s, l.shape().clone())) + } +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?; + println!("{t}"); + let t = t.custom_op1(LayerNorm)?; + println!("{t}"); + Ok(()) +} |