Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-7563: add config wrapper #7730

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from

Conversation

johnzielke
Copy link
Contributor

Fixes #7563 .

Description

This PR introduces a wrapper special keyword that can be used to wrap/decorate a existing model key. It is intended to be used with functions that follow the decorator API in python (i.e. expect the first positional argument of the function to be the instance/function to be decorated).
This can be used to keep the hierarchy inside a config when using tools like torch.compile or wrapping other datasets/models.

This is still a draft, but I'd like to gather opinions on this feature and its implementation.
Open questions (non-exhaustive):

  • name of the special keyword (e.g. _wrapper_, _wrap_, _decorator_)
  • How to enable/handle this experimental feature

This feature should be handled as experimental, meaning subject to change and removal in the future. This way we can gain some experience using it and change it without worrying about backwards compatibility.

Breaking changes

The only breaking change I see is that configs that used the special keyword in their configs before, would have this value be interpreted differently now. In the current state, with the feature flag disabled (the default), the config will warn about usage of this special key since it's meaning could change in the future.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@ericspod
Copy link
Member

ericspod commented May 7, 2024

Hi @johnzielke thanks for the contribution! This looks like an interesting addition with some clearly useful use cases. When you are done mark the PR as ready and we can consider it though I'd want to the feedback of a few other developers on this feature.

@johnzielke johnzielke marked this pull request as ready for review May 13, 2024 23:26
@johnzielke
Copy link
Contributor Author

Hey @ericspod, thank you. I fixed a few minor issues but I think it's ready for review/discussion now.

Signed-off-by: John Zielke <[email protected]>
Signed-off-by: John Zielke <[email protected]>
Copy link
Contributor

@KumoLiu KumoLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! It appears to be a useful and great enhancement.
Leave few comments inline and will also let more people review.

instantiate_kwargs.update(kwargs)
wrapper = self._get_wrapper()
if wrapper is not None:
return wrapper(instantiate(modname, mode, **instantiate_kwargs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the wrapper here accept kwargs such as torch.compile can accept mode?
https://github.com/pytorch/pytorch/blob/4333e122d4b74cdf84351ed2907045c6a767b4cd/torch/compiler/__init__.py#L17

Copy link
Contributor Author

@johnzielke johnzielke May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the wrapper key should be "instantiated" like any other key, and therefore will accept kwargs. So by using mode: "callable" for torch.compile, you can bind any kwargs to it, e.g. pseudo-code:

_target_: nets.Unet
spatial_dim: 3
_wrapper_:
  _target_: torch.compile
  _mode_: callable
  dynamic: true
  fullgraph: true

So it doesn't accept them in this line, but should already have them bound by here.

"features": [16, 16, 32, 32, 64, 64],
"_wrapper_": {
"_target_": "torch::jit::compile",
"_mode_": "callable"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the wrapper can only be callable, do we need add this mode here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might/probably are use cases where you would instantiate a factory here (class with a call method) or have a function that returns another function (in fact torch.compile probably does this, as most decorators supporting arguments do)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @johnzielke, after discuss with @Nic-Ma offline, we find that such wrapper can be easily achieved by using code like this:

from monai.bundle import ConfigParser

config = {
  "model_base": {
      "_target_": "monai.networks.nets.BasicUNet",
      "spatial_dims": 3,
      "in_channels": 1,
      "out_channels": 2,
      "features": [16, 16, 32, 32, 64, 64]
  },
  "model": "$torch.compile(@model_base)"

}
parser = ConfigParser(config=config)
parser.parse()
net = parser.get_parsed_content("model")
print(net)

Does this meet your needs?
Like how we use DDP here:
https://github.com/Project-MONAI/model-zoo/blob/dev/models/spleen_ct_segmentation/configs/multi_gpu_train.json#L3

cc @Nic-Ma @ericspod

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @johnzielke, after discuss with @Nic-Ma offline, we find that such wrapper can be easily achieved by using code like this:

We have used this pattern before which I feel for my use cases is fine. I'd like to hear if there's other cases this PR makes more sense for.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that patterns works in many cases, but it requires you to move definitions to other keys, which might be undesirable for multiple reasons:

  • A lot of references to this in code and overriding configs need to be adjusted if this has not been planned from the beginning
  • If the thing you are adding is used as a parameter to a parent instantiation using target, you cannot just add the "model_base" key to the same level if that parent class does not handle other kwargs gracefully.
  • If your wrapping function has parameters (for example the dynamic shapes in torch.compile), you need to specify those in a single string making it harder to modify these with configs. Of course you could also introduce another config called model_decorator, and then have model be @model_decorator(@model_base), but that adds a lot of visual noise in my opinion

I'll try to compile some other use cases later, but one I could see would be to wrap existing Datasets in CacheDatasets or similar. In that case you would have often have a dictionary of train, val, test and the datasets. Of course the same pattern can be applied again, but it makes the configs harder to read in my opinion.

elif wrapper is not None:
warnings.warn(
f"ConfigComponent: {self.get_id()} has a key {ConfigComponentReservedKeys.WRAPPER}. "
"Since the feature flag CONFIG_WRAPPER is not enabled, the key will be treated as a normal config key. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like in this case the wrapper will become None instead of a normal config key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm not quite following, can you explain a bit further?

k: v
for k, v in self.get_config().items()
if (k not in self.non_arg_keys)
or (k == ConfigComponentReservedKeys.WRAPPER and not _wrapper_feature_flag.enabled)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should return the key and value under k == ConfigComponentReservedKeys.WRAPPER and not _wrapper_feature_flag.enabled this case, it will always wrong since the k will always not be the arguments for the instance?
Correct me if I misunderstand.

Copy link
Contributor Author

@johnzielke johnzielke May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is meant for backwards compatibility. What if someone is already using the key _wrapper_ in one of their configs? SInce it was not reserved before, it might have been used in normal config. It's unlikely, but I wanted to make sure to cover this case. But I'm happy to always remove it. But in that case I would probably add a regex to check for "_\w+_" and make all keys of this pattern reserved for the future.

try:
return instantiate(modname, mode, **instantiate_kwargs)
except Exception as e:
if _wrapper_feature_flag.enabled and self.get_id().endswith(ConfigComponentReservedKeys.WRAPPER):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.get_id() will only return a str, will this check worked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> "test::_wrapper_".endswith(ConfigComponentReservedKeys.WRAPPER)
True

Yes this should work, this evaluates to true for me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add _decorator_ keyword to Monai Bundle Configuration
3 participants