summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/models
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/py_src/candle/models')
-rw-r--r--candle-pyo3/py_src/candle/models/bert.py194
-rw-r--r--candle-pyo3/py_src/candle/models/llama.py150
2 files changed, 344 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py
new file mode 100644
index 00000000..0a773f93
--- /dev/null
+++ b/candle-pyo3/py_src/candle/models/bert.py
@@ -0,0 +1,194 @@
+from dataclasses import dataclass
+from typing import Optional
+from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList
+from candle import Tensor
+import candle
+import candle.functional as F
+from typing import Tuple, Optional
+
+
+@dataclass
+class Config:
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_hidden_layers: int = 12
+ num_attention_heads: int = 12
+ intermediate_size: int = 3072
+ hidden_act: str = "gelu"
+ hidden_dropout_prob: float = 0.1
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+ layer_norm_eps: float = 1e-12
+ pad_token_id: int = 0
+ position_embedding_type: str = "absolute"
+ use_cache: bool = True
+ classifier_dropout: Optional[float] = None
+ model_type: Optional[str] = "bert"
+
+
+class BertSelfAttention(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
+ all_head_size = int(config.num_attention_heads * self.attention_head_size)
+ hidden_size = config.hidden_size
+ self.query = Linear(hidden_size, all_head_size)
+ self.key = Linear(hidden_size, all_head_size)
+ self.value = Linear(hidden_size, all_head_size)
+
+ def transpose_for_scores(self, x: Tensor) -> Tensor:
+ new_x_shape = x.shape[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.reshape(new_x_shape).transpose(1, 2)
+ return x.contiguous()
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ query = self.query.forward(hidden_states)
+ key = self.key.forward(hidden_states)
+ value = self.value.forward(hidden_states)
+
+ query = self.transpose_for_scores(query)
+ key = self.transpose_for_scores(key)
+ value = self.transpose_for_scores(value)
+
+ attention_scores = query.matmul(key.t())
+ attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5)
+ attention_probs = F.softmax(attention_scores, dim=-1)
+
+ context_layer = attention_probs.matmul(value)
+ context_layer = context_layer.transpose(1, 2).contiguous()
+ context_layer = context_layer.flatten_from(-2)
+ return context_layer
+
+
+class BertSelfOutput(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.dense = Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
+ hidden_states = self.dense.forward(hidden_states)
+ return self.LayerNorm.forward(hidden_states + input_tensor)
+
+
+class BertAttention(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.self = BertSelfAttention(config)
+ self.output = BertSelfOutput(config)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ self_outputs = self.self.forward(hidden_states)
+ attention_output = self.output.forward(self_outputs, hidden_states)
+ return attention_output
+
+
+class BertIntermediate(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
+ self.act = F.gelu if config.hidden_act == "gelu" else F.relu
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ hidden_states = self.dense.forward(hidden_states)
+ return self.act(hidden_states)
+
+
+class BertOutput(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
+ hidden_states = self.dense.forward(hidden_states)
+ return self.LayerNorm.forward(hidden_states + input_tensor)
+
+
+class BertLayer(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.attention = BertAttention(config)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ attention_output = self.attention.forward(hidden_states)
+ # TODO: Support cross-attention?
+ # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
+ # TODO: Support something similar to `apply_chunking_to_forward`?
+ intermediate_output = self.intermediate.forward(attention_output)
+ layer_output = self.output.forward(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.layer = ModuleList()
+ for _ in range(config.num_hidden_layers):
+ self.layer.append(BertLayer(config))
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ for l in self.layer:
+ hidden_states = l.forward(hidden_states)
+ return hidden_states
+
+
+class BertEmbeddings(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
+ self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape(
+ (1, config.max_position_embeddings)
+ )
+
+ def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor:
+ (_batch_size, seq_len) = input_ids.shape
+ input_embeddings = self.word_embeddings.forward(input_ids)
+ token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)
+ embeddings: Tensor = input_embeddings + token_type_embeddings
+
+ position_ids = list(range(seq_len))
+ position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device)
+
+ embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids))
+ embeddings = self.LayerNorm(embeddings)
+ return embeddings
+
+
+class BertPooler(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.dense = Linear(config.hidden_size, config.hidden_size)
+ self.activation = F.tanh
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense.forward(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
+class BertModel(Module):
+ def __init__(self, config: Config, add_pooling_layer=True) -> None:
+ super().__init__()
+ self.config = config
+ self.embeddings = BertEmbeddings(config)
+ self.encoder = BertEncoder(config)
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
+ embeddings = self.embeddings.forward(input_ids, token_type_ids)
+ encoder_out = self.encoder.forward(embeddings)
+ pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
+ return encoder_out, pooled_output
diff --git a/candle-pyo3/py_src/candle/models/llama.py b/candle-pyo3/py_src/candle/models/llama.py
new file mode 100644
index 00000000..fd9b30af
--- /dev/null
+++ b/candle-pyo3/py_src/candle/models/llama.py
@@ -0,0 +1,150 @@
+import candle
+from typing import Dict, Tuple, Any
+from candle import Tensor, QTensor, utils, nn
+from candle.nn import Module, ModuleList
+
+
+def masked_fill(on_false: Tensor, mask: Tensor, on_true: Tensor):
+ shape = mask.shape
+ on_true = candle.tensor(on_true).broadcast_as(shape)
+ return mask.where_cond(on_true, on_false)
+
+
+def precompute_freqs_cis(hparams: Dict[str, Any], freq_base: float, max_seq_len: int):
+ head_dim = hparams["n_embd"] // hparams["n_head"]
+ theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
+ theta = candle.tensor(theta)
+ idx_theta = [float(i) for i in range(max_seq_len)]
+ idx_theta = candle.tensor(idx_theta).reshape((max_seq_len, 1))
+ m = idx_theta.matmul(theta.unsqueeze(0))
+ return (m.cos(), m.sin())
+
+
+class RmsNorm(Module):
+ def __init__(self, qtensor: QTensor):
+ super().__init__()
+ self.weight = qtensor.dequantize()
+
+ def forward(self, x: Tensor) -> Tensor:
+ b_size, seq_len, hidden_size = x.shape
+ norm_x = x.sqr().sum_keepdim(2) / hidden_size
+ x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
+ return x_normed.broadcast_mul(self.weight)
+
+
+class QuantizedLayer(Module):
+ def __init__(
+ self,
+ layer_idx: int,
+ hparams: Dict[str, Any],
+ all_tensors: Dict[str, QTensor],
+ cos_sin: Tuple[Tensor, Tensor],
+ ):
+ super().__init__()
+ p = f"layers.{layer_idx}"
+ self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
+ self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
+ self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
+ self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
+ self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
+ self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
+ self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
+ self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
+ self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
+
+ self.n_head = hparams["n_head"]
+ self.n_kv_head = self.n_head
+ self.head_dim = hparams["n_embd"] // self.n_head
+
+ self.kv_cache = None
+ self.cos = cos_sin[0]
+ self.sin = cos_sin[1]
+ self._non_persistent_buffers_set.add("cos")
+ self._non_persistent_buffers_set.add("sin")
+
+ def forward(self, x: Tensor, mask: Tensor, index_pos: int) -> Tensor:
+ residual = x
+ x = self.attn_norm(x)
+ attn = self.forward_attn(x, mask, index_pos)
+ x = attn + residual
+
+ residual = x
+ x = self.ffn_norm(x)
+ w1 = self.ffw1.matmul_t(x)
+ w3 = self.ffw3.matmul_t(x)
+ mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
+
+ return mlp + residual
+
+ def forward_attn(self, x: Tensor, mask: Tensor, index_pos: int):
+ b_size, seq_len, n_embd = x.shape
+ q = self.attention_wq.matmul_t(x)
+ k = self.attention_wk.matmul_t(x)
+ v = self.attention_wv.matmul_t(x)
+
+ q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
+ k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
+ v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
+
+ q = self.apply_rotary_emb(q, index_pos)
+ k = self.apply_rotary_emb(k, index_pos)
+
+ if self.kv_cache is not None and index_pos > 0:
+ prev_k, prev_v = self.kv_cache
+ k = candle.cat([prev_k, k], 2).contiguous()
+ v = candle.cat([prev_v, v], 2).contiguous()
+
+ self.kv_cache = (k, v)
+
+ # TODO: maybe repeat k/v here if we start supporting MQA.
+
+ att = q.matmul(k.t()) / self.head_dim**0.5
+ mask = mask.broadcast_as(att.shape)
+ att = masked_fill(att, mask, float("-inf"))
+ att = nn.softmax(att, -1)
+ y = att.matmul(v.contiguous())
+ y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
+ return self.attention_wo.matmul_t(y)
+
+ def apply_rotary_emb(self, x: Tensor, index_pos: int):
+ b_size, n_head, seq_len, n_embd = x.shape
+ cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1))
+ sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1))
+ x = x.reshape((b_size, n_head, seq_len, n_embd // 2, 2))
+ x0 = x.narrow(-1, 0, 1)
+ x1 = x.narrow(-1, 1, 1)
+ y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
+ y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
+ rope = candle.cat([y0, y1], -1)
+ return rope.flatten_from(-2)
+
+
+class QuantizedLlama(Module):
+ def __init__(self, hparams: Dict[str, Any], all_tensors: Dict[str, QTensor]):
+ super().__init__()
+ self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
+ self.norm = RmsNorm(all_tensors["norm.weight"])
+ self.output = all_tensors["output.weight"]
+ self.layers = ModuleList()
+ rope_freq = hparams.get("rope_freq", 10000.0)
+ cos_sin = precompute_freqs_cis(hparams, rope_freq, hparams["context_length"])
+ for layer_idx in range(hparams["n_layer"]):
+ layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
+ self.layers.append(layer)
+
+ def forward(self, token: Tensor, index_pos: int) -> Tensor:
+ b_size, seq_len = token.shape
+ vocab_size, hidden_size = self.tok_embeddings.shape
+ token = token.reshape((b_size * seq_len,))
+ x = self.tok_embeddings.index_select(token, 0)
+ x = x.reshape((b_size, seq_len, hidden_size))
+
+ mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
+ mask = candle.tensor(mask).reshape((seq_len, seq_len))
+
+ for layer in self.layers:
+ x = layer(x, mask, index_pos)
+ x = self.norm(x)
+ x = x.narrow(1, -1, 1).squeeze(1)
+ x = self.output.matmul_t(x)
+ return x