diff --git a/examples/gymnasium_basic_usage.py b/examples/gymnasium_basic_usage.py new file mode 100644 index 0000000..b46c4a0 --- /dev/null +++ b/examples/gymnasium_basic_usage.py @@ -0,0 +1,105 @@ +import gymnasium as gym +import numpy as np +import matplotlib.pyplot as plt + +# For this demonstration, we'll just use the SE3 controller. +from rotorpy.controllers.quadrotor_control import SE3Control +from rotorpy.vehicles.crazyflie_params import quad_params + +controller = SE3Control(quad_params) + +# Import the QuadrotorEnv gymnasium environment using the following command. +from rotorpy.learning.quadrotor_environments import QuadrotorEnv + +# Reward functions can be specified by the user, or we can import from existing reward functions. +from rotorpy.learning.quadrotor_reward_functions import hover_reward + + +# First, we need to make the gym environment. The inputs to the model are as follows... +""" +Inputs: + initial_state: the initial state of the quadrotor. The default is hover. + control_mode: the appropriate control abstraction that is used by the controller, options are... + 'cmd_motor_speeds': the controller directly commands motor speeds. + 'cmd_motor_thrusts': the controller commands forces for each rotor. + 'cmd_ctbr': the controller commands a collective thrsut and body rates. + 'cmd_ctbm': the controller commands a collective thrust and moments on the x/y/z body axes + 'cmd_vel': the controller commands a velocity vector in the body frame. + reward_fn: the reward function, default to hover, but the user can pass in any function that is used as a reward. + quad_params: the parameters for the quadrotor. + max_time: the maximum time of the session. + world: the world for the quadrotor to operate within. + sim_rate: the simulation rate (in Hz), i.e. the timestep. + render_mode: render the quadrotor. + 'None': no rendering + 'console': output text describing the environment. + '3D': will render the quadrotor in 3D. WARNING: THIS IS SLOW. + +""" +env = gym.make("Quadrotor-v0", + control_mode ='cmd_motor_speeds', + reward_fn = hover_reward, + quad_params = quad_params, + max_time = 5, + world = None, + sim_rate = 100, + render_mode='3D') + +# Now reset the quadrotor. +# Setting initial_state to 'random' will randomly place the vehicle in the map near the origin. +# But you can also set the environment resetting to be deterministic. +observation, info = env.reset(initial_state='random') + +# Number of timesteps +T = 300 +time = np.arange(T)*(1/100) # Just for plotting purposes. +position = np.zeros((T, 3)) # Just for plotting purposes. +velocity = np.zeros((T, 3)) + +for i in range(T): + + # Unpack the observation from the environment + state = {'x': observation[0:3], 'v': observation[3:6], 'q': observation[6:10], 'w': observation[10:13], 'wind': observation[13:16], 'rotor_speeds': observation[16:20]} + + # For illustrative purposes, just command the quad to hover. + flat = {'x': [0, 0, 0], + 'x_dot': [0, 0, 0], + 'x_ddot': [0, 0, 0], + 'x_dddot': [0, 0, 0], + 'yaw': 0, + 'yaw_dot': 0, + 'yaw_ddot': 0} + control_dict = controller.update(0, state, flat) + + # Extract the commanded motor speeds. + cmd_motor_speeds = control_dict['cmd_motor_speeds'] + + # The environment expects the control inputs to all be within the range [-1,1] + action = np.interp(cmd_motor_speeds, [env.unwrapped.rotor_speed_min, env.unwrapped.rotor_speed_max], [-1,1]) + + # Step forward in the environment + observation, reward, terminated, truncated, info = env.step(action) + + # For plotting, save the relevant information + position[i, :] = observation[0:3] + velocity[i, :] = observation[3:6] + +env.close() + +# Plotting + +(fig, axes) = plt.subplots(nrows=2, ncols=1, num='Position vs Time') +ax = axes[0] +ax.plot(time, position[:, 0], 'r', label='X') +ax.plot(time, position[:, 1], 'g', label='Y') +ax.plot(time, position[:, 2], 'b', label='Z') +ax.set_ylabel("Position, m") +ax.legend() + +ax = axes[1] +ax.plot(time, velocity[:, 0], 'r', label='X') +ax.plot(time, velocity[:, 1], 'g', label='Y') +ax.plot(time, velocity[:, 2], 'b', label='Z') +ax.set_xlabel("Time, s") + +plt.show() \ No newline at end of file