summaryrefslogtreecommitdiff
path: root/candle-pyo3/test.py
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-06 10:17:43 +0100
committerGitHub <noreply@github.com>2023-08-06 10:17:43 +0100
commit93cfe5642f473889d1df62ccb8f1740f77523dd3 (patch)
tree41952f254185d73a10d56d54b7b96b9f08ae3a59 /candle-pyo3/test.py
parent88bd3b604af0151b0da792980be482d396867e42 (diff)
downloadcandle-93cfe5642f473889d1df62ccb8f1740f77523dd3.tar.gz
candle-93cfe5642f473889d1df62ccb8f1740f77523dd3.tar.bz2
candle-93cfe5642f473889d1df62ccb8f1740f77523dd3.zip
Pyo3 dtype (#327)
* Better handling of dtypes in pyo3. * More pyo3 dtype.
Diffstat (limited to 'candle-pyo3/test.py')
-rw-r--r--candle-pyo3/test.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py
index 160a099d..1711cdad 100644
--- a/candle-pyo3/test.py
+++ b/candle-pyo3/test.py
@@ -1,3 +1,18 @@
+import os
+import sys
+
+# The "import candle" statement below works if there is a "candle.so" file in sys.path.
+# Here we check for shared libraries that can be used in the build directory.
+BUILD_DIR = "./target/release-with-debug"
+so_file = BUILD_DIR + "/candle.so"
+if os.path.islink(so_file): os.remove(so_file)
+for lib_file in ["libcandle.dylib", "libcandle.so"]:
+ lib_file_ = BUILD_DIR + "/" + lib_file
+ if os.path.isfile(lib_file_):
+ os.symlink(lib_file, so_file)
+ sys.path.insert(0, BUILD_DIR)
+ break
+
import candle
t = candle.Tensor(42.0)
@@ -12,7 +27,9 @@ print(t+t)
t = t.reshape([2, 4])
print(t.matmul(t.t()))
+print(t.to_dtype(candle.u8))
print(t.to_dtype("u8"))
t = candle.randn((5, 3))
print(t)
+print(t.dtype)