diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-26 07:48:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-26 07:48:10 +0100 |
commit | d9f9c859afaeed95df420aca5fdb73f52f9239c5 (patch) | |
tree | 2ef898b2906a24b57ea42b0294bc51b928f0513c /candle-flash-attn/build.rs | |
parent | c97d51243c177e0497ea7147f426c4cc1e532c9b (diff) | |
download | candle-d9f9c859afaeed95df420aca5fdb73f52f9239c5.tar.gz candle-d9f9c859afaeed95df420aca5fdb73f52f9239c5.tar.bz2 candle-d9f9c859afaeed95df420aca5fdb73f52f9239c5.zip |
Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab.
* More flash attn.
* Set up the flash attn parameters.
* Get things to compile locally.
* Move the flash attention files in a different directory.
* Build the static C library with nvcc.
* Add more flash attention.
* Update the build part.
* Better caching.
* Exclude flash attention from the default workspace.
* Put flash-attn behind a feature gate.
* Get the flash attn kernel to run.
* Move the flags to a more appropriate place.
* Enable flash attention in llama.
* Use flash attention in llama.
Diffstat (limited to 'candle-flash-attn/build.rs')
-rw-r--r-- | candle-flash-attn/build.rs | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs new file mode 100644 index 00000000..dc5b82e8 --- /dev/null +++ b/candle-flash-attn/build.rs @@ -0,0 +1,182 @@ +#![allow(unused)] +use anyhow::{Context, Result}; +use std::io::Write; +use std::path::PathBuf; + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-changed=kernels/flash_fwd_hdim32_fp16_sm80.cu"); + println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); + println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); + println!("cargo:rerun-if-changed=kernels/flash.h"); + println!("cargo:rerun-if-changed=kernels/philox.cuh"); + println!("cargo:rerun-if-changed=kernels/softmax.h"); + println!("cargo:rerun-if-changed=kernels/utils.h"); + println!("cargo:rerun-if-changed=kernels/kernel_traits.h"); + println!("cargo:rerun-if-changed=kernels/block_info.h"); + println!("cargo:rerun-if-changed=kernels/static_switch.h"); + + let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; + let mut out_dir = PathBuf::from(out_dir); + // TODO: Getting up two levels avoid having to recompile this too often, however it's likely + // not a safe assumption. + out_dir.pop(); + out_dir.pop(); + set_cuda_include_dir()?; + let compute_cap = compute_cap()?; + + let mut command = std::process::Command::new("nvcc"); + let out_file = out_dir.join("libflashattention.a"); + + let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu"); + let should_compile = if out_file.exists() { + let out_modified = out_file.metadata()?.modified()?; + let in_modified = cu_file.metadata()?.modified()?; + in_modified.duration_since(out_modified).is_ok() + } else { + true + }; + if should_compile { + command + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("--lib") + .args(["-o", out_file.to_str().unwrap()]) + .args(["--default-stream", "per-thread"]) + .arg("-Icutlass/include") + .arg("--expt-relaxed-constexpr") + .arg(cu_file); + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + if !output.status.success() { + anyhow::bail!( + "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + } + println!("cargo:rustc-link-search={}", out_dir.display()); + println!("cargo:rustc-link-lib=flashattention"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=stdc++"); + + /* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never + finishing to run for some reason. Calling nvcc manually worked fine. + cc::Build::new() + .cuda(true) + .include("cutlass/include") + .flag("--expt-relaxed-constexpr") + .flag("--default-stream") + .flag("per-thread") + .flag(&format!("--gpu-architecture=sm_{compute_cap}")) + .file("kernels/flash_fwd_hdim32_fp16_sm80.cu") + .compile("flashattn"); + */ + Ok(()) +} + +fn set_cuda_include_dir() -> Result<()> { + // NOTE: copied from cudarc build.rs. + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .map(std::env::var) + .filter_map(Result::ok) + .map(Into::<PathBuf>::into); + + let roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let roots = roots.into_iter().map(Into::<PathBuf>::into); + let root = env_vars + .chain(roots) + .find(|path| path.join("include").join("cuda.h").is_file()) + .context("cannot find include/cuda.h")?; + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +#[allow(unused)] +fn compute_cap() -> Result<usize> { + // Grab compute code from nvidia-smi + let mut compute_cap = { + let out = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; + let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; + let mut lines = out.lines(); + assert_eq!( + lines.next().context("missing line in stdout")?, + "compute_cap" + ); + let cap = lines + .next() + .context("missing line in stdout")? + .replace('.', ""); + cap.parse::<usize>() + .with_context(|| format!("cannot parse as int {cap}"))? + }; + + // Grab available GPU codes from nvcc and select the highest one + let max_nvcc_code = { + let out = std::process::Command::new("nvcc") + .arg("--list-gpu-code") + .output() + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); + let out = std::str::from_utf8(&out.stdout).unwrap(); + + let out = out.lines().collect::<Vec<&str>>(); + let mut codes = Vec::with_capacity(out.len()); + for code in out { + let code = code.split('_').collect::<Vec<&str>>(); + if !code.is_empty() && code.contains(&"sm") { + if let Ok(num) = code[1].parse::<usize>() { + codes.push(num); + } + } + } + codes.sort(); + if !codes.contains(&compute_cap) { + anyhow::bail!( + "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." + ); + } + *codes.last().unwrap() + }; + + // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, + // then choose the highest gpu code in nvcc + if compute_cap > max_nvcc_code { + println!( + "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." + ); + compute_cap = max_nvcc_code; + } + + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + compute_cap = compute_cap_str + .parse::<usize>() + .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; + println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); + } + println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); + Ok(compute_cap) +} |