diff options
77 files changed, 2327 insertions, 276 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 589accf9..86366c21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,7 +63,7 @@ This documents the main changes to the `candle` crate. [760](https://github.com/huggingface/candle/pull/760). - Add the Segment-Anything Model (SAM) as an example [773](https://github.com/huggingface/candle/pull/773). -- TinyViT backbone for the segemnt anything example +- TinyViT backbone for the segment anything example [787](https://github.com/huggingface/candle/pull/787). - Shape with holes support [770](https://github.com/huggingface/candle/pull/770). @@ -19,7 +19,7 @@ exclude = [ resolver = "2" [workspace.package] -version = "0.3.1" +version = "0.3.2" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" @@ -54,7 +54,7 @@ These online demos run entirely in your browser: - [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition. - [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation. - [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation. -- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation. +- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation. - [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation. - [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning. @@ -62,11 +62,14 @@ We also provide a some command line based examples using state of the art models - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM. -- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b. +- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM pre-trained on 1T tokens of English and code datasets. - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with - performance larger than all publicly available 13b models as of 2023-09-28. + better performance than all publicly available 13b models as of 2023-09-28. +- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of + experts 8x7b general LLM with better performance than a Llama 2 70B model with + much faster inference. - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation. - [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion. - [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual @@ -78,7 +81,7 @@ We also provide a some command line based examples using state of the art models <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600"> - [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to - image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions. + image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions. <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200"> @@ -122,7 +125,7 @@ There are also some wasm examples for whisper and [whisper](https://huggingface.co/spaces/lmz/candle-whisper), [llama2](https://huggingface.co/spaces/lmz/candle-llama2), [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm), -[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm), +[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm), [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm). For LLaMA2, run the following command to retrieve the weight files and start a @@ -141,8 +144,10 @@ And then head over to ## Useful External Resources - [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A very detailed tutorial showing how to convert a PyTorch model to Candle. -- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has - out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples). +- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and + ergonomic LoRA implementation for Candle. `candle-lora` has + out-of-the-box LoRA support for many models from Candle, which can be found + [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples). - [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop. - [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and @@ -171,8 +176,9 @@ If you have an addition to this list, please submit a pull request. - LLaMA v1 and v2. - Falcon. - StarCoder. - - Phi v1.5. + - Phi 1, 1.5, and 2. - Mistral 7b v0.1. + - Mixtral 8x7b v0.1. - StableLM-3B-4E1T. - Replit-code-v1.5-3B. - Bert. @@ -180,8 +186,9 @@ If you have an addition to this list, please submit a pull request. - Quantized LLMs. - Llama 7b, 13b, 70b, as well as the chat and code variants. - Mistral 7b, and 7b instruct. - - Zephyr 7b a and b (Mistral based). - - OpenChat 3.5 (Mistral based). + - Mixtral 8x7b. + - Zephyr 7b a and b (Mistral-7b based). + - OpenChat 3.5 (Mistral-7b based). - Text to text. - T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction). - Marian MT (Machine Translation). diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml index d384019b..94f10c08 100644 --- a/candle-book/Cargo.toml +++ b/candle-book/Cargo.toml @@ -11,11 +11,11 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.3.1" } -candle-nn = { path = "../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../candle-transformers", version = "0.3.1" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true } +candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" } +candle-datasets = { path = "../candle-datasets", version = "0.3.2" } +candle-nn = { path = "../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../candle-transformers", version = "0.3.2" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.2", optional = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-book/src/apps/dekstop.md b/candle-book/src/apps/desktop.md index 32cc4441..32cc4441 100644 --- a/candle-book/src/apps/dekstop.md +++ b/candle-book/src/apps/desktop.md diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 0f8c1a9f..52e79a5a 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -12,8 +12,8 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true } +candle-kernels = { path = "../candle-kernels", version = "0.3.2", optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.2", optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index fc0c79a2..c152f31f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -114,7 +114,7 @@ impl Tensor { | Op::Unary(_node, UnaryOp::Round) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D(node) - | Op::UpsampleNearest2D(node) + | Op::UpsampleNearest2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. } | Op::Copy(node) @@ -350,9 +350,27 @@ impl Tensor { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest1d", })?, - Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { - op: "upsample-nearest2d", - })?, + Op::UpsampleNearest2D { + arg, + target_h, + target_w, + } => { + let (_n, c, h, w) = arg.dims4()?; + if target_h % h != 0 || target_w % w != 0 { + crate::bail!("backward not supported for non integer upscaling factors") + } + let scale_h = target_h / h; + let scale_w = target_w / w; + + if scale_h != scale_w { + crate::bail!("backward not supported for non uniform upscaling factors") + }; + let kernel = + Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?; + let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = conv_sum; + } Op::SliceScatter0(lhs, rhs, start_rhs) => { let rhs_sum_grad = grads.or_insert(rhs)?; let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index df106b73..e3ed41e5 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -64,7 +64,7 @@ impl Tensor { #[derive(Debug)] /// Generic structure used to index a slice of the tensor pub enum TensorIndexer { - /// This selects the elemnts for which an index has some specific value. + /// This selects the elements for which an index has some specific value. Select(usize), /// This is a regular slice, purely indexing a chunk of the tensor Narrow(Bound<usize>, Bound<usize>), diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index fbb20f6c..868673e7 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -132,7 +132,11 @@ pub enum Op { }, UpsampleNearest1D(Tensor), - UpsampleNearest2D(Tensor), + UpsampleNearest2D { + arg: Tensor, + target_h: usize, + target_w: usize, + }, Cat(Vec<Tensor>, usize), diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index 5c3ac822..664f7653 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res q3 = q3.add(32); // Prepare low and high bits - // We hardcode the shifts here to avoid loading them into a seperate register + // We hardcode the shifts here to avoid loading them into a separate register let q3l_0 = _mm256_and_si256(q3bits, m3); let q3h_0 = if j == 0 { _mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0) @@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bits = _mm256_loadu_si256(q5 as *const __m256i); q5 = q5.add(32); - //Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register + //Similar to q3k we hardcode the shifts here to avoid loading them into a separate register let q5l_0 = _mm256_and_si256(q5bits, m4); let q5l_0_shift_input = _mm256_and_si256(hbits, hmask); let q5l_0_right_shift = match j { diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 620bc037..1e9dc517 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -463,7 +463,7 @@ impl Content { ) -> Result<QTensor> { let tensor_info = match self.tensor_infos.get(name) { Some(tensor_info) => tensor_info, - None => crate::bail!("cannot find tensor-infor for {name}"), + None => crate::bail!("cannot find tensor info for {name}"), }; tensor_info.read(reader, self.tensor_data_offset) } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index e478869a..f15f8c1c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,4 +1,4 @@ -//! Tensors are N-dimenional matrixes of elements using a single data type. +//! Tensors are N-dimensional matrixes of elements using a single data type. #![allow(clippy::redundant_closure_call)] use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{ @@ -361,6 +361,16 @@ impl Tensor { Self::new_impl(array, shape, device, false) } + /// Returns a new tensor with all the elements having the same specified value. Note that + /// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed. + pub fn full<D: crate::WithDType, S: Into<Shape>>( + value: D, + shape: S, + device: &Device, + ) -> Result<Self> { + Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape) + } + /// Creates a new 1D tensor from an iterator. pub fn from_iter<D: crate::WithDType>( iter: impl IntoIterator<Item = D>, @@ -669,7 +679,7 @@ impl Tensor { } /// Split a tensor into the specified number of chunks, this may return less chunks than - /// specificed. + /// specified. pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> { let dim = dim.to_index(self.shape(), "chunk")?; let size = self.dim(dim)?; @@ -994,7 +1004,11 @@ impl Tensor { /// tensor also has four dimensions, `(batch, channels, target_h, target_w)`. pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> { let (n, c, _h, _w) = self.dims4()?; - let op = BackpropOp::new1(self, Op::UpsampleNearest2D); + let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D { + arg, + target_h, + target_w, + }); let storage = self .storage() .upsample_nearest2d(self.layout(), target_h, target_w)?; @@ -2558,6 +2572,13 @@ impl Tensor { } mask.where_cond(/* on_true= */ &src, /* on_false= */ self) } + + /// Returns log(sum(exp(tensor), dim)). + pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + let exp = self.exp()?; + let sum = exp.sum(sum_dims)?; + sum.log() + } } macro_rules! bin_trait { diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 791532f2..16e7a82f 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -270,6 +270,166 @@ fn unary_grad(device: &Device) -> Result<()> { [0.7358, 2.0000, 0.2707, 1.0000] ); + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; + let y = x.interpolate2d(6, 6)?.reshape(36)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., + 07., 08., 09., 10., 11., 12., + 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., + 25., 26., 27., 28., 29., 30., + 31., 32., 33., 34., 35., 36., + ], + device, + )?; + // gradient should be + // row 1 + // 1+2+7+8 = 18 + // 3+4+9+10 = 26 + // 5+6+11+12 = 34 + // row 2 + // 13+14+19+20 = 66 + // 15+16+21+22 = 74 + // 17+18+23+24 = 82 + // row 3 + // 25+26+31+32 = 114 + // 27+28+33+34 = 122 + // 29+30+35+36 = 130 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?, + [[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]] + ); + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?; + let y = x.interpolate2d(6, 6)?.reshape(36)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., 05., 06., + 07., 08., 09., 10., 11., 12., + 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., + 25., 26., 27., 28., 29., 30., + 31., 32., 33., 34., 35., 36., + ], + device, + )?; + // gradient should be + // row 1 + // 1+2+3+7+8+9+13+14+15 = 72 + // 4+5+6+10+11+12+16+17+18 = 99 + // row 2 + // 19+20+21+25+26+27+31+32+33 = 234 + // 22+23+24+28+29+30+34+35+36 = 243 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!( + test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?, + [[72_f32, 99.], [234., 261.]] + ); + + // manually checked: see comments + let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?; + + let y = x.interpolate2d(4, 4)?.reshape(32)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32. + ], + device, + )?; + // gradient should be + // m1r1 + // 1+2+5+6=14 + // 3+4+7+8=22 + // m1r2 + // 9+10+13+14=46 + // 11+12+15+16=54 + // m2r1 + // 17+18+21+22=78 + // 19+20+23+24=86 + // m2r2 + // 25+26+29+30=110 + // 27+28+31+32=118 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?, + [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]] + ); + + // manually checked: see comments + let x = Var::new( + &[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]], + device, + )?; + + let y = x.interpolate2d(4, 4)?.reshape(32)?; + + #[rustfmt::skip] + let z = Tensor::new( + &[ + 1_f32, 02., 03., 04., + 05., 06., 07., 08., + 09., 10., 11., 12., + 13., 14., 15., 16., + 17., 18., 19., 20., + 21., 22., 23., 24., + 25., 26., 27., 28., + 29., 30., 31., 32. + ], + device, + )?; + // gradient should be + // m1r1 + // 1+2+5+6=14 + // 3+4+7+8=22 + // m1r2 + // 9+10+13+14=46 + // 11+12+15+16=54 + // m2r1 + // 17+18+21+22=78 + // 19+20+23+24=86 + // m2r2 + // 25+26+29+30=110 + // 27+28+31+32=118 + let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?; + + let grads = loss.backward()?; + + let grad_x = grads.get(&x).context("no grad for x")?; + + assert_eq!( + test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?, + [[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]] + ); Ok(()) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c871dc96..e83fb55b 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,4 +1,4 @@ -use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor}; +use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D}; fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; @@ -32,6 +32,14 @@ fn ones(device: &Device) -> Result<()> { Ok(()) } +fn full(device: &Device) -> Result<()> { + assert_eq!( + Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?, + [[42, 42, 42], [42, 42, 42]], + ); + Ok(()) +} + fn arange(device: &Device) -> Result<()> { assert_eq!( Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?, @@ -1072,6 +1080,7 @@ fn randn(device: &Device) -> Result<()> { test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal); test_device!(ones, ones_cpu, ones_gpu, ones_metal); +test_device!(full, full_cpu, full_gpu, full_metal); test_device!(arange, arange_cpu, arange_gpu, arange_metal); test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal); @@ -1221,3 +1230,26 @@ fn cumsum() -> Result<()> { ); Ok(()) } + +/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data. +/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon. +fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { + let a_vec: Vec<f64> = a.to_vec1()?; + let b_vec: Vec<f64> = b.to_vec1()?; + + assert_eq!(a_vec.len(), b_vec.len()); + for (a, b) in a_vec.iter().zip(b_vec.iter()) { + assert!((a - b).abs() < epsilon); + } + Ok(()) +} + +#[test] +fn logsumexp() -> Result<()> { + let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let output = input.logsumexp(D::Minus1)?; + // The expectations obtained from pytorch. + let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; + assert_close(&output, &expected, 0.00001)?; + Ok(()) +} diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index 103f77d6..5b0b9dee 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -11,8 +11,8 @@ readme = "README.md" [dependencies] byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.1" } +candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.3.2" } hf-hub = { workspace = true} intel-mkl-src = { workspace = true, optional = true } memmap2 = { workspace = true } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index adfa529e..0c4bf20e 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -11,12 +11,12 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.3.1" } -candle-nn = { path = "../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../candle-transformers", version = "0.3.1" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true } -candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true } +candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" } +candle-datasets = { path = "../candle-datasets", version = "0.3.2" } +candle-nn = { path = "../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../candle-transformers", version = "0.3.2" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.2", optional = true } +candle-onnx = { path = "../candle-onnx", version = "0.3.2", optional = true } cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } image = { workspace = true } diff --git a/candle-examples/build.rs b/candle-examples/build.rs index e21f1767..0af3a6a4 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -32,6 +32,8 @@ impl KernelDirectories { if should_compile { #[cfg(feature = "cuda")] { + let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); + println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); let mut command = std::process::Command::new("nvcc"); let out_dir = ptx_file.parent().context("no parent for ptx file")?; let include_dirs: Vec<String> = @@ -44,6 +46,11 @@ impl KernelDirectories { .arg(format!("-I/{}", self.kernel_dir)) .args(include_dirs) .arg(cu_file); + if let Ok(ccbin_path) = &ccbin_env { + command + .arg("-allow-unsupported-compiler") + .args(["-ccbin", ccbin_path]); + } let output = command .spawn() .context("failed spawning nvcc")? @@ -168,8 +175,16 @@ fn set_cuda_include_dir() -> Result<()> { #[allow(unused)] fn compute_cap() -> Result<usize> { - // Grab compute code from nvidia-smi - let mut compute_cap = { + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + + // Try to parse compute cap from env + let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}"); + compute_cap_str + .parse::<usize>() + .context("Could not parse code")? + } else { + // Grab compute cap from nvidia-smi let out = std::process::Command::new("nvidia-smi") .arg("--query-gpu=compute_cap") .arg("--format=csv") @@ -185,6 +200,7 @@ fn compute_cap() -> Result<usize> { .next() .context("missing line in stdout")? .replace('.', ""); + println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}"); cap.parse::<usize>() .with_context(|| format!("cannot parse as int {cap}"))? }; diff --git a/candle-examples/examples/bert/README.md b/candle-examples/examples/bert/README.md index 82ca5f40..5a75b516 100644 --- a/candle-examples/examples/bert/README.md +++ b/candle-examples/examples/bert/README.md @@ -2,10 +2,10 @@ Bert is a general large language model. In this example it can be used for two different tasks: + - Compute sentence embeddings for a prompt. - Compute similarities between a set of sentences. - ## Sentence embeddings Bert is used to compute the sentence embeddings for a prompt. The model weights @@ -24,6 +24,48 @@ cargo run --example bert --release -- --prompt "Here is a test sentence" > Tensor[[1, 7, 384], f32] ``` +### Custom models + +You can specify different models, such as BGE, with the `--model-id` flag: + +```bash +cargo run --example bert --release -- \ +--model-id BAAI/bge-large-zh-v1.5 \ +--prompt "Here is a test sentence" +Loaded and encoded 435.70775ms +[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1], + [-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0], + [ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1], + ... + [ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1], + [ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1], + [ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]] +Tensor[[1, 9, 1024], f32] +Took 176.744667ms +``` + +### Gelu approximation + +You can get a speedup by using an approximation of the gelu activation, with a +small loss of precision, by passing the `--approximate-gelu` flag: + +```bash +$ cargo run --example bert --release -- \ +--model-id BAAI/bge-large-zh-v1.5 \ +--prompt "Here is a test sentence" \ +--approximate-gelu +Loaded and encoded 244.388042ms +[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1], + [-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0], + [ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1], + ... + [ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1], + [ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1], + [ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]] +Tensor[[1, 9, 1024], f32] +Took 116.840791ms +``` + ## Similarities In this example, Bert is used to compute the sentence embeddings for a set of diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index fcd2eab9..88e29718 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -3,7 +3,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_transformers::models::bert::{BertModel, Config, DTYPE}; +use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; use anyhow::{Error as E, Result}; use candle::Tensor; @@ -45,6 +45,10 @@ struct Args { /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, + + /// Use tanh based approximation for Gelu instead of erf implementation. + #[arg(long, default_value = "false")] + approximate_gelu: bool, } impl Args { @@ -73,7 +77,7 @@ impl Args { (config, tokenizer, weights) }; let config = std::fs::read_to_string(config_filename)?; - let config: Config = serde_json::from_str(&config)?; + let mut config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let vb = if self.use_pth { @@ -81,6 +85,9 @@ impl Args { } else { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } }; + if self.approximate_gelu { + config.hidden_act = HiddenAct::GeluApproximate; + } let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) } diff --git a/candle-examples/examples/mixtral/README.md b/candle-examples/examples/mixtral/README.md new file mode 100644 index 00000000..aec5c148 --- /dev/null +++ b/candle-examples/examples/mixtral/README.md @@ -0,0 +1,25 @@ +# candle-mixtral: 8x7b LLM using a sparse mixture of experts. + +Mixtral-8x7B-v0.1 is a pretrained generative LLM with 56 billion parameters. + +- [Blog post](https://mistral.ai/news/mixtral-of-experts/) from Mistral announcing the model release. +- [Model card](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) on the HuggingFace Hub. + +## Running the example + +```bash +$ cargo run --example mixtral --release -- --prompt "def print_prime(n): " +def print_prime(n): # n is the number of prime numbers to be printed + i = 2 + count = 0 + while (count < n): + if (isPrime(i)): + print(i) + count += 1 + i += 1 + +def isPrime(n): + for x in range(2, int(n**0.5)+1): + if (n % x == 0): + ... +``` diff --git a/candle-examples/examples/mixtral/main.rs b/candle-examples/examples/mixtral/main.rs new file mode 100644 index 00000000..fcde03c1 --- /dev/null +++ b/candle-examples/examples/mixtral/main.rs @@ -0,0 +1,263 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::mixtral::{Config, Model}; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option<f64>, + top_p: Option<f64>, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("</s>") { + Some(token) => token, + None => anyhow::bail!("cannot find the </s> token"), + }; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, start_pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + use_flash_attn: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option<f64>, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 100)] + sample_len: usize, + + #[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")] + model_id: String, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer_file: Option<String>, + + #[arg(long)] + weight_files: Option<String>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::<Vec<_>>(), + None => { + vec![ + repo.get("model-00001-of-00019.safetensors")?, + repo.get("model-00002-of-00019.safetensors")?, + repo.get("model-00003-of-00019.safetensors")?, + repo.get("model-00004-of-00019.safetensors")?, + repo.get("model-00005-of-00019.safetensors")?, + repo.get("model-00006-of-00019.safetensors")?, + repo.get("model-00007-of-00019.safetensors")?, + repo.get("model-00008-of-00019.safetensors")?, + repo.get("model-00009-of-00019.safetensors")?, + repo.get("model-00010-of-00019.safetensors")?, + repo.get("model-00011-of-00019.safetensors")?, + repo.get("model-00012-of-00019.safetensors")?, + repo.get("model-00013-of-00019.safetensors")?, + repo.get("model-00014-of-00019.safetensors")?, + repo.get("model-00015-of-00019.safetensors")?, + repo.get("model-00016-of-00019.safetensors")?, + repo.get("model-00017-of-00019.safetensors")?, + repo.get("model-00018-of-00019.safetensors")?, + repo.get("model-00019-of-00019.safetensors")?, + ] + } + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let config = Config::v0_1_8x7b(args.use_flash_attn); + let device = candle_examples::device(args.cpu)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-examples/examples/phi/README.md b/candle-examples/examples/phi/README.md index 566411d1..70af6650 100644 --- a/candle-examples/examples/phi/README.md +++ b/candle-examples/examples/phi/README.md @@ -1,14 +1,33 @@ -# candle-phi: 1.3b LLM with state of the art performance for <10b models. +# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models. -[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using -only 1.3 billion parameters but with state of the art performance compared to +[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and +[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using +only 1.3 and 2.7 billion parameters but with state of the art performance compared to models with up to 10 billion parameters. The candle implementation provides both the standard version as well as a quantized variant. -## Running some example +## Running some examples +For the v2 version. +```bash +$ cargo run --example phi --release -- --model 2 \ + --prompt "A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?" + +A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom? + +Solution: +The potential energy of the skier is converted into kinetic energy as it slides down the slope. The formula for potential energy is mgh, where m is mass, g is acceleration due to gravity (9.8 m/s^2), and h is height. Since there's no friction, all the potential energy is converted into kinetic energy at the bottom of the slope. The formula for kinetic energy is 1/2mv^2, where v is velocity. We can equate these two formulas: +mgh = 1/2mv^2 +Solving for v, we get: +v = sqrt(2gh) +Substituting the given values, we get: +v = sqrt(2*9.8*40) = 28 m/s +Therefore, the skier speed at the bottom of the slope is 28 m/s. +``` + +For the v1.5 version. ```bash $ cargo run --example phi --release -- --prompt "def print_prime(n): " diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 720a4441..52d453b5 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -123,6 +123,8 @@ enum WhichModel { V1, #[value(name = "1.5")] V1_5, + #[value(name = "2")] + V2, PuffinPhiV2, PhiHermes, } @@ -158,7 +160,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, short = 'n', default_value_t = 100)] + #[arg(long, short = 'n', default_value_t = 5000)] sample_len: usize, #[arg(long)] @@ -225,6 +227,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), + WhichModel::V2 => "microsoft/phi-2".to_string(), WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "lmz/candle-quantized-phi".to_string() } @@ -241,7 +244,9 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), - WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), + WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + "main".to_string() + } } } } @@ -250,27 +255,32 @@ fn main() -> Result<()> { let tokenizer_filename = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => match args.model { - WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?, + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?, WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { repo.get("tokenizer-puffin-phi-v2.json")? } }, }; - let filename = match args.weight_file { - Some(weight_file) => std::path::PathBuf::from(weight_file), + let filenames = match args.weight_file { + Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], None => { if args.quantized { match args.model { - WhichModel::V1 => repo.get("model-v1-q4k.gguf")?, - WhichModel::V1_5 => repo.get("model-q4k.gguf")?, - WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?, - WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?, + WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], + WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], + WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?], + WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], + WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], } } else { match args.model { - WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?, - WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?, - WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?, + WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], + WhichModel::V2 => vec![ + repo.get("model-00001-of-00002.safetensors")?, + repo.get("model-00002-of-00002.safetensors")?, + ], + WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?], + WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?], } } } @@ -282,17 +292,24 @@ fn main() -> Result<()> { let config = match args.model { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), + WhichModel::V2 => Config::v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; - let model = QMixFormer::new(&config, vb)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let model = match args.model { + WhichModel::V2 => QMixFormer::new_v2(&config, vb)?, + _ => QMixFormer::new(&config, vb)?, + }; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = MixFormer::new(&config, vb)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let model = match args.model { + WhichModel::V2 => MixFormer::new_v2(&config, vb)?, + _ => MixFormer::new(&config, vb)?, + }; (Model::MixFormer(model), device) }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/quantized/README.md b/candle-examples/examples/quantized/README.md index bed09243..8144bffe 100644 --- a/candle-examples/examples/quantized/README.md +++ b/candle-examples/examples/quantized/README.md @@ -26,6 +26,19 @@ cargo run --example quantized --release -- --prompt "The best thing about coding > The best thing about coding in rust is 1.) that I don’t need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines. ``` +Using the mixtral sparse mixture of expert model: +```bash + +$ cargo run --example quantized --release -- --which mixtral --prompt "Lebesgue's integral is superior to Riemann's because " +> avx: true, neon: false, simd128: false, f16c: true +> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64 +> loaded 995 tensors (26.44GB) in 0.03s +Lebesgue's integral is superior to Riemann's because 1. it is defined for a wider class of functions, those which are absolutely integrable; 2. the definition does not involve limits in two variables---one being computed before the other (which makes some computations more difficult); and 3. interchange of order of integration is easier to establish than with Riemann's integral. On the other hand, Lebesgue's integral applies only for bounded functions defined on finite intervals; it does not provide numerical values for improper integrals. The latter are best evaluated using Cauchy's limit definition. + +The reason $f(x) = x^2$ is discontinuous at the ends of its interval of definition, and Riemann's integral requires continuity on the whole of an open interval containing it (see our earlier post), sine no such function exists with this property, is that the endpoints are infinite in measure for Lebesgue's integral. + ``` + + ## Command-line flags Run with `--help` to see all options. diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index ab8a56ba..df758b4f 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -45,6 +45,10 @@ enum Which { L13bCode, #[value(name = "32b-code")] L34bCode, + #[value(name = "7b-leo")] + Leo7b, + #[value(name = "13b-leo")] + Leo13b, #[value(name = "7b-mistral")] Mistral7b, #[value(name = "7b-mistral-instruct")] @@ -55,6 +59,12 @@ enum Which { Zephyr7bBeta, #[value(name = "7b-open-chat-3.5")] OpenChat35, + #[value(name = "7b-starling-a")] + Starling7bAlpha, + #[value(name = "mixtral")] + Mixtral, + #[value(name = "mixtral-instruct")] + MixtralInstruct, } impl Which { @@ -68,12 +78,17 @@ impl Which { | Self::L70bChat | Self::L7bCode | Self::L13bCode - | Self::L34bCode => false, + | Self::L34bCode + | Self::Leo7b + | Self::Leo13b => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the - // same way. + // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 + | Self::Starling7bAlpha | Self::Zephyr7bAlpha | Self::Zephyr7bBeta + | Self::Mixtral + | Self::MixtralInstruct | Self::Mistral7b | Self::Mistral7bInstruct => true, } @@ -90,15 +105,43 @@ impl Which { | Self::L7bCode | Self::L13bCode | Self::L34bCode + | Self::Leo7b + | Self::Leo13b + | Self::Mixtral + | Self::MixtralInstruct | Self::Mistral7b | Self::Mistral7bInstruct - | Self::OpenChat35 => false, + | Self::OpenChat35 + | Self::Starling7bAlpha => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } } fn is_open_chat(&self) -> bool { match self { + Self::L7b + | Self::L13b + | Self::L70b + | Self::L7bChat + | Self::L13bChat + | Self::L70bChat + | Self::L7bCode + | Self::L13bCode + | Self::L34bCode + | Self::Leo7b + | Self::Leo13b + | Self::Mixtral + | Self::MixtralInstruct + | Self::Mistral7b + | Self::Mistral7bInstruct + | Self::Zephyr7bAlpha + | Self::Zephyr7bBeta => false, + Self::OpenChat35 | Self::Starling7bAlpha => true, + } + } + + fn tokenizer_repo(&self) -> &'static str { + match self { Which::L7b | Which::L13b | Which::L70b @@ -107,12 +150,17 @@ impl Which { | Which::L70bChat | Which::L7bCode | Which::L13bCode - | Which::L34bCode - | Which::Mistral7b + | Which::L34bCode => "hf-internal-testing/llama-tokenizer", + Which::Leo7b => "LeoLM/leo-hessianai-7b", + Which::Leo13b => "LeoLM/leo-hessianai-13b", + Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1", + Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1", + Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha - | Which::Zephyr7bBeta => false, - Which::OpenChat35 => true, + | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", + Which::OpenChat35 => "openchat/openchat_3.5", + Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", } } } @@ -181,13 +229,7 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; - let repo = if self.which.is_open_chat() { - "openchat/openchat_3.5" - } else if self.which.is_mistral() { - "mistralai/Mistral-7B-v0.1" - } else { - "hf-internal-testing/llama-tokenizer" - }; + let repo = self.which.tokenizer_repo(); let api = api.model(repo.to_string()); api.get("tokenizer.json")? } @@ -218,6 +260,22 @@ impl Args { Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"), Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"), Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"), + Which::Leo7b => ( + "TheBloke/leo-hessianai-7B-GGUF", + "leo-hessianai-7b.Q4_K_M.gguf", + ), + Which::Leo13b => ( + "TheBloke/leo-hessianai-13B-GGUF", + "leo-hessianai-13b.Q4_K_M.gguf", + ), + Which::Mixtral => ( + "TheBloke/Mixtral-8x7B-v0.1-GGUF", + "mixtral-8x7b-v0.1.Q4_K_M.gguf", + ), + Which::MixtralInstruct => ( + "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF", + "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf", + ), Which::Mistral7b => ( "TheBloke/Mistral-7B-v0.1-GGUF", "mistral-7b-v0.1.Q4_K_S.gguf", @@ -234,6 +292,10 @@ impl Args { ("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf") } Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"), + Which::Starling7bAlpha => ( + "TheBloke/Starling-LM-7B-alpha-GGUF", + "starling-lm-7b-alpha.Q4_K_M.gguf", + ), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -329,14 +391,19 @@ fn main() -> anyhow::Result<()> { | Which::L13bChat | Which::L7bCode | Which::L13bCode - | Which::L34bCode => 1, - Which::Mistral7b + | Which::L34bCode + | Which::Leo7b + | Which::Leo13b => 1, + Which::Mixtral + | Which::MixtralInstruct + | Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha | Which::Zephyr7bBeta | Which::L70b | Which::L70bChat - | Which::OpenChat35 => 8, + | Which::OpenChat35 + | Which::Starling7bAlpha => 8, }; ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? } @@ -369,7 +436,7 @@ fn main() -> anyhow::Result<()> { } } if args.which.is_open_chat() { - format!("User: {prompt}<|end_of_turn|>Assistant: ") + format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:") } else if args.which.is_zephyr() { if prompt_index == 0 || is_interactive { format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",) diff --git a/candle-examples/examples/reinforcement-learning/atari_wrappers.py b/candle-examples/examples/reinforcement-learning/atari_wrappers.py index b5c4665d..b76fb85d 100644 --- a/candle-examples/examples/reinforcement-learning/atari_wrappers.py +++ b/candle-examples/examples/reinforcement-learning/atari_wrappers.py @@ -78,7 +78,7 @@ class EpisodicLifeEnv(gym.Wrapper): # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() if lives < self.lives and lives > 0: - # for Qbert somtimes we stay in lives == 0 condtion for a few frames + # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. done = True diff --git a/candle-examples/examples/stable-diffusion/README.md b/candle-examples/examples/stable-diffusion/README.md index b8736a2a..feb7ca56 100644 --- a/candle-examples/examples/stable-diffusion/README.md +++ b/candle-examples/examples/stable-diffusion/README.md @@ -8,7 +8,7 @@ XL using Rust and [candle](https://github.com/huggingface/candle). The `stable-diffusion` example is a conversion of [diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1, -as well as Stable Diffusion XL 1.0. +as well as Stable Diffusion XL 1.0, and Turbo. ## Getting the weights @@ -23,16 +23,26 @@ cargo run --example stable-diffusion --release --features=cuda,cudnn \ -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" ``` -The final image is named `sd_final.png` by default. -The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The -original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). +The final image is named `sd_final.png` by default. The Turbo version is much +faster than previous versions, to give it a try add a `--sd-version turbo` flag, +e.g.: + +```bash +cargo run --example stable-diffusion --release --features=cuda,cudnn \ + -- --prompt "a cosmonaut on a horse (hd, realistic, high-def) --sd-version turbo" +``` + +The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising +Diffusion Implicit Model scheduler (DDIM). The original paper and some code can +be found in the [associated repo](https://github.com/ermongroup/ddim). +The default scheduler for the XL Turbo version is the Euler Ancestral scheduler. ### Command-line flags - `--prompt`: the prompt to be used to generate the image. - `--uncond-prompt`: the optional unconditional prompt. -- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or - `xl`. +- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, + `xl`, or `turbo`. - `--cpu`: use the cpu rather than the gpu (much slower). - `--height`, `--width`: set the height and width for the generated image. - `--n-steps`: the number of steps to be used in the diffusion process. diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 3e6de34d..8c3ca2ee 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D}; use clap::Parser; use tokenizers::Tokenizer; -const GUIDANCE_SCALE: f64 = 7.5; - #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Args { @@ -63,8 +61,8 @@ struct Args { sliced_attention_size: Option<usize>, /// The number of steps to run the diffusion for. - #[arg(long, default_value_t = 30)] - n_steps: usize, + #[arg(long)] + n_steps: Option<usize>, /// The number of samples to generate. #[arg(long, default_value_t = 1)] @@ -87,6 +85,9 @@ struct Args { #[arg(long)] use_f16: bool, + #[arg(long)] + guidance_scale: Option<f64>, + #[arg(long, value_name = "FILE")] img2img: Option<String>, @@ -102,6 +103,7 @@ enum StableDiffusionVersion { V1_5, V2_1, Xl, + Turbo, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -120,12 +122,13 @@ impl StableDiffusionVersion { Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", Self::V2_1 => "stabilityai/stable-diffusion-2-1", Self::V1_5 => "runwayml/stable-diffusion-v1-5", + Self::Turbo => "stabilityai/sdxl-turbo", } } fn unet_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "unet/diffusion_pytorch_model.fp16.safetensors" } else { @@ -137,7 +140,7 @@ impl StableDiffusionVersion { fn vae_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "vae/diffusion_pytorch_model.fp16.safetensors" } else { @@ -149,7 +152,7 @@ impl StableDiffusionVersion { fn clip_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "text_encoder/model.fp16.safetensors" } else { @@ -161,7 +164,7 @@ impl StableDiffusionVersion { fn clip2_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "text_encoder_2/model.fp16.safetensors" } else { @@ -189,7 +192,7 @@ impl ModelFile { StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { "openai/clip-vit-base-patch32" } - StableDiffusionVersion::Xl => { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { // This seems similar to the patch32 version except some very small // difference in the split regex. "openai/clip-vit-large-patch14" @@ -206,7 +209,11 @@ impl ModelFile { Self::Vae => { // Override for SDXL when using f16 weights. // See https://github.com/huggingface/candle/issues/1060 - if version == StableDiffusionVersion::Xl && use_f16 { + if matches!( + version, + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, + ) && use_f16 + { ( "madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", @@ -261,6 +268,7 @@ fn text_embeddings( use_f16: bool, device: &Device, dtype: DType, + use_guide_scale: bool, first: bool, ) -> Result<Tensor> { let tokenizer_file = if first { @@ -285,16 +293,6 @@ fn text_embeddings( } let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; - let mut uncond_tokens = tokenizer - .encode(uncond_prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - while uncond_tokens.len() < sd_config.clip.max_position_embeddings { - uncond_tokens.push(pad_id) - } - let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; - println!("Building the Clip transformer."); let clip_weights_file = if first { ModelFile::Clip @@ -310,8 +308,24 @@ fn text_embeddings( let text_model = stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?; let text_embeddings = text_model.forward(&tokens)?; - let uncond_embeddings = text_model.forward(&uncond_tokens)?; - let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?; + + let text_embeddings = if use_guide_scale { + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + + Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)? + } else { + text_embeddings.to_dtype(dtype)? + }; Ok(text_embeddings) } @@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> { unet_weights, tracing, use_f16, + guidance_scale, use_flash_attn, img2img, img2img_strength, @@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> { None }; + let guidance_scale = match guidance_scale { + Some(guidance_scale) => guidance_scale, + None => match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 7.5, + StableDiffusionVersion::Turbo => 0., + }, + }; + let n_steps = match n_steps { + Some(n_steps) => n_steps, + None => match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 30, + StableDiffusionVersion::Turbo => 1, + }, + }; let dtype = if use_f16 { DType::F16 } else { DType::F32 }; let sd_config = match sd_version { StableDiffusionVersion::V1_5 => { @@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> { StableDiffusionVersion::Xl => { stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width) } + StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( + sliced_attention_size, + height, + width, + ), }; let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; + let use_guide_scale = guidance_scale > 1.0; let which = match sd_version { - StableDiffusionVersion::Xl => vec![true, false], + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], _ => vec![true], }; let text_embeddings = which @@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> { use_f16, &device, dtype, + use_guide_scale, *first, ) }) .collect::<Result<Vec<_>>>()?; + let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; println!("{text_embeddings:?}"); @@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> { 0 }; let bsize = 1; + + let vae_scale = match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 0.18215, + StableDiffusionVersion::Turbo => 0.13025, + }; + for idx in 0..num_samples { let timesteps = scheduler.timesteps(); let latents = match &init_latent_dist { Some(init_latent_dist) => { - let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?; + let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; if t_start < timesteps.len() { let noise = latents.randn_like(0f64, 1f64)?; scheduler.add_noise(&latents, noise, timesteps[t_start])? @@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> { continue; } let start_time = std::time::Instant::now(); - let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + let latent_model_input = if use_guide_scale { + Tensor::cat(&[&latents, &latents], 0)? + } else { + latents.clone() + }; let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; let noise_pred = unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; - let noise_pred = noise_pred.chunk(2, 0)?; - let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); - let noise_pred = - (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?; + + let noise_pred = if use_guide_scale { + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + + (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)? + } else { + noise_pred + }; + latents = scheduler.step(&noise_pred, timestep, &latents)?; let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); if args.intermediary_images { - let image = vae.decode(&(&latents / 0.18215)?)?; + let image = vae.decode(&(&latents / vae_scale)?)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = @@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> { idx + 1, num_samples ); - let image = vae.decode(&(&latents / 0.18215)?)?; + let image = vae.decode(&(&latents / vae_scale)?)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = output_filename(&final_image, idx + 1, num_samples, None); diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index ef7b16a0..a84cb3dc 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.3.1" +version = "0.3.2" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" } +candle = { path = "../candle-core", features = ["cuda"], version = "0.3.2", package = "candle-core" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] @@ -21,4 +21,4 @@ rayon = "1.7.0" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] } +candle-nn = { path = "../candle-nn", version = "0.3.2", features = ["cuda"] } diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index 8117ae13..df7df346 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.3.1" +version = "0.3.2" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 012695dd..7ab45a90 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.3.1" +version = "0.3.2" edition = "2021" description = "Metal kernels for Candle" diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 03622752..e0daabef 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } +candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" } half = { workspace = true } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index a2650634..80b750ed 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -1,4 +1,4 @@ -use candle::Tensor; +use candle::{Result, Tensor}; use serde::Deserialize; #[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)] @@ -21,7 +21,7 @@ pub enum Activation { } impl super::Module for Activation { - fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { match self { Self::Gelu => xs.gelu_erf(), // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 @@ -40,3 +40,60 @@ impl super::Module for Activation { } } } + +#[derive(Clone, Debug)] +pub struct PReLU { + weight: Tensor, + is_scalar: bool, +} + +impl PReLU { + pub fn new(weight: Tensor, is_scalar: bool) -> Self { + Self { weight, is_scalar } + } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn is_scalar(&self) -> bool { + self.is_scalar + } +} + +impl candle::Module for PReLU { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let weight = if self.is_scalar { + self.weight.reshape(())? + } else if xs.rank() >= 2 { + let num_channels = xs.dim(1)?; + let num_weights = self.weight.elem_count(); + if num_weights != num_channels { + candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}") + } + let mut s = vec![1; xs.rank()]; + s[1] = self.weight.elem_count(); + self.weight.reshape(s)? + } else { + self.weight.clone() + }; + let zeros = xs.zeros_like()?; + xs.maximum(&zeros)? + xs.minimum(&zeros)?.broadcast_mul(&weight)? + } +} + +/// Create or initialize a new PReLU layer. +/// +/// This uses some default name for weights, namely `"weight"`. +/// # Arguments +/// +/// * `num_channels` - The number of channels. Use `None` to have as single trainable value and +/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward` +/// function, the input tensor shape `s` should either be one dimension with this number of +/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number. +pub fn prelu(num_channels: Option<usize>, vs: crate::VarBuilder) -> Result<PReLU> { + let init_ws = crate::init::Init::Const(0.25); + // When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1. + let ws = vs.get_with_hints((num_channels.unwrap_or(1),), "weight", init_ws)?; + Ok(PReLU::new(ws, num_channels.is_none())) +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 52d8f0c5..8f00e54c 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -15,7 +15,7 @@ pub mod sequential; pub mod var_builder; pub mod var_map; -pub use activation::Activation; +pub use activation::{prelu, Activation, PReLU}; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; pub use conv::{ conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d, diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 94632296..59a4db8a 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -56,7 +56,7 @@ impl super::Module for Linear { /// Create or initialize a new linear layer. /// -/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`. +/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`. pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; @@ -69,6 +69,7 @@ pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Li Ok(Linear::new(ws, Some(bs))) } +/// Create or initialize a new linear layer without biases. pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 7704bb48..2c671fc5 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -190,4 +190,12 @@ impl AdamW { }; Self::new(vars, params) } + + pub fn params(&self) -> &ParamsAdamW { + &self.params + } + + pub fn set_params(&mut self, params: ParamsAdamW) { + self.params = params; + } } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index cbd238dd..83c86a6f 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -40,7 +40,7 @@ struct TensorData<B: Backend> { /// A trait that defines how tensor data is retrieved. /// /// Typically this would use disk storage in some specific format, or random initialization. -/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most +/// Note that there is a specialized version of this trait (`SimpleBackend`) that can be used most /// of the time. The main restriction is that it doesn't allow for specific args (besides /// initialization hints). pub trait Backend: Send + Sync { @@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors { fn get( &self, - _target_shape: Shape, // The size is not checked for ShardedTensors + target_shape: Shape, // The size is only checked when the world size is 1. path: &str, h: Self::Hints, dtype: DType, dev: &Device, ) -> Result<Tensor> { + if h.world_size == 1 { + // There is no sharding to be applied here so we use the default backend to speed + // things up. + return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev); + } + let Shard { dim, rank, diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index e6fe6d85..c5cb56cf 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-onnx" -version = "0.3.1" +version = "0.3.2" edition = "2021" description = "ONNX support for Candle" @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.1" } +candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.3.2" } prost = "0.12.1" [build-dependencies] diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 05cd3050..e6662a51 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -15,9 +15,9 @@ crate-type = ["cdylib"] [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.3.1" } -candle-onnx = {path= "../candle-onnx", version = "0.3.1", optional = true} +candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.3.2" } +candle-onnx = {path= "../candle-onnx", version = "0.3.2", optional = true} half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py index 38718a46..e0e188cb 100644 --- a/candle-pyo3/py_src/candle/__init__.py +++ b/candle-pyo3/py_src/candle/__init__.py @@ -4,7 +4,8 @@ try: from .candle import * except ImportError as e: # If we are in development mode, or we did not bundle the DLLs, we try to locate them here - # PyO3 wont give us any infomration about what DLLs are missing, so we can only try to load the DLLs and re-import the module + # PyO3 wont give us any information about what DLLs are missing, so we can only try to load + # the DLLs and re-import the module logging.warning("DLLs were not bundled with this package. Trying to locate them...") import os import platform diff --git a/candle-pyo3/py_src/candle/nn/container.py b/candle-pyo3/py_src/candle/nn/container.py index 15ed8dd2..6ece31b6 100644 --- a/candle-pyo3/py_src/candle/nn/container.py +++ b/candle-pyo3/py_src/candle/nn/container.py @@ -363,7 +363,7 @@ class ModuleList(Module): self.add_module(str(offset + i), module) return self - # remove forward alltogether to fallback on Module's _forward_unimplemented + # remove forward altogether to fallback on Module's _forward_unimplemented class ModuleDict(Module): @@ -480,4 +480,4 @@ class ModuleDict(Module): # that's too cumbersome to type correctly with overloads, so we add an ignore here self[m[0]] = m[1] # type: ignore[assignment] - # remove forward alltogether to fallback on Module's _forward_unimplemented + # remove forward altogether to fallback on Module's _forward_unimplemented diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ade00012..90826b98 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -212,7 +212,7 @@ trait MapDType { enum Indexer { Index(usize), Slice(usize, usize), - Elipsis, + Ellipsis, Expand, IndexSelect(Tensor), } @@ -568,7 +568,7 @@ impl PyTensor { "Ellipsis ('...') can only be used at the start of an indexing operation", )); } - Ok((Indexer::Elipsis, dims.len() - (index_argument_count - 1))) + Ok((Indexer::Ellipsis, dims.len() - (index_argument_count - 1))) } else if py_indexer.is_none() { // Handle None e.g. tensor[None, 0] Ok((Indexer::Expand, current_dim)) @@ -616,8 +616,9 @@ impl PyTensor { current_dim += 1; out } - Indexer::Elipsis => { - // Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for + Indexer::Ellipsis => { + // Ellipsis is a special case, it means that all remaining dimensions should be + // selected => advance the current_dim to the last dimension we have indexers for current_dim += dims.len() - (indexers.len() - 1); x } @@ -960,11 +961,11 @@ impl PyTensor { extraction_result: PyResult<T>, err_msg: &'static str, ) -> PyResult<()> { - if let Ok(sucessfull_extraction) = extraction_result { + if let Ok(successful_extraction) = extraction_result { if opt.is_some() { return Err(PyValueError::new_err(err_msg)); } - *opt = Some(sucessfull_extraction); + *opt = Some(successful_extraction); } Ok(()) } @@ -1045,9 +1046,7 @@ impl PyTensor { .map_err(wrap_err)?, (Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?, (None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?, - (None, None) => { - return Err(PyTypeError::new_err("No valide dtype or device specified")) - } + (None, None) => return Err(PyTypeError::new_err("No valid dtype or device specified")), }; Ok(PyTensor(result)) diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index 336f674b..c459ebb3 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -156,7 +156,7 @@ def pyi_file(obj, indent=""): string += function(obj, indent) elif inspect.isgetsetdescriptor(obj): - # TODO it would be interesing to add the setter maybe ? + # TODO it would be interesting to add the setter maybe ? string += f"{indent}@property\n" string += function(obj, indent, text_signature="(self)") diff --git a/candle-pyo3/tests/bindings/test_module.py b/candle-pyo3/tests/bindings/test_module.py index 819dae5b..8d7a670d 100644 --- a/candle-pyo3/tests/bindings/test_module.py +++ b/candle-pyo3/tests/bindings/test_module.py @@ -74,7 +74,7 @@ def test_module_can_load_statedict(): a.load_state_dict(statedict) -def test_module_throws_on_shape_missmatch(): +def test_module_throws_on_shape_mismatch(): class A(Module): def __init__(self): super().__init__() @@ -121,7 +121,7 @@ def test_module_can_load_quantized_tensors(): assert a.t.ggml_dtype == "Q4_0" -def test_module_dequantizes_tensors_automaticaly(): +def test_module_dequantizes_tensors_automatically(): class A(Module): def __init__(self): super().__init__() diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index ef44fc4c..04eaca0d 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -84,7 +84,7 @@ def assert_bool(t: Tensor, expected: bool): assert bool(t.values()) == expected -def test_tensor_supports_equality_opperations_with_scalars(): +def test_tensor_supports_equality_operations_with_scalars(): t = Tensor(42.0) assert_bool(t == 42.0, True) @@ -106,7 +106,7 @@ def test_tensor_supports_equality_opperations_with_scalars(): assert_bool(t <= 42.0, True) -def test_tensor_supports_equality_opperations_with_tensors(): +def test_tensor_supports_equality_operations_with_tensors(): t = Tensor(42.0) same = Tensor(42.0) other = Tensor(43.0) @@ -130,7 +130,7 @@ def test_tensor_supports_equality_opperations_with_tensors(): assert_bool(t <= other, True) -def test_tensor_equality_opperations_can_broadcast(): +def test_tensor_equality_operations_can_broadcast(): # Create a decoder attention mask as a test case # e.g. # [[1,0,0] diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index e72cab69..000702f9 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -12,9 +12,9 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true } -candle-nn = { path = "../candle-nn", version = "0.3.1" } +candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.2", optional = true } +candle-nn = { path = "../candle-nn", version = "0.3.2" } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } rand = { workspace = true } diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index d6826a16..51c524f5 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32; #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[serde(rename_all = "lowercase")] -enum HiddenAct { +pub enum HiddenAct { Gelu, + GeluApproximate, Relu, } @@ -28,6 +29,7 @@ impl HiddenActLayer { match self.act { // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 HiddenAct::Gelu => xs.gelu_erf(), + HiddenAct::GeluApproximate => xs.gelu(), HiddenAct::Relu => xs.relu(), } } @@ -48,7 +50,7 @@ pub struct Config { num_hidden_layers: usize, num_attention_heads: usize, intermediate_size: usize, - hidden_act: HiddenAct, + pub hidden_act: HiddenAct, hidden_dropout_prob: f64, max_position_embeddings: usize, type_vocab_size: usize, diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e822ca14..b0e2fb88 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -57,6 +57,22 @@ impl Config { } } + pub fn v2() -> Self { + Self { + vocab_size: 51200, + n_positions: 2048, + n_embd: 2560, + n_layer: 32, + n_inner: None, + n_head: 32, + rotary_dim: usize::min(32, 2560 / 32), + activation_function: Activation::Gelu, + layer_norm_epsilon: 1e-5, + tie_word_embeddings: false, + pad_vocab_size_multiple: 64, + } + } + // https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json pub fn puffin_phi_v2() -> Self { Self { @@ -372,6 +388,24 @@ pub struct MixFormerSequentialForCausalLM { } impl MixFormerSequentialForCausalLM { + pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_head = vb.pp("lm_head"); + let vb = vb.pp("transformer"); + let embedding = Embedding::new(cfg, vb.pp("embd"))?; + let mut blocks = Vec::new(); + for i in 0..cfg.n_layer { + let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?; + blocks.push(block) + } + let head = CausalLMHead::new(cfg, vb_head)?; + Ok(Self { + embedding, + blocks, + head, + span: tracing::span!(tracing::Level::TRACE, "mixformer"), + }) + } + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let vb = vb.pp("layers"); let embedding = Embedding::new(cfg, vb.pp(0))?; diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs new file mode 100644 index 00000000..ede74d3f --- /dev/null +++ b/candle-transformers/src/models/mixtral.rs @@ -0,0 +1,499 @@ +use crate::models::with_tracing::{linear_no_bias, Linear}; +/// Mixtral Model +/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py +/// https://mistral.ai/news/mixtral-of-experts/ +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use serde::Deserialize; +use std::sync::Arc; + +/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub(crate) vocab_size: usize, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: usize, + pub(crate) hidden_act: Activation, + pub(crate) max_position_embeddings: usize, + pub(crate) rms_norm_eps: f64, + pub(crate) rope_theta: f64, + pub(crate) sliding_window: usize, + pub(crate) num_experts_per_tok: usize, + pub(crate) num_local_experts: usize, + pub(crate) use_flash_attn: bool, +} + +impl Config { + /// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json + pub fn v0_1_8x7b(use_flash_attn: bool) -> Self { + Self { + vocab_size: 32000, + hidden_size: 4096, + intermediate_size: 14336, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 8, + hidden_act: Activation::Silu, + max_position_embeddings: 32768, + rms_norm_eps: 1e-5, + rope_theta: 1e6, + sliding_window: 4096, + num_experts_per_tok: 2, + num_local_experts: 8, + use_flash_attn, + } + } +} + +#[derive(Debug, Clone)] +struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +fn rotate_half(xs: &Tensor) -> Result<Tensor> { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; + let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + Ok((q_embed, k_embed)) + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc<RotaryEmbedding>, + kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, +} + +impl Attention { + fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + kv_cache: None, + use_flash_attn: cfg.use_flash_attn, + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> { + let n_rep = self.num_kv_groups; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = self.repeat_kv(key_states)?; + let value_states = self.repeat_kv(value_states)?; + + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? + } else { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct BlockSparseTop2MLP { + w1: Linear, + w2: Linear, + w3: Linear, + act_fn: Activation, +} + +impl BlockSparseTop2MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?; + let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?; + let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?; + Ok(Self { + w1, + w2, + w3, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for BlockSparseTop2MLP { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.w3)?; + (lhs * rhs)?.apply(&self.w2) + } +} + +#[derive(Debug, Clone)] +struct SparseMoeBlock { + gate: Linear, + experts: Vec<BlockSparseTop2MLP>, + num_experts_per_tok: usize, +} + +impl SparseMoeBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?; + let mut experts = Vec::with_capacity(cfg.num_local_experts); + let vb = vb.pp("experts"); + for idx in 0..cfg.num_local_experts { + let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?; + experts.push(expert) + } + Ok(SparseMoeBlock { + gate, + experts, + num_experts_per_tok: cfg.num_experts_per_tok, + }) + } +} + +impl Module for SparseMoeBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = xs.apply(&self.gate)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; self.experts.len()]; + let mut selected_rws = vec![vec![]; self.experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(self.num_experts_per_tok) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(self.num_experts_per_tok) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + selected_rws[expert_idx].push(routing_weight / sum_routing_weights) + } + } + + // routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in self.experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_rws = + Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + block_sparse_moe: SparseMoeBlock, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + block_sparse_moe, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs + .apply(&self.post_attention_layernorm)? + .apply(&self.block_sparse_moe)?; + residual + xs + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec<DecoderLayer>, + norm: RmsNorm, + lm_head: Linear, + sliding_window: usize, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + sliding_window: cfg.sliding_window, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result<Tensor> { + // Sliding window mask? + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + self.sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a9a56673..94a3bd5b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -14,6 +14,7 @@ pub mod llama2_c_weights; pub mod marian; pub mod mistral; pub mod mixformer; +pub mod mixtral; pub mod mpt; pub mod persimmon; pub mod quantized_blip; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 44d89f40..1fb2d9e2 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -48,15 +48,109 @@ impl QMatMul { } #[derive(Debug, Clone)] +struct Mlp { + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let w1 = self.feed_forward_w1.forward(xs)?; + let w3 = self.feed_forward_w3.forward(xs)?; + self.feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?) + } +} + +#[derive(Debug, Clone)] +enum MlpOrMoe { + Mlp(Mlp), + MoE { + n_expert_used: usize, + feed_forward_gate_inp: QMatMul, + experts: Vec<Mlp>, + }, +} + +impl Module for MlpOrMoe { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match self { + Self::MoE { + feed_forward_gate_inp, + experts, + n_expert_used, + } => { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = feed_forward_gate_inp.forward(&xs)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; experts.len()]; + let mut selected_rws = vec![vec![]; experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + selected_rws[expert_idx].push(routing_weight / sum_routing_weights) + } + } + + // routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_rws = + Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())? + .reshape(((), 1))?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = + current_hidden_states.broadcast_mul(&selected_rws)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } + Self::Mlp(mlp) => mlp.forward(xs), + } + } +} + +#[derive(Debug, Clone)] struct LayerWeights { attention_wq: QMatMul, attention_wk: QMatMul, attention_wv: QMatMul, attention_wo: QMatMul, attention_norm: RmsNorm, - feed_forward_w1: QMatMul, - feed_forward_w2: QMatMul, - feed_forward_w3: QMatMul, + mlp_or_moe: MlpOrMoe, ffn_norm: RmsNorm, n_head: usize, n_kv_head: usize, @@ -212,9 +306,16 @@ impl ModelWeights { let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; - let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; - let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; - let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let mlp_or_moe = { + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + }; let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); @@ -226,9 +327,7 @@ impl ModelWeights { attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::new(attention_norm, 1e-5)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + mlp_or_moe, ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, @@ -265,6 +364,12 @@ impl ModelWeights { }; // Parameter extraction from metadata. + let n_expert = md_get("llama.expert_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let n_expert_used = md_get("llama.expert_used_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; let block_count = md_get("llama.block_count")?.to_u32()? as usize; @@ -289,9 +394,38 @@ impl ModelWeights { let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let mlp_or_moe = if n_expert <= 1 { + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + } else { + let feed_forward_gate_inp = + ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?; + let mut experts = Vec::with_capacity(n_expert); + for i in 0..n_expert { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?; + experts.push(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + } + MlpOrMoe::MoE { + n_expert_used, + feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?, + experts, + } + }; let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); @@ -303,9 +437,7 @@ impl ModelWeights { attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + mlp_or_moe, ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, n_head: head_count, n_kv_head: head_count_kv, @@ -360,12 +492,9 @@ impl ModelWeights { let _enter = layer.span_mlp.enter(); let residual = &x; let x = layer.ffn_norm.forward(&x)?; - let w1 = layer.feed_forward_w1.forward(&x)?; - let w3 = layer.feed_forward_w3.forward(&x)?; - let mlp = layer - .feed_forward_w2 - .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; - layer_in = (mlp + residual)?; + let x = layer.mlp_or_moe.forward(&x)?; + let x = (x + residual)?; + layer_in = x } let x = self.norm.forward(&layer_in)?; let x = x.i((.., seq_len - 1, ..))?; diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index f11f2036..1a3cd4ac 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -287,6 +287,24 @@ pub struct MixFormerSequentialForCausalLM { } impl MixFormerSequentialForCausalLM { + pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_head = vb.pp("lm_head"); + let vb = vb.pp("transformer"); + let embedding = Embedding::new(cfg, vb.pp("embd"))?; + let mut blocks = Vec::new(); + for i in 0..cfg.n_layer { + let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?; + blocks.push(block) + } + let head = CausalLMHead::new(cfg, vb_head)?; + Ok(Self { + embedding, + blocks, + head, + span: tracing::span!(tracing::Level::TRACE, "mixformer"), + }) + } + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let vb = vb.pp("layers"); let embedding = Embedding::new(cfg, vb.pp(0))?; diff --git a/candle-transformers/src/models/segment_anything/mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs index 2a91cd44..1703c809 100644 --- a/candle-transformers/src/models/segment_anything/mask_decoder.rs +++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs @@ -182,7 +182,7 @@ impl MaskDecoder { sparse_prompt_embeddings: &Tensor, dense_prompt_embeddings: &Tensor, ) -> Result<(Tensor, Tensor)> { - // Concatenate ouput tokens. + // Concatenate output tokens. let output_tokens = Tensor::cat( &[self.iou_token.embeddings(), self.mask_tokens.embeddings()], 0, diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs index 9d0074b1..16e8a4e8 100644 --- a/candle-transformers/src/models/segment_anything/prompt_encoder.rs +++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs @@ -2,11 +2,11 @@ use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::VarBuilder; #[derive(Debug)] -struct PostionEmbeddingRandom { +struct PositionEmbeddingRandom { positional_encoding_gaussian_matrix: Tensor, } -impl PostionEmbeddingRandom { +impl PositionEmbeddingRandom { fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> { let positional_encoding_gaussian_matrix = vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?; @@ -52,7 +52,7 @@ impl PostionEmbeddingRandom { #[derive(Debug)] pub struct PromptEncoder { - pe_layer: PostionEmbeddingRandom, + pe_layer: PositionEmbeddingRandom, point_embeddings: Vec<candle_nn::Embedding>, not_a_point_embed: candle_nn::Embedding, mask_downscaling_conv1: candle_nn::Conv2d, @@ -76,7 +76,7 @@ impl PromptEncoder { vb: VarBuilder, ) -> Result<Self> { let num_points_embeddings = 4; - let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; + let pe_layer = PositionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?; let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?; let cfg = candle_nn::Conv2dConfig { diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs index 916b7349..d804ed56 100644 --- a/candle-transformers/src/models/stable_diffusion/ddim.rs +++ b/candle-transformers/src/models/stable_diffusion/ddim.rs @@ -7,7 +7,9 @@ //! //! Denoising Diffusion Implicit Models, J. Song et al, 2020. //! https://arxiv.org/abs/2010.02502 -use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use super::schedulers::{ + betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing, +}; use candle::{Result, Tensor}; /// The configuration for the DDIM scheduler. @@ -29,6 +31,8 @@ pub struct DDIMSchedulerConfig { pub prediction_type: PredictionType, /// number of diffusion steps used to train the model pub train_timesteps: usize, + /// time step spacing for the diffusion process + pub timestep_spacing: TimestepSpacing, } impl Default for DDIMSchedulerConfig { @@ -41,10 +45,17 @@ impl Default for DDIMSchedulerConfig { steps_offset: 1, prediction_type: PredictionType::Epsilon, train_timesteps: 1000, + timestep_spacing: TimestepSpacing::Leading, } } } +impl SchedulerConfig for DDIMSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> { + Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?)) + } +} + /// The DDIM scheduler. #[derive(Debug, Clone)] pub struct DDIMScheduler { @@ -60,12 +71,32 @@ impl DDIMScheduler { /// Creates a new DDIM scheduler given the number of steps to be /// used for inference as well as the number of steps that was used /// during training. - pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { + fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { let step_ratio = config.train_timesteps / inference_steps; - let timesteps: Vec<usize> = (0..(inference_steps)) - .map(|s| s * step_ratio + config.steps_offset) - .rev() - .collect(); + let timesteps: Vec<usize> = match config.timestep_spacing { + TimestepSpacing::Leading => (0..(inference_steps)) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(), + TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| { + if *n > step_ratio { + Some(n - step_ratio) + } else { + None + } + }) + .map(|n| n - 1) + .collect(), + TimestepSpacing::Linspace => { + super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)? + .to_vec1::<f64>()? + .iter() + .map(|&f| f as usize) + .rev() + .collect() + } + }; + let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => super::utils::linspace( config.beta_start.sqrt(), @@ -92,19 +123,11 @@ impl DDIMScheduler { config, }) } +} - pub fn timesteps(&self) -> &[usize] { - self.timesteps.as_slice() - } - - /// Ensures interchangeability with schedulers that need to scale the denoising model input - /// depending on the current timestep. - pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> { - Ok(sample) - } - +impl Scheduler for DDIMScheduler { /// Performs a backward step during inference. - pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { @@ -163,7 +186,17 @@ impl DDIMScheduler { } } - pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> { + Ok(sample) + } + + fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { @@ -174,7 +207,7 @@ impl DDIMScheduler { (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? } - pub fn init_noise_sigma(&self) -> f64 { + fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } } diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs new file mode 100644 index 00000000..9576c2de --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs @@ -0,0 +1,235 @@ +//! Ancestral sampling with Euler method steps. +//! +//! Reference implementation in Rust: +//! +//! https://github.com/pykeio/diffusers/blob/250b9ad1898af41e76a74c0d8d4292652823338a/src/schedulers/euler_ancestral_discrete.rs +//! +//! Based on the original [`k-diffusion` implementation by Katherine Crowson][kd]. +/// +/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 +use super::{ + schedulers::{ + betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, + TimestepSpacing, + }, + utils::interp, +}; +use candle::{bail, Error, Result, Tensor}; + +/// The configuration for the EulerAncestral Discrete scheduler. +#[derive(Debug, Clone, Copy)] +pub struct EulerAncestralDiscreteSchedulerConfig { + /// The value of beta at the beginning of training.n + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Adjust the indexes of the inference schedule by this value. + pub steps_offset: usize, + /// prediction type of the scheduler function, one of `epsilon` (predicting + /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`) + /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model + pub train_timesteps: usize, + /// time step spacing for the diffusion process + pub timestep_spacing: TimestepSpacing, +} + +impl Default for EulerAncestralDiscreteSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085f64, + beta_end: 0.012f64, + beta_schedule: BetaSchedule::ScaledLinear, + steps_offset: 1, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + timestep_spacing: TimestepSpacing::Leading, + } + } +} + +impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> { + Ok(Box::new(EulerAncestralDiscreteScheduler::new( + inference_steps, + *self, + )?)) + } +} + +/// The EulerAncestral Discrete scheduler. +#[derive(Debug, Clone)] +pub struct EulerAncestralDiscreteScheduler { + timesteps: Vec<usize>, + sigmas: Vec<f64>, + init_noise_sigma: f64, + pub config: EulerAncestralDiscreteSchedulerConfig, +} + +// clip_sample: False, set_alpha_to_one: False +impl EulerAncestralDiscreteScheduler { + /// Creates a new EulerAncestral Discrete scheduler given the number of steps to be + /// used for inference as well as the number of steps that was used + /// during training. + pub fn new( + inference_steps: usize, + config: EulerAncestralDiscreteSchedulerConfig, + ) -> Result<Self> { + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec<usize> = match config.timestep_spacing { + TimestepSpacing::Leading => (0..(inference_steps)) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(), + TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| { + if *n > step_ratio { + Some(n - step_ratio) + } else { + None + } + }) + .map(|n| n - 1) + .collect(), + TimestepSpacing::Linspace => { + super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)? + .to_vec1::<f64>()? + .iter() + .map(|&f| f as usize) + .rev() + .collect() + } + }; + + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => super::utils::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps, + )? + .sqr()?, + BetaSchedule::Linear => { + super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + } + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, + }; + let betas = betas.to_vec1::<f64>()?; + let mut alphas_cumprod = Vec::with_capacity(betas.len()); + for &beta in betas.iter() { + let alpha = 1.0 - beta; + alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64)) + } + let sigmas: Vec<f64> = alphas_cumprod + .iter() + .map(|&f| ((1. - f) / f).sqrt()) + .collect(); + + let sigmas_xa: Vec<_> = (0..sigmas.len()).map(|i| i as f64).collect(); + + let mut sigmas_int = interp( + ×teps.iter().map(|&t| t as f64).collect::<Vec<_>>(), + &sigmas_xa, + &sigmas, + ); + sigmas_int.push(0.0); + + // standard deviation of the initial noise distribution + // f64 does not implement Ord such that there is no `max`, so we need to use this workaround + let init_noise_sigma = *sigmas_int + .iter() + .chain(std::iter::once(&0.0)) + .reduce(|a, b| if a > b { a } else { b }) + .expect("init_noise_sigma could not be reduced from sigmas - this should never happen"); + + Ok(Self { + sigmas: sigmas_int, + timesteps, + init_noise_sigma, + config, + }) + } +} + +impl Scheduler for EulerAncestralDiscreteScheduler { + fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + /// + /// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm + fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> { + let step_index = match self.timesteps.iter().position(|&t| t == timestep) { + Some(i) => i, + None => bail!("timestep out of this schedulers bounds: {timestep}"), + }; + + let sigma = self + .sigmas + .get(step_index) + .expect("step_index out of sigma bounds - this shouldn't happen"); + + sample / ((sigma.powi(2) + 1.).sqrt()) + } + + /// Performs a backward step during inference. + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + let step_index = self + .timesteps + .iter() + .position(|&p| p == timestep) + .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?; + + let sigma_from = &self.sigmas[step_index]; + let sigma_to = &self.sigmas[step_index + 1]; + + // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + let pred_original_sample = match self.config.prediction_type { + PredictionType::Epsilon => (sample - (model_output * *sigma_from))?, + PredictionType::VPrediction => { + ((model_output * (-sigma_from / (sigma_from.powi(2) + 1.0).sqrt()))? + + (sample / (sigma_from.powi(2) + 1.0))?)? + } + PredictionType::Sample => bail!("prediction_type not implemented yet: sample"), + }; + + let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2)) + / sigma_from.powi(2)) + .sqrt(); + let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt(); + + // 2. convert to a ODE derivative + let derivative = ((sample - pred_original_sample)? / *sigma_from)?; + let dt = sigma_down - *sigma_from; + let prev_sample = (sample + derivative * dt)?; + + let noise = prev_sample.randn_like(0.0, 1.0)?; + + prev_sample + noise * sigma_up + } + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + let step_index = self + .timesteps + .iter() + .position(|&p| p == timestep) + .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?; + + let sigma = self + .sigmas + .get(step_index) + .expect("step_index out of sigma bounds - this shouldn't happen"); + + original + (noise * *sigma)? + } + + fn init_noise_sigma(&self) -> f64 { + match self.config.timestep_spacing { + TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma, + TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(), + } + } +} diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 66ef7149..30f23975 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -3,6 +3,7 @@ pub mod clip; pub mod ddim; pub mod ddpm; pub mod embeddings; +pub mod euler_ancestral_discrete; pub mod resnet; pub mod schedulers; pub mod unet_2d; @@ -10,9 +11,13 @@ pub mod unet_2d_blocks; pub mod utils; pub mod vae; +use std::sync::Arc; + use candle::{DType, Device, Result}; use candle_nn as nn; +use self::schedulers::{Scheduler, SchedulerConfig}; + #[derive(Clone, Debug)] pub struct StableDiffusionConfig { pub width: usize, @@ -21,7 +26,7 @@ pub struct StableDiffusionConfig { pub clip2: Option<clip::Config>, autoencoder: vae::AutoEncoderKLConfig, unet: unet_2d::UNet2DConditionModelConfig, - scheduler: ddim::DDIMSchedulerConfig, + scheduler: Arc<dyn SchedulerConfig>, } impl StableDiffusionConfig { @@ -75,13 +80,18 @@ impl StableDiffusionConfig { 512 }; - Self { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { + prediction_type: schedulers::PredictionType::Epsilon, + ..Default::default() + }); + + StableDiffusionConfig { width, height, clip: clip::Config::v1_5(), clip2: None, autoencoder, - scheduler: Default::default(), + scheduler, unet, } } @@ -124,10 +134,10 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -143,7 +153,7 @@ impl StableDiffusionConfig { 768 }; - Self { + StableDiffusionConfig { width, height, clip: clip::Config::v2_1(), @@ -205,10 +215,10 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -224,6 +234,76 @@ impl StableDiffusionConfig { 1024 }; + StableDiffusionConfig { + width, + height, + clip: clip::Config::sdxl(), + clip2: Some(clip::Config::sdxl2()), + autoencoder, + scheduler, + unet, + } + } + + fn sdxl_turbo_( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + prediction_type: schedulers::PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], + center_input_sample: false, + cross_attention_dim: 2048, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = Arc::new( + euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig { + prediction_type, + timestep_spacing: schedulers::TimestepSpacing::Trailing, + ..Default::default() + }, + ); + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 512 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 512 + }; + Self { width, height, @@ -249,6 +329,20 @@ impl StableDiffusionConfig { ) } + pub fn sdxl_turbo( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + Self::sdxl_turbo_( + sliced_attention_size, + height, + width, + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json + schedulers::PredictionType::Epsilon, + ) + } + pub fn ssd1b( sliced_attention_size: Option<usize>, height: Option<usize>, @@ -285,9 +379,9 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -347,8 +441,8 @@ impl StableDiffusionConfig { Ok(unet) } - pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> { - ddim::DDIMScheduler::new(n_steps, self.scheduler) + pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> { + self.scheduler.build(n_steps) } } diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs index 3f6a1d72..0f0441e0 100644 --- a/candle-transformers/src/models/stable_diffusion/schedulers.rs +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs @@ -3,9 +3,25 @@ //! //! Noise schedulers can be used to set the trade-off between //! inference speed and quality. - use candle::{Result, Tensor}; +pub trait SchedulerConfig: std::fmt::Debug { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>; +} + +/// This trait represents a scheduler for the diffusion process. +pub trait Scheduler { + fn timesteps(&self) -> &[usize]; + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>; + + fn init_noise_sigma(&self) -> f64; + + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>; + + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>; +} + /// This represents how beta ranges from its minimum value to the maximum /// during training. #[derive(Debug, Clone, Copy)] @@ -25,6 +41,22 @@ pub enum PredictionType { Sample, } +/// Time step spacing for the diffusion process. +/// +/// "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 +#[derive(Debug, Clone, Copy)] +pub enum TimestepSpacing { + Leading, + Linspace, + Trailing, +} + +impl Default for TimestepSpacing { + fn default() -> Self { + Self::Leading + } +} + /// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of /// `(1-beta)` over time from `t = [0,1]`. /// diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs index cef06f1c..5b5fa0f7 100644 --- a/candle-transformers/src/models/stable_diffusion/utils.rs +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -13,3 +13,49 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { Tensor::from_vec(vs, steps, &Device::Cpu) } } + +/// A linear interpolator for a sorted array of x and y values. +struct LinearInterpolator<'x, 'y> { + xp: &'x [f64], + fp: &'y [f64], + cache: usize, +} + +impl<'x, 'y> LinearInterpolator<'x, 'y> { + fn accel_find(&mut self, x: f64) -> usize { + let xidx = self.cache; + if x < self.xp[xidx] { + self.cache = self.xp[0..xidx].partition_point(|o| *o < x); + self.cache = self.cache.saturating_sub(1); + } else if x >= self.xp[xidx + 1] { + self.cache = self.xp[xidx..self.xp.len()].partition_point(|o| *o < x) + xidx; + self.cache = self.cache.saturating_sub(1); + } + + self.cache + } + + fn eval(&mut self, x: f64) -> f64 { + if x < self.xp[0] || x > self.xp[self.xp.len() - 1] { + return f64::NAN; + } + + let idx = self.accel_find(x); + + let x_l = self.xp[idx]; + let x_h = self.xp[idx + 1]; + let y_l = self.fp[idx]; + let y_h = self.fp[idx + 1]; + let dx = x_h - x_l; + if dx > 0.0 { + y_l + (x - x_l) / dx * (y_h - y_l) + } else { + f64::NAN + } + } +} + +pub fn interp(x: &[f64], xp: &[f64], fp: &[f64]) -> Vec<f64> { + let mut interpolator = LinearInterpolator { xp, fp, cache: 0 }; + x.iter().map(|&x| interpolator.eval(x)).collect() +} diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml index 0a3e13e0..33edaaad 100644 --- a/candle-wasm-examples/bert/Cargo.toml +++ b/candle-wasm-examples/bert/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.1" } +candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.3.2" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/blip/Cargo.toml b/candle-wasm-examples/blip/Cargo.toml index e77e8595..e73200c5 100644 --- a/candle-wasm-examples/blip/Cargo.toml +++ b/candle-wasm-examples/blip/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.1" } +candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.3.2" } tokenizers = { workspace = true, features = ["unstable_wasm"] } num-traits = { workspace = true } diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index 3dfcc4c2..77780c85 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.1" } +candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.3.2" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs index ea04a810..f254a5ae 100644 --- a/candle-wasm-examples/llama2-c/src/app.rs +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -108,7 +108,7 @@ impl Component for App { fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool { match msg { Msg::SetModel(md) => { - self.status = "weights loaded succesfully!".to_string(); + self.status = "weights loaded successfully!".to_string(); self.loaded = true; console_log!("loaded weights"); self.worker.send(WorkerInput::ModelData(md)); diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 79dd2f32..e38561b9 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -24,7 +24,7 @@ macro_rules! console_log { } // Communication to the worker happens through bincode, the model weights and configs are fetched -// on the main thread and transfered via the following structure. +// on the main thread and transferred via the following structure. #[derive(Serialize, Deserialize)] pub struct ModelData { pub tokenizer: Vec<u8>, diff --git a/candle-wasm-examples/phi/Cargo.toml b/candle-wasm-examples/phi/Cargo.toml index 2969420a..ca07bd99 100644 --- a/candle-wasm-examples/phi/Cargo.toml +++ b/candle-wasm-examples/phi/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.1" } +candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.3.2" } tokenizers = { workspace = true, features = ["unstable_wasm"] } num-traits = { workspace = true } diff --git a/candle-wasm-examples/phi/index.html b/candle-wasm-examples/phi/index.html index 19c6a586..dbef698a 100644 --- a/candle-wasm-examples/phi/index.html +++ b/candle-wasm-examples/phi/index.html @@ -1,7 +1,7 @@ <html> <head> <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> - <title>Candle Phi 1.5 Rust/WASM</title> + <title>Candle Phi 1.5 / Phi 2.0 Rust/WASM</title> </head> <body></body> </html> @@ -39,7 +39,7 @@ import hljs from "https://cdn.skypack.dev/highlight.js"; // models base url const MODELS = { - phi_1_5_quantized: { + phi_1_5_q4k: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-q4k.gguf", @@ -49,7 +49,7 @@ seq_len: 2048, size: "800 MB", }, - phi_1_5_quantized_2: { + phi_1_5_q80: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-q80.gguf", @@ -59,7 +59,21 @@ seq_len: 2048, size: "1.51 GB", }, - puffin_phi_v2_quantized: { + phi_2_0_q4k: { + base_url: + "https://huggingface.co/radames/phi-2-quantized/resolve/main/", + model: [ + "model-v2-q4k.gguf_aa.part", + "model-v2-q4k.gguf_ab.part", + "model-v2-q4k.gguf_ac.part", + ], + tokenizer: "tokenizer.json", + config: "config.json", + quantized: true, + seq_len: 2048, + size: "1.57GB", + }, + puffin_phi_v2_q4k: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-puffin-phi-v2-q4k.gguf", @@ -69,7 +83,7 @@ seq_len: 2048, size: "798 MB", }, - puffin_phi_v2_quantized_2: { + puffin_phi_v2_q80: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-puffin-phi-v2-q80.gguf", @@ -106,8 +120,8 @@ Let’s think step by step.`, }, { title: "Question answering", - prompt: `What is the capital of France? -Answer:`, + prompt: `Instruct: What is the capital of France? +Output:`, }, { title: "Chat mode", @@ -148,7 +162,10 @@ Very polite review:`, const getValue = (id) => document.querySelector(`#${id}`).value; const modelID = getValue("model"); const model = MODELS[modelID]; - const weightsURL = model.base_url + model.model; + const weightsURL = + model.model instanceof Array + ? model.model.map((m) => model.base_url + m) + : model.base_url + model.model; const tokenizerURL = model.base_url + model.tokenizer; const configURL = model.base_url + model.config; @@ -246,6 +263,13 @@ Very polite review:`, option.innerText = `${id} (${model.size})`; modelSelect.appendChild(option); } + const query = new URLSearchParams(window.location.search); + const modelID = query.get("model"); + if (modelID) { + modelSelect.value = modelID; + } else { + modelSelect.value = "phi_1_5_q4k"; + } for (const [i, { title, prompt }] of TEMPLATES.entries()) { const div = document.createElement("div"); @@ -270,8 +294,18 @@ Very polite review:`, prompt.value = template; prompt.style.height = "auto"; prompt.style.height = prompt.scrollHeight + "px"; + runBtn.disabled = false; + clearBtn.classList.remove("invisible"); }); modelSelect.addEventListener("change", (e) => { + const query = new URLSearchParams(window.location.search); + query.set("model", e.target.value); + window.history.replaceState( + {}, + "", + `${window.location.pathname}?${query}` + ); + window.parent.postMessage({ queryString: "?" + query }, "*"); const model = MODELS[e.target.value]; document.querySelector("#max-seq").max = model.seq_len; document.querySelector("#max-seq").nextElementSibling.value = 200; @@ -320,7 +354,7 @@ Very polite review:`, <main class="grid grid-cols-1 gap-8 relative"> <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span> <div> - <h1 class="text-5xl font-bold">Candle Phi 1.5</h1> + <h1 class="text-5xl font-bold">Candle Phi 1.5 / Phi 2.0</h1> <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> <p class="max-w-lg"> The @@ -330,10 +364,17 @@ Very polite review:`, target="_blank" >Phi-1.5</a > - model achieves state-of-the-art performance with only 1.3 billion - parameters, compared to models with up to 10 billion. You can try the - quantized version of the model here. Additional prompt examples are - available in the + and + <a + href="https://huggingface.co/microsoft/phi-2" + class="link" + target="_blank" + >Phi-2</a + > + models achieve state-of-the-art performance with only 1.3 billion and + 2.7 billion parameters, compared to larger models with up to 13 + billion parameters. Here you can try the quantized versions. + Additional prompt examples are available in the <a href="https://arxiv.org/pdf/2309.05463.pdf#page=8" class="link" @@ -350,7 +391,7 @@ Very polite review:`, target="_blank" >Puffin-Phi V2 </a> - quantized version model, a fine-tuned version of Phi-1.5 on the + quantized version, a fine-tuned version of Phi-1.5 on the <a href="https://huggingface.co/datasets/LDJnr/Puffin" class="link" @@ -363,7 +404,7 @@ Very polite review:`, <p class="text-xs italic max-w-lg"> <b>Note:</b> When first run, the app will download and cache the model, which could - take a few minutes. The models are <b>~800MB</b> or <b>~1.51GB</b> in + take a few minutes. The models are <b>~800MB</b> or <b>~1.57GB</b> in size. </p> </div> @@ -375,8 +416,13 @@ Very polite review:`, ></select> </div> <div> - <h3 class="font-medium">Prompt Templates</h3> - <form id="prompt-templates" class="flex flex-col gap-1 my-2"></form> + <details> + <summary class="font-medium cursor-pointer">Prompt Templates</summary> + <form + id="prompt-templates" + class="grid grid-cols-1 sm:grid-cols-2 gap-1 my-2" + ></form> + </details> </div> <form id="form" @@ -386,12 +432,12 @@ Very polite review:`, <textarea type="text" id="prompt" - class="font-light w-full px-3 py-2 mx-1 resize-none outline-none" + class="font-light text-lg w-full px-3 py-2 mx-1 resize-none outline-none" oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'" placeholder="Add your prompt here..." > -Write a detailed analogy between mathematics and a lighthouse. -Answer:</textarea +Instruct: Write a detailed analogy between mathematics and a lighthouse. +Output:</textarea > <button id="clear-btn"> <svg @@ -517,9 +563,9 @@ Answer:</textarea <div id="output-counter" hidden - class="ml-auto font-semibold grid-rows-1 text-sm" + class="ml-auto font-semibold grid-rows-1" ></div> - <p hidden id="output-generation" class="grid-rows-2"></p> + <p hidden id="output-generation" class="grid-rows-2 text-lg"></p> <span id="output-status" class="m-auto font-light" >No output yet</span > diff --git a/candle-wasm-examples/phi/phiWorker.js b/candle-wasm-examples/phi/phiWorker.js index 5c030f1d..bb71b409 100644 --- a/candle-wasm-examples/phi/phiWorker.js +++ b/candle-wasm-examples/phi/phiWorker.js @@ -12,6 +12,20 @@ async function fetchArrayBuffer(url) { cache.put(url, res.clone()); return new Uint8Array(await res.arrayBuffer()); } +async function concatenateArrayBuffers(urls) { + const arrayBuffers = await Promise.all(urls.map(url => fetchArrayBuffer(url))); + + let totalLength = arrayBuffers.reduce((acc, arrayBuffer) => acc + arrayBuffer.byteLength, 0); + let concatenatedBuffer = new Uint8Array(totalLength); + + let offset = 0; + arrayBuffers.forEach(buffer => { + concatenatedBuffer.set(new Uint8Array(buffer), offset); + offset += buffer.byteLength; + }); + return concatenatedBuffer; +} + class Phi { static instance = {}; @@ -27,10 +41,9 @@ class Phi { await init(); self.postMessage({ status: "loading", message: "Loading Model" }); - const [weightsArrayU8, tokenizerArrayU8, configArrayU8] = await Promise.all([ - fetchArrayBuffer(weightsURL), + weightsURL instanceof Array ? concatenateArrayBuffers(weightsURL) : fetchArrayBuffer(weightsURL), fetchArrayBuffer(tokenizerURL), fetchArrayBuffer(configURL), ]); diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs index c18e6c38..999f276d 100644 --- a/candle-wasm-examples/phi/src/bin/m.rs +++ b/candle-wasm-examples/phi/src/bin/m.rs @@ -5,6 +5,7 @@ use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausa use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle_wasm_example_phi::console_log; use js_sys::Date; +use serde::Deserialize; use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; @@ -23,6 +24,12 @@ pub struct Model { repeat_last_n: usize, } +#[derive(Debug, Clone, PartialEq, Deserialize)] + +pub struct ModelName { + pub _name_or_path: String, +} + #[wasm_bindgen] impl Model { #[wasm_bindgen(constructor)] @@ -34,15 +41,25 @@ impl Model { ) -> Result<Model, JsError> { console_error_panic_hook::set_once(); console_log!("loading model"); + let name: ModelName = serde_json::from_slice(&config)?; let config: Config = serde_json::from_slice(&config)?; + + console_log!("config loaded {:?}", name); let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; let start = Date::now(); + console_log!("weights len: {:?}", weights.len()); let model = if quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?; - let model = QMixFormer::new(&config, vb)?; - SelectedModel::Quantized(model) + console_log!("weights loaded"); + if name._name_or_path == "microsoft/phi-2" { + let model = QMixFormer::new_v2(&config, vb)?; + SelectedModel::Quantized(model) + } else { + let model = QMixFormer::new(&config, vb)?; + SelectedModel::Quantized(model) + } } else { let device = &Device::Cpu; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?; diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml index c2787d64..06653b1f 100644 --- a/candle-wasm-examples/segment-anything/Cargo.toml +++ b/candle-wasm-examples/segment-anything/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.1" } +candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.3.2" } num-traits = { workspace = true } # App crates. diff --git a/candle-wasm-examples/t5/Cargo.toml b/candle-wasm-examples/t5/Cargo.toml index 5fec7ea5..86a69fae 100644 --- a/candle-wasm-examples/t5/Cargo.toml +++ b/candle-wasm-examples/t5/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.1" } +candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.3.2" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 864d0688..63ecefd6 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -9,9 +9,9 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.1" } -candle-transformers = { path = "../../candle-transformers", version = "0.3.1" } +candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.3.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.3.2" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/whisper/src/app.rs b/candle-wasm-examples/whisper/src/app.rs index 1bc913d5..1cb31193 100644 --- a/candle-wasm-examples/whisper/src/app.rs +++ b/candle-wasm-examples/whisper/src/app.rs @@ -145,7 +145,7 @@ impl Component for App { fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool { match msg { Msg::SetDecoder(md) => { - self.status = "weights loaded succesfully!".to_string(); + self.status = "weights loaded successfully!".to_string(); self.loaded = true; console_log!("loaded weights"); self.worker.send(WorkerInput::ModelData(md)); diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index db5e8bb1..fd91fa8c 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -414,7 +414,7 @@ pub enum Task { } // Communication to the worker happens through bincode, the model weights and configs are fetched -// on the main thread and transfered via the following structure. +// on the main thread and transferred via the following structure. #[derive(Serialize, Deserialize)] pub struct ModelData { pub weights: Vec<u8>, diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index 9c697fd1..3d8e0a07 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.3.1" } +candle = { path = "../../candle-core", version = "0.3.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.3.2" } num-traits = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-wasm-examples/yolo/src/app.rs b/candle-wasm-examples/yolo/src/app.rs index 0df61f0f..3a88a5f1 100644 --- a/candle-wasm-examples/yolo/src/app.rs +++ b/candle-wasm-examples/yolo/src/app.rs @@ -146,7 +146,7 @@ impl Component for App { fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool { match msg { Msg::SetModel(md) => { - self.status = "weights loaded succesfully!".to_string(); + self.status = "weights loaded successfully!".to_string(); self.loaded = true; console_log!("loaded weights"); self.worker.send(WorkerInput::ModelData(md)); diff --git a/candle-wasm-examples/yolo/src/worker.rs b/candle-wasm-examples/yolo/src/worker.rs index 5733a3fd..1ecef341 100644 --- a/candle-wasm-examples/yolo/src/worker.rs +++ b/candle-wasm-examples/yolo/src/worker.rs @@ -21,7 +21,7 @@ macro_rules! console_log { } // Communication to the worker happens through bincode, the model weights and configs are fetched -// on the main thread and transfered via the following structure. +// on the main thread and transferred via the following structure. #[derive(Serialize, Deserialize)] pub struct ModelData { pub weights: Vec<u8>, diff --git a/candle-wasm-tests/Cargo.toml b/candle-wasm-tests/Cargo.toml index f5fd0eff..5641ccfb 100644 --- a/candle-wasm-tests/Cargo.toml +++ b/candle-wasm-tests/Cargo.toml @@ -7,7 +7,7 @@ keywords.workspace = true categories.workspace = true [dependencies] -candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" } +candle = { path = "../candle-core", version = "0.3.2", package = "candle-core" } rand = { workspace = true } getrandom = { version = "0.2", features = ["js"] } |