summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/models/bert.py
blob: 0a773f939d2f8a395d9857816cd998a837c973b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from dataclasses import dataclass
from typing import Optional
from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList
from candle import Tensor
import candle
import candle.functional as F
from typing import Tuple, Optional


@dataclass
class Config:
    vocab_size: int = 30522
    hidden_size: int = 768
    num_hidden_layers: int = 12
    num_attention_heads: int = 12
    intermediate_size: int = 3072
    hidden_act: str = "gelu"
    hidden_dropout_prob: float = 0.1
    max_position_embeddings: int = 512
    type_vocab_size: int = 2
    initializer_range: float = 0.02
    layer_norm_eps: float = 1e-12
    pad_token_id: int = 0
    position_embedding_type: str = "absolute"
    use_cache: bool = True
    classifier_dropout: Optional[float] = None
    model_type: Optional[str] = "bert"


class BertSelfAttention(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        all_head_size = int(config.num_attention_heads * self.attention_head_size)
        hidden_size = config.hidden_size
        self.query = Linear(hidden_size, all_head_size)
        self.key = Linear(hidden_size, all_head_size)
        self.value = Linear(hidden_size, all_head_size)

    def transpose_for_scores(self, x: Tensor) -> Tensor:
        new_x_shape = x.shape[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.reshape(new_x_shape).transpose(1, 2)
        return x.contiguous()

    def forward(self, hidden_states: Tensor) -> Tensor:
        query = self.query.forward(hidden_states)
        key = self.key.forward(hidden_states)
        value = self.value.forward(hidden_states)

        query = self.transpose_for_scores(query)
        key = self.transpose_for_scores(key)
        value = self.transpose_for_scores(value)

        attention_scores = query.matmul(key.t())
        attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5)
        attention_probs = F.softmax(attention_scores, dim=-1)

        context_layer = attention_probs.matmul(value)
        context_layer = context_layer.transpose(1, 2).contiguous()
        context_layer = context_layer.flatten_from(-2)
        return context_layer


class BertSelfOutput(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.dense = Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
        hidden_states = self.dense.forward(hidden_states)
        return self.LayerNorm.forward(hidden_states + input_tensor)


class BertAttention(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)

    def forward(self, hidden_states: Tensor) -> Tensor:
        self_outputs = self.self.forward(hidden_states)
        attention_output = self.output.forward(self_outputs, hidden_states)
        return attention_output


class BertIntermediate(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.dense = Linear(config.hidden_size, config.intermediate_size)
        self.act = F.gelu if config.hidden_act == "gelu" else F.relu

    def forward(self, hidden_states: Tensor) -> Tensor:
        hidden_states = self.dense.forward(hidden_states)
        return self.act(hidden_states)


class BertOutput(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.dense = Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
        hidden_states = self.dense.forward(hidden_states)
        return self.LayerNorm.forward(hidden_states + input_tensor)


class BertLayer(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(self, hidden_states: Tensor) -> Tensor:
        attention_output = self.attention.forward(hidden_states)
        # TODO: Support cross-attention?
        # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
        # TODO: Support something similar to `apply_chunking_to_forward`?
        intermediate_output = self.intermediate.forward(attention_output)
        layer_output = self.output.forward(intermediate_output, attention_output)
        return layer_output


class BertEncoder(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.layer = ModuleList()
        for _ in range(config.num_hidden_layers):
            self.layer.append(BertLayer(config))

    def forward(self, hidden_states: Tensor) -> Tensor:
        for l in self.layer:
            hidden_states = l.forward(hidden_states)
        return hidden_states


class BertEmbeddings(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size)
        self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape(
            (1, config.max_position_embeddings)
        )

    def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor:
        (_batch_size, seq_len) = input_ids.shape
        input_embeddings = self.word_embeddings.forward(input_ids)
        token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)
        embeddings: Tensor = input_embeddings + token_type_embeddings

        position_ids = list(range(seq_len))
        position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device)

        embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids))
        embeddings = self.LayerNorm(embeddings)
        return embeddings


class BertPooler(Module):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.dense = Linear(config.hidden_size, config.hidden_size)
        self.activation = F.tanh

    def forward(self, hidden_states: Tensor) -> Tensor:
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense.forward(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
class BertModel(Module):
    def __init__(self, config: Config, add_pooling_layer=True) -> None:
        super().__init__()
        self.config = config
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config) if add_pooling_layer else None

    def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
        embeddings = self.embeddings.forward(input_ids, token_type_ids)
        encoder_out = self.encoder.forward(embeddings)
        pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
        return encoder_out, pooled_output