-
Notifications
You must be signed in to change notification settings - Fork 0
/
task_fitness_feature_utils.py
155 lines (115 loc) · 5.19 KB
/
task_fitness_feature_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import jax.numpy as jnp
from brax.math import quat_to_euler
####################################################
##### COMMON FUNCTIONS USED TO COMPUTE REWARDS #####
####################################################
def forward_vel_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
vel_rwd = obs.at[0].get() - prev_obs.at[0].get() / 0.05 # dt is 0.05 HARDCODED FOR NOW
return vel_rwd
def action_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
"""
reward for custom ant omni which penalises large actions
"""
rwd = -0.5*jnp.sum(jnp.square(actions))
return rwd
def forward_vel_reward_humanoid(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
vel_rwd = obs.at[0].get() - prev_obs.at[0].get() / 0.015 # dt is 0.015 HARDCODED FOR NOW
return vel_rwd
def action_reward_humanoid(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
"""
reward for custom ant omni which penalises large actions
"""
rwd = -0.1*jnp.sum(jnp.square(actions)) # 0.1 coeff is HARDCODED FOR NOW
return rwd
def _angle_dist(a, b):
theta = b - a
theta = jnp.where(theta < -jnp.pi, theta + 2 * jnp.pi, theta)
theta = jnp.where(theta > jnp.pi, theta - 2 * jnp.pi, theta)
theta = jnp.where(theta < -jnp.pi, theta + 2 * jnp.pi, theta)
theta = jnp.where(theta > jnp.pi, theta - 2 * jnp.pi, theta)
return theta
def angle_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
"""
angle error reward for custom ant omni based on x,y location there is a desired yaw angle
"""
com_rot = quat_to_euler(obs.at[3:7].get())
z_rot = com_rot.at[2].get()
ang_x, ang_y = obs.at[0].get(), obs.at[1].get()
B = jnp.sqrt((ang_x / 2.0) * (ang_x / 2.0) + (ang_y / 2.0) * (ang_y / 2.0))
alpha = jnp.arctan2(ang_y, ang_x)
A = B / jnp.cos(alpha)
beta = jnp.arctan2(ang_y, ang_x - A)
beta = jnp.where(ang_x < 0,beta-jnp.pi,beta)
beta = jnp.where(beta < -jnp.pi,beta + 2 * jnp.pi, beta)
beta = jnp.where(beta > jnp.pi,beta - 2 * jnp.pi, beta)
beta = jnp.where(beta < -jnp.pi,beta + 2 * jnp.pi, beta)
beta = jnp.where(beta > jnp.pi,beta - 2 * jnp.pi, beta)
beta = jnp.where(beta < -jnp.pi,beta + 2 * jnp.pi, beta)
beta = jnp.where(beta > jnp.pi,beta - 2 * jnp.pi, beta)
angle_diff = jnp.abs(_angle_dist(beta, z_rot))
rwd = -angle_diff
return rwd
#####################################################################################
################## DYNAMICS MODEL REWARDS FOR RESPECTIVE ENVS #######################
#####################################################################################
def antmaze_target_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
xy_target = jnp.array([35.0, 0.0]) # CAREFUL HARDCODED FOR NOW
xy_pos = obs.at[0:2].get()
rwd = -jnp.linalg.norm(xy_pos - xy_target)
return rwd
def pointmaze_target_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
xy_target = jnp.array([-0.5, 0.8]) # CAREFUL HARDCODED FOR NOW
xy_pos = obs.at[0:2].get()
rwd = -jnp.linalg.norm(xy_pos - xy_target)
return rwd
def anttrap_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
"""
reward for custom ant omni which penalises large actions
"""
vel_rwd = forward_vel_reward(obs, prev_obs, actions)
action_rwd = action_reward(obs, prev_obs, actions)
return vel_rwd + action_rwd
def ant_xy_forward_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
"""
reward for custom ant omni which penalises large actions
"""
vel_rwd = forward_vel_reward(obs, prev_obs, actions)
action_rwd = action_reward(obs, prev_obs, actions)
return vel_rwd + action_rwd
def ant_omni_action_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
"""
reward for custom ant omni which penalises large actions
"""
action_rwd = action_reward(obs, prev_obs, actions)
return action_rwd
def ant_omni_angle_reward(obs: jnp.ndarray, prev_obs: jnp.ndarray, actions:jnp.ndarray) -> float:
"""
reward for custom ant omni which penalises angle error
"""
angle_rwd = angle_reward(obs, prev_obs, actions)
return angle_rwd
fitness_extractor_imagination = {
"pointmaze": pointmaze_target_reward,
"anttrap": anttrap_reward,
"antmaze": antmaze_target_reward,
"ant_omni_action": ant_omni_action_reward,
"ant_omni_angle": ant_omni_angle_reward,
"ant_xy_forward": ant_xy_forward_reward,
}
###################################
########## BD EXTRACTORS ##########
###################################
def get_final_xy_position(obs: jnp.ndarray, actions:jnp.ndarray) -> float:
"""
B, T, D = obs.shape
"""
xy_pos = obs.at[:, -1, 0:2].get()
return xy_pos
bd_extractor_imagination = {
"pointmaze": get_final_xy_position,
"anttrap": get_final_xy_position,
"antmaze": get_final_xy_position,
"ant_omni_action": get_final_xy_position,
"ant_omni_angle": get_final_xy_position,
"ant_xy_forward": get_final_xy_position,
}