summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/__init__.py
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-06 20:01:07 +0200
committerGitHub <noreply@github.com>2023-10-06 19:01:07 +0100
commit904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch)
tree8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle/__init__.py
parentb0442eff8a696d1faba10e23ba645eb11e385116 (diff)
downloadcandle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle/__init__.py')
-rw-r--r--candle-pyo3/py_src/candle/__init__.py29
1 files changed, 27 insertions, 2 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py
index 951609cc..dc97b775 100644
--- a/candle-pyo3/py_src/candle/__init__.py
+++ b/candle-pyo3/py_src/candle/__init__.py
@@ -1,5 +1,30 @@
-from .candle import *
+import logging
+
+try:
+ from .candle import *
+except ImportError as e:
+ # If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
+ logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
+ import os
+ import platform
+
+ # Try to locate CUDA_PATH environment variable
+ cuda_path = os.environ.get("CUDA_PATH", None)
+ if cuda_path:
+ logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
+ if platform.system() == "Windows":
+ cuda_path = os.path.join(cuda_path, "bin")
+ else:
+ cuda_path = os.path.join(cuda_path, "lib64")
+
+ logging.warning(f"Adding {cuda_path} to DLL search path...")
+ os.add_dll_directory(cuda_path)
+
+ try:
+ from .candle import *
+ except ImportError as inner_e:
+ raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
__doc__ = candle.__doc__
if hasattr(candle, "__all__"):
- __all__ = candle.__all__ \ No newline at end of file
+ __all__ = candle.__all__