summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/replit-code/README.md45
-rw-r--r--candle-examples/examples/replit-code/main.rs2
-rw-r--r--candle-transformers/src/models/mpt.rs30
3 files changed, 64 insertions, 13 deletions
diff --git a/candle-examples/examples/replit-code/README.md b/candle-examples/examples/replit-code/README.md
new file mode 100644
index 00000000..84ed4c1c
--- /dev/null
+++ b/candle-examples/examples/replit-code/README.md
@@ -0,0 +1,45 @@
+# candle-replit-code: code completion specialized model.
+
+[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a
+language model specialized for code completion. This model uses 3.3B parameters
+in `bfloat16` (so the GPU version will only work on recent nvidia cards).
+
+## Running some example
+
+```bash
+cargo run --example replit-code --release -- --prompt 'def fibonacci(n): '
+```
+This produces the following output which actually doesn't generate the fibonacci
+series properly.
+
+```
+def fibonacci(n): # write Fibonacci series up to n
+ """Print a Fibonacci series up to n."""
+
+ assert type(n) == int, "n must be an integer"
+
+ if (type(fib_list)==None or len==0 ):
+ fib_list = [1]
+
+ for i in range((len-2)): # start at 2nd element of list and go until end.
+ n += 1
+
+ print("Fibonacci number",n,"is:",i)
+
+def main():
+ """Call the functions."""
+
+ userInput=input('Enter a positive integer: ')
+
+ fibonacci(userInput)
+
+
+
+
+
+
+
+if __name__ == '__main__': # only run if this file is called directly.
+ print("This program prints out Fibonacci numbers.")
+ main()
+```
diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs
index 97429b7b..87b7d216 100644
--- a/candle-examples/examples/replit-code/main.rs
+++ b/candle-examples/examples/replit-code/main.rs
@@ -139,7 +139,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
- #[arg(long, short = 'n', default_value_t = 100)]
+ #[arg(long, short = 'n', default_value_t = 1000)]
sample_len: usize,
#[arg(long)]
diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs
index f382a4bb..c1efe16f 100644
--- a/candle-transformers/src/models/mpt.rs
+++ b/candle-transformers/src/models/mpt.rs
@@ -103,23 +103,25 @@ impl GroupedQueryAttention {
(k, v)
}
};
- let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?;
- let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?;
+ self.kv_cache = Some((key.clone(), value.clone()));
+ let query = query.contiguous()?;
+ let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
let attn_bias = {
let s_q = query.dim(D::Minus2)?;
let s_k = key.dim(D::Minus1)?;
let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
- self.attn_bias
- .narrow(2, a_q - s_q, s_q)?
- .narrow(3, a_k - s_k, s_k)?
+ let start_q = a_q.saturating_sub(s_q);
+ let start_k = a_k.saturating_sub(s_k);
+ self.attn_bias.i((.., .., start_q.., start_k..))?
};
- let attn_weights = (attn_weights + attn_bias)?;
+ let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
&attn_weights,
- &mask.broadcast_left(b_size * self.n_heads)?,
+ &mask.broadcast_as(attn_weights.shape())?,
f32::NEG_INFINITY,
)?,
};
@@ -128,7 +130,8 @@ impl GroupedQueryAttention {
.matmul(&value)?
.transpose(1, 2)?
.flatten_from(D::Minus2)?;
- attn_output.apply(&self.out_proj)
+ let out = attn_output.apply(&self.out_proj)?;
+ Ok(out)
}
}
@@ -199,7 +202,7 @@ impl MPTBlock {
let xs = self.attn.forward(&xs, mask)?;
let xs = (xs + residual)?;
let residual = &xs;
- let xs = xs.apply(&self.norm2)?.apply(&self.ffn);
+ let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
xs + residual
}
}
@@ -275,12 +278,15 @@ impl Model {
Some(get_mask(seq_len, xs.device())?)
};
for block in self.blocks.iter_mut() {
- xs = block.forward(&xs, mask.as_ref())?
+ xs = block.forward(&xs, mask.as_ref())?;
}
- xs.narrow(1, seq_len - 1, 1)?
+ let xs = xs.apply(&self.norm_f)?;
+ let logits = xs
+ .narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.matmul(&self.wte.embeddings().t()?)?
- .squeeze(1)
+ .squeeze(1)?;
+ Ok(logits)
}
}