summaryrefslogtreecommitdiff
path: root/candle-transformers/tests/generation_tests.rs
diff options
context:
space:
mode:
authorJuarez Bochi <jbochi@gmail.com>2023-09-12 09:10:16 -0700
committerGitHub <noreply@github.com>2023-09-12 18:10:16 +0200
commit805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch)
tree0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-transformers/tests/generation_tests.rs
parent42da17694a4214a3e39e0d64afc22635ce83f557 (diff)
downloadcandle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
Diffstat (limited to 'candle-transformers/tests/generation_tests.rs')
-rw-r--r--candle-transformers/tests/generation_tests.rs29
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs
new file mode 100644
index 00000000..76f994d0
--- /dev/null
+++ b/candle-transformers/tests/generation_tests.rs
@@ -0,0 +1,29 @@
+use candle::{Device, Result, Tensor};
+use candle_transformers::generation::LogitsProcessor;
+
+#[test]
+fn sample_with_zero_temperature() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(1337, None, None);
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 3);
+ Ok(())
+}
+
+#[test]
+fn sample_with_temperature() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(42, Some(0.9), None);
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 0);
+ Ok(())
+}
+
+#[test]
+fn sample_with_top_p() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(42, Some(1.0), Some(0.5));
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 2);
+ Ok(())
+}