-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
On the fly rescaling (GPU) #64
Conversation
@@ -18,6 +18,7 @@ def list_commands(self, ctx: Context): | |||
add_completion=False, | |||
no_args_is_help=True, | |||
rich_markup_mode="rich", | |||
pretty_exceptions_show_locals=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how you feel, but I don't like default printing of local variables for debugging most of the time. It's a bit annoying because it prints the entire model weights.
Of course, also provides more detailed information, but maybe this could be an advanced option?
) | ||
|
||
|
||
def rescale_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for re-scaling the prediction scores
Hi @LorenzLamm, great PR, this does improve performance for segmentation in some initial tests I did. I have one request: It would be nice set the torch device for the rescaling functions to the model's device. At the moment, the It's a simple change (this is in our wrapper):
|
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #64 +/- ##
======================================
Coverage 0.00% 0.00%
======================================
Files 40 46 +6
Lines 1411 1631 +220
======================================
- Misses 1411 1631 +220 ☔ View full report in Codecov by Sentry. |
Hey @uermel , Thanks a lot for your feedback on this. Sorry for the late reply -- vacation kept me from working on this :) This functionality seems to work -- I think it's ready to merge into main. |
Regarding the discussion in #55, I added the option to perform rescaling of the inference patches on the fly and on GPU.
With this, the user can simply input tomograms in any pixel size, the model will perform sliding window inference, and rescale each patch individually to e.g. 10A. The output is again a segmentation in the original pixel size.
This is really fast (difference depending on tomo pixel size, but e.g. 150 vs 155sec inference time) and did not result in big changes in segmentation quality.
How it's done internally:
Workflow: rescale patch to 160^3 --> model prediction --> rescale back to original shape
@alisterburt This is not using any libtilt functionality yet (I only found fourier cropping / padding for 2D, but I guess this could easily be extended). I guess the rescaling itself could be done more sophisticated, but but sure if necessary for this task?
@rdrighetto
Happy for any feedback :)