summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/phi/src/bin/m.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/phi/src/bin/m.rs')
-rw-r--r--candle-wasm-examples/phi/src/bin/m.rs21
1 files changed, 19 insertions, 2 deletions
diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs
index c18e6c38..999f276d 100644
--- a/candle-wasm-examples/phi/src/bin/m.rs
+++ b/candle-wasm-examples/phi/src/bin/m.rs
@@ -5,6 +5,7 @@ use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausa
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle_wasm_example_phi::console_log;
use js_sys::Date;
+use serde::Deserialize;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
@@ -23,6 +24,12 @@ pub struct Model {
repeat_last_n: usize,
}
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+
+pub struct ModelName {
+ pub _name_or_path: String,
+}
+
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
@@ -34,15 +41,25 @@ impl Model {
) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
+ let name: ModelName = serde_json::from_slice(&config)?;
let config: Config = serde_json::from_slice(&config)?;
+
+ console_log!("config loaded {:?}", name);
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let start = Date::now();
+ console_log!("weights len: {:?}", weights.len());
let model = if quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
- let model = QMixFormer::new(&config, vb)?;
- SelectedModel::Quantized(model)
+ console_log!("weights loaded");
+ if name._name_or_path == "microsoft/phi-2" {
+ let model = QMixFormer::new_v2(&config, vb)?;
+ SelectedModel::Quantized(model)
+ } else {
+ let model = QMixFormer::new(&config, vb)?;
+ SelectedModel::Quantized(model)
+ }
} else {
let device = &Device::Cpu;
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;