Save frames for first rollout
This commit is contained in:
@@ -21,6 +21,10 @@ The task is for the quadrotor to stabilize to hover at the origin when starting
|
|||||||
# First we'll set up some directories for saving the policy and logs.
|
# First we'll set up some directories for saving the policy and logs.
|
||||||
models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies")
|
models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies")
|
||||||
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "logs")
|
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "logs")
|
||||||
|
output_dir = os.path.join(os.path.dirname(__file__), "..", "rotorpy", "data_out", "ppo_hover")
|
||||||
|
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
# Next import Stable Baselines.
|
# Next import Stable Baselines.
|
||||||
try:
|
try:
|
||||||
@@ -68,9 +72,13 @@ num_episodes = 10
|
|||||||
for i in range(num_episodes):
|
for i in range(num_episodes):
|
||||||
obs, info = env.reset()
|
obs, info = env.reset()
|
||||||
terminated = False
|
terminated = False
|
||||||
|
j = 0
|
||||||
while not terminated:
|
while not terminated:
|
||||||
env.render()
|
env.render()
|
||||||
action, _ = model.predict(obs)
|
action, _ = model.predict(obs)
|
||||||
obs, reward, terminated, truncated, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
if i == 0: # Save frames from the first rollout to make a gif.
|
||||||
|
env.fig.savefig(os.path.join(output_dir, 'PPO_hover_'+str(j)+'.png'))
|
||||||
|
j += 1
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
Reference in New Issue
Block a user