From 1769323890a15f683f8b86ae02bda1daeec88a08 Mon Sep 17 00:00:00 2001 From: spencerfolk Date: Wed, 3 Jan 2024 23:41:06 -0500 Subject: [PATCH] Changed color scheme, avoid redrawing world if fig is shared. --- rotorpy/learning/quadrotor_environments.py | 25 ++++++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/rotorpy/learning/quadrotor_environments.py b/rotorpy/learning/quadrotor_environments.py index 354bff8..7c76b14 100644 --- a/rotorpy/learning/quadrotor_environments.py +++ b/rotorpy/learning/quadrotor_environments.py @@ -1,6 +1,7 @@ import numpy as np from scipy.spatial.transform import Rotation import matplotlib.pyplot as plt +import matplotlib.colors as mcolors from enum import Enum from rotorpy.world import World @@ -60,7 +61,8 @@ class QuadrotorEnv(gym.Env): aero = True, # Whether or not aerodynamic wrenches are computed. render_mode = "None", # The rendering mode render_fps = 30, # The rendering frames per second. Lower this for faster visualization. - ax = None, + fig = None, # Figure for rendering. Optional. + ax = None, # Axis for rendering. Optional. ): super(QuadrotorEnv, self).__init__() @@ -151,13 +153,19 @@ class QuadrotorEnv(gym.Env): self.wind_profile = wind_profile if self.render_mode == '3D': - self.fig = plt.figure('Visualization') - # self.fig.clear() - self.ax = self.fig.add_subplot(projection='3d') - self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.rand(3)) + if fig is None and ax is None: + self.fig = plt.figure('Visualization') + self.ax = self.fig.add_subplot(projection='3d') + else: + self.fig = fig + self.ax = ax + colors = list(mcolors.CSS4_COLORS) + self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.choice(colors)) self.world_artists = None self.title_artist = self.ax.set_title('t = {}'.format(self.t)) + self.rendering = False # Bool for tracking when the renderer is actually rendering a frame. + return def render(self): @@ -402,14 +410,17 @@ class QuadrotorEnv(gym.Env): def _plot_quad(self): - if abs(self.t / (1/self.metadata['render_fps']) - round(self.t / (1/self.metadata['render_fps']))) > 1e-2: + if abs(self.t / (1/self.metadata['render_fps']) - round(self.t / (1/self.metadata['render_fps']))) > 5e-2: + self.rendering = False # Set rendering bool to false. return + self.rendering = True # Set rendering bool to true. + plot_position = deepcopy(self.vehicle_state['x']) plot_rotation = Rotation.from_quat(self.vehicle_state['q']).as_matrix() plot_wind = deepcopy(self.vehicle_state['wind']) - if self.world_artists is None: + if self.world_artists is None and not ('x' in self.ax.get_xlabel()): self.world_artists = self.world.draw(self.ax) self.ax.plot(0, 0, 0, 'go')