Changed color scheme, avoid redrawing world if fig is shared.

This commit is contained in:
spencerfolk
2024-01-03 23:41:06 -05:00
parent 173465dbec
commit 1769323890

View File

@@ -1,6 +1,7 @@
import numpy as np import numpy as np
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from enum import Enum from enum import Enum
from rotorpy.world import World from rotorpy.world import World
@@ -60,7 +61,8 @@ class QuadrotorEnv(gym.Env):
aero = True, # Whether or not aerodynamic wrenches are computed. aero = True, # Whether or not aerodynamic wrenches are computed.
render_mode = "None", # The rendering mode render_mode = "None", # The rendering mode
render_fps = 30, # The rendering frames per second. Lower this for faster visualization. render_fps = 30, # The rendering frames per second. Lower this for faster visualization.
ax = None, fig = None, # Figure for rendering. Optional.
ax = None, # Axis for rendering. Optional.
): ):
super(QuadrotorEnv, self).__init__() super(QuadrotorEnv, self).__init__()
@@ -151,13 +153,19 @@ class QuadrotorEnv(gym.Env):
self.wind_profile = wind_profile self.wind_profile = wind_profile
if self.render_mode == '3D': if self.render_mode == '3D':
if fig is None and ax is None:
self.fig = plt.figure('Visualization') self.fig = plt.figure('Visualization')
# self.fig.clear()
self.ax = self.fig.add_subplot(projection='3d') self.ax = self.fig.add_subplot(projection='3d')
self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.rand(3)) else:
self.fig = fig
self.ax = ax
colors = list(mcolors.CSS4_COLORS)
self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.choice(colors))
self.world_artists = None self.world_artists = None
self.title_artist = self.ax.set_title('t = {}'.format(self.t)) self.title_artist = self.ax.set_title('t = {}'.format(self.t))
self.rendering = False # Bool for tracking when the renderer is actually rendering a frame.
return return
def render(self): def render(self):
@@ -402,14 +410,17 @@ class QuadrotorEnv(gym.Env):
def _plot_quad(self): def _plot_quad(self):
if abs(self.t / (1/self.metadata['render_fps']) - round(self.t / (1/self.metadata['render_fps']))) > 1e-2: if abs(self.t / (1/self.metadata['render_fps']) - round(self.t / (1/self.metadata['render_fps']))) > 5e-2:
self.rendering = False # Set rendering bool to false.
return return
self.rendering = True # Set rendering bool to true.
plot_position = deepcopy(self.vehicle_state['x']) plot_position = deepcopy(self.vehicle_state['x'])
plot_rotation = Rotation.from_quat(self.vehicle_state['q']).as_matrix() plot_rotation = Rotation.from_quat(self.vehicle_state['q']).as_matrix()
plot_wind = deepcopy(self.vehicle_state['wind']) plot_wind = deepcopy(self.vehicle_state['wind'])
if self.world_artists is None: if self.world_artists is None and not ('x' in self.ax.get_xlabel()):
self.world_artists = self.world.draw(self.ax) self.world_artists = self.world.draw(self.ax)
self.ax.plot(0, 0, 0, 'go') self.ax.plot(0, 0, 0, 'go')