diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-06 20:01:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-06 19:01:07 +0100 |
commit | 904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch) | |
tree | 8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle/__init__.py | |
parent | b0442eff8a696d1faba10e23ba645eb11e385116 (diff) | |
download | candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2 candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip |
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations
* Add `state_dict` and `load_state_dict` functionality
* Move modules around and create `candle.nn.Linear`
* Add `nn.Embedding` and `nn.LayerNorm`
* Add BERT implementation
* Batch q-matmul
* Automatically dequantize `QTensors` if a `Tensor` is expected
* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality
* Unittests for `Module`, `Tensor` and `candle.utils`
* Add `pytorch` like slicing to `Tensor`
* Cleanup and BERT fixes
* `black` formatting + unit-test for `nn.Linear`
* Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle/__init__.py')
-rw-r--r-- | candle-pyo3/py_src/candle/__init__.py | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py index 951609cc..dc97b775 100644 --- a/candle-pyo3/py_src/candle/__init__.py +++ b/candle-pyo3/py_src/candle/__init__.py @@ -1,5 +1,30 @@ -from .candle import * +import logging + +try: + from .candle import * +except ImportError as e: + # If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here + logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...") + import os + import platform + + # Try to locate CUDA_PATH environment variable + cuda_path = os.environ.get("CUDA_PATH", None) + if cuda_path: + logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}") + if platform.system() == "Windows": + cuda_path = os.path.join(cuda_path, "bin") + else: + cuda_path = os.path.join(cuda_path, "lib64") + + logging.warning(f"Adding {cuda_path} to DLL search path...") + os.add_dll_directory(cuda_path) + + try: + from .candle import * + except ImportError as inner_e: + raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.") __doc__ = candle.__doc__ if hasattr(candle, "__all__"): - __all__ = candle.__all__
\ No newline at end of file + __all__ = candle.__all__ |