summaryrefslogtreecommitdiff
path: root/candle-flash-attn/build.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-26 08:04:02 +0100
committerGitHub <noreply@github.com>2023-07-26 08:04:02 +0100
commit471855e2eec29ffd082dc3ea22157602baae3085 (patch)
treedeeb3b17a478688d6b1896d64a3b2e11a43c0ec0 /candle-flash-attn/build.rs
parentd9f9c859afaeed95df420aca5fdb73f52f9239c5 (diff)
downloadcandle-471855e2eec29ffd082dc3ea22157602baae3085.tar.gz
candle-471855e2eec29ffd082dc3ea22157602baae3085.tar.bz2
candle-471855e2eec29ffd082dc3ea22157602baae3085.zip
Specific cache dir for the flash attn build artifacts. (#242)
Diffstat (limited to 'candle-flash-attn/build.rs')
-rw-r--r--candle-flash-attn/build.rs20
1 files changed, 10 insertions, 10 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index dc5b82e8..05affcbe 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -1,6 +1,7 @@
-#![allow(unused)]
+// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel.
+// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
+// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
-use std::io::Write;
use std::path::PathBuf;
fn main() -> Result<()> {
@@ -16,17 +17,16 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=kernels/block_info.h");
println!("cargo:rerun-if-changed=kernels/static_switch.h");
- let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
- let mut out_dir = PathBuf::from(out_dir);
- // TODO: Getting up two levels avoid having to recompile this too often, however it's likely
- // not a safe assumption.
- out_dir.pop();
- out_dir.pop();
+ let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
+ Err(_) => std::env::var("OUT_DIR").context("OUT_DIR not set")?,
+ Ok(build_dir) => build_dir,
+ };
+ let build_dir = PathBuf::from(build_dir);
set_cuda_include_dir()?;
let compute_cap = compute_cap()?;
let mut command = std::process::Command::new("nvcc");
- let out_file = out_dir.join("libflashattention.a");
+ let out_file = build_dir.join("libflashattention.a");
let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu");
let should_compile = if out_file.exists() {
@@ -57,7 +57,7 @@ fn main() -> Result<()> {
)
}
}
- println!("cargo:rustc-link-search={}", out_dir.display());
+ println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");