summaryrefslogtreecommitdiff
path: root/candle-pyo3/tests/native/test_tensor.py
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-17 11:57:12 +0200
committerGitHub <noreply@github.com>2023-10-17 10:57:12 +0100
commitb355ab4e2e52b077e71aac46c286fbce033f36d6 (patch)
tree27f32bbcb5e0aa16ed14790bd3f5b37ae26ddf26 /candle-pyo3/tests/native/test_tensor.py
parent2fe24ac5b172526c25b07674b38075f8da20815f (diff)
downloadcandle-b355ab4e2e52b077e71aac46c286fbce033f36d6.tar.gz
candle-b355ab4e2e52b077e71aac46c286fbce033f36d6.tar.bz2
candle-b355ab4e2e52b077e71aac46c286fbce033f36d6.zip
Always broadcast magic methods (#1101)
Diffstat (limited to 'candle-pyo3/tests/native/test_tensor.py')
-rw-r--r--candle-pyo3/tests/native/test_tensor.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py
index 1f5b74f6..225a7469 100644
--- a/candle-pyo3/tests/native/test_tensor.py
+++ b/candle-pyo3/tests/native/test_tensor.py
@@ -1,5 +1,6 @@
import candle
from candle import Tensor
+import pytest
def test_tensor_can_be_constructed():
@@ -72,3 +73,75 @@ def test_tensor_can_be_scliced_3d():
assert t[:, 0, 0].values() == [1, 9]
assert t[..., 0].values() == [[1, 5], [9, 13]]
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
+
+
+def test_tensor_can_be_added():
+ t = Tensor(42.0)
+ result = t + t
+ assert result.values() == 84.0
+ result = t + 2.0
+ assert result.values() == 44.0
+ a = candle.rand((3, 1, 4))
+ b = candle.rand((2, 1))
+ c_native = a.broadcast_add(b)
+ c = a + b
+ assert c.shape == (3, 2, 4)
+ assert c.values() == c_native.values()
+ with pytest.raises(ValueError):
+ d = candle.rand((3, 4, 5))
+ e = candle.rand((4, 6))
+ f = d + e
+
+
+def test_tensor_can_be_subtracted():
+ t = Tensor(42.0)
+ result = t - t
+ assert result.values() == 0
+ result = t - 2.0
+ assert result.values() == 40.0
+ a = candle.rand((3, 1, 4))
+ b = candle.rand((2, 1))
+ c_native = a.broadcast_sub(b)
+ c = a - b
+ assert c.shape == (3, 2, 4)
+ assert c.values() == c_native.values()
+ with pytest.raises(ValueError):
+ d = candle.rand((3, 4, 5))
+ e = candle.rand((4, 6))
+ f = d - e
+
+
+def test_tensor_can_be_multiplied():
+ t = Tensor(42.0)
+ result = t * t
+ assert result.values() == 1764.0
+ result = t * 2.0
+ assert result.values() == 84.0
+ a = candle.rand((3, 1, 4))
+ b = candle.rand((2, 1))
+ c_native = a.broadcast_mul(b)
+ c = a * b
+ assert c.shape == (3, 2, 4)
+ assert c.values() == c_native.values()
+ with pytest.raises(ValueError):
+ d = candle.rand((3, 4, 5))
+ e = candle.rand((4, 6))
+ f = d * e
+
+
+def test_tensor_can_be_divided():
+ t = Tensor(42.0)
+ result = t / t
+ assert result.values() == 1.0
+ result = t / 2.0
+ assert result.values() == 21.0
+ a = candle.rand((3, 1, 4))
+ b = candle.rand((2, 1))
+ c_native = a.broadcast_div(b)
+ c = a / b
+ assert c.shape == (3, 2, 4)
+ assert c.values() == c_native.values()
+ with pytest.raises(ValueError):
+ d = candle.rand((3, 4, 5))
+ e = candle.rand((4, 6))
+ f = d / e