summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/nn/container.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/py_src/candle/nn/container.py')
-rw-r--r--candle-pyo3/py_src/candle/nn/container.py483
1 files changed, 483 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/nn/container.py b/candle-pyo3/py_src/candle/nn/container.py
new file mode 100644
index 00000000..15ed8dd2
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/container.py
@@ -0,0 +1,483 @@
+# see https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
+from .module import Module
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ Iterator,
+ Mapping,
+ Optional,
+ overload,
+ Tuple,
+ TypeVar,
+ Union,
+)
+from collections import OrderedDict, abc as container_abcs
+import operator
+from itertools import chain, islice
+
+__all__ = ["Sequential", "ModuleList", "ModuleDict"]
+
+T = TypeVar("T", bound=Module)
+
+
+def _addindent(s_: str, numSpaces: int):
+ s = s_.split("\n")
+ # don't do anything for single-line stuff
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(numSpaces * " ") + line for line in s]
+ s = "\n".join(s)
+ s = first + "\n" + s
+ return s
+
+
+class Sequential(Module):
+ r"""A sequential container.
+ Modules will be added to it in the order they are passed in the
+ constructor. Alternatively, an ``OrderedDict`` of modules can be
+ passed in. The ``forward()`` method of ``Sequential`` accepts any
+ input and forwards it to the first module it contains. It then
+ "chains" outputs to inputs sequentially for each subsequent module,
+ finally returning the output of the last module.
+
+ The value a ``Sequential`` provides over manually calling a sequence
+ of modules is that it allows treating the whole container as a
+ single module, such that performing a transformation on the
+ ``Sequential`` applies to each of the modules it stores (which are
+ each a registered submodule of the ``Sequential``).
+
+ What's the difference between a ``Sequential`` and a
+ :class:`candle.nn.ModuleList`? A ``ModuleList`` is exactly what it
+ sounds like--a list for storing ``Module`` s! On the other hand,
+ the layers in a ``Sequential`` are connected in a cascading way.
+ """
+
+ _modules: Dict[str, Module] # type: ignore[assignment]
+
+ @overload
+ def __init__(self, *args: Module) -> None:
+ ...
+
+ @overload
+ def __init__(self, arg: "OrderedDict[str, Module]") -> None:
+ ...
+
+ def __init__(self, *args):
+ super().__init__()
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
+ for key, module in args[0].items():
+ self.add_module(key, module)
+ else:
+ for idx, module in enumerate(args):
+ self.add_module(str(idx), module)
+
+ def _get_item_by_idx(self, iterator, idx) -> T:
+ """Get the idx-th item of the iterator"""
+ size = len(self)
+ idx = operator.index(idx)
+ if not -size <= idx < size:
+ raise IndexError("index {} is out of range".format(idx))
+ idx %= size
+ return next(islice(iterator, idx, None))
+
+ def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]:
+ if isinstance(idx, slice):
+ return self.__class__(OrderedDict(list(self._modules.items())[idx]))
+ else:
+ return self._get_item_by_idx(self._modules.values(), idx)
+
+ def __setitem__(self, idx: int, module: Module) -> None:
+ key: str = self._get_item_by_idx(self._modules.keys(), idx)
+ return setattr(self, key, module)
+
+ def __delitem__(self, idx: Union[slice, int]) -> None:
+ if isinstance(idx, slice):
+ for key in list(self._modules.keys())[idx]:
+ delattr(self, key)
+ else:
+ key = self._get_item_by_idx(self._modules.keys(), idx)
+ delattr(self, key)
+ # To preserve numbering
+ str_indices = [str(i) for i in range(len(self._modules))]
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __add__(self, other) -> "Sequential":
+ if isinstance(other, Sequential):
+ ret = Sequential()
+ for layer in self:
+ ret.append(layer)
+ for layer in other:
+ ret.append(layer)
+ return ret
+ else:
+ raise ValueError(
+ "add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other)))
+ )
+
+ def pop(self, key: Union[int, slice]) -> Module:
+ v = self[key]
+ del self[key]
+ return v
+
+ def __iadd__(self, other) -> "Sequential":
+ if isinstance(other, Sequential):
+ offset = len(self)
+ for i, module in enumerate(other):
+ self.add_module(str(i + offset), module)
+ return self
+ else:
+ raise ValueError(
+ "add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other)))
+ )
+
+ def __mul__(self, other: int) -> "Sequential":
+ if not isinstance(other, int):
+ raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
+ elif other <= 0:
+ raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
+ else:
+ combined = Sequential()
+ offset = 0
+ for _ in range(other):
+ for module in self:
+ combined.add_module(str(offset), module)
+ offset += 1
+ return combined
+
+ def __rmul__(self, other: int) -> "Sequential":
+ return self.__mul__(other)
+
+ def __imul__(self, other: int) -> "Sequential":
+ if not isinstance(other, int):
+ raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
+ elif other <= 0:
+ raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
+ else:
+ len_original = len(self)
+ offset = len(self)
+ for _ in range(other - 1):
+ for i in range(len_original):
+ self.add_module(str(i + offset), self._modules[str(i)])
+ offset += len_original
+ return self
+
+ def __dir__(self):
+ keys = super().__dir__()
+ keys = [key for key in keys if not key.isdigit()]
+ return keys
+
+ def __iter__(self) -> Iterator[Module]:
+ return iter(self._modules.values())
+
+ # NB: We can't really type check this function as the type of input
+ # may change dynamically (as is tested in
+ # TestScript.test_sequential_intermediary_types). Cannot annotate
+ # with Any as TorchScript expects a more precise type
+ def forward(self, input):
+ for module in self:
+ input = module(input)
+ return input
+
+ def append(self, module: Module) -> "Sequential":
+ r"""Appends a given module to the end.
+
+ Args:
+ module (nn.Module): module to append
+ """
+ self.add_module(str(len(self)), module)
+ return self
+
+ def insert(self, index: int, module: Module) -> "Sequential":
+ if not isinstance(module, Module):
+ raise AssertionError("module should be of type: {}".format(Module))
+ n = len(self._modules)
+ if not (-n <= index <= n):
+ raise IndexError("Index out of range: {}".format(index))
+ if index < 0:
+ index += n
+ for i in range(n, index, -1):
+ self._modules[str(i)] = self._modules[str(i - 1)]
+ self._modules[str(index)] = module
+ return self
+
+ def extend(self, sequential) -> "Sequential":
+ for layer in sequential:
+ self.append(layer)
+ return self
+
+
+class ModuleList(Module):
+ r"""Holds submodules in a list.
+
+ :class:`~candle.nn.ModuleList` can be indexed like a regular Python list, but
+ modules it contains are properly registered, and will be visible by all
+ :class:`~candle.nn.Module` methods.
+
+ Args:
+ modules (iterable, optional): an iterable of modules to add
+
+ Example::
+
+ class MyModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
+
+ def forward(self, x):
+ # ModuleList can act as an iterable, or be indexed using ints
+ for i, l in enumerate(self.linears):
+ x = self.linears[i // 2](x) + l(x)
+ return x
+ """
+
+ _modules: Dict[str, Module] # type: ignore[assignment]
+
+ def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
+ super().__init__()
+ if modules is not None:
+ self += modules
+
+ def _get_abs_string_index(self, idx):
+ """Get the absolute index for the list of modules"""
+ idx = operator.index(idx)
+ if not (-len(self) <= idx < len(self)):
+ raise IndexError("index {} is out of range".format(idx))
+ if idx < 0:
+ idx += len(self)
+ return str(idx)
+
+ def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
+ if isinstance(idx, slice):
+ return self.__class__(list(self._modules.values())[idx])
+ else:
+ return self._modules[self._get_abs_string_index(idx)]
+
+ def __setitem__(self, idx: int, module: Module) -> None:
+ idx = self._get_abs_string_index(idx)
+ return setattr(self, str(idx), module)
+
+ def __delitem__(self, idx: Union[int, slice]) -> None:
+ if isinstance(idx, slice):
+ for k in range(len(self._modules))[idx]:
+ delattr(self, str(k))
+ else:
+ delattr(self, self._get_abs_string_index(idx))
+ # To preserve numbering, self._modules is being reconstructed with modules after deletion
+ str_indices = [str(i) for i in range(len(self._modules))]
+ self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __iter__(self) -> Iterator[Module]:
+ return iter(self._modules.values())
+
+ def __iadd__(self, modules: Iterable[Module]) -> "ModuleList":
+ return self.extend(modules)
+
+ def __add__(self, other: Iterable[Module]) -> "ModuleList":
+ combined = ModuleList()
+ for i, module in enumerate(chain(self, other)):
+ combined.add_module(str(i), module)
+ return combined
+
+ def __repr__(self):
+ """A custom repr for ModuleList that compresses repeated module representations"""
+ list_of_reprs = [repr(item) for item in self]
+ if len(list_of_reprs) == 0:
+ return self._get_name() + "()"
+
+ start_end_indices = [[0, 0]]
+ repeated_blocks = [list_of_reprs[0]]
+ for i, r in enumerate(list_of_reprs[1:], 1):
+ if r == repeated_blocks[-1]:
+ start_end_indices[-1][1] += 1
+ continue
+
+ start_end_indices.append([i, i])
+ repeated_blocks.append(r)
+
+ lines = []
+ main_str = self._get_name() + "("
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
+ local_repr = f"({start_id}): {b}" # default repr
+
+ if start_id != end_id:
+ n = end_id - start_id + 1
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
+
+ local_repr = _addindent(local_repr, 2)
+ lines.append(local_repr)
+
+ main_str += "\n " + "\n ".join(lines) + "\n"
+ main_str += ")"
+ return main_str
+
+ def __dir__(self):
+ keys = super().__dir__()
+ keys = [key for key in keys if not key.isdigit()]
+ return keys
+
+ def insert(self, index: int, module: Module) -> None:
+ r"""Insert a given module before a given index in the list.
+
+ Args:
+ index (int): index to insert.
+ module (nn.Module): module to insert
+ """
+ for i in range(len(self._modules), index, -1):
+ self._modules[str(i)] = self._modules[str(i - 1)]
+ self._modules[str(index)] = module
+
+ def append(self, module: Module) -> "ModuleList":
+ r"""Appends a given module to the end of the list.
+
+ Args:
+ module (nn.Module): module to append
+ """
+ self.add_module(str(len(self)), module)
+ return self
+
+ def pop(self, key: Union[int, slice]) -> Module:
+ v = self[key]
+ del self[key]
+ return v
+
+ def extend(self, modules: Iterable[Module]) -> "ModuleList":
+ r"""Appends modules from a Python iterable to the end of the list.
+
+ Args:
+ modules (iterable): iterable of modules to append
+ """
+ if not isinstance(modules, container_abcs.Iterable):
+ raise TypeError(
+ "ModuleList.extend should be called with an " "iterable, but got " + type(modules).__name__
+ )
+ offset = len(self)
+ for i, module in enumerate(modules):
+ self.add_module(str(offset + i), module)
+ return self
+
+ # remove forward alltogether to fallback on Module's _forward_unimplemented
+
+
+class ModuleDict(Module):
+ r"""Holds submodules in a dictionary.
+
+ :class:`~candle.nn.ModuleDict` can be indexed like a regular Python dictionary,
+ but modules it contains are properly registered, and will be visible by all
+ :class:`~candle.nn.Module` methods.
+
+ :class:`~candle.nn.ModuleDict` is an **ordered** dictionary that respects
+
+ * the order of insertion, and
+
+ * in :meth:`~candle.nn.ModuleDict.update`, the order of the merged
+ ``OrderedDict``, ``dict`` (started from Python 3.6) or another
+ :class:`~candle.nn.ModuleDict` (the argument to
+ :meth:`~candle.nn.ModuleDict.update`).
+
+ Note that :meth:`~candle.nn.ModuleDict.update` with other unordered mapping
+ types (e.g., Python's plain ``dict`` before Python version 3.6) does not
+ preserve the order of the merged mapping.
+
+ Args:
+ modules (iterable, optional): a mapping (dictionary) of (string: module)
+ or an iterable of key-value pairs of type (string, module)
+ """
+
+ _modules: Dict[str, Module] # type: ignore[assignment]
+
+ def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
+ super().__init__()
+ if modules is not None:
+ self.update(modules)
+
+ def __getitem__(self, key: str) -> Module:
+ return self._modules[key]
+
+ def __setitem__(self, key: str, module: Module) -> None:
+ self.add_module(key, module)
+
+ def __delitem__(self, key: str) -> None:
+ del self._modules[key]
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(self._modules)
+
+ def __contains__(self, key: str) -> bool:
+ return key in self._modules
+
+ def clear(self) -> None:
+ """Remove all items from the ModuleDict."""
+ self._modules.clear()
+
+ def pop(self, key: str) -> Module:
+ r"""Remove key from the ModuleDict and return its module.
+
+ Args:
+ key (str): key to pop from the ModuleDict
+ """
+ v = self[key]
+ del self[key]
+ return v
+
+ def keys(self) -> Iterable[str]:
+ r"""Return an iterable of the ModuleDict keys."""
+ return self._modules.keys()
+
+ def items(self) -> Iterable[Tuple[str, Module]]:
+ r"""Return an iterable of the ModuleDict key/value pairs."""
+ return self._modules.items()
+
+ def values(self) -> Iterable[Module]:
+ r"""Return an iterable of the ModuleDict values."""
+ return self._modules.values()
+
+ def update(self, modules: Mapping[str, Module]) -> None:
+ r"""Update the :class:`~candle.nn.ModuleDict` with the key-value pairs from a
+ mapping or an iterable, overwriting existing keys.
+
+ .. note::
+ If :attr:`modules` is an ``OrderedDict``, a :class:`~candle.nn.ModuleDict`, or
+ an iterable of key-value pairs, the order of new elements in it is preserved.
+
+ Args:
+ modules (iterable): a mapping (dictionary) from string to :class:`~candle.nn.Module`,
+ or an iterable of key-value pairs of type (string, :class:`~candle.nn.Module`)
+ """
+ if not isinstance(modules, container_abcs.Iterable):
+ raise TypeError(
+ "ModuleDict.update should be called with an "
+ "iterable of key/value pairs, but got " + type(modules).__name__
+ )
+
+ if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
+ for key, module in modules.items():
+ self[key] = module
+ else:
+ # modules here can be a list with two items
+ for j, m in enumerate(modules):
+ if not isinstance(m, container_abcs.Iterable):
+ raise TypeError(
+ "ModuleDict update sequence element "
+ "#" + str(j) + " should be Iterable; is" + type(m).__name__
+ )
+ if not len(m) == 2:
+ raise ValueError(
+ "ModuleDict update sequence element "
+ "#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
+ )
+ # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
+ # that's too cumbersome to type correctly with overloads, so we add an ignore here
+ self[m[0]] = m[1] # type: ignore[assignment]
+
+ # remove forward alltogether to fallback on Module's _forward_unimplemented