Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On the fly rescaling (GPU) #64

Merged
merged 11 commits into from
May 23, 2024
Merged

Conversation

LorenzLamm
Copy link
Collaborator

Regarding the discussion in #55, I added the option to perform rescaling of the inference patches on the fly and on GPU.

With this, the user can simply input tomograms in any pixel size, the model will perform sliding window inference, and rescale each patch individually to e.g. 10A. The output is again a segmentation in the original pixel size.

This is really fast (difference depending on tomo pixel size, but e.g. 150 vs 155sec inference time) and did not result in big changes in segmentation quality.

How it's done internally:

  1. Sliding Window inferer window size is adjusted s.t. after rescaling this sliding window (e.g. to 10A), the rescaled window has the target size (default 160)
  2. Rescaling is performed within the model itself (as preprocessing / postprocessing options). This way, we do not need to touch the SWInferer class, which seemed convenient to me.
    Workflow: rescale patch to 160^3 --> model prediction --> rescale back to original shape
  3. The SWInferer stitches together the patches in the original dimensions

@alisterburt This is not using any libtilt functionality yet (I only found fourier cropping / padding for 2D, but I guess this could easily be extended). I guess the rescaling itself could be done more sophisticated, but but sure if necessary for this task?
@rdrighetto

Happy for any feedback :)

@@ -18,6 +18,7 @@ def list_commands(self, ctx: Context):
add_completion=False,
no_args_is_help=True,
rich_markup_mode="rich",
pretty_exceptions_show_locals=False
Copy link
Collaborator Author

@LorenzLamm LorenzLamm Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how you feel, but I don't like default printing of local variables for debugging most of the time. It's a bit annoying because it prints the entire model weights.

Of course, also provides more detailed information, but maybe this could be an advanced option?

)


def rescale_tensor(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for re-scaling the prediction scores

@uermel
Copy link

uermel commented May 9, 2024

Hi @LorenzLamm, great PR, this does improve performance for segmentation in some initial tests I did.

I have one request: It would be nice set the torch device for the rescaling functions to the model's device. At the moment, the fourier_cropping_torch and fourier_extend_torch set the device to an unspecified cuda-device, which causes exceptions when trying to run the inference on a specific GPU. I've modified this for our inference wrapper, but this would be good to exist upstream.

It's a simple change (this is in our wrapper):

@codecov-commenter
Copy link

codecov-commenter commented May 23, 2024

Codecov Report

Attention: Patch coverage is 0% with 92 lines in your changes are missing coverage. Please review.

Project coverage is 0.00%. Comparing base (bea5ae0) to head (ef62e08).
Report is 6 commits behind head on main.

Files Patch % Lines
...mbrain_seg/segmentation/networks/inference_unet.py 0.00% 41 Missing ⚠️
..._preprocessing/matching_utils/px_matching_utils.py 0.00% 27 Missing ⚠️
src/membrain_seg/segmentation/segment.py 0.00% 24 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@          Coverage Diff           @@
##            main     #64    +/-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files         40      46     +6     
  Lines       1411    1631   +220     
======================================
- Misses      1411    1631   +220     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@LorenzLamm
Copy link
Collaborator Author

Hey @uermel ,

Thanks a lot for your feedback on this. Sorry for the late reply -- vacation kept me from working on this :)
I have incorporated your suggestions and the device is now the same as the model's device.

This functionality seems to work -- I think it's ready to merge into main.

@LorenzLamm LorenzLamm merged commit 92ec8ca into teamtomo:main May 23, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants