Skip to content

Commit

Permalink
Fix typing error in SQIL implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Jul 18, 2023
1 parent 9c5b91c commit f8584c3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/imitation/algorithms/sqil.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def set_demonstrations(
next_obs=transition["next_obs"],
action=transition["acts"],
done=transition["dones"],
reward=1,
reward=np.array(1.0),
infos=[{}],
)

Expand All @@ -194,7 +194,7 @@ def add(
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
super().add(obs, next_obs, action, 0, done, infos)
super().add(obs, next_obs, action, np.array(0.0), done, infos)

def sample(
self,
Expand Down

0 comments on commit f8584c3

Please sign in to comment.