summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-06 20:01:07 +0200
committerGitHub <noreply@github.com>2023-10-06 19:01:07 +0100
commit904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch)
tree8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle
parentb0442eff8a696d1faba10e23ba645eb11e385116 (diff)
downloadcandle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle')
-rw-r--r--candle-pyo3/py_src/candle/__init__.py29
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.py8
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.pyi (renamed from candle-pyo3/py_src/candle/nn/__init__.pyi)21
-rw-r--r--candle-pyo3/py_src/candle/models/bert.py194
-rw-r--r--candle-pyo3/py_src/candle/models/llama.py150
-rw-r--r--candle-pyo3/py_src/candle/nn/__init__.py10
-rw-r--r--candle-pyo3/py_src/candle/nn/container.py483
-rw-r--r--candle-pyo3/py_src/candle/nn/linear.py119
-rw-r--r--candle-pyo3/py_src/candle/nn/module.py702
-rw-r--r--candle-pyo3/py_src/candle/nn/normalization.py54
-rw-r--r--candle-pyo3/py_src/candle/nn/sparse.py39
-rw-r--r--candle-pyo3/py_src/candle/typing/__init__.py8
12 files changed, 1806 insertions, 11 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py
index 951609cc..dc97b775 100644
--- a/candle-pyo3/py_src/candle/__init__.py
+++ b/candle-pyo3/py_src/candle/__init__.py
@@ -1,5 +1,30 @@
-from .candle import *
+import logging
+
+try:
+ from .candle import *
+except ImportError as e:
+ # If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
+ logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
+ import os
+ import platform
+
+ # Try to locate CUDA_PATH environment variable
+ cuda_path = os.environ.get("CUDA_PATH", None)
+ if cuda_path:
+ logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
+ if platform.system() == "Windows":
+ cuda_path = os.path.join(cuda_path, "bin")
+ else:
+ cuda_path = os.path.join(cuda_path, "lib64")
+
+ logging.warning(f"Adding {cuda_path} to DLL search path...")
+ os.add_dll_directory(cuda_path)
+
+ try:
+ from .candle import *
+ except ImportError as inner_e:
+ raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
__doc__ = candle.__doc__
if hasattr(candle, "__all__"):
- __all__ = candle.__all__ \ No newline at end of file
+ __all__ = candle.__all__
diff --git a/candle-pyo3/py_src/candle/functional/__init__.py b/candle-pyo3/py_src/candle/functional/__init__.py
new file mode 100644
index 00000000..efb246f0
--- /dev/null
+++ b/candle-pyo3/py_src/candle/functional/__init__.py
@@ -0,0 +1,8 @@
+# Generated content DO NOT EDIT
+from .. import functional
+
+gelu = functional.gelu
+relu = functional.relu
+silu = functional.silu
+softmax = functional.softmax
+tanh = functional.tanh
diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi
index 01b30fce..a46b6137 100644
--- a/candle-pyo3/py_src/candle/nn/__init__.pyi
+++ b/candle-pyo3/py_src/candle/functional/__init__.pyi
@@ -5,6 +5,20 @@ from candle.typing import _ArrayLike, Device
from candle import Tensor, DType, QTensor
@staticmethod
+def gelu(tensor: Tensor) -> Tensor:
+ """
+ Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
+ """
+ pass
+
+@staticmethod
+def relu(tensor: Tensor) -> Tensor:
+ """
+ Applies the Rectified Linear Unit (ReLU) function to a given tensor.
+ """
+ pass
+
+@staticmethod
def silu(tensor: Tensor) -> Tensor:
"""
Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
@@ -17,3 +31,10 @@ def softmax(tensor: Tensor, dim: int) -> Tensor:
Applies the Softmax function to a given tensor.#
"""
pass
+
+@staticmethod
+def tanh(tensor: Tensor) -> Tensor:
+ """
+ Applies the tanh function to a given tensor.
+ """
+ pass
diff --git a/candle-pyo3/py_src/candle/models/bert.py b/candle-pyo3/py_src/candle/models/bert.py
new file mode 100644
index 00000000..0a773f93
--- /dev/null
+++ b/candle-pyo3/py_src/candle/models/bert.py
@@ -0,0 +1,194 @@
+from dataclasses import dataclass
+from typing import Optional
+from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList
+from candle import Tensor
+import candle
+import candle.functional as F
+from typing import Tuple, Optional
+
+
+@dataclass
+class Config:
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_hidden_layers: int = 12
+ num_attention_heads: int = 12
+ intermediate_size: int = 3072
+ hidden_act: str = "gelu"
+ hidden_dropout_prob: float = 0.1
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+ layer_norm_eps: float = 1e-12
+ pad_token_id: int = 0
+ position_embedding_type: str = "absolute"
+ use_cache: bool = True
+ classifier_dropout: Optional[float] = None
+ model_type: Optional[str] = "bert"
+
+
+class BertSelfAttention(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
+ all_head_size = int(config.num_attention_heads * self.attention_head_size)
+ hidden_size = config.hidden_size
+ self.query = Linear(hidden_size, all_head_size)
+ self.key = Linear(hidden_size, all_head_size)
+ self.value = Linear(hidden_size, all_head_size)
+
+ def transpose_for_scores(self, x: Tensor) -> Tensor:
+ new_x_shape = x.shape[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.reshape(new_x_shape).transpose(1, 2)
+ return x.contiguous()
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ query = self.query.forward(hidden_states)
+ key = self.key.forward(hidden_states)
+ value = self.value.forward(hidden_states)
+
+ query = self.transpose_for_scores(query)
+ key = self.transpose_for_scores(key)
+ value = self.transpose_for_scores(value)
+
+ attention_scores = query.matmul(key.t())
+ attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5)
+ attention_probs = F.softmax(attention_scores, dim=-1)
+
+ context_layer = attention_probs.matmul(value)
+ context_layer = context_layer.transpose(1, 2).contiguous()
+ context_layer = context_layer.flatten_from(-2)
+ return context_layer
+
+
+class BertSelfOutput(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.dense = Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
+ hidden_states = self.dense.forward(hidden_states)
+ return self.LayerNorm.forward(hidden_states + input_tensor)
+
+
+class BertAttention(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.self = BertSelfAttention(config)
+ self.output = BertSelfOutput(config)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ self_outputs = self.self.forward(hidden_states)
+ attention_output = self.output.forward(self_outputs, hidden_states)
+ return attention_output
+
+
+class BertIntermediate(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
+ self.act = F.gelu if config.hidden_act == "gelu" else F.relu
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ hidden_states = self.dense.forward(hidden_states)
+ return self.act(hidden_states)
+
+
+class BertOutput(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
+ hidden_states = self.dense.forward(hidden_states)
+ return self.LayerNorm.forward(hidden_states + input_tensor)
+
+
+class BertLayer(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.attention = BertAttention(config)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ attention_output = self.attention.forward(hidden_states)
+ # TODO: Support cross-attention?
+ # https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
+ # TODO: Support something similar to `apply_chunking_to_forward`?
+ intermediate_output = self.intermediate.forward(attention_output)
+ layer_output = self.output.forward(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.layer = ModuleList()
+ for _ in range(config.num_hidden_layers):
+ self.layer.append(BertLayer(config))
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ for l in self.layer:
+ hidden_states = l.forward(hidden_states)
+ return hidden_states
+
+
+class BertEmbeddings(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
+ self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape(
+ (1, config.max_position_embeddings)
+ )
+
+ def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor:
+ (_batch_size, seq_len) = input_ids.shape
+ input_embeddings = self.word_embeddings.forward(input_ids)
+ token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)
+ embeddings: Tensor = input_embeddings + token_type_embeddings
+
+ position_ids = list(range(seq_len))
+ position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device)
+
+ embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids))
+ embeddings = self.LayerNorm(embeddings)
+ return embeddings
+
+
+class BertPooler(Module):
+ def __init__(self, config: Config) -> None:
+ super().__init__()
+ self.dense = Linear(config.hidden_size, config.hidden_size)
+ self.activation = F.tanh
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense.forward(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
+class BertModel(Module):
+ def __init__(self, config: Config, add_pooling_layer=True) -> None:
+ super().__init__()
+ self.config = config
+ self.embeddings = BertEmbeddings(config)
+ self.encoder = BertEncoder(config)
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
+ embeddings = self.embeddings.forward(input_ids, token_type_ids)
+ encoder_out = self.encoder.forward(embeddings)
+ pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
+ return encoder_out, pooled_output
diff --git a/candle-pyo3/py_src/candle/models/llama.py b/candle-pyo3/py_src/candle/models/llama.py
new file mode 100644
index 00000000..fd9b30af
--- /dev/null
+++ b/candle-pyo3/py_src/candle/models/llama.py
@@ -0,0 +1,150 @@
+import candle
+from typing import Dict, Tuple, Any
+from candle import Tensor, QTensor, utils, nn
+from candle.nn import Module, ModuleList
+
+
+def masked_fill(on_false: Tensor, mask: Tensor, on_true: Tensor):
+ shape = mask.shape
+ on_true = candle.tensor(on_true).broadcast_as(shape)
+ return mask.where_cond(on_true, on_false)
+
+
+def precompute_freqs_cis(hparams: Dict[str, Any], freq_base: float, max_seq_len: int):
+ head_dim = hparams["n_embd"] // hparams["n_head"]
+ theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
+ theta = candle.tensor(theta)
+ idx_theta = [float(i) for i in range(max_seq_len)]
+ idx_theta = candle.tensor(idx_theta).reshape((max_seq_len, 1))
+ m = idx_theta.matmul(theta.unsqueeze(0))
+ return (m.cos(), m.sin())
+
+
+class RmsNorm(Module):
+ def __init__(self, qtensor: QTensor):
+ super().__init__()
+ self.weight = qtensor.dequantize()
+
+ def forward(self, x: Tensor) -> Tensor:
+ b_size, seq_len, hidden_size = x.shape
+ norm_x = x.sqr().sum_keepdim(2) / hidden_size
+ x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
+ return x_normed.broadcast_mul(self.weight)
+
+
+class QuantizedLayer(Module):
+ def __init__(
+ self,
+ layer_idx: int,
+ hparams: Dict[str, Any],
+ all_tensors: Dict[str, QTensor],
+ cos_sin: Tuple[Tensor, Tensor],
+ ):
+ super().__init__()
+ p = f"layers.{layer_idx}"
+ self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
+ self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
+ self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
+ self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
+ self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
+ self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
+ self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
+ self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
+ self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
+
+ self.n_head = hparams["n_head"]
+ self.n_kv_head = self.n_head
+ self.head_dim = hparams["n_embd"] // self.n_head
+
+ self.kv_cache = None
+ self.cos = cos_sin[0]
+ self.sin = cos_sin[1]
+ self._non_persistent_buffers_set.add("cos")
+ self._non_persistent_buffers_set.add("sin")
+
+ def forward(self, x: Tensor, mask: Tensor, index_pos: int) -> Tensor:
+ residual = x
+ x = self.attn_norm(x)
+ attn = self.forward_attn(x, mask, index_pos)
+ x = attn + residual
+
+ residual = x
+ x = self.ffn_norm(x)
+ w1 = self.ffw1.matmul_t(x)
+ w3 = self.ffw3.matmul_t(x)
+ mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
+
+ return mlp + residual
+
+ def forward_attn(self, x: Tensor, mask: Tensor, index_pos: int):
+ b_size, seq_len, n_embd = x.shape
+ q = self.attention_wq.matmul_t(x)
+ k = self.attention_wk.matmul_t(x)
+ v = self.attention_wv.matmul_t(x)
+
+ q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
+ k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
+ v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
+
+ q = self.apply_rotary_emb(q, index_pos)
+ k = self.apply_rotary_emb(k, index_pos)
+
+ if self.kv_cache is not None and index_pos > 0:
+ prev_k, prev_v = self.kv_cache
+ k = candle.cat([prev_k, k], 2).contiguous()
+ v = candle.cat([prev_v, v], 2).contiguous()
+
+ self.kv_cache = (k, v)
+
+ # TODO: maybe repeat k/v here if we start supporting MQA.
+
+ att = q.matmul(k.t()) / self.head_dim**0.5
+ mask = mask.broadcast_as(att.shape)
+ att = masked_fill(att, mask, float("-inf"))
+ att = nn.softmax(att, -1)
+ y = att.matmul(v.contiguous())
+ y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
+ return self.attention_wo.matmul_t(y)
+
+ def apply_rotary_emb(self, x: Tensor, index_pos: int):
+ b_size, n_head, seq_len, n_embd = x.shape
+ cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1))
+ sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1))
+ x = x.reshape((b_size, n_head, seq_len, n_embd // 2, 2))
+ x0 = x.narrow(-1, 0, 1)
+ x1 = x.narrow(-1, 1, 1)
+ y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
+ y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
+ rope = candle.cat([y0, y1], -1)
+ return rope.flatten_from(-2)
+
+
+class QuantizedLlama(Module):
+ def __init__(self, hparams: Dict[str, Any], all_tensors: Dict[str, QTensor]):
+ super().__init__()
+ self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
+ self.norm = RmsNorm(all_tensors["norm.weight"])
+ self.output = all_tensors["output.weight"]
+ self.layers = ModuleList()
+ rope_freq = hparams.get("rope_freq", 10000.0)
+ cos_sin = precompute_freqs_cis(hparams, rope_freq, hparams["context_length"])
+ for layer_idx in range(hparams["n_layer"]):
+ layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
+ self.layers.append(layer)
+
+ def forward(self, token: Tensor, index_pos: int) -> Tensor:
+ b_size, seq_len = token.shape
+ vocab_size, hidden_size = self.tok_embeddings.shape
+ token = token.reshape((b_size * seq_len,))
+ x = self.tok_embeddings.index_select(token, 0)
+ x = x.reshape((b_size, seq_len, hidden_size))
+
+ mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
+ mask = candle.tensor(mask).reshape((seq_len, seq_len))
+
+ for layer in self.layers:
+ x = layer(x, mask, index_pos)
+ x = self.norm(x)
+ x = x.narrow(1, -1, 1).squeeze(1)
+ x = self.output.matmul_t(x)
+ return x
diff --git a/candle-pyo3/py_src/candle/nn/__init__.py b/candle-pyo3/py_src/candle/nn/__init__.py
index b8c5cfb7..8da0e8aa 100644
--- a/candle-pyo3/py_src/candle/nn/__init__.py
+++ b/candle-pyo3/py_src/candle/nn/__init__.py
@@ -1,5 +1,5 @@
-# Generated content DO NOT EDIT
-from .. import nn
-
-silu = nn.silu
-softmax = nn.softmax
+from .module import Module
+from .container import Sequential, ModuleList, ModuleDict
+from .sparse import Embedding
+from .normalization import LayerNorm
+from .linear import Linear
diff --git a/candle-pyo3/py_src/candle/nn/container.py b/candle-pyo3/py_src/candle/nn/container.py
new file mode 100644
index 00000000..15ed8dd2
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/container.py
@@ -0,0 +1,483 @@
+# see https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
+from .module import Module
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Iterator,
+ Mapping,
+ Optional,
+ overload,
+ Tuple,
+ TypeVar,
+ Union,
+)
+from collections import OrderedDict, abc as container_abcs
+import operator
+from itertools import chain, islice
+
+__all__ = ["Sequential", "ModuleList", "ModuleDict"]
+
+T = TypeVar("T", bound=Module)
+
+
+def _addindent(s_: str, numSpaces: int):
+ s = s_.split("\n")
+ # don't do anything for single-line stuff
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(numSpaces * " ") + line for line in s]
+ s = "\n".join(s)
+ s = first + "\n" + s
+ return s
+
+
+class Sequential(Module):
+ r"""A sequential container.
+ Modules will be added to it in the order they are passed in the
+ constructor. Alternatively, an ``OrderedDict`` of modules can be
+ passed in. The ``forward()`` method of ``Sequential`` accepts any
+ input and forwards it to the first module it contains. It then
+ "chains" outputs to inputs sequentially for each subsequent module,
+ finally returning the output of the last module.
+
+ The value a ``Sequential`` provides over manually calling a sequence
+ of modules is that it allows treating the whole container as a
+ single module, such that performing a transformation on the
+ ``Sequential`` applies to each of the modules it stores (which are
+ each a registered submodule of the ``Sequential``).
+
+ What's the difference between a ``Sequential`` and a
+ :class:`candle.nn.ModuleList`? A ``ModuleList`` is exactly what it
+ sounds like--a list for storing ``Module`` s! On the other hand,
+ the layers in a ``Sequential`` are connected in a cascading way.
+ """
+
+ _modules: Dict[str, Module] # type: ignore[assignment]
+
+ @overload
+ def __init__(self, *args: Module) -> None:
+ ...
+
+ @overload
+ def __init__(self, arg: "OrderedDict[str, Module]") -> None:
+ ...
+
+ def __init__(self, *args):
+ super().__init__()
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
+ for key, module in args[0].items():
+ self.add_module(key, module)
+ else:
+ for idx, module in enumerate(args):
+ self.add_module(str(idx), module)
+
+ def _get_item_by_idx(self, iterator, idx) -> T:
+ """Get the idx-th item of the iterator"""
+ size = len(self)
+ idx = operator.index(idx)
+ if not -size <= idx < size:
+ raise IndexError("index {} is out of range".format(idx))
+ idx %= size
+ return next(islice(iterator, idx, None))
+
+ def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]:
+ if isinstance(idx, slice):
+ return self.__class__(OrderedDict(list(self._modules.items())[idx]))
+ else:
+ return self._get_item_by_idx(self._modules.values(), idx)
+
+ def __setitem__(self, idx: int, module: Module) -> None:
+ key: str = self._get_item_by_idx(self._modules.keys(), idx)
+ return setattr(self, key, module)
+
+ def __delitem__(self, idx: Union[slice, int]) -> None:
+ if isinstance(idx, slice):
+ for key in list(self._modules.keys())[idx]:
+ delattr(self, key)
+ else:
+ key = self._get_item_by_idx(self._modules.keys(), idx)
+ delattr(self, key)
+ # To preserve numbering
+ str_indices = [str(i) for i in range(len(self._modules))]
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __add__(self, other) -> "Sequential":
+ if isinstance(other, Sequential):
+ ret = Sequential()
+ for layer in self:
+ ret.append(layer)
+ for layer in other:
+ ret.append(layer)
+ return ret
+ else:
+ raise ValueError(
+ "add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other)))
+ )
+
+ def pop(self, key: Union[int, slice]) -> Module:
+ v = self[key]
+ del self[key]
+ return v
+
+ def __iadd__(self, other) -> "Sequential":
+ if isinstance(other, Sequential):
+ offset = len(self)
+ for i, module in enumerate(other):
+ self.add_module(str(i + offset), module)
+ return self
+ else:
+ raise ValueError(
+ "add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other)))
+ )
+
+ def __mul__(self, other: int) -> "Sequential":
+ if not isinstance(other, int):
+ raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
+ elif other <= 0:
+ raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
+ else:
+ combined = Sequential()
+ offset = 0
+ for _ in range(other):
+ for module in self:
+ combined.add_module(str(offset), module)
+ offset += 1
+ return combined
+
+ def __rmul__(self, other: int) -> "Sequential":
+ return self.__mul__(other)
+
+ def __imul__(self, other: int) -> "Sequential":
+ if not isinstance(other, int):
+ raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
+ elif other <= 0:
+ raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
+ else:
+ len_original = len(self)
+ offset = len(self)
+ for _ in range(other - 1):
+ for i in range(len_original):
+ self.add_module(str(i + offset), self._modules[str(i)])
+ offset += len_original
+ return self
+
+ def __dir__(self):
+ keys = super().__dir__()
+ keys = [key for key in keys if not key.isdigit()]
+ return keys
+
+ def __iter__(self) -> Iterator[Module]:
+ return iter(self._modules.values())
+
+ # NB: We can't really type check this function as the type of input
+ # may change dynamically (as is tested in
+ # TestScript.test_sequential_intermediary_types). Cannot annotate
+ # with Any as TorchScript expects a more precise type
+ def forward(self, input):
+ for module in self:
+ input = module(input)
+ return input
+
+ def append(self, module: Module) -> "Sequential":
+ r"""Appends a given module to the end.
+
+ Args:
+ module (nn.Module): module to append
+ """
+ self.add_module(str(len(self)), module)
+ return self
+
+ def insert(self, index: int, module: Module) -> "Sequential":
+ if not isinstance(module, Module):
+ raise AssertionError("module should be of type: {}".format(Module))
+ n = len(self._modules)
+ if not (-n <= index <= n):
+ raise IndexError("Index out of range: {}".format(index))
+ if index < 0:
+ index += n
+ for i in range(n, index, -1):
+ self._modules[str(i)] = self._modules[str(i - 1)]
+ self._modules[str(index)] = module
+ return self
+
+ def extend(self, sequential) -> "Sequential":
+ for layer in sequential:
+ self.append(layer)
+ return self
+
+
+class ModuleList(Module):
+ r"""Holds submodules in a list.
+
+ :class:`~candle.nn.ModuleList` can be indexed like a regular Python list, but
+ modules it contains are properly registered, and will be visible by all
+ :class:`~candle.nn.Module` methods.
+
+ Args:
+ modules (iterable, optional): an iterable of modules to add
+
+ Example::
+
+ class MyModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
+
+ def forward(self, x):
+ # ModuleList can act as an iterable, or be indexed using ints
+ for i, l in enumerate(self.linears):
+ x = self.linears[i // 2](x) + l(x)
+ return x
+ """
+
+ _modules: Dict[str, Module] # type: ignore[assignment]
+
+ def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
+ super().__init__()
+ if modules is not None:
+ self += modules
+
+ def _get_abs_string_index(self, idx):
+ """Get the absolute index for the list of modules"""
+ idx = operator.index(idx)
+ if not (-len(self) <= idx < len(self)):
+ raise IndexError("index {} is out of range".format(idx))
+ if idx < 0:
+ idx += len(self)
+ return str(idx)
+
+ def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
+ if isinstance(idx, slice):
+ return self.__class__(list(self._modules.values())[idx])
+ else:
+ return self._modules[self._get_abs_string_index(idx)]
+
+ def __setitem__(self, idx: int, module: Module) -> None:
+ idx = self._get_abs_string_index(idx)
+ return setattr(self, str(idx), module)
+
+ def __delitem__(self, idx: Union[int, slice]) -> None:
+ if isinstance(idx, slice):
+ for k in range(len(self._modules))[idx]:
+ delattr(self, str(k))
+ else:
+ delattr(self, self._get_abs_string_index(idx))
+ # To preserve numbering, self._modules is being reconstructed with modules after deletion
+ str_indices = [str(i) for i in range(len(self._modules))]
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __iter__(self) -> Iterator[Module]:
+ return iter(self._modules.values())
+
+ def __iadd__(self, modules: Iterable[Module]) -> "ModuleList":
+ return self.extend(modules)
+
+ def __add__(self, other: Iterable[Module]) -> "ModuleList":
+ combined = ModuleList()
+ for i, module in enumerate(chain(self, other)):
+ combined.add_module(str(i), module)
+ return combined
+
+ def __repr__(self):
+ """A custom repr for ModuleList that compresses repeated module representations"""
+ list_of_reprs = [repr(item) for item in self]
+ if len(list_of_reprs) == 0:
+ return self._get_name() + "()"
+
+ start_end_indices = [[0, 0]]
+ repeated_blocks = [list_of_reprs[0]]
+ for i, r in enumerate(list_of_reprs[1:], 1):
+ if r == repeated_blocks[-1]:
+ start_end_indices[-1][1] += 1
+ continue
+
+ start_end_indices.append([i, i])
+ repeated_blocks.append(r)
+
+ lines = []
+ main_str = self._get_name() + "("
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
+ local_repr = f"({start_id}): {b}" # default repr
+
+ if start_id != end_id:
+ n = end_id - start_id + 1
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
+
+ local_repr = _addindent(local_repr, 2)
+ lines.append(local_repr)
+
+ main_str += "\n " + "\n ".join(lines) + "\n"
+ main_str += ")"
+ return main_str
+
+ def __dir__(self):
+ keys = super().__dir__()
+ keys = [key for key in keys if not key.isdigit()]
+ return keys
+
+ def insert(self, index: int, module: Module) -> None:
+ r"""Insert a given module before a given index in the list.
+
+ Args:
+ index (int): index to insert.
+ module (nn.Module): module to insert
+ """
+ for i in range(len(self._modules), index, -1):
+ self._modules[str(i)] = self._modules[str(i - 1)]
+ self._modules[str(index)] = module
+
+ def append(self, module: Module) -> "ModuleList":
+ r"""Appends a given module to the end of the list.
+
+ Args:
+ module (nn.Module): module to append
+ """
+ self.add_module(str(len(self)), module)
+ return self
+
+ def pop(self, key: Union[int, slice]) -> Module:
+ v = self[key]
+ del self[key]
+ return v
+
+ def extend(self, modules: Iterable[Module]) -> "ModuleList":
+ r"""Appends modules from a Python iterable to the end of the list.
+
+ Args:
+ modules (iterable): iterable of modules to append
+ """
+ if not isinstance(modules, container_abcs.Iterable):
+ raise TypeError(
+ "ModuleList.extend should be called with an " "iterable, but got " + type(modules).__name__
+ )
+ offset = len(self)
+ for i, module in enumerate(modules):
+ self.add_module(str(offset + i), module)
+ return self
+
+ # remove forward alltogether to fallback on Module's _forward_unimplemented
+
+
+class ModuleDict(Module):
+ r"""Holds submodules in a dictionary.
+
+ :class:`~candle.nn.ModuleDict` can be indexed like a regular Python dictionary,
+ but modules it contains are properly registered, and will be visible by all
+ :class:`~candle.nn.Module` methods.
+
+ :class:`~candle.nn.ModuleDict` is an **ordered** dictionary that respects
+
+ * the order of insertion, and
+
+ * in :meth:`~candle.nn.ModuleDict.update`, the order of the merged
+ ``OrderedDict``, ``dict`` (started from Python 3.6) or another
+ :class:`~candle.nn.ModuleDict` (the argument to
+ :meth:`~candle.nn.ModuleDict.update`).
+
+ Note that :meth:`~candle.nn.ModuleDict.update` with other unordered mapping
+ types (e.g., Python's plain ``dict`` before Python version 3.6) does not
+ preserve the order of the merged mapping.
+
+ Args:
+ modules (iterable, optional): a mapping (dictionary) of (string: module)
+ or an iterable of key-value pairs of type (string, module)
+ """
+
+ _modules: Dict[str, Module] # type: ignore[assignment]
+
+ def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
+ super().__init__()
+ if modules is not None:
+ self.update(modules)
+
+ def __getitem__(self, key: str) -> Module:
+ return self._modules[key]
+
+ def __setitem__(self, key: str, module: Module) -> None:
+ self.add_module(key, module)
+
+ def __delitem__(self, key: str) -> None:
+ del self._modules[key]
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(self._modules)
+
+ def __contains__(self, key: str) -> bool:
+ return key in self._modules
+
+ def clear(self) -> None:
+ """Remove all items from the ModuleDict."""
+ self._modules.clear()
+
+ def pop(self, key: str) -> Module:
+ r"""Remove key from the ModuleDict and return its module.
+
+ Args:
+ key (str): key to pop from the ModuleDict
+ """
+ v = self[key]
+ del self[key]
+ return v
+
+ def keys(self) -> Iterable[str]:
+ r"""Return an iterable of the ModuleDict keys."""
+ return self._modules.keys()
+
+ def items(self) -> Iterable[Tuple[str, Module]]:
+ r"""Return an iterable of the ModuleDict key/value pairs."""
+ return self._modules.items()
+
+ def values(self) -> Iterable[Module]:
+ r"""Return an iterable of the ModuleDict values."""
+ return self._modules.values()
+
+ def update(self, modules: Mapping[str, Module]) -> None:
+ r"""Update the :class:`~candle.nn.ModuleDict` with the key-value pairs from a
+ mapping or an iterable, overwriting existing keys.
+
+ .. note::
+ If :attr:`modules` is an ``OrderedDict``, a :class:`~candle.nn.ModuleDict`, or
+ an iterable of key-value pairs, the order of new elements in it is preserved.
+
+ Args:
+ modules (iterable): a mapping (dictionary) from string to :class:`~candle.nn.Module`,
+ or an iterable of key-value pairs of type (string, :class:`~candle.nn.Module`)
+ """
+ if not isinstance(modules, container_abcs.Iterable):
+ raise TypeError(
+ "ModuleDict.update should be called with an "
+ "iterable of key/value pairs, but got " + type(modules).__name__
+ )
+
+ if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
+ for key, module in modules.items():
+ self[key] = module
+ else:
+ # modules here can be a list with two items
+ for j, m in enumerate(modules):
+ if not isinstance(m, container_abcs.Iterable):
+ raise TypeError(
+ "ModuleDict update sequence element "
+ "#" + str(j) + " should be Iterable; is" + type(m).__name__
+ )
+ if not len(m) == 2:
+ raise ValueError(
+ "ModuleDict update sequence element "
+ "#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
+ )
+ # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
+ # that's too cumbersome to type correctly with overloads, so we add an ignore here
+ self[m[0]] = m[1] # type: ignore[assignment]
+
+ # remove forward alltogether to fallback on Module's _forward_unimplemented
diff --git a/candle-pyo3/py_src/candle/nn/linear.py b/candle-pyo3/py_src/candle/nn/linear.py
new file mode 100644
index 00000000..d275eb1e
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/linear.py
@@ -0,0 +1,119 @@
+import math
+from typing import Any
+
+import candle
+from candle import Tensor
+from .module import Module
+
+# See https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/linear.py
+
+
+class Identity(Module):
+ r"""A placeholder identity operator that is argument-insensitive.
+
+ Args:
+ args: any argument (unused)
+ kwargs: any keyword argument (unused)
+
+ Shape:
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
+ - Output: :math:`(*)`, same shape as the input.
+
+ Examples::
+
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
+ >>> input = candle.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.shape)
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__()
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input
+
+
+class Linear(Module):
+ r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
+ Args:
+ in_features: size of each input sample
+ out_features: size of each output sample
+ bias: If set to ``False``, the layer will not learn an additive bias.
+ Default: ``True``
+
+ Shape:
+ - Input: :math:`(*, H_{in})` where :math:`*` means any number of
+ dimensions including none and :math:`H_{in} = \text{in\_features}`.
+ - Output: :math:`(*, H_{out})` where all but the last dimension
+ are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
+
+ Attributes:
+ weight: the learnable weights of the module of shape
+ :math:`(\text{out\_features}, \text{in\_features})`. The values are
+ initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
+ :math:`k = \frac{1}{\text{in\_features}}`
+ bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
+ If :attr:`bias` is ``True``, the values are initialized from
+ :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+ :math:`k = \frac{1}{\text{in\_features}}`
+ """
+
+ __constants__ = ["in_features", "out_features"]
+ in_features: int
+ out_features: int
+ weight: Tensor
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ # Allow 'weight' to be quantized
+ self._quantizable_buffers.add("weight")
+
+ self.in_features = in_features
+ self.out_features = out_features
+ # TODO: Do actual initialization here: e.g. kaiming_uniform or xavier_uniform
+ self.weight = candle.ones((out_features, in_features), **factory_kwargs)
+ if bias:
+ self.bias = candle.zeros((out_features,), **factory_kwargs)
+ else:
+ self.bias = None
+
+ def forward(self, x: Tensor) -> Tensor:
+ dims = x.shape
+ last_dim = dims[-1]
+
+ if isinstance(self.weight, candle.QTensor):
+ if len(dims) < 3:
+ matmul_result = self.weight.matmul_t(x).broadcast_add(self.bias)
+ elif len(dims) == 3:
+ b, n, m = dims
+ output_shape = (b, n, self.out_features)
+ re = x.reshape((b * n, m))
+ matmul_result = self.weight.matmul_t(re).reshape((output_shape))
+ else:
+ raise NotImplementedError("'QTensor.matmul_t' is not implemented for more than 3 dimensions")
+
+ if self.bias:
+ return matmul_result.broadcast_add(self.bias)
+ else:
+ if self.weight.shape[-1] == last_dim and len(dims) < 3:
+ w = self.weight.t()
+ else:
+ batch_size = dims[0]
+ w = self.weight.broadcast_left((batch_size,)).t()
+
+ x = x.matmul(w)
+ if self.bias is not None:
+ x = x.broadcast_add(self.bias)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
diff --git a/candle-pyo3/py_src/candle/nn/module.py b/candle-pyo3/py_src/candle/nn/module.py
new file mode 100644
index 00000000..514d92b8
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/module.py
@@ -0,0 +1,702 @@
+from candle import Tensor, QTensor, DType
+from typing import (
+ Dict,
+ Tuple,
+ Any,
+ Optional,
+ Union,
+ Iterator,
+ Set,
+ overload,
+ Mapping,
+ TypeVar,
+ List,
+)
+from collections import OrderedDict, namedtuple
+
+TensorLike = Union[Tensor, QTensor]
+T = TypeVar("T", bound="Module")
+
+
+class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
+ def __repr__(self):
+ if not self.missing_keys and not self.unexpected_keys:
+ return "<All keys matched successfully>"
+ return super().__repr__()
+
+ __str__ = __repr__
+
+
+# see: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py
+class Module:
+ """
+ Pytorch like Module.
+
+ Base class for all neural network modules.
+
+ Your models should also subclass this class.
+ """
+
+ _modules: Dict[str, Optional["Module"]]
+ _buffers: Dict[str, Optional[TensorLike]]
+ _non_persistent_buffers_set: Set[str]
+ _quantizable_buffers: Set[str]
+ _version: int = 1
+
+ def __init__(self, *args, **kwargs) -> None:
+ """
+ Initializes internal Module state
+ """
+ super().__setattr__("_modules", OrderedDict())
+ super().__setattr__("_buffers", OrderedDict())
+ super().__setattr__("_non_persistent_buffers_set", set())
+ super().__setattr__("_quantizable_buffers", set())
+
+ def __call__(self, *input):
+ """
+ Call self as a function.
+ """
+ return self.forward(*input)
+
+ def forward(self, *input):
+ """
+ Defines the computation performed at every call.
+ Should be overridden by all subclasses.
+ """
+ pass
+
+ def children(self) -> Iterator["Module"]:
+ r"""Returns an iterator over immediate children modules.
+
+ Yields:
+ Module: a child module
+ """
+ for name, module in self.named_children():
+ yield module
+
+ def named_children(self) -> Iterator[Tuple[str, "Module"]]:
+ r"""Returns an iterator over immediate children modules, yielding both
+ the name of the module as well as the module itself.
+
+ Yields:
+ (str, Module): Tuple containing a name and child module
+
+ Example::
+
+ >>> for name, module in model.named_children():
+ >>> if name in ['conv4', 'conv5']:
+ >>> print(module)
+
+ """
+ memo = set()
+ for name, module in self._modules.items():
+ if module is not None and module not in memo:
+ memo.add(module)
+ yield name, module
+
+ def add_module(self, name: str, module: Optional["Module"]) -> None:
+ r"""Adds a child module to the current module.
+
+ The module can be accessed as an attribute using the given name.
+
+ Args:
+ name (str): name of the child module. The child module can be
+ accessed from this module using the given name
+ module (Module): child module to be added to the module.
+ """
+ if not isinstance(module, Module) and module is not None:
+ raise TypeError(f"{str(module)} is not a Module subclass")
+ elif not isinstance(name, str):
+ raise TypeError(f"module name should be a string. Got {name}")
+ elif hasattr(self, name) and name not in self._modules:
+ raise KeyError(f"attribute '{name}' already exists")
+ elif "." in name:
+ raise KeyError(f'module name can\'t contain ".", got: {name}')
+ elif name == "":
+ raise KeyError('module name can\'t be empty string ""')
+ self._modules[name] = module
+
+ def register_module(self, name: str, module: Optional["Module"]) -> None:
+ r"""Alias for :func:`add_module`."""
+ self.add_module(name, module)
+
+ def modules(self) -> Iterator["Module"]:
+ r"""Returns an iterator over all modules in the network."""
+ for _, module in self.named_modules():
+ yield module
+
+ def named_modules(
+ self,
+ memo: Optional[Set["Module"]] = None,
+ prefix: str = "",
+ remove_duplicate: bool = True,
+ ):
+ r"""Returns an iterator over all modules in the network, yielding
+ both the name of the module as well as the module itself.
+
+ Args:
+ memo: a memo to store the set of modules already added to the result
+ prefix: a prefix that will be added to the name of the module
+ remove_duplicate: whether to remove the duplicated module instances in the result
+ or not
+
+ Yields:
+ (str, Module): Tuple of name and module
+
+ Note:
+ Duplicate modules are returned only once. In the following
+ example, ``l`` will be returned only once.
+ """
+
+ if memo is None:
+ memo = set()
+ if self not in memo:
+ if remove_duplicate:
+ memo.add(self)
+ yield prefix, self
+ for name, module in self._modules.items():
+ if module is None:
+ continue
+ submodule_prefix = prefix + ("." if prefix else "") + name
+ for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
+ yield m
+
+ def buffers(self, recurse: bool = True) -> Iterator[TensorLike]:
+ """
+ Returns an iterator over module buffers.
+ """
+ for name, buf in self.named_buffers(recurse=recurse):
+ yield buf
+
+ def named_buffers(
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
+ ) -> Iterator[Tuple[str, TensorLike]]:
+ r"""Returns an iterator over module buffers, yielding both the
+ name of the buffer as well as the buffer itself.
+
+ Args:
+ prefix (str): prefix to prepend to all buffer names.
+ recurse (bool, optional): if True, then yields buffers of this module
+ and all submodules. Otherwise, yields only buffers that
+ are direct members of this module. Defaults to True.
+ remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
+
+ Yields:
+ (str, Tensor): Tuple containing the name and buffer
+
+ Example::
+
+ >>> for name, buf in self.named_buffers():
+ >>> if name in ['running_var']:
+ >>> print(buf.size())
+
+ """
+ gen = self._named_members(
+ lambda module: module._buffers.items(),
+ prefix=prefix,
+ recurse=recurse,
+ remove_duplicate=remove_duplicate,
+ )
+ yield from gen
+
+ # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
+ # back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
+ T_destination = TypeVar("T_destination", bound=Dict[str, Any])
+
+ @overload
+ def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
+ ...
+
+ @overload
+ def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
+ ...
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ r"""Returns a dictionary containing references to the whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+ Parameters and buffers set to ``None`` are not included.
+
+ .. note::
+ The returned object is a shallow copy. It contains references
+ to the module's parameters and buffers.
+
+ .. warning::
+ Currently ``state_dict()`` also accepts positional arguments for
+ ``destination``, ``prefix`` and ``keep_vars`` in order. However,
+ this is being deprecated and keyword arguments will be enforced in
+ future releases.
+
+ .. warning::
+ Please avoid the use of argument ``destination`` as it is not
+ designed for end-users.
+
+ Args:
+ destination (dict, optional): If provided, the state of module will
+ be updated into the dict and the same object is returned.
+ Otherwise, an ``OrderedDict`` will be created and returned.
+ Default: ``None``.
+ prefix (str, optional): a prefix added to parameter and buffer
+ names to compose the keys in state_dict. Default: ``''``.
+ keep_vars (bool, optional): by default the :class:`~candle.Tensor` s
+ returned in the state dict are detached from autograd. If it's
+ set to ``True``, detaching will not be performed.
+ Default: ``False``.
+
+ Returns:
+ dict:
+ a dictionary containing a whole state of the module
+
+ Example::
+
+ >>> # xdoctest: +SKIP("undefined vars")
+ >>> module.state_dict().keys()
+ ['bias', 'weight']
+
+ """
+
+ # TODO: Remove `args` and the parsing logic when BC allows.
+ if len(args) > 0:
+ if destination is None:
+ destination = args[0]
+ if len(args) > 1 and prefix == "":
+ prefix = args[1]
+ if len(args) > 2 and keep_vars is False:
+ keep_vars = args[2]
+
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+
+ local_metadata = dict(version=self._version)
+ if hasattr(destination, "_metadata"):
+ destination._metadata[prefix[:-1]] = local_metadata
+ self._save_to_state_dict(destination, prefix, keep_vars)
+ for name, module in self._modules.items():
+ if module is not None:
+ module.state_dict(
+ destination=destination,
+ prefix=prefix + name + ".",
+ keep_vars=keep_vars,
+ )
+ return destination
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ r"""Saves module state to `destination` dictionary, containing a state
+ of the module, but not its descendants. This is called on every
+ submodule in :meth:`~candle.nn.Module.state_dict`.
+
+ In rare cases, subclasses can achieve class-specific behavior by
+ overriding this method with custom logic.
+
+ Args:
+ destination (dict): a dict where state will be stored
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ """
+ for name, buf in self._buffers.items():
+ if buf is not None and name not in self._non_persistent_buffers_set:
+ if isinstance(buf, Tensor):
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+ else:
+ destination[prefix + name] = buf
+
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
+ r"""Copies parameters and buffers from :attr:`state_dict` into
+ this module and its descendants. If :attr:`strict` is ``True``, then
+ the keys of :attr:`state_dict` must exactly match the keys returned
+ by this module's :meth:`~candle.nn.Module.state_dict` function.
+
+ .. warning::
+ If :attr:`assign` is ``True`` the optimizer must be created after
+ the call to :attr:`load_state_dict`.
+
+ Args:
+ state_dict (dict): a dict containing parameters and
+ persistent buffers.
+ strict (bool, optional): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~candle.nn.Module.state_dict` function. Default: ``True``
+ assign (bool, optional): whether to assign items in the state
+ dictionary to their corresponding keys in the module instead
+ of copying them inplace into the module's current parameters and buffers.
+ When ``False``, the properties of the tensors in the current
+ module are preserved while when ``True``, the properties of the
+ Tensors in the state dict are preserved.
+ Default: ``False``
+
+ Returns:
+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
+ * **missing_keys** is a list of str containing the missing keys
+ * **unexpected_keys** is a list of str containing the unexpected keys
+
+ Note:
+ If a parameter or buffer is registered as ``None`` and its corresponding key
+ exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
+ ``RuntimeError``.
+ """
+ if not isinstance(state_dict, Mapping):
+ raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
+
+ missing_keys: List[str] = []
+ unexpected_keys: List[str] = []
+ error_msgs: List[str] = []
+
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = OrderedDict(state_dict)
+ if metadata is not None:
+ # mypy isn't aware that "_metadata" exists in state_dict
+ state_dict._metadata = metadata # type: ignore[attr-defined]
+
+ def load(module, local_state_dict, prefix=""):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ if assign:
+ local_metadata["assign_to_params_buffers"] = assign
+ module._load_from_state_dict(
+ local_state_dict,
+ prefix,
+ local_metadata,
+ True,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+ for name, child in module._modules.items():
+ if child is not None:
+ child_prefix = prefix + name + "."
+ child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
+ load(child, child_state_dict, child_prefix)
+
+ load(self, state_dict)
+ del load
+
+ if strict:
+ if len(unexpected_keys) > 0:
+ error_msgs.insert(
+ 0,
+ "Unexpected key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in unexpected_keys)),
+ )
+ if len(missing_keys) > 0:
+ error_msgs.insert(
+ 0,
+ "Missing key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in missing_keys)),
+ )
+
+ if len(error_msgs) > 0:
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs))
+ )
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ r"""Copies parameters and buffers from :attr:`state_dict` into only
+ this module, but not its descendants. This is called on every submodule
+ in :meth:`~candle.nn.Module.load_state_dict`. Metadata saved for this
+ module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
+ For state dicts without metadata, :attr:`local_metadata` is empty.
+ Subclasses can achieve class-specific backward compatible loading using
+ the version number at `local_metadata.get("version", None)`.
+ Additionally, :attr:`local_metadata` can also contain the key
+ `assign_to_params_buffers` that indicates whether keys should be
+ assigned their corresponding tensor in the state_dict.
+
+ .. note::
+ :attr:`state_dict` is not the same object as the input
+ :attr:`state_dict` to :meth:`~candle.nn.Module.load_state_dict`. So
+ it can be modified.
+
+ Args:
+ state_dict (dict): a dict containing parameters and
+ persistent buffers.
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ local_metadata (dict): a dict containing the metadata for this module.
+ See
+ strict (bool): whether to strictly enforce that the keys in
+ :attr:`state_dict` with :attr:`prefix` match the names of
+ parameters and buffers in this module
+ missing_keys (list of str): if ``strict=True``, add missing keys to
+ this list
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
+ keys to this list
+ error_msgs (list of str): error messages should be added to this
+ list, and will be reported together in
+ :meth:`~candle.nn.Module.load_state_dict`
+ """
+ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
+ local_name_params = persistent_buffers.items()
+ local_state = {k: v for k, v in local_name_params if v is not None}
+
+ for name, param in local_state.items():
+ key = prefix + name
+ if key in state_dict:
+ input_param = state_dict[key]
+ if not isinstance(input_param, (Tensor, QTensor)):
+ error_msgs.append(
+ f'While copying the parameter named "{key}", '
+ "expected Tensor-like object from checkpoint but "
+ f"received {type(input_param)}"
+ )
+ continue
+
+ if input_param.shape != param.shape:
+ # local shape should match the one in checkpoint
+ error_msgs.append(
+ "size mismatch for {}: copying a param with shape {} from checkpoint, "
+ "the shape in current model is {}.".format(key, input_param.shape, param.shape)
+ )
+ continue
+
+ try:
+ # Shape checks are already done above -> Just assign tensor
+ setattr(self, name, input_param)
+ except Exception as ex:
+ error_msgs.append(
+ f'While copying the parameter named "{key}", '
+ f"whose dimensions in the model are {param.shape} and "
+ f"whose dimensions in the checkpoint are {input_param.shape}, "
+ f"an exception occurred : {ex.args}."
+ )
+ elif strict:
+ missing_keys.append(key)
+
+ if strict:
+ for key in state_dict.keys():
+ if key.startswith(prefix):
+ input_name = key[len(prefix) :]
+ input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
+ if input_name not in self._modules and input_name not in local_state:
+ unexpected_keys.append(key)
+
+ def _named_members(self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True):
+ r"""Helper method for yielding various names + members of modules."""
+ memo = set()
+ modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
+ for module_prefix, module in modules:
+ members = get_members_fn(module)
+ for k, v in members:
+ if v is None or v in memo:
+ continue
+ if remove_duplicate:
+ memo.add(v)
+ name = module_prefix + ("." if module_prefix else "") + k
+ yield name, v
+
+ def _get_name(self):
+ return self.__class__.__name__
+
+ def _apply(self, fn):
+ for module in self.children():
+ module._apply(fn)
+
+ for key, buf in self._buffers.items():
+ if buf is not None:
+ self._buffers[key] = fn(buf)
+
+ return self
+
+ def __move_tensor_to_device(self, tensor: TensorLike, device: str):
+ if isinstance(tensor, Tensor):
+ return tensor.to_device(device)
+ else:
+ raise NotImplementedError("Cannot offload QTensor to cuda, yet!")
+
+ def device(self) -> str:
+ """
+ Gets the device of the module, by inspecting its tensors.
+ """
+ tensor = next(self.buffers())
+ if isinstance(tensor, Tensor):
+ return tensor.device
+ else:
+ # QTensors can only be on the CPU
+ return "cpu"
+
+ def cuda(self: T) -> T:
+ r"""Moves all model parameters and buffers to the GPU.
+
+ This also makes associated parameters and buffers different objects. So
+ it should be called before constructing optimizer if the module will
+ live on GPU while being optimized.
+
+ .. note::
+ This method modifies the module in-place.
+
+ Returns:
+ Module: self
+ """
+
+ def to_cuda(t: TensorLike):
+ return self.__move_tensor_to_device(t, "cuda")
+
+ return self._apply(to_cuda)
+
+ def cpu(self: T) -> T:
+ r"""Moves all model parameters and buffers to the CPU.
+
+ .. note::
+ This method modifies the module in-place.
+
+ Returns:
+ Module: self
+ """
+
+ def to_cpu(t: TensorLike):
+ return self.__move_tensor_to_device(t, "cpu")
+
+ return self._apply(to_cpu)
+
+ def __cast_tensor(self, tensor: TensorLike, dtype: Union[DType, str]):
+ if isinstance(tensor, Tensor):
+ return tensor.to_dtype(dtype)
+ else:
+ raise TypeError("candle.Module.to only accepts Tensor dtypes, but got desired dtype={}".format(dtype))
+
+ def type(self: T, dst_type: Union[DType, str]) -> T:
+ r"""Casts all parameters and buffers to :attr:`dst_type`.
+
+ .. note::
+ This method modifies the module in-place.
+
+ Args:
+ dst_type (type or string): the desired type
+
+ Returns:
+ Module: self
+ """
+
+ def cast(t: TensorLike):
+ return self.__cast_tensor(t, dst_type)
+
+ return self._apply(cast)
+
+ @overload
+ def to(
+ self: T,
+ device: str = ...,
+ dtype: Optional[Union[DType, str]] = ...,
+ ) -> T:
+ ...
+
+ @overload
+ def to(self: T, dtype: Union[DType, str]) -> T:
+ ...
+
+ def to(self, *args, **kwargs):
+ r"""Moves and/or casts the parameters and buffers.
+
+ This can be called as
+
+ .. function:: to(device=None, dtype=None)
+ :noindex:
+
+ .. function:: to(dtype)
+ :noindex:
+
+ See below for examples.
+
+ .. note::
+ This method modifies the module in-place.
+
+ Args:
+ device (:class:`candle.device`): the desired device of the parameters
+ and buffers in this module
+ dtype (:class:`candle.dtype`): the desired floating point dtype of
+ the parameters and buffers in this module
+
+ Returns:
+ Module: self
+ """
+
+ device = None
+ dtype = None
+
+ if args:
+ for arg in args:
+ # Assuming arg can be a string representing a device or a dtype
+
+ if isinstance(arg, str):
+ lower_arg = str(arg).lower()
+ if lower_arg.startswith("cuda") or lower_arg == "cpu":
+ device = lower_arg
+ else:
+ dtype = arg
+ elif isinstance(arg, DType):
+ dtype = str(arg)
+ else:
+ raise TypeError("Module.to() received an invalid combination of arguments. Got: {}".format(args))
+
+ if kwargs:
+ device = kwargs.get("device", device)
+ dtype = str(kwargs.get("dtype", dtype))
+
+ if device:
+ device = device.lower()
+
+ if dtype:
+ dtype = dtype.lower()
+ if dtype not in ["f32", "f16", "f64"]:
+ raise TypeError(
+ "candle.Module.to only accepts floating point" "dtypes, but got desired dtype={}".format(dtype)
+ )
+
+ def convert(t):
+ if dtype:
+ t = self.__cast_tensor(t, dtype)
+ if device:
+ t = self.__move_tensor_to_device(t, device)
+ return t
+
+ return self._apply(convert)
+
+ def __setattr__(self, __name: str, __value: Any) -> None:
+ if isinstance(__value, Module):
+ self._modules[__name] = __value
+ elif isinstance(__value, QTensor):
+ if __name in self._quantizable_buffers:
+ type = __value.ggml_dtype.lower()
+ if type in ["f32", "f16"]:
+ # It is faster to just dequantize the tensor here and use the normal tensor operations
+ dequant = __value.dequantize()
+ if type == "f16":
+ dequant = dequant.to_dtype("f16")
+ self._buffers[__name] = dequant
+ else:
+ self._buffers[__name] = __value
+ else:
+ # We expect a normal tensor here => dequantize it
+ self._buffers[__name] = __value.dequantize()
+ elif isinstance(__value, Tensor):
+ self._buffers[__name] = __value
+ else:
+ super().__setattr__(__name, __value)
+
+ def __getattr__(self, __name: str) -> Any:
+ if "_modules" in self.__dict__:
+ modules = self.__dict__["_modules"]
+ if __name in modules:
+ return modules[__name]
+ if "_buffers" in self.__dict__:
+ tensors = self.__dict__["_buffers"]
+ if __name in tensors:
+ return tensors[__name]
+ return super().__getattribute__(__name)
+
+ def __delattr__(self, name):
+ if name in self._buffers:
+ del self._buffers[name]
+ elif name in self._modules:
+ del self._modules[name]
+ else:
+ super().__delattr__(name)
diff --git a/candle-pyo3/py_src/candle/nn/normalization.py b/candle-pyo3/py_src/candle/nn/normalization.py
new file mode 100644
index 00000000..67510a24
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/normalization.py
@@ -0,0 +1,54 @@
+import candle
+from candle import Tensor
+from .module import Module
+from typing import Union, List, Tuple, Optional, Any
+
+_shape_t = Union[int, List[int]]
+import numbers
+
+
+class LayerNorm(Module):
+ r"""Applies Layer Normalization over a mini-batch of inputs as described in
+ the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`
+
+ math::
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+ """
+ __constants__ = ["normalized_shape", "eps"]
+ normalized_shape: Tuple[int, ...]
+ eps: float
+
+ def __init__(
+ self,
+ normalized_shape: _shape_t,
+ eps: float = 1e-5,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ self.normalized_shape = tuple(normalized_shape)
+ self.eps = eps
+
+ self.weight = candle.ones(normalized_shape, **factory_kwargs)
+ if bias:
+ self.bias = candle.zeros(normalized_shape, **factory_kwargs)
+ else:
+ self.bias = None
+
+ def forward(self, input: Tensor) -> Tensor:
+ mean_x = input.sum_keepdim(2) / float(self.normalized_shape[-1])
+ x = input.broadcast_sub(mean_x)
+ norm_x = x.sqr().sum_keepdim(2) / float(self.normalized_shape[-1])
+ x_normed = x.broadcast_div((norm_x + self.eps).sqrt())
+ x = x_normed.broadcast_mul(self.weight)
+
+ if self.bias:
+ x = x.broadcast_add(self.bias)
+ return x
+
+ def extra_repr(self) -> str:
+ return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)
diff --git a/candle-pyo3/py_src/candle/nn/sparse.py b/candle-pyo3/py_src/candle/nn/sparse.py
new file mode 100644
index 00000000..386f8081
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/sparse.py
@@ -0,0 +1,39 @@
+from .module import Module
+from typing import Optional, Tuple, Any
+from candle import Tensor
+import candle
+
+
+class Embedding(Module):
+ """A simple lookup table that stores embeddings of a fixed dictionary and size.
+
+ This module is often used to store word embeddings and retrieve them using indices.
+ The input to the module is a list of indices, and the output is the corresponding
+ word embeddings.
+
+ Args:
+ num_embeddings (int): size of the dictionary of embeddings
+ embedding_dim (int): the size of each embedding vector
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
+ initialized from :math:`\mathcal{N}(0, 1)`
+
+ Shape:
+ - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
+ - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, device=None) -> None:
+ factory_kwargs = {"device": device}
+ super().__init__()
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.weight = candle.randn((num_embeddings, embedding_dim), **factory_kwargs)
+
+ def forward(self, indexes: Tensor) -> Tensor:
+ final_dims = list(indexes.shape)
+ final_dims.append(self.embedding_dim)
+ indexes = indexes.flatten_all()
+ values = self.weight.index_select(indexes, 0)
+ return values.reshape(final_dims)
diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py
index ea85d2a3..ccdb6238 100644
--- a/candle-pyo3/py_src/candle/typing/__init__.py
+++ b/candle-pyo3/py_src/candle/typing/__init__.py
@@ -2,7 +2,7 @@ from typing import TypeVar, Union, Sequence
_T = TypeVar("_T")
-_ArrayLike = Union[
+_ArrayLike = Union[
_T,
Sequence[_T],
Sequence[Sequence[_T]],
@@ -10,7 +10,7 @@ _ArrayLike = Union[
Sequence[Sequence[Sequence[Sequence[_T]]]],
]
-CPU:str = "cpu"
-CUDA:str = "cuda"
+CPU: str = "cpu"
+CUDA: str = "cuda"
-Device = TypeVar("Device", CPU, CUDA) \ No newline at end of file
+Device = TypeVar("Device", CPU, CUDA)