summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2024-01-07 12:29:24 +0100
committerGitHub <noreply@github.com>2024-01-07 12:29:24 +0100
commit30313c308106fff7b20fc8cb2b27eb79800cb818 (patch)
tree993098c3494b335064906eeb17dfdca0c3c543c1 /candle-flash-attn
parente72d52b1a2118f8773866e87237586bab762a9c6 (diff)
downloadcandle-30313c308106fff7b20fc8cb2b27eb79800cb818.tar.gz
candle-30313c308106fff7b20fc8cb2b27eb79800cb818.tar.bz2
candle-30313c308106fff7b20fc8cb2b27eb79800cb818.zip
Moving to a proper build crate `bindgen_cuda`. (#1531)
* Moving to a proper build crate `bindgen_cuda`. * Fmt.
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/Cargo.toml4
-rw-r--r--candle-flash-attn/build.rs273
2 files changed, 36 insertions, 241 deletions
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml
index 0d3af91d..d8e8da82 100644
--- a/candle-flash-attn/Cargo.toml
+++ b/candle-flash-attn/Cargo.toml
@@ -15,9 +15,9 @@ candle = { path = "../candle-core", features = ["cuda"], package = "candle-core"
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
+bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] }
-num_cpus = "1.15.0"
-rayon = "1.7.0"
+
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index fde3aeed..4002770b 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -2,44 +2,32 @@
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
-use rayon::prelude::*;
use std::path::PathBuf;
-use std::str::FromStr;
const KERNEL_FILES: [&str; 17] = [
- "flash_api.cu",
- "flash_fwd_hdim128_fp16_sm80.cu",
- "flash_fwd_hdim160_fp16_sm80.cu",
- "flash_fwd_hdim192_fp16_sm80.cu",
- "flash_fwd_hdim224_fp16_sm80.cu",
- "flash_fwd_hdim256_fp16_sm80.cu",
- "flash_fwd_hdim32_fp16_sm80.cu",
- "flash_fwd_hdim64_fp16_sm80.cu",
- "flash_fwd_hdim96_fp16_sm80.cu",
- "flash_fwd_hdim128_bf16_sm80.cu",
- "flash_fwd_hdim160_bf16_sm80.cu",
- "flash_fwd_hdim192_bf16_sm80.cu",
- "flash_fwd_hdim224_bf16_sm80.cu",
- "flash_fwd_hdim256_bf16_sm80.cu",
- "flash_fwd_hdim32_bf16_sm80.cu",
- "flash_fwd_hdim64_bf16_sm80.cu",
- "flash_fwd_hdim96_bf16_sm80.cu",
+ "kernels/flash_api.cu",
+ "kernels/flash_fwd_hdim128_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim160_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim192_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim224_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim256_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim32_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim64_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim96_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim128_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim160_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim192_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim224_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim256_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim32_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim64_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim96_bf16_sm80.cu",
];
fn main() -> Result<()> {
- let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
- |_| num_cpus::get_physical(),
- |s| usize::from_str(&s).unwrap(),
- );
-
- rayon::ThreadPoolBuilder::new()
- .num_threads(num_cpus)
- .build_global()
- .unwrap();
-
println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() {
- println!("cargo:rerun-if-changed=kernels/{kernel_file}");
+ println!("cargo:rerun-if-changed={kernel_file}");
}
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
@@ -66,223 +54,30 @@ fn main() -> Result<()> {
))
}
};
- set_cuda_include_dir()?;
-
- let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
- println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
- let compute_cap = compute_cap()?;
+ let kernels = KERNEL_FILES.iter().collect();
+ let builder = bindgen_cuda::Builder::default()
+ .kernel_paths(kernels)
+ .out_dir(build_dir.clone())
+ .arg("-std=c++17")
+ .arg("-O3")
+ .arg("-U__CUDA_NO_HALF_OPERATORS__")
+ .arg("-U__CUDA_NO_HALF_CONVERSIONS__")
+ .arg("-U__CUDA_NO_HALF2_OPERATORS__")
+ .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
+ .arg("-Icutlass/include")
+ .arg("--expt-relaxed-constexpr")
+ .arg("--expt-extended-lambda")
+ .arg("--use_fast_math")
+ .arg("--verbose");
let out_file = build_dir.join("libflashattention.a");
+ builder.build_lib(out_file);
- let kernel_dir = PathBuf::from("kernels");
- let cu_files: Vec<_> = KERNEL_FILES
- .iter()
- .map(|f| {
- let mut obj_file = out_dir.join(f);
- obj_file.set_extension("o");
- (kernel_dir.join(f), obj_file)
- })
- .collect();
- let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
- let should_compile = if out_file.exists() {
- kernel_dir
- .read_dir()
- .expect("kernels folder should exist")
- .any(|entry| {
- if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
- let in_modified = entry.metadata().unwrap().modified().unwrap();
- in_modified.duration_since(*out_modified).is_ok()
- } else {
- true
- }
- })
- } else {
- true
- };
- if should_compile {
- cu_files
- .par_iter()
- .map(|(cu_file, obj_file)| {
- let mut command = std::process::Command::new("nvcc");
- command
- .arg("-std=c++17")
- .arg("-O3")
- .arg("-U__CUDA_NO_HALF_OPERATORS__")
- .arg("-U__CUDA_NO_HALF_CONVERSIONS__")
- .arg("-U__CUDA_NO_HALF2_OPERATORS__")
- .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
- .arg(format!("--gpu-architecture=sm_{compute_cap}"))
- .arg("-c")
- .args(["-o", obj_file.to_str().unwrap()])
- .args(["--default-stream", "per-thread"])
- .arg("-Icutlass/include")
- .arg("--expt-relaxed-constexpr")
- .arg("--expt-extended-lambda")
- .arg("--use_fast_math")
- .arg("--verbose");
- if let Ok(ccbin_path) = &ccbin_env {
- command
- .arg("-allow-unsupported-compiler")
- .args(["-ccbin", ccbin_path]);
- }
- command.arg(cu_file);
- let output = command
- .spawn()
- .context("failed spawning nvcc")?
- .wait_with_output()?;
- if !output.status.success() {
- anyhow::bail!(
- "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
- &command,
- String::from_utf8_lossy(&output.stdout),
- String::from_utf8_lossy(&output.stderr)
- )
- }
- Ok(())
- })
- .collect::<Result<()>>()?;
- let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
- let mut command = std::process::Command::new("nvcc");
- command
- .arg("--lib")
- .args(["-o", out_file.to_str().unwrap()])
- .args(obj_files);
- let output = command
- .spawn()
- .context("failed spawning nvcc")?
- .wait_with_output()?;
- if !output.status.success() {
- anyhow::bail!(
- "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
- &command,
- String::from_utf8_lossy(&output.stdout),
- String::from_utf8_lossy(&output.stderr)
- )
- }
- }
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
- /* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
- finishing to run for some reason. Calling nvcc manually worked fine.
- cc::Build::new()
- .cuda(true)
- .include("cutlass/include")
- .flag("--expt-relaxed-constexpr")
- .flag("--default-stream")
- .flag("per-thread")
- .flag(&format!("--gpu-architecture=sm_{compute_cap}"))
- .file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
- .compile("flashattn");
- */
Ok(())
}
-
-fn set_cuda_include_dir() -> Result<()> {
- // NOTE: copied from cudarc build.rs.
- let env_vars = [
- "CUDA_PATH",
- "CUDA_ROOT",
- "CUDA_TOOLKIT_ROOT_DIR",
- "CUDNN_LIB",
- ];
- let env_vars = env_vars
- .into_iter()
- .map(std::env::var)
- .filter_map(Result::ok)
- .map(Into::<PathBuf>::into);
-
- let roots = [
- "/usr",
- "/usr/local/cuda",
- "/opt/cuda",
- "/usr/lib/cuda",
- "C:/Program Files/NVIDIA GPU Computing Toolkit",
- "C:/CUDA",
- ];
- let roots = roots.into_iter().map(Into::<PathBuf>::into);
- let root = env_vars
- .chain(roots)
- .find(|path| path.join("include").join("cuda.h").is_file())
- .context("cannot find include/cuda.h")?;
- println!(
- "cargo:rustc-env=CUDA_INCLUDE_DIR={}",
- root.join("include").display()
- );
- Ok(())
-}
-
-#[allow(unused)]
-fn compute_cap() -> Result<usize> {
- println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
-
- // Try to parse compute caps from env
- let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
- compute_cap_str
- .parse::<usize>()
- .context("Could not parse compute cap")?
- } else {
- // Use nvidia-smi to get the current compute cap
- let out = std::process::Command::new("nvidia-smi")
- .arg("--query-gpu=compute_cap")
- .arg("--format=csv")
- .output()
- .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
- let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
- let mut lines = out.lines();
- assert_eq!(
- lines.next().context("missing line in stdout")?,
- "compute_cap"
- );
- let cap = lines
- .next()
- .context("missing line in stdout")?
- .replace('.', "");
- let cap = cap
- .parse::<usize>()
- .with_context(|| format!("cannot parse as int {cap}"))?;
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
- cap
- };
-
- // Grab available GPU codes from nvcc and select the highest one
- let (supported_nvcc_codes, max_nvcc_code) = {
- let out = std::process::Command::new("nvcc")
- .arg("--list-gpu-code")
- .output()
- .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
- let out = std::str::from_utf8(&out.stdout).unwrap();
-
- let out = out.lines().collect::<Vec<&str>>();
- let mut codes = Vec::with_capacity(out.len());
- for code in out {
- let code = code.split('_').collect::<Vec<&str>>();
- if !code.is_empty() && code.contains(&"sm") {
- if let Ok(num) = code[1].parse::<usize>() {
- codes.push(num);
- }
- }
- }
- codes.sort();
- let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
- (codes, max_nvcc_code)
- };
-
- // Check that nvcc supports the asked compute caps
- if !supported_nvcc_codes.contains(&compute_cap) {
- anyhow::bail!(
- "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
- );
- }
- if compute_cap > max_nvcc_code {
- anyhow::bail!(
- "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
- );
- }
-
- Ok(compute_cap)
-}