diff options
author | laurent <laurent.mazare@gmail.com> | 2023-06-28 15:43:03 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-06-28 15:43:03 +0100 |
commit | 3f0d9fbb257baf94acde184de76eb9667e0fa025 (patch) | |
tree | 9bd3217971362a991faac24968f9bf77bf663476 /candle-core/src/cpu_backend.rs | |
parent | cca699be6c8167f565067ceb3c940dd3c1d87503 (diff) | |
download | candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.tar.gz candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.tar.bz2 candle-3f0d9fbb257baf94acde184de76eb9667e0fa025.zip |
Adapt the cuda bits.
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 19 |
1 files changed, 4 insertions, 15 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 9f0c8602..f1547b3c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -101,14 +101,9 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>( } } -fn take_impl1<T: Copy>( - vs: &[T], - ids: &[u32], - layout: &Layout, - vocab_size: usize, - hidden_size: usize, -) -> Result<Vec<T>> { +fn take_impl1<T: Copy>(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result<Vec<T>> { // TODO: Optimize for the case where ids are contiguous. + let (vocab_size, hidden_size) = rhs_l.shape().r2()?; let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); for index in layout.strided_index() { let index = ids[index].try_into()?; @@ -610,15 +605,9 @@ impl CpuStorage { } } - pub(crate) fn embedding( - &self, - layout: &Layout, - vs: &Self, - hidden_size: usize, - vocab_size: usize, - ) -> Result<Self> { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { let ids = self.as_slice::<u32>()?; - map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) + map1!(rhs, take_impl1, ids, layout, rhs_l) } pub(crate) fn matmul( |