diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-09-16 18:23:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-16 17:23:38 +0100 |
commit | 8658df348527cabcd722bfe2e9e48aba3c7f8e96 (patch) | |
tree | 5e7a04ad7650b7872074fee2197eb7db7e248bbb /candle-pyo3/test.py | |
parent | 7cafca835a4bb9a21f3c8111e2f61b7a6b1270fd (diff) | |
download | candle-8658df348527cabcd722bfe2e9e48aba3c7f8e96.tar.gz candle-8658df348527cabcd722bfe2e9e48aba3c7f8e96.tar.bz2 candle-8658df348527cabcd722bfe2e9e48aba3c7f8e96.zip |
Generate `*.pyi` stubs for PyO3 wrapper (#870)
* Begin to generate typehints.
* generate correct stubs
* Correctly include stubs
* Add comments and typhints to static functions
* ensure candle-pyo3 directory
* Make `llama.rope.freq_base` optional
* `fmt`
Diffstat (limited to 'candle-pyo3/test.py')
-rw-r--r-- | candle-pyo3/test.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 7f24b49d..c78ffc41 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,4 +1,5 @@ import candle +from candle import Tensor, QTensor t = candle.Tensor(42.0) print(t) @@ -9,7 +10,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) print(t) print(t+t) -t = t.reshape([2, 4]) +t:Tensor = t.reshape([2, 4]) print(t.matmul(t.t())) print(t.to_dtype(candle.u8)) @@ -20,7 +21,7 @@ print(t) print(t.dtype) t = candle.randn((16, 256)) -quant_t = t.quantize("q6k") -dequant_t = quant_t.dequantize() -diff2 = (t - dequant_t).sqr() +quant_t:QTensor = t.quantize("q6k") +dequant_t:Tensor = quant_t.dequantize() +diff2:Tensor = (t - dequant_t).sqr() print(diff2.mean_all()) |