diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/func.rs | 12 | ||||
-rw-r--r-- | candle-nn/src/sequential.rs | 2 |
2 files changed, 8 insertions, 6 deletions
diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index e7fd73ae..39311d45 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -1,10 +1,12 @@ //! Layers defined by closures. use candle::{Result, Tensor}; +use std::sync::Arc; /// A layer defined by a simple closure. +#[derive(Clone)] pub struct Func<'a> { #[allow(clippy::type_complexity)] - f: Box<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send>, + f: Arc<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync>, } impl<'a> std::fmt::Debug for Func<'a> { @@ -15,9 +17,9 @@ impl<'a> std::fmt::Debug for Func<'a> { pub fn func<'a, F>(f: F) -> Func<'a> where - F: 'a + Fn(&Tensor) -> Result<Tensor> + Send, + F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync, { - Func { f: Box::new(f) } + Func { f: Arc::new(f) } } impl<'a> super::Module for Func<'a> { @@ -29,8 +31,8 @@ impl<'a> super::Module for Func<'a> { impl<'a> Func<'a> { pub fn new<F>(f: F) -> Self where - F: 'a + Fn(&Tensor) -> Result<Tensor> + Send, + F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync, { - Self { f: Box::new(f) } + Self { f: Arc::new(f) } } } diff --git a/candle-nn/src/sequential.rs b/candle-nn/src/sequential.rs index 2fef7742..bef99752 100644 --- a/candle-nn/src/sequential.rs +++ b/candle-nn/src/sequential.rs @@ -44,7 +44,7 @@ impl Sequential { /// Appends a closure after all the current layers. pub fn add_fn<F>(self, f: F) -> Self where - F: 'static + Fn(&Tensor) -> Result<Tensor> + Send, + F: 'static + Fn(&Tensor) -> Result<Tensor> + Send + Sync, { self.add(super::func(f)) } |