summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/dependabot.yml7
-rw-r--r--Cargo.toml14
-rw-r--r--candle-book/Cargo.toml10
-rw-r--r--candle-core/Cargo.toml4
-rw-r--r--candle-core/benches/benchmarks/matmul.rs20
-rw-r--r--candle-core/benches/benchmarks/mod.rs50
-rw-r--r--candle-core/benches/benchmarks/random.rs25
-rw-r--r--candle-core/examples/tensor-tools.rs25
-rw-r--r--candle-core/src/metal_backend.rs58
-rw-r--r--candle-core/src/pickle.rs14
-rw-r--r--candle-core/src/quantized/k_quants.rs8
-rw-r--r--candle-core/src/quantized/neon.rs208
-rw-r--r--candle-core/tests/quantized_tests.rs59
-rw-r--r--candle-datasets/Cargo.toml4
-rw-r--r--candle-examples/Cargo.toml15
-rw-r--r--candle-examples/build.rs247
-rw-r--r--candle-examples/examples/custom-ops/cuda_kernels.rs3
-rw-r--r--candle-examples/examples/custom-ops/main.rs3
-rw-r--r--candle-examples/examples/repvgg/README.md20
-rw-r--r--candle-examples/examples/repvgg/main.rs111
-rw-r--r--candle-flash-attn/Cargo.toml8
-rw-r--r--candle-flash-attn/build.rs273
-rw-r--r--candle-kernels/Cargo.toml4
-rw-r--r--candle-kernels/build.rs243
-rw-r--r--candle-metal-kernels/Cargo.toml9
-rw-r--r--candle-metal-kernels/src/affine.metal2
-rw-r--r--candle-metal-kernels/src/binary.metal2
-rw-r--r--candle-metal-kernels/src/cast.metal42
-rw-r--r--candle-metal-kernels/src/indexing.metal5
-rw-r--r--candle-metal-kernels/src/lib.rs4
-rw-r--r--candle-metal-kernels/src/reduce.metal2
-rw-r--r--candle-metal-kernels/src/tests.rs154
-rw-r--r--candle-metal-kernels/src/unary.metal10
-rw-r--r--candle-nn/Cargo.toml4
-rw-r--r--candle-onnx/Cargo.toml5
-rw-r--r--candle-pyo3/Cargo.toml6
-rw-r--r--candle-transformers/Cargo.toml6
-rw-r--r--candle-transformers/src/models/bert.rs7
-rw-r--r--candle-transformers/src/models/bigcode.rs7
-rw-r--r--candle-transformers/src/models/falcon.rs7
-rw-r--r--candle-transformers/src/models/llama.rs9
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/repvgg.rs306
-rw-r--r--candle-transformers/src/models/whisper/model.rs7
-rw-r--r--candle-wasm-examples/bert/Cargo.toml8
-rw-r--r--candle-wasm-examples/blip/Cargo.toml6
-rw-r--r--candle-wasm-examples/llama2-c/Cargo.toml8
-rw-r--r--candle-wasm-examples/phi/Cargo.toml6
-rw-r--r--candle-wasm-examples/segment-anything/Cargo.toml6
-rw-r--r--candle-wasm-examples/t5/Cargo.toml8
-rw-r--r--candle-wasm-examples/whisper/Cargo.toml8
-rw-r--r--candle-wasm-examples/yolo/Cargo.toml6
-rw-r--r--candle-wasm-tests/Cargo.toml2
53 files changed, 1035 insertions, 1051 deletions
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 00000000..05bcdac6
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,7 @@
+version: 2
+updates:
+ - package-ecosystem: "cargo"
+ directory: "/"
+ schedule:
+ interval: "weekly"
+ open-pull-requests-limit: 5
diff --git a/Cargo.toml b/Cargo.toml
index 7d61cd74..2225c42e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,9 +31,17 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
+candle = { path = "./candle-core", package = "candle-core" }
+candle-datasets = { path = "./candle-datasets" }
+candle-flash-attn = { path = "./candle-flash-attn" }
+candle-kernels = { path = "./candle-kernels" }
+candle-metal-kernels = { path = "./candle-metal-kernels" }
+candle-nn = { path = "./candle-nn" }
+candle-onnx = { path = "./candle-onnx" }
+candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
-cudarc = { version = "0.9.14", features = ["f16"] }
+cudarc = { version = "0.10.0", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
@@ -42,7 +50,7 @@ imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
-memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
+memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "45.0.0" }
@@ -55,7 +63,7 @@ serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
-tokenizers = { version = "0.13.4", default-features = false }
+tokenizers = { version = "0.15.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml
index e28e6623..5ccda31e 100644
--- a/candle-book/Cargo.toml
+++ b/candle-book/Cargo.toml
@@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
-candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
-candle-nn = { path = "../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
+candle = { workspace = true }
+candle-datasets = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
+candle-flash-attn = { workspace = true, optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index afdb67cd..92a04917 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -12,8 +12,8 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
-candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true }
-candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true }
+candle-kernels = { workspace = true, optional = true }
+candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
diff --git a/candle-core/benches/benchmarks/matmul.rs b/candle-core/benches/benchmarks/matmul.rs
index fb173f04..9d67e642 100644
--- a/candle-core/benches/benchmarks/matmul.rs
+++ b/candle-core/benches/benchmarks/matmul.rs
@@ -1,5 +1,5 @@
-use crate::benchmarks::{bench_name, device, BenchDevice};
-use candle_core::{DType, Tensor};
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
@@ -7,20 +7,19 @@ fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
-fn criterion_benchmark(c: &mut Criterion) {
+fn run_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
- let device = device().unwrap();
let dtype = DType::F32;
- let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
- let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
+ let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap();
+ let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap();
let flops = b * m * n * k;
- let mut group = c.benchmark_group(bench_name("matmul"));
+ let mut group = c.benchmark_group(device.bench_name("matmul"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
@@ -35,4 +34,11 @@ fn criterion_benchmark(c: &mut Criterion) {
group.finish();
}
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_bench(c, &device);
+ }
+}
+
criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs
index 6bb37a70..eb20ea70 100644
--- a/candle-core/benches/benchmarks/mod.rs
+++ b/candle-core/benches/benchmarks/mod.rs
@@ -5,6 +5,8 @@ use candle_core::{Device, Result};
pub(crate) trait BenchDevice {
fn sync(&self) -> Result<()>;
+
+ fn bench_name<S: Into<String>>(&self, name: S) -> String;
}
impl BenchDevice for Device {
@@ -25,32 +27,38 @@ impl BenchDevice for Device {
}
}
}
-}
-pub(crate) fn device() -> Result<Device> {
- if cfg!(feature = "metal") {
- Device::new_metal(0)
- } else if cfg!(feature = "cuda") {
- Device::new_cuda(0)
- } else {
- Ok(Device::Cpu)
+ fn bench_name<S: Into<String>>(&self, name: S) -> String {
+ match self {
+ Device::Cpu => {
+ let cpu_type = if cfg!(feature = "accelerate") {
+ "accelerate"
+ } else if cfg!(feature = "mkl") {
+ "mkl"
+ } else {
+ "cpu"
+ };
+ format!("{}_{}", cpu_type, name.into())
+ }
+ Device::Cuda(_) => format!("cuda_{}", name.into()),
+ Device::Metal(_) => format!("metal_{}", name.into()),
+ }
}
}
-pub(crate) fn bench_name<S: Into<String>>(name: S) -> String {
- format!("{}_{}", device_variant(), name.into())
+struct BenchDeviceHandler {
+ devices: Vec<Device>,
}
-const fn device_variant() -> &'static str {
- if cfg!(feature = "metal") {
- "metal"
- } else if cfg!(feature = "cuda") {
- "cuda"
- } else if cfg!(feature = "accelerate") {
- "accelerate"
- } else if cfg!(feature = "mkl") {
- "mkl"
- } else {
- "cpu"
+impl BenchDeviceHandler {
+ pub fn new() -> Result<Self> {
+ let mut devices = Vec::new();
+ if cfg!(feature = "metal") {
+ devices.push(Device::new_metal(0)?);
+ } else if cfg!(feature = "cuda") {
+ devices.push(Device::new_cuda(0)?);
+ }
+ devices.push(Device::Cpu);
+ Ok(Self { devices })
}
}
diff --git a/candle-core/benches/benchmarks/random.rs b/candle-core/benches/benchmarks/random.rs
index e4a4a390..22c60ef1 100644
--- a/candle-core/benches/benchmarks/random.rs
+++ b/candle-core/benches/benchmarks/random.rs
@@ -1,4 +1,4 @@
-use crate::benchmarks::{bench_name, device, BenchDevice};
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
@@ -11,19 +11,18 @@ fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
-fn criterion_benchmark(c: &mut Criterion) {
+fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
- let d = device().unwrap();
let dtype = DType::F32;
- let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap();
+ let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
- let mut group = c.benchmark_group(bench_name("random_uniform"));
+ let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
@@ -31,16 +30,15 @@ fn criterion_benchmark(c: &mut Criterion) {
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
- d.sync().unwrap();
+ device.sync().unwrap();
start.elapsed()
})
});
group.finish();
- let d = device().unwrap();
- let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap();
+ let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
- let mut group = c.benchmark_group(bench_name("random_normal"));
+ let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
@@ -48,11 +46,18 @@ fn criterion_benchmark(c: &mut Criterion) {
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
- d.sync().unwrap();
+ device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
+fn criterion_benchmark(c: &mut Criterion) {
+ let handler = BenchDeviceHandler::new().unwrap();
+ for device in handler.devices {
+ run_random_bench(c, &device);
+ }
+}
+
criterion_group!(benches, criterion_benchmark);
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index d06b30d1..337021aa 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -102,7 +102,7 @@ enum Command {
},
Quantize {
- /// The input file, in gguf format.
+ /// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format.
@@ -117,6 +117,15 @@ enum Command {
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode,
},
+
+ Dequantize {
+ /// The input file, in gguf format.
+ in_file: std::path::PathBuf,
+
+ /// The output file, in safetensors format.
+ #[arg(long)]
+ out_file: std::path::PathBuf,
+ },
}
#[derive(Parser, Debug, Clone)]
@@ -285,6 +294,19 @@ fn run_quantize_safetensors(
Ok(())
}
+fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
+ let mut in_file = std::fs::File::open(in_file)?;
+ let content = gguf_file::Content::read(&mut in_file)?;
+ let mut tensors = std::collections::HashMap::new();
+ for (tensor_name, _) in content.tensor_infos.iter() {
+ let tensor = content.tensor(&mut in_file, tensor_name)?;
+ let tensor = tensor.dequantize(&Device::Cpu)?;
+ tensors.insert(tensor_name.to_string(), tensor);
+ }
+ candle_core::safetensors::save(&tensors, out_file)?;
+ Ok(())
+}
+
fn run_quantize(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
@@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> {
quantization,
mode,
} => run_quantize(&in_file, out_file, quantization, mode)?,
+ Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
}
Ok(())
}
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 24beeb7a..48250233 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -592,14 +592,26 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
+ (DType::U32, DType::BF16) => "cast_u32_bf16",
+
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
+ (DType::U8, DType::BF16) => "cast_u8_bf16",
+
(DType::F32, DType::F16) => "cast_f32_f16",
- (DType::F16, DType::F32) => "cast_f16_f32",
- (DType::I64, DType::F32) => "cast_i64_f32",
(DType::F32, DType::BF16) => "cast_f32_bf16",
+
+ (DType::I64, DType::F32) => "cast_i64_f32",
+
+ (DType::F16, DType::BF16) => "cast_f16_bf16",
+ (DType::F16, DType::F32) => "cast_f16_f32",
+
+ (DType::BF16, DType::U8) => "cast_bf16_u8",
+ (DType::BF16, DType::U32) => "cast_bf16_u32",
+ (DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32",
+
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
}
@@ -677,6 +689,7 @@ impl BackendStorage for MetalStorage {
("uround", DType::F32) => contiguous::round::FLOAT,
("urecip", DType::F32) => contiguous::recip::FLOAT,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
+ ("urelu", DType::F32) => contiguous::relu::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
@@ -693,6 +706,7 @@ impl BackendStorage for MetalStorage {
("uround", DType::F16) => contiguous::round::HALF,
("urecip", DType::F16) => contiguous::recip::HALF,
("utanh", DType::F16) => contiguous::tanh::HALF,
+ ("urelu", DType::F16) => contiguous::relu::HALF,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
@@ -723,6 +737,7 @@ impl BackendStorage for MetalStorage {
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
+ ("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
@@ -737,6 +752,7 @@ impl BackendStorage for MetalStorage {
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
+ ("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
@@ -1129,8 +1145,12 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
+ (DType::U8, DType::BF16) => "is_u8_bf16",
+
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
+ (DType::U32, DType::BF16) => "is_u32_bf16",
+
(left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
}
@@ -1320,6 +1340,7 @@ impl MetalStorage {
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
+
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
@@ -1330,6 +1351,18 @@ impl MetalStorage {
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
+
+ ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
+ ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
+ ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
+ ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
+ ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
+ ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
+ ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
+ ("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
+ ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
+ ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
+
("add", DType::I64) => (contiguous::add::I64, self.dtype),
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
@@ -1340,6 +1373,7 @@ impl MetalStorage {
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
+
("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
@@ -1350,6 +1384,7 @@ impl MetalStorage {
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
+
("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
@@ -1360,6 +1395,7 @@ impl MetalStorage {
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
+
(name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
}
@@ -1393,6 +1429,7 @@ impl MetalStorage {
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
+
("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
@@ -1405,6 +1442,20 @@ impl MetalStorage {
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
+
+ ("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
+ ("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
+ ("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
+ ("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
+ ("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
+ ("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
+ ("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
+ ("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
+ ("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
+ ("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
+ ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
+ ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
+
("badd", DType::I64) => (strided::add::I64, self.dtype),
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
@@ -1417,6 +1468,7 @@ impl MetalStorage {
("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8),
+
("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
@@ -1429,6 +1481,7 @@ impl MetalStorage {
("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8),
+
("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
@@ -1441,6 +1494,7 @@ impl MetalStorage {
("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8),
+
(name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
}
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs
index 25640d1a..276b30e3 100644
--- a/candle-core/src/pickle.rs
+++ b/candle-core/src/pickle.rs
@@ -703,6 +703,7 @@ impl PthTensors {
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
+ use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
@@ -712,14 +713,21 @@ impl PthTensors {
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
- // Reading the data is a bit tricky as it can be strided, use an offset, etc.
- // For now only support the basic case.
- if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
+ // Reading the data is a bit tricky as it can be strided, for now only support the basic
+ // case.
+ if !tensor_info.layout.is_contiguous() {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
+ let start_offset = tensor_info.layout.start_offset();
+ if start_offset > 0 {
+ std::io::copy(
+ &mut reader.by_ref().take(start_offset as u64),
+ &mut std::io::sink(),
+ )?;
+ }
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index d16289e6..6210ac1e 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
let d2 = d * sc as f32;
let m2 = min * m as f32;
for (ql, qh) in ql.iter().zip(qh) {
- let to_add = if qh & u1 != 0 { 16 } else { 1 };
- y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
+ let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
+ y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
ys_index += 1;
}
for (ql, qh) in ql.iter().zip(qh) {
- let to_add = if qh & u2 != 0 { 16 } else { 1 };
- y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
+ let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
+ y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
ys_index += 1;
}
is += 2;
diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs
index 3cb56229..c4d5d6f4 100644
--- a/candle-core/src/quantized/neon.rs
+++ b/candle-core/src/quantized/neon.rs
@@ -13,6 +13,14 @@ use core::arch::arm::*;
use core::arch::aarch64::*;
#[inline(always)]
+unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
+ // TODO: dotprod
+ let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
+ let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
+ vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
+}
+
+#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
@@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
- // TODO: Support dotprod when it's available outside of nightly.
- let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
- let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
- let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
- let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
-
- let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
- let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-
+ let pl0 = vdotq_s32(v0_0ls, v1_0l);
+ let ph0 = vdotq_s32(v0_0hs, v1_0h);
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
@@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
- // TODO dotprod once this is the intrinsics are.
- let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
- let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
- let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
- let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
-
- let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
- let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
+ let p0 = vdotq_s32(x0_0, y0_0);
+ let p1 = vdotq_s32(x0_1, y0_1);
sumv0 = vmlaq_n_f32(
sumv0,
@@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i));
- let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
- let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
-
- let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
+ let xy = vdotq_s32(xs, ys);
sum_i = vaddq_s32(sum_i, xy)
}
sumf += vaddvq_s32(sum_i) as f32 * scale
@@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
- // TODO: dotprod
-
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
- );
+ let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
+ let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
- isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
+ isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
- vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
- vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
- );
+ let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
+ let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
- isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
+ isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8);
@@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
- // TODO: dotprod case.
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
- );
+ let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
+ let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
- isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
+ isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
- vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
- vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
- );
+ let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
+ let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
- isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
+ isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
}
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
@@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
- // TODO: dotprod
-
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
- );
- sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
+ let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
+ let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
+ sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
scales = scales.add(1);
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
- vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
- vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
- );
- sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
+ let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
+ let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
+ sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
scales = scales.add(1);
}
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
@@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32);
- // TODO: dotprod
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
);
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
- );
- sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
+ let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
+ let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
+ sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
@@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
);
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
- );
- sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
+ let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
+ let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
+ sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
}
sumf += d * (sumi1 + sumi2) as f32;
}
@@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3),
);
- // TODO: dotprod
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
- vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
- vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
- );
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
- vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
- vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
- );
- isum += vaddvq_s16(p0) as i32 * *scale as i32
- + vaddvq_s16(p1) as i32 * *scale.add(1) as i32
- + vaddvq_s16(p2) as i32 * *scale.add(2) as i32
- + vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
+ let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
+ let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
+ let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
+ let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
+ isum += vaddvq_s32(p0) * *scale as i32
+ + vaddvq_s32(p1) * *scale.add(1) as i32
+ + vaddvq_s32(p2) * *scale.add(2) as i32
+ + vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0);
@@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3),
);
- // TODO: dotprod
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
- vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
- vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
- );
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
- vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
- vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
- );
- isum += vaddvq_s16(p0) as i32 * *scale as i32
- + vaddvq_s16(p1) as i32 * *scale.add(1) as i32
- + vaddvq_s16(p2) as i32 * *scale.add(2) as i32
- + vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
+ let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
+ let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
+ let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
+ let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
+ isum += vaddvq_s32(p0) * *scale as i32
+ + vaddvq_s32(p1) * *scale.add(1) as i32
+ + vaddvq_s32(p2) * *scale.add(2) as i32
+ + vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
if j == 0 {
@@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
let mut is = 0usize;
// TODO: dotprod
-
for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32);
@@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t,
) -> i32 {
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
- );
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
- );
- vaddvq_s16(p1) as i32 * aux[is + index] as i32
- + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
+ let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
+ let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
+ vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
}
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index 716cca8d..d31e77a7 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -1,4 +1,5 @@
use candle_core::{
+ bail,
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
@@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
}
}
-/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
+/// Creates a vector similar to the ones used in GGML unit tests:
+/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
(0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
@@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
sum / a.len() as f32
}
-/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
+/// Similar to the GGML quantization unit test:
+/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0);
let mut dst = vec![0.0; GGML_TEST_SIZE];
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let error = calculate_rmse(src.as_slice(), dst.as_slice());
if error > max_error {
- candle_core::bail!(
+ bail!(
"Quantization error {} exceeds max error {}",
error,
max_error
@@ -404,7 +407,7 @@ fn quantize_q5k() -> Result<()> {
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
- [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
+ [-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
@@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143,
- GgmlDType::Q4_1 => 0.007784,
+ GgmlDType::Q4_1 => 0.008,
GgmlDType::Q5_0 => 0.001353,
- GgmlDType::Q5_1 => 0.001363,
+ GgmlDType::Q5_1 => 0.00149,
GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo.
GgmlDType::Q8K => 0.00065,
- _ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
+ _ => bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
}
-/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
+/// Similar to the GGML matmul unit test:
+/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0);
+ ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
+ // Another example that is more likely to trigger the overflow reported in #1526
+ let a = (0..GGML_TEST_SIZE)
+ .map(|i| i as f32 / GGML_TEST_SIZE as f32)
+ .collect::<Vec<_>>();
+ let b = (0..GGML_TEST_SIZE)
+ .map(|i| i as f32 / GGML_TEST_SIZE as f32)
+ .collect::<Vec<_>>();
+ ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
+ Ok(())
+}
+
+fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
- T::from_float(&a, &mut a_quant)?;
- T::VecDotType::from_float(&b, &mut b_quant)?;
+ T::from_float(a, &mut a_quant)?;
+ T::VecDotType::from_float(b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
- let reference_result = vec_dot_reference(&a, &b);
+ let reference_result = vec_dot_reference(a, b);
if (result - result_unopt).abs() / length as f32 > 1e-6 {
- candle_core::bail!(
+ bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
)
}
let error = (result - reference_result).abs() / length as f32;
- let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
+ let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
- candle_core::bail!(
- "Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
- );
+ bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
}
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error {
- candle_core::bail!(
+ bail!(
"Dot product error {} exceeds ggml reference error {}",
error,
ggml_error
@@ -543,6 +558,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
Ok(())
}
+#[test]
+fn quantized_mm() -> Result<()> {
+ ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
+ ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
+ ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
+ ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
+ ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
+ Ok(())
+}
+
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
fn get_random_tensors(
m: usize,
diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml
index 69438e0e..ccabf7ed 100644
--- a/candle-datasets/Cargo.toml
+++ b/candle-datasets/Cargo.toml
@@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
-candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 7e081530..00340d08 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -11,12 +11,12 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
-candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
-candle-nn = { path = "../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
-candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true }
+candle = { workspace = true }
+candle-datasets = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
+candle-flash-attn = { workspace = true, optional = true }
+candle-onnx = { workspace = true, optional = true }
csv = "1.3.0"
cudarc = { workspace = true, optional = true }
@@ -49,11 +49,12 @@ tokio = "1.29.1"
[build-dependencies]
anyhow = { workspace = true }
+bindgen_cuda = { version = "0.1.1", optional = true }
[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
-cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
+cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
diff --git a/candle-examples/build.rs b/candle-examples/build.rs
index 0af3a6a4..ba40aeb4 100644
--- a/candle-examples/build.rs
+++ b/candle-examples/build.rs
@@ -4,251 +4,34 @@ use std::io::Write;
use std::path::PathBuf;
struct KernelDirectories {
- kernel_dir: &'static str,
+ kernel_glob: &'static str,
rust_target: &'static str,
include_dirs: &'static [&'static str],
}
-const DIRS: [KernelDirectories; 1] = [KernelDirectories {
- kernel_dir: "examples/custom-ops/kernels/",
+const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
+ kernel_glob: "examples/custom-ops/kernels/*.cu",
rust_target: "examples/custom-ops/cuda_kernels.rs",
include_dirs: &[],
}];
-impl KernelDirectories {
- fn maybe_build_ptx(
- &self,
- cu_file: &std::path::Path,
- ptx_file: &std::path::Path,
- compute_cap: usize,
- ) -> Result<()> {
- let should_compile = if ptx_file.exists() {
- let ptx_modified = ptx_file.metadata()?.modified()?;
- let cu_modified = cu_file.metadata()?.modified()?;
- cu_modified.duration_since(ptx_modified).is_ok()
- } else {
- true
- };
- if should_compile {
- #[cfg(feature = "cuda")]
- {
- let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
- println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
- let mut command = std::process::Command::new("nvcc");
- let out_dir = ptx_file.parent().context("no parent for ptx file")?;
- let include_dirs: Vec<String> =
- self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
- command
- .arg(format!("--gpu-architecture=sm_{compute_cap}"))
- .arg("--ptx")
- .args(["--default-stream", "per-thread"])
- .args(["--output-directory", out_dir.to_str().unwrap()])
- .arg(format!("-I/{}", self.kernel_dir))
- .args(include_dirs)
- .arg(cu_file);
- if let Ok(ccbin_path) = &ccbin_env {
- command
- .arg("-allow-unsupported-compiler")
- .args(["-ccbin", ccbin_path]);
- }
- let output = command
- .spawn()
- .context("failed spawning nvcc")?
- .wait_with_output()?;
- if !output.status.success() {
- anyhow::bail!(
- "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
- String::from_utf8_lossy(&output.stdout),
- String::from_utf8_lossy(&output.stderr)
- )
- }
- }
- #[cfg(not(feature = "cuda"))]
- std::fs::OpenOptions::new()
- .create(true)
- .write(true)
- .open(ptx_file)?;
- }
- Ok(())
- }
- fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
- println!("cargo:rerun-if-changed={}", self.kernel_dir);
- let kernel_dir = PathBuf::from(self.kernel_dir);
- let out_dir = out_dir.join(self.kernel_dir);
- if !out_dir.exists() {
- std::fs::create_dir_all(&out_dir)?;
- }
- let mut cu_files = vec![];
- let mut cuh_files = vec![];
- for file in std::fs::read_dir(kernel_dir)?.flatten() {
- let file = file.path();
- match file.extension().and_then(|v| v.to_str()) {
- Some("cu") => cu_files.push(file),
- Some("cuh") => cuh_files.push(file),
- _ => {}
- }
- }
-
- let mut ptx_paths = vec![];
- for cu_file in cu_files.iter() {
- let file_stem = cu_file
- .file_stem()
- .with_context(|| format!("no stem {cu_file:?}"))?;
- let file_stem = file_stem.to_string_lossy().into_owned();
- let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
- self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
- ptx_paths.push(ptx_file);
- }
-
- let regenerate_rs_file = true;
- if regenerate_rs_file {
- let mut file = std::fs::File::create(self.rust_target)?;
- for ptx_path in ptx_paths {
- let name = ptx_path
- .file_stem()
- .context("empty stem")?
- .to_string_lossy();
- file.write_all(b"#[rustfmt::skip]\n")?;
- let const_definition = format!(
- r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
- name.to_uppercase().replace('.', "_"),
- self.kernel_dir,
- );
- file.write_all(const_definition.as_bytes())?;
- file.write_all(b"\n")?;
- }
- }
- Ok(())
- }
-}
-
fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs");
- let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
- let out_dir = PathBuf::from(out_dir);
- #[cfg(feature = "cuda")]
- set_cuda_include_dir()?;
#[cfg(feature = "cuda")]
- let compute_cap = compute_cap()?;
- #[cfg(not(feature = "cuda"))]
- let compute_cap = 0;
- for d in DIRS {
- d.process(&out_dir, compute_cap)?
- }
- Ok(())
-}
-
-fn set_cuda_include_dir() -> Result<()> {
- // NOTE: copied from cudarc build.rs.
- let env_vars = [
- "CUDA_PATH",
- "CUDA_ROOT",
- "CUDA_TOOLKIT_ROOT_DIR",
- "CUDNN_LIB",
- ];
- let env_vars = env_vars
- .into_iter()
- .map(std::env::var)
- .filter_map(Result::ok)
- .map(Into::<PathBuf>::into);
-
- let roots = [
- "/usr",
- "/usr/local/cuda",
- "/opt/cuda",
- "/usr/lib/cuda",
- "C:/Program Files/NVIDIA GPU Computing Toolkit",
- "C:/CUDA",
- ];
- let roots = roots.into_iter().map(Into::<PathBuf>::into);
- let root = env_vars
- .chain(roots)
- .find(|path| path.join("include").join("cuda.h").is_file())
- .context("cannot find include/cuda.h")?;
- println!(
- "cargo:rustc-env=CUDA_INCLUDE_DIR={}",
- root.join("include").display()
- );
- Ok(())
-}
-
-#[allow(unused)]
-fn compute_cap() -> Result<usize> {
- println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
-
- // Try to parse compute cap from env
- let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
- compute_cap_str
- .parse::<usize>()
- .context("Could not parse code")?
- } else {
- // Grab compute cap from nvidia-smi
- let out = std::process::Command::new("nvidia-smi")
- .arg("--query-gpu=compute_cap")
- .arg("--format=csv")
- .output()
- .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
- let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
- let mut lines = out.lines();
- assert_eq!(
- lines.next().context("missing line in stdout")?,
- "compute_cap"
- );
- let cap = lines
- .next()
- .context("missing line in stdout")?
- .replace('.', "");
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
- cap.parse::<usize>()
- .with_context(|| format!("cannot parse as int {cap}"))?
- };
-
- // Grab available GPU codes from nvcc and select the highest one
- let max_nvcc_code = {
- let out = std::process::Command::new("nvcc")
- .arg("--list-gpu-code")
- .output()
- .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
- let out = std::str::from_utf8(&out.stdout).unwrap();
-
- let out = out.lines().collect::<Vec<&str>>();
- let mut codes = Vec::with_capacity(out.len());
- for code in out {
- let code = code.split('_').collect::<Vec<&str>>();
- if !code.is_empty() && code.contains(&"sm") {
- if let Ok(num) = code[1].parse::<usize>() {
- codes.push(num);
- }
- }
+ {
+ for kdir in KERNEL_DIRS.iter() {
+ let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
+ println!("cargo:info={builder:?}");
+ let bindings = builder.build_ptx().unwrap();
+ bindings.write(kdir.rust_target).unwrap()
}
- codes.sort();
- if !codes.contains(&compute_cap) {
- anyhow::bail!(
- "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
- );
- }
- *codes.last().unwrap()
- };
-
- // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
- // then choose the highest gpu code in nvcc
- if compute_cap > max_nvcc_code {
- println!(
- "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
- );
- compute_cap = max_nvcc_code;
}
-
- println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
-
- if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
- compute_cap = compute_cap_str
- .parse::<usize>()
- .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
- println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
+ #[cfg(not(feature = "cuda"))]
+ {
+ for kdir in KERNEL_DIRS.iter() {
+ let _file = std::fs::File::create(kdir.rust_target)?;
+ }
}
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
- Ok(compute_cap)
+ Ok(())
}
diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs
index 0bee73aa..c00b601b 100644
--- a/candle-examples/examples/custom-ops/cuda_kernels.rs
+++ b/candle-examples/examples/custom-ops/cuda_kernels.rs
@@ -1,2 +1 @@
-#[rustfmt::skip]
-pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));
+pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs
index f2f534dc..30e413c1 100644
--- a/candle-examples/examples/custom-ops/main.rs
+++ b/candle-examples/examples/custom-ops/main.rs
@@ -6,7 +6,8 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
-#[allow(unused)]
+#[rustfmt::skip]
+#[cfg(feature = "cuda")]
mod cuda_kernels;
use clap::Parser;
diff --git a/candle-examples/examples/repvgg/README.md b/candle-examples/examples/repvgg/README.md
new file mode 100644
index 00000000..2cb807c1
--- /dev/null
+++ b/candle-examples/examples/repvgg/README.md
@@ -0,0 +1,20 @@
+# candle-repvgg
+
+A candle implementation of inference using a pre-trained [repvgg](https://arxiv.org/abs/2101.03697).
+This uses a classification head trained on the ImageNet dataset and returns the
+probabilities for the top-5 classes.
+
+## Running an example
+
+```
+$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
+
+loaded image Tensor[dims 3, 224, 224; f32]
+model built
+mountain bike, all-terrain bike, off-roader: 61.70%
+bicycle-built-for-two, tandem bicycle, tandem: 33.14%
+unicycle, monocycle : 4.88%
+crash helmet : 0.15%
+moped : 0.04%
+
+```
diff --git a/candle-examples/examples/repvgg/main.rs b/candle-examples/examples/repvgg/main.rs
new file mode 100644
index 00000000..0864c559
--- /dev/null
+++ b/candle-examples/examples/repvgg/main.rs
@@ -0,0 +1,111 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use clap::{Parser, ValueEnum};
+
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::repvgg;
+
+#[derive(Clone, Copy, Debug, ValueEnum)]
+enum Which {
+ A0,
+ A1,
+ A2,
+ B0,
+ B1,
+ B2,
+ B3,
+ B1G4,
+ B2G4,
+ B3G4,
+}
+
+impl Which {
+ fn model_filename(&self) -> String {
+ let name = match self {
+ Self::A0 => "a0",
+ Self::A1 => "a1",
+ Self::A2 => "a2",
+ Self::B0 => "b0",
+ Self::B1 => "b1",
+ Self::B2 => "b2",
+ Self::B3 => "b3",
+ Self::B1G4 => "b1g4",
+ Self::B2G4 => "b2g4",
+ Self::B3G4 => "b3g4",
+ };
+ format!("timm/repvgg_{}.rvgg_in1k", name)
+ }
+
+ fn config(&self) -> repvgg::Config {
+ match self {
+ Self::A0 => repvgg::Config::a0(),
+ Self::A1 => repvgg::Config::a1(),
+ Self::A2 => repvgg::Config::a2(),
+ Self::B0 => repvgg::Config::b0(),
+ Self::B1 => repvgg::Config::b1(),
+ Self::B2 => repvgg::Config::b2(),
+ Self::B3 => repvgg::Config::b3(),
+ Self::B1G4 => repvgg::Config::b1g4(),
+ Self::B2G4 => repvgg::Config::b2g4(),
+ Self::B3G4 => repvgg::Config::b3g4(),
+ }
+ }
+}
+
+#[derive(Parser)]
+struct Args {
+ #[arg(long)]
+ model: Option<String>,
+
+ #[arg(long)]
+ image: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ #[arg(value_enum, long, default_value_t=Which::A0)]
+ which: Which,
+}
+
+pub fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let image = candle_examples::imagenet::load_image224(args.image)?;
+ println!("loaded image {image:?}");
+
+ let model_file = match args.model {
+ None => {
+ let model_name = args.which.model_filename();
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model(model_name);
+ api.get("model.safetensors")?
+ }
+ Some(model) => model.into(),
+ };
+
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
+ let model = repvgg::repvgg(&args.which.config(), 1000, vb)?;
+ println!("model built");
+ let logits = model.forward(&image.unsqueeze(0)?)?;
+ let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
+ .i(0)?
+ .to_vec1::<f32>()?;
+ let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
+ prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
+ for &(category_idx, pr) in prs.iter().take(5) {
+ println!(
+ "{:24}: {:.2}%",
+ candle_examples::imagenet::CLASSES[category_idx],
+ 100. * pr
+ );
+ }
+ Ok(())
+}
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml
index 64e690e6..d8e8da82 100644
--- a/candle-flash-attn/Cargo.toml
+++ b/candle-flash-attn/Cargo.toml
@@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
-candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" }
+candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
+bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] }
-num_cpus = "1.15.0"
-rayon = "1.7.0"
+
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
-candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] }
+candle-nn = { path = "../candle-nn", features = ["cuda"] }
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index fde3aeed..4002770b 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -2,44 +2,32 @@
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
-use rayon::prelude::*;
use std::path::PathBuf;
-use std::str::FromStr;
const KERNEL_FILES: [&str; 17] = [
- "flash_api.cu",
- "flash_fwd_hdim128_fp16_sm80.cu",
- "flash_fwd_hdim160_fp16_sm80.cu",
- "flash_fwd_hdim192_fp16_sm80.cu",
- "flash_fwd_hdim224_fp16_sm80.cu",
- "flash_fwd_hdim256_fp16_sm80.cu",
- "flash_fwd_hdim32_fp16_sm80.cu",
- "flash_fwd_hdim64_fp16_sm80.cu",
- "flash_fwd_hdim96_fp16_sm80.cu",
- "flash_fwd_hdim128_bf16_sm80.cu",
- "flash_fwd_hdim160_bf16_sm80.cu",
- "flash_fwd_hdim192_bf16_sm80.cu",
- "flash_fwd_hdim224_bf16_sm80.cu",
- "flash_fwd_hdim256_bf16_sm80.cu",
- "flash_fwd_hdim32_bf16_sm80.cu",
- "flash_fwd_hdim64_bf16_sm80.cu",
- "flash_fwd_hdim96_bf16_sm80.cu",
+ "kernels/flash_api.cu",
+ "kernels/flash_fwd_hdim128_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim160_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim192_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim224_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim256_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim32_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim64_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim96_fp16_sm80.cu",
+ "kernels/flash_fwd_hdim128_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim160_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim192_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim224_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim256_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim32_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim64_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim96_bf16_sm80.cu",
];
fn main() -> Result<()> {
- let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
- |_| num_cpus::get_physical(),
- |s| usize::from_str(&s).unwrap(),
- );
-
- rayon::ThreadPoolBuilder::new()
- .num_threads(num_cpus)
- .build_global()
- .unwrap();
-
println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() {
- println!("cargo:rerun-if-changed=kernels/{kernel_file}");
+ println!("cargo:rerun-if-changed={kernel_file}");
}
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
@@ -66,223 +54,30 @@ fn main() -> Result<()> {
))
}
};
- set_cuda_include_dir()?;
-
- let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
- println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
- let compute_cap = compute_cap()?;
+ let kernels = KERNEL_FILES.iter().collect();
+ let builder = bindgen_cuda::Builder::default()
+ .kernel_paths(kernels)
+ .out_dir(build_dir.clone())
+ .arg("-std=c++17")
+ .arg("-O3")
+ .arg("-U__CUDA_NO_HALF_OPERATORS__")
+ .arg("-U__CUDA_NO_HALF_CONVERSIONS__")
+ .arg("-U__CUDA_NO_HALF2_OPERATORS__")
+ .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
+ .arg("-Icutlass/include")
+ .arg("--expt-relaxed-constexpr")
+ .arg("--expt-extended-lambda")
+ .arg("--use_fast_math")
+ .arg("--verbose");
let out_file = build_dir.join("libflashattention.a");
+ builder.build_lib(out_file);
- let kernel_dir = PathBuf::from("kernels");
- let cu_files: Vec<_> = KERNEL_FILES
- .iter()
- .map(|f| {
- let mut obj_file = out_dir.join(f);
- obj_file.set_extension("o");
- (kernel_dir.join(f), obj_file)
- })
- .collect();
- let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
- let should_compile = if out_file.exists() {
- kernel_dir
- .read_dir()
- .expect("kernels folder should exist")
- .any(|entry| {
- if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
- let in_modified = entry.metadata().unwrap().modified().unwrap();
- in_modified.duration_since(*out_modified).is_ok()
- } else {
- true
- }
- })
- } else {
- true
- };
- if should_compile {
- cu_files
- .par_iter()
- .map(|(cu_file, obj_file)| {
- let mut command = std::process::Command::new("nvcc");
- command
- .arg("-std=c++17")
- .arg("-O3")
- .arg("-U__CUDA_NO_HALF_OPERATORS__")
- .arg("-U__CUDA_NO_HALF_CONVERSIONS__")
- .arg("-U__CUDA_NO_HALF2_OPERATORS__")
- .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
- .arg(format!("--gpu-architecture=sm_{compute_cap}"))
- .arg("-c")
- .args(["-o", obj_file.to_str().unwrap()])
- .args(["--default-stream", "per-thread"])
- .arg("-Icutlass/include")
- .arg("--expt-relaxed-constexpr")
- .arg("--expt-extended-lambda")
- .arg("--use_fast_math")
- .arg("--verbose");
- if let Ok(ccbin_path) = &ccbin_env {
- command
- .arg("-allow-unsupported-compiler")
- .args(["-ccbin", ccbin_path]);
- }
- command.arg(cu_file);
- let output = command
- .spawn()
- .context("failed spawning nvcc")?
- .wait_with_output()?;
- if !output.status.success() {
- anyhow::bail!(
- "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
- &command,
- String::from_utf8_lossy(&output.stdout),
- String::from_utf8_lossy(&output.stderr)
- )
- }
- Ok(())
- })
- .collect::<Result<()>>()?;
- let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
- let mut command = std::process::Command::new("nvcc");
- command
- .arg("--lib")
- .args(["-o", out_file.to_str().unwrap()])
- .args(obj_files);
- let output = command
- .spawn()
- .context("failed spawning nvcc")?
- .wait_with_output()?;
- if !output.status.success() {
- anyhow::bail!(
- "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
- &command,
- String::from_utf8_lossy(&output.stdout),
- String::from_utf8_lossy(&output.stderr)
- )
- }
- }
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
- /* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
- finishing to run for some reason. Calling nvcc manually worked fine.
- cc::Build::new()
- .cuda(true)
- .include("cutlass/include")
- .flag("--expt-relaxed-constexpr")
- .flag("--default-stream")
- .flag("per-thread")
- .flag(&format!("--gpu-architecture=sm_{compute_cap}"))
- .file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
- .compile("flashattn");
- */
Ok(())
}
-
-fn set_cuda_include_dir() -> Result<()> {
- // NOTE: copied from cudarc build.rs.
- let env_vars = [
- "CUDA_PATH",
- "CUDA_ROOT",
- "CUDA_TOOLKIT_ROOT_DIR",
- "CUDNN_LIB",
- ];
- let env_vars = env_vars
- .into_iter()
- .map(std::env::var)
- .filter_map(Result::ok)
- .map(Into::<PathBuf>::into);
-
- let roots = [
- "/usr",
- "/usr/local/cuda",
- "/opt/cuda",
- "/usr/lib/cuda",
- "C:/Program Files/NVIDIA GPU Computing Toolkit",
- "C:/CUDA",
- ];
- let roots = roots.into_iter().map(Into::<PathBuf>::into);
- let root = env_vars
- .chain(roots)
- .find(|path| path.join("include").join("cuda.h").is_file())
- .context("cannot find include/cuda.h")?;
- println!(
- "cargo:rustc-env=CUDA_INCLUDE_DIR={}",
- root.join("include").display()
- );
- Ok(())
-}
-
-#[allow(unused)]
-fn compute_cap() -> Result<usize> {
- println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
-
- // Try to parse compute caps from env
- let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
- compute_cap_str
- .parse::<usize>()
- .context("Could not parse compute cap")?
- } else {
- // Use nvidia-smi to get the current compute cap
- let out = std::process::Command::new("nvidia-smi")
- .arg("--query-gpu=compute_cap")
- .arg("--format=csv")
- .output()
- .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
- let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
- let mut lines = out.lines();
- assert_eq!(
- lines.next().context("missing line in stdout")?,
- "compute_cap"
- );
- let cap = lines
- .next()
- .context("missing line in stdout")?
- .replace('.', "");
- let cap = cap
- .parse::<usize>()
- .with_context(|| format!("cannot parse as int {cap}"))?;
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
- cap
- };
-
- // Grab available GPU codes from nvcc and select the highest one
- let (supported_nvcc_codes, max_nvcc_code) = {
- let out = std::process::Command::new("nvcc")
- .arg("--list-gpu-code")
- .output()
- .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
- let out = std::str::from_utf8(&out.stdout).unwrap();
-
- let out = out.lines().collect::<Vec<&str>>();
- let mut codes = Vec::with_capacity(out.len());
- for code in out {
- let code = code.split('_').collect::<Vec<&str>>();
- if !code.is_empty() && code.contains(&"sm") {
- if let Ok(num) = code[1].parse::<usize>() {
- codes.push(num);
- }
- }
- }
- codes.sort();
- let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
- (codes, max_nvcc_code)
- };
-
- // Check that nvcc supports the asked compute caps
- if !supported_nvcc_codes.contains(&compute_cap) {
- anyhow::bail!(
- "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
- );
- }
- if compute_cap > max_nvcc_code {
- anyhow::bail!(
- "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
- );
- }
-
- Ok(compute_cap)
-}
diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml
index e81fe39c..0cd4a14d 100644
--- a/candle-kernels/Cargo.toml
+++ b/candle-kernels/Cargo.toml
@@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0"
[dependencies]
[build-dependencies]
-anyhow = { version = "1", features = ["backtrace"] }
-glob = "0.3.1"
-rayon = "1.7.0"
+bindgen_cuda = "0.1.1"
diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs
index 17a0bf9c..63d744ca 100644
--- a/candle-kernels/build.rs
+++ b/candle-kernels/build.rs
@@ -1,243 +1,8 @@
-use std::io::Write;
-
fn main() {
println!("cargo:rerun-if-changed=build.rs");
- cuda::set_include_dir();
- let (write, kernel_paths) = cuda::build_ptx();
- if write {
- let mut file = std::fs::File::create("src/lib.rs").unwrap();
- for kernel_path in kernel_paths {
- let name = kernel_path.file_stem().unwrap().to_str().unwrap();
- file.write_all(
- format!(
- r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
- name.to_uppercase().replace('.', "_"),
- name
- )
- .as_bytes(),
- )
- .unwrap();
- file.write_all(&[b'\n']).unwrap();
- }
- }
-}
-
-mod cuda {
- use anyhow::{Context, Result};
-
- pub fn set_include_dir() {
- use std::path::PathBuf;
- // NOTE: copied from cudarc build.rs.
- // We can't actually set a env!() value from another crate,
- // so we have to do that here.
-
- // use PathBuf;
-
- let env_vars = [
- "CUDA_PATH",
- "CUDA_ROOT",
- "CUDA_TOOLKIT_ROOT_DIR",
- "CUDNN_LIB",
- ];
- #[allow(unused)]
- let env_vars = env_vars
- .into_iter()
- .map(std::env::var)
- .filter_map(Result::ok)
- .map(Into::<PathBuf>::into);
-
- let roots = [
- "/usr",
- "/usr/local/cuda",
- "/opt/cuda",
- "/usr/lib/cuda",
- "C:/Program Files/NVIDIA GPU Computing Toolkit",
- "C:/CUDA",
- ];
- #[allow(unused)]
- let roots = roots.into_iter().map(Into::<PathBuf>::into);
-
- #[cfg(feature = "ci-check")]
- let root: PathBuf = "ci".into();
-
- #[cfg(not(feature = "ci-check"))]
- let root = env_vars
- .chain(roots)
- .find(|path| path.join("include").join("cuda.h").is_file())
- .unwrap();
-
- println!(
- "cargo:rustc-env=CUDA_INCLUDE_DIR={}",
- root.join("include").display()
- );
- }
-
- pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
- use rayon::prelude::*;
- use std::path::PathBuf;
- let out_dir = std::env::var("OUT_DIR").unwrap();
- let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
- .unwrap()
- .map(|p| p.unwrap())
- .collect();
- let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
- .unwrap()
- .map(|p| p.unwrap())
- .collect();
-
- println!("cargo:rerun-if-changed=src/");
- // for path in &kernel_paths {
- // println!("cargo:rerun-if-changed={}", path.display());
- // }
-
- for path in &mut include_directories {
- // println!("cargo:rerun-if-changed={}", path.display());
- let destination =
- std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
- std::fs::copy(path.clone(), destination).unwrap();
- // remove the filename from the path so it's just the directory
- path.pop();
- }
-
- include_directories.sort();
- include_directories.dedup();
-
- let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
-
- #[allow(unused)]
- let include_options: Vec<String> = include_directories
- .into_iter()
- .map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
- .collect::<Vec<_>>();
-
- let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
- println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
- let children = kernel_paths
- .par_iter()
- .flat_map(|p| {
- let mut output = p.clone();
- output.set_extension("ptx");
- let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
-
- let ignore = if output_filename.exists() {
- let out_modified = output_filename.metadata().unwrap().modified().unwrap();
- let in_modified = p.metadata().unwrap().modified().unwrap();
- out_modified.duration_since(in_modified).is_ok()
- } else {
- false
- };
- if ignore {
- None
- } else {
- let mut command = std::process::Command::new("nvcc");
- command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
- .arg("--ptx")
- .args(["--default-stream", "per-thread"])
- .args(["--output-directory", &out_dir])
- // Flash attention only
- // .arg("--expt-relaxed-constexpr")
- .args(&include_options);
- if let Ok(ccbin_path) = &ccbin_env {
- command
- .arg("-allow-unsupported-compiler")
- .args(["-ccbin", ccbin_path]);
- }
- command.arg(p);
- Some((p, command.spawn()
- .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
- }
- })
- .collect::<Vec<_>>();
-
- let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
- .unwrap()
- .map(|p| p.unwrap())
- .collect();
- // We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
- // some old ones
- let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
- for (kernel_path, child) in children {
- let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
- assert!(
- output.status.success(),
- "nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
- String::from_utf8_lossy(&output.stdout),
- String::from_utf8_lossy(&output.stderr)
- );
- }
- (write, kernel_paths)
- }
-
- #[allow(unused)]
- fn compute_cap() -> Result<usize> {
- println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
-
- // Try to parse compute caps from env
- let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
- compute_cap_str
- .parse::<usize>()
- .context("Could not parse code")?
- } else {
- // Use nvidia-smi to get the current compute cap
- let out = std::process::Command::new("nvidia-smi")
- .arg("--query-gpu=compute_cap")
- .arg("--format=csv")
- .output()
- .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
- let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
- let mut lines = out.lines();
- assert_eq!(
- lines.next().context("missing line in stdout")?,
- "compute_cap"
- );
- let cap = lines
- .next()
- .context("missing line in stdout")?
- .replace('.', "");
- let cap = cap
- .parse::<usize>()
- .with_context(|| format!("cannot parse as int {cap}"))?;
- println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
- cap
- };
-
- // Grab available GPU codes from nvcc and select the highest one
- let (supported_nvcc_codes, max_nvcc_code) = {
- let out = std::process::Command::new("nvcc")
- .arg("--list-gpu-code")
- .output()
- .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
- let out = std::str::from_utf8(&out.stdout).unwrap();
-
- let out = out.lines().collect::<Vec<&str>>();
- let mut codes = Vec::with_capacity(out.len());
- for code in out {
- let code = code.split('_').collect::<Vec<&str>>();
- if !code.is_empty() && code.contains(&"sm") {
- if let Ok(num) = code[1].parse::<usize>() {
- codes.push(num);
- }
- }
- }
- codes.sort();
- let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
- (codes, max_nvcc_code)
- };
-
- // Check that nvcc supports the asked compute caps
- if !supported_nvcc_codes.contains(&compute_cap) {
- anyhow::bail!(
- "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
- );
- }
- if compute_cap > max_nvcc_code {
- anyhow::bail!(
- "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
- );
- }
-
- Ok(compute_cap)
- }
+ let builder = bindgen_cuda::Builder::default();
+ println!("cargo:info={builder:?}");
+ let bindings = builder.build_ptx().unwrap();
+ bindings.write("src/lib.rs").unwrap();
}
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml
index 441d2e88..187cb4fd 100644
--- a/candle-metal-kernels/Cargo.toml
+++ b/candle-metal-kernels/Cargo.toml
@@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
+
[dependencies]
-metal = { version = "0.27.0", features = ["mps"]}
+metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
-half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
+half = { version = "2.3.1", features = [
+ "num-traits",
+ "use-intrinsics",
+ "rand_distr",
+] }
rand = "0.8.5"
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index 4166d811..3d8e7f0d 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -117,7 +117,7 @@ ELU(elu_f32, float)
ELU(elu_f16, half)
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
AFFINE(affine_bf16, bfloat);
POWF(powf_bf16, bfloat);
ELU(elu_bf16, bfloat);
diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal
index cdc8fef8..eb560f16 100644
--- a/candle-metal-kernels/src/binary.metal
+++ b/candle-metal-kernels/src/binary.metal
@@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
INT64_BINARY_OP_OUT(gt, x > y)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index e9ab17b1..9aead139 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -28,7 +28,7 @@ kernel void FN_NAME( \
if (tid >= dim) { \
return; \
} \
- output[tid] = RIGHT_TYPENAME(input[tid]); \
+ output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
if (tid >= dim) { \
return; \
} \
- output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
+ output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
+} \
+
+#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= dim) { \
+ return; \
+ } \
+ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
+} \
+kernel void FN_NAME_STRIDED( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= dim) { \
+ return; \
+ } \
+ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
@@ -58,7 +85,14 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
+CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
+CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
+CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
-#endif
+
+CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
+CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
+CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
+#endif \ No newline at end of file
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 63357428..2a57bdbb 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -173,7 +173,10 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half)
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
+INDEX_OP(is_u32_bf16, uint32_t, bfloat)
+INDEX_OP(is_u8_bf16, uint8_t, bfloat)
+
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 75f0286d..c427a690 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -178,8 +178,8 @@ macro_rules! ops{
pub mod unary {
ops!(
- cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
- recip
+ cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
+ tanh, recip
);
}
pub mod binary {
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 83a56f0a..93dac662 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 067dece8..b15505f7 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,6 +1,6 @@
use super::*;
use half::{bf16, f16};
-use metal::{Device, MTLResourceOptions};
+use metal::{Buffer, Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@@ -248,6 +248,34 @@ fn binary_add_f32() {
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
+#[test]
+fn binary_ops_bf16() {
+ let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
+ let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
+ .into_iter()
+ .map(bf16::from_f32)
+ .collect();
+
+ macro_rules! binary_op {
+ ($opname:ident, $opexpr:expr) => {{
+ let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
+ let expected: Vec<bf16> = lhs
+ .iter()
+ .zip(rhs.iter())
+ .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
+ .collect();
+ assert_eq!(results, expected);
+ }};
+ }
+
+ binary_op!(add, |x, y| x + y);
+ binary_op!(sub, |x, y| x - y);
+ binary_op!(mul, |x, y| x * y);
+ binary_op!(div, |x, y| x / y);
+ binary_op!(min, |x: bf16, y| x.min(y));
+ binary_op!(max, |x: bf16, y| x.max(y));
+}
+
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let fence = device.new_fence();
@@ -296,6 +324,89 @@ fn cast_u32_f32() {
assert_eq!(results, vec![1.0f32; 10_000]);
}
+#[test]
+fn it_cast_bf16_u32() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<u32> = cast(&input, "cast_bf16_u32");
+ let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_f32() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<f32> = cast(&input, "cast_bf16_f32");
+ let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_u8_bf16() {
+ let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
+ let expected: Vec<bf16> = input
+ .iter()
+ .map(|v| bf16::from_f32(*v as f32))
+ .collect::<Vec<_>>();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_u32_bf16() {
+ let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_f32_bf16() {
+ let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_u8() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<u8> = cast(&input, "cast_bf16_u8");
+ let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_f16() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<f16> = cast(&input, "cast_bf16_f16");
+ let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_f16_bf16() {
+ let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
+
+ assert_eq!(output, expected);
+}
+
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let fence = device.new_fence();
@@ -396,14 +507,14 @@ fn index_select() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
@@ -419,7 +530,7 @@ fn index_select_f16() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
@@ -427,12 +538,38 @@ fn index_select_f16() {
}
#[test]
+fn index_select_is_u32_bf16() {
+ let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
+ let shape = [5, 2];
+ let ids = [0u32, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
+ assert_eq!(
+ approx_bf16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
+fn index_select_is_u8_bf16() {
+ let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
+ let shape = [5, 2];
+ let ids = [0u8, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
+ assert_eq!(
+ approx_bf16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
@@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
shape: &[usize],
ids: &[I],
dim: usize,
+ name: &'static str,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
@@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
- let name = match core::mem::size_of::<T>() {
- 4 => "is_u32_f32",
- 2 => "is_u32_f16",
- _ => unimplemented!(),
- };
-
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select(
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 7fbb613d..dcf803d8 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) {
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
+template <typename T> METAL_FUNC T relu(T in){
+ if (in < 0) {
+ return 0;
+ }
+ return in;
+}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@@ -110,6 +116,7 @@ UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
+UNARY_OP(relu)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
@@ -120,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
UNARY(id, int64_t, copy_i64, copy_i64_strided)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
@@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
+BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index 5e0e5c2b..214e8a59 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -11,7 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
+candle = { workspace = true }
half = { workspace = true }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
@@ -20,7 +20,7 @@ rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
metal = { workspace = true, optional = true }
-candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
+candle-metal-kernels = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml
index ba33b07a..de1e3350 100644
--- a/candle-onnx/Cargo.toml
+++ b/candle-onnx/Cargo.toml
@@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
-candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.3.3" }
+candle = { path = "../candle-core", package = "candle-core" }
+candle-nn = { path = "../candle-nn" }
prost = "0.12.1"
[build-dependencies]
@@ -20,4 +20,3 @@ prost-build = "0.12.1"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
clap = { version = "4.2.4", features = ["derive"] }
-
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index a03c7559..7c6fbd68 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -15,9 +15,9 @@ crate-type = ["cdylib"]
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../candle-nn", version = "0.3.3" }
-candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true}
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-onnx = { workspace = true, optional = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index 83bcff62..1a72c36a 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -12,9 +12,9 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
-candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
-candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
-candle-nn = { path = "../candle-nn", version = "0.3.3" }
+candle = { workspace = true }
+candle-flash-attn = { workspace = true, optional = true }
+candle-nn = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true }
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index 51c524f5..810f2803 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -1,6 +1,6 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
-use candle_nn::{Embedding, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
@@ -112,11 +112,6 @@ impl Config {
}
}
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
struct Dropout {
#[allow(dead_code)]
pr: f64,
diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs
index c4a2d1db..e69f08c8 100644
--- a/candle-transformers/src/models/bigcode.rs
+++ b/candle-transformers/src/models/bigcode.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
@@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
Ok(Linear::new(weight, bias))
}
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs
index 6ede136a..ef5a92fc 100644
--- a/candle-transformers/src/models/falcon.rs
+++ b/candle-transformers/src/models/falcon.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor, D};
-use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;
@@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
Ok(LayerNorm::new(weight, bias, eps))
}
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
#[derive(Debug)]
pub struct Config {
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 7e8c8920..f003866a 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -1,6 +1,6 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, Module, VarBuilder};
+use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -136,11 +136,6 @@ impl Cache {
}
}
-fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, cfg.hidden_size))
-}
-
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
@@ -409,7 +404,7 @@ impl Llama {
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
- let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
+ let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 94a3bd5b..a60b5a06 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -26,6 +26,7 @@ pub mod quantized_mixformer;
pub mod quantized_mpt;
pub mod quantized_stable_lm;
pub mod quantized_t5;
+pub mod repvgg;
pub mod resnet;
pub mod segment_anything;
pub mod stable_diffusion;
diff --git a/candle-transformers/src/models/repvgg.rs b/candle-transformers/src/models/repvgg.rs
new file mode 100644
index 00000000..34016e5b
--- /dev/null
+++ b/candle-transformers/src/models/repvgg.rs
@@ -0,0 +1,306 @@
+//! RepVGG inference implementation
+//!
+//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021
+//! https://arxiv.org/abs/2101.03697
+
+use candle::{Result, Tensor, D};
+use candle_nn::{
+ batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
+};
+
+const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
+
+#[derive(Clone)]
+pub struct Config {
+ a: f32,
+ b: f32,
+ groups: usize,
+ stages: [usize; 4],
+}
+
+impl Config {
+ pub fn a0() -> Self {
+ Self {
+ a: 0.75,
+ b: 2.5,
+ groups: 1,
+ stages: [2, 4, 14, 1],
+ }
+ }
+
+ pub fn a1() -> Self {
+ Self {
+ a: 1.0,
+ b: 2.5,
+ groups: 1,
+ stages: [2, 4, 14, 1],
+ }
+ }
+
+ pub fn a2() -> Self {
+ Self {
+ a: 1.5,
+ b: 2.75,
+ groups: 1,
+ stages: [2, 4, 14, 1],
+ }
+ }
+
+ pub fn b0() -> Self {
+ Self {
+ a: 1.0,
+ b: 2.5,
+ groups: 1,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b1() -> Self {
+ Self {
+ a: 2.0,
+ b: 4.0,
+ groups: 1,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b2() -> Self {
+ Self {
+ a: 2.5,
+ b: 5.0,
+ groups: 1,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b3() -> Self {
+ Self {
+ a: 3.0,
+ b: 5.0,
+ groups: 1,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b1g4() -> Self {
+ Self {
+ a: 2.0,
+ b: 4.0,
+ groups: 4,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b2g4() -> Self {
+ Self {
+ a: 2.5,
+ b: 5.0,
+ groups: 4,
+ stages: [4, 6, 16, 1],
+ }
+ }
+
+ pub fn b3g4() -> Self {
+ Self {
+ a: 3.0,
+ b: 5.0,
+ groups: 4,
+ stages: [4, 6, 16, 1],
+ }
+ }
+}
+
+// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
+// based on the _fuse_bn_tensor method in timm
+// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
+fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
+ let (gamma, beta) = bn.weight_and_bias().unwrap();
+ let mu = bn.running_mean();
+ let sigma = (bn.running_var() + bn.eps())?.sqrt();
+ let gps = (gamma / sigma)?;
+ let bias = (beta - mu * &gps)?;
+ let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
+
+ Ok((weights, bias))
+}
+
+// A RepVGG layer has a different training time and inference time architecture.
+// The latter is a simple and efficient equivalent transformation of the former
+// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions
+// along with identity branches and batchnorm layers are fused into a single 3x3 convolution.
+fn repvgg_layer(
+ has_identity: bool,
+ dim: usize,
+ stride: usize,
+ in_channels: usize,
+ out_channels: usize,
+ groups: usize,
+ vb: VarBuilder,
+) -> Result<Func<'static>> {
+ let conv2d_cfg = Conv2dConfig {
+ stride,
+ groups,
+ padding: 1,
+ ..Default::default()
+ };
+
+ // read and reparameterize the 1x1 conv and bn into w1 and b1
+ // based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543
+
+ let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
+ let conv1x1 = conv2d_no_bias(
+ in_channels,
+ out_channels,
+ 1,
+ conv2d_cfg,
+ vb.pp("conv_1x1.conv"),
+ )?;
+
+ let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
+
+ // resize to 3x3
+ w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
+ w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
+
+ // read and reparameterize the 3x3 conv and bn into w3 and b3
+ let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
+ let conv3x3 = conv2d_no_bias(
+ in_channels,
+ out_channels,
+ 3,
+ conv2d_cfg,
+ vb.pp("conv_kxk.conv"),
+ )?;
+
+ let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
+
+ let mut w = (w1 + w3)?;
+ let mut b = (b1 + b3)?;
+
+ // read and reparameterize the identity bn into wi and bi
+ if has_identity {
+ let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
+
+ // create a 3x3 convolution equivalent to the identity branch
+ let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
+
+ // https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620
+ let in_dim = in_channels / groups;
+ for i in 0..in_channels {
+ weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
+ }
+
+ let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
+ let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
+
+ w = (w + wi)?;
+ b = (b + bi)?;
+ }
+
+ // create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches
+ let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
+
+ Ok(Func::new(move |xs| {
+ let xs = xs.apply(&reparam_conv)?.relu()?;
+ Ok(xs)
+ }))
+}
+
+// Get the number of output channels per stage taking into account the multipliers
+fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
+ let channels = CHANNELS_PER_STAGE[stage] as f32;
+
+ match stage {
+ 0 => std::cmp::min(64, (channels * a) as usize),
+ 4 => (channels * b) as usize,
+ _ => (channels * a) as usize,
+ }
+}
+
+// Each stage is made of layers. The first layer always downsamples with stride 2.
+// All but the first layer have a residual connection.
+// The G4 variants have a groupwise convolution instead of a dense one on odd layers
+// counted across stage boundaries, so we keep track of which layer we are in the
+// full model.
+fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
+ let nlayers = cfg.stages[idx - 1];
+ let mut layers = Vec::with_capacity(nlayers);
+ let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
+ let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
+ let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
+
+ for layer_idx in 0..nlayers {
+ let (has_identity, stride, in_channels) = if layer_idx == 0 {
+ (false, 2, out_channels_prev)
+ } else {
+ (true, 1, out_channels)
+ };
+
+ let groups = if (prev_layers + layer_idx) % 2 == 1 {
+ cfg.groups
+ } else {
+ 1
+ };
+
+ layers.push(repvgg_layer(
+ has_identity,
+ out_channels,
+ stride,
+ in_channels,
+ out_channels,
+ groups,
+ vb.pp(layer_idx),
+ )?)
+ }
+
+ Ok(Func::new(move |xs| {
+ let mut xs = xs.clone();
+ for layer in layers.iter() {
+ xs = xs.apply(layer)?
+ }
+ Ok(xs)
+ }))
+}
+
+// Build a RepVGG model for a given configuration.
+fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
+ let cls = match nclasses {
+ None => None,
+ Some(nclasses) => {
+ let outputs = output_channels_per_stage(config.a, config.b, 4);
+ let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
+ Some(linear)
+ }
+ };
+
+ let stem_dim = output_channels_per_stage(config.a, config.b, 0);
+ let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
+ let vb = vb.pp("stages");
+ let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
+ let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
+ let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
+ let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
+
+ Ok(Func::new(move |xs| {
+ let xs = xs
+ .apply(&stem)?
+ .apply(&stage1)?
+ .apply(&stage2)?
+ .apply(&stage3)?
+ .apply(&stage4)?
+ .mean(D::Minus1)?
+ .mean(D::Minus1)?;
+ match &cls {
+ None => Ok(xs),
+ Some(cls) => xs.apply(cls),
+ }
+ }))
+}
+
+pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
+ repvgg_model(cfg, Some(nclasses), vb)
+}
+
+pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
+ repvgg_model(cfg, None, vb)
+}
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
index 25454ba6..ea2a59b9 100644
--- a/candle-transformers/src/models/whisper/model.rs
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -1,12 +1,7 @@
use super::Config;
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
-
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
+use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
fn conv1d(
in_channels: usize,
diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml
index 59ce1be3..51358e45 100644
--- a/candle-wasm-examples/bert/Cargo.toml
+++ b/candle-wasm-examples/bert/Cargo.toml
@@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
@@ -27,7 +27,7 @@ safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
-gloo = "0.8"
+gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
serde-wasm-bindgen = "0.6.0"
diff --git a/candle-wasm-examples/blip/Cargo.toml b/candle-wasm-examples/blip/Cargo.toml
index 904e90e6..f4de054e 100644
--- a/candle-wasm-examples/blip/Cargo.toml
+++ b/candle-wasm-examples/blip/Cargo.toml
@@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml
index 63f8a9c5..d46cdafa 100644
--- a/candle-wasm-examples/llama2-c/Cargo.toml
+++ b/candle-wasm-examples/llama2-c/Cargo.toml
@@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
@@ -26,7 +26,7 @@ serde_json = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
-gloo = "0.8"
+gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
diff --git a/candle-wasm-examples/phi/Cargo.toml b/candle-wasm-examples/phi/Cargo.toml
index c4950df9..e437a937 100644
--- a/candle-wasm-examples/phi/Cargo.toml
+++ b/candle-wasm-examples/phi/Cargo.toml
@@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }
diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml
index 4d886bc2..1840bb62 100644
--- a/candle-wasm-examples/segment-anything/Cargo.toml
+++ b/candle-wasm-examples/segment-anything/Cargo.toml
@@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
num-traits = { workspace = true }
# App crates.
diff --git a/candle-wasm-examples/t5/Cargo.toml b/candle-wasm-examples/t5/Cargo.toml
index 237f9e61..5f60d917 100644
--- a/candle-wasm-examples/t5/Cargo.toml
+++ b/candle-wasm-examples/t5/Cargo.toml
@@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
@@ -27,7 +27,7 @@ safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
-gloo = "0.8"
+gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
serde-wasm-bindgen = "0.6.0"
diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml
index 5d2b2a38..92e206b2 100644
--- a/candle-wasm-examples/whisper/Cargo.toml
+++ b/candle-wasm-examples/whisper/Cargo.toml
@@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.3.3" }
-candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
+candle-transformers = { workspace = true }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
@@ -26,7 +26,7 @@ safetensors = { workspace = true }
# Wasm specific crates.
getrandom = { version = "0.2", features = ["js"] }
-gloo = "0.8"
+gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml
index eb2c320b..ac76f9a7 100644
--- a/candle-wasm-examples/yolo/Cargo.toml
+++ b/candle-wasm-examples/yolo/Cargo.toml
@@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
-candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
-candle-nn = { path = "../../candle-nn", version = "0.3.3" }
+candle = { workspace = true }
+candle-nn = { workspace = true }
num-traits = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@@ -26,7 +26,7 @@ safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
-gloo = "0.8"
+gloo = "0.11"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
diff --git a/candle-wasm-tests/Cargo.toml b/candle-wasm-tests/Cargo.toml
index a684f2ce..40c37acd 100644
--- a/candle-wasm-tests/Cargo.toml
+++ b/candle-wasm-tests/Cargo.toml
@@ -7,7 +7,7 @@ keywords.workspace = true
categories.workspace = true
[dependencies]
-candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
+candle = { workspace = true }
rand = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }