summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-12-15 13:06:04 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-12-15 13:06:04 +0100
commit6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0 (patch)
tree41848e54f8d9542cbcb09cde31290906eaf5e8ca /candle-metal-kernels
parentaa040150985e78079bcc05df86266e447c23b4fc (diff)
downloadcandle-6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0.tar.gz
candle-6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0.tar.bz2
candle-6bc92e63cb4d1a3bb4910348861b9f7e80dfa4f0.zip
Addressing a lot of comments.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs6
-rw-r--r--candle-metal-kernels/src/tests.rs21
2 files changed, 16 insertions, 11 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index a23aa47c..f2db171e 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -597,6 +597,7 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
+ input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@@ -604,7 +605,10 @@ pub fn call_last_softmax(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, elements_to_sum, input, output));
+ set_params!(
+ encoder,
+ (length, elements_to_sum, (input, input_offset), output)
+ );
let out_length = length / elements_to_sum;
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 75c2f013..9c9475a2 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -312,7 +312,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&device,
command_buffer,
&kernels,
- "affine_float",
+ "affine_f32",
size,
&input,
&output,
@@ -346,7 +346,7 @@ fn run_affine_strided<T: Clone>(
&device,
command_buffer,
&kernels,
- "affine_float_strided",
+ "affine_f32_strided",
shape,
&input,
strides,
@@ -608,6 +608,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
v.len(),
last_dim,
&input,
+ 0,
&output,
)
.unwrap();
@@ -622,7 +623,7 @@ fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
- let results = run_reduce(&v, out_length, "fast_sum_float");
+ let results = run_reduce(&v, out_length, "fast_sum_f32");
assert_eq!(approx(results, 4), vec![21.0]);
}
@@ -631,7 +632,7 @@ fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
- let results = run_reduce(&v, out_length, "fast_sum_float");
+ let results = run_reduce(&v, out_length, "fast_sum_f32");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}
@@ -639,7 +640,7 @@ fn reduce_sum2() {
fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
@@ -651,7 +652,7 @@ fn softmax() {
for i in 0..n {
v[i * last_dim] = 20.0;
}
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4);
println!("{results:?}");
assert_eq!(
@@ -665,7 +666,7 @@ fn softmax() {
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
@@ -673,7 +674,7 @@ fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
@@ -684,7 +685,7 @@ fn softmax() {
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_half");
+ let results = run_softmax(&v, last_dim, "softmax_f16");
assert_eq!(
approx_f16(results, 4),
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
@@ -695,7 +696,7 @@ fn softmax() {
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_bfloat");
+ let results = run_softmax(&v, last_dim, "softmax_bf16");
assert_eq!(
approx_bf16(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]