diff options
Diffstat (limited to 'candle-wasm-examples/phi/src/bin/m.rs')
-rw-r--r-- | candle-wasm-examples/phi/src/bin/m.rs | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs index c18e6c38..999f276d 100644 --- a/candle-wasm-examples/phi/src/bin/m.rs +++ b/candle-wasm-examples/phi/src/bin/m.rs @@ -5,6 +5,7 @@ use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausa use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle_wasm_example_phi::console_log; use js_sys::Date; +use serde::Deserialize; use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; @@ -23,6 +24,12 @@ pub struct Model { repeat_last_n: usize, } +#[derive(Debug, Clone, PartialEq, Deserialize)] + +pub struct ModelName { + pub _name_or_path: String, +} + #[wasm_bindgen] impl Model { #[wasm_bindgen(constructor)] @@ -34,15 +41,25 @@ impl Model { ) -> Result<Model, JsError> { console_error_panic_hook::set_once(); console_log!("loading model"); + let name: ModelName = serde_json::from_slice(&config)?; let config: Config = serde_json::from_slice(&config)?; + + console_log!("config loaded {:?}", name); let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; let start = Date::now(); + console_log!("weights len: {:?}", weights.len()); let model = if quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?; - let model = QMixFormer::new(&config, vb)?; - SelectedModel::Quantized(model) + console_log!("weights loaded"); + if name._name_or_path == "microsoft/phi-2" { + let model = QMixFormer::new_v2(&config, vb)?; + SelectedModel::Quantized(model) + } else { + let model = QMixFormer::new(&config, vb)?; + SelectedModel::Quantized(model) + } } else { let device = &Device::Cpu; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?; |