diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-23 10:42:19 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 10:42:19 +0100 |
commit | 9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8 (patch) | |
tree | 4c7fef2cdb78409ca30e14981c783d717cd49f97 /candle-core | |
parent | 3743bed2d7bc02069770902e4a956aeabaef5453 (diff) | |
download | candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.tar.gz candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.tar.bz2 candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.zip |
Add support for i64 (#563)
* Add the i64 dtype.
* Adapt the cuda kernels.
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/convert.rs | 6 | ||||
-rw-r--r-- | candle-core/src/cpu/kernels.rs | 1 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 101 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 93 | ||||
-rw-r--r-- | candle-core/src/display.rs | 7 | ||||
-rw-r--r-- | candle-core/src/dtype.rs | 14 | ||||
-rw-r--r-- | candle-core/src/npy.rs | 8 | ||||
-rw-r--r-- | candle-core/src/op.rs | 24 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 23 |
9 files changed, 242 insertions, 35 deletions
diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 744982fc..5ea5612a 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -92,6 +92,7 @@ from_tensor!(f64); from_tensor!(f32); from_tensor!(f16); from_tensor!(bf16); +from_tensor!(i64); from_tensor!(u32); from_tensor!(u8); @@ -129,6 +130,11 @@ impl Tensor { f.write_u32::<LittleEndian>(v)? } } + DType::I64 => { + for v in vs.to_vec1::<i64>()? { + f.write_i64::<LittleEndian>(v)? + } + } DType::U8 => { let vs = vs.to_vec1::<u8>()?; f.write_all(&vs)?; diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 1184f8f3..cdbf5a21 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -53,6 +53,7 @@ impl VecOps for f64 {} impl VecOps for half::bf16 {} impl VecOps for u8 {} impl VecOps for u32 {} +impl VecOps for i64 {} #[inline(always)] pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8d18a343..d7e0ec84 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -9,6 +9,7 @@ use half::{bf16, f16}; pub enum CpuStorage { U8(Vec<u8>), U32(Vec<u32>), + I64(Vec<i64>), BF16(Vec<bf16>), F16(Vec<f16>), F32(Vec<f32>), @@ -25,6 +26,7 @@ pub trait Map1 { match vs { CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)), CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)), + CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)), CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)), CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)), CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)), @@ -45,6 +47,7 @@ pub trait Map1Any { match vs { CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?), CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?), + CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?), CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?), CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?), CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?), @@ -68,6 +71,7 @@ pub trait Map2 { match (v1, v2) { (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), @@ -96,6 +100,7 @@ pub trait Map2U8 { match (v1, v2) { (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), @@ -1527,6 +1532,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, + Self::I64(_) => DType::I64, Self::BF16(_) => DType::BF16, Self::F16(_) => DType::F16, Self::F32(_) => DType::F32, @@ -1545,6 +1551,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) } + (Self::I64(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } (Self::BF16(storage), DType::BF16) => { let data = unary_map(storage, layout, |v| v); Ok(Self::BF16(data)) @@ -1569,6 +1579,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } + (Self::I64(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } (Self::BF16(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); Ok(Self::F16(data)) @@ -1593,6 +1607,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::I64(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } (Self::BF16(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v.to_f32()); Ok(Self::F32(data)) @@ -1629,18 +1647,26 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } - (Self::U8(storage), DType::U32) => { - let data = unary_map(storage, layout, |v| v as u32); - Ok(Self::U32(data)) - } (Self::U32(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::I64(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } + (Self::U8(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } (Self::U32(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } + (Self::I64(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } (Self::BF16(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v.to_f32() as u32); Ok(Self::U32(data)) @@ -1657,6 +1683,34 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::U8(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I64(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I64(data)) + } + (Self::BF16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } + (Self::F16(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } + (Self::F32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::F64(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } (Self::U8(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -1665,6 +1719,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } + (Self::I64(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } (Self::BF16(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v.to_f64()); Ok(Self::F64(data)) @@ -1791,6 +1849,7 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -1840,6 +1899,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } + Self::I64(storage) => { + let data = unary_map(storage, layout, B::i64); + Ok(Self::I64(data)) + } } } @@ -1890,6 +1953,14 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U32(data)) } + (Self::I64(lhs), Self::I64(rhs)) => { + let data = if B::I64_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i64) + }; + Ok(Self::I64(data)) + } (Self::U8(lhs), Self::U8(rhs)) => { let data = if B::U8_VEC { binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec) @@ -1914,6 +1985,7 @@ impl BackendStorage for CpuStorage { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -1942,6 +2014,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), } } @@ -1970,6 +2043,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")), } } @@ -1978,6 +2052,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")), } } @@ -1994,6 +2069,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")), } } @@ -2022,6 +2098,13 @@ impl BackendStorage for CpuStorage { }; IndexAdd { ids, dim }.map(self, l, src, src_l) } + Self::I64(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" })?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")), } } @@ -2074,7 +2157,9 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()), + DType::U8 | DType::U32 | DType::I64 => { + Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) + } DType::BF16 => { let mut data = Vec::with_capacity(elem_count); let uniform = @@ -2118,7 +2203,9 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), + DType::U8 | DType::U32 | DType::I64 => { + Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + } DType::BF16 => { let mut data = Vec::with_capacity(elem_count); let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std)) @@ -2162,6 +2249,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), + DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), DType::F32 => CpuStorage::F32(vec![1f32; elem_count]), @@ -2175,6 +2263,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 44e65a85..9809bb4a 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -139,6 +139,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(data) } + DType::I64 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::<i64>(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i64", kernels::FILL)?; + let params = (&data, v as i64, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I64(data) + } DType::BF16 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?; @@ -236,6 +244,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::<u32>(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I64 => { + let data = self.alloc_zeros::<i64>(elem_count).w()?; + CudaStorageSlice::I64(data) + } DType::BF16 => { let data = self.alloc_zeros::<bf16>(elem_count).w()?; CudaStorageSlice::BF16(data) @@ -265,11 +277,13 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - }) - .w()?, + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()? + } DType::F32 => { let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?; curand.0.fill_with_uniform(&mut data).w()?; @@ -297,11 +311,13 @@ impl BackendDevice for CudaDevice { let elem_count = shape.elem_count(); let curand = self.curand.lock().unwrap(); let slice = match dtype { - DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()?, + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()? + } DType::F32 => { let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?; curand @@ -336,6 +352,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I64(data) + } CpuStorage::BF16(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::BF16(data) @@ -364,6 +384,7 @@ impl BackendDevice for CudaDevice { enum CudaStorageSlice { U8(CudaSlice<u8>), U32(CudaSlice<u32>), + I64(CudaSlice<i64>), BF16(CudaSlice<bf16>), F16(CudaSlice<f16>), F32(CudaSlice<f32>), @@ -383,6 +404,7 @@ trait Map1 { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), + S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), @@ -406,6 +428,7 @@ trait Map2 { let out = match (s1, s2) { (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), @@ -437,6 +460,7 @@ trait Map2InPlace { match (dst, src) { (S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d), (S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d), (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d), (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), @@ -459,6 +483,7 @@ trait Map1Any { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, + S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, S::F16(s) => self.f(s, d, l, S::F16)?, S::F32(s) => self.f(s, d, l, S::F32)?, @@ -482,6 +507,7 @@ trait Map2Any { let out = match (s1, s2) { (S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?, (S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?, (S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, @@ -714,6 +740,9 @@ impl<'a> Map1 for IndexSelect<'a> { CudaStorageSlice::U8(slice) => { ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) } + CudaStorageSlice::I64(slice) => { + ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) + } _ => Err(CudaError::UnexpectedDType { msg: "index_select ids should be u8 or u32", expected: DType::U32, @@ -773,8 +802,11 @@ impl<'a> Map1 for Gather<'a> { ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) } CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I64(slice) => { + ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) + } _ => Err(CudaError::UnexpectedDType { - msg: "gather ids should be u8 or u32", + msg: "gather ids should be u8/u32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -820,9 +852,10 @@ impl<'a> Map2InPlace for IndexAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "index-add ids should be u8 or u32", + msg: "index-add ids should be u8/u32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -867,9 +900,10 @@ impl<'a> Map2InPlace for ScatterAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "scatter-add ids should be u8 or u32", + msg: "scatter-add ids should be u8/u32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -1080,8 +1114,12 @@ impl<'a> Map2 for WhereCond<'a> { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_u32") } + CudaStorageSlice::I64(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i64") + } _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u8 or u32", + msg: "where conditions should be u8/u32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -1225,6 +1263,7 @@ macro_rules! cuda_dtype { } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); +cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); cuda_dtype!(f32, F32); @@ -1338,6 +1377,7 @@ impl BackendStorage for CudaStorage { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, @@ -1363,6 +1403,7 @@ impl BackendStorage for CudaStorage { let inp = match &self.slice { CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), @@ -1385,6 +1426,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } + DType::I64 => { + let out = unsafe { dev.alloc::<i64>(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I64(out) + } DType::BF16 => { let out = unsafe { dev.alloc::<bf16>(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); @@ -1469,6 +1516,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::I64(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I64(cpu_storage)) + } CudaStorageSlice::BF16(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; @@ -1588,6 +1640,7 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, }; Ok(Self { slice, device }) @@ -1802,6 +1855,18 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()? } } + (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index f5c77e4b..da2adf37 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -49,6 +49,7 @@ impl std::fmt::Debug for Tensor { match self.dtype() { DType::U8 => self.fmt_dt::<u8>(f), DType::U32 => self.fmt_dt::<u32>(f), + DType::I64 => self.fmt_dt::<i64>(f), DType::BF16 => self.fmt_dt::<bf16>(f), DType::F16 => self.fmt_dt::<f16>(f), DType::F32 => self.fmt_dt::<f32>(f), @@ -431,6 +432,12 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } + DType::I64 => { + let tf: IntFormatter<i64> = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::BF16 => { if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) { let max_w = tf.max_width(&to_display); diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 7f04a653..91922b11 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -5,6 +5,7 @@ use crate::{CpuStorage, Error, Result}; pub enum DType { U8, U32, + I64, BF16, F16, F32, @@ -20,6 +21,7 @@ impl std::str::FromStr for DType { match s { "u8" => Ok(Self::U8), "u32" => Ok(Self::U32), + "i64" => Ok(Self::I64), "bf16" => Ok(Self::BF16), "f16" => Ok(Self::F16), "f32" => Ok(Self::F32), @@ -34,6 +36,7 @@ impl DType { match self { Self::U8 => "u8", Self::U32 => "u32", + Self::I64 => "i64", Self::BF16 => "bf16", Self::F16 => "f16", Self::F32 => "f32", @@ -45,6 +48,7 @@ impl DType { match self { Self::U8 => 1, Self::U32 => 4, + Self::I64 => 8, Self::BF16 => 2, Self::F16 => 2, Self::F32 => 4, @@ -125,6 +129,7 @@ use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); @@ -135,6 +140,15 @@ pub trait IntDType: WithDType { fn as_usize(&self) -> usize; } +impl IntDType for i64 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + impl IntDType for u32 { fn is_true(&self) -> bool { *self != 0 diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 62c62f9e..f3a75965 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -85,6 +85,7 @@ impl Header { DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", + DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", }; @@ -160,7 +161,7 @@ impl Header { "f" | "f4" => DType::F32, "d" | "f8" => DType::F64, // "i" | "i4" => DType::S32, - // "q" | "i8" => DType::S64, + "q" | "i8" => DType::I64, // "h" | "i2" => DType::S16, // "b" | "i1" => DType::S8, "B" | "u1" => DType::U8, @@ -233,6 +234,11 @@ impl Tensor { reader.read_u32_into::<LittleEndian>(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::I64 => { + let mut data_t = vec![0i64; elem_count]; + reader.read_i64_into::<LittleEndian>(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } } } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index b8d4e34f..b0528494 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -251,6 +251,7 @@ pub trait UnaryOpT { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + fn i64(v1: i64) -> i64; // There is no very good way to represent optional function in traits so we go for an explicit // boolean flag to mark the function as existing. @@ -274,6 +275,7 @@ pub trait BinaryOpT { fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + fn i64(v1: i64, v2: i64) -> i64; const BF16_VEC: bool = false; fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {} @@ -287,6 +289,8 @@ pub trait BinaryOpT { fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} const U32_VEC: bool = false; fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {} + const I64_VEC: bool = false; + fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} } pub(crate) struct Add; @@ -337,6 +341,10 @@ macro_rules! bin_op { fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) } + #[inline(always)] + fn i64(v1: i64, v2: i64) -> i64 { + $e(v1, v2) + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -420,6 +428,10 @@ macro_rules! unary_op { fn u32(_: u32) -> u32 { todo!("no unary function for u32") } + #[inline(always)] + fn i64(_: i64) -> i64 { + todo!("no unary function for i64") + } } }; @@ -452,6 +464,10 @@ macro_rules! unary_op { fn u32(_: u32) -> u32 { todo!("no unary function for u32") } + #[inline(always)] + fn i64(_: i64) -> i64 { + todo!("no unary function for i64") + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -543,6 +559,10 @@ impl UnaryOpT for Gelu { fn u32(_: u32) -> u32 { 0 } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] @@ -592,6 +612,10 @@ impl UnaryOpT for Relu { fn u32(v: u32) -> u32 { v } + #[inline(always)] + fn i64(v: i64) -> i64 { + v + } } /// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index ec0dc766..f37bb8ef 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -10,6 +10,7 @@ impl From<DType> for st::Dtype { match value { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, + DType::I64 => st::Dtype::I64, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, @@ -24,6 +25,7 @@ impl TryFrom<st::Dtype> for DType { match value { st::Dtype::U8 => Ok(DType::U8), st::Dtype::U32 => Ok(DType::U32), + st::Dtype::I64 => Ok(DType::I64), st::Dtype::BF16 => Ok(DType::BF16), st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), @@ -189,6 +191,7 @@ impl Tensor { match dtype { DType::U8 => convert_slice::<u8>(data, shape, device), DType::U32 => convert_slice::<u32>(data, shape, device), + DType::I64 => convert_slice::<i64>(data, shape, device), DType::BF16 => convert_slice::<half::bf16>(data, shape, device), DType::F16 => convert_slice::<half::f16>(data, shape, device), DType::F32 => convert_slice::<f32>(data, shape, device), @@ -205,24 +208,15 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> { convert_with_cast_::<u16, u32, _>(view, device, conv) } st::Dtype::U32 => convert_::<u32>(view, device), + st::Dtype::I32 => { + let conv = |x| Ok(i64::from(x)); + convert_with_cast_::<i32, i64, _>(view, device, conv) + } + st::Dtype::I64 => convert_::<i64>(view, device), st::Dtype::BF16 => convert_::<half::bf16>(view, device), st::Dtype::F16 => convert_::<half::f16>(view, device), st::Dtype::F32 => convert_::<f32>(view, device), st::Dtype::F64 => convert_::<f64>(view, device), - st::Dtype::I32 => { - let conv = |x| { - u32::try_from(x) - .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}"))) - }; - convert_with_cast_::<i32, u32, _>(view, device, conv) - } - st::Dtype::I64 => { - let conv = |x| { - u32::try_from(x) - .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}"))) - }; - convert_with_cast_::<i64, u32, _>(view, device, conv) - } dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } @@ -233,6 +227,7 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> { match tensor.dtype() { DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)), + DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)), DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)), |