summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-02 10:52:02 +0200
committerGitHub <noreply@github.com>2024-10-02 10:52:02 +0200
commitf479840ce6d2222bd004b6f275494297f1f0ae91 (patch)
treebae564a332edd86fab1bbb9f9ce4fa1a65e3b38f /candle-examples/examples
parentfd08d3d0a40872f207284b008de23ef875d54f74 (diff)
downloadcandle-f479840ce6d2222bd004b6f275494297f1f0ae91.tar.gz
candle-f479840ce6d2222bd004b6f275494297f1f0ae91.tar.bz2
candle-f479840ce6d2222bd004b6f275494297f1f0ae91.zip
Add a seed to the flux example. (#2529)
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/flux/main.rs13
1 files changed, 10 insertions, 3 deletions
diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs
index 24b1fa2b..943db112 100644
--- a/candle-examples/examples/flux/main.rs
+++ b/candle-examples/examples/flux/main.rs
@@ -45,9 +45,13 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
- /// Use the faster kernels which are buggy at the moment.
+ /// Use the slower kernels.
#[arg(long)]
- no_dmmv: bool,
+ use_dmmv: bool,
+
+ /// The seed to use when generating random samples.
+ #[arg(long)]
+ seed: Option<u64>,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
@@ -91,6 +95,9 @@ fn run(args: Args) -> Result<()> {
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?;
+ if let Some(seed) = args.seed {
+ device.set_seed(seed)?;
+ }
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
None => {
@@ -250,6 +257,6 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
- candle::quantized::cuda::set_force_dmmv(!args.no_dmmv);
+ candle::quantized::cuda::set_force_dmmv(args.use_dmmv);
run(args)
}