Changed epoch to num_timesteps

Epoch was the wrong term to use.
This commit is contained in:
spencerfolk
2024-01-04 15:19:12 -05:00
committed by GitHub
parent 73b49de96c
commit 2e03548f85

View File

@@ -42,7 +42,7 @@ models_available = os.listdir(models_dir)
for i, name in enumerate(models_available):
print(f"{i}: {name}")
model_idx = int(input("Enter the model index: "))
epoch_dir = os.path.join(models_dir, models_available[model_idx])
num_timesteps_dir = os.path.join(models_dir, models_available[model_idx])
# Next import Stable Baselines.
try:
@@ -95,33 +95,33 @@ envs.append(gym.make("Quadrotor-v0",
# Print out policies for the user to select.
def extract_number(filename):
return int(filename.split('_')[1].split('.')[0])
epochs_list = [fname for fname in os.listdir(epoch_dir) if fname.startswith('hover_')]
epochs_list_sorted = sorted(epochs_list, key=extract_number)
num_timesteps_list = [fname for fname in os.listdir(num_timesteps_dir) if fname.startswith('hover_')]
num_timesteps_list_sorted = sorted(num_timesteps_list, key=extract_number)
print("Select one of the epochs:")
for i, name in enumerate(epochs_list_sorted):
for i, name in enumerate(num_timesteps_list_sorted):
print(f"{i}: {name}")
epoch_idxs = [int(input("Enter the epoch index: "))]
num_timesteps_idxs = [int(input("Enter the epoch index: "))]
# You can optionally just hard code a series of epochs you'd like to evaluate all at once.
# e.g. epoch_idxs = [0, 1, 2, ...]
# e.g. num_timesteps_idxs = [0, 1, 2, ...]
# Evaluation...
for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index...
for (k, num_timesteps_idx) in enumerate(num_timesteps_idxs): # For each num_timesteps index...
print(f"[ppo_hover_eval.py]: Starting epoch {k+1} out of {len(epoch_idxs)}.")
# Load the model for the appropriate epoch.
model_path = os.path.join(epoch_dir, epochs_list_sorted[epoch_idx])
model_path = os.path.join(num_timesteps_dir, num_timesteps_list_sorted[num_timesteps_idx])
print(f"Loading model from the path {model_path}")
model = PPO.load(model_path, env=envs[0], tensorboard_log=log_dir)
# Set figure title for 3D plot.
fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}")
fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(num_timesteps_list_sorted[num_timesteps_idx]):,}")
# Visualization is slow, so we'll also save frames to make a GIF later.
# Set the path for these frames here.
frame_path = os.path.join(output_dir, epochs_list_sorted[epoch_idx][:-4])
frame_path = os.path.join(output_dir, num_timesteps_list_sorted[num_timesteps_idx][:-4])
if not os.path.exists(frame_path):
os.makedirs(frame_path)
@@ -192,7 +192,7 @@ for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index...
# Plot position vs time.
fig_pos, ax_pos = plt.subplots(nrows=3, ncols=1, num="Position vs Time")
fig_pos.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}")
fig_pos.suptitle(f"Model: PPO/{models_available[model_idx]}, Num Timesteps: {extract_number(num_timesteps_list_sorted[num_timesteps_idx]):,}")
ax_pos[0].plot(T, x[:, 0], 'b-', linewidth=1, label="RL")
ax_pos[0].plot(T, x[:, 1:-1], 'b-', linewidth=1)
ax_pos[0].plot(T, x[:, -1], 'k-', linewidth=2, label="GC")