Bug fix in example

This commit is contained in:
spencerfolk
2024-01-04 18:17:43 -05:00
parent b17fe2a437
commit ea9d8dc009

View File

@@ -109,7 +109,7 @@ num_timesteps_idxs = [int(input("Enter the epoch index: "))]
# Evaluation... # Evaluation...
for (k, num_timesteps_idx) in enumerate(num_timesteps_idxs): # For each num_timesteps index... for (k, num_timesteps_idx) in enumerate(num_timesteps_idxs): # For each num_timesteps index...
print(f"[ppo_hover_eval.py]: Starting epoch {k+1} out of {len(epoch_idxs)}.") print(f"[ppo_hover_eval.py]: Starting epoch {k+1} out of {len(num_timesteps_idxs)}.")
# Load the model for the appropriate epoch. # Load the model for the appropriate epoch.
model_path = os.path.join(num_timesteps_dir, num_timesteps_list_sorted[num_timesteps_idx]) model_path = os.path.join(num_timesteps_dir, num_timesteps_list_sorted[num_timesteps_idx])
@@ -117,7 +117,7 @@ for (k, num_timesteps_idx) in enumerate(num_timesteps_idxs): # For each num_tim
model = PPO.load(model_path, env=envs[0], tensorboard_log=log_dir) model = PPO.load(model_path, env=envs[0], tensorboard_log=log_dir)
# Set figure title for 3D plot. # Set figure title for 3D plot.
fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(num_timesteps_list_sorted[num_timesteps_idx]):,}") fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Num Timesteps: {extract_number(num_timesteps_list_sorted[num_timesteps_idx]):,}")
# Visualization is slow, so we'll also save frames to make a GIF later. # Visualization is slow, so we'll also save frames to make a GIF later.
# Set the path for these frames here. # Set the path for these frames here.