Changed epoch to num_timesteps
Epoch was the wrong term to use.
This commit is contained in:
@@ -42,7 +42,7 @@ models_available = os.listdir(models_dir)
|
||||
for i, name in enumerate(models_available):
|
||||
print(f"{i}: {name}")
|
||||
model_idx = int(input("Enter the model index: "))
|
||||
epoch_dir = os.path.join(models_dir, models_available[model_idx])
|
||||
num_timesteps_dir = os.path.join(models_dir, models_available[model_idx])
|
||||
|
||||
# Next import Stable Baselines.
|
||||
try:
|
||||
@@ -95,33 +95,33 @@ envs.append(gym.make("Quadrotor-v0",
|
||||
# Print out policies for the user to select.
|
||||
def extract_number(filename):
|
||||
return int(filename.split('_')[1].split('.')[0])
|
||||
epochs_list = [fname for fname in os.listdir(epoch_dir) if fname.startswith('hover_')]
|
||||
epochs_list_sorted = sorted(epochs_list, key=extract_number)
|
||||
num_timesteps_list = [fname for fname in os.listdir(num_timesteps_dir) if fname.startswith('hover_')]
|
||||
num_timesteps_list_sorted = sorted(num_timesteps_list, key=extract_number)
|
||||
|
||||
print("Select one of the epochs:")
|
||||
for i, name in enumerate(epochs_list_sorted):
|
||||
for i, name in enumerate(num_timesteps_list_sorted):
|
||||
print(f"{i}: {name}")
|
||||
epoch_idxs = [int(input("Enter the epoch index: "))]
|
||||
num_timesteps_idxs = [int(input("Enter the epoch index: "))]
|
||||
|
||||
# You can optionally just hard code a series of epochs you'd like to evaluate all at once.
|
||||
# e.g. epoch_idxs = [0, 1, 2, ...]
|
||||
# e.g. num_timesteps_idxs = [0, 1, 2, ...]
|
||||
|
||||
# Evaluation...
|
||||
for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index...
|
||||
for (k, num_timesteps_idx) in enumerate(num_timesteps_idxs): # For each num_timesteps index...
|
||||
|
||||
print(f"[ppo_hover_eval.py]: Starting epoch {k+1} out of {len(epoch_idxs)}.")
|
||||
|
||||
# Load the model for the appropriate epoch.
|
||||
model_path = os.path.join(epoch_dir, epochs_list_sorted[epoch_idx])
|
||||
model_path = os.path.join(num_timesteps_dir, num_timesteps_list_sorted[num_timesteps_idx])
|
||||
print(f"Loading model from the path {model_path}")
|
||||
model = PPO.load(model_path, env=envs[0], tensorboard_log=log_dir)
|
||||
|
||||
# Set figure title for 3D plot.
|
||||
fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}")
|
||||
fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(num_timesteps_list_sorted[num_timesteps_idx]):,}")
|
||||
|
||||
# Visualization is slow, so we'll also save frames to make a GIF later.
|
||||
# Set the path for these frames here.
|
||||
frame_path = os.path.join(output_dir, epochs_list_sorted[epoch_idx][:-4])
|
||||
frame_path = os.path.join(output_dir, num_timesteps_list_sorted[num_timesteps_idx][:-4])
|
||||
if not os.path.exists(frame_path):
|
||||
os.makedirs(frame_path)
|
||||
|
||||
@@ -192,7 +192,7 @@ for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index...
|
||||
|
||||
# Plot position vs time.
|
||||
fig_pos, ax_pos = plt.subplots(nrows=3, ncols=1, num="Position vs Time")
|
||||
fig_pos.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}")
|
||||
fig_pos.suptitle(f"Model: PPO/{models_available[model_idx]}, Num Timesteps: {extract_number(num_timesteps_list_sorted[num_timesteps_idx]):,}")
|
||||
ax_pos[0].plot(T, x[:, 0], 'b-', linewidth=1, label="RL")
|
||||
ax_pos[0].plot(T, x[:, 1:-1], 'b-', linewidth=1)
|
||||
ax_pos[0].plot(T, x[:, -1], 'k-', linewidth=2, label="GC")
|
||||
|
||||
Reference in New Issue
Block a user