diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/config/spot/flat_env_cfg.py b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/config/spot/flat_env_cfg.py index e29abc2c62..1b6f98d276 100644 --- a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/config/spot/flat_env_cfg.py +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/config/spot/flat_env_cfg.py @@ -212,9 +212,14 @@ class SpotRewardsCfg: }, ) gait = RewardTermCfg( - func=spot_mdp.gait_reward, + func=spot_mdp.GaitReward, weight=10.0, - params={"std": 0.1, "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_foot")}, + params={ + "std": 0.1, + "max_err": 0.2, + "synced_feet_pair_names": (("fl_foot", "hr_foot"), ("fr_foot", "hl_foot")), + "sensor_cfg": SceneEntityCfg("contact_forces"), + }, ) # -- penalties diff --git a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/config/spot/mdp/rewards.py b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/config/spot/mdp/rewards.py index 32992a3ef7..1ce4dc67d8 100644 --- a/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/config/spot/mdp/rewards.py +++ b/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/config/spot/mdp/rewards.py @@ -10,11 +10,12 @@ from typing import TYPE_CHECKING from omni.isaac.lab.assets import Articulation, RigidObject -from omni.isaac.lab.managers import SceneEntityCfg +from omni.isaac.lab.managers import ManagerTermBase, SceneEntityCfg from omni.isaac.lab.sensors import ContactSensor if TYPE_CHECKING: from omni.isaac.lab.envs import ManagerBasedRLEnv + from omni.isaac.lab.managers import RewardTermCfg # -- Task Rewards @@ -62,41 +63,79 @@ def base_linear_velocity_reward( return torch.exp(-lin_vel_error / std) * velocity_scaling_multiple -# ! need to finalize logic, params, and docstring -def gait_reward(env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg, std: float) -> torch.Tensor: - """Penalize ...""" - # extract the used quantities (to enable type-hinting) - contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] - if contact_sensor.cfg.track_air_time is False: - raise RuntimeError("Activate ContactSensor's track_air_time!") - # compute the reward - air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids] - contact_time = contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] - - max_err = 0.2 - indices_0 = [0, 1] - indices_1 = [2, 3] - cmd = torch.norm(env.command_manager.get_command("base_velocity"), dim=1) - asym_err_0 = torch.clip( - torch.square(air_time[:, indices_0[0]] - contact_time[:, indices_0[1]]), max=max_err**2 - ) + torch.clip(torch.square(contact_time[:, indices_0[0]] - air_time[:, indices_0[1]]), max=max_err**2) - asym_err_1 = torch.clip( - torch.square(air_time[:, indices_1[0]] - contact_time[:, indices_1[1]]), max=max_err**2 - ) + torch.clip(torch.square(contact_time[:, indices_1[0]] - air_time[:, indices_1[1]]), max=max_err**2) - asym_err_2 = torch.clip( - torch.square(air_time[:, indices_0[0]] - contact_time[:, indices_1[0]]), max=max_err**2 - ) + torch.clip(torch.square(contact_time[:, indices_0[0]] - air_time[:, indices_1[0]]), max=max_err**2) - asym_err_3 = torch.clip( - torch.square(air_time[:, indices_0[1]] - contact_time[:, indices_1[1]]), max=max_err**2 - ) + torch.clip(torch.square(contact_time[:, indices_0[1]] - air_time[:, indices_1[1]]), max=max_err**2) - sym_err_0 = torch.clip( - torch.square(air_time[:, indices_0[0]] - air_time[:, indices_1[1]]), max=max_err**2 - ) + torch.clip(torch.square(contact_time[:, indices_0[0]] - contact_time[:, indices_1[1]]), max=max_err**2) - sym_err_1 = torch.clip( - torch.square(air_time[:, indices_0[1]] - air_time[:, indices_1[0]]), max=max_err**2 - ) + torch.clip(torch.square(contact_time[:, indices_0[1]] - contact_time[:, indices_1[0]]), max=max_err**2) - gait_err = asym_err_0 + asym_err_1 + sym_err_0 + sym_err_1 + asym_err_2 + asym_err_3 - return torch.where(cmd > 0.0, torch.exp(-gait_err / std), 0.0) +class GaitReward(ManagerTermBase): + """Gait enforcing reward term for quadrupeds. + + This reward penalizes contact timing differences between selected foot pairs defined in :attr:`synced_feet_pair_names` + to bias the policy towards a desired gait, i.e trotting, bounding, or pacing. Note that this reward is only for + quadrupedal gaits with two pairs of synchronized feet. + """ + + def __init__(self, cfg: RewardTermCfg, env: RLTaskEnv): + """Initialize the term. + + Args: + cfg: The configuration of the reward. + env: The RL environment instance. + """ + super().__init__(cfg, env) + self.std: float = cfg.params["std"] + self.max_err: float = cfg.params["max_err"] + self.contact_sensor: ContactSensor = env.scene.sensors[cfg.params["sensor_cfg"].name] + # match foot body names with corresponding foot body ids + synced_feet_pair_names = cfg.params["synced_feet_pair_names"] + if ( + len(synced_feet_pair_names) != 2 + or len(synced_feet_pair_names[0]) != 2 + or len(synced_feet_pair_names[1]) != 2 + ): + raise ValueError("This reward only supports gaits with two pairs of synchronized feet, like trotting.") + synced_feet_pair_0 = self.contact_sensor.find_bodies(synced_feet_pair_names[0])[0] + synced_feet_pair_1 = self.contact_sensor.find_bodies(synced_feet_pair_names[1])[0] + self.synced_feet_pairs = [synced_feet_pair_0, synced_feet_pair_1] + + def __call__(self, env: RLTaskEnv, std, max_err, synced_feet_pair_names, sensor_cfg) -> torch.Tensor: + """Compute the reward. + + This reward is defined as a multiplication between six terms where two of them enforce pair feet + being in sync and the other four rewards if all the other remaining pairs are out of sync + + Args: + env: The RL environment instance. + Returns: + The reward value. + """ + # for synchronous feet, the contact (air) times of two feet should match + syc_reward_0 = self._syc_reward_func(self.synced_feet_pairs[0][0], self.synced_feet_pairs[0][1]) + syc_reward_1 = self._syc_reward_func(self.synced_feet_pairs[1][0], self.synced_feet_pairs[1][1]) + syc_reward = syc_reward_0 * syc_reward_1 + # for asynchronous feet, the contact time of one foot should match the air time of the other one + asyc_reward_0 = self._asyc_reward_func(self.synced_feet_pairs[0][0], self.synced_feet_pairs[1][0]) + asyc_reward_1 = self._asyc_reward_func(self.synced_feet_pairs[0][1], self.synced_feet_pairs[1][1]) + asyc_reward_2 = self._asyc_reward_func(self.synced_feet_pairs[0][0], self.synced_feet_pairs[1][1]) + asyc_reward_3 = self._asyc_reward_func(self.synced_feet_pairs[1][0], self.synced_feet_pairs[0][1]) + asyc_reward = asyc_reward_0 * asyc_reward_1 * asyc_reward_2 * asyc_reward_3 + # only enforce gait if cmd > 0 + cmd = torch.norm(env.command_manager.get_command("base_velocity"), dim=1) + return torch.where(cmd > 0.0, syc_reward * asyc_reward, 0.0) + + def _syc_reward_func(self, foot_0: int, foot_1: int) -> torch.Tensor: + """Reward two feet being in sync.""" + air_time = self.contact_sensor.data.current_air_time + contact_time = self.contact_sensor.data.current_contact_time + # squared error between air times plus squared error between contact time + se_air = torch.clip(torch.square(air_time[:, foot_0] - air_time[:, foot_1]), max=self.max_err**2) + se_contact = torch.clip(torch.square(contact_time[:, foot_0] - contact_time[:, foot_1]), max=self.max_err**2) + return torch.exp(-(se_air + se_contact) / self.std) + + def _asyc_reward_func(self, foot_0: int, foot_1: int) -> torch.Tensor: + """Reward two feet being out of sync.""" + air_time = self.contact_sensor.data.current_air_time + contact_time = self.contact_sensor.data.current_contact_time + # squared error between air time and contact time + se_act_0 = torch.clip(torch.square(air_time[:, foot_0] - contact_time[:, foot_1]), max=self.max_err**2) + se_act_1 = torch.clip(torch.square(contact_time[:, foot_0] - air_time[:, foot_1]), max=self.max_err**2) + return torch.exp(-(se_act_0 + se_act_1) / self.std) def foot_clearance_reward(