summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-23 21:10:59 +0200
committerGitHub <noreply@github.com>2023-10-23 20:10:59 +0100
commiteae94a451b3c6b3ef5975639e98dfbc587a2ac27 (patch)
tree3e3241174fdbcdba145e388d2242b810363ee333 /candle-pyo3/py_src
parent86e1803191e2ed44c57ad01807b29a886c0263bb (diff)
downloadcandle-eae94a451b3c6b3ef5975639e98dfbc587a2ac27.tar.gz
candle-eae94a451b3c6b3ef5975639e98dfbc587a2ac27.tar.bz2
candle-eae94a451b3c6b3ef5975639e98dfbc587a2ac27.zip
PyO3: Add `mkl` support (#1159)
* Add `mkl` support * Set `mkl` path on linux
Diffstat (limited to 'candle-pyo3/py_src')
-rw-r--r--candle-pyo3/py_src/candle/__init__.py48
1 files changed, 36 insertions, 12 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py
index dc97b775..38718a46 100644
--- a/candle-pyo3/py_src/candle/__init__.py
+++ b/candle-pyo3/py_src/candle/__init__.py
@@ -3,27 +3,51 @@ import logging
try:
from .candle import *
except ImportError as e:
- # If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
- logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
+ # If we are in development mode, or we did not bundle the DLLs, we try to locate them here
+ # PyO3 wont give us any infomration about what DLLs are missing, so we can only try to load the DLLs and re-import the module
+ logging.warning("DLLs were not bundled with this package. Trying to locate them...")
import os
import platform
- # Try to locate CUDA_PATH environment variable
- cuda_path = os.environ.get("CUDA_PATH", None)
- if cuda_path:
- logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
- if platform.system() == "Windows":
- cuda_path = os.path.join(cuda_path, "bin")
+ def locate_cuda_dlls():
+ logging.warning("Locating CUDA DLLs...")
+ # Try to locate CUDA_PATH environment variable
+ cuda_path = os.environ.get("CUDA_PATH", None)
+ if cuda_path:
+ logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
+ if platform.system() == "Windows":
+ cuda_path = os.path.join(cuda_path, "bin")
+ else:
+ cuda_path = os.path.join(cuda_path, "lib64")
+
+ logging.warning(f"Adding {cuda_path} to DLL search path...")
+ os.add_dll_directory(cuda_path)
+ else:
+ logging.warning("CUDA_PATH environment variable not found!")
+
+ def locate_mkl_dlls():
+ # Try to locate ONEAPI_ROOT environment variable
+ oneapi_root = os.environ.get("ONEAPI_ROOT", None)
+ if oneapi_root:
+ if platform.system() == "Windows":
+ mkl_path = os.path.join(
+ oneapi_root, "compiler", "latest", "windows", "redist", "intel64_win", "compiler"
+ )
+ else:
+ mkl_path = os.path.join(oneapi_root, "mkl", "latest", "lib", "intel64")
+
+ logging.warning(f"Adding {mkl_path} to DLL search path...")
+ os.add_dll_directory(mkl_path)
else:
- cuda_path = os.path.join(cuda_path, "lib64")
+ logging.warning("ONEAPI_ROOT environment variable not found!")
- logging.warning(f"Adding {cuda_path} to DLL search path...")
- os.add_dll_directory(cuda_path)
+ locate_cuda_dlls()
+ locate_mkl_dlls()
try:
from .candle import *
except ImportError as inner_e:
- raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
+ raise ImportError("Could not locate DLLs. Please check the documentation for more information.")
__doc__ = candle.__doc__
if hasattr(candle, "__all__"):