diff options
Diffstat (limited to 'candle-pyo3/test.py')
-rw-r--r-- | candle-pyo3/test.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 7f24b49d..c78ffc41 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,4 +1,5 @@ import candle +from candle import Tensor, QTensor t = candle.Tensor(42.0) print(t) @@ -9,7 +10,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) print(t) print(t+t) -t = t.reshape([2, 4]) +t:Tensor = t.reshape([2, 4]) print(t.matmul(t.t())) print(t.to_dtype(candle.u8)) @@ -20,7 +21,7 @@ print(t) print(t.dtype) t = candle.randn((16, 256)) -quant_t = t.quantize("q6k") -dequant_t = quant_t.dequantize() -diff2 = (t - dequant_t).sqr() +quant_t:QTensor = t.quantize("q6k") +dequant_t:Tensor = quant_t.dequantize() +diff2:Tensor = (t - dequant_t).sqr() print(diff2.mean_all()) |