diff options
author | Zsombor <gzsombor@gmail.com> | 2023-09-08 09:15:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 08:15:14 +0100 |
commit | cfcbec9fc70aca2b0e08f382dec8634f88b61bce (patch) | |
tree | 835cba4824ea05ead5164072eea20939e8c30540 /candle-flash-attn | |
parent | 3898e500debd632b520054ebfa42f8333323a20e (diff) | |
download | candle-cfcbec9fc70aca2b0e08f382dec8634f88b61bce.tar.gz candle-cfcbec9fc70aca2b0e08f382dec8634f88b61bce.tar.bz2 candle-cfcbec9fc70aca2b0e08f382dec8634f88b61bce.zip |
Add small customization to the build (#768)
* Add ability to override the compiler used by NVCC from an environment variable
* Allow relative paths in CANDLE_FLASH_ATTN_BUILD_DIR
* Add the compilation failure to the readme, with a possible solution
* Adjust the error message, and remove the special handling of the relative paths
Diffstat (limited to 'candle-flash-attn')
-rw-r--r-- | candle-flash-attn/build.rs | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 2a3b7eb1..4cc7e5fb 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -57,9 +57,17 @@ fn main() -> Result<()> { #[allow(clippy::redundant_clone)] out_dir.clone() } - Ok(build_dir) => PathBuf::from(build_dir), + Ok(build_dir) => + { + let path = PathBuf::from(build_dir); + path.canonicalize().expect(&format!("Directory doesn't exists: {} (the current directory is {})", &path.display(), std::env::current_dir()?.display())) + } }; set_cuda_include_dir()?; + + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); + let compute_cap = compute_cap()?; let out_file = build_dir.join("libflashattention.a"); @@ -95,14 +103,21 @@ fn main() -> Result<()> { .args(["--default-stream", "per-thread"]) .arg("-Icutlass/include") .arg("--expt-relaxed-constexpr") - .arg(cu_file); + .arg("--verbose"); + if let Ok(ccbin_path) = &ccbin_env { + command + .arg("-allow-unsupported-compiler") + .args(["-ccbin", ccbin_path]); + } + command.arg(cu_file); let output = command .spawn() .context("failed spawning nvcc")? .wait_with_output()?; if !output.status.success() { anyhow::bail!( - "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + &command, String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stderr) ) @@ -122,7 +137,8 @@ fn main() -> Result<()> { .wait_with_output()?; if !output.status.success() { anyhow::bail!( - "nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + &command, String::from_utf8_lossy(&output.stdout), String::from_utf8_lossy(&output.stderr) ) |