diff --git a/examples/ppo_hover_eval.py b/examples/ppo_hover_eval.py index 9874939..19a70dc 100644 --- a/examples/ppo_hover_eval.py +++ b/examples/ppo_hover_eval.py @@ -42,7 +42,7 @@ models_available = os.listdir(models_dir) for i, name in enumerate(models_available): print(f"{i}: {name}") model_idx = int(input("Enter the model index: ")) -epoch_dir = os.path.join(models_dir, models_available[model_idx]) +num_timesteps_dir = os.path.join(models_dir, models_available[model_idx]) # Next import Stable Baselines. try: @@ -95,33 +95,33 @@ envs.append(gym.make("Quadrotor-v0", # Print out policies for the user to select. def extract_number(filename): return int(filename.split('_')[1].split('.')[0]) -epochs_list = [fname for fname in os.listdir(epoch_dir) if fname.startswith('hover_')] -epochs_list_sorted = sorted(epochs_list, key=extract_number) +num_timesteps_list = [fname for fname in os.listdir(num_timesteps_dir) if fname.startswith('hover_')] +num_timesteps_list_sorted = sorted(num_timesteps_list, key=extract_number) print("Select one of the epochs:") -for i, name in enumerate(epochs_list_sorted): +for i, name in enumerate(num_timesteps_list_sorted): print(f"{i}: {name}") -epoch_idxs = [int(input("Enter the epoch index: "))] +num_timesteps_idxs = [int(input("Enter the epoch index: "))] # You can optionally just hard code a series of epochs you'd like to evaluate all at once. -# e.g. epoch_idxs = [0, 1, 2, ...] +# e.g. num_timesteps_idxs = [0, 1, 2, ...] # Evaluation... -for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index... +for (k, num_timesteps_idx) in enumerate(num_timesteps_idxs): # For each num_timesteps index... print(f"[ppo_hover_eval.py]: Starting epoch {k+1} out of {len(epoch_idxs)}.") # Load the model for the appropriate epoch. - model_path = os.path.join(epoch_dir, epochs_list_sorted[epoch_idx]) + model_path = os.path.join(num_timesteps_dir, num_timesteps_list_sorted[num_timesteps_idx]) print(f"Loading model from the path {model_path}") model = PPO.load(model_path, env=envs[0], tensorboard_log=log_dir) # Set figure title for 3D plot. - fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}") + fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(num_timesteps_list_sorted[num_timesteps_idx]):,}") # Visualization is slow, so we'll also save frames to make a GIF later. # Set the path for these frames here. - frame_path = os.path.join(output_dir, epochs_list_sorted[epoch_idx][:-4]) + frame_path = os.path.join(output_dir, num_timesteps_list_sorted[num_timesteps_idx][:-4]) if not os.path.exists(frame_path): os.makedirs(frame_path) @@ -192,7 +192,7 @@ for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index... # Plot position vs time. fig_pos, ax_pos = plt.subplots(nrows=3, ncols=1, num="Position vs Time") - fig_pos.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}") + fig_pos.suptitle(f"Model: PPO/{models_available[model_idx]}, Num Timesteps: {extract_number(num_timesteps_list_sorted[num_timesteps_idx]):,}") ax_pos[0].plot(T, x[:, 0], 'b-', linewidth=1, label="RL") ax_pos[0].plot(T, x[:, 1:-1], 'b-', linewidth=1) ax_pos[0].plot(T, x[:, -1], 'k-', linewidth=2, label="GC") @@ -214,4 +214,4 @@ for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index... # Save fig. fig_pos.savefig(os.path.join(frame_path, 'position_vs_time.png')) -plt.show() \ No newline at end of file +plt.show()