-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rgb observation to dagger #802
base: master
Are you sure you want to change the base?
Conversation
Co-authored-by: Adam Gleave <[email protected]>
In the files changed part, some of them are from master branch. This is because this branch is initialized from a certain commit of support-dict-obs-space branch rather than the latest condition. This would be more clear once support-dict-obs-space get merged and I can change its base to master. |
Codecov Report
@@ Coverage Diff @@
## master #802 +/- ##
==========================================
- Coverage 96.40% 96.35% -0.06%
==========================================
Files 98 100 +2
Lines 9441 9582 +141
==========================================
+ Hits 9102 9233 +131
- Misses 339 349 +10
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! I like the implementation of HumanReadableWrapper
and the test cases, very clean.
However I don't understand the design decision of how this was integrated with DAgger. As I see it right now it seems like we're special-casing some logic into BC and DAgger to catch (and strip out) RGB observations from the learnable policy. Generally it seems preferable to have each part of the code responsible for a small, discrete chunk. One natural way of splitting it is the algorithms are responsible for the training process, the policies as responsible for how the observations are processed and converted to actions, and other parts of the code as responsible for data collection etc. Putting logic to strip RGB observations out into the algorithms is breaking this abstraction hierarchy.
Sometimes it is necessary to break abstractions but I'd like to better understand what we're gaining from this and what the alternatives look like before committing to this. To sketch one alternative (which I suspect will need some modification): what if we made policies (or the policies' feature extractors) responsible for removing RGB observations instead? We might be able to do this with a wrapper for a policy, or just specifying a custom feature extractor (we might need some small changes to the algorithms to let callee specify feature extractor, but it'd be a more generic change). There may well be pitfalls to this approach as well -- would love to hear your thoughts.
Thank you for the detailed comments! I agree with your idea and I think the original design is flawed. I update the design to use a policy wrapper. I choose it instead of feature extractor because feature_extractor works on the tensor level and the wrapper can work on the np.ndarray and dict[str, np.ndarray] level, which is closer to the input. I don't have a strong preference about it though. In this new design, like you said, we don't need to modify any algorithm level code to fix the data level issue. Please take another look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZiyueWang25 please don't make any further changes to this PR now your period with us is up -- safe travels back to Seattle and have a good weekend :) We'll find someone else to finish off the PR. Including this for informational value for you and for the next developer to pick up on.
It looks like HumanReadableWrapper
ends up with an observation space that is inconsistent with observations returned (does not include the human-readable component that is added). This violates the Gynmasium API and is likely to cause problems somewhere down the road: e.g. if algorithms allocate buffers to store observations based on the declared observation space. This should be fixed.
I suspect having the observation space be unchanged made the policy wrapper easier. It's workable to do it without, but might need to mangle the observation space being fed into the underlying policy. This along with the need for policy wrapper to be specialized to particular kinds (ActorCriticPolicy, OffPolicy, maybe even SACPolicy) makes me think feature extractors are likely the cleaner solution here.
|
||
def lr_schedule(_: float): | ||
# Set lr_schedule to max value to force error if policy.optimizer | ||
# is used by mistake (should use self.optimizer instead). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy-and-pasted comment doesn't make sense out of context (what is self
here?)
from imitation.data import wrappers as data_wrappers | ||
|
||
|
||
class Base(ActorCriticPolicy, abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the need to inherit from ActorCriticPolicy
, I suggest we instead implement a feature extractor sub-classing CombinedExtractor
, having it skip the key (I think this is as simple as modifying observation_space
passed through to the constructor of CombinedExtractor
). Then set that feature extractor in the policy, and it should work with any kind of policy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Base
is a bit vague, base of what? Base policy wraper? Base actor critic policy wrapper?
else: | ||
full_std = True | ||
use_expln = False | ||
super().__init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit hacky although I don't see a better way of doing this. It's a limitation/gap in the SB3 API that there isn't a PolicyWrapper
class -- but this is perhaps also a sign that wrapping policies will fit awkwardly into the API (sorry for putting you on the wrong trail).
raise ValueError( | ||
"Only human readable observation exists, can't remove it", | ||
) | ||
# keeps the original observation unchanged in case it is used elsewhere. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 yes good to avoid side effects where possible, and copying a dict should be cheap as it's just a shallow-copy
) | ||
assert isinstance(env.observation_space, gym.spaces.Dict) | ||
_check_obs_or_space_equal(env.observation_space, expected_obs_space) | ||
assert hr_env.observation_space == ori_env.observation_space |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can this be true? Does the observation returned by step()
not belong to the observation space (that'd violate the Gymnasium API I think)? Or are we encoding the human readable information somewhere other than the observation?
@@ -235,30 +234,8 @@ def __init__(self, env: Env, original_obs_key: str = "ORI_OBS"): | |||
) | |||
self._original_obs_key = original_obs_key | |||
super().__init__(env) | |||
self._update_obs_space() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we still need the observation space update. From the Gynmasium docs:
The transformation defined in that method must be reflected by the env observation space. Otherwise, you need to specify the new observation space of the wrapper by setting self.observation_space in the init() method of your wrapper.
Description
Testing
Unit test + test by examples.