summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-01 10:05:07 +0100
committerGitHub <noreply@github.com>2024-08-01 11:05:07 +0200
commit957d604a7888bbf0243dbbca83a438db5132b48f (patch)
tree8aa92225904d90f8f8c3adf460d1dc5d1caaf4bd /candle-examples/examples
parentce90287f45817fc22b53f576ad9bd27b1b0ebeb8 (diff)
downloadcandle-957d604a7888bbf0243dbbca83a438db5132b48f.tar.gz
candle-957d604a7888bbf0243dbbca83a438db5132b48f.tar.bz2
candle-957d604a7888bbf0243dbbca83a438db5132b48f.zip
Enable BF16 on metal. (#2380)
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/phi/main.rs6
1 files changed, 2 insertions, 4 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 1a0d9aca..ceddc35e 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -361,10 +361,8 @@ fn main() -> Result<()> {
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
- if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
- && device.is_cuda()
- {
- DType::BF16
+ if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium {
+ device.bf16_default_to_f32()
} else {
DType::F32
}