summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--candle-core/src/cuda_backend.rs41
-rw-r--r--candle-core/tests/pool_tests.rs10
-rw-r--r--candle-examples/examples/bert/main.rs5
-rw-r--r--candle-examples/examples/bigcode/main.rs5
-rw-r--r--candle-examples/examples/falcon/main.rs6
-rw-r--r--candle-examples/examples/llama/main.rs4
-rw-r--r--candle-examples/examples/whisper/main.rs5
-rw-r--r--candle-kernels/src/conv.cu62
-rw-r--r--candle-wasm-examples/whisper/src/worker.rs5
10 files changed, 119 insertions, 26 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 915e6314..cef5c011 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -48,7 +48,7 @@ safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
thiserror = "1"
-tokenizers = { version = "0.13.3", default-features = false }
+tokenizers = { version = "0.13.4", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 90d3ee6d..0a73c023 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -980,7 +980,7 @@ impl Map1 for Pool2D {
dev: &CudaDevice,
inp_l: &Layout,
) -> Result<CudaSlice<T>> {
- // Kernel shape: (c_out, c_in_k, w_k, h_k)
+ // Input shape: (b_size, c, h, w)
let inp = &inp.slice(inp_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
@@ -1018,6 +1018,39 @@ impl Map1 for Pool2D {
}
}
+struct UpsampleNearest2D(usize, usize);
+impl Map1 for UpsampleNearest2D {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ inp: &CudaSlice<T>,
+ dev: &CudaDevice,
+ inp_l: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ // Input shape: (b_size, c, h, w)
+ let inp = &inp.slice(inp_l.start_offset()..);
+ let shape = inp_l.shape();
+ let dims = shape.dims();
+ let ds = if dims.len() == 4 {
+ [dims, inp_l.stride()].concat()
+ } else {
+ panic!("unexpected input shape for conv1d {dims:?}")
+ };
+ let (out_w, out_h) = (self.0, self.1);
+ let dst_el = out_w * out_h * dims[0] * dims[1];
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let ds = dev.htod_copy(ds).w()?;
+ let scale_w = dims[2] as f64 / out_w as f64;
+ let scale_h = dims[3] as f64 / out_h as f64;
+ let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(out)
+ }
+}
+
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map2 for WhereCond<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1513,8 +1546,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
- todo!()
+ fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
+ let device = self.device().clone();
+ let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
+ Ok(Self { slice, device })
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
index 009564fa..d2eb8f3f 100644
--- a/candle-core/tests/pool_tests.rs
+++ b/candle-core/tests/pool_tests.rs
@@ -56,9 +56,8 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
Ok(())
}
-#[test]
-fn upsample_nearest2d() -> Result<()> {
- let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?;
+fn upsample_nearest2d(dev: &Device) -> Result<()> {
+ let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?;
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
assert_eq!(
t.i(0)?.i(0)?.to_vec2::<f32>()?,
@@ -83,3 +82,8 @@ test_device!(
avg_pool2d_pytorch_gpu
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
+test_device!(
+ upsample_nearest2d,
+ upsample_nearest2d_cpu,
+ upsample_nearest2d_gpu
+);
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 574755ed..7f0ae7b1 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -111,7 +111,10 @@ fn main() -> Result<()> {
let device = &model.device;
if let Some(prompt) = args.prompt {
- let tokenizer = tokenizer.with_padding(None).with_truncation(None);
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs
index ac9c63c7..39b1de27 100644
--- a/candle-examples/examples/bigcode/main.rs
+++ b/candle-examples/examples/bigcode/main.rs
@@ -65,10 +65,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
- let token = self
- .tokenizer
- .decode(vec![next_token], true)
- .map_err(E::msg)?;
+ let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index c37d9a96..0df3a001 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -72,16 +72,14 @@ impl TextGeneration {
"{} token: {} '{}'",
index + 1,
next_token,
- self.tokenizer
- .decode(vec![next_token], true)
- .map_err(E::msg)?
+ self.tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({} token/s)\n----\n{}\n----",
sample_len as f64 / dt.as_secs_f64(),
- self.tokenizer.decode(new_tokens, true).map_err(E::msg)?
+ self.tokenizer.decode(&new_tokens, true).map_err(E::msg)?
);
Ok(())
}
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 9a62eba5..b1e112fd 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -223,7 +223,7 @@ fn main() -> Result<()> {
"{} token: {} '{}'",
index + 1,
next_token,
- tokenizer.decode(vec![next_token], true).map_err(E::msg)?
+ tokenizer.decode(&[next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
@@ -231,7 +231,7 @@ fn main() -> Result<()> {
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
- tokenizer.decode(new_tokens, true).map_err(E::msg)?
+ tokenizer.decode(&new_tokens, true).map_err(E::msg)?
);
Ok(())
}
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 99919f8d..5c58c002 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -169,10 +169,7 @@ impl Decoder {
}
sum_logprob += prob.ln();
}
- let text = self
- .tokenizer
- .decode(tokens.clone(), true)
- .map_err(E::msg)?;
+ let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 2da4d401..afda7d1d 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -220,6 +220,48 @@ __device__ void max_pool2d(
dst[dst_i] = d;
}
+template <typename T>
+__device__ void upsample_nearest2d(
+ const size_t w_out,
+ const size_t h_out,
+ const double w_scale,
+ const double h_scale,
+ const size_t *info,
+ const T *src,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // src: (b_size, c_in, w_in, h_in)
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 4;
+
+ const size_t c = src_dims[1];
+ const size_t w_in = src_dims[2];
+ const size_t h_in = src_dims[3];
+
+ if (dst_i >= src_dims[0] * c * w_out * h_out) {
+ return;
+ }
+
+ // TODO: Improve this.
+ const size_t b_idx = dst_i / (w_out * h_out * c);
+ const size_t c_idx = (dst_i / (w_out * h_out)) % c;
+ const size_t dst_w = (dst_i / h_out) % w_out;
+ const size_t dst_h = dst_i % h_out;
+
+ size_t src_w = static_cast<size_t>(dst_w * w_scale);
+ size_t src_h = static_cast<size_t>(dst_h * h_scale);
+ if (src_w >= w_in) {
+ src_w = w_in - 1;
+ }
+ if (src_h >= h_in) {
+ src_h = h_in - 1;
+ }
+
+ const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
+ dst[dst_i] = src[src_i];
+}
+
#define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
@@ -278,11 +320,25 @@ extern "C" __global__ void FN_NAME( \
max_pool2d<TYPENAME>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \
} \
+#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t w_out, \
+ const size_t h_out, \
+ const double w_scale, \
+ const double h_scale, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ TYPENAME *dst \
+) { \
+ upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, info, src, dst); \
+} \
+
#if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
+UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@@ -290,6 +346,7 @@ CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
+UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@@ -311,3 +368,8 @@ MAX_POOL2D_OP(float, max_pool2d_f32)
MAX_POOL2D_OP(double, max_pool2d_f64)
MAX_POOL2D_OP(uint8_t, max_pool2d_u8)
MAX_POOL2D_OP(uint32_t, max_pool2d_u32)
+
+UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
+UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64)
+UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
+UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs
index 139755cb..d77d3e32 100644
--- a/candle-wasm-examples/whisper/src/worker.rs
+++ b/candle-wasm-examples/whisper/src/worker.rs
@@ -159,10 +159,7 @@ impl Decoder {
}
sum_logprob += prob.ln();
}
- let text = self
- .tokenizer
- .decode(tokens.clone(), true)
- .map_err(E::msg)?;
+ let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {