-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gh-7563: add config wrapper #7730
base: dev
Are you sure you want to change the base?
Conversation
Signed-off-by: John Zielke <[email protected]>
Signed-off-by: John Zielke <[email protected]>
16f579f
to
22518de
Compare
Signed-off-by: John Zielke <[email protected]>
Signed-off-by: John Zielke <[email protected]>
Hi @johnzielke thanks for the contribution! This looks like an interesting addition with some clearly useful use cases. When you are done mark the PR as ready and we can consider it though I'd want to the feedback of a few other developers on this feature. |
Signed-off-by: John Zielke <[email protected]>
Signed-off-by: John Zielke <[email protected]>
Hey @ericspod, thank you. I fixed a few minor issues but I think it's ready for review/discussion now. |
Signed-off-by: John Zielke <[email protected]>
Signed-off-by: John Zielke <[email protected]>
Signed-off-by: John Zielke <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! It appears to be a useful and great enhancement.
Leave few comments inline and will also let more people review.
instantiate_kwargs.update(kwargs) | ||
wrapper = self._get_wrapper() | ||
if wrapper is not None: | ||
return wrapper(instantiate(modname, mode, **instantiate_kwargs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the wrapper here accept kwargs
such as torch.compile can accept mode
?
https://github.com/pytorch/pytorch/blob/4333e122d4b74cdf84351ed2907045c6a767b4cd/torch/compiler/__init__.py#L17
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the wrapper key should be "instantiated" like any other key, and therefore will accept kwargs. So by using mode: "callable" for torch.compile, you can bind any kwargs to it, e.g. pseudo-code:
_target_: nets.Unet
spatial_dim: 3
_wrapper_:
_target_: torch.compile
_mode_: callable
dynamic: true
fullgraph: true
So it doesn't accept them in this line, but should already have them bound by here.
"features": [16, 16, 32, 32, 64, 64], | ||
"_wrapper_": { | ||
"_target_": "torch::jit::compile", | ||
"_mode_": "callable" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the wrapper can only be callable, do we need add this mode here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might/probably are use cases where you would instantiate a factory here (class with a call method) or have a function that returns another function (in fact torch.compile probably does this, as most decorators supporting arguments do)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @johnzielke, after discuss with @Nic-Ma offline, we find that such wrapper can be easily achieved by using code like this:
from monai.bundle import ConfigParser
config = {
"model_base": {
"_target_": "monai.networks.nets.BasicUNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"features": [16, 16, 32, 32, 64, 64]
},
"model": "$torch.compile(@model_base)"
}
parser = ConfigParser(config=config)
parser.parse()
net = parser.get_parsed_content("model")
print(net)
Does this meet your needs?
Like how we use DDP here:
https://github.com/Project-MONAI/model-zoo/blob/dev/models/spleen_ct_segmentation/configs/multi_gpu_train.json#L3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @johnzielke, after discuss with @Nic-Ma offline, we find that such wrapper can be easily achieved by using code like this:
We have used this pattern before which I feel for my use cases is fine. I'd like to hear if there's other cases this PR makes more sense for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that patterns works in many cases, but it requires you to move definitions to other keys, which might be undesirable for multiple reasons:
- A lot of references to this in code and overriding configs need to be adjusted if this has not been planned from the beginning
- If the thing you are adding is used as a parameter to a parent instantiation using target, you cannot just add the "model_base" key to the same level if that parent class does not handle other kwargs gracefully.
- If your wrapping function has parameters (for example the dynamic shapes in torch.compile), you need to specify those in a single string making it harder to modify these with configs. Of course you could also introduce another config called model_decorator, and then have model be
@model_decorator(@model_base)
, but that adds a lot of visual noise in my opinion
I'll try to compile some other use cases later, but one I could see would be to wrap existing Datasets in CacheDatasets or similar. In that case you would have often have a dictionary of train, val, test and the datasets. Of course the same pattern can be applied again, but it makes the configs harder to read in my opinion.
elif wrapper is not None: | ||
warnings.warn( | ||
f"ConfigComponent: {self.get_id()} has a key {ConfigComponentReservedKeys.WRAPPER}. " | ||
"Since the feature flag CONFIG_WRAPPER is not enabled, the key will be treated as a normal config key. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like in this case the wrapper will become None instead of a normal config key
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I'm not quite following, can you explain a bit further?
k: v | ||
for k, v in self.get_config().items() | ||
if (k not in self.non_arg_keys) | ||
or (k == ConfigComponentReservedKeys.WRAPPER and not _wrapper_feature_flag.enabled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should return the key and value under k == ConfigComponentReservedKeys.WRAPPER and not _wrapper_feature_flag.enabled
this case, it will always wrong since the k will always not be the arguments for the instance?
Correct me if I misunderstand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is meant for backwards compatibility. What if someone is already using the key _wrapper_
in one of their configs? SInce it was not reserved before, it might have been used in normal config. It's unlikely, but I wanted to make sure to cover this case. But I'm happy to always remove it. But in that case I would probably add a regex to check for "_\w+_"
and make all keys of this pattern reserved for the future.
try: | ||
return instantiate(modname, mode, **instantiate_kwargs) | ||
except Exception as e: | ||
if _wrapper_feature_flag.enabled and self.get_id().endswith(ConfigComponentReservedKeys.WRAPPER): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.get_id()
will only return a str, will this check worked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> "test::_wrapper_".endswith(ConfigComponentReservedKeys.WRAPPER)
True
Yes this should work, this evaluates to true for me
Signed-off-by: John Zielke <[email protected]>
Fixes #7563 .
Description
This PR introduces a wrapper special keyword that can be used to wrap/decorate a existing model key. It is intended to be used with functions that follow the decorator API in python (i.e. expect the first positional argument of the function to be the instance/function to be decorated).
This can be used to keep the hierarchy inside a config when using tools like
torch.compile
or wrapping other datasets/models.This is still a draft, but I'd like to gather opinions on this feature and its implementation.
Open questions (non-exhaustive):
This feature should be handled as experimental, meaning subject to change and removal in the future. This way we can gain some experience using it and change it without worrying about backwards compatibility.
Breaking changes
The only breaking change I see is that configs that used the special keyword in their configs before, would have this value be interpreted differently now. In the current state, with the feature flag disabled (the default), the config will warn about usage of this special key since it's meaning could change in the future.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.