diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-01 10:05:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-01 11:05:07 +0200 |
commit | 957d604a7888bbf0243dbbca83a438db5132b48f (patch) | |
tree | 8aa92225904d90f8f8c3adf460d1dc5d1caaf4bd /candle-examples/examples | |
parent | ce90287f45817fc22b53f576ad9bd27b1b0ebeb8 (diff) | |
download | candle-957d604a7888bbf0243dbbca83a438db5132b48f.tar.gz candle-957d604a7888bbf0243dbbca83a438db5132b48f.tar.bz2 candle-957d604a7888bbf0243dbbca83a438db5132b48f.zip |
Enable BF16 on metal. (#2380)
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/phi/main.rs | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 1a0d9aca..ceddc35e 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -361,10 +361,8 @@ fn main() -> Result<()> { let dtype = match args.dtype { Some(dtype) => std::str::FromStr::from_str(&dtype)?, None => { - if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium) - && device.is_cuda() - { - DType::BF16 + if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium { + device.bf16_default_to_f32() } else { DType::F32 } |