summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs19
1 files changed, 4 insertions, 15 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 9f0c8602..f1547b3c 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -101,14 +101,9 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
}
}
-fn take_impl1<T: Copy>(
- vs: &[T],
- ids: &[u32],
- layout: &Layout,
- vocab_size: usize,
- hidden_size: usize,
-) -> Result<Vec<T>> {
+fn take_impl1<T: Copy>(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result<Vec<T>> {
// TODO: Optimize for the case where ids are contiguous.
+ let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size);
for index in layout.strided_index() {
let index = ids[index].try_into()?;
@@ -610,15 +605,9 @@ impl CpuStorage {
}
}
- pub(crate) fn embedding(
- &self,
- layout: &Layout,
- vs: &Self,
- hidden_size: usize,
- vocab_size: usize,
- ) -> Result<Self> {
+ pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
let ids = self.as_slice::<u32>()?;
- map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
+ map1!(rhs, take_impl1, ids, layout, rhs_l)
}
pub(crate) fn matmul(