diff options
Diffstat (limited to 'candle-wasm-examples/t5/src/bin/m-quantized.rs')
-rw-r--r-- | candle-wasm-examples/t5/src/bin/m-quantized.rs | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/candle-wasm-examples/t5/src/bin/m-quantized.rs b/candle-wasm-examples/t5/src/bin/m-quantized.rs index 2f490b84..3b99a275 100644 --- a/candle-wasm-examples/t5/src/bin/m-quantized.rs +++ b/candle-wasm-examples/t5/src/bin/m-quantized.rs @@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{ use candle_wasm_example_t5::console_log; use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; +const DEVICE: Device = Device::Cpu; #[wasm_bindgen] pub struct ModelEncoder { @@ -31,7 +32,7 @@ impl ModelConditionalGeneration { ) -> Result<ModelConditionalGeneration, JsError> { console_error_panic_hook::set_once(); console_log!("loading model"); - let vb = VarBuilder::from_gguf_buffer(&weights)?; + let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let mut config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; @@ -46,7 +47,7 @@ impl ModelConditionalGeneration { pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { let input: ConditionalGenerationParams = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; - let device = &Device::Cpu; + let device = &DEVICE; self.model.clear_kv_cache(); let mut output_token_ids = [self.config.pad_token_id as u32].to_vec(); let prompt = input.prompt; @@ -128,7 +129,7 @@ impl ModelEncoder { ) -> Result<ModelEncoder, JsError> { console_error_panic_hook::set_once(); console_log!("loading model"); - let vb = VarBuilder::from_gguf_buffer(&weights)?; + let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let mut config: Config = serde_json::from_slice(&config)?; config.use_cache = false; let tokenizer = @@ -138,7 +139,7 @@ impl ModelEncoder { } pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { - let device = &Device::Cpu; + let device = &DEVICE; let input: DecoderParams = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; |