diff options
-rw-r--r-- | candle-core/src/cuda_backend.rs | 16 | ||||
-rw-r--r-- | candle-examples/examples/bigcode/model.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/llama/model.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/model.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/simple-training/main.rs | 139 | ||||
-rw-r--r-- | candle-kernels/src/ternary.cu | 23 | ||||
-rw-r--r-- | candle-nn/src/var_builder.rs | 1 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 3 |
9 files changed, 161 insertions, 31 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 4050b595..a88d62c7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -940,16 +940,22 @@ impl<'a> Map2 for WhereCond<'a> { dev: &CudaDevice, ) -> Result<CudaSlice<T>> { let ids_l = &self.1; - let ids = match &self.0.slice { - CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + let (ids, name) = match &self.0.slice { + CudaStorageSlice::U8(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_u8") + } + CudaStorageSlice::U32(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_u32") + } _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u32", + msg: "where conditions should be u8 or u32", expected: DType::U32, got: self.0.dtype(), }) .w()?, }; - let ids = &ids; let shape = ids_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); @@ -959,7 +965,7 @@ impl<'a> Map2 for WhereCond<'a> { .w()?; let t = &t.slice(layout_t.start_offset()..); let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?; + let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::<T>(el) }.w()?; let params = (el, dims.len(), &ds, ids, t, f, &out); diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs index ddee0d27..99f5bb5a 100644 --- a/candle-examples/examples/bigcode/model.rs +++ b/candle-examples/examples/bigcode/model.rs @@ -24,7 +24,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u32::from(j <= i))) + .flat_map(|i| (0..t).map(move |j| u8::from(j <= i))) .collect(); let mask = Tensor::from_slice(&mask, (t, t), device)?; Ok(mask) diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index cab0b314..1c77cbaf 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -424,7 +424,7 @@ pub struct Falcon { fn make_causal_mask(t: usize) -> Result<Tensor> { let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; Ok(mask) diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index dba1d535..ae27afc1 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -91,9 +91,8 @@ impl Cache { if let Some(mask) = masks.get(&t) { Ok(mask.clone()) } else { - // TODO: If we support bool or u8 tensors, this would be better. let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; masks.insert(t, mask.clone()); diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 6d9e4bcd..9e1c3eda 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -47,9 +47,8 @@ impl Cache { if let Some(mask) = masks.get(&t) { Ok(mask.clone()) } else { - // TODO: If we support bool or u8 tensors, this would be better. let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; masks.insert(t, mask.clone()); diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index edec2e92..35b938e8 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -1,16 +1,130 @@ -// This should rearch 91.5% accuracy. +// This should reach 91.5% accuracy. #[cfg(feature = "mkl")] extern crate intel_mkl_src; -use anyhow::Result; -use candle::{DType, Var, D}; -use candle_nn::{loss, ops}; +use candle::{DType, Device, Result, Shape, Tensor, Var, D}; +use candle_nn::{loss, ops, Linear}; +use std::sync::{Arc, Mutex}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; -pub fn main() -> Result<()> { +struct TensorData { + tensors: std::collections::HashMap<String, Var>, + pub dtype: DType, + pub device: Device, +} + +// A variant of candle_nn::VarBuilder for initializing variables before training. +#[derive(Clone)] +struct VarStore { + data: Arc<Mutex<TensorData>>, + path: Vec<String>, +} + +impl VarStore { + fn new(dtype: DType, device: Device) -> Self { + let data = TensorData { + tensors: std::collections::HashMap::new(), + dtype, + device, + }; + Self { + data: Arc::new(Mutex::new(data)), + path: vec![], + } + } + + fn pp(&self, s: &str) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + } + } + + fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> { + let shape = shape.into(); + let path = if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + }; + let mut tensor_data = self.data.lock().unwrap(); + if let Some(tensor) = tensor_data.tensors.get(&path) { + let tensor_shape = tensor.shape(); + if &shape != tensor_shape { + candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(tensor.as_tensor().clone()); + } + // TODO: Proper initialization using the `Init` enum. + let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?; + let tensor = var.as_tensor().clone(); + tensor_data.tensors.insert(path, var); + Ok(tensor) + } + + fn all_vars(&self) -> Vec<Var> { + let tensor_data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + tensor_data + .tensors + .values() + .map(|c| c.clone()) + .collect::<Vec<_>>() + } +} + +fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> { + let ws = vs.get((dim2, dim1), "weight")?; + let bs = vs.get(dim2, "bias")?; + Ok(Linear::new(ws, Some(bs))) +} + +#[allow(unused)] +struct LinearModel { + linear: Linear, +} + +#[allow(unused)] +impl LinearModel { + fn new(vs: VarStore) -> Result<Self> { + let linear = linear(IMAGE_DIM, LABELS, vs)?; + Ok(Self { linear }) + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.linear.forward(xs) + } +} + +#[allow(unused)] +struct Mlp { + ln1: Linear, + ln2: Linear, +} + +#[allow(unused)] +impl Mlp { + fn new(vs: VarStore) -> Result<Self> { + let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?; + let ln2 = linear(100, LABELS, vs.pp("ln2"))?; + Ok(Self { ln1, ln2 }) + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.ln1.forward(xs)?; + let xs = xs.relu()?; + self.ln2.forward(&xs) + } +} + +pub fn main() -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; + + // Load the dataset let m = candle_nn::vision::mnist::load_dir("data")?; println!("train-images: {:?}", m.train_images.shape()); println!("train-labels: {:?}", m.train_labels.shape()); @@ -19,18 +133,23 @@ pub fn main() -> Result<()> { let train_labels = m.train_labels; let train_images = m.train_images; let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?; - let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?; - let bs = Var::zeros(LABELS, DType::F32, &dev)?; - let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0); + + let vs = VarStore::new(DType::F32, dev); + let model = LinearModel::new(vs.clone())?; + // let model = Mlp::new(vs)?; + + let all_vars = vs.all_vars(); + let all_vars = all_vars.iter().collect::<Vec<_>>(); + let sgd = candle_nn::SGD::new(&all_vars, 1.0); let test_images = m.test_images; let test_labels = m.test_labels.to_dtype(DType::U32)?; for epoch in 1..200 { - let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?; + let logits = model.forward(&train_images)?; let log_sm = ops::log_softmax(&logits, D::Minus1)?; let loss = loss::nll(&log_sm, &train_labels)?; sgd.backward_step(&loss)?; - let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?; + let test_logits = model.forward(&test_images)?; let sum_ok = test_logits .argmax(D::Minus1)? .eq(&test_labels)? diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index c064f6e5..eceb45c8 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -1,12 +1,12 @@ #include "cuda_utils.cuh" #include<stdint.h> -#define WHERE_OP(TYPENAME, FN_NAME) \ +#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ const size_t *info, \ - const uint32_t *ids, \ + const ID_TYPENAME *ids, \ const TYPENAME *t, \ const TYPENAME *f, \ TYPENAME *out \ @@ -33,14 +33,21 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 800 -WHERE_OP(__nv_bfloat16, where_bf16) +WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) +WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 -WHERE_OP(__half, where_f16) +WHERE_OP(__half, uint32_t, where_u32_f16) +WHERE_OP(__half, uint8_t, where_u8_f16) #endif -WHERE_OP(float, where_f32) -WHERE_OP(double, where_f64) -WHERE_OP(uint8_t, where_u8) -WHERE_OP(uint32_t, where_u32) +WHERE_OP(float, uint32_t, where_u32_f32) +WHERE_OP(double, uint32_t, where_u32_f64) +WHERE_OP(uint8_t, uint32_t, where_u32_u8) +WHERE_OP(uint32_t, uint32_t, where_u32_u32) + +WHERE_OP(float, uint8_t, where_u8_f32) +WHERE_OP(double, uint8_t, where_u8_f64) +WHERE_OP(uint8_t, uint8_t, where_u8_u8) +WHERE_OP(uint8_t, uint32_t, where_u8_u32) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 5c222bf6..be1380b7 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -209,6 +209,7 @@ impl<'a> VarBuilder<'a> { }; Ok(tensor) } + pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> { let data = self.data.as_ref(); let s: Shape = s.into(); diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index d95672b9..8cf53c2a 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -47,9 +47,8 @@ impl Cache { if let Some(mask) = masks.get(&t) { Ok(mask.clone()) } else { - // TODO: If we support bool or u8 tensors, this would be better. let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; masks.insert(t, mask.clone()); |