summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/mixformer.rs217
-rw-r--r--candle-transformers/src/models/mod.rs1
2 files changed, 218 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
new file mode 100644
index 00000000..2674d34f
--- /dev/null
+++ b/candle-transformers/src/models/mixformer.rs
@@ -0,0 +1,217 @@
+#![allow(unused)]
+/// MixFormer model.
+/// https://huggingface.co/microsoft/phi-1_5
+/// https://arxiv.org/abs/2309.05463
+use candle::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::{Activation, VarBuilder};
+
+// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
+#[derive(Debug, Clone, PartialEq)]
+pub struct Config {
+ vocab_size: usize,
+ n_positions: usize,
+ n_embd: usize,
+ n_layer: usize,
+ n_inner: Option<usize>,
+ n_head: usize,
+ rotary_dim: usize,
+ activation_function: Activation,
+ layer_norm_epsilon: f64,
+ tie_word_embeddings: bool,
+ pad_vocab_size_multiple: usize,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 50304,
+ n_positions: 2048,
+ n_embd: 1024,
+ n_layer: 20,
+ n_inner: None,
+ n_head: 16,
+ rotary_dim: usize::min(32, 1024 / 16),
+ activation_function: Activation::Gelu,
+ layer_norm_epsilon: 1e-5,
+ tie_word_embeddings: false,
+ pad_vocab_size_multiple: 64,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Embedding {
+ wte: candle_nn::Embedding,
+}
+
+impl Embedding {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let wte = candle_nn::embedding(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
+ Ok(Self { wte })
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.wte.forward(xs)
+ }
+}
+
+#[derive(Debug)]
+struct RotaryEmbedding {}
+
+#[derive(Debug)]
+#[allow(clippy::upper_case_acronyms)]
+struct MLP {
+ fc1: candle_nn::Linear,
+ fc2: candle_nn::Linear,
+ act: Activation,
+}
+
+impl MLP {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
+ let fc1 = candle_nn::linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
+ let fc2 = candle_nn::linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
+ Ok(Self {
+ fc1,
+ fc2,
+ act: cfg.activation_function,
+ })
+ }
+}
+
+impl Module for MLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
+ }
+}
+
+#[derive(Debug)]
+struct SelfAttention {
+ causal: bool,
+ softmax_scale: f64,
+}
+
+#[derive(Debug)]
+struct CrossAttention {
+ causal: bool,
+ softmax_scale: f64,
+}
+
+#[derive(Debug)]
+struct CausalLMHead {
+ ln: candle_nn::LayerNorm,
+ linear: candle_nn::Linear,
+}
+
+impl CausalLMHead {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
+ let linear = candle_nn::linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
+ Ok(Self { ln, linear })
+ }
+}
+
+impl Module for CausalLMHead {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.apply(&self.ln)?
+ .apply(&self.linear)?
+ .to_dtype(DType::F32)
+ }
+}
+
+#[derive(Debug)]
+#[allow(clippy::upper_case_acronyms)]
+struct MHA {
+ wqkv: candle_nn::Linear,
+ out_proj: candle_nn::Linear,
+ head_dim: usize,
+}
+
+impl MHA {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let head_dim = cfg.n_embd / cfg.n_head;
+ let op_size = cfg.n_embd;
+ let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
+ let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
+ Ok(Self {
+ wqkv,
+ out_proj,
+ head_dim,
+ })
+ }
+}
+
+impl Module for MHA {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_size, seq_len, n_embd) = xs.dims3()?;
+ let qkv = self
+ .wqkv
+ .forward(xs)?
+ .reshape((b_size, seq_len, 3, (), self.head_dim))?;
+ let context: Tensor = qkv; // TODO
+ context.flatten_from(D::Minus2)?.apply(&self.out_proj)
+ }
+}
+
+#[derive(Debug)]
+struct ParallelBlock {
+ ln: candle_nn::LayerNorm,
+ mixer: MHA,
+ mlp: MLP,
+}
+
+impl ParallelBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
+ let mixer = MHA::new(cfg, vb.pp("mixer"))?;
+ let mlp = MLP::new(cfg, vb.pp("mlp"))?;
+ Ok(Self { ln, mixer, mlp })
+ }
+}
+
+impl Module for ParallelBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = xs.apply(&self.ln)?;
+ let attn_outputs = self.mixer.forward(&xs)?;
+ let feed_forward_hidden_states = self.mlp.forward(&xs)?;
+ attn_outputs + feed_forward_hidden_states + residual
+ }
+}
+
+#[derive(Debug)]
+pub struct MixFormerSequentialForCausalLM {
+ embedding: Embedding,
+ blocks: Vec<ParallelBlock>,
+ head: CausalLMHead,
+}
+
+impl MixFormerSequentialForCausalLM {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb = vb.pp("layers");
+ let embedding = Embedding::new(cfg, vb.pp(0))?;
+ let mut blocks = Vec::new();
+ for i in 0..cfg.n_layer {
+ let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
+ blocks.push(block)
+ }
+ let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
+ Ok(Self {
+ embedding,
+ blocks,
+ head,
+ })
+ }
+}
+
+impl Module for MixFormerSequentialForCausalLM {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.apply(&self.embedding)?;
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ xs.apply(&self.head)
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index d783a2c6..991ee201 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -4,6 +4,7 @@ pub mod dinov2;
pub mod efficientnet;
pub mod falcon;
pub mod llama;
+pub mod mixformer;
pub mod quantized_llama;
pub mod quantized_t5;
pub mod segment_anything;