summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/ci_cuda.yaml87
-rw-r--r--Cargo.toml2
-rw-r--r--README.md4
-rw-r--r--candle-core/Cargo.toml2
-rw-r--r--candle-core/src/cpu_backend.rs66
-rw-r--r--candle-core/src/cpu_kernels.rs28
-rw-r--r--candle-core/src/dtype.rs8
-rw-r--r--candle-core/src/error.rs14
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/npy.rs18
-rw-r--r--candle-core/src/safetensors.rs20
-rw-r--r--candle-core/tests/tensor_tests.rs11
-rw-r--r--candle-examples/examples/bert/main.rs14
-rw-r--r--candle-kernels/src/compatibility.cuh18
14 files changed, 227 insertions, 66 deletions
diff --git a/.github/workflows/ci_cuda.yaml b/.github/workflows/ci_cuda.yaml
new file mode 100644
index 00000000..7c6cfa9b
--- /dev/null
+++ b/.github/workflows/ci_cuda.yaml
@@ -0,0 +1,87 @@
+name: CI / cuda
+
+on:
+ workflow_dispatch:
+ pull_request:
+
+jobs:
+ start-runner:
+ name: Start self-hosted EC2 runner
+ runs-on: ubuntu-latest
+ env:
+ AWS_REGION: us-east-1
+ EC2_AMI_ID: ami-03cfed9ea28f4b002
+ EC2_INSTANCE_TYPE: g5.xlarge
+ EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
+ EC2_SECURITY_GROUP: sg-030175c435ac141d6
+ outputs:
+ label: ${{ steps.start-ec2-runner.outputs.label }}
+ ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
+ steps:
+ - name: Configure AWS credentials
+ uses: aws-actions/configure-aws-credentials@v1
+ with:
+ aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ aws-region: ${{ env.AWS_REGION }}
+ - name: Start EC2 runner
+ id: start-ec2-runner
+ uses: philschmid/philschmid-ec2-github-runner@main
+ with:
+ mode: start
+ github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
+ ec2-image-id: ${{ env.EC2_AMI_ID }}
+ ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
+ subnet-id: ${{ env.EC2_SUBNET_ID }}
+ security-group-id: ${{ env.EC2_SECURITY_GROUP }}
+ aws-resource-tags: > # optional, requires additional permissions
+ [
+ {"Key": "Name", "Value": "ec2-tgi-github-runner"},
+ {"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
+ ]
+
+ test-cuda:
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+ needs: start-runner # required to start the main job when the runner is ready
+ runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
+ permissions:
+ contents: write
+ packages: write
+ # This is used to complete the identity challenge
+ # with sigstore/fulcio when running outside of PRs.
+ id-token: write
+ security-events: write
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v3
+ - name: Install Rust Stable
+ run: curl https://sh.rustup.rs -sSf | sh -s -- -y
+ - uses: Swatinem/rust-cache@v2
+ - run: apt update -y && apt install libssl-dev -y
+ - name: Test (cuda)
+ run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
+ stop-runner:
+ name: Stop self-hosted EC2 runner
+ needs:
+ - start-runner
+ - test-cuda
+ runs-on: ubuntu-latest
+ env:
+ AWS_REGION: us-east-1
+ if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
+ steps:
+ - name: Configure AWS credentials
+ uses: aws-actions/configure-aws-credentials@v1
+ with:
+ aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ aws-region: ${{ env.AWS_REGION }}
+ - name: Stop EC2 runner
+ uses: philschmid/philschmid-ec2-github-runner@main
+ with:
+ mode: stop
+ github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
+ label: ${{ needs.start-runner.outputs.label }}
+ ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
diff --git a/Cargo.toml b/Cargo.toml
index 4bc0058b..c0d87680 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,6 +31,7 @@ clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.13", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.6", package = "candle-gemm" }
+ggblas = "0.1.2"
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
@@ -41,6 +42,7 @@ memmap2 = "0.7.1"
num_cpus = "1.15.0"
num-traits = "0.2.15"
rand = "0.8.5"
+rand_distr = "0.4.3"
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
diff --git a/README.md b/README.md
index 3b71927b..67ab5678 100644
--- a/README.md
+++ b/README.md
@@ -48,8 +48,8 @@ For llama2, run the following command to retrieve the weight files and start a
test server:
```bash
cd candle-wasm-examples/llama2-c
-wget https://karpathy.ai/llama2c/model.bin
-wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
+wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
+wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --public-url /candle-llama2/ --port 8081
```
And then head over to
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index af77a0e0..bf57a91c 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -15,6 +15,7 @@ byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
+ggblas = { workspace = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }
@@ -22,6 +23,7 @@ memmap2 = { workspace = true }
num-traits = { workspace = true }
num_cpus = { workspace = true }
rand = { workspace = true }
+rand_distr = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
zip = { workspace = true }
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 238a9a69..250e2721 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1023,14 +1023,7 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
const OP: &'static str = "conv1d";
- fn f<T: 'static + num_traits::NumAssign + Copy>(
- &self,
- inp: &[T],
- inp_l: &Layout,
- k: &[T],
- k_l: &Layout,
- ) -> Result<Vec<T>> {
- // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
+ fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
@@ -1040,25 +1033,35 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_elems = p.c_out * l_out * p.b_size;
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
for b_idx in 0..p.b_size {
- let inp_idx = b_idx * inp_s0;
- let dst_idx = b_idx * p.c_out * l_out;
+ for src_l in 0..p.l_in {
+ for src_c_idx in 0..p.c_in {
+ let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
+ inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
+ }
+ }
+ }
+ for offset in 0..p.k_size {
for dst_c_idx in 0..p.c_out {
- let dst_idx = dst_idx + dst_c_idx * l_out;
- for dst_l in 0..l_out {
- let dst_idx = dst_idx + dst_l;
- let mut d = T::zero();
- for offset in 0..p.k_size {
+ let dst_idx = dst_c_idx * l_out;
+ let k_cont = (0..p.c_in)
+ .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
+ .collect::<Vec<_>>();
+ for b_idx in 0..p.b_size {
+ let dst_idx = dst_idx + b_idx * p.c_out * l_out;
+ for dst_l in 0..l_out {
+ let dst_idx = dst_idx + dst_l;
let src_l = (p.stride * dst_l + offset)
.saturating_sub(p.padding)
.min(p.l_in - 1);
- for src_c_idx in 0..p.c_in {
- let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
- let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
- d += inp[inp_idx] * k[k_idx]
- }
+ let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
+ assert!(inp_cont.len() >= p.c_in);
+ assert!(k_cont.len() >= p.c_in);
+ let mut d = T::zero();
+ unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
+ dst[dst_idx] += d
}
- dst[dst_idx] = d
}
}
}
@@ -2070,35 +2073,36 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = bf16::from_f64(std);
- let mean = bf16::from_f64(mean);
+ let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = f16::from_f64(std);
- let mean = f16::from_f64(mean);
+ let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
- let std = std as f32;
- let mean = mean as f32;
+ let normal =
+ rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
+ let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F64(data))
}
diff --git a/candle-core/src/cpu_kernels.rs b/candle-core/src/cpu_kernels.rs
new file mode 100644
index 00000000..187dc16b
--- /dev/null
+++ b/candle-core/src/cpu_kernels.rs
@@ -0,0 +1,28 @@
+pub trait VecDot: num_traits::NumAssign + Copy {
+ /// Dot-product of two vectors.
+ ///
+ /// # Safety
+ ///
+ /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
+ /// element.
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ *res = Self::zero();
+ for i in 0..len {
+ *res += *lhs.add(i) * *rhs.add(i)
+ }
+ }
+}
+
+impl VecDot for f32 {
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ ggblas::ggml::vec_dot_f32(lhs, rhs, res, len)
+ }
+}
+
+impl VecDot for f64 {}
+impl VecDot for half::bf16 {}
+impl VecDot for half::f16 {}
+impl VecDot for u8 {}
+impl VecDot for u32 {}
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index 92929748..5d24b08f 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -54,7 +54,13 @@ impl DType {
}
pub trait WithDType:
- Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + std::fmt::Display + 'static
+ Sized
+ + Copy
+ + num_traits::NumAssign
+ + std::cmp::PartialOrd
+ + std::fmt::Display
+ + 'static
+ + crate::cpu_kernels::VecDot
{
const DTYPE: DType;
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 35a33032..c18b43c6 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -185,6 +185,13 @@ pub enum Error {
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
+ /// Adding path information to an error.
+ #[error("path: {path:?} {inner}")]
+ WithPath {
+ inner: Box<Self>,
+ path: std::path::PathBuf,
+ },
+
#[error("{inner}\n{backtrace}")]
WithBacktrace {
inner: Box<Self>,
@@ -214,6 +221,13 @@ impl Error {
},
}
}
+
+ pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
+ Self::WithPath {
+ inner: Box::new(self),
+ path: p.as_ref().to_path_buf(),
+ }
+ }
}
#[macro_export]
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 016d3806..aba88135 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -40,6 +40,7 @@ pub mod backprop;
mod conv;
mod convert;
pub mod cpu_backend;
+pub mod cpu_kernels;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
mod device;
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index 6302cf71..e17ba02a 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -307,39 +307,39 @@ impl Tensor {
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
- let elem_count = self.elem_count();
+ let vs = self.flatten_all()?;
match self.dtype() {
DType::BF16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
+ let vs = vs.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
- let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
+ let vs = vs.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
- for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
+ for v in vs.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
- for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
+ for v in vs.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
- for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
+ for v in vs.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
DType::U8 => {
- let data = self.reshape(elem_count)?.to_vec1::<u8>()?;
- f.write_all(&data)?;
+ let vs = vs.to_vec1::<u8>()?;
+ f.write_all(&vs)?;
}
}
Ok(())
@@ -373,7 +373,7 @@ pub struct NpzTensors {
index_per_name: HashMap<String, usize>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
- // re-create a zip reader each time.
+ // re-create a zip reader for each tensor.
}
impl NpzTensors {
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 132fb914..914e5101 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -257,7 +257,10 @@ pub fn save<P: AsRef<Path>>(tensors: &HashMap<&str, Tensor>, filename: P) -> Res
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
-pub struct MmapedFile(memmap2::Mmap);
+pub struct MmapedFile {
+ path: std::path::PathBuf,
+ inner: memmap2::Mmap,
+}
impl MmapedFile {
/// Creates a wrapper around a memory mapped file from which you can retrieve
@@ -267,13 +270,20 @@ impl MmapedFile {
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
- let file = std::fs::File::open(p)?;
- let mmap = memmap2::MmapOptions::new().map(&file)?;
- Ok(Self(mmap))
+ let p = p.as_ref();
+ let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
+ let inner = memmap2::MmapOptions::new()
+ .map(&file)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ Ok(Self {
+ inner,
+ path: p.to_path_buf(),
+ })
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
- let st = safetensors::SafeTensors::deserialize(&self.0)?;
+ let st = safetensors::SafeTensors::deserialize(&self.inner)
+ .map_err(|e| Error::from(e).with_path(&self.path))?;
Ok(st)
}
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 599c2665..0b77f1a5 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -869,3 +869,14 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+
+// There was originally a bug on the CPU implementation for randn
+// https://github.com/huggingface/candle/issues/381
+#[test]
+fn randn_hasneg() -> Result<()> {
+ let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::<f32>()?;
+ if t.iter().all(|&v| v >= 0.) {
+ candle_core::bail!("all values in tensors are non-negative")
+ }
+ Ok(())
+}
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 79c78968..574755ed 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -39,6 +39,10 @@ struct Args {
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
+
+ /// L2 normalization for embeddings.
+ #[arg(long, default_value = "true")]
+ normalize_embeddings: bool,
}
impl Args {
@@ -164,7 +168,13 @@ fn main() -> Result<()> {
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+ let embeddings = if args.normalize_embeddings {
+ normalize_l2(&embeddings)?
+ } else {
+ embeddings
+ };
println!("pooled embeddings {:?}", embeddings.shape());
+
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
@@ -184,3 +194,7 @@ fn main() -> Result<()> {
}
Ok(())
}
+
+pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
+ Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
+}
diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh
index 2df8e921..5a22f4bc 100644
--- a/candle-kernels/src/compatibility.cuh
+++ b/candle-kernels/src/compatibility.cuh
@@ -6,24 +6,6 @@
// FIXME: the minimum compute capabilities are just guesses since the table is not specific enough
-// #if __CUDA_ARCH__ < 600
-// __device__ __forceinline__ __half __hmax(__half a, __half b) {
-// return __float2half(fmaxf(__half2float(a), __half2float(b)));
-// }
-// __device__ __forceinline__ __half __hmin(__half a, __half b) {
-// return __float2half(fminf(__half2float(a), __half2float(b)));
-// }
-// #endif
-
-#if __CUDA_ARCH__ < 800
-__device__ __forceinline__ __half __hmax_nan(__half a, __half b) {
- // return __hisnan(a) ? a : (__hisnan(b) ? b : __hmax(a, b));
-}
-__device__ __forceinline__ __half __hmin_nan(__half a, __half b) {
- // return __hisnan(a) ? a : (__hisnan(b) ? b : __hmin(a, b));
-}
-#endif
-
#if __CUDA_ARCH__ < 600
// Copied from https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomic-functions
__device__ double atomicAdd(double* address, double val) {