diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-06 10:17:43 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-06 10:17:43 +0100 |
commit | 93cfe5642f473889d1df62ccb8f1740f77523dd3 (patch) | |
tree | 41952f254185d73a10d56d54b7b96b9f08ae3a59 /candle-pyo3/test.py | |
parent | 88bd3b604af0151b0da792980be482d396867e42 (diff) | |
download | candle-93cfe5642f473889d1df62ccb8f1740f77523dd3.tar.gz candle-93cfe5642f473889d1df62ccb8f1740f77523dd3.tar.bz2 candle-93cfe5642f473889d1df62ccb8f1740f77523dd3.zip |
Pyo3 dtype (#327)
* Better handling of dtypes in pyo3.
* More pyo3 dtype.
Diffstat (limited to 'candle-pyo3/test.py')
-rw-r--r-- | candle-pyo3/test.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 160a099d..1711cdad 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,3 +1,18 @@ +import os +import sys + +# The "import candle" statement below works if there is a "candle.so" file in sys.path. +# Here we check for shared libraries that can be used in the build directory. +BUILD_DIR = "./target/release-with-debug" +so_file = BUILD_DIR + "/candle.so" +if os.path.islink(so_file): os.remove(so_file) +for lib_file in ["libcandle.dylib", "libcandle.so"]: + lib_file_ = BUILD_DIR + "/" + lib_file + if os.path.isfile(lib_file_): + os.symlink(lib_file, so_file) + sys.path.insert(0, BUILD_DIR) + break + import candle t = candle.Tensor(42.0) @@ -12,7 +27,9 @@ print(t+t) t = t.reshape([2, 4]) print(t.matmul(t.t())) +print(t.to_dtype(candle.u8)) print(t.to_dtype("u8")) t = candle.randn((5, 3)) print(t) +print(t.dtype) |