diff options
-rw-r--r-- | README.md | 7 | ||||
-rw-r--r-- | candle-core/benches/bench_main.rs | 7 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/affine.rs | 43 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/mod.rs | 2 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/where_cond.rs | 64 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 3 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 12 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 16 | ||||
-rw-r--r-- | candle-examples/build.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/custom-ops/cuda_kernels.rs | 1 | ||||
-rw-r--r-- | candle-examples/examples/phi/main.rs | 47 | ||||
-rw-r--r-- | candle-examples/examples/repvgg/README.md | 6 | ||||
-rw-r--r-- | candle-metal-kernels/src/affine.metal | 14 | ||||
-rw-r--r-- | candle-metal-kernels/src/ternary.metal | 66 | ||||
-rw-r--r-- | candle-nn/src/activation.rs | 1 | ||||
-rw-r--r-- | candle-onnx/src/eval.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/phi.rs | 363 |
18 files changed, 608 insertions, 57 deletions
@@ -66,7 +66,7 @@ We also provide a some command line based examples using state of the art models - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM pre-trained on 1T tokens of English and code datasets. -- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal +- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal implementation of the Mamba state space model. - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with better performance than all publicly available 13b models as of 2023-09-28. @@ -109,6 +109,9 @@ We also provide a some command line based examples using state of the art models - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained using self-supervision (can be used for imagenet classification, depth evaluation, segmentation). +- [VGG](./candle-examples/examples/vgg/), + [RepVGG](./candle-examples/examples/repvgg): computer vision models. +- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to - [BLIP](./candle-examples/examples/blip/): image to text model, can be used to generate captions for an image. - [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation @@ -204,7 +207,7 @@ If you have an addition to this list, please submit a pull request. - Image to text. - BLIP. - Computer Vision Models. - - DINOv2, ConvMixer, EfficientNet, ResNet, ViT. + - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG. - yolo-v3, yolo-v8. - Segment-Anything Model (SAM). - File formats: load models from safetensors, npz, ggml, or PyTorch files. diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 8913df4f..162e3f2b 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,9 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches, benchmarks::random::benches); +criterion_main!( + benchmarks::affine::benches, + benchmarks::matmul::benches, + benchmarks::random::benches, + benchmarks::where_cond::benches +); diff --git a/candle-core/benches/benchmarks/affine.rs b/candle-core/benches/benchmarks/affine.rs new file mode 100644 index 00000000..eded9f57 --- /dev/null +++ b/candle-core/benches/benchmarks/affine.rs @@ -0,0 +1,43 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor) { + a.affine(12.34, 56.78).unwrap(); +} + +fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let b = 1; + let m = 1024; + let k = 1024; + + let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + + let flops = b * m * k * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&tensor)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + run_affine_benchmark(c, &device, DType::F32, "affine_f32"); + run_affine_benchmark(c, &device, DType::F16, "affine_f16"); + run_affine_benchmark(c, &device, DType::BF16, "affine_bf16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index eb20ea70..c45effee 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,5 +1,7 @@ +pub(crate) mod affine; pub(crate) mod matmul; pub(crate) mod random; +pub(crate) mod where_cond; use candle_core::{Device, Result}; diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs new file mode 100644 index 00000000..c517dcf5 --- /dev/null +++ b/candle-core/benches/benchmarks/where_cond.rs @@ -0,0 +1,64 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor, c: &Tensor) { + a.where_cond(b, c).unwrap(); +} + +const fn create_cond_arr<const N: usize>() -> [u8; N] { + let mut arr = [0u8; N]; + let mut i = 0; + while i < N { + arr[i] = (i % 2) as u8; + i += 1; + } + arr +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; +const SIZE: usize = B * M * K; + +const DATA: [u8; SIZE] = create_cond_arr::<SIZE>(); + +fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap(); + let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap(); + let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap(); + + let elements = B * M * K; + // E.g. 2 f32 tensors + 1 u8 tensor + let flops = (2 * elements * dtype.size_in_bytes()) + elements; + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box(&tensor), + black_box(&on_true), + black_box(&on_false), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32"); + run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16"); + run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 48250233..8a75bd7c 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -355,6 +355,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", + DType::BF16 => "affine_bf16", dtype => crate::bail!("Metal contiguous affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine( @@ -373,6 +374,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "affine_f32_strided", DType::F16 => "affine_f16_strided", + DType::BF16 => "affine_bf16_strided", dtype => crate::bail!("Metal strided affine {dtype:?} not implemented"), }; candle_metal_kernels::call_affine_strided( @@ -808,6 +810,7 @@ impl BackendStorage for MetalStorage { } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 54f9fa2b..3100c6e8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2578,11 +2578,21 @@ impl Tensor { } /// Returns log(sum(exp(tensor), dim)). - pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { let exp = self.exp()?; let sum = exp.sum(sum_dims)?; sum.log() } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> { + rhs.broadcast_mul(&self.log()?)?.exp() + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index e83fb55b..33bab1b6 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { } #[test] -fn logsumexp() -> Result<()> { +fn log_sum_exp() -> Result<()> { let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; - let output = input.logsumexp(D::Minus1)?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; assert_close(&output, &expected, 0.00001)?; Ok(()) } + +#[test] +fn pow() -> Result<()> { + let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let rhs = (&lhs - 2.)?; + let res = lhs.pow(&rhs)?; + assert_eq!( + test_utils::to_vec2_round(&res, 4)?, + [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]] + ); + Ok(()) +} diff --git a/candle-examples/build.rs b/candle-examples/build.rs index ba40aeb4..33497714 100644 --- a/candle-examples/build.rs +++ b/candle-examples/build.rs @@ -27,11 +27,5 @@ fn main() -> Result<()> { bindings.write(kdir.rust_target).unwrap() } } - #[cfg(not(feature = "cuda"))] - { - for kdir in KERNEL_DIRS.iter() { - let _file = std::fs::File::create(kdir.rust_target)?; - } - } Ok(()) } diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs index c00b601b..e69de29b 100644 --- a/candle-examples/examples/custom-ops/cuda_kernels.rs +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -1 +0,0 @@ -pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx")); diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index c529867b..ea99c706 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -8,6 +8,7 @@ use anyhow::{Error as E, Result}; use clap::{Parser, ValueEnum}; use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer}; +use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi}; use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle::{DType, Device, Tensor}; @@ -18,6 +19,7 @@ use tokenizers::Tokenizer; enum Model { MixFormer(MixFormer), + Phi(Phi), Quantized(QMixFormer), } @@ -84,6 +86,7 @@ impl TextGeneration { let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = match &mut self.model { Model::MixFormer(m) => m.forward(&input)?, + Model::Phi(m) => m.forward(&input)?, Model::Quantized(m) => m.forward(&input)?, }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; @@ -117,7 +120,7 @@ impl TextGeneration { } } -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)] enum WhichModel { #[value(name = "1")] V1, @@ -125,6 +128,9 @@ enum WhichModel { V1_5, #[value(name = "2")] V2, + // TODO: Make this the default once it has been battle tested. + #[value(name = "2-new")] + V2New, PuffinPhiV2, PhiHermes, } @@ -169,7 +175,7 @@ struct Args { #[arg(long)] model_id: Option<String>, - #[arg(long, default_value = "1.5")] + #[arg(long, default_value = "2")] model: WhichModel, #[arg(long)] @@ -230,7 +236,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), - WhichModel::V2 => "microsoft/phi-2".to_string(), + WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(), WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "lmz/candle-quantized-phi".to_string() } @@ -247,7 +253,8 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), - WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), + WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "main".to_string() } } @@ -258,7 +265,9 @@ fn main() -> Result<()> { let tokenizer_filename = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => match args.model { - WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?, + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => { + repo.get("tokenizer.json")? + } WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { repo.get("tokenizer-puffin-phi-v2.json")? } @@ -271,14 +280,14 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], - WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?], + WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], } } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], - WhichModel::V2 => candle_examples::hub_load_safetensors( + WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors( &repo, "model.safetensors.index.json", )?, @@ -292,25 +301,35 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = match args.model { + let config = || match args.model { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), - WhichModel::V2 => Config::v2(), + WhichModel::V2 | WhichModel::V2New => Config::v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.quantized { + let (model, device) = if args.model == WhichModel::V2New { + let device = candle_examples::device(args.cpu)?; + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: PhiConfig = serde_json::from_str(&config)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let phi = Phi::new(&config, vb)?; + (Model::Phi(phi), device) + } else if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let config = config(); let model = match args.model { - WhichModel::V2 => QMixFormer::new_v2(&config, vb)?, + WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; + let config = config(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let model = match args.model { - WhichModel::V2 => MixFormer::new_v2(&config, vb)?, + WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?, _ => MixFormer::new(&config, vb)?, }; (Model::MixFormer(model), device) @@ -393,6 +412,10 @@ fn mmlu<P: AsRef<std::path::Path>>( m.clear_kv_cache(); m.forward(&input)? } + Model::Phi(m) => { + m.clear_kv_cache(); + m.forward(&input)? + } Model::Quantized(m) => { m.clear_kv_cache(); m.forward(&input)? diff --git a/candle-examples/examples/repvgg/README.md b/candle-examples/examples/repvgg/README.md index 2cb807c1..d24bcd6d 100644 --- a/candle-examples/examples/repvgg/README.md +++ b/candle-examples/examples/repvgg/README.md @@ -1,7 +1,9 @@ # candle-repvgg -A candle implementation of inference using a pre-trained [repvgg](https://arxiv.org/abs/2101.03697). -This uses a classification head trained on the ImageNet dataset and returns the +[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/abs/2101.03697). + +This candle implementation uses a pre-trained RepVGG network for inference. The +classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes. ## Running an example diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 3d8e7f0d..a4484998 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -17,19 +17,19 @@ METAL_FUNC uint get_strided_index( using namespace metal; -#define AFFINE(FN_NAME, TYPENAME) \ +#define AFFINE(FN_NAME, T) \ kernel void FN_NAME( \ constant size_t &dim, \ constant float &mul, \ constant float &add, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ + device const T *input, \ + device T *output, \ uint id [[ thread_position_in_grid ]] \ ) { \ if (id >= dim) { \ return; \ } \ - output[id] = TYPENAME(float(input[id]) * mul + add); \ + output[id] = T(fma(float(input[id]), mul, add)); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ @@ -38,14 +38,14 @@ kernel void FN_NAME##_strided( \ constant size_t *strides, \ constant float &mul, \ constant float &add, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ + device const T *input, \ + device T *output, \ uint id [[ thread_position_in_grid ]] \ ) { \ if (id >= dim) { \ return; \ } \ - output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \ + output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \ } #define POWF(FN_NAME, TYPENAME) \ diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 40b4bcf4..7b3b8ca9 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -17,29 +17,45 @@ METAL_FUNC uint get_strided_index( return strided_i; } +template<typename T, typename ID> +METAL_FUNC void where_cond( + constant size_t &numel, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t *strides_t, + constant size_t *strides_f, + device const ID *ids, + device const T *t, + device const T *f, + device T *out, + uint i [[ thread_position_in_grid ]] +) { + if (i >= numel){ + return; + } + uint strided_i = get_strided_index(i, num_dims, dims, strides); + uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); + uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; +} -#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ -kernel void FN_NAME( \ - constant size_t &numel, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t *strides_t, \ - constant size_t *strides_f, \ - device const ID_TYPENAME *ids, \ - device const TYPENAME *t, \ - device const TYPENAME *f, \ - device TYPENAME *out ,\ - uint i [[ thread_position_in_grid ]] \ -) { \ - if (i >= numel){ \ - return; \ - } \ - uint strided_i = get_strided_index(i, num_dims, dims, strides); \ - uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ - uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ - out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ -} \ +#define WHERE_OP(T, ID, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t *strides_t, \ + constant size_t *strides_f, \ + device const ID *ids, \ + device const T *t, \ + device const T *f, \ + device T *out, \ + uint i [[ thread_position_in_grid ]] \ +) { \ + where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \ +} \ // WHERE_OP(float, int64_t, where_i64_f32) // WHERE_OP(double, int64_t, where_i64_f64) @@ -54,10 +70,14 @@ kernel void FN_NAME( \ // WHERE_OP(int64_t, uint32_t, where_u32_i64) WHERE_OP(float, uint8_t, where_u8_f32) -// WHERE_OP(double, uint8_t, where_u8_f64) +WHERE_OP(half, uint8_t, where_u8_f16) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) #if __METAL_VERSION__ >= 220 WHERE_OP(int64_t, uint8_t, where_u8_i64) #endif + +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, uint8_t, where_u8_bf16) +#endif
\ No newline at end of file diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 80b750ed..e00463f0 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -6,6 +6,7 @@ use serde::Deserialize; pub enum Activation { #[default] Gelu, + #[serde(alias = "gelu_new")] NewGelu, Relu, Relu2, diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 684776c2..c0ad8668 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -254,6 +254,12 @@ pub fn simple_eval( let output = input0.broadcast_div(input1)?; values.insert(node.output[0].clone(), output); } + "Pow" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; + let output = input0.broadcast_pow(input1)?; + values.insert(node.output[0].clone(), output); + } "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a60b5a06..9af6df69 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -17,6 +17,7 @@ pub mod mixformer; pub mod mixtral; pub mod mpt; pub mod persimmon; +pub mod phi; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs new file mode 100644 index 00000000..8bf357e7 --- /dev/null +++ b/candle-transformers/src/models/phi.rs @@ -0,0 +1,363 @@ +use crate::models::with_tracing::{layer_norm, linear, Embedding, LayerNorm, Linear}; +/// Phi model. +/// https://huggingface.co/microsoft/phi-2 +/// There is an alternative implementation of the phi model in mixformers.rs. +/// This corresponds to the model update made with the following commit: +/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869 +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use serde::Deserialize; + +// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub(crate) vocab_size: usize, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: Option<usize>, + pub(crate) hidden_act: Activation, + pub(crate) max_position_embeddings: usize, + pub(crate) layer_norm_eps: f64, + pub(crate) tie_word_embeddings: bool, + pub(crate) rope_theta: f32, + pub(crate) partial_rotary_factor: f64, + pub(crate) qk_layernorm: bool, +} + +impl Config { + fn num_key_value_heads(&self) -> usize { + self.num_key_value_heads.unwrap_or(self.num_attention_heads) + } + + fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + dim: usize, + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(cfg: &Config, dev: &Device) -> Result<Self> { + let dim = (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)? + .to_dtype(DType::F32)? + .reshape((cfg.max_position_embeddings, 1))?; + let freqs = t.matmul(&inv_freq)?; + let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + Ok(Self { + dim, + sin: emb.sin()?, + cos: emb.cos()?, + }) + } + + fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> { + let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?; + let xs_rot = xs.i((.., .., .., ..self.dim))?; + let xs_pass = xs.i((.., .., .., self.dim..))?; + let xs12 = xs_rot.chunk(2, D::Minus1)?; + let (xs1, xs2) = (&xs12[0], &xs12[1]); + let c = self.cos.narrow(0, seqlen_offset, seq_len)?; + let s = self.sin.narrow(0, seqlen_offset, seq_len)?; + let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?; + let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?; + Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + fc1: Linear, + fc2: Linear, + act: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?; + let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; + Ok(Self { + fc1, + fc2, + // This does not match the mixformers implementation where Gelu is used rather than + // GeluNew. + act: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) + } +} + +#[derive(Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + dense: Linear, + kv_cache: Option<(Tensor, Tensor)>, + q_layernorm: Option<LayerNorm>, + k_layernorm: Option<LayerNorm>, + rotary_emb: RotaryEmbedding, + softmax_scale: f64, + num_heads: usize, + num_kv_heads: usize, + head_dim: usize, + span: tracing::Span, +} + +fn get_mask(size: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +impl Attention { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads(); + let head_dim = cfg.head_dim(); + let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?; + // Alternative rope scalings are not supported. + let rotary_emb = RotaryEmbedding::new(cfg, vb.device())?; + let (q_layernorm, k_layernorm) = if cfg.qk_layernorm { + let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?; + let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?; + (Some(q_layernorm), Some(k_layernorm)) + } else { + (None, None) + }; + let softmax_scale = 1f64 / (head_dim as f64).sqrt(); + Ok(Self { + q_proj, + k_proj, + v_proj, + dense, + kv_cache: None, + q_layernorm, + k_layernorm, + rotary_emb, + softmax_scale, + num_heads, + num_kv_heads, + head_dim, + span: tracing::span!(tracing::Level::TRACE, "attention"), + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> { + let n_rep = self.num_heads / self.num_kv_heads; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b_size, seq_len, _n_embd) = xs.dims3()?; + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = match &self.q_layernorm { + None => query_states, + Some(ln) => query_states.apply(ln)?, + }; + let key_states = match &self.k_layernorm { + None => key_states, + Some(ln) => key_states.apply(ln)?, + }; + + let query_states = query_states + .reshape((b_size, seq_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + // Rotary embeddings. + let seqlen_offset = match &self.kv_cache { + None => 0, + Some((prev_k, _)) => prev_k.dim(2)?, + }; + let query_states = self + .rotary_emb + .apply_rotary_emb(&query_states, seqlen_offset)?; + let key_states = self + .rotary_emb + .apply_rotary_emb(&key_states, seqlen_offset)?; + + // KV cache. + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &key_states], 2)?; + let v = Tensor::cat(&[prev_v, &value_states], 2)?; + (k, v) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + // Repeat kv. + let key_states = self.repeat_kv(key_states)?.contiguous()?; + let value_states = self.repeat_kv(value_states)?.contiguous()?; + + let attn_weights = (query_states + .to_dtype(DType::F32)? + .contiguous()? + .matmul(&key_states.to_dtype(DType::F32)?.t()?)? + * self.softmax_scale)?; + let attn_weights = match mask { + None => attn_weights, + Some(mask) => masked_fill( + &attn_weights, + &mask.broadcast_left((b_size, self.num_heads))?, + f32::NEG_INFINITY, + )?, + }; + let attn_weights = + candle_nn::ops::softmax_last_dim(&attn_weights)?.to_dtype(value_states.dtype())?; + let attn_output = attn_weights.matmul(&value_states)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_size, seq_len, ()))?; + attn_output.apply(&self.dense) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: LayerNorm, + span: tracing::Span, +} + +impl DecoderLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let self_attn = Attention::new(cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = layer_norm( + cfg.hidden_size, + cfg.layer_norm_eps, + vb.pp("input_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); + let residual = xs; + let xs = xs.apply(&self.input_layernorm)?; + let attn_outputs = self.self_attn.forward(&xs, mask)?; + let feed_forward_hidden_states = self.mlp.forward(&xs)?; + attn_outputs + feed_forward_hidden_states + residual + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } +} + +#[derive(Clone)] +pub struct Model { + embed_tokens: Embedding, + layers: Vec<DecoderLayer>, + final_layernorm: LayerNorm, + lm_head: Linear, + span: tracing::Span, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("model"); + let embed_tokens = + Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let final_layernorm = layer_norm( + cfg.hidden_size, + cfg.layer_norm_eps, + vb_m.pp("final_layernorm"), + )?; + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_m = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(cfg, vb_m.pp(layer_idx))?; + layers.push(layer) + } + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + final_layernorm, + lm_head, + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (_b_size, seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.embed_tokens)?; + let mask = if seq_len <= 1 { + None + } else { + Some(get_mask(seq_len, xs.device())?) + }; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, mask.as_ref())?; + } + xs.apply(&self.final_layernorm)? + .narrow(1, seq_len - 1, 1)? + .apply(&self.lm_head)? + .squeeze(1) + } + + pub fn clear_kv_cache(&mut self) { + self.layers.iter_mut().for_each(|b| b.clear_kv_cache()) + } +} |