diff options
Diffstat (limited to 'candle-examples/build.rs')
| -rw-r--r-- | candle-examples/build.rs | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/candle-examples/build.rs b/candle-examples/build.rs index e21f1767..0af3a6a4 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -32,6 +32,8 @@ impl KernelDirectories { if should_compile { #[cfg(feature = "cuda")] { + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); let mut command = std::process::Command::new("nvcc"); let out_dir = ptx_file.parent().context("no parent for ptx file")?; let include_dirs: Vec<String> = @@ -44,6 +46,11 @@ impl KernelDirectories { .arg(format!("-I/{}", self.kernel_dir)) .args(include_dirs) .arg(cu_file); + if let Ok(ccbin_path) = &ccbin_env { + command + .arg("-allow-unsupported-compiler") + .args(["-ccbin", ccbin_path]); + } let output = command .spawn() .context("failed spawning nvcc")? @@ -168,8 +175,16 @@ fn set_cuda_include_dir() -> Result<()> { #[allow(unused)] fn compute_cap() -> Result<usize> { - // Grab compute code from nvidia-smi - let mut compute_cap = { + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + + // Try to parse compute cap from env + let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); + compute_cap_str + .parse::<usize>() + .context("Could not parse code")? + } else { + // Grab compute cap from nvidia-smi let out = std::process::Command::new("nvidia-smi") .arg("--query-gpu=compute_cap") .arg("--format=csv") @@ -185,6 +200,7 @@ fn compute_cap() -> Result<usize> { .next() .context("missing line in stdout")? .replace('.', ""); + println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); cap.parse::<usize>() .with_context(|| format!("cannot parse as int {cap}"))? }; |
