summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorZsombor <gzsombor@gmail.com>2023-09-08 09:15:14 +0200
committerGitHub <noreply@github.com>2023-09-08 08:15:14 +0100
commitcfcbec9fc70aca2b0e08f382dec8634f88b61bce (patch)
tree835cba4824ea05ead5164072eea20939e8c30540 /candle-flash-attn
parent3898e500debd632b520054ebfa42f8333323a20e (diff)
downloadcandle-cfcbec9fc70aca2b0e08f382dec8634f88b61bce.tar.gz
candle-cfcbec9fc70aca2b0e08f382dec8634f88b61bce.tar.bz2
candle-cfcbec9fc70aca2b0e08f382dec8634f88b61bce.zip
Add small customization to the build (#768)
* Add ability to override the compiler used by NVCC from an environment variable * Allow relative paths in CANDLE_FLASH_ATTN_BUILD_DIR * Add the compilation failure to the readme, with a possible solution * Adjust the error message, and remove the special handling of the relative paths
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/build.rs24
1 files changed, 20 insertions, 4 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 2a3b7eb1..4cc7e5fb 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -57,9 +57,17 @@ fn main() -> Result<()> {
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
- Ok(build_dir) => PathBuf::from(build_dir),
+ Ok(build_dir) =>
+ {
+ let path = PathBuf::from(build_dir);
+ path.canonicalize().expect(&format!("Directory doesn't exists: {} (the current directory is {})", &path.display(), std::env::current_dir()?.display()))
+ }
};
set_cuda_include_dir()?;
+
+ let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
+ println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
+
let compute_cap = compute_cap()?;
let out_file = build_dir.join("libflashattention.a");
@@ -95,14 +103,21 @@ fn main() -> Result<()> {
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
- .arg(cu_file);
+ .arg("--verbose");
+ if let Ok(ccbin_path) = &ccbin_env {
+ command
+ .arg("-allow-unsupported-compiler")
+ .args(["-ccbin", ccbin_path]);
+ }
+ command.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
- "nvcc error while compiling:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ &command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
@@ -122,7 +137,8 @@ fn main() -> Result<()> {
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
- "nvcc error while linking:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ &command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)