diff options
Diffstat (limited to 'candle-pyo3/py_src/candle/nn/container.py')
-rw-r--r-- | candle-pyo3/py_src/candle/nn/container.py | 483 |
1 files changed, 483 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/nn/container.py b/candle-pyo3/py_src/candle/nn/container.py new file mode 100644 index 00000000..15ed8dd2 --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/container.py @@ -0,0 +1,483 @@ +# see https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py +from .module import Module +from typing import ( + Any, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + overload, + Tuple, + TypeVar, + Union, +) +from collections import OrderedDict, abc as container_abcs +import operator +from itertools import chain, islice + +__all__ = ["Sequential", "ModuleList", "ModuleDict"] + +T = TypeVar("T", bound=Module) + + +def _addindent(s_: str, numSpaces: int): + s = s_.split("\n") + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s + + +class Sequential(Module): + r"""A sequential container. + Modules will be added to it in the order they are passed in the + constructor. Alternatively, an ``OrderedDict`` of modules can be + passed in. The ``forward()`` method of ``Sequential`` accepts any + input and forwards it to the first module it contains. It then + "chains" outputs to inputs sequentially for each subsequent module, + finally returning the output of the last module. + + The value a ``Sequential`` provides over manually calling a sequence + of modules is that it allows treating the whole container as a + single module, such that performing a transformation on the + ``Sequential`` applies to each of the modules it stores (which are + each a registered submodule of the ``Sequential``). + + What's the difference between a ``Sequential`` and a + :class:`candle.nn.ModuleList`? A ``ModuleList`` is exactly what it + sounds like--a list for storing ``Module`` s! On the other hand, + the layers in a ``Sequential`` are connected in a cascading way. + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + @overload + def __init__(self, *args: Module) -> None: + ... + + @overload + def __init__(self, arg: "OrderedDict[str, Module]") -> None: + ... + + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.add_module(key, module) + else: + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def _get_item_by_idx(self, iterator, idx) -> T: + """Get the idx-th item of the iterator""" + size = len(self) + idx = operator.index(idx) + if not -size <= idx < size: + raise IndexError("index {} is out of range".format(idx)) + idx %= size + return next(islice(iterator, idx, None)) + + def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]: + if isinstance(idx, slice): + return self.__class__(OrderedDict(list(self._modules.items())[idx])) + else: + return self._get_item_by_idx(self._modules.values(), idx) + + def __setitem__(self, idx: int, module: Module) -> None: + key: str = self._get_item_by_idx(self._modules.keys(), idx) + return setattr(self, key, module) + + def __delitem__(self, idx: Union[slice, int]) -> None: + if isinstance(idx, slice): + for key in list(self._modules.keys())[idx]: + delattr(self, key) + else: + key = self._get_item_by_idx(self._modules.keys(), idx) + delattr(self, key) + # To preserve numbering + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __add__(self, other) -> "Sequential": + if isinstance(other, Sequential): + ret = Sequential() + for layer in self: + ret.append(layer) + for layer in other: + ret.append(layer) + return ret + else: + raise ValueError( + "add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other))) + ) + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def __iadd__(self, other) -> "Sequential": + if isinstance(other, Sequential): + offset = len(self) + for i, module in enumerate(other): + self.add_module(str(i + offset), module) + return self + else: + raise ValueError( + "add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other))) + ) + + def __mul__(self, other: int) -> "Sequential": + if not isinstance(other, int): + raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") + elif other <= 0: + raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + else: + combined = Sequential() + offset = 0 + for _ in range(other): + for module in self: + combined.add_module(str(offset), module) + offset += 1 + return combined + + def __rmul__(self, other: int) -> "Sequential": + return self.__mul__(other) + + def __imul__(self, other: int) -> "Sequential": + if not isinstance(other, int): + raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") + elif other <= 0: + raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + else: + len_original = len(self) + offset = len(self) + for _ in range(other - 1): + for i in range(len_original): + self.add_module(str(i + offset), self._modules[str(i)]) + offset += len_original + return self + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + # NB: We can't really type check this function as the type of input + # may change dynamically (as is tested in + # TestScript.test_sequential_intermediary_types). Cannot annotate + # with Any as TorchScript expects a more precise type + def forward(self, input): + for module in self: + input = module(input) + return input + + def append(self, module: Module) -> "Sequential": + r"""Appends a given module to the end. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + def insert(self, index: int, module: Module) -> "Sequential": + if not isinstance(module, Module): + raise AssertionError("module should be of type: {}".format(Module)) + n = len(self._modules) + if not (-n <= index <= n): + raise IndexError("Index out of range: {}".format(index)) + if index < 0: + index += n + for i in range(n, index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + return self + + def extend(self, sequential) -> "Sequential": + for layer in sequential: + self.append(layer) + return self + + +class ModuleList(Module): + r"""Holds submodules in a list. + + :class:`~candle.nn.ModuleList` can be indexed like a regular Python list, but + modules it contains are properly registered, and will be visible by all + :class:`~candle.nn.Module` methods. + + Args: + modules (iterable, optional): an iterable of modules to add + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: + super().__init__() + if modules is not None: + self += modules + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError("index {} is out of range".format(idx)) + if idx < 0: + idx += len(self) + return str(idx) + + def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]: + if isinstance(idx, slice): + return self.__class__(list(self._modules.values())[idx]) + else: + return self._modules[self._get_abs_string_index(idx)] + + def __setitem__(self, idx: int, module: Module) -> None: + idx = self._get_abs_string_index(idx) + return setattr(self, str(idx), module) + + def __delitem__(self, idx: Union[int, slice]) -> None: + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + delattr(self, str(k)) + else: + delattr(self, self._get_abs_string_index(idx)) + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Module]: + return iter(self._modules.values()) + + def __iadd__(self, modules: Iterable[Module]) -> "ModuleList": + return self.extend(modules) + + def __add__(self, other: Iterable[Module]) -> "ModuleList": + combined = ModuleList() + for i, module in enumerate(chain(self, other)): + combined.add_module(str(i), module) + return combined + + def __repr__(self): + """A custom repr for ModuleList that compresses repeated module representations""" + list_of_reprs = [repr(item) for item in self] + if len(list_of_reprs) == 0: + return self._get_name() + "()" + + start_end_indices = [[0, 0]] + repeated_blocks = [list_of_reprs[0]] + for i, r in enumerate(list_of_reprs[1:], 1): + if r == repeated_blocks[-1]: + start_end_indices[-1][1] += 1 + continue + + start_end_indices.append([i, i]) + repeated_blocks.append(r) + + lines = [] + main_str = self._get_name() + "(" + for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): + local_repr = f"({start_id}): {b}" # default repr + + if start_id != end_id: + n = end_id - start_id + 1 + local_repr = f"({start_id}-{end_id}): {n} x {b}" + + local_repr = _addindent(local_repr, 2) + lines.append(local_repr) + + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def insert(self, index: int, module: Module) -> None: + r"""Insert a given module before a given index in the list. + + Args: + index (int): index to insert. + module (nn.Module): module to insert + """ + for i in range(len(self._modules), index, -1): + self._modules[str(i)] = self._modules[str(i - 1)] + self._modules[str(index)] = module + + def append(self, module: Module) -> "ModuleList": + r"""Appends a given module to the end of the list. + + Args: + module (nn.Module): module to append + """ + self.add_module(str(len(self)), module) + return self + + def pop(self, key: Union[int, slice]) -> Module: + v = self[key] + del self[key] + return v + + def extend(self, modules: Iterable[Module]) -> "ModuleList": + r"""Appends modules from a Python iterable to the end of the list. + + Args: + modules (iterable): iterable of modules to append + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleList.extend should be called with an " "iterable, but got " + type(modules).__name__ + ) + offset = len(self) + for i, module in enumerate(modules): + self.add_module(str(offset + i), module) + return self + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ModuleDict(Module): + r"""Holds submodules in a dictionary. + + :class:`~candle.nn.ModuleDict` can be indexed like a regular Python dictionary, + but modules it contains are properly registered, and will be visible by all + :class:`~candle.nn.Module` methods. + + :class:`~candle.nn.ModuleDict` is an **ordered** dictionary that respects + + * the order of insertion, and + + * in :meth:`~candle.nn.ModuleDict.update`, the order of the merged + ``OrderedDict``, ``dict`` (started from Python 3.6) or another + :class:`~candle.nn.ModuleDict` (the argument to + :meth:`~candle.nn.ModuleDict.update`). + + Note that :meth:`~candle.nn.ModuleDict.update` with other unordered mapping + types (e.g., Python's plain ``dict`` before Python version 3.6) does not + preserve the order of the merged mapping. + + Args: + modules (iterable, optional): a mapping (dictionary) of (string: module) + or an iterable of key-value pairs of type (string, module) + """ + + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + super().__init__() + if modules is not None: + self.update(modules) + + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + def pop(self, key: str) -> Module: + r"""Remove key from the ModuleDict and return its module. + + Args: + key (str): key to pop from the ModuleDict + """ + v = self[key] + del self[key] + return v + + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ModuleDict keys.""" + return self._modules.keys() + + def items(self) -> Iterable[Tuple[str, Module]]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return self._modules.items() + + def values(self) -> Iterable[Module]: + r"""Return an iterable of the ModuleDict values.""" + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + r"""Update the :class:`~candle.nn.ModuleDict` with the key-value pairs from a + mapping or an iterable, overwriting existing keys. + + .. note:: + If :attr:`modules` is an ``OrderedDict``, a :class:`~candle.nn.ModuleDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + modules (iterable): a mapping (dictionary) from string to :class:`~candle.nn.Module`, + or an iterable of key-value pairs of type (string, :class:`~candle.nn.Module`) + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError( + "ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + type(modules).__name__ + ) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError( + "ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(m).__name__ + ) + if not len(m) == 2: + raise ValueError( + "ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" + ) + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + # remove forward alltogether to fallback on Module's _forward_unimplemented |