Documentation for ppo eval example.
This commit is contained in:
@@ -11,21 +11,39 @@ from rotorpy.learning.quadrotor_environments import QuadrotorEnv
|
||||
# Reward functions can be specified by the user, or we can import from existing reward functions.
|
||||
from rotorpy.learning.quadrotor_reward_functions import hover_reward
|
||||
|
||||
# For the baseline, we'll use the stock SE3 controller.
|
||||
from rotorpy.controllers.quadrotor_control import SE3Control
|
||||
baseline_controller = SE3Control(quad_params)
|
||||
|
||||
"""
|
||||
In this script, we evaluate the policy trained in ppo_hover_train.py.
|
||||
In this script, we evaluate the policy trained in ppo_hover_train.py. It's meant to complement the output of ppo_hover_train.py.
|
||||
|
||||
The task is for the quadrotor to stabilize to hover at the origin when starting at a random position nearby.
|
||||
|
||||
This script will ask the user which model they'd like to use, and then ask which specific epoch(s) they would like to evaluate.
|
||||
Then, for each model epoch selected, 10 agents will be spawned alongside the baseline SE3 controller at random positions.
|
||||
|
||||
Visualization is slow for this!! To speed things up, we save the figures as individual frames in data_out/ppo_hover/. If you
|
||||
close out of the matplotlib figure things should run faster. You can also speed it up by only visualizing 1 or 2 RL agents.
|
||||
|
||||
"""
|
||||
|
||||
# First we'll set up some directories for saving the policy and logs.
|
||||
models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies")
|
||||
models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "policies", "PPO")
|
||||
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "rotorpy", "learning", "logs")
|
||||
output_dir = os.path.join(os.path.dirname(__file__), "..", "rotorpy", "data_out", "ppo_hover")
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# List the models here and let the user select which one.
|
||||
print("Select one of the models:")
|
||||
models_available = os.listdir(models_dir)
|
||||
for i, name in enumerate(models_available):
|
||||
print(f"{i}: {name}")
|
||||
model_idx = int(input("Enter the model index: "))
|
||||
epoch_dir = os.path.join(models_dir, models_available[model_idx])
|
||||
|
||||
# Next import Stable Baselines.
|
||||
try:
|
||||
import stable_baselines3
|
||||
@@ -35,50 +53,165 @@ except:
|
||||
from stable_baselines3 import PPO # We'll use PPO for training.
|
||||
from stable_baselines3.ppo.policies import MlpPolicy # The policy will be represented by an MLP
|
||||
|
||||
num_cpu = 4 # for parallelization
|
||||
|
||||
# Choose the weights for our reward function. Here we are creating a lambda function over hover_reward.
|
||||
reward_function = lambda obs, act: hover_reward(obs, act, weights={'x': 1, 'v': 0.1, 'w': 0, 'u': 1e-5})
|
||||
|
||||
# Make the environment. For this demo we'll train a policy in cmd_vel. Higher abstractions lead to easier tasks.
|
||||
env = gym.make("Quadrotor-v0",
|
||||
control_mode ='cmd_ctbr',
|
||||
# Set up the figure for plotting all the agents.
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(projection='3d')
|
||||
|
||||
# Make the environments for the RL agents.
|
||||
num_quads = 10
|
||||
def make_env():
|
||||
return gym.make("Quadrotor-v0",
|
||||
control_mode ='cmd_motor_speeds',
|
||||
reward_fn = reward_function,
|
||||
quad_params = quad_params,
|
||||
max_time = 5,
|
||||
world = None,
|
||||
sim_rate = 100,
|
||||
render_mode='3D')
|
||||
render_mode='3D',
|
||||
render_fps = 60,
|
||||
fig=fig,
|
||||
ax=ax,
|
||||
color='b')
|
||||
|
||||
# from stable_baselines3.common.env_checker import check_env
|
||||
# check_env(env, warn=True) # you can check the environment using built-in tools
|
||||
envs = [make_env() for _ in range(num_quads)]
|
||||
|
||||
# Reset the environment
|
||||
observation, info = env.reset(initial_state='random', options={'pos_bound': 2, 'vel_bound': 0})
|
||||
# Lastly, add in the baseline (SE3 controller) environment.
|
||||
envs.append(gym.make("Quadrotor-v0",
|
||||
control_mode ='cmd_motor_speeds',
|
||||
reward_fn = reward_function,
|
||||
quad_params = quad_params,
|
||||
max_time = 5,
|
||||
world = None,
|
||||
sim_rate = 100,
|
||||
render_mode='3D',
|
||||
render_fps = 60,
|
||||
fig=fig,
|
||||
ax=ax,
|
||||
color='k')) # Geometric controller
|
||||
|
||||
# Print out policies for the user to select.
|
||||
print("Select one of the models:")
|
||||
models_available = os.listdir(models_dir)
|
||||
for i, name in enumerate(models_available):
|
||||
def extract_number(filename):
|
||||
return int(filename.split('_')[1].split('.')[0])
|
||||
epochs_list = [fname for fname in os.listdir(epoch_dir) if fname.startswith('hover_')]
|
||||
epochs_list_sorted = sorted(epochs_list, key=extract_number)
|
||||
|
||||
print("Select one of the epochs:")
|
||||
for i, name in enumerate(epochs_list_sorted):
|
||||
print(f"{i}: {name}")
|
||||
model_idx = int(input("Enter the model index: "))
|
||||
epoch_idxs = [int(input("Enter the epoch index: "))]
|
||||
|
||||
# Load the model
|
||||
model_path = os.path.join(models_dir, models_available[model_idx])
|
||||
# You can optionally just hard code a series of epochs you'd like to evaluate all at once.
|
||||
# e.g. epoch_idxs = [0, 1, 2, ...]
|
||||
|
||||
# Evaluation...
|
||||
for (k, epoch_idx) in enumerate(epoch_idxs): # For each epoch index...
|
||||
|
||||
print(f"[ppo_hover_eval.py]: Starting epoch {k+1} out of {len(epoch_idxs)}.")
|
||||
|
||||
# Load the model for the appropriate epoch.
|
||||
model_path = os.path.join(epoch_dir, epochs_list_sorted[epoch_idx])
|
||||
print(f"Loading model from the path {model_path}")
|
||||
model = PPO.load(model_path, env=env, tensorboard_log=log_dir)
|
||||
model = PPO.load(model_path, env=envs[0], tensorboard_log=log_dir)
|
||||
|
||||
num_episodes = 10
|
||||
for i in range(num_episodes):
|
||||
obs, info = env.reset()
|
||||
terminated = False
|
||||
j = 0
|
||||
while not terminated:
|
||||
# Set figure title for 3D plot.
|
||||
fig.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}")
|
||||
|
||||
# Visualization is slow, so we'll also save frames to make a GIF later.
|
||||
# Set the path for these frames here.
|
||||
frame_path = os.path.join(output_dir, epochs_list_sorted[epoch_idx][:-4])
|
||||
if not os.path.exists(frame_path):
|
||||
os.makedirs(frame_path)
|
||||
|
||||
# Collect observations for each environment.
|
||||
observations = [env.reset()[0] for env in envs]
|
||||
|
||||
# This is a list of env termination conditions so that the loop only ends when the final env is terminated.
|
||||
terminated = [False]*len(observations)
|
||||
|
||||
# Arrays for plotting position vs time.
|
||||
T = [0]
|
||||
x = [[obs[0] for obs in observations]]
|
||||
y = [[obs[1] for obs in observations]]
|
||||
z = [[obs[2] for obs in observations]]
|
||||
|
||||
j = 0 # Index for frames. Only updated when the last environment runs its update for the time step.
|
||||
while not all(terminated):
|
||||
frames = [] # Reset frames.
|
||||
for (i, env) in enumerate(envs): # For each environment...
|
||||
env.render()
|
||||
action, _ = model.predict(obs)
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if i == 0: # Save frames from the first rollout to make a gif.
|
||||
env.fig.savefig(os.path.join(output_dir, 'PPO_hover_'+str(j)+'.png'))
|
||||
|
||||
if i == len(envs)-1: # If it's the last environment, run the SE3 controller for the baseline.
|
||||
|
||||
# Unpack the observation from the environment
|
||||
state = {'x': observations[i][0:3], 'v': observations[i][3:6], 'q': observations[i][6:10], 'w': observations[i][10:13]}
|
||||
|
||||
# Command the quad to hover.
|
||||
flat = {'x': [0, 0, 0],
|
||||
'x_dot': [0, 0, 0],
|
||||
'x_ddot': [0, 0, 0],
|
||||
'x_dddot': [0, 0, 0],
|
||||
'yaw': 0,
|
||||
'yaw_dot': 0,
|
||||
'yaw_ddot': 0}
|
||||
control_dict = baseline_controller.update(0, state, flat)
|
||||
|
||||
# Extract the commanded motor speeds.
|
||||
cmd_motor_speeds = control_dict['cmd_motor_speeds']
|
||||
|
||||
# The environment expects the control inputs to all be within the range [-1,1]
|
||||
action = np.interp(cmd_motor_speeds, [env.unwrapped.rotor_speed_min, env.unwrapped.rotor_speed_max], [-1,1])
|
||||
|
||||
# For the last environment, append the current timestep.
|
||||
T.append(env.unwrapped.t)
|
||||
|
||||
else: # For all other environments, get the action from the RL control policy.
|
||||
action, _ = model.predict(observations[i], deterministic=True)
|
||||
|
||||
# Step the environment forward.
|
||||
observations[i], reward, terminated[i], truncated, info = env.step(action)
|
||||
|
||||
if i == len(envs)-1: # Save the current fig after the last agent.
|
||||
if env.unwrapped.rendering:
|
||||
frame = os.path.join(frame_path, 'frame_'+str(j)+'.png')
|
||||
fig.savefig(frame)
|
||||
j += 1
|
||||
|
||||
# Append arrays for plotting.
|
||||
x.append([obs[0] for obs in observations])
|
||||
y.append([obs[1] for obs in observations])
|
||||
z.append([obs[2] for obs in observations])
|
||||
|
||||
# Convert to numpy arrays.
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
z = np.array(z)
|
||||
T = np.array(T)
|
||||
|
||||
# Plot position vs time.
|
||||
fig_pos, ax_pos = plt.subplots(nrows=3, ncols=1, num="Position vs Time")
|
||||
fig_pos.suptitle(f"Model: PPO/{models_available[model_idx]}, Epoch: {extract_number(epochs_list_sorted[epoch_idx]):,}")
|
||||
ax_pos[0].plot(T, x[:, 0], 'b-', linewidth=1, label="RL")
|
||||
ax_pos[0].plot(T, x[:, 1:-1], 'b-', linewidth=1)
|
||||
ax_pos[0].plot(T, x[:, -1], 'k-', linewidth=2, label="GC")
|
||||
ax_pos[0].legend()
|
||||
ax_pos[0].set_ylabel("X, m")
|
||||
ax_pos[0].set_ylim([-2.5, 2.5])
|
||||
ax_pos[1].plot(T, y[:, 0], 'b-', linewidth=1, label="RL")
|
||||
ax_pos[1].plot(T, y[:, 1:-1], 'b-', linewidth=1)
|
||||
ax_pos[1].plot(T, y[:, -1], 'k-', linewidth=2, label="GC")
|
||||
ax_pos[1].set_ylabel("Y, m")
|
||||
ax_pos[1].set_ylim([-2.5, 2.5])
|
||||
ax_pos[2].plot(T, z[:, 0], 'b-', linewidth=1, label="RL")
|
||||
ax_pos[2].plot(T, z[:, 1:-1], 'b-', linewidth=1)
|
||||
ax_pos[2].plot(T, z[:, -1], 'k-', linewidth=2, label="GC")
|
||||
ax_pos[2].set_ylabel("Z, m")
|
||||
ax_pos[2].set_ylim([-2.5, 2.5])
|
||||
ax_pos[2].set_xlabel("Time, s")
|
||||
|
||||
# Save fig.
|
||||
fig_pos.savefig(os.path.join(frame_path, 'position_vs_time.png'))
|
||||
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user