summaryrefslogtreecommitdiff
path: root/candle-examples/examples/reinforcement-learning/gym_env.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/reinforcement-learning/gym_env.rs')
-rw-r--r--candle-examples/examples/reinforcement-learning/gym_env.rs108
1 files changed, 108 insertions, 0 deletions
diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs
new file mode 100644
index 00000000..b98be6bc
--- /dev/null
+++ b/candle-examples/examples/reinforcement-learning/gym_env.rs
@@ -0,0 +1,108 @@
+#![allow(unused)]
+//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
+use candle::{Device, Result, Tensor};
+use pyo3::prelude::*;
+use pyo3::types::PyDict;
+
+/// The return value for a step.
+#[derive(Debug)]
+pub struct Step<A> {
+ pub obs: Tensor,
+ pub action: A,
+ pub reward: f64,
+ pub is_done: bool,
+}
+
+impl<A: Copy> Step<A> {
+ /// Returns a copy of this step changing the observation tensor.
+ pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
+ Step {
+ obs: obs.clone(),
+ action: self.action,
+ reward: self.reward,
+ is_done: self.is_done,
+ }
+ }
+}
+
+/// An OpenAI Gym session.
+pub struct GymEnv {
+ env: PyObject,
+ action_space: usize,
+ observation_space: Vec<usize>,
+}
+
+fn w(res: PyErr) -> candle::Error {
+ candle::Error::wrap(res)
+}
+
+impl GymEnv {
+ /// Creates a new session of the specified OpenAI Gym environment.
+ pub fn new(name: &str) -> Result<GymEnv> {
+ Python::with_gil(|py| {
+ let gym = py.import("gymnasium")?;
+ let make = gym.getattr("make")?;
+ let env = make.call1((name,))?;
+ let action_space = env.getattr("action_space")?;
+ let action_space = if let Ok(val) = action_space.getattr("n") {
+ val.extract()?
+ } else {
+ let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
+ action_space[0]
+ };
+ let observation_space = env.getattr("observation_space")?;
+ let observation_space = observation_space.getattr("shape")?.extract()?;
+ Ok(GymEnv {
+ env: env.into(),
+ action_space,
+ observation_space,
+ })
+ })
+ .map_err(w)
+ }
+
+ /// Resets the environment, returning the observation tensor.
+ pub fn reset(&self, seed: u64) -> Result<Tensor> {
+ let obs: Vec<f32> = Python::with_gil(|py| {
+ let kwargs = PyDict::new(py);
+ kwargs.set_item("seed", seed)?;
+ let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
+ obs.as_ref(py).get_item(0)?.extract()
+ })
+ .map_err(w)?;
+ Tensor::new(obs, &Device::Cpu)
+ }
+
+ /// Applies an environment step using the specified action.
+ pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
+ &self,
+ action: A,
+ ) -> Result<Step<A>> {
+ let (obs, reward, is_done) = Python::with_gil(|py| {
+ let step = self.env.call_method(py, "step", (action.clone(),), None)?;
+ let step = step.as_ref(py);
+ let obs: Vec<f32> = step.get_item(0)?.extract()?;
+ let reward: f64 = step.get_item(1)?.extract()?;
+ let is_done: bool = step.get_item(2)?.extract()?;
+ Ok((obs, reward, is_done))
+ })
+ .map_err(w)?;
+ let obs = Tensor::new(obs, &Device::Cpu)?;
+ Ok(Step {
+ obs,
+ reward,
+ is_done,
+ action,
+ })
+ }
+
+ /// Returns the number of allowed actions for this environment.
+ pub fn action_space(&self) -> usize {
+ self.action_space
+ }
+
+ /// Returns the shape of the observation tensors.
+ pub fn observation_space(&self) -> &[usize] {
+ &self.observation_space
+ }
+}