summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/shape.rs34
-rw-r--r--candle-examples/examples/musicgen/musicgen_model.rs2
-rw-r--r--candle-transformers/src/models/segment_anything/sam.rs2
-rw-r--r--candle-wasm-examples/llama2-c/src/app.rs12
-rw-r--r--candle-wasm-examples/whisper/src/app.rs16
-rw-r--r--candle-wasm-examples/yolo/src/app.rs12
6 files changed, 38 insertions, 40 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index beaa9455..32ebb23f 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -478,23 +478,6 @@ extract_dims!(
(usize, usize, usize, usize, usize)
);
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn stride() {
- let shape = Shape::from(());
- assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
- let shape = Shape::from(42);
- assert_eq!(shape.stride_contiguous(), [1]);
- let shape = Shape::from((42, 1337));
- assert_eq!(shape.stride_contiguous(), [1337, 1]);
- let shape = Shape::from((299, 792, 458));
- assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
- }
-}
-
pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>;
}
@@ -627,3 +610,20 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
Ok((d1, d2, d3, d4, d).into())
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn stride() {
+ let shape = Shape::from(());
+ assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
+ let shape = Shape::from(42);
+ assert_eq!(shape.stride_contiguous(), [1]);
+ let shape = Shape::from((42, 1337));
+ assert_eq!(shape.stride_contiguous(), [1337, 1]);
+ let shape = Shape::from((299, 792, 458));
+ assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
+ }
+}
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs
index dc8c3667..c6b52fde 100644
--- a/candle-examples/examples/musicgen/musicgen_model.rs
+++ b/candle-examples/examples/musicgen/musicgen_model.rs
@@ -321,7 +321,7 @@ impl MusicgenDecoder {
let positions = self.embed_positions.forward(&input)?.to_device(dev)?;
let mut xs = inputs_embeds.broadcast_add(&positions)?;
let attention_mask = self.prepare_decoder_attention_mask(b_sz, seq_len)?;
- for (_layer_idx, decoder_layer) in self.layers.iter_mut().enumerate() {
+ for decoder_layer in self.layers.iter_mut() {
xs = decoder_layer.forward(&xs, &attention_mask, None)?;
}
let xs = self.layer_norm.forward(&xs)?;
diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs
index ccf9ca7a..a2156a75 100644
--- a/candle-transformers/src/models/segment_anything/sam.rs
+++ b/candle-transformers/src/models/segment_anything/sam.rs
@@ -184,7 +184,7 @@ impl Sam {
let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?;
Some((points, labels))
};
- let points = points.as_ref().map(|(x, y)| (x, y));
+ let points = points.as_ref().map(|xy| (&xy.0, &xy.1));
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?;
self.mask_decoder.forward(
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
index f254a5ae..1e40b77e 100644
--- a/candle-wasm-examples/llama2-c/src/app.rs
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -34,8 +34,8 @@ pub enum Msg {
Run,
UpdateStatus(String),
SetModel(ModelData),
- WorkerInMsg(WorkerInput),
- WorkerOutMsg(Result<WorkerOutput, String>),
+ WorkerIn(WorkerInput),
+ WorkerOut(Result<WorkerOutput, String>),
}
pub struct CurrentDecode {
@@ -75,7 +75,7 @@ impl Component for App {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
- move |e| link.send_message(Self::Message::WorkerOutMsg(e))
+ move |e| link.send_message(Self::Message::WorkerOut(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
@@ -128,11 +128,11 @@ impl Component for App {
let prompt = self.prompt.borrow().clone();
console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
ctx.link()
- .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
+ .send_message(Msg::WorkerIn(WorkerInput::Run(temp, top_p, prompt)))
}
true
}
- Msg::WorkerOutMsg(output) => {
+ Msg::WorkerOut(output) => {
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::GenerationDone(Err(err))) => {
@@ -165,7 +165,7 @@ impl Component for App {
}
true
}
- Msg::WorkerInMsg(inp) => {
+ Msg::WorkerIn(inp) => {
self.worker.send(inp);
true
}
diff --git a/candle-wasm-examples/whisper/src/app.rs b/candle-wasm-examples/whisper/src/app.rs
index 1cb31193..e344096c 100644
--- a/candle-wasm-examples/whisper/src/app.rs
+++ b/candle-wasm-examples/whisper/src/app.rs
@@ -42,8 +42,8 @@ pub enum Msg {
Run(usize),
UpdateStatus(String),
SetDecoder(ModelData),
- WorkerInMsg(WorkerInput),
- WorkerOutMsg(Result<WorkerOutput, String>),
+ WorkerIn(WorkerInput),
+ WorkerOut(Result<WorkerOutput, String>),
}
pub struct CurrentDecode {
@@ -116,7 +116,7 @@ impl Component for App {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
- move |e| link.send_message(Self::Message::WorkerOutMsg(e))
+ move |e| link.send_message(Self::Message::WorkerOut(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
@@ -165,18 +165,16 @@ impl Component for App {
Err(err) => {
let output = Err(format!("decoding error: {err:?}"));
// Mimic a worker output to so as to release current_decode
- Msg::WorkerOutMsg(output)
- }
- Ok(wav_bytes) => {
- Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
+ Msg::WorkerOut(output)
}
+ Ok(wav_bytes) => Msg::WorkerIn(WorkerInput::DecodeTask { wav_bytes }),
}
})
}
//
true
}
- Msg::WorkerOutMsg(output) => {
+ Msg::WorkerOut(output) => {
let dt = self.current_decode.as_ref().and_then(|current_decode| {
current_decode.start_time.and_then(|start_time| {
performance_now().map(|stop_time| stop_time - start_time)
@@ -198,7 +196,7 @@ impl Component for App {
}
true
}
- Msg::WorkerInMsg(inp) => {
+ Msg::WorkerIn(inp) => {
self.worker.send(inp);
true
}
diff --git a/candle-wasm-examples/yolo/src/app.rs b/candle-wasm-examples/yolo/src/app.rs
index 3a88a5f1..a68284fa 100644
--- a/candle-wasm-examples/yolo/src/app.rs
+++ b/candle-wasm-examples/yolo/src/app.rs
@@ -33,8 +33,8 @@ pub enum Msg {
Run,
UpdateStatus(String),
SetModel(ModelData),
- WorkerInMsg(WorkerInput),
- WorkerOutMsg(Result<WorkerOutput, String>),
+ WorkerIn(WorkerInput),
+ WorkerOut(Result<WorkerOutput, String>),
}
pub struct CurrentDecode {
@@ -117,7 +117,7 @@ impl Component for App {
let status = "loading weights".to_string();
let cb = {
let link = ctx.link().clone();
- move |e| link.send_message(Self::Message::WorkerOutMsg(e))
+ move |e| link.send_message(Self::Message::WorkerOut(e))
};
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
@@ -166,7 +166,7 @@ impl Component for App {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
- Ok(image_data) => Msg::WorkerInMsg(WorkerInput::RunData(RunData {
+ Ok(image_data) => Msg::WorkerIn(WorkerInput::RunData(RunData {
image_data,
conf_threshold: 0.5,
iou_threshold: 0.5,
@@ -176,7 +176,7 @@ impl Component for App {
}
true
}
- Msg::WorkerOutMsg(output) => {
+ Msg::WorkerOut(output) => {
match output {
Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
Ok(WorkerOutput::ProcessingDone(Err(err))) => {
@@ -218,7 +218,7 @@ impl Component for App {
}
true
}
- Msg::WorkerInMsg(inp) => {
+ Msg::WorkerIn(inp) => {
self.worker.send(inp);
true
}