Changed color scheme, avoid redrawing world if fig is shared.

This commit is contained in:
spencerfolk
2024-01-03 23:41:06 -05:00
parent 173465dbec
commit 1769323890

View File

@@ -1,6 +1,7 @@
import numpy as np
from scipy.spatial.transform import Rotation
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from enum import Enum
from rotorpy.world import World
@@ -60,7 +61,8 @@ class QuadrotorEnv(gym.Env):
aero = True, # Whether or not aerodynamic wrenches are computed.
render_mode = "None", # The rendering mode
render_fps = 30, # The rendering frames per second. Lower this for faster visualization.
ax = None,
fig = None, # Figure for rendering. Optional.
ax = None, # Axis for rendering. Optional.
):
super(QuadrotorEnv, self).__init__()
@@ -151,13 +153,19 @@ class QuadrotorEnv(gym.Env):
self.wind_profile = wind_profile
if self.render_mode == '3D':
if fig is None and ax is None:
self.fig = plt.figure('Visualization')
# self.fig.clear()
self.ax = self.fig.add_subplot(projection='3d')
self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.rand(3))
else:
self.fig = fig
self.ax = ax
colors = list(mcolors.CSS4_COLORS)
self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.choice(colors))
self.world_artists = None
self.title_artist = self.ax.set_title('t = {}'.format(self.t))
self.rendering = False # Bool for tracking when the renderer is actually rendering a frame.
return
def render(self):
@@ -402,14 +410,17 @@ class QuadrotorEnv(gym.Env):
def _plot_quad(self):
if abs(self.t / (1/self.metadata['render_fps']) - round(self.t / (1/self.metadata['render_fps']))) > 1e-2:
if abs(self.t / (1/self.metadata['render_fps']) - round(self.t / (1/self.metadata['render_fps']))) > 5e-2:
self.rendering = False # Set rendering bool to false.
return
self.rendering = True # Set rendering bool to true.
plot_position = deepcopy(self.vehicle_state['x'])
plot_rotation = Rotation.from_quat(self.vehicle_state['q']).as_matrix()
plot_wind = deepcopy(self.vehicle_state['wind'])
if self.world_artists is None:
if self.world_artists is None and not ('x' in self.ax.get_xlabel()):
self.world_artists = self.world.draw(self.ax)
self.ax.plot(0, 0, 0, 'go')