Save frames for first rollout
This commit is contained in:
@@ -21,6 +21,10 @@ The task is for the quadrotor to stabilize to hover at the origin when starting
|
||||
# First we'll set up some directories for saving the policy and logs.
|
||||
models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies")
|
||||
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "logs")
|
||||
output_dir = os.path.join(os.path.dirname(__file__), "..", "rotorpy", "data_out", "ppo_hover")
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Next import Stable Baselines.
|
||||
try:
|
||||
@@ -68,9 +72,13 @@ num_episodes = 10
|
||||
for i in range(num_episodes):
|
||||
obs, info = env.reset()
|
||||
terminated = False
|
||||
j = 0
|
||||
while not terminated:
|
||||
env.render()
|
||||
action, _ = model.predict(obs)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if i == 0: # Save frames from the first rollout to make a gif.
|
||||
env.fig.savefig(os.path.join(output_dir, 'PPO_hover_'+str(j)+'.png'))
|
||||
j += 1
|
||||
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user