diff --git a/rotorpy/learning/quadrotor_reward_functions.py b/rotorpy/learning/quadrotor_reward_functions.py index 1662427..1560bdc 100644 --- a/rotorpy/learning/quadrotor_reward_functions.py +++ b/rotorpy/learning/quadrotor_reward_functions.py @@ -10,27 +10,22 @@ import math Reward functions for quadrotor tasks. """ -def hover_reward(observation, action): +def hover_reward(observation, action, weights={'x': 1, 'v': 0.1, 'w': 0, 'u': 1e-5}): """ Rewards hovering at (0, 0, 0). It is a combination of position error, velocity error, body rates, and action reward. """ - dist_weight = 1 - vel_weight = 0.1 - action_weight = 0.001 - ang_rate_weight = 0.1 - # Compute the distance to goal - dist_reward = -dist_weight*np.linalg.norm(observation[0:3]) + dist_reward = -weights['x']*np.linalg.norm(observation[0:3]) # Compute the velocity reward - vel_reward = -vel_weight*np.linalg.norm(observation[3:6]) + vel_reward = -weights['v']*np.linalg.norm(observation[3:6]) # Compute the angular rate reward - ang_rate_reward = -ang_rate_weight*np.linalg.norm(observation[10:13]) + ang_rate_reward = -weights['w']*np.linalg.norm(observation[10:13]) # Compute the action reward - action_reward = -action_weight*np.linalg.norm(action) + action_reward = -weights['u']*np.linalg.norm(action) return dist_reward + vel_reward + action_reward + ang_rate_reward \ No newline at end of file