summaryrefslogtreecommitdiff
path: root/candle-examples/build.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/build.rs')
-rw-r--r--candle-examples/build.rs20
1 files changed, 18 insertions, 2 deletions
diff --git a/candle-examples/build.rs b/candle-examples/build.rs
index e21f1767..0af3a6a4 100644
--- a/candle-examples/build.rs
+++ b/candle-examples/build.rs
@@ -32,6 +32,8 @@ impl KernelDirectories {
if should_compile {
#[cfg(feature = "cuda")]
{
+ let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
+ println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
@@ -44,6 +46,11 @@ impl KernelDirectories {
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
+ if let Ok(ccbin_path) = &ccbin_env {
+ command
+ .arg("-allow-unsupported-compiler")
+ .args(["-ccbin", ccbin_path]);
+ }
let output = command
.spawn()
.context("failed spawning nvcc")?
@@ -168,8 +175,16 @@ fn set_cuda_include_dir() -> Result<()> {
#[allow(unused)]
fn compute_cap() -> Result<usize> {
- // Grab compute code from nvidia-smi
- let mut compute_cap = {
+ println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
+
+ // Try to parse compute cap from env
+ let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
+ println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
+ compute_cap_str
+ .parse::<usize>()
+ .context("Could not parse code")?
+ } else {
+ // Grab compute cap from nvidia-smi
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
@@ -185,6 +200,7 @@ fn compute_cap() -> Result<usize> {
.next()
.context("missing line in stdout")?
.replace('.', "");
+ println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};