diff --git a/rotorpy/utils/animate.py b/rotorpy/utils/animate.py index 53c942f..4da06b1 100644 --- a/rotorpy/utils/animate.py +++ b/rotorpy/utils/animate.py @@ -49,11 +49,16 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None, returned object in order to prevent garbage collection before the animation has completed displaying. + Below, M corresponds to the number of drones you're animating. If M is None, i.e. the arrays are (N,3) and (N,3,3), then it is assumed that there is only one drone. + Otherwise, we iterate over the M drones and animate them on the same axes. + + N is the number of time steps in the simulation. + Parameters time, (N,) with uniform intervals - position, (N,3) - rotation, (N,3,3) - wind, (N,3) world wind velocity + position, (N,M,3) + rotation, (N,M,3,3) + wind, (N,M,3) world wind velocity animate_wind, if True animate wind vector world, a World object filename, for saved video, or live view if None @@ -62,16 +67,24 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None, close_on_finish, if True close figure at end of live animation or save, default is False """ + # Check if there is only one drone. + if len(position.shape) == 2: + position = np.expand_dims(position, axis=1) + rotation = np.expand_dims(rotation, axis=1) + wind = np.expand_dims(wind, axis=1) + M = position.shape[1] + # Temporal style. rtf = 1.0 # real time factor > 1.0 is faster than real time playback render_fps = 30 # Normalize the wind by the max of the wind magnitude on each axis, so that the maximum length of the arrow is decided by the scale factor - wind_mag = np.linalg.norm(wind, axis=1) # Get the wind magnitude time series + wind_mag = np.max(np.linalg.norm(wind, axis=-1), axis=1) # Get the wind magnitude time series max_wind = np.max(wind_mag) # Find the maximum wind magnitude in the time series + if max_wind != 0: wind_arrow_scale_factor = 1 # Scale factor for the wind arrow - wind = wind_arrow_scale_factor*wind / max_wind # Apply scaling on wind. + wind = wind_arrow_scale_factor*wind / max_wind # Decimate data to render interval; always include t=0. if time[-1] != 0: @@ -81,7 +94,7 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None, index = _decimate_index(time, sample_time) time = time[index] position = position[index,:] - rotation = rotation[index,:,:] + rotation = rotation[index,:] wind = wind[index,:] # Set up axes. @@ -97,7 +110,7 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None, if not show_axes: ax.set_axis_off() - quad = Quadrotor(ax, wind=animate_wind, wind_scale_factor=1) + quads = [Quadrotor(ax, wind=animate_wind, wind_scale_factor=1) for _ in range(M)] world_artists = world.draw(ax) @@ -105,13 +118,16 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None, def init(): ax.draw(fig.canvas.get_renderer()) - return world_artists + list(quad.artists) + [title_artist] + # return world_artists + list(cquad.artists) + [title_artist] + return world_artists + [title_artist] + [q.artists for q in quads] def update(frame): title_artist.set_text('t = {:.2f}'.format(time[frame])) - quad.transform(position=position[frame,:], rotation=rotation[frame,:,:], wind=wind[frame,:]) + for i, quad in enumerate(quads): + quad.transform(position=position[frame,i,:], rotation=rotation[frame,i,:,:], wind=wind[frame,i,:]) # [a.do_3d_projection(fig.canvas.get_renderer()) for a in quad.artists] # No longer necessary in newer matplotlib? - return world_artists + list(quad.artists) + [title_artist] + # return world_artists + list(quad.artists) + [title_artist] + return world_artists + [title_artist] + [q.artists for q in quads] ani = ClosingFuncAnimation(fig=fig, func=update, @@ -136,3 +152,89 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None, ani = None return ani + +if __name__ == "__main__": + + from rotorpy.vehicles.crazyflie_params import quad_params # Import quad params for the quadrotor environment. + from rotorpy.learning.quadrotor_environments import QuadrotorEnv + from rotorpy.controllers.quadrotor_control import SE3Control + from rotorpy.trajectories.circular_traj import CircularTraj + + import gymnasium as gym + + output_video_dir = os.path.join(os.path.dirname(__file__), "..", "data_out", "test_animation.mp4") + + # Create M SE3 drones. + M = 3 + baseline_controller = SE3Control(quad_params) + + def make_env(): + return gym.make("Quadrotor-v0", + control_mode ='cmd_motor_speeds', + quad_params = quad_params, + max_time = 5, + world = None, + sim_rate = 100, + render_mode='3D', + render_fps = 60, + color='b') + + envs = [make_env() for _ in range(M)] + + # For each environment command it to do a circle with random radius and location. + trajs = [] + for env in envs: + center = np.random.uniform(low=-2, high=2, size=3) + radius = np.random.uniform(low=0.5, high=1.5) + freq = np.random.uniform(low=0.1, high=0.3) + plane = np.random.choice(['XY', 'YZ', 'XZ']) + traj = CircularTraj(center=center, radius=radius, freq=freq, plane=plane, direction=np.random.choice(['CW', 'CCW'])) + trajs.append(traj) + + # Collect observations for each environment. + observations = [env.reset(initial_state='random')[0] for env in envs] + + # This is a list of env termination conditions so that the loop only ends when the final env is terminated. + terminated = [False]*len(observations) + + # Arrays for animating. + T = [0] + position = [[obs[0:3] for obs in observations]] + quat = [[obs[6:10] for obs in observations]] + + while not all(terminated): + for (i, env) in enumerate(envs): # For each environment... + # Unpack the observation from the environment + state = {'x': observations[i][0:3], 'v': observations[i][3:6], 'q': observations[i][6:10], 'w': observations[i][10:13]} + + # Command the quad to do circles. + flat = trajs[i].update(env.unwrapped.t) + control_dict = baseline_controller.update(0, state, flat) + + # Extract the commanded motor speeds. + cmd_motor_speeds = control_dict['cmd_motor_speeds'] + + # The environment expects the control inputs to all be within the range [-1,1] + action = np.interp(cmd_motor_speeds, [env.unwrapped.rotor_speed_min, env.unwrapped.rotor_speed_max], [-1,1]) + + # For the last environment, append the current timestep. + if i == 0: + T.append(env.unwrapped.t) + + # Step the environment forward. + observations[i], reward, terminated[i], truncated, info = env.step(action) + + # Append arrays for plotting. + position.append([obs[0:3] for obs in observations]) + quat.append([obs[6:10] for obs in observations]) + + # Convert to numpy arrays. + T = np.array(T) + position = np.array(position) + quat = np.array(quat) + + # Convert the quaternion to rotation matrix. + rotation = np.array([Rotation.from_quat(quat[i]).as_matrix() for i in range(T.size)]) + + # Animate the results. + ani = animate(T, position, rotation, wind=np.zeros((T.size,M,3)), animate_wind=False, world=envs[0].world, filename=output_video_dir, blit=False, show_axes=True, close_on_finish=True) \ No newline at end of file