diff options
-rw-r--r-- | README.md | 17 | ||||
-rw-r--r-- | candle-core/examples/basics.rs | 11 | ||||
-rw-r--r-- | candle-core/src/indexer.rs | 50 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 6 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 4 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 58 | ||||
-rw-r--r-- | candle-core/tests/indexing_tests.rs | 29 | ||||
-rw-r--r-- | candle-core/tests/quantized_tests.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/distilbert/README.md | 22 | ||||
-rw-r--r-- | candle-examples/examples/distilbert/main.rs | 135 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v3/main.rs | 1 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v8/main.rs | 4 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 624 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 616 | ||||
-rw-r--r-- | candle-nn/examples/cpu_benchmarks.rs | 2 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/distilbert.rs | 342 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/utils.rs | 17 | ||||
-rw-r--r-- | candle-wasm-tests/tests/quantized_tests.rs | 2 |
20 files changed, 1262 insertions, 683 deletions
@@ -139,16 +139,16 @@ And then head over to <!--- ANCHOR: useful_libraries ---> ## Useful External Resources -- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a +- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A very detailed tutorial showing how to convert a PyTorch model to Candle. -- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers +- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has + out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples). +- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop. -- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation - that conforms to the official `peft` implementation. - [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and serving local LLMs including an OpenAI compatible API server. -- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle. -- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. +- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle. +- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more. - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. If you have an addition to this list, please submit a pull request. @@ -177,6 +177,11 @@ If you have an addition to this list, please submit a pull request. - Replit-code-v1.5-3B. - Bert. - Yi-6B and Yi-34B. + - Quantized LLMs. + - Llama 7b, 13b, 70b, as well as the chat and code variants. + - Mistral 7b, and 7b instruct. + - Zephyr 7b a and b (Mistral based). + - OpenChat 3.5 (Mistral based). - Text to text. - T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction). - Marian MT (Machine Translation). diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs index ad008177..fe15187b 100644 --- a/candle-core/examples/basics.rs +++ b/candle-core/examples/basics.rs @@ -8,11 +8,10 @@ use anyhow::Result; use candle_core::{Device, Tensor}; fn main() -> Result<()> { - let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; - let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; - let start = std::time::Instant::now(); - let res = inp.conv2d(&w, 0, 1, 1, 1)?; - println!("{:?}", start.elapsed()); - println!("{res:?}"); + let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?; + let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?; + let new_a = a.slice_scatter(&b, 1, 2)?; + assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); Ok(()) } diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index 7b84d316..df106b73 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -104,37 +104,31 @@ impl From<&Tensor> for TensorIndexer { } } -macro_rules! impl_from_range { - ($range_type:ty) => { - impl From<$range_type> for TensorIndexer { - fn from(range: $range_type) -> Self { - use std::ops::Bound::*; +trait RB: RangeBounds<usize> {} +impl RB for Range<usize> {} +impl RB for RangeFrom<usize> {} +impl RB for RangeFull {} +impl RB for RangeInclusive<usize> {} +impl RB for RangeTo<usize> {} +impl RB for RangeToInclusive<usize> {} - let start = match range.start_bound() { - Included(idx) => Included(*idx), - Excluded(idx) => Excluded(*idx), - Unbounded => Unbounded, - }; - - let end = match range.end_bound() { - Included(idx) => Included(*idx), - Excluded(idx) => Excluded(*idx), - Unbounded => Unbounded, - }; - - TensorIndexer::Narrow(start, end) - } - } - }; +impl<T: RB> From<T> for TensorIndexer { + fn from(range: T) -> Self { + use std::ops::Bound::*; + let start = match range.start_bound() { + Included(idx) => Included(*idx), + Excluded(idx) => Excluded(*idx), + Unbounded => Unbounded, + }; + let end = match range.end_bound() { + Included(idx) => Included(*idx), + Excluded(idx) => Excluded(*idx), + Unbounded => Unbounded, + }; + TensorIndexer::Narrow(start, end) + } } -impl_from_range!(Range<usize>); -impl_from_range!(RangeFrom<usize>); -impl_from_range!(RangeFull); -impl_from_range!(RangeInclusive<usize>); -impl_from_range!(RangeTo<usize>); -impl_from_range!(RangeToInclusive<usize>); - /// Trait used to implement multiple signatures for ease of use of the slicing /// of a tensor pub trait IndexOp<T> { diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 36f5f6b1..6c4fea91 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -123,12 +123,6 @@ pub trait Module { fn forward(&self, xs: &Tensor) -> Result<Tensor>; } -impl Module for quantized::QMatMul { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - self.forward(xs) - } -} - impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T { fn forward(&self, xs: &Tensor) -> Result<Tensor> { self(xs) diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 58f261b4..043733ae 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor { } } -impl QMatMul { - pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { +impl crate::Module for QMatMul { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { match self { Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()), Self::Tensor(w) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ce5858fa..87323a84 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2503,6 +2503,64 @@ impl Tensor { t.transpose(dim, last) } } + + /// Returns a copy of `self` where the values within `ranges` have been replaced with the + /// content of `src`. + pub fn slice_assign<D: std::ops::RangeBounds<usize>>( + &self, + ranges: &[D], + src: &Tensor, + ) -> Result<Self> { + let src_dims = src.dims(); + let self_dims = self.dims(); + if self_dims.len() != src_dims.len() { + crate::bail!( + "slice-assign requires input with the same rank {} <> {}", + self_dims.len(), + src_dims.len() + ) + } + if self_dims.len() != ranges.len() { + crate::bail!( + "slice-assign requires input with the same rank as there are ranges {} <> {}", + self_dims.len(), + ranges.len() + ) + } + let mut src = src.clone(); + let mut mask = Self::ones(src.shape(), DType::U8, src.device())?; + for (i, range) in ranges.iter().enumerate() { + let start_included = match range.start_bound() { + std::ops::Bound::Unbounded => 0, + std::ops::Bound::Included(v) => *v, + std::ops::Bound::Excluded(v) => *v + 1, + }; + let end_excluded = match range.end_bound() { + std::ops::Bound::Unbounded => self_dims[i], + std::ops::Bound::Included(v) => *v + 1, + std::ops::Bound::Excluded(v) => *v, + }; + if end_excluded <= start_included { + crate::bail!( + "slice-assign: empty range for dim {i}, {start_included} {end_excluded}" + ) + } + if self_dims[i] < end_excluded { + crate::bail!( + "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}", + self_dims[i] + ) + } + if end_excluded - start_included != src_dims[i] { + crate::bail!( + "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i] + ) + } + src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?; + mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)? + } + mask.where_cond(/* on_true= */ &src, /* on_false= */ self) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/indexing_tests.rs b/candle-core/tests/indexing_tests.rs index 9c88f319..047205a3 100644 --- a/candle-core/tests/indexing_tests.rs +++ b/candle-core/tests/indexing_tests.rs @@ -91,3 +91,32 @@ fn index_3d() -> Result<()> { assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]); Ok(()) } + +#[test] +fn slice_assign() -> Result<()> { + let dev = Device::Cpu; + + let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?; + let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?; + let out = tensor.slice_assign(&[1..4, 3..5], &src)?; + assert_eq!( + out.to_vec2::<u32>()?, + &[ + [0, 1, 2, 3, 4], + [5, 6, 7, 0, 1], + [10, 11, 12, 2, 3], + [15, 16, 17, 4, 5] + ] + ); + let out = tensor.slice_assign(&[0..3, 0..2], &src)?; + assert_eq!( + out.to_vec2::<u32>()?, + &[ + [0, 1, 2, 3, 4], + [2, 3, 7, 8, 9], + [4, 5, 12, 13, 14], + [15, 16, 17, 18, 19] + ] + ); + Ok(()) +} diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index a2cecbc3..716cca8d 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,7 +1,7 @@ use candle_core::{ quantized::{self, GgmlDType}, test_utils::to_vec2_round, - Device, Result, Tensor, + Device, Module, Result, Tensor, }; use quantized::{k_quants, GgmlType}; use rand::prelude::*; diff --git a/candle-examples/examples/distilbert/README.md b/candle-examples/examples/distilbert/README.md new file mode 100644 index 00000000..88f97f2b --- /dev/null +++ b/candle-examples/examples/distilbert/README.md @@ -0,0 +1,22 @@ +# candle-distilbert + +DistilBert is a distiled version of the Bert model. + +## Sentence embeddings + +DistilBert is used to compute the sentence embeddings for a prompt. The model weights +are downloaded from the hub on the first run. + +```bash +cargo run --example distilbert --release -- --prompt "Here is a test sentence" + +> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441], +> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244], +> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829], +> ... +> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042], +> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190], +> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]] +> Tensor[[1, 7, 768], f32] + +``` diff --git a/candle-examples/examples/distilbert/main.rs b/candle-examples/examples/distilbert/main.rs new file mode 100644 index 00000000..1d42011c --- /dev/null +++ b/candle-examples/examples/distilbert/main.rs @@ -0,0 +1,135 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE}; + +use anyhow::{Error as E, Result}; +use candle::{Device, Tensor}; +use candle_nn::VarBuilder; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option<String>, + + #[arg(long)] + revision: Option<String>, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: String, + + /// Use the pytorch weights rather than the safetensors ones + #[arg(long)] + use_pth: bool, + + /// The number of times to run the prompt. + #[arg(long, default_value = "1")] + n: usize, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, +} + +impl Args { + fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> { + let device = candle_examples::device(self.cpu)?; + let default_model = "distilbert-base-uncased".to_string(); + let default_revision = "main".to_string(); + let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new()?; + let api = api.repo(repo); + let config = api.get("config.json")?; + let tokenizer = api.get("tokenizer.json")?; + let weights = if self.use_pth { + api.get("pytorch_model.bin")? + } else { + api.get("model.safetensors")? + }; + (config, tokenizer, weights) + }; + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let vb = if self.use_pth { + VarBuilder::from_pth(&weights_filename, DTYPE, &device)? + } else { + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } + }; + let model = DistilBertModel::load(vb, &config)?; + Ok((model, tokenizer)) + } +} + +fn get_mask(size: usize, device: &Device) -> Tensor { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device).unwrap() +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + let (model, mut tokenizer) = args.build_model_and_tokenizer()?; + let device = &model.device; + + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(args.prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let mask = get_mask(tokens.len(), device); + + println!("token_ids: {:?}", token_ids.to_vec2::<u32>()); + println!("mask: {:?}", mask.to_vec2::<u8>()); + + let ys = model.forward(&token_ids, &mask)?; + println!("{ys}"); + + Ok(()) +} + +pub fn normalize_l2(v: &Tensor) -> Result<Tensor> { + Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) +} diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index 5b1937ac..a6574697 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -43,6 +43,7 @@ pub fn report( confidence_threshold: f32, nms_threshold: f32, ) -> Result<DynamicImage> { + let pred = pred.to_device(&Device::Cpu)?; let (npreds, pred_size) = pred.dims2()?; let nclasses = pred_size - 5; // The bounding boxes grouped by (maximum) class index. diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index af8cf98a..c65a5ca1 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; mod model; use model::{Multiples, YoloV8, YoloV8Pose}; -use candle::{DType, IndexOp, Result, Tensor}; +use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Module, VarBuilder}; use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; @@ -61,6 +61,7 @@ pub fn report_detect( nms_threshold: f32, legend_size: u32, ) -> Result<DynamicImage> { + let pred = pred.to_device(&Device::Cpu)?; let (pred_size, npreds) = pred.dims2()?; let nclasses = pred_size - 4; // The bounding boxes grouped by (maximum) class index. @@ -153,6 +154,7 @@ pub fn report_pose( confidence_threshold: f32, nms_threshold: f32, ) -> Result<DynamicImage> { + let pred = pred.to_device(&Device::Cpu)?; let (pred_size, npreds) = pred.dims2()?; if pred_size != 17 * 3 + 4 + 1 { candle::bail!("unexpected pred-size {pred_size}"); diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index cff8e763..5a6bd41b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -672,626 +672,4 @@ pub fn call_index_select( } #[cfg(test)] -mod tests { - use super::*; - use half::f16; - use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; - - fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { - let options = MTLResourceOptions::StorageModeManaged; - let ptr = data.as_ptr() as *const core::ffi::c_void; - let size = (data.len() * std::mem::size_of::<T>()) as u64; - device.new_buffer_with_data(ptr, size, options) - } - - fn device() -> Device { - Device::system_default().unwrap() - } - - fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> { - let b = 10f32.powi(digits); - v.iter().map(|t| f32::round(t * b) / b).collect() - } - - fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> { - let b = 10f32.powi(digits); - v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() - } - - fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - call_unary_contiguous( - &device, - command_buffer, - &kernels, - name, - v.len(), - &input, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) - } - - fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let left = new_buffer(&device, x); - let right = new_buffer(&device, y); - let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); - call_binary_contiguous( - &device, - command_buffer, - &kernels, - name, - x.len(), - &left, - &right, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - output.read_to_vec::<T>(x.len()) - } - - fn run_strided<T: Clone>( - v: &[T], - kernel: unary::strided::Kernel, - shape: &[usize], - strides: &[usize], - offset: usize, - ) -> Vec<T> { - let device = device(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - let kernels = Kernels::new(); - call_unary_strided( - &device, - command_buffer, - &kernels, - kernel, - shape, - &input, - strides, - offset, - &mut output, - 0, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) - } - - #[test] - fn cos_f32() { - let v = vec![1.0f32, 2.0, 3.0]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); - assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); - - let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); - } - - #[test] - fn cos_f32_strided() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let shape = vec![6]; - let strides = vec![1]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!( - approx(results, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - assert_eq!( - approx(expected, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - - // Contiguous - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let shape = vec![3, 2]; - let strides = vec![2, 1]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!( - approx(results, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - assert_eq!( - approx(expected, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - - // Transposed - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let shape = vec![3, 2]; - let strides = vec![1, 3]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!( - approx(results, 4), - vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] - ); - assert_eq!( - approx(expected, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - - // Very large - let v = vec![1.0f32; 10_000]; - let shape = vec![2, 5_000]; - let strides = vec![2, 1]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); - } - - #[test] - fn cos_strided_random() { - let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect(); - let shape = vec![5_000, 2]; - let strides = vec![1, 5_000]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4)); - assert_eq!( - approx(vec![results[1]], 4), - approx(vec![expected[5_000]], 4) - ); - assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4)); - assert_eq!( - approx(vec![results[3]], 4), - approx(vec![expected[5_001]], 4) - ); - assert_eq!( - approx(vec![results[5_000]], 4), - approx(vec![expected[2_500]], 4) - ); - } - - #[test] - fn binary_add_f32() { - let left = vec![1.0f32, 2.0, 3.0]; - let right = vec![2.0f32, 3.1, 4.2]; - let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); - let expected: Vec<_> = left - .iter() - .zip(right.iter()) - .map(|(&x, &y)| x + y) - .collect(); - assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); - assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); - } - - fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - - call_cast_contiguous( - &device, - command_buffer, - &kernels, - name, - v.len(), - &input, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - output.read_to_vec::<U>(v.len()) - } - - #[test] - fn cast_u32_f32() { - let v = vec![1u32, 2, 3]; - let results = cast(&v, "cast_u32_f32"); - let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); - assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); - assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); - - let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); - } - - fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - - let size = v.len(); - - call_affine( - &device, - command_buffer, - &kernels, - size, - &input, - &mut output, - mul as f32, - add as f32, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - output.read_to_vec::<T>(v.len()) - } - - #[test] - fn affine() { - let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; - let mul = 1.5; - let add = 1.1; - let result = run_affine(&input, mul, add); - assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); - - let input = [1.0f32; 40_000]; - let mul = 1.5; - let add = 1.1; - let result = run_affine(&input, mul, add); - assert_eq!(result, vec![2.6; 40_000]); - } - - #[test] - fn index_select() { - let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - let shape = [5, 2]; - let ids = [0u32, 4, 2]; - let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); - assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); - - let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - let shape = [2, 5]; - let ids = [0u32, 1, 0]; - let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); - assert_eq!( - result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] - ); - - let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - let shape = [5, 2]; - let ids = [0u32, 1, 0]; - let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim); - assert_eq!( - result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] - ); - } - - fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( - embeddings: &[T], - shape: &[usize], - ids: &[I], - dim: usize, - ) -> Vec<T> { - let device = Device::system_default().expect("no device found"); - - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let embeddings_buffer = new_buffer(&device, &embeddings); - let ids_buffer = new_buffer(&device, &ids); - - let left_size: usize = shape[..dim].iter().product(); - let right_size: usize = shape[dim + 1..].iter().product(); - let dst_el = ids.len() * left_size * right_size; - let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - - let kernels = Kernels::new(); - call_index_select( - &device, - &command_buffer, - &kernels, - "is_u32_f32", - shape, - ids.len(), - dim, - &embeddings_buffer, - &ids_buffer, - &mut dst_buffer, - ) - .unwrap(); - - command_buffer.commit(); - command_buffer.wait_until_completed(); - - dst_buffer.read_to_vec::<T>(dst_el) - } - - #[test] - fn index_add() { - let device = Device::system_default().expect("no device found"); - - let options = CompileOptions::new(); - let library = device.new_library_with_source(INDEXING, &options).unwrap(); - - let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; - let right = [1.0f32; 15]; - let index = [0u32, 4, 2]; - let ids_dim_size = index.len() as u32; - let dst_dim_size: u32 = 15; - let left_size: u32 = 3; - let right_size: u32 = 3; - - let function = library.get_function("ia_u32_f32", None).unwrap(); - let pipeline = device - .new_compute_pipeline_state_with_function(&function) - .unwrap(); - - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let encoder = command_buffer.new_compute_command_encoder(); - - encoder.set_compute_pipeline_state(&pipeline); - - let index_buffer = new_buffer(&device, &index); - let inputs_buffer = new_buffer(&device, &left); - let outputs_buffer = new_buffer(&device, &right); - - set_params!( - encoder, - ( - &index_buffer, - &inputs_buffer, - &outputs_buffer, - ids_dim_size, - left_size, - dst_dim_size, - right_size - ) - ); - - let grid_size = MTLSize { - width: right.len() as NSUInteger, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: pipeline.max_total_threads_per_threadgroup(), - height: 1, - depth: 1, - }; - - encoder.dispatch_thread_groups(grid_size, thread_group_size); - encoder.end_encoding(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - let expected = vec![ - 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, - ]; - let result = outputs_buffer.read_to_vec::<f32>(right.len()); - assert_eq!(result, expected); - } - - #[test] - fn cos_f16() { - let v: Vec<f16> = [1.0f32, 2.0, 3.0] - .iter() - .map(|v| f16::from_f32(*v)) - .collect(); - let results = run(&v, unary::contiguous::cos::HALF); - let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); - assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); - } - - fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - - let options = MTLResourceOptions::StorageModeManaged; - let mut output = - device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); - call_reduce_contiguous( - &device, - command_buffer, - &kernels, - name, - v.len(), - out_length, - &input, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - output.read_to_vec::<T>(out_length) - } - - fn run_softmax<T: Clone + std::fmt::Debug>( - v: &[T], - last_dim: usize, - name: &'static str, - ) -> Vec<T> { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - call_last_softmax( - &device, - command_buffer, - &kernels, - name, - v.len(), - last_dim, - &input, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - output.read_to_vec::<T>(v.len()) - } - - #[test] - fn reduce_sum() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let out_length = 1; - - let results = run_reduce(&v, out_length, "fast_sum_float"); - assert_eq!(approx(results, 4), vec![21.0]); - } - - #[test] - fn reduce_sum2() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let out_length = 2; - - let results = run_reduce(&v, out_length, "fast_sum_float"); - assert_eq!(approx(results, 4), vec![6.0, 15.0]); - } - - #[test] - fn softmax() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); - assert_eq!( - approx(results, 4), - vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] - ); - - let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; - let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); - assert_eq!( - approx(results, 4), - vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] - ); - - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let last_dim = 3; - let results = run_softmax(&v, last_dim, "softmax_float"); - assert_eq!( - approx(results, 4), - vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] - ); - } - - fn run_where_cond<I: Clone, T: Clone>( - shape: &[usize], - cond: &[I], - (cond_stride, cond_offset): (Vec<usize>, usize), - left_true: &[T], - (left_stride, left_offset): (Vec<usize>, usize), - right_false: &[T], - (_right_stride, _right_offset): (Vec<usize>, usize), - name: &'static str, - ) -> Vec<T> { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - - let length = cond.len(); - let cond = device.new_buffer_with_data( - cond.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(cond) as u64, - options, - ); - let left = device.new_buffer_with_data( - left_true.as_ptr() as *const core::ffi::c_void, - (length * core::mem::size_of::<T>()) as u64, - options, - ); - let right = device.new_buffer_with_data( - right_false.as_ptr() as *const core::ffi::c_void, - (length * core::mem::size_of::<T>()) as u64, - options, - ); - - let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); - call_where_cond_strided( - &device, - command_buffer, - &kernels, - name, - shape, - &cond, - (&cond_stride, cond_offset), - &left, - (&left_stride, left_offset), - &right, - (&cond_stride, cond_offset), - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - output.read_to_vec::<T>(length) - } - - #[test] - fn where_cond() { - let shape = vec![6]; - let cond = vec![0u8, 1, 0, 0, 1, 1]; - let cond_l = (vec![1], 0); - let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let left_l = (vec![1], 0); - let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; - let right_l = (vec![1], 0); - let results = run_where_cond( - &shape, - &cond, - cond_l, - &left_true, - left_l, - &right_false, - right_l, - "where_u8_f32", - ); - assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); - } -} +mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs new file mode 100644 index 00000000..2330d48d --- /dev/null +++ b/candle-metal-kernels/src/tests.rs @@ -0,0 +1,616 @@ +use super::*; +use half::f16; +use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; + +fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { + let options = MTLResourceOptions::StorageModeManaged; + let ptr = data.as_ptr() as *const core::ffi::c_void; + let size = (data.len() * std::mem::size_of::<T>()) as u64; + device.new_buffer_with_data(ptr, size, options) +} + +fn device() -> Device { + Device::system_default().unwrap() +} + +fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t * b) / b).collect() +} + +fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + +fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + call_unary_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::<T>(v.len()) +} + +fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let left = new_buffer(&device, x); + let right = new_buffer(&device, y); + let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + call_binary_contiguous( + &device, + command_buffer, + &kernels, + name, + x.len(), + &left, + &right, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::<T>(x.len()) +} + +fn run_strided<T: Clone>( + v: &[T], + kernel: unary::strided::Kernel, + shape: &[usize], + strides: &[usize], + offset: usize, +) -> Vec<T> { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + let kernels = Kernels::new(); + call_unary_strided( + &device, + command_buffer, + &kernels, + kernel, + shape, + &input, + strides, + offset, + &mut output, + 0, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::<T>(v.len()) +} + +#[test] +fn cos_f32() { + let v = vec![1.0f32, 2.0, 3.0]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); + assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +#[test] +fn cos_f32_strided() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![6]; + let strides = vec![1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Contiguous + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Transposed + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![1, 3]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Very large + let v = vec![1.0f32; 10_000]; + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +#[test] +fn cos_strided_random() { + let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect(); + let shape = vec![5_000, 2]; + let strides = vec![1, 5_000]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4)); + assert_eq!( + approx(vec![results[1]], 4), + approx(vec![expected[5_000]], 4) + ); + assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4)); + assert_eq!( + approx(vec![results[3]], 4), + approx(vec![expected[5_001]], 4) + ); + assert_eq!( + approx(vec![results[5_000]], 4), + approx(vec![expected[2_500]], 4) + ); +} + +#[test] +fn binary_add_f32() { + let left = vec![1.0f32, 2.0, 3.0]; + let right = vec![2.0f32, 3.1, 4.2]; + let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); + let expected: Vec<_> = left + .iter() + .zip(right.iter()) + .map(|(&x, &y)| x + y) + .collect(); + assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); + assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); +} + +fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + + call_cast_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::<U>(v.len()) +} + +#[test] +fn cast_u32_f32() { + let v = vec![1u32, 2, 3]; + let results = cast(&v, "cast_u32_f32"); + let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); + assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); + assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + + let size = v.len(); + + call_affine( + &device, + command_buffer, + &kernels, + size, + &input, + &mut output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::<T>(v.len()) +} + +#[test] +fn affine() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + + let input = [1.0f32; 40_000]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6; 40_000]); +} + +#[test] +fn index_select() { + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [2, 5]; + let ids = [0u32, 1, 0]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 1, 0]; + let dim = 1; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); +} + +fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( + embeddings: &[T], + shape: &[usize], + ids: &[I], + dim: usize, +) -> Vec<T> { + let device = Device::system_default().expect("no device found"); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let embeddings_buffer = new_buffer(&device, &embeddings); + let ids_buffer = new_buffer(&device, &ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; + let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let kernels = Kernels::new(); + call_index_select( + &device, + &command_buffer, + &kernels, + "is_u32_f32", + shape, + ids.len(), + dim, + &embeddings_buffer, + &ids_buffer, + &mut dst_buffer, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + dst_buffer.read_to_vec::<T>(dst_el) +} + +#[test] +fn index_add() { + let device = Device::system_default().expect("no device found"); + + let options = CompileOptions::new(); + let library = device.new_library_with_source(INDEXING, &options).unwrap(); + + let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let right = [1.0f32; 15]; + let index = [0u32, 4, 2]; + let ids_dim_size = index.len() as u32; + let dst_dim_size: u32 = 15; + let left_size: u32 = 3; + let right_size: u32 = 3; + + let function = library.get_function("ia_u32_f32", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + let index_buffer = new_buffer(&device, &index); + let inputs_buffer = new_buffer(&device, &left); + let outputs_buffer = new_buffer(&device, &right); + + set_params!( + encoder, + ( + &index_buffer, + &inputs_buffer, + &outputs_buffer, + ids_dim_size, + left_size, + dst_dim_size, + right_size + ) + ); + + let grid_size = MTLSize { + width: right.len() as NSUInteger, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: pipeline.max_total_threads_per_threadgroup(), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = vec![ + 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, + ]; + let result = outputs_buffer.read_to_vec::<f32>(right.len()); + assert_eq!(result, expected); +} + +#[test] +fn cos_f16() { + let v: Vec<f16> = [1.0f32, 2.0, 3.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let results = run(&v, unary::contiguous::cos::HALF); + let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); + assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); +} + +fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + + let options = MTLResourceOptions::StorageModeManaged; + let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); + call_reduce_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + out_length, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::<T>(out_length) +} + +fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + call_last_softmax( + &device, + command_buffer, + &kernels, + name, + v.len(), + last_dim, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::<T>(v.len()) +} + +#[test] +fn reduce_sum() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 1; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![21.0]); +} + +#[test] +fn reduce_sum2() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 2; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![6.0, 15.0]); +} + +#[test] +fn softmax() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 3; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] + ); +} + +fn run_where_cond<I: Clone, T: Clone>( + shape: &[usize], + cond: &[I], + (cond_stride, cond_offset): (Vec<usize>, usize), + left_true: &[T], + (left_stride, left_offset): (Vec<usize>, usize), + right_false: &[T], + (_right_stride, _right_offset): (Vec<usize>, usize), + name: &'static str, +) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let length = cond.len(); + let cond = device.new_buffer_with_data( + cond.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(cond) as u64, + options, + ); + let left = device.new_buffer_with_data( + left_true.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::<T>()) as u64, + options, + ); + let right = device.new_buffer_with_data( + right_false.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::<T>()) as u64, + options, + ); + + let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + call_where_cond_strided( + &device, + command_buffer, + &kernels, + name, + shape, + &cond, + (&cond_stride, cond_offset), + &left, + (&left_stride, left_offset), + &right, + (&cond_stride, cond_offset), + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::<T>(length) +} + +#[test] +fn where_cond() { + let shape = vec![6]; + let cond = vec![0u8, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u8_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); +} diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 9ded5f71..68d384a6 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -6,7 +6,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::quantized::GgmlType; -use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D}; +use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D}; use clap::{Parser, Subcommand}; const CHECK_CONV2D: bool = false; diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index b0c623d3..ade00012 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -17,7 +17,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; +use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType}; mod utils; use utils::wrap_err; diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs new file mode 100644 index 00000000..ea074c97 --- /dev/null +++ b/candle-transformers/src/models/distilbert.rs @@ -0,0 +1,342 @@ +use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; +use candle::{DType, Device, Result, Tensor}; +use candle_nn::{Embedding, Module, VarBuilder}; +use serde::Deserialize; + +pub const DTYPE: DType = DType::F32; + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +enum HiddenAct { + Gelu, + Relu, +} + +struct HiddenActLayer { + act: HiddenAct, + span: tracing::Span, +} + +impl HiddenActLayer { + fn new(act: HiddenAct) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); + Self { act, span } + } +} + +impl Module for HiddenActLayer { + fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { + let _enter = self.span.enter(); + match self.act { + // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 + HiddenAct::Gelu => xs.gelu(), + HiddenAct::Relu => xs.relu(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +enum PositionEmbeddingType { + #[default] + Absolute, +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + dim: usize, + n_layers: usize, + n_heads: usize, + hidden_dim: usize, + activation: HiddenAct, + max_position_embeddings: usize, + initializer_range: f64, + pad_token_id: usize, + #[serde(default)] + position_embedding_type: PositionEmbeddingType, + #[serde(default)] + use_cache: bool, + model_type: Option<String>, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 30522, + dim: 768, + n_layers: 12, + n_heads: 12, + hidden_dim: 3072, + activation: HiddenAct::Gelu, + max_position_embeddings: 512, + initializer_range: 0.02, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + model_type: Some("distilbert".to_string()), + } + } +} + +struct Embeddings { + word_embeddings: Embedding, + position_embeddings: Embedding, + layer_norm: LayerNorm, + span: tracing::Span, +} + +impl Embeddings { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let word_embeddings = + candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?; + let position_embeddings = candle_nn::embedding( + config.max_position_embeddings, + config.dim, + vb.pp("position_embeddings"), + )?; + let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?; + Ok(Self { + word_embeddings, + position_embeddings, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (_bsize, seq_len) = input_ids.dims2()?; + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let position_ids = (0..seq_len as u32).collect::<Vec<_>>(); + let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; + let embeddings = + input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?; + + let embeddings = self.layer_norm.forward(&embeddings)?; + Ok(embeddings) + } +} + +struct MultiHeadSelfAttention { + q_lin: Linear, + k_lin: Linear, + v_lin: Linear, + out_lin: Linear, + n_heads: usize, + attention_head_size: usize, + span: tracing::Span, +} + +impl MultiHeadSelfAttention { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let attention_head_size = config.dim / config.n_heads; + let all_head_size = config.n_heads * attention_head_size; + let dim = config.dim; + let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?; + let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?; + let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?; + let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?; + Ok(Self { + q_lin, + k_lin, + v_lin, + out_lin, + n_heads: config.n_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } +} + +impl MultiHeadSelfAttention { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (bs, q_length, _dim) = hidden_states.dims3()?; + + let dim_per_head = self.attention_head_size; + let q = self.q_lin.forward(hidden_states)?; + let k = self.k_lin.forward(hidden_states)?; + let v = self.v_lin.forward(hidden_states)?; + + let q = q + .reshape((bs, q_length, self.n_heads, dim_per_head))? + .transpose(1, 2)?; + let k = k + .reshape((bs, q_length, self.n_heads, dim_per_head))? + .transpose(1, 2)?; + let v = v + .reshape((bs, q_length, self.n_heads, dim_per_head))? + .transpose(1, 2)?; + + let q: Tensor = (q / (dim_per_head as f64).sqrt())?; + let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?; + let mask = attention_mask.broadcast_as(scores.shape())?; + + let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?; + let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?; + + let context = weights.matmul(&v.contiguous()?)?; + let context = context + .transpose(1, 2)? + .reshape((bs, q_length, self.n_heads * dim_per_head))? + .contiguous()?; + let context = self.out_lin.forward(&context)?; + + Ok(context) + } +} + +#[allow(clippy::upper_case_acronyms)] +struct FFN { + lin1: Linear, + lin2: Linear, + activation: HiddenActLayer, + span: tracing::Span, +} + +impl FFN { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?; + let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?; + Ok(Self { + lin1, + lin2, + activation: HiddenActLayer::new(config.activation), + span: tracing::span!(tracing::Level::TRACE, "ffn"), + }) + } +} + +impl Module for FFN { + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + hidden_states + .apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) + } +} + +struct TransformerBlock { + attention: MultiHeadSelfAttention, + sa_layer_norm: LayerNorm, + ffn: FFN, + output_layer_norm: LayerNorm, + span: tracing::Span, +} + +impl TransformerBlock { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?; + let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?; + let ffn = FFN::load(vb.pp("ffn"), config)?; + let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?; + Ok(Self { + attention, + sa_layer_norm, + ffn, + output_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } +} + +impl TransformerBlock { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let sa_output = self.attention.forward(hidden_states, attention_mask)?; + // TODO: Support cross-attention? + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + // TODO: Support something similar to `apply_chunking_to_forward`? + let sa_output = sa_output.broadcast_add(hidden_states)?; + let sa_output = self.sa_layer_norm.forward(&sa_output)?; + + let ffn_output = self.ffn.forward(&sa_output)?; + let ffn_output = (&ffn_output + sa_output)?; + let output = self.output_layer_norm.forward(&ffn_output)?; + Ok(output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 +struct Transformer { + layers: Vec<TransformerBlock>, + span: tracing::Span, +} + +impl Transformer { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let layers = (0..config.n_layers) + .map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config)) + .collect::<Result<Vec<_>>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(Transformer { layers, span }) + } +} + +impl Transformer { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, attention_mask)?; + } + Ok(hidden_states) + } +} + +pub struct DistilBertModel { + embeddings: Embeddings, + transformer: Transformer, + pub device: Device, + span: tracing::Span, +} + +impl DistilBertModel { + pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let (embeddings, transformer) = match ( + Embeddings::load(vb.pp("embeddings"), config), + Transformer::load(vb.pp("transformer"), config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let Some(model_type) = &config.model_type { + if let (Ok(embeddings), Ok(encoder)) = ( + Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), + Transformer::load(vb.pp(&format!("{model_type}.transformer")), config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } else { + return Err(err); + } + } + }; + Ok(Self { + embeddings, + transformer, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids)?; + let sequence_output = self + .transformer + .forward(&embedding_output, attention_mask)?; + Ok(sequence_output) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 558583b6..a9a56673 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -4,6 +4,7 @@ pub mod blip; pub mod blip_text; pub mod convmixer; pub mod dinov2; +pub mod distilbert; pub mod efficientnet; pub mod falcon; pub mod jina_bert; diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs index 0c95cfef..cef06f1c 100644 --- a/candle-transformers/src/models/stable_diffusion/utils.rs +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -1,12 +1,15 @@ use candle::{Device, Result, Tensor}; pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { - if steps < 1 { - candle::bail!("cannot use linspace with steps {steps} <= 1") + if steps == 0 { + Tensor::from_vec(Vec::<f64>::new(), steps, &Device::Cpu) + } else if steps == 1 { + Tensor::from_vec(vec![start], steps, &Device::Cpu) + } else { + let delta = (stop - start) / (steps - 1) as f64; + let vs = (0..steps) + .map(|step| start + step as f64 * delta) + .collect::<Vec<_>>(); + Tensor::from_vec(vs, steps, &Device::Cpu) } - let delta = (stop - start) / (steps - 1) as f64; - let vs = (0..steps) - .map(|step| start + step as f64 * delta) - .collect::<Vec<_>>(); - Tensor::from_vec(vs, steps, &Device::Cpu) } diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index 5d53728c..e5fa7dec 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -1,7 +1,7 @@ use candle::{ quantized::{self, k_quants, GgmlDType, GgmlType}, test_utils::to_vec2_round, - Device, Result, Tensor, + Device, Module, Result, Tensor, }; use wasm_bindgen_test::*; |