summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs7
1 files changed, 4 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ecc018f9..5d4e106f 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,3 +1,4 @@
+use crate::backend::{BackendDevice, BackendStorage};
use crate::shape::Dim;
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::Arc;
@@ -963,19 +964,19 @@ impl Tensor {
/// If the target device is the same as the tensor device, only a shallow copy is performed.
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
- if self.device().same_id(device) {
+ if self.device().same_device(device) {
Ok(self.clone())
} else {
let storage = match (self.storage.as_ref(), device) {
(Storage::Cpu(storage), Device::Cuda(cuda)) => {
- Storage::Cuda(cuda.cuda_from_cpu_storage(storage)?)
+ Storage::Cuda(cuda.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
let cpu_storage = storage.to_cpu_storage()?;
- Storage::Cuda(cuda.cuda_from_cpu_storage(&cpu_storage)?)
+ Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
};