summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-12-15 13:06:04 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-12-15 13:06:04 +0100
commit6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0 (patch)
tree41848e54f8d9542cbcb09cde31290906eaf5e8ca /candle-nn
parentaa040150985e78079bcc05df86266e447c23b4fc (diff)
downloadcandle-6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0.tar.gz
candle-6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0.tar.bz2
candle-6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0.zip
Addressing a lot of comments.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/ops.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 94380f12..816eff42 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -220,7 +220,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
let n = layout.stride().len();
- if !(layout.is_contiguous() && layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
+ if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
candle::bail!("Non contiguous softmax-last-dim is not implemented");
}
@@ -235,6 +235,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
elem_count,
last_dim,
storage.buffer(),
+ layout.start_offset() * storage.dtype().size_in_bytes(),
&mut output,
)
.unwrap();