diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-26 08:04:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-26 08:04:02 +0100 |
commit | 471855e2eec29ffd082dc3ea22157602baae3085 (patch) | |
tree | deeb3b17a478688d6b1896d64a3b2e11a43c0ec0 /candle-flash-attn/build.rs | |
parent | d9f9c859afaeed95df420aca5fdb73f52f9239c5 (diff) | |
download | candle-471855e2eec29ffd082dc3ea22157602baae3085.tar.gz candle-471855e2eec29ffd082dc3ea22157602baae3085.tar.bz2 candle-471855e2eec29ffd082dc3ea22157602baae3085.zip |
Specific cache dir for the flash attn build artifacts. (#242)
Diffstat (limited to 'candle-flash-attn/build.rs')
-rw-r--r-- | candle-flash-attn/build.rs | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index dc5b82e8..05affcbe 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -1,6 +1,7 @@ -#![allow(unused)] +// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel. +// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment +// variable in order to cache the compiled artifacts and avoid recompiling too often. use anyhow::{Context, Result}; -use std::io::Write; use std::path::PathBuf; fn main() -> Result<()> { @@ -16,17 +17,16 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=kernels/block_info.h"); println!("cargo:rerun-if-changed=kernels/static_switch.h"); - let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; - let mut out_dir = PathBuf::from(out_dir); - // TODO: Getting up two levels avoid having to recompile this too often, however it's likely - // not a safe assumption. - out_dir.pop(); - out_dir.pop(); + let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") { + Err(_) => std::env::var("OUT_DIR").context("OUT_DIR not set")?, + Ok(build_dir) => build_dir, + }; + let build_dir = PathBuf::from(build_dir); set_cuda_include_dir()?; let compute_cap = compute_cap()?; let mut command = std::process::Command::new("nvcc"); - let out_file = out_dir.join("libflashattention.a"); + let out_file = build_dir.join("libflashattention.a"); let cu_file = PathBuf::from("kernels/flash_fwd_hdim32_fp16_sm80.cu"); let should_compile = if out_file.exists() { @@ -57,7 +57,7 @@ fn main() -> Result<()> { ) } } - println!("cargo:rustc-link-search={}", out_dir.display()); + println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=stdc++"); |