Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MedNext implementation #8004

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open

Conversation

surajpaib
Copy link
Contributor

Fixes #7786

Description

Added MedNext architectures implementation for MONAI.

Since a lot of the code is heavily sourced from the original MedNext repo, https://github.com/MIC-DKFZ/MedNeXt, I wanted to check if there is an attribution policy with regarded to borrowed source code. I've added a derivative notice bellow the monai copyright comment. Let me know if this needs to be changed.

The blocks have been taken almost as is but the network implementation has been changed largely to allow flexible blocks and follow MONAI segresnet styling.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Copy link
Contributor

@johnzielke johnzielke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! I think it's a great addition to MONAI!

return self.conv_out(x)


class LayerNorm(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this LayerNorm implementation? Or can we just use the torch.nn.LayerNorm? I don't see the channels_last used anywhere.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy from the original codebase and can probably be replaced with torch.nn.LayerNorm as 'channels_first' is always assumed.

self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)

def forward(self, x, dummy_tensor=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know what the dummy_tensor is used for? I don't see it being used anywhere. This applies to the other forward functions as well

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also come from a copy from the original codebase and should be removed as it's never supposed to be used.

self,
spatial_dims: int = 3,
init_filters: int = 32,
in_channels: int = 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about setting this to None per default and using LazyConv if the in_channels are set to None?

blocks_down: list = [2, 2, 2, 2],
blocks_bottleneck: int = 2,
blocks_up: list = [2, 2, 2, 2],
norm_type: str = "group",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this a StrEnum? Makes it easy to see the options. Ofc also applies to the Block

Copy link

@rcremese rcremese Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnzielke Can you give me an exemple of its usage in one of network implementation. Because it's not straighforward for me how to use it.

Comment on lines 53 to 55
enc_exp_r: int = 2,
dec_exp_r: int = 2,
bottlenec_exp_r: int = 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO encoder_expansion_ratio (the same for decoder and bottleneck) is better here. It's a public API, makes it easier to see what it refers to without having to look at the docs and the few letters more don't matter that much

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept this convention similar to the original implementation which uses a joint exp_r argument. Happy to rename it. Agree with the comment

bottlenec_exp_r: int = 2,
kernel_size: int = 7,
deep_supervision: bool = False,
do_res: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use_residual_connections?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think use_res would probably suffice? Similar to args in other APIs like SegResNet such as use_conv_final

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, my only concern is that res can be sth. like resolution or result.

kernel_size: int = 7,
deep_supervision: bool = False,
do_res: bool = False,
do_res_up_down: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use_residual_connections_up_down_blocks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds very verbose to me and may not be necessary. I think the contraction res is easily understood. If playing with these params, users are encouraged to read the code nonetheless to figure out where it goes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about use_residual_up_down ?

monai/networks/nets/mednext.py Show resolved Hide resolved
Comment on lines 78 to 84
if dim == "2d":
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
else:
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is easier to understand, but the current one is fine as well.

Suggested change
if dim == "2d":
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1), requires_grad=True)
else:
self.grn_beta = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(1, exp_r * in_channels, 1, 1, 1), requires_grad=True)
if dim == "2d":
grn_parameter_shape = (1,1)
elif dim == "3d":
grn_parameter_shape = (1,1,1)
else:
raise ValueError()
grn_parameter_shape = (1, exp_r * in_channels,) + grn_parameter_shape
self.grn_beta = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)
self.grn_gamma = nn.Parameter(torch.zeros(grn_parameter_shape), requires_grad=True)

@rcremese
Copy link

rcremese commented Aug 14, 2024

Can't it be usefull to adapt the factory functions of create_mednext_v1 script inside your mednext.py script in order to implement basic configurations of the models in order to compare with the reference paper ?

rcremese added a commit to rcremese/MONAI that referenced this pull request Aug 21, 2024
…edNext variants (S, B, M, L) + integration of remarks from @johnzilke (Project-MONAI#8004 (review)) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder

Signed-off-by: Robin CREMESE <[email protected]>
@rcremese rcremese mentioned this pull request Aug 21, 2024
7 tasks
@surajpaib
Copy link
Contributor Author

Thanks for the comments @johnzielke and thank you for the updates @rcremese. I'll update my branch to dev and look through the changes in your PR asap.

rcremese added a commit to rcremese/MONAI that referenced this pull request Sep 2, 2024
…edNext variants (S, B, M, L) + integration of remarks from @johnzilke (Project-MONAI#8004 (review)) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder

Signed-off-by: Robin CREMESE <[email protected]>
rcremese and others added 2 commits September 3, 2024 15:21
…edNext variants (S, B, M, L) + integration of remarks from @johnzilke (Project-MONAI#8004 (review)) for renaming class arguments - removal of self defined LayerNorm - linked residual connection for encoder and decoder

Signed-off-by: Robin CREMESE <[email protected]>
@surajpaib surajpaib marked this pull request as ready for review September 27, 2024 23:34
Signed-off-by: Suraj Pai <[email protected]>
@surajpaib
Copy link
Contributor Author

@KumoLiu This implementation should be ready now. Please let me know if you have any comments

monai/networks/blocks/mednext_block.py Show resolved Hide resolved
monai/networks/blocks/mednext_block.py Outdated Show resolved Hide resolved
monai/networks/blocks/mednext_block.py Show resolved Hide resolved
monai/networks/blocks/mednext_block.py Show resolved Hide resolved
Signed-off-by: Suraj Pai <[email protected]>
@surajpaib
Copy link
Contributor Author

@KumoLiu Added.

Do you think there would be interest to add this as a candidate for Auto3DSeg? I refer to this paper for its performance benchmarking: https://arxiv.org/abs/2404.09556

@KumoLiu
Copy link
Contributor

KumoLiu commented Oct 4, 2024

Do you think there would be interest to add this as a candidate for Auto3DSeg? I refer to this paper for its performance benchmarking: https://arxiv.org/abs/2404.09556

Thank you for bringing this up, it's an interesting suggestion. I believe it could be worthwhile to consider this as a potential candidate for Auto3DSeg. However, before moving forward, I would appreciate hearing others' thoughts and insights on whether this aligns with the current goals and roadmap for Auto3DSeg. cc @mingxin-zheng @dongyang0122 @Nic-Ma @myron

@mrcolo
Copy link

mrcolo commented Oct 10, 2024

Great work @surajpaib ! any plans on when we're gonna be able to get this into main?

@surajpaib
Copy link
Contributor Author

Hi @KumoLiu, any update on merging this?

We can maybe have a separate issue to discuss Auto3DSeg integration then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add MedNeXt model architectures within MONAI
5 participants