From e13f7037fa28bef3a9d389880f612d57d8aca126 Mon Sep 17 00:00:00 2001 From: spencerfolk Date: Wed, 10 Jan 2024 10:31:20 -0500 Subject: [PATCH] Added wind scale factor for visualization --- rotorpy/learning/quadrotor_environments.py | 2 +- rotorpy/utils/animate.py | 2 +- rotorpy/utils/shapes.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/rotorpy/learning/quadrotor_environments.py b/rotorpy/learning/quadrotor_environments.py index 3772e1f..8a02989 100644 --- a/rotorpy/learning/quadrotor_environments.py +++ b/rotorpy/learning/quadrotor_environments.py @@ -165,7 +165,7 @@ class QuadrotorEnv(gym.Env): colors = list(mcolors.CSS4_COLORS) else: colors = [color] - self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.choice(colors)) + self.quad_obj = Quadrotor(self.ax, wind=True, color=np.random.choice(colors), wind_scale_factor=5) self.world_artists = None self.title_artist = self.ax.set_title('t = {}'.format(self.t)) diff --git a/rotorpy/utils/animate.py b/rotorpy/utils/animate.py index 769690c..53c942f 100644 --- a/rotorpy/utils/animate.py +++ b/rotorpy/utils/animate.py @@ -97,7 +97,7 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None, if not show_axes: ax.set_axis_off() - quad = Quadrotor(ax, wind=animate_wind) + quad = Quadrotor(ax, wind=animate_wind, wind_scale_factor=1) world_artists = world.draw(ax) diff --git a/rotorpy/utils/shapes.py b/rotorpy/utils/shapes.py index 397b7ab..036bcae 100644 --- a/rotorpy/utils/shapes.py +++ b/rotorpy/utils/shapes.py @@ -321,10 +321,11 @@ class Quadrotor(): def __init__(self, ax, arm_length=0.125, rotor_radius=0.08, n_rotors=4, - shade=True, color=None, wind=True): + shade=True, color=None, wind=True, wind_scale_factor=5): self.ax = ax self.wind_bool = wind + self.wind_scale_factor = wind_scale_factor # Apply same color to all rotor objects. if color is None: @@ -360,7 +361,7 @@ class Quadrotor(): r.transform(np.matmul(rotation,pos)+position, rotation) if self.wind_bool: self.wind_vector[0].remove() - self.wind_vector = [self.ax.quiver(position[0], position[1], position[2], wind[0], wind[1], wind[2], color='r', linewidth=1.5)] + self.wind_vector = [self.ax.quiver(position[0], position[1], position[2], wind[0]/self.wind_scale_factor, wind[1]/self.wind_scale_factor, wind[2]/self.wind_scale_factor, color='r', linewidth=1.5)] if __name__ == '__main__': import matplotlib.pyplot as plt