Randomize quad color for each instance. Misc fixes.

This commit is contained in:
spencerfolk
2023-12-30 00:08:44 -05:00
parent 0cbc806fa0
commit 3e181d7646

View File

@@ -37,6 +37,7 @@ class QuadrotorEnv(gym.Env):
""" """
metadata = {"render_modes": ["None", "3D", "console"], metadata = {"render_modes": ["None", "3D", "console"],
"render_fps": 30,
"control_modes": ['cmd_motor_speeds', 'cmd_motor_thrusts', 'cmd_ctbr', 'cmd_ctbm', 'cmd_vel']} "control_modes": ['cmd_motor_speeds', 'cmd_motor_thrusts', 'cmd_ctbr', 'cmd_ctbm', 'cmd_vel']}
def __init__(self, def __init__(self,
@@ -53,7 +54,8 @@ class QuadrotorEnv(gym.Env):
wind_profile = None, # wind profile object, if none is supplied it will choose no wind. wind_profile = None, # wind profile object, if none is supplied it will choose no wind.
world = None, # The world object world = None, # The world object
sim_rate = 100, # The update frequency of the simulator in Hz sim_rate = 100, # The update frequency of the simulator in Hz
render_mode = None, # The rendering mode render_mode = "None", # The rendering mode
ax = None,
): ):
super(QuadrotorEnv, self).__init__() super(QuadrotorEnv, self).__init__()
@@ -66,7 +68,6 @@ class QuadrotorEnv(gym.Env):
self.sim_rate = sim_rate self.sim_rate = sim_rate
self.t_step = 1/self.sim_rate self.t_step = 1/self.sim_rate
self.world = world
self.reward_fn = reward_fn self.reward_fn = reward_fn
# Create quadrotor from quad params and control abstraction. # Create quadrotor from quad params and control abstraction.
@@ -151,7 +152,7 @@ class QuadrotorEnv(gym.Env):
self.fig = plt.figure('Visualization') self.fig = plt.figure('Visualization')
# self.fig.clear() # self.fig.clear()
self.ax = self.fig.add_subplot(projection='3d') self.ax = self.fig.add_subplot(projection='3d')
self.quad_obj = Quadrotor(self.ax, wind=True) self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.rand(3))
self.world_artists = None self.world_artists = None
self.title_artist = self.ax.set_title('t = {}'.format(self.t)) self.title_artist = self.ax.set_title('t = {}'.format(self.t))
@@ -168,7 +169,7 @@ class QuadrotorEnv(gym.Env):
# Close the plots # Close the plots
plt.close('all') plt.close('all')
def reset(self, seed=None, initial_state='random', options={'pos_bound': -2, 'vel_bound': 0}): def reset(self, seed=None, initial_state='random', options={'pos_bound': 2, 'vel_bound': 0}):
""" """
Reset the environment Reset the environment
Inputs: Inputs:
@@ -191,6 +192,7 @@ class QuadrotorEnv(gym.Env):
'vel_bound': the min/max velocity region for random placement 'vel_bound': the min/max velocity region for random placement
""" """
assert options['pos_bound'] >= 0 and options['vel_bound'] >= 0 , "Bounds must be greater than or equal to 0."
super().reset(seed=seed) super().reset(seed=seed)