diff options
Diffstat (limited to 'candle-pyo3/tests/native/test_tensor.py')
-rw-r--r-- | candle-pyo3/tests/native/test_tensor.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 225a7469..659423e0 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -1,5 +1,6 @@ import candle from candle import Tensor +from candle.utils import cuda_is_available import pytest @@ -75,6 +76,70 @@ def test_tensor_can_be_scliced_3d(): assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] +def test_tensor_can_be_cast_via_to(): + t = Tensor(42.0) + assert str(t.dtype) == str(candle.f32) + t_new_args = t.to(candle.f64) + assert str(t_new_args.dtype) == str(candle.f64) + t_new_kwargs = t.to(dtype=candle.f64) + assert str(t_new_kwargs.dtype) == str(candle.f64) + pytest.raises(TypeError, lambda: t.to("not a dtype")) + pytest.raises(TypeError, lambda: t.to(dtype="not a dtype")) + pytest.raises(TypeError, lambda: t.to(candle.f64, "not a dtype")) + pytest.raises(TypeError, lambda: t.to()) + pytest.raises(ValueError, lambda: t.to(candle.f16, dtype=candle.f64)) + pytest.raises(ValueError, lambda: t.to(candle.f16, candle.f16)) + + other = Tensor(42.0).to(candle.f64) + t_new_other_args = t.to(other) + assert str(t_new_other_args.dtype) == str(candle.f64) + t_new_other_kwargs = t.to(other=other) + assert str(t_new_other_kwargs.dtype) == str(candle.f64) + + +@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available") +def test_tensor_can_be_moved_via_to(): + t = Tensor(42.0) + assert t.device == "cpu" + t_new_args = t.to("cuda") + assert t_new_args.device == "cuda" + t_new_kwargs = t.to(device="cuda") + assert t_new_kwargs.device == "cuda" + pytest.raises(TypeError, lambda: t.to("not a device")) + pytest.raises(TypeError, lambda: t.to(device="not a device")) + pytest.raises(TypeError, lambda: t.to("cuda", "not a device")) + pytest.raises(TypeError, lambda: t.to()) + pytest.raises(ValueError, lambda: t.to("cuda", device="cpu")) + pytest.raises(ValueError, lambda: t.to("cuda", "cuda")) + + other = Tensor(42.0).to("cuda") + t_new_other_args = t.to(other) + assert t_new_other_args.device == "cuda" + t_new_other_kwargs = t.to(other=other) + assert t_new_other_kwargs.device == "cuda" + + +@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available") +def test_tensor_can_be_moved_and_cast_via_to(): + t = Tensor(42.0) + assert t.device == "cpu" + assert str(t.dtype) == str(candle.f32) + t_new_args = t.to("cuda", candle.f64) + assert t_new_args.device == "cuda" + assert str(t_new_args.dtype) == str(candle.f64) + t_new_kwargs = t.to(device="cuda", dtype=candle.f64) + assert t_new_kwargs.device == "cuda" + assert str(t_new_kwargs.dtype) == str(candle.f64) + + other = Tensor(42.0).to("cuda").to(candle.f64) + t_new_other_args = t.to(other) + assert t_new_other_args.device == "cuda" + assert str(t_new_other_args.dtype) == str(candle.f64) + t_new_other_kwargs = t.to(other=other) + assert t_new_other_kwargs.device == "cuda" + assert str(t_new_other_kwargs.dtype) == str(candle.f64) + + def test_tensor_can_be_added(): t = Tensor(42.0) result = t + t |