diff options
53 files changed, 1035 insertions, 1051 deletions
diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..05bcdac6 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 @@ -31,9 +31,17 @@ license = "MIT OR Apache-2.0" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" +candle = { path = "./candle-core", package = "candle-core" } +candle-datasets = { path = "./candle-datasets" } +candle-flash-attn = { path = "./candle-flash-attn" } +candle-kernels = { path = "./candle-kernels" } +candle-metal-kernels = { path = "./candle-metal-kernels" } +candle-nn = { path = "./candle-nn" } +candle-onnx = { path = "./candle-onnx" } +candle-transformers = { path = "./candle-transformers" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } -cudarc = { version = "0.9.14", features = ["f16"] } +cudarc = { version = "0.10.0", features = ["f16"] } gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } @@ -42,7 +50,7 @@ imageproc = { version = "0.23.0", default-features = false } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } libc = { version = "0.2.147" } log = "0.4" -memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] } +memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } num_cpus = "1.15.0" num-traits = "0.2.15" parquet = { version = "45.0.0" } @@ -55,7 +63,7 @@ serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" serde_json = "1.0.99" thiserror = "1" -tokenizers = { version = "0.13.4", default-features = false } +tokenizers = { version = "0.15.0", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml index e28e6623..5ccda31e 100644 --- a/candle-book/Cargo.toml +++ b/candle-book/Cargo.toml @@ -11,11 +11,11 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.3.3" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../candle-transformers", version = "0.3.3" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } +candle = { workspace = true } +candle-datasets = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index afdb67cd..92a04917 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -12,8 +12,8 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true } +candle-kernels = { workspace = true, optional = true } +candle-metal-kernels = { workspace = true, optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } diff --git a/candle-core/benches/benchmarks/matmul.rs b/candle-core/benches/benchmarks/matmul.rs index fb173f04..9d67e642 100644 --- a/candle-core/benches/benchmarks/matmul.rs +++ b/candle-core/benches/benchmarks/matmul.rs @@ -1,5 +1,5 @@ -use crate::benchmarks::{bench_name, device, BenchDevice}; -use candle_core::{DType, Tensor}; +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; @@ -7,20 +7,19 @@ fn run(a: &Tensor, b: &Tensor) { a.matmul(&b.t().unwrap()).unwrap(); } -fn criterion_benchmark(c: &mut Criterion) { +fn run_bench(c: &mut Criterion, device: &Device) { let b = 1; let m = 1; let n = 2048; let k = 2048; - let device = device().unwrap(); let dtype = DType::F32; - let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); - let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); + let lhs = Tensor::zeros((b, m, k), dtype, device).unwrap(); + let rhs = Tensor::zeros((b, n, k), dtype, device).unwrap(); let flops = b * m * n * k; - let mut group = c.benchmark_group(bench_name("matmul")); + let mut group = c.benchmark_group(device.bench_name("matmul")); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { b.iter_custom(|iters| { @@ -35,4 +34,11 @@ fn criterion_benchmark(c: &mut Criterion) { group.finish(); } +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_bench(c, &device); + } +} + criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 6bb37a70..eb20ea70 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -5,6 +5,8 @@ use candle_core::{Device, Result}; pub(crate) trait BenchDevice { fn sync(&self) -> Result<()>; + + fn bench_name<S: Into<String>>(&self, name: S) -> String; } impl BenchDevice for Device { @@ -25,32 +27,38 @@ impl BenchDevice for Device { } } } -} -pub(crate) fn device() -> Result<Device> { - if cfg!(feature = "metal") { - Device::new_metal(0) - } else if cfg!(feature = "cuda") { - Device::new_cuda(0) - } else { - Ok(Device::Cpu) + fn bench_name<S: Into<String>>(&self, name: S) -> String { + match self { + Device::Cpu => { + let cpu_type = if cfg!(feature = "accelerate") { + "accelerate" + } else if cfg!(feature = "mkl") { + "mkl" + } else { + "cpu" + }; + format!("{}_{}", cpu_type, name.into()) + } + Device::Cuda(_) => format!("cuda_{}", name.into()), + Device::Metal(_) => format!("metal_{}", name.into()), + } } } -pub(crate) fn bench_name<S: Into<String>>(name: S) -> String { - format!("{}_{}", device_variant(), name.into()) +struct BenchDeviceHandler { + devices: Vec<Device>, } -const fn device_variant() -> &'static str { - if cfg!(feature = "metal") { - "metal" - } else if cfg!(feature = "cuda") { - "cuda" - } else if cfg!(feature = "accelerate") { - "accelerate" - } else if cfg!(feature = "mkl") { - "mkl" - } else { - "cpu" +impl BenchDeviceHandler { + pub fn new() -> Result<Self> { + let mut devices = Vec::new(); + if cfg!(feature = "metal") { + devices.push(Device::new_metal(0)?); + } else if cfg!(feature = "cuda") { + devices.push(Device::new_cuda(0)?); + } + devices.push(Device::Cpu); + Ok(Self { devices }) } } diff --git a/candle-core/benches/benchmarks/random.rs b/candle-core/benches/benchmarks/random.rs index e4a4a390..22c60ef1 100644 --- a/candle-core/benches/benchmarks/random.rs +++ b/candle-core/benches/benchmarks/random.rs @@ -1,4 +1,4 @@ -use crate::benchmarks::{bench_name, device, BenchDevice}; +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; use candle_core::{DType, Device, Tensor}; use criterion::{black_box, criterion_group, Criterion, Throughput}; use std::time::Instant; @@ -11,19 +11,18 @@ fn rand_normal(a: &Tensor) { a.randn_like(100.0, 15.0).unwrap(); } -fn criterion_benchmark(c: &mut Criterion) { +fn run_random_bench(c: &mut Criterion, device: &Device) { let b = 1; let rows = 2048; let cols = 2048; - let d = device().unwrap(); let dtype = DType::F32; - let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap(); + let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap(); let flops = b * rows * cols * dtype.size_in_bytes(); - let mut group = c.benchmark_group(bench_name("random_uniform")); + let mut group = c.benchmark_group(device.bench_name("random_uniform")); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |benches| { benches.iter_custom(|iters| { @@ -31,16 +30,15 @@ fn criterion_benchmark(c: &mut Criterion) { for _i in 0..iters { rand_uniform(black_box(&tensor)); } - d.sync().unwrap(); + device.sync().unwrap(); start.elapsed() }) }); group.finish(); - let d = device().unwrap(); - let tensor = Tensor::zeros((b, rows, cols), dtype, &d).unwrap(); + let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap(); - let mut group = c.benchmark_group(bench_name("random_normal")); + let mut group = c.benchmark_group(device.bench_name("random_normal")); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |benches| { benches.iter_custom(|iters| { @@ -48,11 +46,18 @@ fn criterion_benchmark(c: &mut Criterion) { for _i in 0..iters { rand_normal(black_box(&tensor)); } - d.sync().unwrap(); + device.sync().unwrap(); start.elapsed() }) }); group.finish(); } +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_random_bench(c, &device); + } +} + criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index d06b30d1..337021aa 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -102,7 +102,7 @@ enum Command { }, Quantize { - /// The input file, in gguf format. + /// The input file(s), in safetensors format. in_file: Vec<std::path::PathBuf>, /// The output file, in gguf format. @@ -117,6 +117,15 @@ enum Command { #[arg(long, value_enum, default_value_t = QuantizationMode::Llama)] mode: QuantizationMode, }, + + Dequantize { + /// The input file, in gguf format. + in_file: std::path::PathBuf, + + /// The output file, in safetensors format. + #[arg(long)] + out_file: std::path::PathBuf, + }, } #[derive(Parser, Debug, Clone)] @@ -285,6 +294,19 @@ fn run_quantize_safetensors( Ok(()) } +fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> { + let mut in_file = std::fs::File::open(in_file)?; + let content = gguf_file::Content::read(&mut in_file)?; + let mut tensors = std::collections::HashMap::new(); + for (tensor_name, _) in content.tensor_infos.iter() { + let tensor = content.tensor(&mut in_file, tensor_name)?; + let tensor = tensor.dequantize(&Device::Cpu)?; + tensors.insert(tensor_name.to_string(), tensor); + } + candle_core::safetensors::save(&tensors, out_file)?; + Ok(()) +} + fn run_quantize( in_files: &[std::path::PathBuf], out_file: std::path::PathBuf, @@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> { quantization, mode, } => run_quantize(&in_file, out_file, quantization, mode)?, + Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?, } Ok(()) } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 24beeb7a..48250233 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -592,14 +592,26 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::BF16) => "cast_u32_bf16", + (DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::I64) => "cast_u8_i64", + (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::F32, DType::F16) => "cast_f32_f16", - (DType::F16, DType::F32) => "cast_f16_f32", - (DType::I64, DType::F32) => "cast_i64_f32", (DType::F32, DType::BF16) => "cast_f32_bf16", + + (DType::I64, DType::F32) => "cast_i64_f32", + + (DType::F16, DType::BF16) => "cast_f16_bf16", + (DType::F16, DType::F32) => "cast_f16_f32", + + (DType::BF16, DType::U8) => "cast_bf16_u8", + (DType::BF16, DType::U32) => "cast_bf16_u32", + (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") } @@ -677,6 +689,7 @@ impl BackendStorage for MetalStorage { ("uround", DType::F32) => contiguous::round::FLOAT, ("urecip", DType::F32) => contiguous::recip::FLOAT, ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("urelu", DType::F32) => contiguous::relu::FLOAT, ("ucos", DType::F16) => contiguous::cos::HALF, ("usin", DType::F16) => contiguous::sin::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF, @@ -693,6 +706,7 @@ impl BackendStorage for MetalStorage { ("uround", DType::F16) => contiguous::round::HALF, ("urecip", DType::F16) => contiguous::recip::HALF, ("utanh", DType::F16) => contiguous::tanh::HALF, + ("urelu", DType::F16) => contiguous::relu::HALF, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") } @@ -723,6 +737,7 @@ impl BackendStorage for MetalStorage { ("uabs", DType::F32) => strided::abs::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT, + ("urelu", DType::F32) => strided::relu::FLOAT, ("uround", DType::F32) => strided::round::FLOAT, ("ucos", DType::F16) => strided::cos::HALF, ("usin", DType::F16) => strided::sin::HALF, @@ -737,6 +752,7 @@ impl BackendStorage for MetalStorage { ("uabs", DType::F16) => strided::abs::HALF, ("uceil", DType::F16) => strided::ceil::HALF, ("ufloor", DType::F16) => strided::floor::HALF, + ("urelu", DType::F16) => strided::relu::HALF, ("uround", DType::F16) => strided::round::HALF, (name, dtype) => { crate::bail!("Metal strided unary {name} {dtype:?} not implemented") @@ -1129,8 +1145,12 @@ impl BackendStorage for MetalStorage { let device = self.device(); let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::BF16) => "is_u8_bf16", + (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", + (DType::U32, DType::BF16) => "is_u32_bf16", + (left, right) => { crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") } @@ -1320,6 +1340,7 @@ impl MetalStorage { ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), @@ -1330,6 +1351,18 @@ impl MetalStorage { ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + + ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype), + ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype), + ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype), + ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype), + ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), @@ -1340,6 +1373,7 @@ impl MetalStorage { ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), + ("add", DType::U32) => (contiguous::add::U32, self.dtype), ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), @@ -1350,6 +1384,7 @@ impl MetalStorage { ("lt", DType::U32) => (contiguous::lt::U32, DType::U8), ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), + ("add", DType::U8) => (contiguous::add::U8, self.dtype), ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), @@ -1360,6 +1395,7 @@ impl MetalStorage { ("lt", DType::U8) => (contiguous::lt::U8, DType::U8), ("ge", DType::U8) => (contiguous::ge::U8, DType::U8), ("gt", DType::U8) => (contiguous::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") } @@ -1393,6 +1429,7 @@ impl MetalStorage { ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), @@ -1405,6 +1442,20 @@ impl MetalStorage { ("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + + ("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype), + ("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype), + ("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype), + ("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype), + ("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype), + ("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype), + ("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (strided::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype), @@ -1417,6 +1468,7 @@ impl MetalStorage { ("lt", DType::I64) => (strided::lt::I64, DType::U8), ("ge", DType::I64) => (strided::ge::I64, DType::U8), ("gt", DType::I64) => (strided::gt::I64, DType::U8), + ("badd", DType::U32) => (strided::add::U32, self.dtype), ("bsub", DType::U32) => (strided::sub::U32, self.dtype), ("bmul", DType::U32) => (strided::mul::U32, self.dtype), @@ -1429,6 +1481,7 @@ impl MetalStorage { ("lt", DType::U32) => (strided::lt::U32, DType::U8), ("ge", DType::U32) => (strided::ge::U32, DType::U8), ("gt", DType::U32) => (strided::gt::U32, DType::U8), + ("badd", DType::U8) => (strided::add::U8, self.dtype), ("bsub", DType::U8) => (strided::sub::U8, self.dtype), ("bmul", DType::U8) => (strided::mul::U8, self.dtype), @@ -1441,6 +1494,7 @@ impl MetalStorage { ("lt", DType::U8) => (strided::lt::U8, DType::U8), ("ge", DType::U8) => (strided::ge::U8, DType::U8), ("gt", DType::U8) => (strided::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal strided binary {name} {dtype:?} not implemented") } diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 25640d1a..276b30e3 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -703,6 +703,7 @@ impl PthTensors { } pub fn get(&self, name: &str) -> Result<Option<Tensor>> { + use std::io::Read; let tensor_info = match self.tensor_infos.get(name) { None => return Ok(None), Some(tensor_info) => tensor_info, @@ -712,14 +713,21 @@ impl PthTensors { let mut zip = zip::ZipArchive::new(zip_reader)?; let mut reader = zip.by_name(&tensor_info.path)?; - // Reading the data is a bit tricky as it can be strided, use an offset, etc. - // For now only support the basic case. - if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() { + // Reading the data is a bit tricky as it can be strided, for now only support the basic + // case. + if !tensor_info.layout.is_contiguous() { crate::bail!( "cannot retrieve non-contiguous tensors {:?}", tensor_info.layout ) } + let start_offset = tensor_info.layout.start_offset(); + if start_offset > 0 { + std::io::copy( + &mut reader.by_ref().take(start_offset as u64), + &mut std::io::sink(), + )?; + } let tensor = Tensor::from_reader( tensor_info.layout.shape().clone(), tensor_info.dtype, diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index d16289e6..6210ac1e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K { let d2 = d * sc as f32; let m2 = min * m as f32; for (ql, qh) in ql.iter().zip(qh) { - let to_add = if qh & u1 != 0 { 16 } else { 1 }; - y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1; + let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1; ys_index += 1; } for (ql, qh) in ql.iter().zip(qh) { - let to_add = if qh & u2 != 0 { 16 } else { 1 }; - y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2; + let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2; ys_index += 1; } is += 2; diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 3cb56229..c4d5d6f4 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -13,6 +13,14 @@ use core::arch::arm::*; use core::arch::aarch64::*; #[inline(always)] +unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { + // TODO: dotprod + let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) +} + +#[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> { let qk = QK8_0; let nb = n / qk; @@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO: Support dotprod when it's available outside of nightly. - let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); - let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h)); - let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - + let pl0 = vdotq_s32(v0_0ls, v1_0l); + let ph0 = vdotq_s32(v0_0hs, v1_0h); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), @@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO dotprod once this is the intrinsics are. - let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); - let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); - let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + let p0 = vdotq_s32(x0_0, y0_0); + let p1 = vdotq_s32(x0_1, y0_1); sumv0 = vmlaq_n_f32( sumv0, @@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res for i in (0..QK_K).step_by(16) { let xs = vld1q_s8(xs.add(i)); let ys = vld1q_s8(ys.add(i)); - let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); - let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys)); - - let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up)); + let xy = vdotq_s32(xs, ys); sum_i = vaddq_s32(sum_i, xy) } sumf += vaddvq_s32(sum_i) as f32 * scale @@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); let q8bytes = vld1q_s8_x4(q8); @@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); - // TODO: dotprod case. - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); } sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); @@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)), - ); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32; + let p0 = vdotq_s32(q5bytes_0, q8bytes.0); + let p1 = vdotq_s32(q5bytes_1, q8bytes.1); + sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; scales = scales.add(1); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)), - ); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32; + let p2 = vdotq_s32(q5bytes_2, q8bytes.2); + let p3 = vdotq_s32(q5bytes_3, q8bytes.3); + sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; scales = scales.add(1); } sumf += d * sumi as f32 - dmin * sumi_mins as f32; @@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res for j in 0..QK_K / 64 { let q4bits = vld1q_u8_x2(q4); q4 = q4.add(32); - // TODO: dotprod let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); let q4bytes = int8x16x2_t( vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), ); - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32; + let p0 = vdotq_s32(q4bytes.0, q8bytes.0); + let p1 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); @@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32; + let p2 = vdotq_s32(q4bytes.0, q8bytes.0); + let p3 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; } sumf += d * (sumi1 + sumi2) as f32; } @@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); let q3h_0 = vbicq_u8(m2, qhbits.0); @@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); if j == 0 { @@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res let mut is = 0usize; // TODO: dotprod - for _j in 0..QK_K / 128 { let q2bits = vld1q_u8_x2(q2); q2 = q2.add(32); @@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale( q2bytes: int8x16x2_t, q8bytes: int8x16x2_t, ) -> i32 { - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)), - ); - vaddvq_s16(p1) as i32 * aux[is + index] as i32 - + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32 + let p1 = vdotq_s32(q2bytes.0, q8bytes.0); + let p2 = vdotq_s32(q2bytes.1, q8bytes.1); + vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 716cca8d..d31e77a7 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,4 +1,5 @@ use candle_core::{ + bail, quantized::{self, GgmlDType}, test_utils::to_vec2_round, Device, Module, Result, Tensor, @@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) { } } -/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 +/// Creates a vector similar to the ones used in GGML unit tests: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 fn create_ggml_like_vector(offset: f32) -> Vec<f32> { (0..GGML_TEST_SIZE) .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) @@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 { sum / a.len() as f32 } -/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 +/// Similar to the GGML quantization unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> { let src = create_ggml_like_vector(0.0); let mut dst = vec![0.0; GGML_TEST_SIZE]; let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?; let error = calculate_rmse(src.as_slice(), dst.as_slice()); if error > max_error { - candle_core::bail!( + bail!( "Quantization error {} exceeds max error {}", error, max_error @@ -404,7 +407,7 @@ fn quantize_q5k() -> Result<()> { let dst = round_vector(&dst); assert_eq!( [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], - [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499] + [-0.5, -0.373, -0.25, 0.0, 0.279, 0.499] ); let (src_big, mut dst_big) = get_test_vector(128.0, 1024); @@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> { GgmlDType::Q5K => 0.000740, GgmlDType::Q6K => 0.000952, GgmlDType::Q4_0 => 0.001143, - GgmlDType::Q4_1 => 0.007784, + GgmlDType::Q4_1 => 0.008, GgmlDType::Q5_0 => 0.001353, - GgmlDType::Q5_1 => 0.001363, + GgmlDType::Q5_1 => 0.00149, GgmlDType::Q8_0 => 0.000092, // Not from the ggml repo. GgmlDType::Q8K => 0.00065, - _ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), + _ => bail!("No GGML results for quantization type {dtype:?}",), }; Ok(err) } -/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 +/// Similar to the GGML matmul unit test: +/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> { let a = create_ggml_like_vector(0.0); let b = create_ggml_like_vector(1.0); + ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?; + // Another example that is more likely to trigger the overflow reported in #1526 + let a = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::<Vec<_>>(); + let b = (0..GGML_TEST_SIZE) + .map(|i| i as f32 / GGML_TEST_SIZE as f32) + .collect::<Vec<_>>(); + ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?; + Ok(()) +} + +fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> { let length = a.len(); let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; - T::from_float(&a, &mut a_quant)?; - T::VecDotType::from_float(&b, &mut b_quant)?; + T::from_float(a, &mut a_quant)?; + T::VecDotType::from_float(b, &mut b_quant)?; let result = T::vec_dot(length, &a_quant, &b_quant)?; let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; - let reference_result = vec_dot_reference(&a, &b); + let reference_result = vec_dot_reference(a, b); if (result - result_unopt).abs() / length as f32 > 1e-6 { - candle_core::bail!( + bail!( "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" ) } let error = (result - reference_result).abs() / length as f32; - let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; + let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m; if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { - candle_core::bail!( - "Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}", - ); + bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",); } // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML // => we use a slightly higher error threshold const ERROR_LENIENCY: f32 = 0.00001; if error - ERROR_LENIENCY > ggml_error { - candle_core::bail!( + bail!( "Dot product error {} exceeds ggml reference error {}", error, ggml_error @@ -543,6 +558,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> { Ok(()) } +#[test] +fn quantized_mm() -> Result<()> { + ggml_matmul_error_test::<k_quants::BlockQ4_0>()?; + ggml_matmul_error_test::<k_quants::BlockQ4_1>()?; + ggml_matmul_error_test::<k_quants::BlockQ5_0>()?; + ggml_matmul_error_test::<k_quants::BlockQ5_1>()?; + ggml_matmul_error_test::<k_quants::BlockQ8_0>()?; + Ok(()) +} + /// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result. fn get_random_tensors( m: usize, diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index 69438e0e..ccabf7ed 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -11,8 +11,8 @@ readme = "README.md" [dependencies] byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } hf-hub = { workspace = true} intel-mkl-src = { workspace = true, optional = true } memmap2 = { workspace = true } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 7e081530..00340d08 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -11,12 +11,12 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.3.3" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../candle-transformers", version = "0.3.3" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } -candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true } +candle = { workspace = true } +candle-datasets = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } +candle-onnx = { workspace = true, optional = true } csv = "1.3.0" cudarc = { workspace = true, optional = true } @@ -49,11 +49,12 @@ tokio = "1.29.1" [build-dependencies] anyhow = { workspace = true } +bindgen_cuda = { version = "0.1.1", optional = true } [features] default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] -cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] +cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"] cudnn = ["candle/cudnn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] diff --git a/candle-examples/build.rs b/candle-examples/build.rs index 0af3a6a4..ba40aeb4 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -4,251 +4,34 @@ use std::io::Write; use std::path::PathBuf; struct KernelDirectories { - kernel_dir: &'static str, + kernel_glob: &'static str, rust_target: &'static str, include_dirs: &'static [&'static str], } -const DIRS: [KernelDirectories; 1] = [KernelDirectories { - kernel_dir: "examples/custom-ops/kernels/", +const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories { + kernel_glob: "examples/custom-ops/kernels/*.cu", rust_target: "examples/custom-ops/cuda_kernels.rs", include_dirs: &[], }]; -impl KernelDirectories { - fn maybe_build_ptx( - &self, - cu_file: &std::path::Path, - ptx_file: &std::path::Path, - compute_cap: usize, - ) -> Result<()> { - let should_compile = if ptx_file.exists() { - let ptx_modified = ptx_file.metadata()?.modified()?; - let cu_modified = cu_file.metadata()?.modified()?; - cu_modified.duration_since(ptx_modified).is_ok() - } else { - true - }; - if should_compile { - #[cfg(feature = "cuda")] - { - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - let mut command = std::process::Command::new("nvcc"); - let out_dir = ptx_file.parent().context("no parent for ptx file")?; - let include_dirs: Vec<String> = - self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); - command - .arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("--ptx") - .args(["--default-stream", "per-thread"]) - .args(["--output-directory", out_dir.to_str().unwrap()]) - .arg(format!("-I/{}", self.kernel_dir)) - .args(include_dirs) - .arg(cu_file); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - } - #[cfg(not(feature = "cuda"))] - std::fs::OpenOptions::new() - .create(true) - .write(true) - .open(ptx_file)?; - } - Ok(()) - } - fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> { - println!("cargo:rerun-if-changed={}", self.kernel_dir); - let kernel_dir = PathBuf::from(self.kernel_dir); - let out_dir = out_dir.join(self.kernel_dir); - if !out_dir.exists() { - std::fs::create_dir_all(&out_dir)?; - } - let mut cu_files = vec![]; - let mut cuh_files = vec![]; - for file in std::fs::read_dir(kernel_dir)?.flatten() { - let file = file.path(); - match file.extension().and_then(|v| v.to_str()) { - Some("cu") => cu_files.push(file), - Some("cuh") => cuh_files.push(file), - _ => {} - } - } - - let mut ptx_paths = vec![]; - for cu_file in cu_files.iter() { - let file_stem = cu_file - .file_stem() - .with_context(|| format!("no stem {cu_file:?}"))?; - let file_stem = file_stem.to_string_lossy().into_owned(); - let ptx_file = out_dir.join(&format!("{file_stem}.ptx")); - self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?; - ptx_paths.push(ptx_file); - } - - let regenerate_rs_file = true; - if regenerate_rs_file { - let mut file = std::fs::File::create(self.rust_target)?; - for ptx_path in ptx_paths { - let name = ptx_path - .file_stem() - .context("empty stem")? - .to_string_lossy(); - file.write_all(b"#[rustfmt::skip]\n")?; - let const_definition = format!( - r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#, - name.to_uppercase().replace('.', "_"), - self.kernel_dir, - ); - file.write_all(const_definition.as_bytes())?; - file.write_all(b"\n")?; - } - } - Ok(()) - } -} - fn main() -> Result<()> { println!("cargo:rerun-if-changed=build.rs"); - let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; - let out_dir = PathBuf::from(out_dir); - #[cfg(feature = "cuda")] - set_cuda_include_dir()?; #[cfg(feature = "cuda")] - let compute_cap = compute_cap()?; - #[cfg(not(feature = "cuda"))] - let compute_cap = 0; - for d in DIRS { - d.process(&out_dir, compute_cap)? - } - Ok(()) -} - -fn set_cuda_include_dir() -> Result<()> { - // NOTE: copied from cudarc build.rs. - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::<PathBuf>::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - let roots = roots.into_iter().map(Into::<PathBuf>::into); - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .context("cannot find include/cuda.h")?; - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - Ok(()) -} - -#[allow(unused)] -fn compute_cap() -> Result<usize> { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute cap from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::<usize>() - .context("Could not parse code")? - } else { - // Grab compute cap from nvidia-smi - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap.parse::<usize>() - .with_context(|| format!("cannot parse as int {cap}"))? - }; - - // Grab available GPU codes from nvcc and select the highest one - let max_nvcc_code = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::<Vec<&str>>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::<Vec<&str>>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::<usize>() { - codes.push(num); - } - } + { + for kdir in KERNEL_DIRS.iter() { + let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob); + println!("cargo:info={builder:?}"); + let bindings = builder.build_ptx().unwrap(); + bindings.write(kdir.rust_target).unwrap() } - codes.sort(); - if !codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." - ); - } - *codes.last().unwrap() - }; - - // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, - // then choose the highest gpu code in nvcc - if compute_cap > max_nvcc_code { - println!( - "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." - ); - compute_cap = max_nvcc_code; } - - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - compute_cap = compute_cap_str - .parse::<usize>() - .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; - println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); + #[cfg(not(feature = "cuda"))] + { + for kdir in KERNEL_DIRS.iter() { + let _file = std::fs::File::create(kdir.rust_target)?; + } } - println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); - Ok(compute_cap) + Ok(()) } diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs index 0bee73aa..c00b601b 100644 --- a/candle-examples/examples/custom-ops/cuda_kernels.rs +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -1,2 +1 @@ -#[rustfmt::skip] -pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx")); +pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx")); diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index f2f534dc..30e413c1 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -6,7 +6,8 @@ #[cfg(feature = "mkl")] extern crate intel_mkl_src; -#[allow(unused)] +#[rustfmt::skip] +#[cfg(feature = "cuda")] mod cuda_kernels; use clap::Parser; diff --git a/candle-examples/examples/repvgg/README.md b/candle-examples/examples/repvgg/README.md new file mode 100644 index 00000000..2cb807c1 --- /dev/null +++ b/candle-examples/examples/repvgg/README.md @@ -0,0 +1,20 @@ +# candle-repvgg + +A candle implementation of inference using a pre-trained [repvgg](https://arxiv.org/abs/2101.03697). +This uses a classification head trained on the ImageNet dataset and returns the +probabilities for the top-5 classes. + +## Running an example + +``` +$ cargo run --example repvgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg + +loaded image Tensor[dims 3, 224, 224; f32] +model built +mountain bike, all-terrain bike, off-roader: 61.70% +bicycle-built-for-two, tandem bicycle, tandem: 33.14% +unicycle, monocycle : 4.88% +crash helmet : 0.15% +moped : 0.04% + +``` diff --git a/candle-examples/examples/repvgg/main.rs b/candle-examples/examples/repvgg/main.rs new file mode 100644 index 00000000..0864c559 --- /dev/null +++ b/candle-examples/examples/repvgg/main.rs @@ -0,0 +1,111 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::repvgg; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + A0, + A1, + A2, + B0, + B1, + B2, + B3, + B1G4, + B2G4, + B3G4, +} + +impl Which { + fn model_filename(&self) -> String { + let name = match self { + Self::A0 => "a0", + Self::A1 => "a1", + Self::A2 => "a2", + Self::B0 => "b0", + Self::B1 => "b1", + Self::B2 => "b2", + Self::B3 => "b3", + Self::B1G4 => "b1g4", + Self::B2G4 => "b2g4", + Self::B3G4 => "b3g4", + }; + format!("timm/repvgg_{}.rvgg_in1k", name) + } + + fn config(&self) -> repvgg::Config { + match self { + Self::A0 => repvgg::Config::a0(), + Self::A1 => repvgg::Config::a1(), + Self::A2 => repvgg::Config::a2(), + Self::B0 => repvgg::Config::b0(), + Self::B1 => repvgg::Config::b1(), + Self::B2 => repvgg::Config::b2(), + Self::B3 => repvgg::Config::b3(), + Self::B1G4 => repvgg::Config::b1g4(), + Self::B2G4 => repvgg::Config::b2g4(), + Self::B3G4 => repvgg::Config::b3g4(), + } + } +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option<String>, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(value_enum, long, default_value_t=Which::A0)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let model_name = args.which.model_filename(); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = repvgg::repvgg(&args.which.config(), 1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::<f32>()?; + let mut prs = prs.iter().enumerate().collect::<Vec<_>>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 64e690e6..d8e8da82 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] +bindgen_cuda = "0.1.1" anyhow = { version = "1", features = ["backtrace"] } -num_cpus = "1.15.0" -rayon = "1.7.0" + [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] } +candle-nn = { path = "../candle-nn", features = ["cuda"] } diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index fde3aeed..4002770b 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -2,44 +2,32 @@ // The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment // variable in order to cache the compiled artifacts and avoid recompiling too often. use anyhow::{Context, Result}; -use rayon::prelude::*; use std::path::PathBuf; -use std::str::FromStr; const KERNEL_FILES: [&str; 17] = [ - "flash_api.cu", - "flash_fwd_hdim128_fp16_sm80.cu", - "flash_fwd_hdim160_fp16_sm80.cu", - "flash_fwd_hdim192_fp16_sm80.cu", - "flash_fwd_hdim224_fp16_sm80.cu", - "flash_fwd_hdim256_fp16_sm80.cu", - "flash_fwd_hdim32_fp16_sm80.cu", - "flash_fwd_hdim64_fp16_sm80.cu", - "flash_fwd_hdim96_fp16_sm80.cu", - "flash_fwd_hdim128_bf16_sm80.cu", - "flash_fwd_hdim160_bf16_sm80.cu", - "flash_fwd_hdim192_bf16_sm80.cu", - "flash_fwd_hdim224_bf16_sm80.cu", - "flash_fwd_hdim256_bf16_sm80.cu", - "flash_fwd_hdim32_bf16_sm80.cu", - "flash_fwd_hdim64_bf16_sm80.cu", - "flash_fwd_hdim96_bf16_sm80.cu", + "kernels/flash_api.cu", + "kernels/flash_fwd_hdim128_fp16_sm80.cu", + "kernels/flash_fwd_hdim160_fp16_sm80.cu", + "kernels/flash_fwd_hdim192_fp16_sm80.cu", + "kernels/flash_fwd_hdim224_fp16_sm80.cu", + "kernels/flash_fwd_hdim256_fp16_sm80.cu", + "kernels/flash_fwd_hdim32_fp16_sm80.cu", + "kernels/flash_fwd_hdim64_fp16_sm80.cu", + "kernels/flash_fwd_hdim96_fp16_sm80.cu", + "kernels/flash_fwd_hdim128_bf16_sm80.cu", + "kernels/flash_fwd_hdim160_bf16_sm80.cu", + "kernels/flash_fwd_hdim192_bf16_sm80.cu", + "kernels/flash_fwd_hdim224_bf16_sm80.cu", + "kernels/flash_fwd_hdim256_bf16_sm80.cu", + "kernels/flash_fwd_hdim32_bf16_sm80.cu", + "kernels/flash_fwd_hdim64_bf16_sm80.cu", + "kernels/flash_fwd_hdim96_bf16_sm80.cu", ]; fn main() -> Result<()> { - let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else( - |_| num_cpus::get_physical(), - |s| usize::from_str(&s).unwrap(), - ); - - rayon::ThreadPoolBuilder::new() - .num_threads(num_cpus) - .build_global() - .unwrap(); - println!("cargo:rerun-if-changed=build.rs"); for kernel_file in KERNEL_FILES.iter() { - println!("cargo:rerun-if-changed=kernels/{kernel_file}"); + println!("cargo:rerun-if-changed={kernel_file}"); } println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); @@ -66,223 +54,30 @@ fn main() -> Result<()> { )) } }; - set_cuda_include_dir()?; - - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - let compute_cap = compute_cap()?; + let kernels = KERNEL_FILES.iter().collect(); + let builder = bindgen_cuda::Builder::default() + .kernel_paths(kernels) + .out_dir(build_dir.clone()) + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("-Icutlass/include") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--verbose"); let out_file = build_dir.join("libflashattention.a"); + builder.build_lib(out_file); - let kernel_dir = PathBuf::from("kernels"); - let cu_files: Vec<_> = KERNEL_FILES - .iter() - .map(|f| { - let mut obj_file = out_dir.join(f); - obj_file.set_extension("o"); - (kernel_dir.join(f), obj_file) - }) - .collect(); - let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified()); - let should_compile = if out_file.exists() { - kernel_dir - .read_dir() - .expect("kernels folder should exist") - .any(|entry| { - if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) { - let in_modified = entry.metadata().unwrap().modified().unwrap(); - in_modified.duration_since(*out_modified).is_ok() - } else { - true - } - }) - } else { - true - }; - if should_compile { - cu_files - .par_iter() - .map(|(cu_file, obj_file)| { - let mut command = std::process::Command::new("nvcc"); - command - .arg("-std=c++17") - .arg("-O3") - .arg("-U__CUDA_NO_HALF_OPERATORS__") - .arg("-U__CUDA_NO_HALF_CONVERSIONS__") - .arg("-U__CUDA_NO_HALF2_OPERATORS__") - .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") - .arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("-c") - .args(["-o", obj_file.to_str().unwrap()]) - .args(["--default-stream", "per-thread"]) - .arg("-Icutlass/include") - .arg("--expt-relaxed-constexpr") - .arg("--expt-extended-lambda") - .arg("--use_fast_math") - .arg("--verbose"); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - command.arg(cu_file); - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - &command, - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - Ok(()) - }) - .collect::<Result<()>>()?; - let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>(); - let mut command = std::process::Command::new("nvcc"); - command - .arg("--lib") - .args(["-o", out_file.to_str().unwrap()]) - .args(obj_files); - let output = command - .spawn() - .context("failed spawning nvcc")? - .wait_with_output()?; - if !output.status.success() { - anyhow::bail!( - "nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - &command, - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - } - } println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=stdc++"); - /* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never - finishing to run for some reason. Calling nvcc manually worked fine. - cc::Build::new() - .cuda(true) - .include("cutlass/include") - .flag("--expt-relaxed-constexpr") - .flag("--default-stream") - .flag("per-thread") - .flag(&format!("--gpu-architecture=sm_{compute_cap}")) - .file("kernels/flash_fwd_hdim32_fp16_sm80.cu") - .compile("flashattn"); - */ Ok(()) } - -fn set_cuda_include_dir() -> Result<()> { - // NOTE: copied from cudarc build.rs. - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::<PathBuf>::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - let roots = roots.into_iter().map(Into::<PathBuf>::into); - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .context("cannot find include/cuda.h")?; - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - Ok(()) -} - -#[allow(unused)] -fn compute_cap() -> Result<usize> { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute caps from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::<usize>() - .context("Could not parse compute cap")? - } else { - // Use nvidia-smi to get the current compute cap - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - let cap = cap - .parse::<usize>() - .with_context(|| format!("cannot parse as int {cap}"))?; - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap - }; - - // Grab available GPU codes from nvcc and select the highest one - let (supported_nvcc_codes, max_nvcc_code) = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::<Vec<&str>>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::<Vec<&str>>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::<usize>() { - codes.push(num); - } - } - } - codes.sort(); - let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?; - (codes, max_nvcc_code) - }; - - // Check that nvcc supports the asked compute caps - if !supported_nvcc_codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}." - ); - } - if compute_cap > max_nvcc_code { - anyhow::bail!( - "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}" - ); - } - - Ok(compute_cap) -} diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index e81fe39c..0cd4a14d 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0" [dependencies] [build-dependencies] -anyhow = { version = "1", features = ["backtrace"] } -glob = "0.3.1" -rayon = "1.7.0" +bindgen_cuda = "0.1.1" diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index 17a0bf9c..63d744ca 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -1,243 +1,8 @@ -use std::io::Write; - fn main() { println!("cargo:rerun-if-changed=build.rs"); - cuda::set_include_dir(); - let (write, kernel_paths) = cuda::build_ptx(); - if write { - let mut file = std::fs::File::create("src/lib.rs").unwrap(); - for kernel_path in kernel_paths { - let name = kernel_path.file_stem().unwrap().to_str().unwrap(); - file.write_all( - format!( - r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#, - name.to_uppercase().replace('.', "_"), - name - ) - .as_bytes(), - ) - .unwrap(); - file.write_all(&[b'\n']).unwrap(); - } - } -} - -mod cuda { - use anyhow::{Context, Result}; - - pub fn set_include_dir() { - use std::path::PathBuf; - // NOTE: copied from cudarc build.rs. - // We can't actually set a env!() value from another crate, - // so we have to do that here. - - // use PathBuf; - - let env_vars = [ - "CUDA_PATH", - "CUDA_ROOT", - "CUDA_TOOLKIT_ROOT_DIR", - "CUDNN_LIB", - ]; - #[allow(unused)] - let env_vars = env_vars - .into_iter() - .map(std::env::var) - .filter_map(Result::ok) - .map(Into::<PathBuf>::into); - - let roots = [ - "/usr", - "/usr/local/cuda", - "/opt/cuda", - "/usr/lib/cuda", - "C:/Program Files/NVIDIA GPU Computing Toolkit", - "C:/CUDA", - ]; - #[allow(unused)] - let roots = roots.into_iter().map(Into::<PathBuf>::into); - - #[cfg(feature = "ci-check")] - let root: PathBuf = "ci".into(); - - #[cfg(not(feature = "ci-check"))] - let root = env_vars - .chain(roots) - .find(|path| path.join("include").join("cuda.h").is_file()) - .unwrap(); - - println!( - "cargo:rustc-env=CUDA_INCLUDE_DIR={}", - root.join("include").display() - ); - } - - pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) { - use rayon::prelude::*; - use std::path::PathBuf; - let out_dir = std::env::var("OUT_DIR").unwrap(); - let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu") - .unwrap() - .map(|p| p.unwrap()) - .collect(); - let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh") - .unwrap() - .map(|p| p.unwrap()) - .collect(); - - println!("cargo:rerun-if-changed=src/"); - // for path in &kernel_paths { - // println!("cargo:rerun-if-changed={}", path.display()); - // } - - for path in &mut include_directories { - // println!("cargo:rerun-if-changed={}", path.display()); - let destination = - std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap()); - std::fs::copy(path.clone(), destination).unwrap(); - // remove the filename from the path so it's just the directory - path.pop(); - } - - include_directories.sort(); - include_directories.dedup(); - - let compute_cap = compute_cap().expect("Could not get Cuda compute cap"); - - #[allow(unused)] - let include_options: Vec<String> = include_directories - .into_iter() - .map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap()) - .collect::<Vec<_>>(); - - let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); - println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); - let children = kernel_paths - .par_iter() - .flat_map(|p| { - let mut output = p.clone(); - output.set_extension("ptx"); - let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap()); - - let ignore = if output_filename.exists() { - let out_modified = output_filename.metadata().unwrap().modified().unwrap(); - let in_modified = p.metadata().unwrap().modified().unwrap(); - out_modified.duration_since(in_modified).is_ok() - } else { - false - }; - if ignore { - None - } else { - let mut command = std::process::Command::new("nvcc"); - command.arg(format!("--gpu-architecture=sm_{compute_cap}")) - .arg("--ptx") - .args(["--default-stream", "per-thread"]) - .args(["--output-directory", &out_dir]) - // Flash attention only - // .arg("--expt-relaxed-constexpr") - .args(&include_options); - if let Ok(ccbin_path) = &ccbin_env { - command - .arg("-allow-unsupported-compiler") - .args(["-ccbin", ccbin_path]); - } - command.arg(p); - Some((p, command.spawn() - .expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output())) - } - }) - .collect::<Vec<_>>(); - - let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx")) - .unwrap() - .map(|p| p.unwrap()) - .collect(); - // We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed - // some old ones - let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len(); - for (kernel_path, child) in children { - let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - assert!( - output.status.success(), - "nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ); - } - (write, kernel_paths) - } - - #[allow(unused)] - fn compute_cap() -> Result<usize> { - println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); - - // Try to parse compute caps from env - let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); - compute_cap_str - .parse::<usize>() - .context("Could not parse code")? - } else { - // Use nvidia-smi to get the current compute cap - let out = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=compute_cap") - .arg("--format=csv") - .output() - .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; - let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; - let mut lines = out.lines(); - assert_eq!( - lines.next().context("missing line in stdout")?, - "compute_cap" - ); - let cap = lines - .next() - .context("missing line in stdout")? - .replace('.', ""); - let cap = cap - .parse::<usize>() - .with_context(|| format!("cannot parse as int {cap}"))?; - println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); - cap - }; - - // Grab available GPU codes from nvcc and select the highest one - let (supported_nvcc_codes, max_nvcc_code) = { - let out = std::process::Command::new("nvcc") - .arg("--list-gpu-code") - .output() - .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); - let out = std::str::from_utf8(&out.stdout).unwrap(); - - let out = out.lines().collect::<Vec<&str>>(); - let mut codes = Vec::with_capacity(out.len()); - for code in out { - let code = code.split('_').collect::<Vec<&str>>(); - if !code.is_empty() && code.contains(&"sm") { - if let Ok(num) = code[1].parse::<usize>() { - codes.push(num); - } - } - } - codes.sort(); - let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?; - (codes, max_nvcc_code) - }; - - // Check that nvcc supports the asked compute caps - if !supported_nvcc_codes.contains(&compute_cap) { - anyhow::bail!( - "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}." - ); - } - if compute_cap > max_nvcc_code { - anyhow::bail!( - "CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}" - ); - } - - Ok(compute_cap) - } + let builder = bindgen_cuda::Builder::default(); + println!("cargo:info={builder:?}"); + let bindings = builder.build_ptx().unwrap(); + bindings.write("src/lib.rs").unwrap(); } diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 441d2e88..187cb4fd 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"] categories = ["science"] license = "MIT OR Apache-2.0" + [dependencies] -metal = { version = "0.27.0", features = ["mps"]} +metal = { version = "0.27.0", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" [dev-dependencies] -half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +half = { version = "2.3.1", features = [ + "num-traits", + "use-intrinsics", + "rand_distr", +] } rand = "0.8.5" diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 4166d811..3d8e7f0d 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -117,7 +117,7 @@ ELU(elu_f32, float) ELU(elu_f16, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) AFFINE(affine_bf16, bfloat); POWF(powf_bf16, bfloat); ELU(elu_bf16, bfloat); diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index cdc8fef8..eb560f16 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y) INT64_BINARY_OP_OUT(gt, x > y) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x * y, mul) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index e9ab17b1..9aead139 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -28,7 +28,7 @@ kernel void FN_NAME( \ if (tid >= dim) { \ return; \ } \ - output[tid] = RIGHT_TYPENAME(input[tid]); \ + output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \ if (tid >= dim) { \ return; \ } \ - output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \ + output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \ +} \ + +#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \ } \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) @@ -58,7 +85,14 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) +CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) +CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) +CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) -#endif + +CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) +CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) +CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) +#endif
\ No newline at end of file diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 63357428..2a57bdbb 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -173,7 +173,10 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float) SCATTER_ADD_OP(sa_u32_f16, uint, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u32_bf16, uint32_t, bfloat) +INDEX_OP(is_u8_bf16, uint8_t, bfloat) + INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 75f0286d..c427a690 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -178,8 +178,8 @@ macro_rules! ops{ pub mod unary { ops!( - cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh, - recip + cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, + tanh, recip ); } pub mod binary { diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 83a56f0a..93dac662 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 067dece8..b15505f7 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{Device, MTLResourceOptions}; +use metal::{Buffer, Device, MTLResourceOptions}; fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { let ptr = buffer.contents() as *const T; @@ -248,6 +248,34 @@ fn binary_add_f32() { assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); } +#[test] +fn binary_ops_bf16() { + let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect(); + let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32] + .into_iter() + .map(bf16::from_f32) + .collect(); + + macro_rules! binary_op { + ($opname:ident, $opexpr:expr) => {{ + let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT); + let expected: Vec<bf16> = lhs + .iter() + .zip(rhs.iter()) + .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y)) + .collect(); + assert_eq!(results, expected); + }}; + } + + binary_op!(add, |x, y| x + y); + binary_op!(sub, |x, y| x - y); + binary_op!(mul, |x, y| x * y); + binary_op!(div, |x, y| x / y); + binary_op!(min, |x: bf16, y| x.min(y)); + binary_op!(max, |x: bf16, y| x.max(y)); +} + fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { let device = device(); let fence = device.new_fence(); @@ -296,6 +324,89 @@ fn cast_u32_f32() { assert_eq!(results, vec![1.0f32; 10_000]); } +#[test] +fn it_cast_bf16_u32() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<u32> = cast(&input, "cast_bf16_u32"); + let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f32() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<f32> = cast(&input, "cast_bf16_f32"); + let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u8_bf16() { + let input: Vec<u8> = (1..=3).map(|v| v as u8).collect(); + + let output: Vec<bf16> = cast(&input, "cast_u8_bf16"); + let expected: Vec<bf16> = input + .iter() + .map(|v| bf16::from_f32(*v as f32)) + .collect::<Vec<_>>(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u32_bf16() { + let input: Vec<u32> = (1..=3).map(|v| v as u32).collect(); + + let output: Vec<bf16> = cast(&input, "cast_u32_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f32_bf16() { + let input: Vec<f32> = (1..=3).map(|v| v as f32).collect(); + + let output: Vec<bf16> = cast(&input, "cast_f32_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_u8() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<u8> = cast(&input, "cast_bf16_u8"); + let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f16() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<f16> = cast(&input, "cast_bf16_f16"); + let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f16_bf16() { + let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect(); + + let output: Vec<bf16> = cast(&input, "cast_f16_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { let device = device(); let fence = device.new_fence(); @@ -396,14 +507,14 @@ fn index_select() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [2, 5]; let ids = [0u32, 1, 0]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] @@ -419,7 +530,7 @@ fn index_select_f16() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16"); assert_eq!( approx_f16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -427,12 +538,38 @@ fn index_select_f16() { } #[test] +fn index_select_is_u32_bf16() { + let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_is_u8_bf16() { + let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u8, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] @@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( shape: &[usize], ids: &[I], dim: usize, + name: &'static str, ) -> Vec<T> { let device = Device::system_default().expect("no device found"); @@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let dst_el = ids.len() * left_size * right_size; let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - let name = match core::mem::size_of::<T>() { - 4 => "is_u32_f32", - 2 => "is_u32_f16", - _ => unimplemented!(), - }; - let fence = device.new_fence(); let kernels = Kernels::new(fence); call_index_select( diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 7fbb613d..dcf803d8 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) { T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); } +template <typename T> METAL_FUNC T relu(T in){ + if (in < 0) { + return 0; + } + return in; +} #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -110,6 +116,7 @@ UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) UNARY_OP(recip) +UNARY_OP(relu) UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) @@ -120,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided) UNARY(id, int64_t, copy_i64, copy_i64_strided) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) BFLOAT_UNARY_OP(sqr) @@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(recip) +BFLOAT_UNARY_OP(relu) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 5e0e5c2b..214e8a59 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } +candle = { workspace = true } half = { workspace = true } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } @@ -20,7 +20,7 @@ rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } metal = { workspace = true, optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } +candle-metal-kernels = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index ba33b07a..de1e3350 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { path = "../candle-core", package = "candle-core" } +candle-nn = { path = "../candle-nn" } prost = "0.12.1" [build-dependencies] @@ -20,4 +20,3 @@ prost-build = "0.12.1" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } clap = { version = "4.2.4", features = ["derive"] } - diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index a03c7559..7c6fbd68 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -15,9 +15,9 @@ crate-type = ["cdylib"] [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.3" } -candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true} +candle = { workspace = true } +candle-nn = { workspace = true } +candle-onnx = { workspace = true, optional = true } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 83bcff62..1a72c36a 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -12,9 +12,9 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } -candle-nn = { path = "../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-flash-attn = { workspace = true, optional = true } +candle-nn = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } rand = { workspace = true } diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 51c524f5..810f2803 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,6 +1,6 @@ use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use candle::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use serde::Deserialize; pub const DTYPE: DType = DType::F32; @@ -112,11 +112,6 @@ impl Config { } } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - struct Dropout { #[allow(dead_code)] pr: f64, diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs index c4a2d1db..e69f08c8 100644 --- a/candle-transformers/src/models/bigcode.rs +++ b/candle-transformers/src/models/bigcode.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { let weight = vb.get((size2, size1), "weight")?; @@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line Ok(Linear::new(weight, bias)) } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 6ede136a..ef5a92fc 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; @@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { Ok(LayerNorm::new(weight, bias, eps)) } -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - // https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py #[derive(Debug)] pub struct Config { diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 7e8c8920..f003866a 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,6 +1,6 @@ use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Module, VarBuilder}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -136,11 +136,6 @@ impl Cache { } } -fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; - Ok(Embedding::new(embeddings, cfg.hidden_size)) -} - struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, @@ -409,7 +404,7 @@ impl Llama { } pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 94a3bd5b..a60b5a06 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -26,6 +26,7 @@ pub mod quantized_mixformer; pub mod quantized_mpt; pub mod quantized_stable_lm; pub mod quantized_t5; +pub mod repvgg; pub mod resnet; pub mod segment_anything; pub mod stable_diffusion; diff --git a/candle-transformers/src/models/repvgg.rs b/candle-transformers/src/models/repvgg.rs new file mode 100644 index 00000000..34016e5b --- /dev/null +++ b/candle-transformers/src/models/repvgg.rs @@ -0,0 +1,306 @@ +//! RepVGG inference implementation +//! +//! See "RepVGG: Making VGG-style ConvNets Great Again" Ding et al. 2021 +//! https://arxiv.org/abs/2101.03697 + +use candle::{Result, Tensor, D}; +use candle_nn::{ + batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder, +}; + +const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512]; + +#[derive(Clone)] +pub struct Config { + a: f32, + b: f32, + groups: usize, + stages: [usize; 4], +} + +impl Config { + pub fn a0() -> Self { + Self { + a: 0.75, + b: 2.5, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn a1() -> Self { + Self { + a: 1.0, + b: 2.5, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn a2() -> Self { + Self { + a: 1.5, + b: 2.75, + groups: 1, + stages: [2, 4, 14, 1], + } + } + + pub fn b0() -> Self { + Self { + a: 1.0, + b: 2.5, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b1() -> Self { + Self { + a: 2.0, + b: 4.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b2() -> Self { + Self { + a: 2.5, + b: 5.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b3() -> Self { + Self { + a: 3.0, + b: 5.0, + groups: 1, + stages: [4, 6, 16, 1], + } + } + + pub fn b1g4() -> Self { + Self { + a: 2.0, + b: 4.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } + + pub fn b2g4() -> Self { + Self { + a: 2.5, + b: 5.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } + + pub fn b3g4() -> Self { + Self { + a: 3.0, + b: 5.0, + groups: 4, + stages: [4, 6, 16, 1], + } + } +} + +// fuses a convolutional kernel and a batchnorm layer into a convolutional layer +// based on the _fuse_bn_tensor method in timm +// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602 +fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> { + let (gamma, beta) = bn.weight_and_bias().unwrap(); + let mu = bn.running_mean(); + let sigma = (bn.running_var() + bn.eps())?.sqrt(); + let gps = (gamma / sigma)?; + let bias = (beta - mu * &gps)?; + let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?; + + Ok((weights, bias)) +} + +// A RepVGG layer has a different training time and inference time architecture. +// The latter is a simple and efficient equivalent transformation of the former +// realized by a structural reparameterization technique, where 3x3 and 1x1 convolutions +// along with identity branches and batchnorm layers are fused into a single 3x3 convolution. +fn repvgg_layer( + has_identity: bool, + dim: usize, + stride: usize, + in_channels: usize, + out_channels: usize, + groups: usize, + vb: VarBuilder, +) -> Result<Func<'static>> { + let conv2d_cfg = Conv2dConfig { + stride, + groups, + padding: 1, + ..Default::default() + }; + + // read and reparameterize the 1x1 conv and bn into w1 and b1 + // based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L543 + + let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?; + let conv1x1 = conv2d_no_bias( + in_channels, + out_channels, + 1, + conv2d_cfg, + vb.pp("conv_1x1.conv"), + )?; + + let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?; + + // resize to 3x3 + w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?; + w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?; + + // read and reparameterize the 3x3 conv and bn into w3 and b3 + let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?; + let conv3x3 = conv2d_no_bias( + in_channels, + out_channels, + 3, + conv2d_cfg, + vb.pp("conv_kxk.conv"), + )?; + + let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?; + + let mut w = (w1 + w3)?; + let mut b = (b1 + b3)?; + + // read and reparameterize the identity bn into wi and bi + if has_identity { + let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?; + + // create a 3x3 convolution equivalent to the identity branch + let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()]; + + // https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L620 + let in_dim = in_channels / groups; + for i in 0..in_channels { + weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0; + } + + let weights = &Tensor::from_vec(weights, w.shape(), w.device())?; + let (wi, bi) = fuse_conv_bn(weights, identity_bn)?; + + w = (w + wi)?; + b = (b + bi)?; + } + + // create the 3x3 conv equivalent to the sum of 3x3, 1x1 and identity branches + let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg); + + Ok(Func::new(move |xs| { + let xs = xs.apply(&reparam_conv)?.relu()?; + Ok(xs) + })) +} + +// Get the number of output channels per stage taking into account the multipliers +fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize { + let channels = CHANNELS_PER_STAGE[stage] as f32; + + match stage { + 0 => std::cmp::min(64, (channels * a) as usize), + 4 => (channels * b) as usize, + _ => (channels * a) as usize, + } +} + +// Each stage is made of layers. The first layer always downsamples with stride 2. +// All but the first layer have a residual connection. +// The G4 variants have a groupwise convolution instead of a dense one on odd layers +// counted across stage boundaries, so we keep track of which layer we are in the +// full model. +fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> { + let nlayers = cfg.stages[idx - 1]; + let mut layers = Vec::with_capacity(nlayers); + let prev_layers: usize = cfg.stages[..idx - 1].iter().sum(); + let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1); + let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx); + + for layer_idx in 0..nlayers { + let (has_identity, stride, in_channels) = if layer_idx == 0 { + (false, 2, out_channels_prev) + } else { + (true, 1, out_channels) + }; + + let groups = if (prev_layers + layer_idx) % 2 == 1 { + cfg.groups + } else { + 1 + }; + + layers.push(repvgg_layer( + has_identity, + out_channels, + stride, + in_channels, + out_channels, + groups, + vb.pp(layer_idx), + )?) + } + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for layer in layers.iter() { + xs = xs.apply(layer)? + } + Ok(xs) + })) +} + +// Build a RepVGG model for a given configuration. +fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> { + let cls = match nclasses { + None => None, + Some(nclasses) => { + let outputs = output_channels_per_stage(config.a, config.b, 4); + let linear = linear(outputs, nclasses, vb.pp("head.fc"))?; + Some(linear) + } + }; + + let stem_dim = output_channels_per_stage(config.a, config.b, 0); + let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?; + let vb = vb.pp("stages"); + let stage1 = repvgg_stage(config, 1, vb.pp(0))?; + let stage2 = repvgg_stage(config, 2, vb.pp(1))?; + let stage3 = repvgg_stage(config, 3, vb.pp(2))?; + let stage4 = repvgg_stage(config, 4, vb.pp(3))?; + + Ok(Func::new(move |xs| { + let xs = xs + .apply(&stem)? + .apply(&stage1)? + .apply(&stage2)? + .apply(&stage3)? + .apply(&stage4)? + .mean(D::Minus1)? + .mean(D::Minus1)?; + match &cls { + None => Ok(xs), + Some(cls) => xs.apply(cls), + } + })) +} + +pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> { + repvgg_model(cfg, Some(nclasses), vb) +} + +pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> { + repvgg_model(cfg, None, vb) +} diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 25454ba6..ea2a59b9 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,12 +1,7 @@ use super::Config; use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} +use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; fn conv1d( in_channels: usize, diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml index 59ce1be3..51358e45 100644 --- a/candle-wasm-examples/bert/Cargo.toml +++ b/candle-wasm-examples/bert/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } @@ -27,7 +27,7 @@ safetensors = { workspace = true } # Wasm specific crates. console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" serde-wasm-bindgen = "0.6.0" diff --git a/candle-wasm-examples/blip/Cargo.toml b/candle-wasm-examples/blip/Cargo.toml index 904e90e6..f4de054e 100644 --- a/candle-wasm-examples/blip/Cargo.toml +++ b/candle-wasm-examples/blip/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } num-traits = { workspace = true } diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index 63f8a9c5..d46cdafa 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } @@ -26,7 +26,7 @@ serde_json = { workspace = true } # Wasm specific crates. console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.37" diff --git a/candle-wasm-examples/phi/Cargo.toml b/candle-wasm-examples/phi/Cargo.toml index c4950df9..e437a937 100644 --- a/candle-wasm-examples/phi/Cargo.toml +++ b/candle-wasm-examples/phi/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } num-traits = { workspace = true } diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml index 4d886bc2..1840bb62 100644 --- a/candle-wasm-examples/segment-anything/Cargo.toml +++ b/candle-wasm-examples/segment-anything/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } # App crates. diff --git a/candle-wasm-examples/t5/Cargo.toml b/candle-wasm-examples/t5/Cargo.toml index 237f9e61..5f60d917 100644 --- a/candle-wasm-examples/t5/Cargo.toml +++ b/candle-wasm-examples/t5/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } @@ -27,7 +27,7 @@ safetensors = { workspace = true } # Wasm specific crates. console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" serde-wasm-bindgen = "0.6.0" diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 5d2b2a38..92e206b2 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } @@ -26,7 +26,7 @@ safetensors = { workspace = true } # Wasm specific crates. getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.37" diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index eb2c320b..ac76f9a7 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.3" } +candle = { workspace = true } +candle-nn = { workspace = true } num-traits = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } @@ -26,7 +26,7 @@ safetensors = { workspace = true } # Wasm specific crates. console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } -gloo = "0.8" +gloo = "0.11" js-sys = "0.3.64" wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.37" diff --git a/candle-wasm-tests/Cargo.toml b/candle-wasm-tests/Cargo.toml index a684f2ce..40c37acd 100644 --- a/candle-wasm-tests/Cargo.toml +++ b/candle-wasm-tests/Cargo.toml @@ -7,7 +7,7 @@ keywords.workspace = true categories.workspace = true [dependencies] -candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } +candle = { workspace = true } rand = { workspace = true } getrandom = { version = "0.2", features = ["js"] } |