summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning/policy_gradient.rs
blob: 6c355fe62f847a63a1c1c72409eaf25b88d111ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use super::gym_env::{GymEnv, Step};
use candle::{DType, Device, Error, Module, Result, Tensor};
use candle_nn::{
    linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
    ParamsAdamW, VarBuilder, VarMap,
};
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};

fn new_model(
    input_shape: &[usize],
    num_actions: usize,
    dtype: DType,
    device: &Device,
) -> Result<(impl Module, VarMap)> {
    let input_size = input_shape.iter().product();

    let mut varmap = VarMap::new();
    let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);

    let model = seq()
        .add(linear(input_size, 32, var_builder.pp("lin1"))?)
        .add(Activation::Relu)
        .add(linear(32, num_actions, var_builder.pp("lin2"))?);

    Ok((model, varmap))
}

fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
    let mut rewards: Vec<f64> = steps.iter().map(|s| s.reward).collect();
    let mut acc_reward = 0f64;
    for (i, reward) in rewards.iter_mut().enumerate().rev() {
        if steps[i].terminated {
            acc_reward = 0.0;
        }
        acc_reward += *reward;
        *reward = acc_reward;
    }
    rewards
}

fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
    let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
    let mut rng = rng;
    Ok(distribution.sample(&mut rng))
}

pub fn run() -> Result<()> {
    let env = GymEnv::new("CartPole-v1")?;

    println!("action space: {:?}", env.action_space());
    println!("observation space: {:?}", env.observation_space());

    let (model, varmap) = new_model(
        env.observation_space(),
        env.action_space(),
        DType::F32,
        &Device::Cpu,
    )?;

    let optimizer_params = ParamsAdamW {
        lr: 0.01,
        weight_decay: 0.01,
        ..Default::default()
    };

    let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;

    let mut rng = rand::thread_rng();

    for epoch_idx in 0..100 {
        let mut state = env.reset(rng.gen::<u64>())?;
        let mut steps: Vec<Step<i64>> = vec![];

        loop {
            let action = {
                let action_probs: Vec<f32> =
                    softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
                        .squeeze(0)?
                        .to_vec1()?;
                weighted_sample(action_probs, &mut rng)? as i64
            };

            let step = env.step(action)?;
            steps.push(step.copy_with_obs(&state));

            if step.terminated || step.truncated {
                state = env.reset(rng.gen::<u64>())?;
                if steps.len() > 5000 {
                    break;
                }
            } else {
                state = step.state;
            }
        }

        let total_reward: f64 = steps.iter().map(|s| s.reward).sum();
        let episodes: i64 = steps
            .iter()
            .map(|s| (s.terminated || s.truncated) as i64)
            .sum();
        println!(
            "epoch: {:<3} episodes: {:<5} avg reward per episode: {:.2}",
            epoch_idx,
            episodes,
            total_reward / episodes as f64
        );

        let batch_size = steps.len();

        let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
            .to_dtype(DType::F32)?
            .detach();

        let actions_mask = {
            let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
            let actions_mask: Vec<Tensor> = actions
                .iter()
                .map(|&action| {
                    // One-hot encoding
                    let mut action_mask = vec![0.0; env.action_space()];
                    action_mask[action as usize] = 1.0;

                    Tensor::from_vec(action_mask, env.action_space(), &Device::Cpu)
                        .unwrap()
                        .to_dtype(DType::F32)
                        .unwrap()
                })
                .collect();
            Tensor::stack(&actions_mask, 0)?.detach()
        };

        let states = {
            let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
            Tensor::stack(&states, 0)?.detach()
        };

        let log_probs = actions_mask
            .mul(&log_softmax(&model.forward(&states)?, 1)?)?
            .sum(1)?;

        let loss = rewards.mul(&log_probs)?.neg()?.mean_all()?;
        optimizer.backward_step(&loss)?;
    }

    Ok(())
}