summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-kernels/build.rs210
1 files changed, 97 insertions, 113 deletions
diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs
index 1f4a4b79..3c8e96a9 100644
--- a/candle-kernels/build.rs
+++ b/candle-kernels/build.rs
@@ -3,22 +3,22 @@ fn main() {
println!("cargo:rerun-if-changed=build.rs");
cuda::set_include_dir();
- let kernel_paths = cuda::build_ptx();
- // println!("cargo:warning=kernels {kernel_paths:?}");
-
- let mut file = std::fs::File::create("src/lib.rs").unwrap();
- for kernel_path in kernel_paths {
- let name = kernel_path.file_stem().unwrap().to_str().unwrap();
- file.write_all(
- format!(
- r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
- name.to_uppercase().replace('.', "_"),
- name
+ let (write, kernel_paths) = cuda::build_ptx();
+ if write {
+ let mut file = std::fs::File::create("src/lib.rs").unwrap();
+ for kernel_path in kernel_paths {
+ let name = kernel_path.file_stem().unwrap().to_str().unwrap();
+ file.write_all(
+ format!(
+ r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
+ name.to_uppercase().replace('.', "_"),
+ name
+ )
+ .as_bytes(),
)
- .as_bytes(),
- )
- .unwrap();
- file.write_all(&[b'\n']).unwrap();
+ .unwrap();
+ file.write_all(&[b'\n']).unwrap();
+ }
}
}
@@ -70,7 +70,7 @@ mod cuda {
);
}
- pub fn build_ptx() -> Vec<std::path::PathBuf> {
+ pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
use rayon::prelude::*;
use std::path::PathBuf;
let out_dir = std::env::var("OUT_DIR").unwrap();
@@ -83,17 +83,13 @@ mod cuda {
.map(|p| p.unwrap())
.collect();
- for out_path in glob::glob(&format!("{out_dir}/**/*.ptx")).unwrap() {
- std::fs::remove_file(out_path.unwrap()).unwrap();
- }
-
println!("cargo:rerun-if-changed=src/");
- for path in &kernel_paths {
- println!("cargo:rerun-if-changed={}", path.display());
- }
+ // for path in &kernel_paths {
+ // println!("cargo:rerun-if-changed={}", path.display());
+ // }
for path in &mut include_directories {
- println!("cargo:rerun-if-changed={}", path.display());
+ // println!("cargo:rerun-if-changed={}", path.display());
let destination =
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
std::fs::copy(path.clone(), destination).unwrap();
@@ -110,123 +106,111 @@ mod cuda {
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
.collect::<Vec<_>>();
- #[cfg(feature = "ci-check")]
- {
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP=ci");
-
- for mut kernel_path in kernel_paths.into_iter() {
- kernel_path.set_extension("ptx");
-
- let mut ptx_path: PathBuf = out_dir.clone().into();
- ptx_path.push(kernel_path.as_path().file_name().unwrap());
- std::fs::File::create(ptx_path).unwrap();
- }
- }
-
- #[cfg(not(feature = "ci-check"))]
- {
- // let start = std::time::Instant::now();
+ // let start = std::time::Instant::now();
- // Grab compute code from nvidia-smi
- let mut compute_cap = {
- let out = std::process::Command::new("nvidia-smi")
+ // Grab compute code from nvidia-smi
+ let mut compute_cap = {
+ let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.expect("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.");
- let out = std::str::from_utf8(&out.stdout).unwrap();
- let mut lines = out.lines();
- assert_eq!(lines.next().unwrap(), "compute_cap");
- let cap = lines.next().unwrap().replace('.', "");
- cap.parse::<usize>().unwrap()
- };
-
- // Grab available GPU codes from nvcc and select the highest one
- let max_nvcc_code = {
- let out = std::process::Command::new("nvcc")
+ let out = std::str::from_utf8(&out.stdout).unwrap();
+ let mut lines = out.lines();
+ assert_eq!(lines.next().unwrap(), "compute_cap");
+ let cap = lines.next().unwrap().replace('.', "");
+ cap.parse::<usize>().unwrap()
+ };
+
+ // Grab available GPU codes from nvcc and select the highest one
+ let max_nvcc_code = {
+ let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
- let out = std::str::from_utf8(&out.stdout).unwrap();
-
- let out = out.lines().collect::<Vec<&str>>();
- let mut codes = Vec::with_capacity(out.len());
- for code in out {
- let code = code.split('_').collect::<Vec<&str>>();
- if !code.is_empty() && code.contains(&"sm") {
- if let Ok(num) = code[1].parse::<usize>() {
- codes.push(num);
- }
+ let out = std::str::from_utf8(&out.stdout).unwrap();
+
+ let out = out.lines().collect::<Vec<&str>>();
+ let mut codes = Vec::with_capacity(out.len());
+ for code in out {
+ let code = code.split('_').collect::<Vec<&str>>();
+ if !code.is_empty() && code.contains(&"sm") {
+ if let Ok(num) = code[1].parse::<usize>() {
+ codes.push(num);
}
}
- codes.sort();
- if !codes.contains(&compute_cap) {
- panic!("nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}.");
- }
- *codes.last().unwrap()
- };
-
- // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
- // then choose the highest gpu code in nvcc
- if compute_cap > max_nvcc_code {
- println!("cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}.");
- compute_cap = max_nvcc_code;
}
-
- println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
- if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
- compute_cap = compute_cap_str.parse::<usize>().unwrap();
- println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
+ codes.sort();
+ if !codes.contains(&compute_cap) {
+ panic!("nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}.");
}
+ *codes.last().unwrap()
+ };
+
+ // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
+ // then choose the highest gpu code in nvcc
+ if compute_cap > max_nvcc_code {
+ println!(
+ "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
+ );
+ compute_cap = max_nvcc_code;
+ }
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
+ println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
+ if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
+ compute_cap = compute_cap_str.parse::<usize>().unwrap();
+ println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
+ }
- kernel_paths
- .iter()
- .for_each(|p| println!("cargo:rerun-if-changed={}", p.display()));
+ println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
- let children = kernel_paths
+ let children = kernel_paths
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
- if output_filename.exists(){
+ let ignore = if output_filename.exists() {
+ let out_modified = output_filename.metadata().unwrap().modified().unwrap();
+ let in_modified = p.metadata().unwrap().modified().unwrap();
+ out_modified.duration_since(in_modified).is_ok()
+ }else{
+ false
+ };
+ if ignore{
None
}else{
- let mut command = std::process::Command::new("nvcc");
- command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
- .arg("--ptx")
- .args(["--default-stream", "per-thread"])
- .args(["--output-directory", &out_dir])
- // Flash attention only
- // .arg("--expt-relaxed-constexpr")
- .args(&include_options)
- .arg(p);
- // println!(
- // "cargo:warning={command:?}");
+ let mut command = std::process::Command::new("nvcc");
+ command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
+ .arg("--ptx")
+ .args(["--default-stream", "per-thread"])
+ .args(["--output-directory", &out_dir])
+ // Flash attention only
+ // .arg("--expt-relaxed-constexpr")
+ .args(&include_options)
+ .arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}})
.collect::<Vec<_>>();
- for (kernel_path, child) in children {
- let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
- assert!(
- output.status.success(),
- "nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
- String::from_utf8_lossy(&output.stdout),
- String::from_utf8_lossy(&output.stderr)
- );
- }
-
- // println!(
- // "cargo:warning=Compiled {:?} cuda kernels in {:?}",
- // n,
- // start.elapsed()
- // );
+ let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
+ .unwrap()
+ .map(|p| p.unwrap())
+ .collect();
+ // We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
+ // some old ones
+ let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
+ for (kernel_path, child) in children {
+ let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
+ assert!(
+ output.status.success(),
+ "nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
+ String::from_utf8_lossy(&output.stdout),
+ String::from_utf8_lossy(&output.stderr)
+ );
}
- kernel_paths
+ (write, kernel_paths)
}
}