summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/t5/src/bin/m-quantized.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/t5/src/bin/m-quantized.rs')
-rw-r--r--candle-wasm-examples/t5/src/bin/m-quantized.rs9
1 files changed, 5 insertions, 4 deletions
diff --git a/candle-wasm-examples/t5/src/bin/m-quantized.rs b/candle-wasm-examples/t5/src/bin/m-quantized.rs
index 2f490b84..3b99a275 100644
--- a/candle-wasm-examples/t5/src/bin/m-quantized.rs
+++ b/candle-wasm-examples/t5/src/bin/m-quantized.rs
@@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{
use candle_wasm_example_t5::console_log;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
+const DEVICE: Device = Device::Cpu;
#[wasm_bindgen]
pub struct ModelEncoder {
@@ -31,7 +32,7 @@ impl ModelConditionalGeneration {
) -> Result<ModelConditionalGeneration, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
- let vb = VarBuilder::from_gguf_buffer(&weights)?;
+ let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
@@ -46,7 +47,7 @@ impl ModelConditionalGeneration {
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: ConditionalGenerationParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
- let device = &Device::Cpu;
+ let device = &DEVICE;
self.model.clear_kv_cache();
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
let prompt = input.prompt;
@@ -128,7 +129,7 @@ impl ModelEncoder {
) -> Result<ModelEncoder, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
- let vb = VarBuilder::from_gguf_buffer(&weights)?;
+ let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false;
let tokenizer =
@@ -138,7 +139,7 @@ impl ModelEncoder {
}
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
- let device = &Device::Cpu;
+ let device = &DEVICE;
let input: DecoderParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;