diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-21 22:02:50 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-21 22:02:50 +0100 |
commit | 45d5322d62d59a2c9be5fe8c642d0fa56fbb73b1 (patch) | |
tree | 8e2965b96fe2769095b88ffb7d28fde12d07f3c6 | |
parent | a2cb2edead523f9ee65ddac34f6e1946d52236b3 (diff) | |
download | candle-45d5322d62d59a2c9be5fe8c642d0fa56fbb73b1.tar.gz candle-45d5322d62d59a2c9be5fe8c642d0fa56fbb73b1.tar.bz2 candle-45d5322d62d59a2c9be5fe8c642d0fa56fbb73b1.zip |
Add the Gemma models. (#1741)
* Add the Gemma models.
* Add the gemma example.
* Adapt the RmsNorm.
* Get the 2b model to work.
* 7b support.
* Use the config head dim.
* Yet another fix.
* Make the matrixes contiguous.
* Also get the 7b model to work.
* And add to the readme.
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | candle-examples/examples/gemma/README.md | 27 | ||||
-rw-r--r-- | candle-examples/examples/gemma/main.rs | 254 | ||||
-rw-r--r-- | candle-transformers/src/models/gemma.rs | 380 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 |
5 files changed, 665 insertions, 0 deletions
@@ -63,6 +63,8 @@ We also provide a some command line based examples using state of the art models - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes the SOLAR-10.7B variant. - [Falcon](./candle-examples/examples/falcon/): general LLM. +- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google + Deepmind. - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM pre-trained on 1T tokens of English and code datasets. Also supports @@ -190,6 +192,7 @@ If you have an addition to this list, please submit a pull request. - StarCoder. - Phi 1, 1.5, and 2. - Mamba, Minimal Mamba + - Gemma 2b and 7b. - Mistral 7b v0.1. - Mixtral 8x7b v0.1. - StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B. diff --git a/candle-examples/examples/gemma/README.md b/candle-examples/examples/gemma/README.md new file mode 100644 index 00000000..8319cf44 --- /dev/null +++ b/candle-examples/examples/gemma/README.md @@ -0,0 +1,27 @@ +# candle-mistral: 2b and 7b LLMs from Google DeepMind + +[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open +models published by Google Deepmind with a 2b and a 7b variant. + +In order to use the example below, you have to accept the license on the +[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up +your access token via the [HuggingFace cli login +command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login). + +## Running the example + +```bash +$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)" +fn count_primes(max_n: usize) -> usize { + let mut primes = vec![true; max_n]; + for i in 2..=max_n { + if primes[i] { + for j in i * i..max_n { + primes[j] = false; + } + } + } + primes.len() +} +``` + diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs new file mode 100644 index 00000000..a4a1b3de --- /dev/null +++ b/candle-examples/examples/gemma/main.rs @@ -0,0 +1,254 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::gemma::{Config, Model}; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option<f64>, + top_p: Option<f64>, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("</s>") { + Some(token) => token, + None => anyhow::bail!("cannot find the </s> token"), + }; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let start_pos = tokens.len().saturating_sub(context_size); + let ctxt = &tokens[start_pos..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, start_pos)?; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option<f64>, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 10000)] + sample_len: usize, + + #[arg(long)] + model_id: Option<String>, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer_file: Option<String>, + + #[arg(long)] + config_file: Option<String>, + + #[arg(long)] + weight_files: Option<String>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => match model_id.as_str() { + "7b" => "google/gemma-7b".to_string(), + "2b" => "google/gemma-2b".to_string(), + _ => model_id.to_string(), + }, + None => "google/gemma-2b".to_string(), + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::<Vec<_>>(), + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?; + + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = Model::new(&config, vb)?; + + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs new file mode 100644 index 00000000..e2be8406 --- /dev/null +++ b/candle-transformers/src/models/gemma.rs @@ -0,0 +1,380 @@ +use std::sync::Arc; + +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{linear_no_bias, Linear, VarBuilder}; + +fn default_max_position_embeddings() -> usize { + 4096 +} + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub attention_bias: bool, + pub head_dim: usize, + pub hidden_act: candle_nn::Activation, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + pub rope_theta: f64, + pub vocab_size: usize, + + #[serde(default = "default_max_position_embeddings")] + pub max_position_embeddings: usize, +} + +#[derive(Debug, Clone)] +struct RmsNorm { + weight: Tensor, + eps: f64, +} + +impl RmsNorm { + fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(dim, "weight")?; + Ok(Self { weight, eps }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&(&self.weight + 1.0)?) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +fn rotate_half(xs: &Tensor) -> Result<Tensor> { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { + let dim = cfg.head_dim; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; + let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: candle_nn::Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; + let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + rotary_emb: Arc<RotaryEmbedding>, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl Attention { + fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = cfg.head_dim; + let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + rotary_emb, + kv_cache: None, + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> { + let n_rep = self.num_kv_groups; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = self.repeat_kv(key_states)?.contiguous()?; + let value_states = self.repeat_kv(value_states)?.contiguous()?; + + let attn_output = { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, ()))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec<DecoderLayer>, + norm: RmsNorm, + lm_head: Linear, + device: Device, + dtype: DType, + hidden_size: usize, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = Linear::new(embed_tokens.embeddings().clone(), None); + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + device: vb.device().clone(), + dtype: vb.dtype(), + hidden_size: cfg.hidden_size, + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result<Tensor> { + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let xs = self.embed_tokens.forward(input_ids)?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8eab4744..bb59a53f 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -9,6 +9,7 @@ pub mod dinov2; pub mod distilbert; pub mod efficientnet; pub mod falcon; +pub mod gemma; pub mod jina_bert; pub mod llama; pub mod llama2_c; |