summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-11-10 01:24:49 +0100
committerNicolas Patry <nicolas@Nicolass-MacBook-Pro.local>2023-11-20 14:12:57 +0100
commitdf6814f34ef8cbbf0b5e9e98fc8a71690cf8e8a4 (patch)
tree4c63cc6ce80fb80a85401f77d74bf7a5e0bccf5f /candle-core/src/tensor.rs
parent39406a67214b01f85d5f3e2095ee36eb13d3cbf3 (diff)
downloadcandle-df6814f34ef8cbbf0b5e9e98fc8a71690cf8e8a4.tar.gz
candle-df6814f34ef8cbbf0b5e9e98fc8a71690cf8e8a4.tar.bz2
candle-df6814f34ef8cbbf0b5e9e98fc8a71690cf8e8a4.zip
Refactor to simplify our lives for settings the params in the encoder.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 2a0924b6..f7f66668 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1859,7 +1859,11 @@ impl Tensor {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
+ (Storage::Cpu(storage), Device::Metal(metal)) => {
+ Storage::Metal(metal.storage_from_cpu_storage(storage)?)
+ }
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
+ (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.