diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-06 20:01:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-06 19:01:07 +0100 |
commit | 904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch) | |
tree | 8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/tests/bindings/test_module.py | |
parent | b0442eff8a696d1faba10e23ba645eb11e385116 (diff) | |
download | candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2 candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip |
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations
* Add `state_dict` and `load_state_dict` functionality
* Move modules around and create `candle.nn.Linear`
* Add `nn.Embedding` and `nn.LayerNorm`
* Add BERT implementation
* Batch q-matmul
* Automatically dequantize `QTensors` if a `Tensor` is expected
* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality
* Unittests for `Module`, `Tensor` and `candle.utils`
* Add `pytorch` like slicing to `Tensor`
* Cleanup and BERT fixes
* `black` formatting + unit-test for `nn.Linear`
* Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/tests/bindings/test_module.py')
-rw-r--r-- | candle-pyo3/tests/bindings/test_module.py | 161 |
1 files changed, 161 insertions, 0 deletions
diff --git a/candle-pyo3/tests/bindings/test_module.py b/candle-pyo3/tests/bindings/test_module.py new file mode 100644 index 00000000..819dae5b --- /dev/null +++ b/candle-pyo3/tests/bindings/test_module.py @@ -0,0 +1,161 @@ +import candle +from candle import Tensor, QTensor +from candle.nn import Module, Linear +from candle.utils import cuda_is_available + +import pytest + + +def test_module_can_be_constructed(): + class A(Module): + pass + + a = A() + assert a is not None + assert len(list(a.buffers())) == 0 + + +def test_module_registers_tensors(): + class A(Module): + def __init__(self): + super().__init__() + self.t = Tensor(42.0) + + a = A() + named_buffers = dict(a.named_buffers()) + assert len(named_buffers) == 1 + assert "t" in named_buffers + + +def test_module_registers_submodules(): + class A(Module): + def __init__(self): + super().__init__() + self.linear = Linear(10, 20) + + a = A() + named_modules = dict(a.named_modules()) + named_buffers = dict(a.named_buffers()) + assert len(named_buffers) == 2 + assert "linear" in named_modules + assert "linear.weight" in named_buffers + assert "linear.bias" in named_buffers + + +def test_module_can_dump_statedict(): + class A(Module): + def __init__(self): + super().__init__() + self.linear = Linear(10, 20) + self.t = Tensor(42.0) + + a = A() + state_dict = a.state_dict() + assert hasattr(state_dict, "_metadata") + assert "t" in state_dict + assert "linear.weight" in state_dict + assert "linear.bias" in state_dict + assert len(state_dict) == 3 + + +def test_module_can_load_statedict(): + class A(Module): + def __init__(self): + super().__init__() + self.linear = Linear(10, 20) + self.t = Tensor(42.0) + + statedict = { + "linear.weight": candle.ones((20, 10)), + "linear.bias": candle.zeros((20,)), + "t": Tensor(42.0), + } + a = A() + a.load_state_dict(statedict) + + +def test_module_throws_on_shape_missmatch(): + class A(Module): + def __init__(self): + super().__init__() + self.t = Tensor(42.0) + + statedict = { + "t": candle.ones((20,)), + } + a = A() + with pytest.raises(RuntimeError) as excinfo: + a.load_state_dict(statedict) + assert "size mismatch" in str(excinfo.value) + + +def test_module_throws_on_missing_key(): + class A(Module): + def __init__(self): + super().__init__() + self.t = Tensor(42.0) + + statedict = { + "not_t": Tensor(42.0), + } + + a = A() + with pytest.raises(RuntimeError) as excinfo: + a.load_state_dict(statedict) + assert 'Missing key(s) in state_dict: "t".' in str(excinfo.value) + + +def test_module_can_load_quantized_tensors(): + class A(Module): + def __init__(self): + super().__init__() + self.t = candle.randn((16, 256)) + self._quantizable_buffers.add("t") + + statedict = { + "t": candle.ones((16, 256)).quantize("q4_0"), + } + a = A() + a.load_state_dict(statedict) + assert isinstance(a.t, QTensor) + assert a.t.ggml_dtype == "Q4_0" + + +def test_module_dequantizes_tensors_automaticaly(): + class A(Module): + def __init__(self): + super().__init__() + self.t = candle.randn((16, 256)) + + statedict = { + "t": candle.ones((16, 256)).quantize("q4_0"), + } + a = A() + a.load_state_dict(statedict) + assert isinstance(a.t, Tensor) + + +@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available") +def test_module_can_be_moved_to_cuda(): + class A(Module): + def __init__(self): + super().__init__() + self.t = candle.randn((16, 256)) + + a = A() + a.cuda() + assert a.t.device == "cuda" + + +@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available") +def test_module_can_be_moved_from_cuda_to_cpu(): + class A(Module): + def __init__(self): + super().__init__() + self.t = candle.randn((16, 256)) + + a = A() + a.cuda() + assert a.t.device == "cuda" + a.cpu() + assert a.t.device == "cpu" |