summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md8
-rw-r--r--candle-examples/examples/rwkv/main.rs290
-rw-r--r--candle-transformers/src/models/llama.rs3
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/rwkv_v5.rs317
5 files changed, 616 insertions, 3 deletions
diff --git a/README.md b/README.md
index 9bfa30d8..5c65ef68 100644
--- a/README.md
+++ b/README.md
@@ -75,6 +75,9 @@ We also provide a some command line based examples using state of the art models
experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
+- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
+- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
+ performance.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
@@ -193,6 +196,8 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
+ - Qwen1.5.
+ - RWKV.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
@@ -210,7 +215,8 @@ If you have an addition to this list, please submit a pull request.
- BLIP.
- TrOCR.
- Computer Vision Models.
- - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT.
+ - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
+ ConvNeXTv2.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs
new file mode 100644
index 00000000..7fd1f76c
--- /dev/null
+++ b/candle-examples/examples/rwkv/main.rs
@@ -0,0 +1,290 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::{Error as E, Result};
+use clap::{Parser, ValueEnum};
+
+use candle_transformers::models::rwkv_v5::{Config, Model, State};
+
+use candle::{DType, Device, Tensor};
+use candle_examples::token_output_stream::TokenOutputStream;
+use candle_nn::VarBuilder;
+use candle_transformers::generation::LogitsProcessor;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::Tokenizer;
+
+struct TextGeneration {
+ model: Model,
+ config: Config,
+ device: Device,
+ tokenizer: TokenOutputStream,
+ logits_processor: LogitsProcessor,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+}
+
+impl TextGeneration {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ model: Model,
+ config: Config,
+ tokenizer: Tokenizer,
+ seed: u64,
+ temp: Option<f64>,
+ top_p: Option<f64>,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ device: &Device,
+ ) -> Self {
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
+ Self {
+ model,
+ config,
+ tokenizer: TokenOutputStream::new(tokenizer),
+ logits_processor,
+ repeat_penalty,
+ repeat_last_n,
+ device: device.clone(),
+ }
+ }
+
+ fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
+ use std::io::Write;
+ self.tokenizer.clear();
+ let mut tokens = self
+ .tokenizer
+ .tokenizer()
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let mut generated_tokens = 0usize;
+ let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
+ Some(token) => token,
+ None => anyhow::bail!("cannot find the </s> token"),
+ };
+ let mut state = State::new(1, &self.config, &self.device)?;
+ let mut next_logits = None;
+ for &t in tokens.iter() {
+ let input = Tensor::new(&[[t]], &self.device)?;
+ let logits = self.model.forward(&input, &mut state)?;
+ next_logits = Some(logits);
+ if let Some(t) = self.tokenizer.next_token(t)? {
+ print!("{t}")
+ }
+ }
+ std::io::stdout().flush()?;
+
+ let start_gen = std::time::Instant::now();
+ for _ in 0..sample_len {
+ let logits = match next_logits.as_ref() {
+ Some(logits) => logits,
+ None => anyhow::bail!("cannot work on an empty prompt"),
+ };
+ let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
+ let logits = if self.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ self.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
+ let next_token = self.logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ generated_tokens += 1;
+ if next_token == eos_token {
+ break;
+ }
+ if let Some(t) = self.tokenizer.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+
+ let input = Tensor::new(&[[next_token]], &self.device)?;
+ next_logits = Some(self.model.forward(&input, &mut state)?)
+ }
+ let dt = start_gen.elapsed();
+ if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
+ std::io::stdout().flush()?;
+ println!(
+ "\n{generated_tokens} tokens generated ({:.2} token/s)",
+ generated_tokens as f64 / dt.as_secs_f64(),
+ );
+ Ok(())
+ }
+}
+
+#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
+enum Which {
+ Eagle7b,
+ World1b5,
+ World3b,
+}
+
+impl std::fmt::Display for Which {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{:?}", self)
+ }
+}
+
+impl Which {
+ fn model_id(&self) -> &'static str {
+ match self {
+ Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
+ Self::World1b5 => "RWKV/rwkv-5-world-1b5",
+ Self::World3b => "RWKV/rwkv-5-world-3b",
+ }
+ }
+
+ fn revision(&self) -> &'static str {
+ match self {
+ Self::Eagle7b => "refs/pr/1",
+ Self::World1b5 | Self::World3b => "refs/pr/2",
+ }
+ }
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ prompt: String,
+
+ /// The temperature used to generate samples.
+ #[arg(long)]
+ temperature: Option<f64>,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// The length of the sample to generate (in tokens).
+ #[arg(long, short = 'n', default_value_t = 5000)]
+ sample_len: usize,
+
+ #[arg(long, default_value = "world1b5")]
+ which: Which,
+
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long)]
+ revision: Option<String>,
+
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ #[arg(long)]
+ config_file: Option<String>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+}
+
+fn main() -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+ println!(
+ "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+ args.temperature.unwrap_or(0.),
+ args.repeat_penalty,
+ args.repeat_last_n
+ );
+
+ let start = std::time::Instant::now();
+ let api = Api::new()?;
+ let repo = api.repo(Repo::with_revision(
+ args.model_id
+ .unwrap_or_else(|| args.which.model_id().to_string()),
+ RepoType::Model,
+ args.revision
+ .unwrap_or_else(|| args.which.revision().to_string()),
+ ));
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => api
+ // TODO: Use the appropriate tokenizer here.
+ .model("EleutherAI/gpt-neox-20b".to_string())
+ .get("tokenizer.json")?,
+ };
+ let config_filename = match args.config_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("config.json")?,
+ };
+ let filenames = match args.weight_files {
+ Some(files) => files
+ .split(',')
+ .map(std::path::PathBuf::from)
+ .collect::<Vec<_>>(),
+ None => {
+ vec![repo.get("model.safetensors")?]
+ }
+ };
+ println!("retrieved the files in {:?}", start.elapsed());
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let start = std::time::Instant::now();
+ let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
+ let device = candle_examples::device(args.cpu)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
+ let model = Model::new(&config, vb)?;
+ println!("loaded the model in {:?}", start.elapsed());
+
+ let mut pipeline = TextGeneration::new(
+ model,
+ config,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ args.repeat_penalty,
+ args.repeat_last_n,
+ &device,
+ );
+ pipeline.run(&args.prompt, args.sample_len)?;
+ Ok(())
+}
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 7a920cb8..f8126394 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -1,13 +1,12 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
-use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096;
-#[derive(Debug, Clone, Deserialize)]
+#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 769fd650..8eab4744 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -34,6 +34,7 @@ pub mod quantized_t5;
pub mod qwen2;
pub mod repvgg;
pub mod resnet;
+pub mod rwkv_v5;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs
new file mode 100644
index 00000000..d8ea7b20
--- /dev/null
+++ b/candle-transformers/src/models/rwkv_v5.rs
@@ -0,0 +1,317 @@
+use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
+use candle::{DType, Device, IndexOp, Result, Tensor};
+use candle_nn::{embedding, Embedding, Module, VarBuilder};
+
+fn default_num_attention_heads() -> usize {
+ 64
+}
+
+// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct Config {
+ pub vocab_size: usize,
+ pub hidden_size: usize,
+ pub num_hidden_layers: usize,
+ pub attention_hidden_size: usize,
+ #[serde(default = "default_num_attention_heads")]
+ pub num_attention_heads: usize,
+ pub head_size: usize,
+ pub intermediate_size: Option<usize>,
+ pub layer_norm_epsilon: f64,
+ pub rescale_every: usize,
+}
+
+struct StatePerLayer {
+ extract_key_value: Tensor,
+ linear_attention: Tensor,
+ feed_forward: Tensor,
+}
+
+pub struct State {
+ per_layer: Vec<StatePerLayer>,
+ pos: usize,
+}
+
+impl State {
+ pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> {
+ let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers);
+ // Certainly a weird convention but taken from modeling_rwkv5.py
+ let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads;
+ for _layer_idx in 0..cfg.num_hidden_layers {
+ let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
+ let linear_attention = Tensor::zeros(
+ (
+ batch_size,
+ num_attention_heads,
+ cfg.hidden_size / num_attention_heads,
+ cfg.hidden_size / num_attention_heads,
+ ),
+ DType::F32,
+ dev,
+ )?;
+ let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
+ per_layer.push(StatePerLayer {
+ extract_key_value,
+ linear_attention,
+ feed_forward,
+ });
+ }
+ Ok(Self { per_layer, pos: 0 })
+ }
+}
+
+#[derive(Debug, Clone)]
+struct SelfAttention {
+ key: Linear,
+ receptance: Linear,
+ value: Linear,
+ gate: Linear,
+ output: Linear,
+ ln_x: candle_nn::GroupNorm,
+ time_mix_key: Tensor,
+ time_mix_value: Tensor,
+ time_mix_receptance: Tensor,
+ time_decay: Tensor,
+ time_faaaa: Tensor,
+ time_mix_gate: Tensor,
+ layer_id: usize,
+ n_attn_heads: usize,
+}
+
+impl SelfAttention {
+ pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_size = cfg.hidden_size;
+ let attn_hidden_size = cfg.attention_hidden_size;
+ let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
+ let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
+ let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
+ let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
+ let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
+ let ln_x = candle_nn::group_norm(
+ hidden_size / cfg.head_size,
+ hidden_size,
+ 1e-5,
+ vb.pp("ln_x"),
+ )?;
+ let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
+ let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
+ let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
+ let n_attn_heads = cfg.hidden_size / cfg.head_size;
+ let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?;
+ let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
+ let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
+ Ok(Self {
+ key,
+ value,
+ receptance,
+ gate,
+ output,
+ ln_x,
+ time_mix_key,
+ time_mix_value,
+ time_mix_receptance,
+ time_decay,
+ time_faaaa,
+ time_mix_gate,
+ layer_id,
+ n_attn_heads,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let h = self.time_decay.dim(0)?;
+ let (b, t, s) = xs.dims3()?;
+ let s = s / h;
+ let (receptance, key, value, gate) = {
+ // exctract key-value
+ let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
+ let shifted = if shifted.rank() == 2 {
+ shifted.unsqueeze(1)?
+ } else {
+ shifted
+ };
+ let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
+ let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
+ let receptance = ((xs * &self.time_mix_receptance)?
+ + &shifted * (1.0 - &self.time_mix_receptance)?)?;
+ let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
+
+ let key = self.key.forward(&key)?;
+ let value = self.value.forward(&value)?;
+ let receptance = self.receptance.forward(&receptance)?;
+ let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
+ state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
+ (receptance, key, value, gate)
+ };
+ // linear attention
+ let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
+ let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
+ let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
+ let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
+
+ let time_decay = self
+ .time_decay
+ .exp()?
+ .neg()?
+ .exp()?
+ .reshape(((), 1, 1))?
+ .reshape((self.n_attn_heads, (), 1))?;
+ let time_faaaa =
+ self.time_faaaa
+ .reshape(((), 1, 1))?
+ .reshape((self.n_attn_heads, (), 1))?;
+
+ let mut out: Vec<Tensor> = Vec::with_capacity(t);
+ for t_ in 0..t {
+ //
+ let rt = receptance.i((.., .., t_..t_ + 1))?;
+ let kt = key.i((.., .., .., t_..t_ + 1))?;
+ let vt = value.i((.., .., t_..t_ + 1))?;
+ let at = kt.matmul(&vt)?;
+ let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
+ let out_ = rt.matmul(&rhs)?.squeeze(2)?;
+ state_ = (&at + time_decay.broadcast_mul(&state_))?;
+ out.push(out_)
+ }
+ let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
+ let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
+ let out = (out * gate)?.apply(&self.output)?;
+ state.per_layer[self.layer_id].linear_attention = state_;
+ Ok(out)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct FeedForward {
+ time_mix_key: Tensor,
+ time_mix_receptance: Tensor,
+ key: Linear,
+ receptance: Linear,
+ value: Linear,
+ layer_id: usize,
+}
+
+impl FeedForward {
+ pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let int_size = cfg
+ .intermediate_size
+ .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
+ let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
+ let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
+ let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
+ let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
+ let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
+ Ok(Self {
+ key,
+ receptance,
+ value,
+ time_mix_key,
+ time_mix_receptance,
+ layer_id,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let shifted = &state.per_layer[self.layer_id].feed_forward;
+ let key = (xs.broadcast_mul(&self.time_mix_key)?
+ + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
+ let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
+ + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
+ let key = key.apply(&self.key)?.relu()?.sqr()?;
+ let value = key.apply(&self.value)?;
+ let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
+ state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
+ let xs = (receptance * value)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Block {
+ pre_ln: Option<LayerNorm>,
+ ln1: LayerNorm,
+ ln2: LayerNorm,
+ attention: SelfAttention,
+ feed_forward: FeedForward,
+}
+
+impl Block {
+ pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
+ let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
+ let pre_ln = if layer_id == 0 {
+ let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
+ Some(ln)
+ } else {
+ None
+ };
+ let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
+ let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
+ Ok(Self {
+ pre_ln,
+ ln1,
+ ln2,
+ attention,
+ feed_forward,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let xs = match self.pre_ln.as_ref() {
+ None => xs.clone(),
+ Some(pre_ln) => xs.apply(pre_ln)?,
+ };
+ let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
+ let xs = (xs + attention)?;
+ let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
+ let xs = (xs + feed_forward)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ embeddings: Embedding,
+ blocks: Vec<Block>,
+ ln_out: LayerNorm,
+ head: Linear,
+ rescale_every: usize,
+ layers_are_rescaled: bool,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_m = vb.pp("rwkv");
+ let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
+ let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_b = vb_m.pp("blocks");
+ for block_index in 0..cfg.num_hidden_layers {
+ let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
+ blocks.push(block)
+ }
+ let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
+ let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
+ Ok(Self {
+ embeddings,
+ blocks,
+ ln_out,
+ head,
+ rescale_every: cfg.rescale_every,
+ layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let (_b_size, _seq_len) = xs.dims2()?;
+ let mut xs = xs.apply(&self.embeddings)?;
+ for (block_idx, block) in self.blocks.iter().enumerate() {
+ xs = block.forward(&xs, state)?;
+ if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
+ xs = (xs / 2.)?
+ }
+ }
+ let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
+ state.pos += 1;
+ Ok(xs)
+ }
+}