summaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-06-21 21:37:54 +0100
committerGitHub <noreply@github.com>2023-06-21 21:37:54 +0100
commitdb35b310504ab97044b2c3826de72f9bccf86415 (patch)
tree710596156a4c026d4dd2ba804fab79b6cdafae3b /src/lib.rs
parent7c317f9611c263f10d661b44151d3655a2fa3b90 (diff)
parent7c46de9584fd4315b84d3bc4c28cf1b2bad7785d (diff)
downloadcandle-db35b310504ab97044b2c3826de72f9bccf86415.tar.gz
candle-db35b310504ab97044b2c3826de72f9bccf86415.tar.bz2
candle-db35b310504ab97044b2c3826de72f9bccf86415.zip
Merge pull request #3 from LaurentMazare/cuda
Add Cuda support.
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs11
1 files changed, 10 insertions, 1 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 175d36ad..3bae1a7e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,6 +1,9 @@
mod cpu_backend;
+#[cfg(feature = "cuda")]
+mod cuda_backend;
mod device;
mod dtype;
+mod dummy_cuda_backend;
mod error;
mod op;
mod shape;
@@ -9,10 +12,16 @@ mod strided_index;
mod tensor;
pub use cpu_backend::CpuStorage;
-pub use device::Device;
+pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use shape::Shape;
pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
+
+#[cfg(feature = "cuda")]
+pub use cuda_backend::{CudaDevice, CudaError, CudaStorage};
+
+#[cfg(not(feature = "cuda"))]
+pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage};