Support for animating multiple vehicles.
This commit is contained in:
@@ -49,11 +49,16 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None,
|
|||||||
returned object in order to prevent garbage collection before the animation
|
returned object in order to prevent garbage collection before the animation
|
||||||
has completed displaying.
|
has completed displaying.
|
||||||
|
|
||||||
|
Below, M corresponds to the number of drones you're animating. If M is None, i.e. the arrays are (N,3) and (N,3,3), then it is assumed that there is only one drone.
|
||||||
|
Otherwise, we iterate over the M drones and animate them on the same axes.
|
||||||
|
|
||||||
|
N is the number of time steps in the simulation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
time, (N,) with uniform intervals
|
time, (N,) with uniform intervals
|
||||||
position, (N,3)
|
position, (N,M,3)
|
||||||
rotation, (N,3,3)
|
rotation, (N,M,3,3)
|
||||||
wind, (N,3) world wind velocity
|
wind, (N,M,3) world wind velocity
|
||||||
animate_wind, if True animate wind vector
|
animate_wind, if True animate wind vector
|
||||||
world, a World object
|
world, a World object
|
||||||
filename, for saved video, or live view if None
|
filename, for saved video, or live view if None
|
||||||
@@ -62,16 +67,24 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None,
|
|||||||
close_on_finish, if True close figure at end of live animation or save, default is False
|
close_on_finish, if True close figure at end of live animation or save, default is False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Check if there is only one drone.
|
||||||
|
if len(position.shape) == 2:
|
||||||
|
position = np.expand_dims(position, axis=1)
|
||||||
|
rotation = np.expand_dims(rotation, axis=1)
|
||||||
|
wind = np.expand_dims(wind, axis=1)
|
||||||
|
M = position.shape[1]
|
||||||
|
|
||||||
# Temporal style.
|
# Temporal style.
|
||||||
rtf = 1.0 # real time factor > 1.0 is faster than real time playback
|
rtf = 1.0 # real time factor > 1.0 is faster than real time playback
|
||||||
render_fps = 30
|
render_fps = 30
|
||||||
|
|
||||||
# Normalize the wind by the max of the wind magnitude on each axis, so that the maximum length of the arrow is decided by the scale factor
|
# Normalize the wind by the max of the wind magnitude on each axis, so that the maximum length of the arrow is decided by the scale factor
|
||||||
wind_mag = np.linalg.norm(wind, axis=1) # Get the wind magnitude time series
|
wind_mag = np.max(np.linalg.norm(wind, axis=-1), axis=1) # Get the wind magnitude time series
|
||||||
max_wind = np.max(wind_mag) # Find the maximum wind magnitude in the time series
|
max_wind = np.max(wind_mag) # Find the maximum wind magnitude in the time series
|
||||||
|
|
||||||
if max_wind != 0:
|
if max_wind != 0:
|
||||||
wind_arrow_scale_factor = 1 # Scale factor for the wind arrow
|
wind_arrow_scale_factor = 1 # Scale factor for the wind arrow
|
||||||
wind = wind_arrow_scale_factor*wind / max_wind # Apply scaling on wind.
|
wind = wind_arrow_scale_factor*wind / max_wind
|
||||||
|
|
||||||
# Decimate data to render interval; always include t=0.
|
# Decimate data to render interval; always include t=0.
|
||||||
if time[-1] != 0:
|
if time[-1] != 0:
|
||||||
@@ -81,7 +94,7 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None,
|
|||||||
index = _decimate_index(time, sample_time)
|
index = _decimate_index(time, sample_time)
|
||||||
time = time[index]
|
time = time[index]
|
||||||
position = position[index,:]
|
position = position[index,:]
|
||||||
rotation = rotation[index,:,:]
|
rotation = rotation[index,:]
|
||||||
wind = wind[index,:]
|
wind = wind[index,:]
|
||||||
|
|
||||||
# Set up axes.
|
# Set up axes.
|
||||||
@@ -97,7 +110,7 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None,
|
|||||||
if not show_axes:
|
if not show_axes:
|
||||||
ax.set_axis_off()
|
ax.set_axis_off()
|
||||||
|
|
||||||
quad = Quadrotor(ax, wind=animate_wind, wind_scale_factor=1)
|
quads = [Quadrotor(ax, wind=animate_wind, wind_scale_factor=1) for _ in range(M)]
|
||||||
|
|
||||||
world_artists = world.draw(ax)
|
world_artists = world.draw(ax)
|
||||||
|
|
||||||
@@ -105,13 +118,16 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None,
|
|||||||
|
|
||||||
def init():
|
def init():
|
||||||
ax.draw(fig.canvas.get_renderer())
|
ax.draw(fig.canvas.get_renderer())
|
||||||
return world_artists + list(quad.artists) + [title_artist]
|
# return world_artists + list(cquad.artists) + [title_artist]
|
||||||
|
return world_artists + [title_artist] + [q.artists for q in quads]
|
||||||
|
|
||||||
def update(frame):
|
def update(frame):
|
||||||
title_artist.set_text('t = {:.2f}'.format(time[frame]))
|
title_artist.set_text('t = {:.2f}'.format(time[frame]))
|
||||||
quad.transform(position=position[frame,:], rotation=rotation[frame,:,:], wind=wind[frame,:])
|
for i, quad in enumerate(quads):
|
||||||
|
quad.transform(position=position[frame,i,:], rotation=rotation[frame,i,:,:], wind=wind[frame,i,:])
|
||||||
# [a.do_3d_projection(fig.canvas.get_renderer()) for a in quad.artists] # No longer necessary in newer matplotlib?
|
# [a.do_3d_projection(fig.canvas.get_renderer()) for a in quad.artists] # No longer necessary in newer matplotlib?
|
||||||
return world_artists + list(quad.artists) + [title_artist]
|
# return world_artists + list(quad.artists) + [title_artist]
|
||||||
|
return world_artists + [title_artist] + [q.artists for q in quads]
|
||||||
|
|
||||||
ani = ClosingFuncAnimation(fig=fig,
|
ani = ClosingFuncAnimation(fig=fig,
|
||||||
func=update,
|
func=update,
|
||||||
@@ -136,3 +152,89 @@ def animate(time, position, rotation, wind, animate_wind, world, filename=None,
|
|||||||
ani = None
|
ani = None
|
||||||
|
|
||||||
return ani
|
return ani
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
from rotorpy.vehicles.crazyflie_params import quad_params # Import quad params for the quadrotor environment.
|
||||||
|
from rotorpy.learning.quadrotor_environments import QuadrotorEnv
|
||||||
|
from rotorpy.controllers.quadrotor_control import SE3Control
|
||||||
|
from rotorpy.trajectories.circular_traj import CircularTraj
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
output_video_dir = os.path.join(os.path.dirname(__file__), "..", "data_out", "test_animation.mp4")
|
||||||
|
|
||||||
|
# Create M SE3 drones.
|
||||||
|
M = 3
|
||||||
|
baseline_controller = SE3Control(quad_params)
|
||||||
|
|
||||||
|
def make_env():
|
||||||
|
return gym.make("Quadrotor-v0",
|
||||||
|
control_mode ='cmd_motor_speeds',
|
||||||
|
quad_params = quad_params,
|
||||||
|
max_time = 5,
|
||||||
|
world = None,
|
||||||
|
sim_rate = 100,
|
||||||
|
render_mode='3D',
|
||||||
|
render_fps = 60,
|
||||||
|
color='b')
|
||||||
|
|
||||||
|
envs = [make_env() for _ in range(M)]
|
||||||
|
|
||||||
|
# For each environment command it to do a circle with random radius and location.
|
||||||
|
trajs = []
|
||||||
|
for env in envs:
|
||||||
|
center = np.random.uniform(low=-2, high=2, size=3)
|
||||||
|
radius = np.random.uniform(low=0.5, high=1.5)
|
||||||
|
freq = np.random.uniform(low=0.1, high=0.3)
|
||||||
|
plane = np.random.choice(['XY', 'YZ', 'XZ'])
|
||||||
|
traj = CircularTraj(center=center, radius=radius, freq=freq, plane=plane, direction=np.random.choice(['CW', 'CCW']))
|
||||||
|
trajs.append(traj)
|
||||||
|
|
||||||
|
# Collect observations for each environment.
|
||||||
|
observations = [env.reset(initial_state='random')[0] for env in envs]
|
||||||
|
|
||||||
|
# This is a list of env termination conditions so that the loop only ends when the final env is terminated.
|
||||||
|
terminated = [False]*len(observations)
|
||||||
|
|
||||||
|
# Arrays for animating.
|
||||||
|
T = [0]
|
||||||
|
position = [[obs[0:3] for obs in observations]]
|
||||||
|
quat = [[obs[6:10] for obs in observations]]
|
||||||
|
|
||||||
|
while not all(terminated):
|
||||||
|
for (i, env) in enumerate(envs): # For each environment...
|
||||||
|
# Unpack the observation from the environment
|
||||||
|
state = {'x': observations[i][0:3], 'v': observations[i][3:6], 'q': observations[i][6:10], 'w': observations[i][10:13]}
|
||||||
|
|
||||||
|
# Command the quad to do circles.
|
||||||
|
flat = trajs[i].update(env.unwrapped.t)
|
||||||
|
control_dict = baseline_controller.update(0, state, flat)
|
||||||
|
|
||||||
|
# Extract the commanded motor speeds.
|
||||||
|
cmd_motor_speeds = control_dict['cmd_motor_speeds']
|
||||||
|
|
||||||
|
# The environment expects the control inputs to all be within the range [-1,1]
|
||||||
|
action = np.interp(cmd_motor_speeds, [env.unwrapped.rotor_speed_min, env.unwrapped.rotor_speed_max], [-1,1])
|
||||||
|
|
||||||
|
# For the last environment, append the current timestep.
|
||||||
|
if i == 0:
|
||||||
|
T.append(env.unwrapped.t)
|
||||||
|
|
||||||
|
# Step the environment forward.
|
||||||
|
observations[i], reward, terminated[i], truncated, info = env.step(action)
|
||||||
|
|
||||||
|
# Append arrays for plotting.
|
||||||
|
position.append([obs[0:3] for obs in observations])
|
||||||
|
quat.append([obs[6:10] for obs in observations])
|
||||||
|
|
||||||
|
# Convert to numpy arrays.
|
||||||
|
T = np.array(T)
|
||||||
|
position = np.array(position)
|
||||||
|
quat = np.array(quat)
|
||||||
|
|
||||||
|
# Convert the quaternion to rotation matrix.
|
||||||
|
rotation = np.array([Rotation.from_quat(quat[i]).as_matrix() for i in range(T.size)])
|
||||||
|
|
||||||
|
# Animate the results.
|
||||||
|
ani = animate(T, position, rotation, wind=np.zeros((T.size,M,3)), animate_wind=False, world=envs[0].world, filename=output_video_dir, blit=False, show_axes=True, close_on_finish=True)
|
||||||
Reference in New Issue
Block a user