summaryrefslogtreecommitdiff
path: root/candle-pyo3/tests/native/test_tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/tests/native/test_tensor.py')
-rw-r--r--candle-pyo3/tests/native/test_tensor.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py
index 225a7469..659423e0 100644
--- a/candle-pyo3/tests/native/test_tensor.py
+++ b/candle-pyo3/tests/native/test_tensor.py
@@ -1,5 +1,6 @@
import candle
from candle import Tensor
+from candle.utils import cuda_is_available
import pytest
@@ -75,6 +76,70 @@ def test_tensor_can_be_scliced_3d():
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
+def test_tensor_can_be_cast_via_to():
+ t = Tensor(42.0)
+ assert str(t.dtype) == str(candle.f32)
+ t_new_args = t.to(candle.f64)
+ assert str(t_new_args.dtype) == str(candle.f64)
+ t_new_kwargs = t.to(dtype=candle.f64)
+ assert str(t_new_kwargs.dtype) == str(candle.f64)
+ pytest.raises(TypeError, lambda: t.to("not a dtype"))
+ pytest.raises(TypeError, lambda: t.to(dtype="not a dtype"))
+ pytest.raises(TypeError, lambda: t.to(candle.f64, "not a dtype"))
+ pytest.raises(TypeError, lambda: t.to())
+ pytest.raises(ValueError, lambda: t.to(candle.f16, dtype=candle.f64))
+ pytest.raises(ValueError, lambda: t.to(candle.f16, candle.f16))
+
+ other = Tensor(42.0).to(candle.f64)
+ t_new_other_args = t.to(other)
+ assert str(t_new_other_args.dtype) == str(candle.f64)
+ t_new_other_kwargs = t.to(other=other)
+ assert str(t_new_other_kwargs.dtype) == str(candle.f64)
+
+
+@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
+def test_tensor_can_be_moved_via_to():
+ t = Tensor(42.0)
+ assert t.device == "cpu"
+ t_new_args = t.to("cuda")
+ assert t_new_args.device == "cuda"
+ t_new_kwargs = t.to(device="cuda")
+ assert t_new_kwargs.device == "cuda"
+ pytest.raises(TypeError, lambda: t.to("not a device"))
+ pytest.raises(TypeError, lambda: t.to(device="not a device"))
+ pytest.raises(TypeError, lambda: t.to("cuda", "not a device"))
+ pytest.raises(TypeError, lambda: t.to())
+ pytest.raises(ValueError, lambda: t.to("cuda", device="cpu"))
+ pytest.raises(ValueError, lambda: t.to("cuda", "cuda"))
+
+ other = Tensor(42.0).to("cuda")
+ t_new_other_args = t.to(other)
+ assert t_new_other_args.device == "cuda"
+ t_new_other_kwargs = t.to(other=other)
+ assert t_new_other_kwargs.device == "cuda"
+
+
+@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
+def test_tensor_can_be_moved_and_cast_via_to():
+ t = Tensor(42.0)
+ assert t.device == "cpu"
+ assert str(t.dtype) == str(candle.f32)
+ t_new_args = t.to("cuda", candle.f64)
+ assert t_new_args.device == "cuda"
+ assert str(t_new_args.dtype) == str(candle.f64)
+ t_new_kwargs = t.to(device="cuda", dtype=candle.f64)
+ assert t_new_kwargs.device == "cuda"
+ assert str(t_new_kwargs.dtype) == str(candle.f64)
+
+ other = Tensor(42.0).to("cuda").to(candle.f64)
+ t_new_other_args = t.to(other)
+ assert t_new_other_args.device == "cuda"
+ assert str(t_new_other_args.dtype) == str(candle.f64)
+ t_new_other_kwargs = t.to(other=other)
+ assert t_new_other_kwargs.device == "cuda"
+ assert str(t_new_other_kwargs.dtype) == str(candle.f64)
+
+
def test_tensor_can_be_added():
t = Tensor(42.0)
result = t + t