diff options
author | Juarez Bochi <jbochi@gmail.com> | 2023-09-12 09:10:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-12 18:10:16 +0200 |
commit | 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch) | |
tree | 0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-transformers/tests | |
parent | 42da17694a4214a3e39e0d64afc22635ce83f557 (diff) | |
download | candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2 candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip |
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling
* Update changelog
* rustfmt
* Add tests
* Fix clippy warning
* Fix another clippy error
Diffstat (limited to 'candle-transformers/tests')
-rw-r--r-- | candle-transformers/tests/generation_tests.rs | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs new file mode 100644 index 00000000..76f994d0 --- /dev/null +++ b/candle-transformers/tests/generation_tests.rs @@ -0,0 +1,29 @@ +use candle::{Device, Result, Tensor}; +use candle_transformers::generation::LogitsProcessor; + +#[test] +fn sample_with_zero_temperature() -> Result<()> { + let mut logits_process = LogitsProcessor::new(1337, None, None); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 3); + Ok(()) +} + +#[test] +fn sample_with_temperature() -> Result<()> { + let mut logits_process = LogitsProcessor::new(42, Some(0.9), None); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 0); + Ok(()) +} + +#[test] +fn sample_with_top_p() -> Result<()> { + let mut logits_process = LogitsProcessor::new(42, Some(1.0), Some(0.5)); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 2); + Ok(()) +} |