summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-23 10:42:19 +0100
committerGitHub <noreply@github.com>2023-08-23 10:42:19 +0100
commit9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8 (patch)
tree4c7fef2cdb78409ca30e14981c783d717cd49f97 /candle-core
parent3743bed2d7bc02069770902e4a956aeabaef5453 (diff)
downloadcandle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.tar.gz
candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.tar.bz2
candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.zip
Add support for i64 (#563)
* Add the i64 dtype. * Adapt the cuda kernels.
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/convert.rs6
-rw-r--r--candle-core/src/cpu/kernels.rs1
-rw-r--r--candle-core/src/cpu_backend.rs101
-rw-r--r--candle-core/src/cuda_backend.rs93
-rw-r--r--candle-core/src/display.rs7
-rw-r--r--candle-core/src/dtype.rs14
-rw-r--r--candle-core/src/npy.rs8
-rw-r--r--candle-core/src/op.rs24
-rw-r--r--candle-core/src/safetensors.rs23
9 files changed, 242 insertions, 35 deletions
diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs
index 744982fc..5ea5612a 100644
--- a/candle-core/src/convert.rs
+++ b/candle-core/src/convert.rs
@@ -92,6 +92,7 @@ from_tensor!(f64);
from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
+from_tensor!(i64);
from_tensor!(u32);
from_tensor!(u8);
@@ -129,6 +130,11 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
+ DType::I64 => {
+ for v in vs.to_vec1::<i64>()? {
+ f.write_i64::<LittleEndian>(v)?
+ }
+ }
DType::U8 => {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs
index 1184f8f3..cdbf5a21 100644
--- a/candle-core/src/cpu/kernels.rs
+++ b/candle-core/src/cpu/kernels.rs
@@ -53,6 +53,7 @@ impl VecOps for f64 {}
impl VecOps for half::bf16 {}
impl VecOps for u8 {}
impl VecOps for u32 {}
+impl VecOps for i64 {}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 8d18a343..d7e0ec84 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -9,6 +9,7 @@ use half::{bf16, f16};
pub enum CpuStorage {
U8(Vec<u8>),
U32(Vec<u32>),
+ I64(Vec<i64>),
BF16(Vec<bf16>),
F16(Vec<f16>),
F32(Vec<f32>),
@@ -25,6 +26,7 @@ pub trait Map1 {
match vs {
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
+ CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
@@ -45,6 +47,7 @@ pub trait Map1Any {
match vs {
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
+ CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
@@ -68,6 +71,7 @@ pub trait Map2 {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
+ (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
@@ -96,6 +100,7 @@ pub trait Map2U8 {
match (v1, v2) {
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
@@ -1527,6 +1532,7 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(_) => DType::U8,
Self::U32(_) => DType::U32,
+ Self::I64(_) => DType::I64,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
@@ -1545,6 +1551,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
+ (Self::I64(storage), DType::BF16) => {
+ let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
+ Ok(Self::BF16(data))
+ }
(Self::BF16(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::BF16(data))
@@ -1569,6 +1579,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
+ (Self::I64(storage), DType::F16) => {
+ let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
+ Ok(Self::F16(data))
+ }
(Self::BF16(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
@@ -1593,6 +1607,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
+ (Self::I64(storage), DType::F32) => {
+ let data = unary_map(storage, layout, |v| v as f32);
+ Ok(Self::F32(data))
+ }
(Self::BF16(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data))
@@ -1629,18 +1647,26 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
- (Self::U8(storage), DType::U32) => {
- let data = unary_map(storage, layout, |v| v as u32);
- Ok(Self::U32(data))
- }
(Self::U32(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
+ (Self::I64(storage), DType::U8) => {
+ let data = unary_map(storage, layout, |v| v as u8);
+ Ok(Self::U8(data))
+ }
+ (Self::U8(storage), DType::U32) => {
+ let data = unary_map(storage, layout, |v| v as u32);
+ Ok(Self::U32(data))
+ }
(Self::U32(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U32(data))
}
+ (Self::I64(storage), DType::U32) => {
+ let data = unary_map(storage, layout, |v| v as u32);
+ Ok(Self::U32(data))
+ }
(Self::BF16(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data))
@@ -1657,6 +1683,34 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
+ (Self::U8(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v as i64);
+ Ok(Self::I64(data))
+ }
+ (Self::U32(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v as i64);
+ Ok(Self::I64(data))
+ }
+ (Self::I64(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v);
+ Ok(Self::I64(data))
+ }
+ (Self::BF16(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i64);
+ Ok(Self::I64(data))
+ }
+ (Self::F16(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v.to_f32() as i64);
+ Ok(Self::I64(data))
+ }
+ (Self::F32(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v as i64);
+ Ok(Self::I64(data))
+ }
+ (Self::F64(storage), DType::I64) => {
+ let data = unary_map(storage, layout, |v| v as i64);
+ Ok(Self::I64(data))
+ }
(Self::U8(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
@@ -1665,6 +1719,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
+ (Self::I64(storage), DType::F64) => {
+ let data = unary_map(storage, layout, |v| v as f64);
+ Ok(Self::F64(data))
+ }
(Self::BF16(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data))
@@ -1791,6 +1849,7 @@ impl BackendStorage for CpuStorage {
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
+ Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
@@ -1840,6 +1899,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, B::u32);
Ok(Self::U32(data))
}
+ Self::I64(storage) => {
+ let data = unary_map(storage, layout, B::i64);
+ Ok(Self::I64(data))
+ }
}
}
@@ -1890,6 +1953,14 @@ impl BackendStorage for CpuStorage {
};
Ok(Self::U32(data))
}
+ (Self::I64(lhs), Self::I64(rhs)) => {
+ let data = if B::I64_VEC {
+ binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
+ } else {
+ binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
+ };
+ Ok(Self::I64(data))
+ }
(Self::U8(lhs), Self::U8(rhs)) => {
let data = if B::U8_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
@@ -1914,6 +1985,7 @@ impl BackendStorage for CpuStorage {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
+ (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
@@ -1942,6 +2014,7 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
+ Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
}
}
@@ -1970,6 +2043,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
+ Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
}
}
@@ -1978,6 +2052,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
+ Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
}
}
@@ -1994,6 +2069,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
+ Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
}
}
@@ -2022,6 +2098,13 @@ impl BackendStorage for CpuStorage {
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
+ Self::I64(ids) => {
+ let ids = match ids_l.contiguous_offsets() {
+ Some((a, b)) => &ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "index-add" })?,
+ };
+ IndexAdd { ids, dim }.map(self, l, src, src_l)
+ }
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
}
}
@@ -2074,7 +2157,9 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
- DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
+ DType::U8 | DType::U32 | DType::I64 => {
+ Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
+ }
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
@@ -2118,7 +2203,9 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
- DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
+ DType::U8 | DType::U32 | DType::I64 => {
+ Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
+ }
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
@@ -2162,6 +2249,7 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
+ DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
@@ -2175,6 +2263,7 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
+ DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 44e65a85..9809bb4a 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -139,6 +139,14 @@ impl CudaDevice {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(data)
}
+ DType::I64 => {
+ // SAFETY: Set later by running the fill kernel.
+ let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
+ let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
+ let params = (&data, v as i64, elem_count);
+ unsafe { func.launch(cfg, params) }.w()?;
+ CudaStorageSlice::I64(data)
+ }
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
@@ -236,6 +244,10 @@ impl BackendDevice for CudaDevice {
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
+ DType::I64 => {
+ let data = self.alloc_zeros::<i64>(elem_count).w()?;
+ CudaStorageSlice::I64(data)
+ }
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
@@ -265,11 +277,13 @@ impl BackendDevice for CudaDevice {
let slice = match dtype {
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
// cudarc changes.
- DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
- dtype,
- op: "rand_uniform",
- })
- .w()?,
+ DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
+ Err(CudaError::UnsupportedDtype {
+ dtype,
+ op: "rand_uniform",
+ })
+ .w()?
+ }
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
@@ -297,11 +311,13 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
let slice = match dtype {
- DType::U8 | DType::U32 | DType::F16 | DType::BF16 => Err(CudaError::UnsupportedDtype {
- dtype,
- op: "rand_normal",
- })
- .w()?,
+ DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
+ Err(CudaError::UnsupportedDtype {
+ dtype,
+ op: "rand_normal",
+ })
+ .w()?
+ }
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand
@@ -336,6 +352,10 @@ impl BackendDevice for CudaDevice {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::U32(data)
}
+ CpuStorage::I64(storage) => {
+ let data = self.htod_sync_copy(storage).w()?;
+ CudaStorageSlice::I64(data)
+ }
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
CudaStorageSlice::BF16(data)
@@ -364,6 +384,7 @@ impl BackendDevice for CudaDevice {
enum CudaStorageSlice {
U8(CudaSlice<u8>),
U32(CudaSlice<u32>),
+ I64(CudaSlice<i64>),
BF16(CudaSlice<bf16>),
F16(CudaSlice<f16>),
F32(CudaSlice<f32>),
@@ -383,6 +404,7 @@ trait Map1 {
let out = match s {
S::U8(s) => S::U8(self.f(s, d, l)?),
S::U32(s) => S::U32(self.f(s, d, l)?),
+ S::I64(s) => S::I64(self.f(s, d, l)?),
S::BF16(s) => S::BF16(self.f(s, d, l)?),
S::F16(s) => S::F16(self.f(s, d, l)?),
S::F32(s) => S::F32(self.f(s, d, l)?),
@@ -406,6 +428,7 @@ trait Map2 {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
+ (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
@@ -437,6 +460,7 @@ trait Map2InPlace {
match (dst, src) {
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
+ (S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
@@ -459,6 +483,7 @@ trait Map1Any {
let out = match s {
S::U8(s) => self.f(s, d, l, S::U8)?,
S::U32(s) => self.f(s, d, l, S::U32)?,
+ S::I64(s) => self.f(s, d, l, S::I64)?,
S::BF16(s) => self.f(s, d, l, S::BF16)?,
S::F16(s) => self.f(s, d, l, S::F16)?,
S::F32(s) => self.f(s, d, l, S::F32)?,
@@ -482,6 +507,7 @@ trait Map2Any {
let out = match (s1, s2) {
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
+ (S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
@@ -714,6 +740,9 @@ impl<'a> Map1 for IndexSelect<'a> {
CudaStorageSlice::U8(slice) => {
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
}
+ CudaStorageSlice::I64(slice) => {
+ ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
+ }
_ => Err(CudaError::UnexpectedDType {
msg: "index_select ids should be u8 or u32",
expected: DType::U32,
@@ -773,8 +802,11 @@ impl<'a> Map1 for Gather<'a> {
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
}
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
+ CudaStorageSlice::I64(slice) => {
+ ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
+ }
_ => Err(CudaError::UnexpectedDType {
- msg: "gather ids should be u8 or u32",
+ msg: "gather ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
@@ -820,9 +852,10 @@ impl<'a> Map2InPlace for IndexAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
+ CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
- msg: "index-add ids should be u8 or u32",
+ msg: "index-add ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
@@ -867,9 +900,10 @@ impl<'a> Map2InPlace for ScatterAdd<'a> {
};
let (name, ids) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
+ CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
_ => Err(CudaError::UnexpectedDType {
- msg: "scatter-add ids should be u8 or u32",
+ msg: "scatter-add ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
@@ -1080,8 +1114,12 @@ impl<'a> Map2 for WhereCond<'a> {
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
(ptr, "where_u32")
}
+ CudaStorageSlice::I64(slice) => {
+ let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
+ (ptr, "where_i64")
+ }
_ => Err(CudaError::UnexpectedDType {
- msg: "where conditions should be u8 or u32",
+ msg: "where conditions should be u8/u32/i64",
expected: DType::U32,
got: self.0.dtype(),
})
@@ -1225,6 +1263,7 @@ macro_rules! cuda_dtype {
}
cuda_dtype!(u8, U8);
cuda_dtype!(u32, U32);
+cuda_dtype!(i64, I64);
cuda_dtype!(f16, F16);
cuda_dtype!(bf16, BF16);
cuda_dtype!(f32, F32);
@@ -1338,6 +1377,7 @@ impl BackendStorage for CudaStorage {
match self.slice {
CudaStorageSlice::U8(_) => DType::U8,
CudaStorageSlice::U32(_) => DType::U32,
+ CudaStorageSlice::I64(_) => DType::I64,
CudaStorageSlice::BF16(_) => DType::BF16,
CudaStorageSlice::F16(_) => DType::F16,
CudaStorageSlice::F32(_) => DType::F32,
@@ -1363,6 +1403,7 @@ impl BackendStorage for CudaStorage {
let inp = match &self.slice {
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
+ CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
@@ -1385,6 +1426,12 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U32(out)
}
+ DType::I64 => {
+ let out = unsafe { dev.alloc::<i64>(el) }.w()?;
+ let params = (el, dims.len(), &ds, *inp, &out);
+ unsafe { func.launch(cfg, params) }.w()?;
+ CudaStorageSlice::I64(out)
+ }
DType::BF16 => {
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
let params = (el, dims.len(), &ds, *inp, &out);
@@ -1469,6 +1516,11 @@ impl BackendStorage for CudaStorage {
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
Ok(CpuStorage::U32(cpu_storage))
}
+ CudaStorageSlice::I64(slice) => {
+ let dev = slice.device();
+ let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
+ Ok(CpuStorage::I64(cpu_storage))
+ }
CudaStorageSlice::BF16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
@@ -1588,6 +1640,7 @@ impl BackendStorage for CudaStorage {
S::F64(out)
}
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
+ (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
};
Ok(Self { slice, device })
@@ -1802,6 +1855,18 @@ impl BackendStorage for CudaStorage {
unsafe { func.launch(cfg, params) }.w()?
}
}
+ (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
+ let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
+ if src_l.is_contiguous() {
+ dev.dtod_copy(&src, &mut dst).w()?
+ } else {
+ let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?;
+ // SAFETY: Set later by running the kernel.
+ let params = (el_count, dims.len(), &ds, &src, &mut dst);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?
+ }
+ }
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs
index f5c77e4b..da2adf37 100644
--- a/candle-core/src/display.rs
+++ b/candle-core/src/display.rs
@@ -49,6 +49,7 @@ impl std::fmt::Debug for Tensor {
match self.dtype() {
DType::U8 => self.fmt_dt::<u8>(f),
DType::U32 => self.fmt_dt::<u32>(f),
+ DType::I64 => self.fmt_dt::<i64>(f),
DType::BF16 => self.fmt_dt::<bf16>(f),
DType::F16 => self.fmt_dt::<f16>(f),
DType::F32 => self.fmt_dt::<f32>(f),
@@ -431,6 +432,12 @@ impl std::fmt::Display for Tensor {
tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
writeln!(f)?;
}
+ DType::I64 => {
+ let tf: IntFormatter<i64> = IntFormatter::new();
+ let max_w = tf.max_width(&to_display);
+ tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?;
+ writeln!(f)?;
+ }
DType::BF16 => {
if let Ok(tf) = FloatFormatter::<bf16>::new(&to_display, &po) {
let max_w = tf.max_width(&to_display);
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index 7f04a653..91922b11 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -5,6 +5,7 @@ use crate::{CpuStorage, Error, Result};
pub enum DType {
U8,
U32,
+ I64,
BF16,
F16,
F32,
@@ -20,6 +21,7 @@ impl std::str::FromStr for DType {
match s {
"u8" => Ok(Self::U8),
"u32" => Ok(Self::U32),
+ "i64" => Ok(Self::I64),
"bf16" => Ok(Self::BF16),
"f16" => Ok(Self::F16),
"f32" => Ok(Self::F32),
@@ -34,6 +36,7 @@ impl DType {
match self {
Self::U8 => "u8",
Self::U32 => "u32",
+ Self::I64 => "i64",
Self::BF16 => "bf16",
Self::F16 => "f16",
Self::F32 => "f32",
@@ -45,6 +48,7 @@ impl DType {
match self {
Self::U8 => 1,
Self::U32 => 4,
+ Self::I64 => 8,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
@@ -125,6 +129,7 @@ use half::{bf16, f16};
with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
+with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
@@ -135,6 +140,15 @@ pub trait IntDType: WithDType {
fn as_usize(&self) -> usize;
}
+impl IntDType for i64 {
+ fn is_true(&self) -> bool {
+ *self != 0
+ }
+ fn as_usize(&self) -> usize {
+ *self as usize
+ }
+}
+
impl IntDType for u32 {
fn is_true(&self) -> bool {
*self != 0
diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs
index 62c62f9e..f3a75965 100644
--- a/candle-core/src/npy.rs
+++ b/candle-core/src/npy.rs
@@ -85,6 +85,7 @@ impl Header {
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
+ DType::I64 => "i8",
DType::U32 => "u4",
DType::U8 => "u1",
};
@@ -160,7 +161,7 @@ impl Header {
"f" | "f4" => DType::F32,
"d" | "f8" => DType::F64,
// "i" | "i4" => DType::S32,
- // "q" | "i8" => DType::S64,
+ "q" | "i8" => DType::I64,
// "h" | "i2" => DType::S16,
// "b" | "i1" => DType::S8,
"B" | "u1" => DType::U8,
@@ -233,6 +234,11 @@ impl Tensor {
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
+ DType::I64 => {
+ let mut data_t = vec![0i64; elem_count];
+ reader.read_i64_into::<LittleEndian>(&mut data_t)?;
+ Tensor::from_vec(data_t, shape, &Device::Cpu)
+ }
}
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index b8d4e34f..b0528494 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -251,6 +251,7 @@ pub trait UnaryOpT {
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
+ fn i64(v1: i64) -> i64;
// There is no very good way to represent optional function in traits so we go for an explicit
// boolean flag to mark the function as existing.
@@ -274,6 +275,7 @@ pub trait BinaryOpT {
fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32;
+ fn i64(v1: i64, v2: i64) -> i64;
const BF16_VEC: bool = false;
fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
@@ -287,6 +289,8 @@ pub trait BinaryOpT {
fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
const U32_VEC: bool = false;
fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
+ const I64_VEC: bool = false;
+ fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
}
pub(crate) struct Add;
@@ -337,6 +341,10 @@ macro_rules! bin_op {
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
+ #[inline(always)]
+ fn i64(v1: i64, v2: i64) -> i64 {
+ $e(v1, v2)
+ }
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
@@ -420,6 +428,10 @@ macro_rules! unary_op {
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ todo!("no unary function for i64")
+ }
}
};
@@ -452,6 +464,10 @@ macro_rules! unary_op {
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ todo!("no unary function for i64")
+ }
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
@@ -543,6 +559,10 @@ impl UnaryOpT for Gelu {
fn u32(_: u32) -> u32 {
0
}
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ 0
+ }
const KERNEL: &'static str = "ugelu";
#[cfg(feature = "mkl")]
@@ -592,6 +612,10 @@ impl UnaryOpT for Relu {
fn u32(v: u32) -> u32 {
v
}
+ #[inline(always)]
+ fn i64(v: i64) -> i64 {
+ v
+ }
}
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index ec0dc766..f37bb8ef 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -10,6 +10,7 @@ impl From<DType> for st::Dtype {
match value {
DType::U8 => st::Dtype::U8,
DType::U32 => st::Dtype::U32,
+ DType::I64 => st::Dtype::I64,
DType::BF16 => st::Dtype::BF16,
DType::F16 => st::Dtype::F16,
DType::F32 => st::Dtype::F32,
@@ -24,6 +25,7 @@ impl TryFrom<st::Dtype> for DType {
match value {
st::Dtype::U8 => Ok(DType::U8),
st::Dtype::U32 => Ok(DType::U32),
+ st::Dtype::I64 => Ok(DType::I64),
st::Dtype::BF16 => Ok(DType::BF16),
st::Dtype::F16 => Ok(DType::F16),
st::Dtype::F32 => Ok(DType::F32),
@@ -189,6 +191,7 @@ impl Tensor {
match dtype {
DType::U8 => convert_slice::<u8>(data, shape, device),
DType::U32 => convert_slice::<u32>(data, shape, device),
+ DType::I64 => convert_slice::<i64>(data, shape, device),
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
DType::F16 => convert_slice::<half::f16>(data, shape, device),
DType::F32 => convert_slice::<f32>(data, shape, device),
@@ -205,24 +208,15 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_with_cast_::<u16, u32, _>(view, device, conv)
}
st::Dtype::U32 => convert_::<u32>(view, device),
+ st::Dtype::I32 => {
+ let conv = |x| Ok(i64::from(x));
+ convert_with_cast_::<i32, i64, _>(view, device, conv)
+ }
+ st::Dtype::I64 => convert_::<i64>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
- st::Dtype::I32 => {
- let conv = |x| {
- u32::try_from(x)
- .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
- };
- convert_with_cast_::<i32, u32, _>(view, device, conv)
- }
- st::Dtype::I64 => {
- let conv = |x| {
- u32::try_from(x)
- .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
- };
- convert_with_cast_::<i64, u32, _>(view, device, conv)
- }
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
@@ -233,6 +227,7 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
match tensor.dtype() {
DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
+ DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),