summaryrefslogtreecommitdiff
path: root/candle-pyo3/tests/bindings/test_module.py
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-06 20:01:07 +0200
committerGitHub <noreply@github.com>2023-10-06 19:01:07 +0100
commit904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch)
tree8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/tests/bindings/test_module.py
parentb0442eff8a696d1faba10e23ba645eb11e385116 (diff)
downloadcandle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/tests/bindings/test_module.py')
-rw-r--r--candle-pyo3/tests/bindings/test_module.py161
1 files changed, 161 insertions, 0 deletions
diff --git a/candle-pyo3/tests/bindings/test_module.py b/candle-pyo3/tests/bindings/test_module.py
new file mode 100644
index 00000000..819dae5b
--- /dev/null
+++ b/candle-pyo3/tests/bindings/test_module.py
@@ -0,0 +1,161 @@
+import candle
+from candle import Tensor, QTensor
+from candle.nn import Module, Linear
+from candle.utils import cuda_is_available
+
+import pytest
+
+
+def test_module_can_be_constructed():
+ class A(Module):
+ pass
+
+ a = A()
+ assert a is not None
+ assert len(list(a.buffers())) == 0
+
+
+def test_module_registers_tensors():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.t = Tensor(42.0)
+
+ a = A()
+ named_buffers = dict(a.named_buffers())
+ assert len(named_buffers) == 1
+ assert "t" in named_buffers
+
+
+def test_module_registers_submodules():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = Linear(10, 20)
+
+ a = A()
+ named_modules = dict(a.named_modules())
+ named_buffers = dict(a.named_buffers())
+ assert len(named_buffers) == 2
+ assert "linear" in named_modules
+ assert "linear.weight" in named_buffers
+ assert "linear.bias" in named_buffers
+
+
+def test_module_can_dump_statedict():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = Linear(10, 20)
+ self.t = Tensor(42.0)
+
+ a = A()
+ state_dict = a.state_dict()
+ assert hasattr(state_dict, "_metadata")
+ assert "t" in state_dict
+ assert "linear.weight" in state_dict
+ assert "linear.bias" in state_dict
+ assert len(state_dict) == 3
+
+
+def test_module_can_load_statedict():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = Linear(10, 20)
+ self.t = Tensor(42.0)
+
+ statedict = {
+ "linear.weight": candle.ones((20, 10)),
+ "linear.bias": candle.zeros((20,)),
+ "t": Tensor(42.0),
+ }
+ a = A()
+ a.load_state_dict(statedict)
+
+
+def test_module_throws_on_shape_missmatch():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.t = Tensor(42.0)
+
+ statedict = {
+ "t": candle.ones((20,)),
+ }
+ a = A()
+ with pytest.raises(RuntimeError) as excinfo:
+ a.load_state_dict(statedict)
+ assert "size mismatch" in str(excinfo.value)
+
+
+def test_module_throws_on_missing_key():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.t = Tensor(42.0)
+
+ statedict = {
+ "not_t": Tensor(42.0),
+ }
+
+ a = A()
+ with pytest.raises(RuntimeError) as excinfo:
+ a.load_state_dict(statedict)
+ assert 'Missing key(s) in state_dict: "t".' in str(excinfo.value)
+
+
+def test_module_can_load_quantized_tensors():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.t = candle.randn((16, 256))
+ self._quantizable_buffers.add("t")
+
+ statedict = {
+ "t": candle.ones((16, 256)).quantize("q4_0"),
+ }
+ a = A()
+ a.load_state_dict(statedict)
+ assert isinstance(a.t, QTensor)
+ assert a.t.ggml_dtype == "Q4_0"
+
+
+def test_module_dequantizes_tensors_automaticaly():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.t = candle.randn((16, 256))
+
+ statedict = {
+ "t": candle.ones((16, 256)).quantize("q4_0"),
+ }
+ a = A()
+ a.load_state_dict(statedict)
+ assert isinstance(a.t, Tensor)
+
+
+@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
+def test_module_can_be_moved_to_cuda():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.t = candle.randn((16, 256))
+
+ a = A()
+ a.cuda()
+ assert a.t.device == "cuda"
+
+
+@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
+def test_module_can_be_moved_from_cuda_to_cpu():
+ class A(Module):
+ def __init__(self):
+ super().__init__()
+ self.t = candle.randn((16, 256))
+
+ a = A()
+ a.cuda()
+ assert a.t.device == "cuda"
+ a.cpu()
+ assert a.t.device == "cpu"