diff --git a/rlbench/demo.py b/rlbench/demo.py index bf3df6a0d..18d32f0fc 100644 --- a/rlbench/demo.py +++ b/rlbench/demo.py @@ -1,9 +1,11 @@ import numpy as np +from typing import List +from rlbench.backend.observation import Observation class Demo(object): - def __init__(self, observations, random_seed=None, num_reset_attempts = None): + def __init__(self, observations: List[Observation], random_seed=None, num_reset_attempts=None): self._observations = observations self.random_seed = random_seed self.num_reset_attempts = num_reset_attempts @@ -11,7 +13,7 @@ def __init__(self, observations, random_seed=None, num_reset_attempts = None): def __len__(self): return len(self._observations) - def __getitem__(self, i): + def __getitem__(self, i) -> Observation: return self._observations[i] def restore_state(self):