summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-06-18 23:46:58 +0200
committerGitHub <noreply@github.com>2024-06-18 23:46:58 +0200
commit36cf54525d93660f62c3601ba0988653f3567e0e (patch)
tree25e9fe0fc75782186f2dc810e8381e553a4ac138 /candle-transformers
parent2b10aaa05d3752186899bd5b5364d92164edc7ef (diff)
downloadcandle-36cf54525d93660f62c3601ba0988653f3567e0e.tar.gz
candle-36cf54525d93660f62c3601ba0988653f3567e0e.tar.bz2
candle-36cf54525d93660f62c3601ba0988653f3567e0e.zip
Fix the fast bf16 gemm cublas kernels. (#2274)
* Use flash-attn in gemma. * Fix for the fast bf16 cublas gemm. * Fix some clippy lints. * Fix another lint. * Proper clippy fix.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/vgg.rs3
1 files changed, 1 insertions, 2 deletions
diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs
index a20b5e37..010643c8 100644
--- a/candle-transformers/src/models/vgg.rs
+++ b/candle-transformers/src/models/vgg.rs
@@ -54,8 +54,7 @@ impl ModuleT for Vgg<'_> {
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
let layers = convs
.iter()
- .enumerate()
- .map(|(_, &(in_c, out_c, name))| {
+ .map(|&(in_c, out_c, name)| {
candle_nn::conv2d(
in_c,
out_c,