diff options
Diffstat (limited to 'candle-pyo3/py_src/candle/models')
-rw-r--r-- | candle-pyo3/py_src/candle/models/bert.py | 194 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/models/llama.py | 150 |
2 files changed, 344 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py new file mode 100644 index 00000000..0a773f93 --- /dev/null +++ b/candle-pyo3/py_src/candle/models/bert.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Optional +from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList +from candle import Tensor +import candle +import candle.functional as F +from typing import Tuple, Optional + + +@dataclass +class Config: + vocab_size: int = 30522 + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + intermediate_size: int = 3072 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + max_position_embeddings: int = 512 + type_vocab_size: int = 2 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-12 + pad_token_id: int = 0 + position_embedding_type: str = "absolute" + use_cache: bool = True + classifier_dropout: Optional[float] = None + model_type: Optional[str] = "bert" + + +class BertSelfAttention(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + all_head_size = int(config.num_attention_heads * self.attention_head_size) + hidden_size = config.hidden_size + self.query = Linear(hidden_size, all_head_size) + self.key = Linear(hidden_size, all_head_size) + self.value = Linear(hidden_size, all_head_size) + + def transpose_for_scores(self, x: Tensor) -> Tensor: + new_x_shape = x.shape[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.reshape(new_x_shape).transpose(1, 2) + return x.contiguous() + + def forward(self, hidden_states: Tensor) -> Tensor: + query = self.query.forward(hidden_states) + key = self.key.forward(hidden_states) + value = self.value.forward(hidden_states) + + query = self.transpose_for_scores(query) + key = self.transpose_for_scores(key) + value = self.transpose_for_scores(value) + + attention_scores = query.matmul(key.t()) + attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5) + attention_probs = F.softmax(attention_scores, dim=-1) + + context_layer = attention_probs.matmul(value) + context_layer = context_layer.transpose(1, 2).contiguous() + context_layer = context_layer.flatten_from(-2) + return context_layer + + +class BertSelfOutput(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.dense = Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: + hidden_states = self.dense.forward(hidden_states) + return self.LayerNorm.forward(hidden_states + input_tensor) + + +class BertAttention(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, hidden_states: Tensor) -> Tensor: + self_outputs = self.self.forward(hidden_states) + attention_output = self.output.forward(self_outputs, hidden_states) + return attention_output + + +class BertIntermediate(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.dense = Linear(config.hidden_size, config.intermediate_size) + self.act = F.gelu if config.hidden_act == "gelu" else F.relu + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense.forward(hidden_states) + return self.act(hidden_states) + + +class BertOutput(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.dense = Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: + hidden_states = self.dense.forward(hidden_states) + return self.LayerNorm.forward(hidden_states + input_tensor) + + +class BertLayer(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states: Tensor) -> Tensor: + attention_output = self.attention.forward(hidden_states) + # TODO: Support cross-attention? + # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + # TODO: Support something similar to `apply_chunking_to_forward`? + intermediate_output = self.intermediate.forward(attention_output) + layer_output = self.output.forward(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.layer = ModuleList() + for _ in range(config.num_hidden_layers): + self.layer.append(BertLayer(config)) + + def forward(self, hidden_states: Tensor) -> Tensor: + for l in self.layer: + hidden_states = l.forward(hidden_states) + return hidden_states + + +class BertEmbeddings(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.word_embeddings = Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape( + (1, config.max_position_embeddings) + ) + + def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor: + (_batch_size, seq_len) = input_ids.shape + input_embeddings = self.word_embeddings.forward(input_ids) + token_type_embeddings = self.token_type_embeddings.forward(token_type_ids) + embeddings: Tensor = input_embeddings + token_type_embeddings + + position_ids = list(range(seq_len)) + position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device) + + embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids)) + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertPooler(Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.dense = Linear(config.hidden_size, config.hidden_size) + self.activation = F.tanh + + def forward(self, hidden_states: Tensor) -> Tensor: + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense.forward(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 +class BertModel(Module): + def __init__(self, config: Config, add_pooling_layer=True) -> None: + super().__init__() + self.config = config + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]: + embeddings = self.embeddings.forward(input_ids, token_type_ids) + encoder_out = self.encoder.forward(embeddings) + pooled_output = self.pooler(encoder_out) if self.pooler is not None else None + return encoder_out, pooled_output diff --git a/candle-pyo3/py_src/candle/models/llama.py b/candle-pyo3/py_src/candle/models/llama.py new file mode 100644 index 00000000..fd9b30af --- /dev/null +++ b/candle-pyo3/py_src/candle/models/llama.py @@ -0,0 +1,150 @@ +import candle +from typing import Dict, Tuple, Any +from candle import Tensor, QTensor, utils, nn +from candle.nn import Module, ModuleList + + +def masked_fill(on_false: Tensor, mask: Tensor, on_true: Tensor): + shape = mask.shape + on_true = candle.tensor(on_true).broadcast_as(shape) + return mask.where_cond(on_true, on_false) + + +def precompute_freqs_cis(hparams: Dict[str, Any], freq_base: float, max_seq_len: int): + head_dim = hparams["n_embd"] // hparams["n_head"] + theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)] + theta = candle.tensor(theta) + idx_theta = [float(i) for i in range(max_seq_len)] + idx_theta = candle.tensor(idx_theta).reshape((max_seq_len, 1)) + m = idx_theta.matmul(theta.unsqueeze(0)) + return (m.cos(), m.sin()) + + +class RmsNorm(Module): + def __init__(self, qtensor: QTensor): + super().__init__() + self.weight = qtensor.dequantize() + + def forward(self, x: Tensor) -> Tensor: + b_size, seq_len, hidden_size = x.shape + norm_x = x.sqr().sum_keepdim(2) / hidden_size + x_normed = x.broadcast_div((norm_x + 1e-5).sqrt()) + return x_normed.broadcast_mul(self.weight) + + +class QuantizedLayer(Module): + def __init__( + self, + layer_idx: int, + hparams: Dict[str, Any], + all_tensors: Dict[str, QTensor], + cos_sin: Tuple[Tensor, Tensor], + ): + super().__init__() + p = f"layers.{layer_idx}" + self.attention_wq = all_tensors[f"{p}.attention.wq.weight"] + self.attention_wk = all_tensors[f"{p}.attention.wk.weight"] + self.attention_wv = all_tensors[f"{p}.attention.wv.weight"] + self.attention_wo = all_tensors[f"{p}.attention.wo.weight"] + self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"] + self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"] + self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"] + self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"]) + self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"]) + + self.n_head = hparams["n_head"] + self.n_kv_head = self.n_head + self.head_dim = hparams["n_embd"] // self.n_head + + self.kv_cache = None + self.cos = cos_sin[0] + self.sin = cos_sin[1] + self._non_persistent_buffers_set.add("cos") + self._non_persistent_buffers_set.add("sin") + + def forward(self, x: Tensor, mask: Tensor, index_pos: int) -> Tensor: + residual = x + x = self.attn_norm(x) + attn = self.forward_attn(x, mask, index_pos) + x = attn + residual + + residual = x + x = self.ffn_norm(x) + w1 = self.ffw1.matmul_t(x) + w3 = self.ffw3.matmul_t(x) + mlp = self.ffw2.matmul_t(nn.silu(w1) * w3) + + return mlp + residual + + def forward_attn(self, x: Tensor, mask: Tensor, index_pos: int): + b_size, seq_len, n_embd = x.shape + q = self.attention_wq.matmul_t(x) + k = self.attention_wk.matmul_t(x) + v = self.attention_wv.matmul_t(x) + + q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2) + k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) + v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2) + + q = self.apply_rotary_emb(q, index_pos) + k = self.apply_rotary_emb(k, index_pos) + + if self.kv_cache is not None and index_pos > 0: + prev_k, prev_v = self.kv_cache + k = candle.cat([prev_k, k], 2).contiguous() + v = candle.cat([prev_v, v], 2).contiguous() + + self.kv_cache = (k, v) + + # TODO: maybe repeat k/v here if we start supporting MQA. + + att = q.matmul(k.t()) / self.head_dim**0.5 + mask = mask.broadcast_as(att.shape) + att = masked_fill(att, mask, float("-inf")) + att = nn.softmax(att, -1) + y = att.matmul(v.contiguous()) + y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd)) + return self.attention_wo.matmul_t(y) + + def apply_rotary_emb(self, x: Tensor, index_pos: int): + b_size, n_head, seq_len, n_embd = x.shape + cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1)) + sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1)) + x = x.reshape((b_size, n_head, seq_len, n_embd // 2, 2)) + x0 = x.narrow(-1, 0, 1) + x1 = x.narrow(-1, 1, 1) + y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin) + y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos) + rope = candle.cat([y0, y1], -1) + return rope.flatten_from(-2) + + +class QuantizedLlama(Module): + def __init__(self, hparams: Dict[str, Any], all_tensors: Dict[str, QTensor]): + super().__init__() + self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize() + self.norm = RmsNorm(all_tensors["norm.weight"]) + self.output = all_tensors["output.weight"] + self.layers = ModuleList() + rope_freq = hparams.get("rope_freq", 10000.0) + cos_sin = precompute_freqs_cis(hparams, rope_freq, hparams["context_length"]) + for layer_idx in range(hparams["n_layer"]): + layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) + self.layers.append(layer) + + def forward(self, token: Tensor, index_pos: int) -> Tensor: + b_size, seq_len = token.shape + vocab_size, hidden_size = self.tok_embeddings.shape + token = token.reshape((b_size * seq_len,)) + x = self.tok_embeddings.index_select(token, 0) + x = x.reshape((b_size, seq_len, hidden_size)) + + mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)] + mask = candle.tensor(mask).reshape((seq_len, seq_len)) + + for layer in self.layers: + x = layer(x, mask, index_pos) + x = self.norm(x) + x = x.narrow(1, -1, 1).squeeze(1) + x = self.output.matmul_t(x) + return x |