Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lift the requirement of human fingering with RP1M. #23

Open
zhaoyi11 opened this issue Sep 3, 2024 · 2 comments
Open

Lift the requirement of human fingering with RP1M. #23

zhaoyi11 opened this issue Sep 3, 2024 · 2 comments

Comments

@zhaoyi11
Copy link
Contributor

zhaoyi11 commented Sep 3, 2024

Hello,

Thanks for the great work! Recently, we released a paper named RP1M (https://arxiv.org/abs/2408.11048, cc @clthegoat) which includes a reward term based on optimal transport, enabling the agent to play MIDI files without human fingering. We want to know whether it is possible to integrate the method in this repo, such that people can conveniently use the Robopianist to play more songs beyond the PIG dataset.

Here are some comparison results from the paper as well as a short plan for the modification of the code. Please let me know your thoughts.

Results:
comparison results

Modifications:
I plan to change these lines

if not self._disable_fingering_reward:
self._reward_fn.add("fingering_reward", self._compute_fingering_reward)

as:

if not self._disable_fingering_reward: 
    # when human fingering is available.
    self._reward_fn.add("fingering_reward", self._compute_fingering_reward)
else:
    # use OT reward 
    self._reward_fn.add("ot_reward", self._compute_ot_reward)

where the _compute_ot_reward is defined as:

from scipy.optimize import linear_sum_assignment

def _compute_ot_reward(self, physics: mjcf.Physics) -> float:
    """ OT reward calculation from RP1M https://arxiv.org/abs/2408.11048 """
    # calcuate fingertip positions
    fingertip_pos = [physics.bind(finger).xpos.copy() for finger in self.left_hand.fingertip_sites]
    fingertip_pos += [physics.bind(finger).xpos.copy() for finger in self.right_hand.fingertip_sites]
    
    # calcuate the positions of piano keys to press.
    keys_to_press = np.flatnonzero(self._goal_current[:-1]) # keys to press
    # if no key is pressed
    if keys_to_press.shape[0] == 0:
        return 1.

    # same as RoboPianist
    key_pos = []
    for key in keys_to_press:
        key_geom = self.piano.keys[key].geom[0]
        key_geom_pos = physics.bind(key_geom).xpos.copy()
        key_geom_pos[-1] += 0.5 * physics.bind(key_geom).size[2]
        key_geom_pos[0] += 0.35 * physics.bind(key_geom).size[0]
        key_pos.append(key_geom_pos.copy())

    # calcualte the distance between keys and fingers
    dist = np.full((len(fingertip_pos), len(key_pos)), 100.)
    for i, finger in enumerate(fingertip_pos):
        for j, key in enumerate(key_pos):
            dist[i, j] = np.linalg.norm(key - finger)
    
    # calculate the shortest distance
    row_ind, col_ind = linear_sum_assignment(dist)
    dist = dist[row_ind, col_ind]
    rews = tolerance(
        dist,
        bounds=(0, _FINGER_CLOSE_ENOUGH_TO_KEY),
        margin=(_FINGER_CLOSE_ENOUGH_TO_KEY * 10),
        sigmoid="gaussian",
    )
    return float(np.mean(rews))
@kevinzakka
Copy link
Collaborator

kevinzakka commented Sep 3, 2024

@zhaoyi11 I was going to email you all and ask you if you would be down to integrate your labeling pipeline in the repo 😂 So to answer your question, yes! I would be more than happy to help you, please feel free to submit a PR! Love the paper btw!

@zhaoyi11
Copy link
Contributor Author

zhaoyi11 commented Sep 3, 2024

Thanks @kevinzakka! Great, I will prepare the PR asap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants