Save frames for first rollout

This commit is contained in:
spencerfolk
2024-01-02 17:03:27 -05:00
parent c93f2bb603
commit 128d500953

View File

@@ -21,6 +21,10 @@ The task is for the quadrotor to stabilize to hover at the origin when starting
# First we'll set up some directories for saving the policy and logs.
models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies")
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "logs")
output_dir = os.path.join(os.path.dirname(__file__), "..", "rotorpy", "data_out", "ppo_hover")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Next import Stable Baselines.
try:
@@ -68,9 +72,13 @@ num_episodes = 10
for i in range(num_episodes):
obs, info = env.reset()
terminated = False
j = 0
while not terminated:
env.render()
action, _ = model.predict(obs)
obs, reward, terminated, truncated, info = env.step(action)
if i == 0: # Save frames from the first rollout to make a gif.
env.fig.savefig(os.path.join(output_dir, 'PPO_hover_'+str(j)+'.png'))
j += 1
plt.show()