summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorOlivierDehaene <Olivier.dehaene@gmail.com>2023-10-16 16:37:38 +0200
committerGitHub <noreply@github.com>2023-10-16 15:37:38 +0100
commit75629981bc2b101400a301803c027da2362a4ff9 (patch)
tree105781868b6024facddbf05492dacd33873d4903 /candle-flash-attn
parent0106b0b04c3505a1155b3eab65ac212977c6c3dd (diff)
downloadcandle-75629981bc2b101400a301803c027da2362a4ff9.tar.gz
candle-75629981bc2b101400a301803c027da2362a4ff9.tar.bz2
candle-75629981bc2b101400a301803c027da2362a4ff9.zip
feat: parse Cuda compute cap from env (#1066)
* feat: add support for multiple compute caps * Revert to one compute cap * fmt * fix
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/build.rs88
1 files changed, 52 insertions, 36 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 64275fda..fde3aeed 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -84,12 +84,19 @@ fn main() -> Result<()> {
(kernel_dir.join(f), obj_file)
})
.collect();
+ let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
let should_compile = if out_file.exists() {
- cu_files.iter().any(|(cu_file, _)| {
- let out_modified = out_file.metadata().unwrap().modified().unwrap();
- let in_modified = cu_file.metadata().unwrap().modified().unwrap();
- in_modified.duration_since(out_modified).is_ok()
- })
+ kernel_dir
+ .read_dir()
+ .expect("kernels folder should exist")
+ .any(|entry| {
+ if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
+ let in_modified = entry.metadata().unwrap().modified().unwrap();
+ in_modified.duration_since(*out_modified).is_ok()
+ } else {
+ true
+ }
+ })
} else {
true
};
@@ -100,12 +107,19 @@ fn main() -> Result<()> {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
+ .arg("-O3")
+ .arg("-U__CUDA_NO_HALF_OPERATORS__")
+ .arg("-U__CUDA_NO_HALF_CONVERSIONS__")
+ .arg("-U__CUDA_NO_HALF2_OPERATORS__")
+ .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
+ .arg("--expt-extended-lambda")
+ .arg("--use_fast_math")
.arg("--verbose");
if let Ok(ccbin_path) = &ccbin_env {
command
@@ -203,13 +217,21 @@ fn set_cuda_include_dir() -> Result<()> {
#[allow(unused)]
fn compute_cap() -> Result<usize> {
- // Grab compute code from nvidia-smi
- let mut compute_cap = {
+ println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
+
+ // Try to parse compute caps from env
+ let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
+ println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
+ compute_cap_str
+ .parse::<usize>()
+ .context("Could not parse compute cap")?
+ } else {
+ // Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
- .arg("--query-gpu=compute_cap")
- .arg("--format=csv")
- .output()
- .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
+ .arg("--query-gpu=compute_cap")
+ .arg("--format=csv")
+ .output()
+ .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
@@ -220,16 +242,19 @@ fn compute_cap() -> Result<usize> {
.next()
.context("missing line in stdout")?
.replace('.', "");
- cap.parse::<usize>()
- .with_context(|| format!("cannot parse as int {cap}"))?
+ let cap = cap
+ .parse::<usize>()
+ .with_context(|| format!("cannot parse as int {cap}"))?;
+ println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
+ cap
};
// Grab available GPU codes from nvcc and select the highest one
- let max_nvcc_code = {
+ let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
- .arg("--list-gpu-code")
- .output()
- .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
+ .arg("--list-gpu-code")
+ .output()
+ .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
@@ -243,30 +268,21 @@ fn compute_cap() -> Result<usize> {
}
}
codes.sort();
- if !codes.contains(&compute_cap) {
- anyhow::bail!(
- "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
- );
- }
- *codes.last().unwrap()
+ let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
+ (codes, max_nvcc_code)
};
- // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
- // then choose the highest gpu code in nvcc
+ // Check that nvcc supports the asked compute caps
+ if !supported_nvcc_codes.contains(&compute_cap) {
+ anyhow::bail!(
+ "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
+ );
+ }
if compute_cap > max_nvcc_code {
- println!(
- "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
+ anyhow::bail!(
+ "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
- compute_cap = max_nvcc_code;
}
- println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
- if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
- compute_cap = compute_cap_str
- .parse::<usize>()
- .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
- println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
- }
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}