summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs44
-rw-r--r--candle-examples/Cargo.toml3
-rw-r--r--candle-examples/build.rs231
-rw-r--r--candle-examples/examples/custom-ops/cuda_kernels.rs1
-rw-r--r--candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu37
-rw-r--r--candle-examples/examples/custom-ops/kernels/reduction_utils.cuh46
-rw-r--r--candle-examples/examples/custom-ops/main.rs65
7 files changed, 427 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index f9fefe17..d2cc3e41 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -771,6 +771,50 @@ pub struct CudaStorage {
device: CudaDevice,
}
+pub trait CudaDType: Sized {
+ fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
+ fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
+}
+
+macro_rules! cuda_dtype {
+ ($ty:ty, $dtype:ident) => {
+ impl CudaDType for $ty {
+ fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> {
+ match &s.slice {
+ CudaStorageSlice::$dtype(data) => Ok(&data),
+ _ => Err(crate::Error::UnexpectedDType {
+ expected: DType::$dtype,
+ got: s.dtype(),
+ msg: "unexpected dtype",
+ }
+ .bt()),
+ }
+ }
+
+ fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
+ let slice = CudaStorageSlice::$dtype(slice);
+ CudaStorage { slice, device }
+ }
+ }
+ };
+}
+cuda_dtype!(u8, U8);
+cuda_dtype!(u32, U32);
+cuda_dtype!(f16, F16);
+cuda_dtype!(bf16, BF16);
+cuda_dtype!(f32, F32);
+cuda_dtype!(f64, F64);
+
+impl CudaStorage {
+ pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage {
+ T::wrap_cuda_slice(slice, device)
+ }
+
+ pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
+ T::as_cuda_slice(self)
+ }
+}
+
fn gemm_config<T>(
alpha: T,
beta: T,
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 24435e81..f940a937 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -30,6 +30,9 @@ tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
+[build-dependencies]
+anyhow = { workspace = true }
+
[features]
default = []
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
diff --git a/candle-examples/build.rs b/candle-examples/build.rs
new file mode 100644
index 00000000..7f69fa77
--- /dev/null
+++ b/candle-examples/build.rs
@@ -0,0 +1,231 @@
+#![allow(unused)]
+use anyhow::{Context, Result};
+use std::io::Write;
+use std::path::PathBuf;
+
+struct KernelDirectories {
+ kernel_dir: &'static str,
+ rust_target: &'static str,
+}
+
+const DIRS: [KernelDirectories; 1] = [KernelDirectories {
+ kernel_dir: "examples/custom-ops/kernels/",
+ rust_target: "examples/custom-ops/cuda_kernels.rs",
+}];
+
+impl KernelDirectories {
+ fn maybe_build_ptx(
+ &self,
+ cu_file: &std::path::Path,
+ ptx_file: &std::path::Path,
+ compute_cap: usize,
+ ) -> Result<()> {
+ let should_compile = if ptx_file.exists() {
+ let ptx_modified = ptx_file.metadata()?.modified()?;
+ let cu_modified = cu_file.metadata()?.modified()?;
+ cu_modified.duration_since(ptx_modified).is_ok()
+ } else {
+ true
+ };
+ if should_compile {
+ #[cfg(feature = "cuda")]
+ {
+ let mut command = std::process::Command::new("nvcc");
+ let out_dir = ptx_file.parent().context("no parent for ptx file")?;
+ command
+ .arg(format!("--gpu-architecture=sm_{compute_cap}"))
+ .arg("--ptx")
+ .args(["--default-stream", "per-thread"])
+ .args(["--output-directory", out_dir.to_str().unwrap()])
+ .arg(format!("-I/{}", self.kernel_dir))
+ .arg(cu_file);
+ let output = command
+ .spawn()
+ .context("failed spawning nvcc")?
+ .wait_with_output()?;
+ if !output.status.success() {
+ anyhow::bail!(
+ "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ String::from_utf8_lossy(&output.stdout),
+ String::from_utf8_lossy(&output.stderr)
+ )
+ }
+ }
+ #[cfg(not(feature = "cuda"))]
+ std::fs::OpenOptions::new()
+ .create(true)
+ .write(true)
+ .open(ptx_file)?;
+ }
+ Ok(())
+ }
+ fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
+ println!("cargo:rerun-if-changed={}", self.kernel_dir);
+ let kernel_dir = PathBuf::from(self.kernel_dir);
+ let out_dir = out_dir.join(self.kernel_dir);
+ if !out_dir.exists() {
+ std::fs::create_dir_all(&out_dir)?;
+ }
+ let mut cu_files = vec![];
+ let mut cuh_files = vec![];
+ for file in std::fs::read_dir(kernel_dir)?.flatten() {
+ let file = file.path();
+ match file.extension().and_then(|v| v.to_str()) {
+ Some("cu") => cu_files.push(file),
+ Some("cuh") => cuh_files.push(file),
+ _ => {}
+ }
+ }
+
+ let mut ptx_paths = vec![];
+ for cu_file in cu_files.iter() {
+ let file_stem = cu_file
+ .file_stem()
+ .with_context(|| format!("no stem {cu_file:?}"))?;
+ let file_stem = file_stem.to_string_lossy().into_owned();
+ let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
+ self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
+ ptx_paths.push(ptx_file);
+ }
+
+ let regenerate_rs_file = true;
+ if regenerate_rs_file {
+ let mut file = std::fs::File::create(self.rust_target)?;
+ for ptx_path in ptx_paths {
+ let name = ptx_path
+ .file_stem()
+ .context("empty stem")?
+ .to_string_lossy();
+ let const_definition = format!(
+ r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
+ name.to_uppercase().replace('.', "_"),
+ self.kernel_dir,
+ );
+ file.write_all(const_definition.as_bytes())?;
+ file.write_all(b"\n")?;
+ }
+ }
+ Ok(())
+ }
+}
+
+fn main() -> Result<()> {
+ println!("cargo:rerun-if-changed=build.rs");
+
+ let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
+ let out_dir = PathBuf::from(out_dir);
+ #[cfg(feature = "cuda")]
+ set_cuda_include_dir()?;
+ #[cfg(feature = "cuda")]
+ let compute_cap = compute_cap()?;
+ #[cfg(not(feature = "cuda"))]
+ let compute_cap = 0;
+ for d in DIRS {
+ d.process(&out_dir, compute_cap)?
+ }
+ Ok(())
+}
+
+fn set_cuda_include_dir() -> Result<()> {
+ // NOTE: copied from cudarc build.rs.
+ let env_vars = [
+ "CUDA_PATH",
+ "CUDA_ROOT",
+ "CUDA_TOOLKIT_ROOT_DIR",
+ "CUDNN_LIB",
+ ];
+ let env_vars = env_vars
+ .into_iter()
+ .map(std::env::var)
+ .filter_map(Result::ok)
+ .map(Into::<PathBuf>::into);
+
+ let roots = [
+ "/usr",
+ "/usr/local/cuda",
+ "/opt/cuda",
+ "/usr/lib/cuda",
+ "C:/Program Files/NVIDIA GPU Computing Toolkit",
+ "C:/CUDA",
+ ];
+ let roots = roots.into_iter().map(Into::<PathBuf>::into);
+ let root = env_vars
+ .chain(roots)
+ .find(|path| path.join("include").join("cuda.h").is_file())
+ .context("cannot find include/cuda.h")?;
+ println!(
+ "cargo:rustc-env=CUDA_INCLUDE_DIR={}",
+ root.join("include").display()
+ );
+ Ok(())
+}
+
+#[allow(unused)]
+fn compute_cap() -> Result<usize> {
+ // Grab compute code from nvidia-smi
+ let mut compute_cap = {
+ let out = std::process::Command::new("nvidia-smi")
+ .arg("--query-gpu=compute_cap")
+ .arg("--format=csv")
+ .output()
+ .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
+ let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
+ let mut lines = out.lines();
+ assert_eq!(
+ lines.next().context("missing line in stdout")?,
+ "compute_cap"
+ );
+ let cap = lines
+ .next()
+ .context("missing line in stdout")?
+ .replace('.', "");
+ cap.parse::<usize>()
+ .with_context(|| format!("cannot parse as int {cap}"))?
+ };
+
+ // Grab available GPU codes from nvcc and select the highest one
+ let max_nvcc_code = {
+ let out = std::process::Command::new("nvcc")
+ .arg("--list-gpu-code")
+ .output()
+ .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
+ let out = std::str::from_utf8(&out.stdout).unwrap();
+
+ let out = out.lines().collect::<Vec<&str>>();
+ let mut codes = Vec::with_capacity(out.len());
+ for code in out {
+ let code = code.split('_').collect::<Vec<&str>>();
+ if !code.is_empty() && code.contains(&"sm") {
+ if let Ok(num) = code[1].parse::<usize>() {
+ codes.push(num);
+ }
+ }
+ }
+ codes.sort();
+ if !codes.contains(&compute_cap) {
+ anyhow::bail!(
+ "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
+ );
+ }
+ *codes.last().unwrap()
+ };
+
+ // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
+ // then choose the highest gpu code in nvcc
+ if compute_cap > max_nvcc_code {
+ println!(
+ "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
+ );
+ compute_cap = max_nvcc_code;
+ }
+
+ println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
+ if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
+ compute_cap = compute_cap_str
+ .parse::<usize>()
+ .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
+ println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
+ }
+ println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
+ Ok(compute_cap)
+}
diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs
new file mode 100644
index 00000000..07d18342
--- /dev/null
+++ b/candle-examples/examples/custom-ops/cuda_kernels.rs
@@ -0,0 +1 @@
+pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
diff --git a/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu
new file mode 100644
index 00000000..07ab8639
--- /dev/null
+++ b/candle-examples/examples/custom-ops/kernels/layernorm_kernels.cu
@@ -0,0 +1,37 @@
+#include "reduction_utils.cuh"
+
+template <typename scalar_t>
+__device__ void
+rms_norm_kernel(scalar_t *__restrict__ out, // [num_tokens, hidden_size]
+ const scalar_t *__restrict__ input, // [num_tokens, hidden_size]
+ const scalar_t *__restrict__ weight, // [hidden_size]
+ const float epsilon, const int num_tokens,
+ const int hidden_size) {
+ __shared__ float s_variance;
+ float variance = 0.0f;
+
+ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
+ const float x = (float)input[blockIdx.x * hidden_size + idx];
+ variance += x * x;
+ }
+ variance = blockReduceSum<float>(variance);
+ if (threadIdx.x == 0) {
+ s_variance = rsqrtf(variance / hidden_size + epsilon);
+ }
+ __syncthreads();
+
+ for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
+ float x = (float)input[blockIdx.x * hidden_size + idx];
+ out[blockIdx.x * hidden_size + idx] =
+ ((scalar_t)(x * s_variance)) * weight[idx];
+ }
+}
+extern "C" __global__ void rms_norm_kernel_f32(
+ float *__restrict__ out, // [num_tokens, hidden_size]
+ const float *__restrict__ input, // [num_tokens, hidden_size]
+ const float *__restrict__ weight, // [hidden_size]
+ const float epsilon, const int num_tokens,
+ const int hidden_size) {
+ rms_norm_kernel(out, input, weight, epsilon, num_tokens, hidden_size);
+}
+
diff --git a/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh b/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh
new file mode 100644
index 00000000..d5765f4f
--- /dev/null
+++ b/candle-examples/examples/custom-ops/kernels/reduction_utils.cuh
@@ -0,0 +1,46 @@
+/*
+ * Adapted from
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+template <typename T> __inline__ __device__ T warpReduceSum(T val) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1)
+ val += __shfl_xor_sync(0xffffffff, val, mask, 32);
+ return val;
+}
+
+/* Calculate the sum of all elements in a block */
+template <typename T> __inline__ __device__ T blockReduceSum(T val) {
+ static __shared__ T shared[32];
+ int lane = threadIdx.x & 0x1f;
+ int wid = threadIdx.x >> 5;
+
+ val = warpReduceSum<T>(val);
+
+ if (lane == 0)
+ shared[wid] = val;
+
+ __syncthreads();
+
+ // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
+ // blockDim.x is not divided by 32
+ val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
+ val = warpReduceSum<T>(val);
+ return val;
+}
diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs
new file mode 100644
index 00000000..adc7abd7
--- /dev/null
+++ b/candle-examples/examples/custom-ops/main.rs
@@ -0,0 +1,65 @@
+#![allow(dead_code)]
+#![allow(unused)]
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use clap::Parser;
+
+use candle::backend::BackendStorage;
+use candle::cpu_backend;
+use candle::{CpuStorage, CustomOp1, DType, Device, Error, Layout, Result, Shape, Tensor};
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+}
+
+struct LayerNorm;
+
+impl CustomOp1 for LayerNorm {
+ fn name(&self) -> &'static str {
+ "layer-norm"
+ }
+
+ fn cpu_fwd(&self, s: &CpuStorage, l: &Layout) -> Result<(CpuStorage, Shape)> {
+ let s = s.as_slice::<f32>()?;
+ let _s = match l.contiguous_offsets() {
+ None => Err(Error::Wrapped("input has to be contiguous".into()))?,
+ Some((o1, o2)) => &s[o1..o2],
+ };
+ todo!()
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ s: &candle::CudaStorage,
+ l: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ let device = s.device().clone();
+ let s = s.as_cuda_slice::<f32>()?;
+ let s = match l.contiguous_offsets() {
+ None => Err(Error::Wrapped("input has to be contiguous".into()))?,
+ Some((o1, o2)) => s, // TODO: slice with o1 and o2
+ };
+ let s: std::result::Result<_, candle::cuda_backend::CudaError> =
+ s.try_clone().map_err(|v| v.into());
+ let s = s?;
+ let s = candle::CudaStorage::wrap_cuda_slice(s, device);
+ Ok((s, l.shape().clone()))
+ }
+}
+
+fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+ let device = candle_examples::device(args.cpu)?;
+ let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
+ println!("{t}");
+ let t = t.custom_op1(LayerNorm)?;
+ println!("{t}");
+ Ok(())
+}