summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2025-01-13 17:39:49 +0100
committerGitHub <noreply@github.com>2025-01-13 17:39:49 +0100
commit309cd0f7c7d2035f3f43da8a4cd7e6a7a897c515 (patch)
tree73e20422edfb2c3211b2759af15c56d565874efa
parentab7ff7081eab36958b82b98b89cee3eacf877111 (diff)
downloadcandle-309cd0f7c7d2035f3f43da8a4cd7e6a7a897c515.tar.gz
candle-309cd0f7c7d2035f3f43da8a4cd7e6a7a897c515.tar.bz2
candle-309cd0f7c7d2035f3f43da8a4cd7e6a7a897c515.zip
Add the helium model. (#2715)
-rw-r--r--candle-examples/examples/helium/README.md11
-rw-r--r--candle-examples/examples/helium/main.rs292
-rw-r--r--candle-transformers/src/models/helium.rs395
-rw-r--r--candle-transformers/src/models/mod.rs1
4 files changed, 699 insertions, 0 deletions
diff --git a/candle-examples/examples/helium/README.md b/candle-examples/examples/helium/README.md
new file mode 100644
index 00000000..9d1f2009
--- /dev/null
+++ b/candle-examples/examples/helium/README.md
@@ -0,0 +1,11 @@
+# candle-helium: 2b LLM with CC-BY licensed weights
+
+- [Model card](https://huggingface.co/kyutai/helium-1-preview) on the HuggingFace Hub.
+
+## Running the example
+
+```bash
+$ cargo run --example helium --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
+```
+
+
diff --git a/candle-examples/examples/helium/main.rs b/candle-examples/examples/helium/main.rs
new file mode 100644
index 00000000..d427f104
--- /dev/null
+++ b/candle-examples/examples/helium/main.rs
@@ -0,0 +1,292 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::{Error as E, Result};
+use clap::Parser;
+
+use candle_transformers::models::helium::{Config, Model};
+
+use candle::{DType, Device, Tensor};
+use candle_examples::token_output_stream::TokenOutputStream;
+use candle_nn::VarBuilder;
+use candle_transformers::generation::{LogitsProcessor, Sampling};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+struct TextGeneration {
+ model: Model,
+ device: Device,
+ tokenizer: TokenOutputStream,
+ logits_processor: LogitsProcessor,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ config: Config,
+}
+
+impl TextGeneration {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ model: Model,
+ tokenizer: Tokenizer,
+ seed: u64,
+ temp: Option<f64>,
+ top_p: Option<f64>,
+ top_k: Option<usize>,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ config: Config,
+ device: &Device,
+ ) -> Self {
+ let logits_processor = {
+ let temperature = temp.unwrap_or(0.);
+ let sampling = if temperature <= 0. {
+ Sampling::ArgMax
+ } else {
+ match (top_k, top_p) {
+ (None, None) => Sampling::All { temperature },
+ (Some(k), None) => Sampling::TopK { k, temperature },
+ (None, Some(p)) => Sampling::TopP { p, temperature },
+ (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
+ }
+ };
+ LogitsProcessor::from_sampling(seed, sampling)
+ };
+
+ Self {
+ model,
+ tokenizer: TokenOutputStream::new(tokenizer),
+ logits_processor,
+ repeat_penalty,
+ repeat_last_n,
+ device: device.clone(),
+ config,
+ }
+ }
+
+ fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
+ use std::io::Write;
+ self.tokenizer.clear();
+ let mut tokens = self
+ .tokenizer
+ .tokenizer()
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ for &t in tokens.iter() {
+ if let Some(t) = self.tokenizer.next_token(t)? {
+ print!("{t}")
+ }
+ }
+ std::io::stdout().flush()?;
+
+ let mut generated_tokens = 0usize;
+ let start_gen = std::time::Instant::now();
+ for index in 0..sample_len {
+ let context_size = if index > 0 { 1 } else { tokens.len() };
+ let start_pos = tokens.len().saturating_sub(context_size);
+ let ctxt = &tokens[start_pos..];
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
+ let logits = self.model.forward(&input, start_pos)?;
+ let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
+ let logits = if self.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ self.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
+
+ let next_token = self.logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ generated_tokens += 1;
+ if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
+ break;
+ }
+ if let Some(t) = self.tokenizer.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+ }
+ let dt = start_gen.elapsed();
+ if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
+ std::io::stdout().flush()?;
+ println!(
+ "\n{generated_tokens} tokens generated ({:.2} token/s)",
+ generated_tokens as f64 / dt.as_secs_f64(),
+ );
+ Ok(())
+ }
+}
+
+#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
+enum Which {
+ #[value(name = "v1-preview")]
+ V1Preview,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ use_flash_attn: bool,
+
+ #[arg(long)]
+ prompt: String,
+
+ /// The temperature used to generate samples.
+ #[arg(long, default_value_t = 0.7)]
+ temperature: f64,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
+ /// Only sample among the top K samples.
+ #[arg(long)]
+ top_k: Option<usize>,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// The length of the sample to generate (in tokens).
+ #[arg(long, short = 'n', default_value_t = 10000)]
+ sample_len: usize,
+
+ /// The model size to use.
+ #[arg(long, default_value = "v1-preview")]
+ which: Which,
+
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long)]
+ tokenizer: Option<String>,
+
+ #[arg(long)]
+ config: Option<String>,
+
+ #[arg(long)]
+ weights: Option<String>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+ println!(
+ "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+ args.temperature, args.repeat_penalty, args.repeat_last_n
+ );
+
+ let start = std::time::Instant::now();
+ let api = Api::new()?;
+ let model_id = match args.model_id {
+ Some(model_id) => model_id,
+ None => {
+ let name = match args.which {
+ Which::V1Preview => "kyutai/helium-1-preview",
+ };
+ name.to_string()
+ }
+ };
+ let repo = api.repo(Repo::with_revision(
+ model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+ let tokenizer_filename = match args.tokenizer {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("tokenizer.json")?,
+ };
+ let filenames = match args.weights {
+ Some(files) => files
+ .split(',')
+ .map(std::path::PathBuf::from)
+ .collect::<Vec<_>>(),
+ None => candle_examples::hub_load_safetensors(&repo, "model.safetensors")?,
+ };
+ println!("retrieved the files in {:?}", start.elapsed());
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let start = std::time::Instant::now();
+ let config: Config = match args.config {
+ Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
+ None => {
+ let config_file = repo.get("config.json")?;
+ serde_json::from_slice(&std::fs::read(config_file)?)?
+ }
+ };
+ let device = candle_examples::device(args.cpu)?;
+ let (model, device) = {
+ let dtype = if device.is_cuda() {
+ DType::BF16
+ } else {
+ DType::F32
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
+ let model = Model::new(&config, vb)?;
+ (model, device)
+ };
+
+ println!("loaded the model in {:?}", start.elapsed());
+
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ Some(args.temperature),
+ args.top_p,
+ args.top_k,
+ args.repeat_penalty,
+ args.repeat_last_n,
+ config,
+ &device,
+ );
+ pipeline.run(&args.prompt, args.sample_len)?;
+ Ok(())
+}
diff --git a/candle-transformers/src/models/helium.rs b/candle-transformers/src/models/helium.rs
new file mode 100644
index 00000000..40cff396
--- /dev/null
+++ b/candle-transformers/src/models/helium.rs
@@ -0,0 +1,395 @@
+//! Helium inference implementation.
+//!
+//! See the model card on Hugging Face's [hub](https://huggingface.co/kmhf/helium-2b).
+
+use super::with_tracing::{linear_b as linear, Linear, RmsNorm};
+use candle::{DType, Device, Result, Tensor, D};
+use candle_nn::{Module, VarBuilder};
+use std::sync::Arc;
+
+fn default_use_flash_attn() -> bool {
+ false
+}
+
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct Config {
+ pub attention_bias: bool,
+ pub bos_token_id: u32,
+ pub eos_token_id: u32,
+ pub head_dim: usize,
+ pub hidden_act: candle_nn::Activation,
+ pub hidden_size: usize,
+ pub intermediate_size: usize,
+ pub max_position_embeddings: usize,
+ pub mlp_bias: bool,
+ pub num_attention_heads: usize,
+ pub num_hidden_layers: usize,
+ pub num_key_value_heads: usize,
+ pub rms_norm_eps: f64,
+ pub rope_theta: f64,
+ pub tie_word_embeddings: bool,
+ pub vocab_size: usize,
+ #[serde(default = "default_use_flash_attn")]
+ pub use_flash_attn: bool,
+}
+
+impl Config {
+ pub fn config_2b(use_flash_attn: bool) -> Self {
+ Self {
+ attention_bias: false,
+ bos_token_id: 1,
+ eos_token_id: 2,
+ head_dim: 128,
+ hidden_act: candle_nn::Activation::Silu,
+ hidden_size: 2560,
+ intermediate_size: 7040,
+ max_position_embeddings: 4096,
+ mlp_bias: false,
+ num_attention_heads: 20,
+ num_hidden_layers: 24,
+ num_key_value_heads: 20,
+ rms_norm_eps: 1e-08,
+ rope_theta: 100000.0,
+ tie_word_embeddings: false,
+ vocab_size: 48000,
+ use_flash_attn,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct RotaryEmbedding {
+ sin: Tensor,
+ cos: Tensor,
+}
+
+impl RotaryEmbedding {
+ fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
+ let rope_theta = cfg.rope_theta as f32;
+ let dim = cfg.head_dim;
+ let max_seq_len = cfg.max_position_embeddings;
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
+ let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
+ .to_dtype(DType::F32)?
+ .reshape((max_seq_len, 1))?;
+ let freqs = t.matmul(&inv_freq)?;
+ Ok(Self {
+ sin: freqs.sin()?.to_dtype(dtype)?,
+ cos: freqs.cos()?.to_dtype(dtype)?,
+ })
+ }
+
+ fn apply_rotary_emb_qkv(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ seqlen_offset: usize,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
+ let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
+ let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
+ let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
+ Ok((q_embed, k_embed))
+ }
+}
+
+#[derive(Debug, Clone)]
+#[allow(clippy::upper_case_acronyms)]
+struct MLP {
+ gate_proj: Linear,
+ up_proj: Linear,
+ down_proj: Linear,
+ act_fn: candle_nn::Activation,
+}
+
+impl MLP {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let intermediate_sz = cfg.intermediate_size;
+ let bias = cfg.mlp_bias;
+ let gate_proj = linear(hidden_sz, intermediate_sz, bias, vb.pp("gate_proj"))?;
+ let up_proj = linear(hidden_sz, intermediate_sz, bias, vb.pp("up_proj"))?;
+ let down_proj = linear(intermediate_sz, hidden_sz, bias, vb.pp("down_proj"))?;
+ Ok(Self {
+ gate_proj,
+ up_proj,
+ down_proj,
+ act_fn: cfg.hidden_act,
+ })
+ }
+}
+
+impl Module for MLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
+ let rhs = xs.apply(&self.up_proj)?;
+ (lhs * rhs)?.apply(&self.down_proj)
+ }
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
+#[derive(Debug, Clone)]
+struct Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ num_heads: usize,
+ num_kv_heads: usize,
+ num_kv_groups: usize,
+ head_dim: usize,
+ rotary_emb: Arc<RotaryEmbedding>,
+ kv_cache: Option<(Tensor, Tensor)>,
+ use_flash_attn: bool,
+}
+
+impl Attention {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let num_heads = cfg.num_attention_heads;
+ let num_kv_heads = cfg.num_key_value_heads;
+ let num_kv_groups = num_heads / num_kv_heads;
+ let head_dim = cfg.head_dim;
+ let bias = cfg.attention_bias;
+ let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
+ let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
+ let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
+ let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ num_heads,
+ num_kv_heads,
+ num_kv_groups,
+ head_dim,
+ rotary_emb,
+ kv_cache: None,
+ use_flash_attn: cfg.use_flash_attn,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (b_sz, q_len, _) = xs.dims3()?;
+
+ let query_states = self.q_proj.forward(xs)?;
+ let key_states = self.k_proj.forward(xs)?;
+ let value_states = self.v_proj.forward(xs)?;
+
+ let query_states = query_states
+ .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let key_states = key_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let value_states = value_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+
+ let (query_states, key_states) =
+ self.rotary_emb
+ .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
+
+ let (key_states, value_states) = match &self.kv_cache {
+ None => (key_states, value_states),
+ Some((prev_k, prev_v)) => {
+ let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
+ let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
+ (key_states, value_states)
+ }
+ };
+ self.kv_cache = Some((key_states.clone(), value_states.clone()));
+
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
+
+ let attn_output = if self.use_flash_attn {
+ // flash-attn expects (b_sz, seq_len, nheads, head_dim)
+ let q = query_states.transpose(1, 2)?;
+ let k = key_states.transpose(1, 2)?;
+ let v = value_states.transpose(1, 2)?;
+ let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
+ flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
+ } else {
+ let scale = 1f64 / f64::sqrt(self.head_dim as f64);
+ let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
+
+ let attn_weights = match attention_mask {
+ None => attn_weights,
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ };
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+ attn_weights.matmul(&value_states)?
+ };
+ attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.num_heads * self.head_dim))?
+ .apply(&self.o_proj)
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderLayer {
+ self_attn: Attention,
+ mlp: MLP,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl DecoderLayer {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
+ let mlp = MLP::new(cfg, vb.pp("mlp"))?;
+ let input_layernorm =
+ RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm = RmsNorm::new(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+ Ok(Self {
+ self_attn,
+ mlp,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.input_layernorm.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
+ residual + xs
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache()
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ embed_tokens: candle_nn::Embedding,
+ layers: Vec<DecoderLayer>,
+ norm: RmsNorm,
+ lm_head: Linear,
+ device: Device,
+ dtype: DType,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_m = vb.pp("model");
+ let embed_tokens =
+ candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
+ let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_l = vb_m.pp("layers");
+ for layer_idx in 0..cfg.num_hidden_layers {
+ let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
+ let lm_head = if cfg.tie_word_embeddings {
+ Linear::from_weights(embed_tokens.embeddings().clone(), None)
+ } else {
+ linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?
+ };
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm,
+ lm_head,
+ device: vb.device().clone(),
+ dtype: vb.dtype(),
+ })
+ }
+
+ fn prepare_decoder_attention_mask(
+ &self,
+ tgt_len: usize,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let mask: Vec<_> = (0..tgt_len)
+ .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
+ let mask = if seqlen_offset > 0 {
+ let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
+ Tensor::cat(&[&mask0, &mask], D::Minus1)?
+ } else {
+ mask
+ };
+ mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
+ .to_dtype(self.dtype)
+ }
+
+ pub fn embed_tokens(&self) -> &candle_nn::Embedding {
+ &self.embed_tokens
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let (_b_size, seq_len) = input_ids.dims2()?;
+ let attention_mask = if seq_len <= 1 {
+ None
+ } else {
+ let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
+ Some(mask)
+ };
+ let mut xs = self.embed_tokens.forward(input_ids)?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ for layer in self.layers.iter_mut() {
+ layer.clear_kv_cache()
+ }
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 473a276f..df1de0b2 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -43,6 +43,7 @@ pub mod gemma;
pub mod gemma2;
pub mod glm4;
pub mod granite;
+pub mod helium;
pub mod hiera;
pub mod jina_bert;
pub mod llama;