2023-03-15 15:38:14 -04:00
"""
TODO : Set up figure for appropriate target video size ( eg . 720 p ) .
TODO : Decide which additional user options should be available .
"""
from datetime import datetime
from pathlib import Path
import numpy as np
from matplotlib . animation import FuncAnimation
import matplotlib . pyplot as plt
from scipy . spatial . transform import Rotation
from rotorpy . utils . shapes import Quadrotor
import os
class ClosingFuncAnimation ( FuncAnimation ) :
def __init__ ( self , fig , func , * args , * * kwargs ) :
self . _close_on_finish = kwargs . pop ( ' close_on_finish ' )
FuncAnimation . __init__ ( self , fig , func , * args , * * kwargs )
# def _stop(self, *args):
# super()._stop(self, *args)
# if self._close_on_finish:
# plt.close(self._fig)
def _step ( self , * args ) :
still_going = FuncAnimation . _step ( self , * args )
if self . _close_on_finish and not still_going :
plt . close ( self . _fig )
def _decimate_index ( time , sample_time ) :
"""
Given sorted lists of source times and sample times , return indices of
source time closest to each sample time .
"""
index = np . arange ( time . size )
sample_index = np . round ( np . interp ( sample_time , time , index ) ) . astype ( int )
return sample_index
2023-04-05 15:37:52 -04:00
def animate ( time , position , rotation , wind , animate_wind , world , filename = None , blit = False , show_axes = True , close_on_finish = False ) :
2023-03-15 15:38:14 -04:00
"""
Animate a completed simulation result based on the time , position , and
rotation history . The animation may be viewed live or saved to a . mp4 video
( slower , requires additional libraries ) .
For a live view , it is absolutely critical to retain a reference to the
returned object in order to prevent garbage collection before the animation
has completed displaying .
2024-03-06 14:25:06 -05:00
Below , M corresponds to the number of drones you ' re animating. If M is None, i.e. the arrays are (N,3) and (N,3,3), then it is assumed that there is only one drone.
Otherwise , we iterate over the M drones and animate them on the same axes .
N is the number of time steps in the simulation .
2023-03-15 15:38:14 -04:00
Parameters
time , ( N , ) with uniform intervals
2024-03-06 14:25:06 -05:00
position , ( N , M , 3 )
rotation , ( N , M , 3 , 3 )
wind , ( N , M , 3 ) world wind velocity
2023-04-05 15:37:52 -04:00
animate_wind , if True animate wind vector
2023-03-15 15:38:14 -04:00
world , a World object
filename , for saved video , or live view if None
blit , if True use blit for faster animation , default is False
show_axes , if True plot axes , default is True
close_on_finish , if True close figure at end of live animation or save , default is False
"""
2024-03-06 14:25:06 -05:00
# Check if there is only one drone.
if len ( position . shape ) == 2 :
position = np . expand_dims ( position , axis = 1 )
rotation = np . expand_dims ( rotation , axis = 1 )
wind = np . expand_dims ( wind , axis = 1 )
M = position . shape [ 1 ]
2023-03-15 15:38:14 -04:00
# Temporal style.
rtf = 1.0 # real time factor > 1.0 is faster than real time playback
render_fps = 30
2023-04-05 15:37:52 -04:00
# Normalize the wind by the max of the wind magnitude on each axis, so that the maximum length of the arrow is decided by the scale factor
2024-03-06 14:25:06 -05:00
wind_mag = np . max ( np . linalg . norm ( wind , axis = - 1 ) , axis = 1 ) # Get the wind magnitude time series
2023-04-05 15:37:52 -04:00
max_wind = np . max ( wind_mag ) # Find the maximum wind magnitude in the time series
2024-03-06 14:25:06 -05:00
2023-04-05 15:37:52 -04:00
if max_wind != 0 :
wind_arrow_scale_factor = 1 # Scale factor for the wind arrow
2024-03-06 14:25:06 -05:00
wind = wind_arrow_scale_factor * wind / max_wind
2023-04-05 15:37:52 -04:00
2023-03-15 15:38:14 -04:00
# Decimate data to render interval; always include t=0.
if time [ - 1 ] != 0 :
sample_time = np . arange ( 0 , time [ - 1 ] , 1 / render_fps * rtf )
else :
sample_time = np . zeros ( ( 1 , ) )
index = _decimate_index ( time , sample_time )
time = time [ index ]
position = position [ index , : ]
2024-03-06 14:25:06 -05:00
rotation = rotation [ index , : ]
2023-04-05 15:37:52 -04:00
wind = wind [ index , : ]
2023-03-15 15:38:14 -04:00
# Set up axes.
if filename is not None :
if isinstance ( filename , Path ) :
fig = plt . figure ( filename . name )
else :
fig = plt . figure ( filename )
else :
fig = plt . figure ( ' Animation ' )
fig . clear ( )
2023-12-15 11:00:30 -05:00
ax = fig . add_subplot ( projection = ' 3d ' )
2023-03-15 15:38:14 -04:00
if not show_axes :
ax . set_axis_off ( )
2024-03-06 14:25:06 -05:00
quads = [ Quadrotor ( ax , wind = animate_wind , wind_scale_factor = 1 ) for _ in range ( M ) ]
2023-03-15 15:38:14 -04:00
world_artists = world . draw ( ax )
title_artist = ax . set_title ( ' t = {} ' . format ( time [ 0 ] ) )
def init ( ) :
ax . draw ( fig . canvas . get_renderer ( ) )
2024-03-06 14:25:06 -05:00
# return world_artists + list(cquad.artists) + [title_artist]
return world_artists + [ title_artist ] + [ q . artists for q in quads ]
2023-03-15 15:38:14 -04:00
def update ( frame ) :
title_artist . set_text ( ' t = {:.2f} ' . format ( time [ frame ] ) )
2024-03-06 14:25:06 -05:00
for i , quad in enumerate ( quads ) :
quad . transform ( position = position [ frame , i , : ] , rotation = rotation [ frame , i , : , : ] , wind = wind [ frame , i , : ] )
2023-12-15 11:00:30 -05:00
# [a.do_3d_projection(fig.canvas.get_renderer()) for a in quad.artists] # No longer necessary in newer matplotlib?
2024-03-06 14:25:06 -05:00
# return world_artists + list(quad.artists) + [title_artist]
return world_artists + [ title_artist ] + [ q . artists for q in quads ]
2023-03-15 15:38:14 -04:00
ani = ClosingFuncAnimation ( fig = fig ,
func = update ,
frames = time . size ,
init_func = init ,
interval = 1000.0 / render_fps ,
repeat = False ,
blit = blit ,
close_on_finish = close_on_finish )
if filename is not None :
print ( ' Saving Animation ' )
if not " .mp4 " in filename :
filename = filename + " .mp4 "
path = os . path . join ( os . path . dirname ( __file__ ) , ' .. ' , ' data_out ' , filename )
ani . save ( path ,
writer = ' ffmpeg ' ,
fps = render_fps ,
dpi = 100 )
if close_on_finish :
plt . close ( fig )
ani = None
return ani
2024-03-06 14:25:06 -05:00
if __name__ == " __main__ " :
from rotorpy . vehicles . crazyflie_params import quad_params # Import quad params for the quadrotor environment.
from rotorpy . learning . quadrotor_environments import QuadrotorEnv
from rotorpy . controllers . quadrotor_control import SE3Control
from rotorpy . trajectories . circular_traj import CircularTraj
import gymnasium as gym
output_video_dir = os . path . join ( os . path . dirname ( __file__ ) , " .. " , " data_out " , " test_animation.mp4 " )
# Create M SE3 drones.
M = 3
baseline_controller = SE3Control ( quad_params )
def make_env ( ) :
return gym . make ( " Quadrotor-v0 " ,
control_mode = ' cmd_motor_speeds ' ,
quad_params = quad_params ,
max_time = 5 ,
world = None ,
sim_rate = 100 ,
render_mode = ' 3D ' ,
render_fps = 60 ,
color = ' b ' )
envs = [ make_env ( ) for _ in range ( M ) ]
# For each environment command it to do a circle with random radius and location.
trajs = [ ]
for env in envs :
center = np . random . uniform ( low = - 2 , high = 2 , size = 3 )
radius = np . random . uniform ( low = 0.5 , high = 1.5 )
freq = np . random . uniform ( low = 0.1 , high = 0.3 )
plane = np . random . choice ( [ ' XY ' , ' YZ ' , ' XZ ' ] )
traj = CircularTraj ( center = center , radius = radius , freq = freq , plane = plane , direction = np . random . choice ( [ ' CW ' , ' CCW ' ] ) )
trajs . append ( traj )
# Collect observations for each environment.
observations = [ env . reset ( initial_state = ' random ' ) [ 0 ] for env in envs ]
# This is a list of env termination conditions so that the loop only ends when the final env is terminated.
terminated = [ False ] * len ( observations )
# Arrays for animating.
T = [ 0 ]
position = [ [ obs [ 0 : 3 ] for obs in observations ] ]
quat = [ [ obs [ 6 : 10 ] for obs in observations ] ]
while not all ( terminated ) :
for ( i , env ) in enumerate ( envs ) : # For each environment...
# Unpack the observation from the environment
state = { ' x ' : observations [ i ] [ 0 : 3 ] , ' v ' : observations [ i ] [ 3 : 6 ] , ' q ' : observations [ i ] [ 6 : 10 ] , ' w ' : observations [ i ] [ 10 : 13 ] }
# Command the quad to do circles.
flat = trajs [ i ] . update ( env . unwrapped . t )
control_dict = baseline_controller . update ( 0 , state , flat )
# Extract the commanded motor speeds.
cmd_motor_speeds = control_dict [ ' cmd_motor_speeds ' ]
# The environment expects the control inputs to all be within the range [-1,1]
action = np . interp ( cmd_motor_speeds , [ env . unwrapped . rotor_speed_min , env . unwrapped . rotor_speed_max ] , [ - 1 , 1 ] )
# For the last environment, append the current timestep.
if i == 0 :
T . append ( env . unwrapped . t )
# Step the environment forward.
observations [ i ] , reward , terminated [ i ] , truncated , info = env . step ( action )
# Append arrays for plotting.
position . append ( [ obs [ 0 : 3 ] for obs in observations ] )
quat . append ( [ obs [ 6 : 10 ] for obs in observations ] )
# Convert to numpy arrays.
T = np . array ( T )
position = np . array ( position )
quat = np . array ( quat )
# Convert the quaternion to rotation matrix.
rotation = np . array ( [ Rotation . from_quat ( quat [ i ] ) . as_matrix ( ) for i in range ( T . size ) ] )
# Animate the results.
ani = animate ( T , position , rotation , wind = np . zeros ( ( T . size , M , 3 ) ) , animate_wind = False , world = envs [ 0 ] . world , filename = output_video_dir , blit = False , show_axes = True , close_on_finish = True )