summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md17
-rw-r--r--candle-core/examples/basics.rs11
-rw-r--r--candle-core/src/indexer.rs50
-rw-r--r--candle-core/src/lib.rs6
-rw-r--r--candle-core/src/quantized/mod.rs4
-rw-r--r--candle-core/src/tensor.rs58
-rw-r--r--candle-core/tests/indexing_tests.rs29
-rw-r--r--candle-core/tests/quantized_tests.rs2
-rw-r--r--candle-examples/examples/distilbert/README.md22
-rw-r--r--candle-examples/examples/distilbert/main.rs135
-rw-r--r--candle-examples/examples/yolo-v3/main.rs1
-rw-r--r--candle-examples/examples/yolo-v8/main.rs4
-rw-r--r--candle-metal-kernels/src/lib.rs624
-rw-r--r--candle-metal-kernels/src/tests.rs616
-rw-r--r--candle-nn/examples/cpu_benchmarks.rs2
-rw-r--r--candle-pyo3/src/lib.rs2
-rw-r--r--candle-transformers/src/models/distilbert.rs342
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/stable_diffusion/utils.rs17
-rw-r--r--candle-wasm-tests/tests/quantized_tests.rs2
20 files changed, 1262 insertions, 683 deletions
diff --git a/README.md b/README.md
index 5f470458..20596fe1 100644
--- a/README.md
+++ b/README.md
@@ -139,16 +139,16 @@ And then head over to
<!--- ANCHOR: useful_libraries --->
## Useful External Resources
-- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
+- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
very detailed tutorial showing how to convert a PyTorch model to Candle.
-- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
+- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
+ out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
+- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
-- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
- that conforms to the official `peft` implementation.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
serving local LLMs including an OpenAI compatible API server.
-- [`candle-ext`](https://github.com/mokeyish/candle-ext): an extension library to Candle that provides PyTorch functions not currently available in Candle.
-- [`kalosm`](https://github.com/floneum/floneum/tree/master/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
+- [`candle-ext`](https://github.com/mokeyish/candle-ext): An extension library to Candle that provides PyTorch functions not currently available in Candle.
+- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
If you have an addition to this list, please submit a pull request.
@@ -177,6 +177,11 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
+ - Quantized LLMs.
+ - Llama 7b, 13b, 70b, as well as the chat and code variants.
+ - Mistral 7b, and 7b instruct.
+ - Zephyr 7b a and b (Mistral based).
+ - OpenChat 3.5 (Mistral based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs
index ad008177..fe15187b 100644
--- a/candle-core/examples/basics.rs
+++ b/candle-core/examples/basics.rs
@@ -8,11 +8,10 @@ use anyhow::Result;
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
- let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?;
- let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?;
- let start = std::time::Instant::now();
- let res = inp.conv2d(&w, 0, 1, 1, 1)?;
- println!("{:?}", start.elapsed());
- println!("{res:?}");
+ let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
+ let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
+ let new_a = a.slice_scatter(&b, 1, 2)?;
+ assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
+ assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
Ok(())
}
diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs
index 7b84d316..df106b73 100644
--- a/candle-core/src/indexer.rs
+++ b/candle-core/src/indexer.rs
@@ -104,37 +104,31 @@ impl From<&Tensor> for TensorIndexer {
}
}
-macro_rules! impl_from_range {
- ($range_type:ty) => {
- impl From<$range_type> for TensorIndexer {
- fn from(range: $range_type) -> Self {
- use std::ops::Bound::*;
+trait RB: RangeBounds<usize> {}
+impl RB for Range<usize> {}
+impl RB for RangeFrom<usize> {}
+impl RB for RangeFull {}
+impl RB for RangeInclusive<usize> {}
+impl RB for RangeTo<usize> {}
+impl RB for RangeToInclusive<usize> {}
- let start = match range.start_bound() {
- Included(idx) => Included(*idx),
- Excluded(idx) => Excluded(*idx),
- Unbounded => Unbounded,
- };
-
- let end = match range.end_bound() {
- Included(idx) => Included(*idx),
- Excluded(idx) => Excluded(*idx),
- Unbounded => Unbounded,
- };
-
- TensorIndexer::Narrow(start, end)
- }
- }
- };
+impl<T: RB> From<T> for TensorIndexer {
+ fn from(range: T) -> Self {
+ use std::ops::Bound::*;
+ let start = match range.start_bound() {
+ Included(idx) => Included(*idx),
+ Excluded(idx) => Excluded(*idx),
+ Unbounded => Unbounded,
+ };
+ let end = match range.end_bound() {
+ Included(idx) => Included(*idx),
+ Excluded(idx) => Excluded(*idx),
+ Unbounded => Unbounded,
+ };
+ TensorIndexer::Narrow(start, end)
+ }
}
-impl_from_range!(Range<usize>);
-impl_from_range!(RangeFrom<usize>);
-impl_from_range!(RangeFull);
-impl_from_range!(RangeInclusive<usize>);
-impl_from_range!(RangeTo<usize>);
-impl_from_range!(RangeToInclusive<usize>);
-
/// Trait used to implement multiple signatures for ease of use of the slicing
/// of a tensor
pub trait IndexOp<T> {
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 36f5f6b1..6c4fea91 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -123,12 +123,6 @@ pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
-impl Module for quantized::QMatMul {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- self.forward(xs)
- }
-}
-
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 58f261b4..043733ae 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor {
}
}
-impl QMatMul {
- pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+impl crate::Module for QMatMul {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ce5858fa..87323a84 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2503,6 +2503,64 @@ impl Tensor {
t.transpose(dim, last)
}
}
+
+ /// Returns a copy of `self` where the values within `ranges` have been replaced with the
+ /// content of `src`.
+ pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
+ &self,
+ ranges: &[D],
+ src: &Tensor,
+ ) -> Result<Self> {
+ let src_dims = src.dims();
+ let self_dims = self.dims();
+ if self_dims.len() != src_dims.len() {
+ crate::bail!(
+ "slice-assign requires input with the same rank {} <> {}",
+ self_dims.len(),
+ src_dims.len()
+ )
+ }
+ if self_dims.len() != ranges.len() {
+ crate::bail!(
+ "slice-assign requires input with the same rank as there are ranges {} <> {}",
+ self_dims.len(),
+ ranges.len()
+ )
+ }
+ let mut src = src.clone();
+ let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
+ for (i, range) in ranges.iter().enumerate() {
+ let start_included = match range.start_bound() {
+ std::ops::Bound::Unbounded => 0,
+ std::ops::Bound::Included(v) => *v,
+ std::ops::Bound::Excluded(v) => *v + 1,
+ };
+ let end_excluded = match range.end_bound() {
+ std::ops::Bound::Unbounded => self_dims[i],
+ std::ops::Bound::Included(v) => *v + 1,
+ std::ops::Bound::Excluded(v) => *v,
+ };
+ if end_excluded <= start_included {
+ crate::bail!(
+ "slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
+ )
+ }
+ if self_dims[i] < end_excluded {
+ crate::bail!(
+ "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
+ self_dims[i]
+ )
+ }
+ if end_excluded - start_included != src_dims[i] {
+ crate::bail!(
+ "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
+ )
+ }
+ src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
+ mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
+ }
+ mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
+ }
}
macro_rules! bin_trait {
diff --git a/candle-core/tests/indexing_tests.rs b/candle-core/tests/indexing_tests.rs
index 9c88f319..047205a3 100644
--- a/candle-core/tests/indexing_tests.rs
+++ b/candle-core/tests/indexing_tests.rs
@@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
Ok(())
}
+
+#[test]
+fn slice_assign() -> Result<()> {
+ let dev = Device::Cpu;
+
+ let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
+ let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
+ let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
+ assert_eq!(
+ out.to_vec2::<u32>()?,
+ &[
+ [0, 1, 2, 3, 4],
+ [5, 6, 7, 0, 1],
+ [10, 11, 12, 2, 3],
+ [15, 16, 17, 4, 5]
+ ]
+ );
+ let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
+ assert_eq!(
+ out.to_vec2::<u32>()?,
+ &[
+ [0, 1, 2, 3, 4],
+ [2, 3, 7, 8, 9],
+ [4, 5, 12, 13, 14],
+ [15, 16, 17, 18, 19]
+ ]
+ );
+ Ok(())
+}
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index a2cecbc3..716cca8d 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -1,7 +1,7 @@
use candle_core::{
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
- Device, Result, Tensor,
+ Device, Module, Result, Tensor,
};
use quantized::{k_quants, GgmlType};
use rand::prelude::*;
diff --git a/candle-examples/examples/distilbert/README.md b/candle-examples/examples/distilbert/README.md
new file mode 100644
index 00000000..88f97f2b
--- /dev/null
+++ b/candle-examples/examples/distilbert/README.md
@@ -0,0 +1,22 @@
+# candle-distilbert
+
+DistilBert is a distiled version of the Bert model.
+
+## Sentence embeddings
+
+DistilBert is used to compute the sentence embeddings for a prompt. The model weights
+are downloaded from the hub on the first run.
+
+```bash
+cargo run --example distilbert --release -- --prompt "Here is a test sentence"
+
+> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
+> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
+> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
+> ...
+> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
+> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
+> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
+> Tensor[[1, 7, 768], f32]
+
+```
diff --git a/candle-examples/examples/distilbert/main.rs b/candle-examples/examples/distilbert/main.rs
new file mode 100644
index 00000000..1d42011c
--- /dev/null
+++ b/candle-examples/examples/distilbert/main.rs
@@ -0,0 +1,135 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
+
+use anyhow::{Error as E, Result};
+use candle::{Device, Tensor};
+use candle_nn::VarBuilder;
+use clap::Parser;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
+
+ /// When set, compute embeddings for this prompt.
+ #[arg(long)]
+ prompt: String,
+
+ /// Use the pytorch weights rather than the safetensors ones
+ #[arg(long)]
+ use_pth: bool,
+
+ /// The number of times to run the prompt.
+ #[arg(long, default_value = "1")]
+ n: usize,
+
+ /// L2 normalization for embeddings.
+ #[arg(long, default_value = "true")]
+ normalize_embeddings: bool,
+}
+
+impl Args {
+ fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
+ let device = candle_examples::device(self.cpu)?;
+ let default_model = "distilbert-base-uncased".to_string();
+ let default_revision = "main".to_string();
+ let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
+ (Some(model_id), Some(revision)) => (model_id, revision),
+ (Some(model_id), None) => (model_id, "main".to_string()),
+ (None, Some(revision)) => (default_model, revision),
+ (None, None) => (default_model, default_revision),
+ };
+
+ let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+ let (config_filename, tokenizer_filename, weights_filename) = {
+ let api = Api::new()?;
+ let api = api.repo(repo);
+ let config = api.get("config.json")?;
+ let tokenizer = api.get("tokenizer.json")?;
+ let weights = if self.use_pth {
+ api.get("pytorch_model.bin")?
+ } else {
+ api.get("model.safetensors")?
+ };
+ (config, tokenizer, weights)
+ };
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: Config = serde_json::from_str(&config)?;
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let vb = if self.use_pth {
+ VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
+ } else {
+ unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
+ };
+ let model = DistilBertModel::load(vb, &config)?;
+ Ok((model, tokenizer))
+ }
+}
+
+fn get_mask(size: usize, device: &Device) -> Tensor {
+ let mask: Vec<_> = (0..size)
+ .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
+ .collect();
+ Tensor::from_slice(&mask, (size, size), device).unwrap()
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
+ let device = &model.device;
+
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
+ let tokens = tokenizer
+ .encode(args.prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let mask = get_mask(tokens.len(), device);
+
+ println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
+ println!("mask: {:?}", mask.to_vec2::<u8>());
+
+ let ys = model.forward(&token_ids, &mask)?;
+ println!("{ys}");
+
+ Ok(())
+}
+
+pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
+ Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+}
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs
index 5b1937ac..a6574697 100644
--- a/candle-examples/examples/yolo-v3/main.rs
+++ b/candle-examples/examples/yolo-v3/main.rs
@@ -43,6 +43,7 @@ pub fn report(
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
+ let pred = pred.to_device(&Device::Cpu)?;
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index af8cf98a..c65a5ca1 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -7,7 +7,7 @@ extern crate accelerate_src;
mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
-use candle::{DType, IndexOp, Result, Tensor};
+use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
@@ -61,6 +61,7 @@ pub fn report_detect(
nms_threshold: f32,
legend_size: u32,
) -> Result<DynamicImage> {
+ let pred = pred.to_device(&Device::Cpu)?;
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
@@ -153,6 +154,7 @@ pub fn report_pose(
confidence_threshold: f32,
nms_threshold: f32,
) -> Result<DynamicImage> {
+ let pred = pred.to_device(&Device::Cpu)?;
let (pred_size, npreds) = pred.dims2()?;
if pred_size != 17 * 3 + 4 + 1 {
candle::bail!("unexpected pred-size {pred_size}");
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index cff8e763..5a6bd41b 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -672,626 +672,4 @@ pub fn call_index_select(
}
#[cfg(test)]
-mod tests {
- use super::*;
- use half::f16;
- use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
-
- fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
- let options = MTLResourceOptions::StorageModeManaged;
- let ptr = data.as_ptr() as *const core::ffi::c_void;
- let size = (data.len() * std::mem::size_of::<T>()) as u64;
- device.new_buffer_with_data(ptr, size, options)
- }
-
- fn device() -> Device {
- Device::system_default().unwrap()
- }
-
- fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
- let b = 10f32.powi(digits);
- v.iter().map(|t| f32::round(t * b) / b).collect()
- }
-
- fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
- let b = 10f32.powi(digits);
- v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
- }
-
- fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
- let device = device();
- let kernels = Kernels::new();
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
- call_unary_contiguous(
- &device,
- command_buffer,
- &kernels,
- name,
- v.len(),
- &input,
- &mut output,
- )
- .unwrap();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
- }
-
- fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
- let device = device();
- let kernels = Kernels::new();
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let options = MTLResourceOptions::StorageModeManaged;
- let left = new_buffer(&device, x);
- let right = new_buffer(&device, y);
- let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
- call_binary_contiguous(
- &device,
- command_buffer,
- &kernels,
- name,
- x.len(),
- &left,
- &right,
- &mut output,
- )
- .unwrap();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- output.read_to_vec::<T>(x.len())
- }
-
- fn run_strided<T: Clone>(
- v: &[T],
- kernel: unary::strided::Kernel,
- shape: &[usize],
- strides: &[usize],
- offset: usize,
- ) -> Vec<T> {
- let device = device();
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
- let kernels = Kernels::new();
- call_unary_strided(
- &device,
- command_buffer,
- &kernels,
- kernel,
- shape,
- &input,
- strides,
- offset,
- &mut output,
- 0,
- )
- .unwrap();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
- }
-
- #[test]
- fn cos_f32() {
- let v = vec![1.0f32, 2.0, 3.0];
- let results = run(&v, unary::contiguous::cos::FLOAT);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
- assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
-
- let v = vec![1.0f32; 10_000];
- let results = run(&v, unary::contiguous::cos::FLOAT);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
- assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
- }
-
- #[test]
- fn cos_f32_strided() {
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- let shape = vec![6];
- let strides = vec![1];
- let offset = 0;
- let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(
- approx(results, 4),
- vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
- );
- assert_eq!(
- approx(expected, 4),
- vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
- );
-
- // Contiguous
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- let shape = vec![3, 2];
- let strides = vec![2, 1];
- let offset = 0;
- let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(
- approx(results, 4),
- vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
- );
- assert_eq!(
- approx(expected, 4),
- vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
- );
-
- // Transposed
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- let shape = vec![3, 2];
- let strides = vec![1, 3];
- let offset = 0;
- let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(
- approx(results, 4),
- vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
- );
- assert_eq!(
- approx(expected, 4),
- vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
- );
-
- // Very large
- let v = vec![1.0f32; 10_000];
- let shape = vec![2, 5_000];
- let strides = vec![2, 1];
- let offset = 0;
- let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
- assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
- }
-
- #[test]
- fn cos_strided_random() {
- let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
- let shape = vec![5_000, 2];
- let strides = vec![1, 5_000];
- let offset = 0;
- let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
- assert_eq!(
- approx(vec![results[1]], 4),
- approx(vec![expected[5_000]], 4)
- );
- assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
- assert_eq!(
- approx(vec![results[3]], 4),
- approx(vec![expected[5_001]], 4)
- );
- assert_eq!(
- approx(vec![results[5_000]], 4),
- approx(vec![expected[2_500]], 4)
- );
- }
-
- #[test]
- fn binary_add_f32() {
- let left = vec![1.0f32, 2.0, 3.0];
- let right = vec![2.0f32, 3.1, 4.2];
- let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
- let expected: Vec<_> = left
- .iter()
- .zip(right.iter())
- .map(|(&x, &y)| x + y)
- .collect();
- assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
- assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
- }
-
- fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
- let device = device();
- let kernels = Kernels::new();
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
-
- call_cast_contiguous(
- &device,
- command_buffer,
- &kernels,
- name,
- v.len(),
- &input,
- &mut output,
- )
- .unwrap();
- command_buffer.commit();
- command_buffer.wait_until_completed();
- output.read_to_vec::<U>(v.len())
- }
-
- #[test]
- fn cast_u32_f32() {
- let v = vec![1u32, 2, 3];
- let results = cast(&v, "cast_u32_f32");
- let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
- assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
- assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
-
- let v = vec![1.0f32; 10_000];
- let results = run(&v, unary::contiguous::cos::FLOAT);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
- assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
- }
-
- fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
- let device = device();
- let kernels = Kernels::new();
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
-
- let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
-
- let size = v.len();
-
- call_affine(
- &device,
- command_buffer,
- &kernels,
- size,
- &input,
- &mut output,
- mul as f32,
- add as f32,
- )
- .unwrap();
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- output.read_to_vec::<T>(v.len())
- }
-
- #[test]
- fn affine() {
- let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
- let mul = 1.5;
- let add = 1.1;
- let result = run_affine(&input, mul, add);
- assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
-
- let input = [1.0f32; 40_000];
- let mul = 1.5;
- let add = 1.1;
- let result = run_affine(&input, mul, add);
- assert_eq!(result, vec![2.6; 40_000]);
- }
-
- #[test]
- fn index_select() {
- let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
- let shape = [5, 2];
- let ids = [0u32, 4, 2];
- let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
- assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
-
- let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
- let shape = [2, 5];
- let ids = [0u32, 1, 0];
- let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
- assert_eq!(
- result,
- vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
- );
-
- let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
- let shape = [5, 2];
- let ids = [0u32, 1, 0];
- let dim = 1;
- let result = run_index_select(&embedding, &shape, &ids, dim);
- assert_eq!(
- result,
- vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
- );
- }
-
- fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
- embeddings: &[T],
- shape: &[usize],
- ids: &[I],
- dim: usize,
- ) -> Vec<T> {
- let device = Device::system_default().expect("no device found");
-
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let embeddings_buffer = new_buffer(&device, &embeddings);
- let ids_buffer = new_buffer(&device, &ids);
-
- let left_size: usize = shape[..dim].iter().product();
- let right_size: usize = shape[dim + 1..].iter().product();
- let dst_el = ids.len() * left_size * right_size;
- let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
-
- let kernels = Kernels::new();
- call_index_select(
- &device,
- &command_buffer,
- &kernels,
- "is_u32_f32",
- shape,
- ids.len(),
- dim,
- &embeddings_buffer,
- &ids_buffer,
- &mut dst_buffer,
- )
- .unwrap();
-
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- dst_buffer.read_to_vec::<T>(dst_el)
- }
-
- #[test]
- fn index_add() {
- let device = Device::system_default().expect("no device found");
-
- let options = CompileOptions::new();
- let library = device.new_library_with_source(INDEXING, &options).unwrap();
-
- let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
- let right = [1.0f32; 15];
- let index = [0u32, 4, 2];
- let ids_dim_size = index.len() as u32;
- let dst_dim_size: u32 = 15;
- let left_size: u32 = 3;
- let right_size: u32 = 3;
-
- let function = library.get_function("ia_u32_f32", None).unwrap();
- let pipeline = device
- .new_compute_pipeline_state_with_function(&function)
- .unwrap();
-
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let encoder = command_buffer.new_compute_command_encoder();
-
- encoder.set_compute_pipeline_state(&pipeline);
-
- let index_buffer = new_buffer(&device, &index);
- let inputs_buffer = new_buffer(&device, &left);
- let outputs_buffer = new_buffer(&device, &right);
-
- set_params!(
- encoder,
- (
- &index_buffer,
- &inputs_buffer,
- &outputs_buffer,
- ids_dim_size,
- left_size,
- dst_dim_size,
- right_size
- )
- );
-
- let grid_size = MTLSize {
- width: right.len() as NSUInteger,
- height: 1,
- depth: 1,
- };
-
- let thread_group_size = MTLSize {
- width: pipeline.max_total_threads_per_threadgroup(),
- height: 1,
- depth: 1,
- };
-
- encoder.dispatch_thread_groups(grid_size, thread_group_size);
- encoder.end_encoding();
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- let expected = vec![
- 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
- ];
- let result = outputs_buffer.read_to_vec::<f32>(right.len());
- assert_eq!(result, expected);
- }
-
- #[test]
- fn cos_f16() {
- let v: Vec<f16> = [1.0f32, 2.0, 3.0]
- .iter()
- .map(|v| f16::from_f32(*v))
- .collect();
- let results = run(&v, unary::contiguous::cos::HALF);
- let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
- assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
- assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
- }
-
- fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
- let device = device();
- let kernels = Kernels::new();
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let input = new_buffer(&device, v);
-
- let options = MTLResourceOptions::StorageModeManaged;
- let mut output =
- device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
- call_reduce_contiguous(
- &device,
- command_buffer,
- &kernels,
- name,
- v.len(),
- out_length,
- &input,
- &mut output,
- )
- .unwrap();
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- output.read_to_vec::<T>(out_length)
- }
-
- fn run_softmax<T: Clone + std::fmt::Debug>(
- v: &[T],
- last_dim: usize,
- name: &'static str,
- ) -> Vec<T> {
- let device = device();
- let kernels = Kernels::new();
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
- call_last_softmax(
- &device,
- command_buffer,
- &kernels,
- name,
- v.len(),
- last_dim,
- &input,
- &mut output,
- )
- .unwrap();
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- output.read_to_vec::<T>(v.len())
- }
-
- #[test]
- fn reduce_sum() {
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- let out_length = 1;
-
- let results = run_reduce(&v, out_length, "fast_sum_float");
- assert_eq!(approx(results, 4), vec![21.0]);
- }
-
- #[test]
- fn reduce_sum2() {
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- let out_length = 2;
-
- let results = run_reduce(&v, out_length, "fast_sum_float");
- assert_eq!(approx(results, 4), vec![6.0, 15.0]);
- }
-
- #[test]
- fn softmax() {
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_float");
- assert_eq!(
- approx(results, 4),
- vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
- );
-
- let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
- let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_float");
- assert_eq!(
- approx(results, 4),
- vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
- );
-
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- let last_dim = 3;
- let results = run_softmax(&v, last_dim, "softmax_float");
- assert_eq!(
- approx(results, 4),
- vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
- );
- }
-
- fn run_where_cond<I: Clone, T: Clone>(
- shape: &[usize],
- cond: &[I],
- (cond_stride, cond_offset): (Vec<usize>, usize),
- left_true: &[T],
- (left_stride, left_offset): (Vec<usize>, usize),
- right_false: &[T],
- (_right_stride, _right_offset): (Vec<usize>, usize),
- name: &'static str,
- ) -> Vec<T> {
- let device = device();
- let kernels = Kernels::new();
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let options = MTLResourceOptions::StorageModeManaged;
-
- let length = cond.len();
- let cond = device.new_buffer_with_data(
- cond.as_ptr() as *const core::ffi::c_void,
- std::mem::size_of_val(cond) as u64,
- options,
- );
- let left = device.new_buffer_with_data(
- left_true.as_ptr() as *const core::ffi::c_void,
- (length * core::mem::size_of::<T>()) as u64,
- options,
- );
- let right = device.new_buffer_with_data(
- right_false.as_ptr() as *const core::ffi::c_void,
- (length * core::mem::size_of::<T>()) as u64,
- options,
- );
-
- let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
- call_where_cond_strided(
- &device,
- command_buffer,
- &kernels,
- name,
- shape,
- &cond,
- (&cond_stride, cond_offset),
- &left,
- (&left_stride, left_offset),
- &right,
- (&cond_stride, cond_offset),
- &mut output,
- )
- .unwrap();
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- output.read_to_vec::<T>(length)
- }
-
- #[test]
- fn where_cond() {
- let shape = vec![6];
- let cond = vec![0u8, 1, 0, 0, 1, 1];
- let cond_l = (vec![1], 0);
- let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- let left_l = (vec![1], 0);
- let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
- let right_l = (vec![1], 0);
- let results = run_where_cond(
- &shape,
- &cond,
- cond_l,
- &left_true,
- left_l,
- &right_false,
- right_l,
- "where_u8_f32",
- );
- assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
- }
-}
+mod tests;
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
new file mode 100644
index 00000000..2330d48d
--- /dev/null
+++ b/candle-metal-kernels/src/tests.rs
@@ -0,0 +1,616 @@
+use super::*;
+use half::f16;
+use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
+
+fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
+ let options = MTLResourceOptions::StorageModeManaged;
+ let ptr = data.as_ptr() as *const core::ffi::c_void;
+ let size = (data.len() * std::mem::size_of::<T>()) as u64;
+ device.new_buffer_with_data(ptr, size, options)
+}
+
+fn device() -> Device {
+ Device::system_default().unwrap()
+}
+
+fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
+ let b = 10f32.powi(digits);
+ v.iter().map(|t| f32::round(t * b) / b).collect()
+}
+
+fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
+ let b = 10f32.powi(digits);
+ v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
+}
+
+fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let input = new_buffer(&device, v);
+ let mut output = new_buffer(&device, v);
+ call_unary_contiguous(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ v.len(),
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ output.read_to_vec::<T>(v.len())
+}
+
+fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let left = new_buffer(&device, x);
+ let right = new_buffer(&device, y);
+ let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
+ call_binary_contiguous(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ x.len(),
+ &left,
+ &right,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ output.read_to_vec::<T>(x.len())
+}
+
+fn run_strided<T: Clone>(
+ v: &[T],
+ kernel: unary::strided::Kernel,
+ shape: &[usize],
+ strides: &[usize],
+ offset: usize,
+) -> Vec<T> {
+ let device = device();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let input = new_buffer(&device, v);
+ let mut output = new_buffer(&device, v);
+ let kernels = Kernels::new();
+ call_unary_strided(
+ &device,
+ command_buffer,
+ &kernels,
+ kernel,
+ shape,
+ &input,
+ strides,
+ offset,
+ &mut output,
+ 0,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ output.read_to_vec::<T>(v.len())
+}
+
+#[test]
+fn cos_f32() {
+ let v = vec![1.0f32, 2.0, 3.0];
+ let results = run(&v, unary::contiguous::cos::FLOAT);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
+ assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
+
+ let v = vec![1.0f32; 10_000];
+ let results = run(&v, unary::contiguous::cos::FLOAT);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
+ assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+}
+
+#[test]
+fn cos_f32_strided() {
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let shape = vec![6];
+ let strides = vec![1];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(
+ approx(results, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+ assert_eq!(
+ approx(expected, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+
+ // Contiguous
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let shape = vec![3, 2];
+ let strides = vec![2, 1];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(
+ approx(results, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+ assert_eq!(
+ approx(expected, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+
+ // Transposed
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let shape = vec![3, 2];
+ let strides = vec![1, 3];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(
+ approx(results, 4),
+ vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
+ );
+ assert_eq!(
+ approx(expected, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+
+ // Very large
+ let v = vec![1.0f32; 10_000];
+ let shape = vec![2, 5_000];
+ let strides = vec![2, 1];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
+ assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+}
+
+#[test]
+fn cos_strided_random() {
+ let v: Vec<_> = (0..10_000).map(|_| rand::random::<f32>()).collect();
+ let shape = vec![5_000, 2];
+ let strides = vec![1, 5_000];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4));
+ assert_eq!(
+ approx(vec![results[1]], 4),
+ approx(vec![expected[5_000]], 4)
+ );
+ assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4));
+ assert_eq!(
+ approx(vec![results[3]], 4),
+ approx(vec![expected[5_001]], 4)
+ );
+ assert_eq!(
+ approx(vec![results[5_000]], 4),
+ approx(vec![expected[2_500]], 4)
+ );
+}
+
+#[test]
+fn binary_add_f32() {
+ let left = vec![1.0f32, 2.0, 3.0];
+ let right = vec![2.0f32, 3.1, 4.2];
+ let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
+ let expected: Vec<_> = left
+ .iter()
+ .zip(right.iter())
+ .map(|(&x, &y)| x + y)
+ .collect();
+ assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
+ assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
+}
+
+fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let input = new_buffer(&device, v);
+ let mut output = new_buffer(&device, v);
+
+ call_cast_contiguous(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ v.len(),
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ output.read_to_vec::<U>(v.len())
+}
+
+#[test]
+fn cast_u32_f32() {
+ let v = vec![1u32, 2, 3];
+ let results = cast(&v, "cast_u32_f32");
+ let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
+ assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
+ assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
+
+ let v = vec![1.0f32; 10_000];
+ let results = run(&v, unary::contiguous::cos::FLOAT);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
+ assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+}
+
+fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+
+ let input = new_buffer(&device, v);
+ let mut output = new_buffer(&device, v);
+
+ let size = v.len();
+
+ call_affine(
+ &device,
+ command_buffer,
+ &kernels,
+ size,
+ &input,
+ &mut output,
+ mul as f32,
+ add as f32,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(v.len())
+}
+
+#[test]
+fn affine() {
+ let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
+ let mul = 1.5;
+ let add = 1.1;
+ let result = run_affine(&input, mul, add);
+ assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
+
+ let input = [1.0f32; 40_000];
+ let mul = 1.5;
+ let add = 1.1;
+ let result = run_affine(&input, mul, add);
+ assert_eq!(result, vec![2.6; 40_000]);
+}
+
+#[test]
+fn index_select() {
+ let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
+ let shape = [5, 2];
+ let ids = [0u32, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim);
+ assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
+
+ let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
+ let shape = [2, 5];
+ let ids = [0u32, 1, 0];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim);
+ assert_eq!(
+ result,
+ vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
+ );
+
+ let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
+ let shape = [5, 2];
+ let ids = [0u32, 1, 0];
+ let dim = 1;
+ let result = run_index_select(&embedding, &shape, &ids, dim);
+ assert_eq!(
+ result,
+ vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
+ );
+}
+
+fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
+ embeddings: &[T],
+ shape: &[usize],
+ ids: &[I],
+ dim: usize,
+) -> Vec<T> {
+ let device = Device::system_default().expect("no device found");
+
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let embeddings_buffer = new_buffer(&device, &embeddings);
+ let ids_buffer = new_buffer(&device, &ids);
+
+ let left_size: usize = shape[..dim].iter().product();
+ let right_size: usize = shape[dim + 1..].iter().product();
+ let dst_el = ids.len() * left_size * right_size;
+ let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
+
+ let kernels = Kernels::new();
+ call_index_select(
+ &device,
+ &command_buffer,
+ &kernels,
+ "is_u32_f32",
+ shape,
+ ids.len(),
+ dim,
+ &embeddings_buffer,
+ &ids_buffer,
+ &mut dst_buffer,
+ )
+ .unwrap();
+
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ dst_buffer.read_to_vec::<T>(dst_el)
+}
+
+#[test]
+fn index_add() {
+ let device = Device::system_default().expect("no device found");
+
+ let options = CompileOptions::new();
+ let library = device.new_library_with_source(INDEXING, &options).unwrap();
+
+ let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
+ let right = [1.0f32; 15];
+ let index = [0u32, 4, 2];
+ let ids_dim_size = index.len() as u32;
+ let dst_dim_size: u32 = 15;
+ let left_size: u32 = 3;
+ let right_size: u32 = 3;
+
+ let function = library.get_function("ia_u32_f32", None).unwrap();
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(&function)
+ .unwrap();
+
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ let index_buffer = new_buffer(&device, &index);
+ let inputs_buffer = new_buffer(&device, &left);
+ let outputs_buffer = new_buffer(&device, &right);
+
+ set_params!(
+ encoder,
+ (
+ &index_buffer,
+ &inputs_buffer,
+ &outputs_buffer,
+ ids_dim_size,
+ left_size,
+ dst_dim_size,
+ right_size
+ )
+ );
+
+ let grid_size = MTLSize {
+ width: right.len() as NSUInteger,
+ height: 1,
+ depth: 1,
+ };
+
+ let thread_group_size = MTLSize {
+ width: pipeline.max_total_threads_per_threadgroup(),
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(grid_size, thread_group_size);
+ encoder.end_encoding();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ let expected = vec![
+ 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
+ ];
+ let result = outputs_buffer.read_to_vec::<f32>(right.len());
+ assert_eq!(result, expected);
+}
+
+#[test]
+fn cos_f16() {
+ let v: Vec<f16> = [1.0f32, 2.0, 3.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect();
+ let results = run(&v, unary::contiguous::cos::HALF);
+ let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
+ assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
+ assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
+}
+
+fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let input = new_buffer(&device, v);
+
+ let options = MTLResourceOptions::StorageModeManaged;
+ let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
+ call_reduce_contiguous(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ v.len(),
+ out_length,
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(out_length)
+}
+
+fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let input = new_buffer(&device, v);
+ let mut output = new_buffer(&device, v);
+ call_last_softmax(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ v.len(),
+ last_dim,
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(v.len())
+}
+
+#[test]
+fn reduce_sum() {
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let out_length = 1;
+
+ let results = run_reduce(&v, out_length, "fast_sum_float");
+ assert_eq!(approx(results, 4), vec![21.0]);
+}
+
+#[test]
+fn reduce_sum2() {
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let out_length = 2;
+
+ let results = run_reduce(&v, out_length, "fast_sum_float");
+ assert_eq!(approx(results, 4), vec![6.0, 15.0]);
+}
+
+#[test]
+fn softmax() {
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_float");
+ assert_eq!(
+ approx(results, 4),
+ vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
+ );
+
+ let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_float");
+ assert_eq!(
+ approx(results, 4),
+ vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
+ );
+
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let last_dim = 3;
+ let results = run_softmax(&v, last_dim, "softmax_float");
+ assert_eq!(
+ approx(results, 4),
+ vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
+ );
+}
+
+fn run_where_cond<I: Clone, T: Clone>(
+ shape: &[usize],
+ cond: &[I],
+ (cond_stride, cond_offset): (Vec<usize>, usize),
+ left_true: &[T],
+ (left_stride, left_offset): (Vec<usize>, usize),
+ right_false: &[T],
+ (_right_stride, _right_offset): (Vec<usize>, usize),
+ name: &'static str,
+) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let length = cond.len();
+ let cond = device.new_buffer_with_data(
+ cond.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(cond) as u64,
+ options,
+ );
+ let left = device.new_buffer_with_data(
+ left_true.as_ptr() as *const core::ffi::c_void,
+ (length * core::mem::size_of::<T>()) as u64,
+ options,
+ );
+ let right = device.new_buffer_with_data(
+ right_false.as_ptr() as *const core::ffi::c_void,
+ (length * core::mem::size_of::<T>()) as u64,
+ options,
+ );
+
+ let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ call_where_cond_strided(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ shape,
+ &cond,
+ (&cond_stride, cond_offset),
+ &left,
+ (&left_stride, left_offset),
+ &right,
+ (&cond_stride, cond_offset),
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(length)
+}
+
+#[test]
+fn where_cond() {
+ let shape = vec![6];
+ let cond = vec![0u8, 1, 0, 0, 1, 1];
+ let cond_l = (vec![1], 0);
+ let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let left_l = (vec![1], 0);
+ let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
+ let right_l = (vec![1], 0);
+ let results = run_where_cond(
+ &shape,
+ &cond,
+ cond_l,
+ &left_true,
+ left_l,
+ &right_false,
+ right_l,
+ "where_u8_f32",
+ );
+ assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
+}
diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs
index 9ded5f71..68d384a6 100644
--- a/candle-nn/examples/cpu_benchmarks.rs
+++ b/candle-nn/examples/cpu_benchmarks.rs
@@ -6,7 +6,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle::quantized::GgmlType;
-use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D};
+use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D};
use clap::{Parser, Subcommand};
const CHECK_CONV2D: bool = false;
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index b0c623d3..ade00012 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -17,7 +17,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
+use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType};
mod utils;
use utils::wrap_err;
diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs
new file mode 100644
index 00000000..ea074c97
--- /dev/null
+++ b/candle-transformers/src/models/distilbert.rs
@@ -0,0 +1,342 @@
+use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
+use candle::{DType, Device, Result, Tensor};
+use candle_nn::{Embedding, Module, VarBuilder};
+use serde::Deserialize;
+
+pub const DTYPE: DType = DType::F32;
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "lowercase")]
+enum HiddenAct {
+ Gelu,
+ Relu,
+}
+
+struct HiddenActLayer {
+ act: HiddenAct,
+ span: tracing::Span,
+}
+
+impl HiddenActLayer {
+ fn new(act: HiddenAct) -> Self {
+ let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
+ Self { act, span }
+ }
+}
+
+impl Module for HiddenActLayer {
+ fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
+ let _enter = self.span.enter();
+ match self.act {
+ // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
+ HiddenAct::Gelu => xs.gelu(),
+ HiddenAct::Relu => xs.relu(),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
+#[serde(rename_all = "lowercase")]
+enum PositionEmbeddingType {
+ #[default]
+ Absolute,
+}
+
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ dim: usize,
+ n_layers: usize,
+ n_heads: usize,
+ hidden_dim: usize,
+ activation: HiddenAct,
+ max_position_embeddings: usize,
+ initializer_range: f64,
+ pad_token_id: usize,
+ #[serde(default)]
+ position_embedding_type: PositionEmbeddingType,
+ #[serde(default)]
+ use_cache: bool,
+ model_type: Option<String>,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 30522,
+ dim: 768,
+ n_layers: 12,
+ n_heads: 12,
+ hidden_dim: 3072,
+ activation: HiddenAct::Gelu,
+ max_position_embeddings: 512,
+ initializer_range: 0.02,
+ pad_token_id: 0,
+ position_embedding_type: PositionEmbeddingType::Absolute,
+ use_cache: true,
+ model_type: Some("distilbert".to_string()),
+ }
+ }
+}
+
+struct Embeddings {
+ word_embeddings: Embedding,
+ position_embeddings: Embedding,
+ layer_norm: LayerNorm,
+ span: tracing::Span,
+}
+
+impl Embeddings {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let word_embeddings =
+ candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
+ let position_embeddings = candle_nn::embedding(
+ config.max_position_embeddings,
+ config.dim,
+ vb.pp("position_embeddings"),
+ )?;
+ let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
+ Ok(Self {
+ word_embeddings,
+ position_embeddings,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "embeddings"),
+ })
+ }
+
+ fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (_bsize, seq_len) = input_ids.dims2()?;
+ let input_embeddings = self.word_embeddings.forward(input_ids)?;
+ let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
+ let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
+ let embeddings =
+ input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
+
+ let embeddings = self.layer_norm.forward(&embeddings)?;
+ Ok(embeddings)
+ }
+}
+
+struct MultiHeadSelfAttention {
+ q_lin: Linear,
+ k_lin: Linear,
+ v_lin: Linear,
+ out_lin: Linear,
+ n_heads: usize,
+ attention_head_size: usize,
+ span: tracing::Span,
+}
+
+impl MultiHeadSelfAttention {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention_head_size = config.dim / config.n_heads;
+ let all_head_size = config.n_heads * attention_head_size;
+ let dim = config.dim;
+ let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
+ let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
+ let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
+ let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
+ Ok(Self {
+ q_lin,
+ k_lin,
+ v_lin,
+ out_lin,
+ n_heads: config.n_heads,
+ attention_head_size,
+ span: tracing::span!(tracing::Level::TRACE, "attention"),
+ })
+ }
+}
+
+impl MultiHeadSelfAttention {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (bs, q_length, _dim) = hidden_states.dims3()?;
+
+ let dim_per_head = self.attention_head_size;
+ let q = self.q_lin.forward(hidden_states)?;
+ let k = self.k_lin.forward(hidden_states)?;
+ let v = self.v_lin.forward(hidden_states)?;
+
+ let q = q
+ .reshape((bs, q_length, self.n_heads, dim_per_head))?
+ .transpose(1, 2)?;
+ let k = k
+ .reshape((bs, q_length, self.n_heads, dim_per_head))?
+ .transpose(1, 2)?;
+ let v = v
+ .reshape((bs, q_length, self.n_heads, dim_per_head))?
+ .transpose(1, 2)?;
+
+ let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
+ let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
+ let mask = attention_mask.broadcast_as(scores.shape())?;
+
+ let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
+ let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
+
+ let context = weights.matmul(&v.contiguous()?)?;
+ let context = context
+ .transpose(1, 2)?
+ .reshape((bs, q_length, self.n_heads * dim_per_head))?
+ .contiguous()?;
+ let context = self.out_lin.forward(&context)?;
+
+ Ok(context)
+ }
+}
+
+#[allow(clippy::upper_case_acronyms)]
+struct FFN {
+ lin1: Linear,
+ lin2: Linear,
+ activation: HiddenActLayer,
+ span: tracing::Span,
+}
+
+impl FFN {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
+ let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
+ Ok(Self {
+ lin1,
+ lin2,
+ activation: HiddenActLayer::new(config.activation),
+ span: tracing::span!(tracing::Level::TRACE, "ffn"),
+ })
+ }
+}
+
+impl Module for FFN {
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ hidden_states
+ .apply(&self.lin1)?
+ .apply(&self.activation)?
+ .apply(&self.lin2)
+ }
+}
+
+struct TransformerBlock {
+ attention: MultiHeadSelfAttention,
+ sa_layer_norm: LayerNorm,
+ ffn: FFN,
+ output_layer_norm: LayerNorm,
+ span: tracing::Span,
+}
+
+impl TransformerBlock {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
+ let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
+ let ffn = FFN::load(vb.pp("ffn"), config)?;
+ let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
+ Ok(Self {
+ attention,
+ sa_layer_norm,
+ ffn,
+ output_layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "layer"),
+ })
+ }
+}
+
+impl TransformerBlock {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let sa_output = self.attention.forward(hidden_states, attention_mask)?;
+ // TODO: Support cross-attention?
+ // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
+ // TODO: Support something similar to `apply_chunking_to_forward`?
+ let sa_output = sa_output.broadcast_add(hidden_states)?;
+ let sa_output = self.sa_layer_norm.forward(&sa_output)?;
+
+ let ffn_output = self.ffn.forward(&sa_output)?;
+ let ffn_output = (&ffn_output + sa_output)?;
+ let output = self.output_layer_norm.forward(&ffn_output)?;
+ Ok(output)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
+struct Transformer {
+ layers: Vec<TransformerBlock>,
+ span: tracing::Span,
+}
+
+impl Transformer {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let layers = (0..config.n_layers)
+ .map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
+ .collect::<Result<Vec<_>>>()?;
+ let span = tracing::span!(tracing::Level::TRACE, "encoder");
+ Ok(Transformer { layers, span })
+ }
+}
+
+impl Transformer {
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut hidden_states = hidden_states.clone();
+ // Use a loop rather than a fold as it's easier to modify when adding debug/...
+ for layer in self.layers.iter() {
+ hidden_states = layer.forward(&hidden_states, attention_mask)?;
+ }
+ Ok(hidden_states)
+ }
+}
+
+pub struct DistilBertModel {
+ embeddings: Embeddings,
+ transformer: Transformer,
+ pub device: Device,
+ span: tracing::Span,
+}
+
+impl DistilBertModel {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let (embeddings, transformer) = match (
+ Embeddings::load(vb.pp("embeddings"), config),
+ Transformer::load(vb.pp("transformer"), config),
+ ) {
+ (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
+ (Err(err), _) | (_, Err(err)) => {
+ if let Some(model_type) = &config.model_type {
+ if let (Ok(embeddings), Ok(encoder)) = (
+ Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
+ Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
+ ) {
+ (embeddings, encoder)
+ } else {
+ return Err(err);
+ }
+ } else {
+ return Err(err);
+ }
+ }
+ };
+ Ok(Self {
+ embeddings,
+ transformer,
+ device: vb.device().clone(),
+ span: tracing::span!(tracing::Level::TRACE, "model"),
+ })
+ }
+
+ pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let embedding_output = self.embeddings.forward(input_ids)?;
+ let sequence_output = self
+ .transformer
+ .forward(&embedding_output, attention_mask)?;
+ Ok(sequence_output)
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 558583b6..a9a56673 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -4,6 +4,7 @@ pub mod blip;
pub mod blip_text;
pub mod convmixer;
pub mod dinov2;
+pub mod distilbert;
pub mod efficientnet;
pub mod falcon;
pub mod jina_bert;
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs
index 0c95cfef..cef06f1c 100644
--- a/candle-transformers/src/models/stable_diffusion/utils.rs
+++ b/candle-transformers/src/models/stable_diffusion/utils.rs
@@ -1,12 +1,15 @@
use candle::{Device, Result, Tensor};
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
- if steps < 1 {
- candle::bail!("cannot use linspace with steps {steps} <= 1")
+ if steps == 0 {
+ Tensor::from_vec(Vec::<f64>::new(), steps, &Device::Cpu)
+ } else if steps == 1 {
+ Tensor::from_vec(vec![start], steps, &Device::Cpu)
+ } else {
+ let delta = (stop - start) / (steps - 1) as f64;
+ let vs = (0..steps)
+ .map(|step| start + step as f64 * delta)
+ .collect::<Vec<_>>();
+ Tensor::from_vec(vs, steps, &Device::Cpu)
}
- let delta = (stop - start) / (steps - 1) as f64;
- let vs = (0..steps)
- .map(|step| start + step as f64 * delta)
- .collect::<Vec<_>>();
- Tensor::from_vec(vs, steps, &Device::Cpu)
}
diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs
index 5d53728c..e5fa7dec 100644
--- a/candle-wasm-tests/tests/quantized_tests.rs
+++ b/candle-wasm-tests/tests/quantized_tests.rs
@@ -1,7 +1,7 @@
use candle::{
quantized::{self, k_quants, GgmlDType, GgmlType},
test_utils::to_vec2_round,
- Device, Result, Tensor,
+ Device, Module, Result, Tensor,
};
use wasm_bindgen_test::*;