Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for torch.compile #1684

Open
23 tasks
Fabioomega opened this issue Aug 5, 2024 · 13 comments
Open
23 tasks

Add support for torch.compile #1684

Fabioomega opened this issue Aug 5, 2024 · 13 comments
Labels
Milestone

Comments

@Fabioomega
Copy link
Contributor

Fabioomega commented Aug 5, 2024

🚀 The feature

torch.compile is a Pytorch feature that compiles the model into static kernels using a JIT. More information on: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

Motivation, pitch

I've been using the doctr for some of my projects but some of the more robust models take a long time on old hardware. Using some of those nifty zero-overhead features of pytorch seemed like a good way to improve performance.

Alternatives

I've already considered using using pytorch jit or trace but it seemed like the easiest to implement would be torch.compile approach. Also I think it's the easiest to implement while being 'zero-overhead'.

Additional context

I've already implemented a version of doctr that has support for the feature: https://github.com/Fabioomega/doctr/tree/pytorch-compile-support. But, because of the significant code changes and incomplete fullgraph support for some models, I was unsure if it was of any interest so I thought it would be prudent to open an issue first. As of right now I should be finished in implementing all that I could for the detection models, the crop and orientation models and fullgraph support for parseq.

Model tasks

Recognition

  • crnn_vgg16_bn
  • fullgraph=True
  • crnn_mobilenet_v3_
  • fullgraph=True
  • sar_resnet31
  • fullgraph=True
  • master
  • fullgraph=True
  • vitstr_
  • fullgraph=True
  • parseq
  • fullgraph=True

Detection

  • db_
  • fullgraph=True
  • fast_
  • fullgraph=True
  • linknet_
  • fullgraph=True

Classification

  • mobilenet_v3_small_crop_orientation
  • fullgraph=True
  • mobilenet_v3_small_page_orientation
  • fullgraph=True
  • .. all other classification models (backbones)
@Fabioomega Fabioomega added the type: enhancement Improvement label Aug 5, 2024
@felixdittrich92
Copy link
Contributor

Hi @Fabioomega 👋,

Thanks for opening the feature request, that's a really cool idea.

The optimal way would look like:

  1. Ensure that all models can be complied with torch.compile
  2. Add corresponding unittests similar to the onnx onces:
  1. Update the documentation (https://github.com/mindee/doctr/blob/main/docs/source/using_doctr/using_model_export.rst) and mark it as pytorch only feature

In this case we would avoid to add backend specific parameter/s to the interface.
For example:

import torch

det_model = torch.compile(db_resnet50(pretrained=True))
reco_model = torch.compile(crnn_vgg16_bn(pretrained=True))
predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False)

For running on older hardware you could also check out: https://github.com/felixdittrich92/OnnxTR which is optimized for inference and more lightweight.

Wdyt ? :)

PS: I haven't had the time to check your branch but i remember i tried this with pytorch 2.0 already and most models failed with the postprocessing or was slower as before !?

@Fabioomega
Copy link
Contributor Author

Hey! Thank you for the reply. I’ve already tried OnnxTR, but for some reason, and I’m unsure if it was because of the detection or recognition models, it actually performed worse. Maybe I did something wrong? I’m unsure.

Also, as far as my tests go, the first time the model runs, it takes significantly longer because of the compiling step. But after that, it becomes somewhat faster, with a 20-30% improvement on my machine. It did have some problems in the post-processing step with some OpenCV functions, which I fixed by wrapping them into custom operators; I haven’t had an issue since.

For the interface, I didn’t change it at all. I only added an environment variable to control whether to run torch.compile on the models or not, while disabling itself if the version is before PyTorch 2.4 or there’s no Triton support. Though I did need to add some PyTorch-specific functions in some files, I believe it shouldn’t cause any problems because they're guarded.

I will write some documentation and add unittests as soon as I have some free time. Also, should I write unittest for each model? As far as I checked it some models may neeed some specific tweaks to make they react better to the compilation step, but they should work fine without it.

Thank you for your amazing work and for making it public!

@felixdittrich92
Copy link
Contributor

Hey @Fabioomega 👋,

  1. Do you remember the version you have used ? Because i have released a significant bug fix with the last release (https://github.com/felixdittrich92/OnnxTR/releases/tag/v0.3.2)

  2. We need to think about this (best on a first draft PR) because doctr provides support for TF and PT as backend, so wrapping something in the postprocessors sounds tricky depending on the backend because the postprocessors are "shared" between both backends (In short we can't use any backend specific conditions here) - we need to think careful about this and maybe alternatives

  3. The optimal case would be to let the user do the torch.complie call on it's own and describe it well in the documentation. We can't pin pytorch to >=2.4 yet, so this would require an additional version check before -> better to keep it on user side

  4. Yep we should test it like the mentioned onnx tests (that we can compile each model and that the outputs from the base model and the compiled models are the same or close (with if required a maximal tolerance of 1e-4))
    Best would be to open a draft PR where we can iterate on :)

CC @frgfm @odulcy-mindee

@Fabioomega
Copy link
Contributor Author

I'm not sure which version I used, but I believe it was before 0.3.2, it's been a couple of moths, so I'm definitly trying the new one!

Also, I should've handled the issue with the shared postprocessors because I've made a cv2 fallback option in case the new wrapped cv2 functions are not being used. But again, the way I structured may be kinda annoying.

Either way I'm going to open the PR and thank you very much to letting me contribute to such an amazing project!

@felixdittrich92
Copy link
Contributor

👍 Happy to get your feedback.

Do you have tested how huge the impact is if we compile only the model itself without the postprocessing step ? This would make it much easier to integrate

Happy about every new contributor 🤗

@Fabioomega
Copy link
Contributor Author

Fabioomega commented Aug 9, 2024

I haven't tested. In fact, I'm unsure whether that is even possible. I may be mistaken but I think the postprocessing is called in the foward function of the model, which will make the torch.compile try to compile them. It would make it easier to integrate it if it was not part of the foward function, tough I imagine that would require some annoying refactoring that I'm unsure it's worth the effort.

Do you have any ideia on how could I test that? I think the only way I see it to force the torch.compile to not compile those functions would be to maybe wrap them in a custom operator, which, as you said before, is not ideal.

One solution that, maybe, is better is to do something similar to the .cuda() that the lib currently has. Internally it could replace some functions or even the postprocessor with a compatible version? But that could be hard to maintain.

Anyway, based on the performance metrics I've tried while wrapping the cv2 function into custom ops I think the postprocessing step shouldn't affect anything by being compiled or the improvemenet is negligible. Mostly I did the wrapping because some functions were causing crashes on my machine while being compiled. Should probably investigate that further, shouldn't I?

Either way, I will try to remove the custom operators and see if there's some signficant performance loss or if I can narrow down the crashing function. Never know maybe I was just hallucinating. Also, I'll try to think some solutions that are not so extensive in changes or focus on the user as you said before.

@Fabioomega
Copy link
Contributor Author

Fabioomega commented Aug 13, 2024

I’ve just confirmed. It seems that not compiling the postprocessor is not possible with the current architecture, at least to my knowledge. Though it seems that I was just crazy when I said it couldn’t compile without any changes to the postprocessor; I’ve just tested and it seems that, while there’s a little degradation in performance, it could be used regardless.

Also, the performance gain is almost insignificant on my machine when used with assume_straight_pages=True for some reason. While assume_straight_pages=False seems to get around the advertised number I talked about before: 20%-30% improvement, but the result may vary.

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented Aug 14, 2024

Hey @Fabioomega 👋,

Excuse the late response.
So if i understand your last message correct we can compile the models without the need to wrap the postproc functions ?
That would definitely the prefered way actually :)

In this case the scope of the PR would be more related to writing test cases and a well defined documentation section + maybe small model related changes.

pseudo code example test case: (same for the classification and detection models)

@pytest.mark.parametrize("fullgraph", [True, False])
@pytest.mark.parametrize(
    "arch_name, input_shape",
    [
        ["crnn_vgg16_bn", (3, 32, 128)],
        ["crnn_mobilenet_v3_small", (3, 32, 128)],
        ["crnn_mobilenet_v3_large", (3, 32, 128)],
        pytest.param(
            "sar_resnet31",
            (3, 32, 128),
            marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory"),
        ),
        pytest.param(
            "master", (3, 32, 128), marks=pytest.mark.skipif(system_available_memory < 16, reason="too less memory")
        ),
        ["vitstr_small", (3, 32, 128)],  # testing one vitstr version is enough
        ["parseq", (3, 32, 128)],
    ],
)
def test_models_torch_compiled(arch_name, input_shape, fullgraph):
    # General Check that the model can be compiled
    assert torch.compile(recognition.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph)
    # TODO: Should we check the model outputs including the post proc step or are the logits enough?
    # Model
    batch_size = 2
    model = recognition.__dict__[arch_name](pretrained=True, exportable=True).eval()
    dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
    pt_logits = model(dummy_input)["logits"].detach().cpu().numpy()

    compiled_model = torch.compile(model, fullgraph=fullgraph)
    pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy()

    assert pt_logits_compiled.shape == pt_logits.shape
    # Check that the output is close to the "original" output
    assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4)

@felixdittrich92
Copy link
Contributor

I’ve just confirmed. It seems that not compiling the postprocessor is not possible with the current architecture, at least to my knowledge. Though it seems that I was just crazy when I said it couldn’t compile without any changes to the postprocessor; I’ve just tested and it seems that, while there’s a little degradation in performance, it could be used regardless.

Also, the performance gain is almost insignificant on my machine when used with assume_straight_pages=True for some reason. While assume_straight_pages=False seems to get around the advertised number I talked about before: 20%-30% improvement, but the result may vary.

Yeah with assume_straight_pages=False there is much more computation under the hood (performed on polygons not on boxes), additional there are 2 more models for crop orientation and page orientation classification so the difference you see makes sense :)

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented Aug 14, 2024

All recognition models seems to work out of the box with torch.compile - no fullgraph (except master)
with full graph only vitstr_ and parseq seems to work.

So i would suggest the following for the current PR:

  • Implement the tests recognition / detection / classification (skip models which does fail while compilation) - same way it's done for the onnx tests -> that way we get only the warning but the tests does not fail hard
  • Add the documentation section (A table would be could which shows which models can be compiled + fullgraph support)
  • We can pin the lower bound from pytorch to 2.0 in the pyproject.toml this should be fine

If that's done we can take care of the failing models and fix it one by one (so 1 fix == 1 PR)

That way we have a clear structure and can keep track of the progress wdyt ?

@Fabioomega
Copy link
Contributor Author

Hey! Sorry for the delay; I’ve been busy with university stuff!

Your idea sounds good! I should be able to run the tests this week, hopefully. Should I add a new commit to remove the changes in the post-processing, or start a new branch and create a new PR?

I also like the idea of separating the models into different PRs. Are there any failing models right now? Or by “taking care of the failing models” did you mean enabling fullgraph support?

@felixdittrich92
Copy link
Contributor

Hey @Fabioomega 🤗,

No stress :)

Correct i would say let's update your current PR 👍

I have tested only the recognition models quickly:

compile with fullgraph=False: All models work expect master
compile with fullgraph=True: only vitstr_ and parseq models works

For the documentation we could provide a table in the section for the first iteration:

for example:

Screenshot from 2024-08-19 07-49-21

So in the first PR we provide the tests for all models and skip failing onces (should run but raise only a warning - same we have for the onnx tests) and the documentation section with an example usage and an overview which models can be compiled.

On follow up PR's we pick a model which does not work - fix it (if possible) - update the table -> done -> next model :)

@felixdittrich92
Copy link
Contributor

felixdittrich92 commented Aug 19, 2024

PS: Last column should be fullgraph=True 😅

PS 2: updated the issue to keep track on the progress

@felixdittrich92 felixdittrich92 added this to the 2.0.0 milestone Oct 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants