summaryrefslogtreecommitdiff
path: root/src/device.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-22 08:33:32 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-22 08:33:32 +0100
commit0a758ffa0523629336e7224fa181dd1e76d8919c (patch)
tree8338a53820ea5a124f97ea9b1fca91788f0f4e4f /src/device.rs
parentfc26bab3ede511c3c4d2f1afb15f58eb6c588c94 (diff)
downloadcandle-0a758ffa0523629336e7224fa181dd1e76d8919c.tar.gz
candle-0a758ffa0523629336e7224fa181dd1e76d8919c.tar.bz2
candle-0a758ffa0523629336e7224fa181dd1e76d8919c.zip
Add the fill kernel and use it for 'ones'.
Diffstat (limited to 'src/device.rs')
-rw-r--r--src/device.rs5
1 files changed, 1 insertions, 4 deletions
diff --git a/src/device.rs b/src/device.rs
index e522cd42..ab7bad26 100644
--- a/src/device.rs
+++ b/src/device.rs
@@ -82,10 +82,7 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
- // TODO: Instead of allocating memory on the host and transfering it,
- // allocate some zeros on the device and use a shader to set them to 1.
- let storage = CpuStorage::ones_impl(shape, dtype);
- let storage = device.cuda_from_cpu_storage(&storage)?;
+ let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
}