diff --git a/examples/ppo_hover_eval.py b/examples/ppo_hover_eval.py index 18db0d9..7a629ee 100644 --- a/examples/ppo_hover_eval.py +++ b/examples/ppo_hover_eval.py @@ -21,6 +21,10 @@ The task is for the quadrotor to stabilize to hover at the origin when starting # First we'll set up some directories for saving the policy and logs. models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies") log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "logs") +output_dir = os.path.join(os.path.dirname(__file__), "..", "rotorpy", "data_out", "ppo_hover") + +if not os.path.exists(output_dir): + os.makedirs(output_dir) # Next import Stable Baselines. try: @@ -68,9 +72,13 @@ num_episodes = 10 for i in range(num_episodes): obs, info = env.reset() terminated = False + j = 0 while not terminated: env.render() action, _ = model.predict(obs) obs, reward, terminated, truncated, info = env.step(action) + if i == 0: # Save frames from the first rollout to make a gif. + env.fig.savefig(os.path.join(output_dir, 'PPO_hover_'+str(j)+'.png')) + j += 1 plt.show() \ No newline at end of file