summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs16
-rw-r--r--candle-examples/examples/bigcode/model.rs2
-rw-r--r--candle-examples/examples/falcon/model.rs2
-rw-r--r--candle-examples/examples/llama/model.rs3
-rw-r--r--candle-examples/examples/llama2-c/model.rs3
-rw-r--r--candle-examples/examples/simple-training/main.rs139
-rw-r--r--candle-kernels/src/ternary.cu23
-rw-r--r--candle-nn/src/var_builder.rs1
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs3
9 files changed, 161 insertions, 31 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 4050b595..a88d62c7 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -940,16 +940,22 @@ impl<'a> Map2 for WhereCond<'a> {
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
- let ids = match &self.0.slice {
- CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
+ let (ids, name) = match &self.0.slice {
+ CudaStorageSlice::U8(slice) => {
+ let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
+ (ptr, "where_u8")
+ }
+ CudaStorageSlice::U32(slice) => {
+ let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
+ (ptr, "where_u32")
+ }
_ => Err(CudaError::UnexpectedDType {
- msg: "where conditions should be u32",
+ msg: "where conditions should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
.w()?,
};
- let ids = &ids;
let shape = ids_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
@@ -959,7 +965,7 @@ impl<'a> Map2 for WhereCond<'a> {
.w()?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
- let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
+ let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs
index ddee0d27..99f5bb5a 100644
--- a/candle-examples/examples/bigcode/model.rs
+++ b/candle-examples/examples/bigcode/model.rs
@@ -24,7 +24,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u32::from(j <= i)))
+ .flat_map(|i| (0..t).map(move |j| u8::from(j <= i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), device)?;
Ok(mask)
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs
index cab0b314..1c77cbaf 100644
--- a/candle-examples/examples/falcon/model.rs
+++ b/candle-examples/examples/falcon/model.rs
@@ -424,7 +424,7 @@ pub struct Falcon {
fn make_causal_mask(t: usize) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
+ .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
Ok(mask)
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index dba1d535..ae27afc1 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -91,9 +91,8 @@ impl Cache {
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
- // TODO: If we support bool or u8 tensors, this would be better.
let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
+ .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs
index 6d9e4bcd..9e1c3eda 100644
--- a/candle-examples/examples/llama2-c/model.rs
+++ b/candle-examples/examples/llama2-c/model.rs
@@ -47,9 +47,8 @@ impl Cache {
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
- // TODO: If we support bool or u8 tensors, this would be better.
let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
+ .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs
index edec2e92..35b938e8 100644
--- a/candle-examples/examples/simple-training/main.rs
+++ b/candle-examples/examples/simple-training/main.rs
@@ -1,16 +1,130 @@
-// This should rearch 91.5% accuracy.
+// This should reach 91.5% accuracy.
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
-use anyhow::Result;
-use candle::{DType, Var, D};
-use candle_nn::{loss, ops};
+use candle::{DType, Device, Result, Shape, Tensor, Var, D};
+use candle_nn::{loss, ops, Linear};
+use std::sync::{Arc, Mutex};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
-pub fn main() -> Result<()> {
+struct TensorData {
+ tensors: std::collections::HashMap<String, Var>,
+ pub dtype: DType,
+ pub device: Device,
+}
+
+// A variant of candle_nn::VarBuilder for initializing variables before training.
+#[derive(Clone)]
+struct VarStore {
+ data: Arc<Mutex<TensorData>>,
+ path: Vec<String>,
+}
+
+impl VarStore {
+ fn new(dtype: DType, device: Device) -> Self {
+ let data = TensorData {
+ tensors: std::collections::HashMap::new(),
+ dtype,
+ device,
+ };
+ Self {
+ data: Arc::new(Mutex::new(data)),
+ path: vec![],
+ }
+ }
+
+ fn pp(&self, s: &str) -> Self {
+ let mut path = self.path.clone();
+ path.push(s.to_string());
+ Self {
+ data: self.data.clone(),
+ path,
+ }
+ }
+
+ fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> {
+ let shape = shape.into();
+ let path = if self.path.is_empty() {
+ tensor_name.to_string()
+ } else {
+ [&self.path.join("."), tensor_name].join(".")
+ };
+ let mut tensor_data = self.data.lock().unwrap();
+ if let Some(tensor) = tensor_data.tensors.get(&path) {
+ let tensor_shape = tensor.shape();
+ if &shape != tensor_shape {
+ candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
+ }
+ return Ok(tensor.as_tensor().clone());
+ }
+ // TODO: Proper initialization using the `Init` enum.
+ let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
+ let tensor = var.as_tensor().clone();
+ tensor_data.tensors.insert(path, var);
+ Ok(tensor)
+ }
+
+ fn all_vars(&self) -> Vec<Var> {
+ let tensor_data = self.data.lock().unwrap();
+ #[allow(clippy::map_clone)]
+ tensor_data
+ .tensors
+ .values()
+ .map(|c| c.clone())
+ .collect::<Vec<_>>()
+ }
+}
+
+fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> {
+ let ws = vs.get((dim2, dim1), "weight")?;
+ let bs = vs.get(dim2, "bias")?;
+ Ok(Linear::new(ws, Some(bs)))
+}
+
+#[allow(unused)]
+struct LinearModel {
+ linear: Linear,
+}
+
+#[allow(unused)]
+impl LinearModel {
+ fn new(vs: VarStore) -> Result<Self> {
+ let linear = linear(IMAGE_DIM, LABELS, vs)?;
+ Ok(Self { linear })
+ }
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.linear.forward(xs)
+ }
+}
+
+#[allow(unused)]
+struct Mlp {
+ ln1: Linear,
+ ln2: Linear,
+}
+
+#[allow(unused)]
+impl Mlp {
+ fn new(vs: VarStore) -> Result<Self> {
+ let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
+ let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
+ Ok(Self { ln1, ln2 })
+ }
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.ln1.forward(xs)?;
+ let xs = xs.relu()?;
+ self.ln2.forward(&xs)
+ }
+}
+
+pub fn main() -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
+
+ // Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
@@ -19,18 +133,23 @@ pub fn main() -> Result<()> {
let train_labels = m.train_labels;
let train_images = m.train_images;
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;
- let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
- let bs = Var::zeros(LABELS, DType::F32, &dev)?;
- let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
+
+ let vs = VarStore::new(DType::F32, dev);
+ let model = LinearModel::new(vs.clone())?;
+ // let model = Mlp::new(vs)?;
+
+ let all_vars = vs.all_vars();
+ let all_vars = all_vars.iter().collect::<Vec<_>>();
+ let sgd = candle_nn::SGD::new(&all_vars, 1.0);
let test_images = m.test_images;
let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
- let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?;
+ let logits = model.forward(&train_images)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_labels)?;
sgd.backward_step(&loss)?;
- let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?;
+ let test_logits = model.forward(&test_images)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu
index c064f6e5..eceb45c8 100644
--- a/candle-kernels/src/ternary.cu
+++ b/candle-kernels/src/ternary.cu
@@ -1,12 +1,12 @@
#include "cuda_utils.cuh"
#include<stdint.h>
-#define WHERE_OP(TYPENAME, FN_NAME) \
+#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
- const uint32_t *ids, \
+ const ID_TYPENAME *ids, \
const TYPENAME *t, \
const TYPENAME *f, \
TYPENAME *out \
@@ -33,14 +33,21 @@ extern "C" __global__ void FN_NAME( \
} \
#if __CUDA_ARCH__ >= 800
-WHERE_OP(__nv_bfloat16, where_bf16)
+WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16)
+WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16)
#endif
#if __CUDA_ARCH__ >= 530
-WHERE_OP(__half, where_f16)
+WHERE_OP(__half, uint32_t, where_u32_f16)
+WHERE_OP(__half, uint8_t, where_u8_f16)
#endif
-WHERE_OP(float, where_f32)
-WHERE_OP(double, where_f64)
-WHERE_OP(uint8_t, where_u8)
-WHERE_OP(uint32_t, where_u32)
+WHERE_OP(float, uint32_t, where_u32_f32)
+WHERE_OP(double, uint32_t, where_u32_f64)
+WHERE_OP(uint8_t, uint32_t, where_u32_u8)
+WHERE_OP(uint32_t, uint32_t, where_u32_u32)
+
+WHERE_OP(float, uint8_t, where_u8_f32)
+WHERE_OP(double, uint8_t, where_u8_f64)
+WHERE_OP(uint8_t, uint8_t, where_u8_u8)
+WHERE_OP(uint8_t, uint32_t, where_u8_u32)
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 5c222bf6..be1380b7 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -209,6 +209,7 @@ impl<'a> VarBuilder<'a> {
};
Ok(tensor)
}
+
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
let data = self.data.as_ref();
let s: Shape = s.into();
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index d95672b9..8cf53c2a 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -47,9 +47,8 @@ impl Cache {
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
- // TODO: If we support bool or u8 tensors, this would be better.
let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
+ .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());