From b0397976ec0d5f75365803b9db8164e41e860c9a Mon Sep 17 00:00:00 2001 From: spencerfolk Date: Thu, 4 Jan 2024 12:30:35 -0500 Subject: [PATCH] Documentation for ppo eval example. --- examples/ppo_hover_eval.py | 197 +++++++++++++++++++++++++++++++------ 1 file changed, 165 insertions(+), 32 deletions(-) diff --git a/examples/ppo_hover_eval.py b/examples/ppo_hover_eval.py index 7a629ee..9874939 100644 --- a/examples/ppo_hover_eval.py +++ b/examples/ppo_hover_eval.py @@ -11,21 +11,39 @@ from rotorpy.learning.quadrotor_environments import QuadrotorEnv # Reward functions can be specified by the user, or we can import from existing reward functions. from rotorpy.learning.quadrotor_reward_functions import hover_reward +# For the baseline, we'll use the stock SE3 controller. +from rotorpy.controllers.quadrotor_control import SE3Control +baseline_controller = SE3Control(quad_params) + """ -In this script, we evaluate the policy trained in ppo_hover_train.py. +In this script, we evaluate the policy trained in ppo_hover_train.py. It's meant to complement the output of ppo_hover_train.py. The task is for the quadrotor to stabilize to hover at the origin when starting at a random position nearby. +This script will ask the user which model they'd like to use, and then ask which specific epoch(s) they would like to evaluate. +Then, for each model epoch selected, 10 agents will be spawned alongside the baseline SE3 controller at random positions. + +Visualization is slow for this!! To speed things up, we save the figures as individual frames in data_out/ppo_hover/. If you +close out of the matplotlib figure things should run faster. You can also speed it up by only visualizing 1 or 2 RL agents. + """ # First we'll set up some directories for saving the policy and logs. -models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies") +models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies", "PPO") log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "logs") output_dir = os.path.join(os.path.dirname(__file__), "..", "rotorpy", "data_out", "ppo_hover") if not os.path.exists(output_dir): os.makedirs(output_dir) +# List the models here and let the user select which one. +print("Select one of the models:") +models_available = os.listdir(models_dir) +for i, name in enumerate(models_available): + print(f"{i}: {name}") +model_idx = int(input("Enter the model index: ")) +epoch_dir = os.path.join(models_dir, models_available[model_idx]) + # Next import Stable Baselines. try: import stable_baselines3 @@ -35,50 +53,165 @@ except: from stable_baselines3 import PPO # We'll use PPO for training. from stable_baselines3.ppo.policies import MlpPolicy # The policy will be represented by an MLP -num_cpu = 4 # for parallelization - # Choose the weights for our reward function. Here we are creating a lambda function over hover_reward. reward_function = lambda obs, act: hover_reward(obs, act, weights={'x': 1, 'v': 0.1, 'w': 0, 'u': 1e-5}) -# Make the environment. For this demo we'll train a policy in cmd_vel. Higher abstractions lead to easier tasks. -env = gym.make("Quadrotor-v0", - control_mode ='cmd_ctbr', +# Set up the figure for plotting all the agents. +fig = plt.figure() +ax = fig.add_subplot(projection='3d') + +# Make the environments for the RL agents. +num_quads = 10 +def make_env(): + return gym.make("Quadrotor-v0", + control_mode ='cmd_motor_speeds', reward_fn = reward_function, quad_params = quad_params, max_time = 5, world = None, sim_rate = 100, - render_mode='3D') + render_mode='3D', + render_fps = 60, + fig=fig, + ax=ax, + color='b') -# from stable_baselines3.common.env_checker import check_env -# check_env(env, warn=True) # you can check the environment using built-in tools +envs = [make_env() for _ in range(num_quads)] -# Reset the environment -observation, info = env.reset(initial_state='random', options={'pos_bound': 2, 'vel_bound': 0}) +# Lastly, add in the baseline (SE3 controller) environment. +envs.append(gym.make("Quadrotor-v0", + control_mode ='cmd_motor_speeds', + reward_fn = reward_function, + quad_params = quad_params, + max_time = 5, + world = None, + sim_rate = 100, + render_mode='3D', + render_fps = 60, + fig=fig, + ax=ax, + color='k')) # Geometric controller # Print out policies for the user to select. -print("Select one of the models:") -models_available = os.listdir(models_dir) -for i, name in enumerate(models_available): +def extract_number(filename): + return int(filename.split('_')[1].split('.')[0]) +epochs_list = [fname for fname in os.listdir(epoch_dir) if fname.startswith('hover_')] +epochs_list_sorted = sorted(epochs_list, key=extract_number) + +print("Select one of the epochs:") +for i, name in enumerate(epochs_list_sorted): print(f"{i}: {name}") -model_idx = int(input("Enter the model index: ")) +epoch_idxs = [int(input("Enter the epoch index: "))] -# Load the model -model_path = os.path.join(models_dir, models_available[model_idx]) -print(f"Loading model from the path {model_path}") -model = PPO.load(model_path, env=env, tensorboard_log=log_dir) +# You can optionally just hard code a series of epochs you'd like to evaluate all at once. +# e.g. epoch_idxs = [0, 1, 2, ...] -num_episodes = 10 -for i in range(num_episodes): - obs, info = env.reset() - terminated = False - j = 0 - while not terminated: - env.render() - action, _ = model.predict(obs) - obs, reward, terminated, truncated, info = env.step(action) - if i == 0: # Save frames from the first rollout to make a gif. - env.fig.savefig(os.path.join(output_dir, 'PPO_hover_'+str(j)+'.png')) - j += 1 +# Evaluation... +for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index... + + print(f"[ppo_hover_eval.py]: Starting epoch {k+1} out of {len(epoch_idxs)}.") + + # Load the model for the appropriate epoch. + model_path = os.path.join(epoch_dir, epochs_list_sorted[epoch_idx]) + print(f"Loading model from the path {model_path}") + model = PPO.load(model_path, env=envs[0], tensorboard_log=log_dir) + + # Set figure title for 3D plot. + fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}") + + # Visualization is slow, so we'll also save frames to make a GIF later. + # Set the path for these frames here. + frame_path = os.path.join(output_dir, epochs_list_sorted[epoch_idx][:-4]) + if not os.path.exists(frame_path): + os.makedirs(frame_path) + + # Collect observations for each environment. + observations = [env.reset()[0] for env in envs] + + # This is a list of env termination conditions so that the loop only ends when the final env is terminated. + terminated = [False]*len(observations) + + # Arrays for plotting position vs time. + T = [0] + x = [[obs[0] for obs in observations]] + y = [[obs[1] for obs in observations]] + z = [[obs[2] for obs in observations]] + + j = 0 # Index for frames. Only updated when the last environment runs its update for the time step. + while not all(terminated): + frames = [] # Reset frames. + for (i, env) in enumerate(envs): # For each environment... + env.render() + + if i == len(envs)-1: # If it's the last environment, run the SE3 controller for the baseline. + + # Unpack the observation from the environment + state = {'x': observations[i][0:3], 'v': observations[i][3:6], 'q': observations[i][6:10], 'w': observations[i][10:13]} + + # Command the quad to hover. + flat = {'x': [0, 0, 0], + 'x_dot': [0, 0, 0], + 'x_ddot': [0, 0, 0], + 'x_dddot': [0, 0, 0], + 'yaw': 0, + 'yaw_dot': 0, + 'yaw_ddot': 0} + control_dict = baseline_controller.update(0, state, flat) + + # Extract the commanded motor speeds. + cmd_motor_speeds = control_dict['cmd_motor_speeds'] + + # The environment expects the control inputs to all be within the range [-1,1] + action = np.interp(cmd_motor_speeds, [env.unwrapped.rotor_speed_min, env.unwrapped.rotor_speed_max], [-1,1]) + + # For the last environment, append the current timestep. + T.append(env.unwrapped.t) + + else: # For all other environments, get the action from the RL control policy. + action, _ = model.predict(observations[i], deterministic=True) + + # Step the environment forward. + observations[i], reward, terminated[i], truncated, info = env.step(action) + + if i == len(envs)-1: # Save the current fig after the last agent. + if env.unwrapped.rendering: + frame = os.path.join(frame_path, 'frame_'+str(j)+'.png') + fig.savefig(frame) + j += 1 + + # Append arrays for plotting. + x.append([obs[0] for obs in observations]) + y.append([obs[1] for obs in observations]) + z.append([obs[2] for obs in observations]) + + # Convert to numpy arrays. + x = np.array(x) + y = np.array(y) + z = np.array(z) + T = np.array(T) + + # Plot position vs time. + fig_pos, ax_pos = plt.subplots(nrows=3, ncols=1, num="Position vs Time") + fig_pos.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}") + ax_pos[0].plot(T, x[:, 0], 'b-', linewidth=1, label="RL") + ax_pos[0].plot(T, x[:, 1:-1], 'b-', linewidth=1) + ax_pos[0].plot(T, x[:, -1], 'k-', linewidth=2, label="GC") + ax_pos[0].legend() + ax_pos[0].set_ylabel("X, m") + ax_pos[0].set_ylim([-2.5, 2.5]) + ax_pos[1].plot(T, y[:, 0], 'b-', linewidth=1, label="RL") + ax_pos[1].plot(T, y[:, 1:-1], 'b-', linewidth=1) + ax_pos[1].plot(T, y[:, -1], 'k-', linewidth=2, label="GC") + ax_pos[1].set_ylabel("Y, m") + ax_pos[1].set_ylim([-2.5, 2.5]) + ax_pos[2].plot(T, z[:, 0], 'b-', linewidth=1, label="RL") + ax_pos[2].plot(T, z[:, 1:-1], 'b-', linewidth=1) + ax_pos[2].plot(T, z[:, -1], 'k-', linewidth=2, label="GC") + ax_pos[2].set_ylabel("Z, m") + ax_pos[2].set_ylim([-2.5, 2.5]) + ax_pos[2].set_xlabel("Time, s") + + # Save fig. + fig_pos.savefig(os.path.join(frame_path, 'position_vs_time.png')) plt.show() \ No newline at end of file