summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mmdit/projections.rs
blob: dc1e8ec941010cd89040af571b0805ec52c455d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use candle::{Module, Result, Tensor};
use candle_nn as nn;

pub struct Qkv {
    pub q: Tensor,
    pub k: Tensor,
    pub v: Tensor,
}

pub struct Mlp {
    fc1: nn::Linear,
    act: nn::Activation,
    fc2: nn::Linear,
}

impl Mlp {
    pub fn new(
        in_features: usize,
        hidden_features: usize,
        vb: candle_nn::VarBuilder,
    ) -> Result<Self> {
        let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?;
        let act = nn::Activation::GeluPytorchTanh;
        let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?;

        Ok(Self { fc1, act, fc2 })
    }
}

impl Module for Mlp {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let x = self.fc1.forward(x)?;
        let x = self.act.forward(&x)?;
        self.fc2.forward(&x)
    }
}

pub struct QkvOnlyAttnProjections {
    qkv: nn::Linear,
    head_dim: usize,
}

impl QkvOnlyAttnProjections {
    pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
        let head_dim = dim / num_heads;
        let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
        Ok(Self { qkv, head_dim })
    }

    pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
        let qkv = self.qkv.forward(x)?;
        split_qkv(&qkv, self.head_dim)
    }
}

pub struct AttnProjections {
    head_dim: usize,
    qkv: nn::Linear,
    proj: nn::Linear,
}

impl AttnProjections {
    pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
        let head_dim = dim / num_heads;
        let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
        let proj = nn::linear(dim, dim, vb.pp("proj"))?;
        Ok(Self {
            head_dim,
            qkv,
            proj,
        })
    }

    pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
        let qkv = self.qkv.forward(x)?;
        split_qkv(&qkv, self.head_dim)
    }

    pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
        self.proj.forward(x)
    }
}

fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result<Qkv> {
    let (batch_size, seq_len, _) = qkv.dims3()?;
    let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?;
    let q = qkv.get_on_dim(2, 0)?;
    let q = q.reshape((batch_size, seq_len, ()))?;
    let k = qkv.get_on_dim(2, 1)?;
    let k = k.reshape((batch_size, seq_len, ()))?;
    let v = qkv.get_on_dim(2, 2)?;
    Ok(Qkv { q, k, v })
}