Skip to content

Commit

Permalink
[demo] remove limitation and update demo (mindee#1390)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Nov 24, 2023
1 parent 8e0609d commit 3b99d46
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 108 deletions.
18 changes: 15 additions & 3 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,26 @@ def main(det_archs, reco_archs):

# For newline
st.sidebar.write("\n")
# Only straight pages or possible rotation
st.sidebar.title("Parameters")
assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
st.sidebar.write("\n")
# Straighten pages
straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
st.sidebar.write("\n")
# Binarization threshold
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
st.sidebar.write("\n")

if st.sidebar.button("Analyze page"):
if uploaded_file is None:
st.sidebar.write("Please upload a document")

else:
with st.spinner("Loading model..."):
predictor = load_predictor(det_arch, reco_arch, forward_device)
predictor = load_predictor(
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
)

with st.spinner("Analyzing..."):
# Forward the image to the model
Expand All @@ -93,12 +105,12 @@ def main(det_archs, reco_archs):

# Plot OCR output
out = predictor([page])
fig = visualize_page(out.pages[0].export(), page, interactive=False)
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
cols[2].pyplot(fig)

# Page reconsitution under input page
page_export = out.pages[0].export()
if "rotation" not in det_arch:
if assume_straight_pages or (not assume_straight_pages and straighten_pages):
img = out.pages[0].synthesize()
cols[3].image(img, clamp=True)

Expand Down
22 changes: 19 additions & 3 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"db_resnet50",
"db_resnet34",
"db_mobilenet_v3_large",
"db_resnet50_rotation",
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
Expand All @@ -30,22 +29,39 @@
]


def load_predictor(det_arch: str, reco_arch: str, device: torch.device) -> OCRPredictor:
def load_predictor(
det_arch: str,
reco_arch: str,
assume_straight_pages: bool,
straighten_pages: bool,
bin_thresh: float,
device: torch.device,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Args:
----
det_arch: detection architecture
reco_arch: recognition architecture
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
bin_thresh: binarization threshold for the segmentation map
device: torch.device, the device to load the predictor on
Returns:
-------
instance of OCRPredictor
"""
predictor = ocr_predictor(
det_arch, reco_arch, pretrained=True, assume_straight_pages=("rotation" not in det_arch)
det_arch,
reco_arch,
pretrained=True,
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
export_as_straight_boxes=straighten_pages,
detect_orientation=not assume_straight_pages,
).to(device)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
return predictor


Expand Down
22 changes: 19 additions & 3 deletions demo/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"db_resnet50",
"db_mobilenet_v3_large",
"linknet_resnet18",
"linknet_resnet18_rotation",
"linknet_resnet34",
"linknet_resnet50",
]
Expand All @@ -29,13 +28,23 @@
]


def load_predictor(det_arch: str, reco_arch: str, device: tf.device) -> OCRPredictor:
def load_predictor(
det_arch: str,
reco_arch: str,
assume_straight_pages: bool,
straighten_pages: bool,
bin_thresh: float,
device: tf.device,
) -> OCRPredictor:
"""Load a predictor from doctr.models
Args:
----
det_arch: detection architecture
reco_arch: recognition architecture
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
bin_thresh: binarization threshold for the segmentation map
device: tf.device, the device to load the predictor on
Returns:
Expand All @@ -44,8 +53,15 @@ def load_predictor(det_arch: str, reco_arch: str, device: tf.device) -> OCRPredi
"""
with device:
predictor = ocr_predictor(
det_arch, reco_arch, pretrained=True, assume_straight_pages=("rotation" not in det_arch)
det_arch,
reco_arch,
pretrained=True,
assume_straight_pages=assume_straight_pages,
straighten_pages=straighten_pages,
export_as_straight_boxes=straighten_pages,
detect_orientation=not assume_straight_pages,
)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
return predictor


Expand Down
4 changes: 0 additions & 4 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,12 @@ doctr.models.detection

.. autofunction:: doctr.models.detection.linknet_resnet18

.. autofunction:: doctr.models.detection.linknet_resnet18_rotation

.. autofunction:: doctr.models.detection.linknet_resnet34

.. autofunction:: doctr.models.detection.linknet_resnet50

.. autofunction:: doctr.models.detection.db_resnet50

.. autofunction:: doctr.models.detection.differentiable_binarization.pytorch.db_resnet50_rotation

.. autofunction:: doctr.models.detection.db_mobilenet_v3_large

.. autofunction:: doctr.models.detection.detection_predictor
Expand Down
44 changes: 1 addition & 43 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ...utils import _bf16_to_float32, load_pretrained_params
from .base import DBPostProcessor, _DBNet

__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "db_resnet50_rotation"]
__all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]


default_cfgs: Dict[str, Dict[str, Any]] = {
Expand All @@ -41,12 +41,6 @@
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.3.1/db_mobilenet_v3_large-fd62154b.pt&src=0",
},
"db_resnet50_rotation": {
"input_shape": (3, 1024, 1024),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/db_resnet50-1138863a.pt&src=0",
},
}


Expand Down Expand Up @@ -431,39 +425,3 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
],
**kwargs,
)


def db_resnet50_rotation(pretrained: bool = False, **kwargs: Any) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
This model is trained with rotated documents
>>> import torch
>>> from doctr.models import db_resnet50_rotation
>>> model = db_resnet50_rotation(pretrained=True)
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
----
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _dbnet(
"db_resnet50_rotation",
pretrained,
resnet50,
["layer1", "layer2", "layer3", "layer4"],
None,
ignore_keys=[
"prob_head.6.weight",
"prob_head.6.bias",
"thresh_head.6.weight",
"thresh_head.6.bias",
],
**kwargs,
)
36 changes: 1 addition & 35 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,10 @@

from .base import LinkNetPostProcessor, _LinkNet

__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50", "linknet_resnet18_rotation"]
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]

default_cfgs: Dict[str, Dict[str, Any]] = {
"linknet_resnet18": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.6.0/linknet_resnet18-611b50f2.zip&src=0",
},
"linknet_resnet18_rotation": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
Expand Down Expand Up @@ -313,34 +307,6 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
)


def linknet_resnet18_rotation(pretrained: bool = False, **kwargs: Any) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
<https://arxiv.org/pdf/1707.03718.pdf>`_.
>>> import tensorflow as tf
>>> from doctr.models import linknet_resnet18_rotation
>>> model = linknet_resnet18_rotation(pretrained=True)
>>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Args:
----
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
**kwargs: keyword arguments of the LinkNet architecture
Returns:
-------
text detection architecture
"""
return _linknet(
"linknet_resnet18_rotation",
pretrained,
resnet18,
["resnet_block_1", "resnet_block_3", "resnet_block_5", "resnet_block_7"],
**kwargs,
)


def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
<https://arxiv.org/pdf/1707.03718.pdf>`_.
Expand Down
13 changes: 1 addition & 12 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
__all__ = ["detection_predictor"]

ARCHS: List[str]
ROT_ARCHS: List[str]


if is_tf_available():
ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
ROT_ARCHS = ["linknet_resnet18_rotation"]
elif is_torch_available():
ARCHS = [
"db_resnet34",
Expand All @@ -29,22 +27,13 @@
"linknet_resnet34",
"linknet_resnet50",
]
ROT_ARCHS = ["db_resnet50_rotation"]


def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
if isinstance(arch, str):
if arch not in ARCHS + ROT_ARCHS:
if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

if arch not in ROT_ARCHS and not assume_straight_pages:
raise AssertionError(
"You are trying to use a model trained on straight pages while not assuming"
" your pages are straight. If you have only straight documents, don't pass"
" assume_straight_pages=False, otherwise you should use one of these archs:"
f"{ROT_ARCHS}"
)

_model = detection.__dict__[arch](
pretrained=pretrained,
pretrained_backbone=kwargs.get("pretrained_backbone", True),
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

AVAILABLE_ARCHS = {
"classification": models.classification.zoo.ARCHS,
"detection": models.detection.zoo.ARCHS + models.detection.zoo.ROT_ARCHS,
"detection": models.detection.zoo.ARCHS,
"recognition": models.recognition.zoo.ARCHS,
"obj_detection": ["fasterrcnn_mobilenet_v3_large_fpn"] if is_torch_available() else None,
}
Expand Down
5 changes: 4 additions & 1 deletion doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def forward(
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)

# Detect document rotation and rotate pages
seg_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps]
seg_maps = [
np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
for out_map in out_maps
]
if self.detect_orientation:
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
orientations = [
Expand Down
5 changes: 4 additions & 1 deletion doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def __call__(
loc_preds_dict, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)

# Detect document rotation and rotate pages
seg_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps]
seg_maps = [
np.where(out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), 255, 0).astype(np.uint8)
for out_map in out_maps
]
if self.detect_orientation:
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
orientations = [
Expand Down
1 change: 0 additions & 1 deletion tests/pytorch/test_models_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test_push_to_hf_hub():
["db_resnet34", "detection", "Felix92/doctr-dummy-torch-db-resnet34"],
["db_resnet50", "detection", "Felix92/doctr-dummy-torch-db-resnet50"],
["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-torch-db-mobilenet-v3-large"],
["db_resnet50_rotation", "detection", "Felix92/doctr-dummy-torch-db-resnet50-rotation"],
["linknet_resnet18", "detection", "Felix92/doctr-dummy-torch-linknet-resnet18"],
["linknet_resnet34", "detection", "Felix92/doctr-dummy-torch-linknet-resnet34"],
["linknet_resnet50", "detection", "Felix92/doctr-dummy-torch-linknet-resnet50"],
Expand Down
1 change: 0 additions & 1 deletion tests/tensorflow/test_models_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test_push_to_hf_hub():
["db_resnet50", "detection", "Felix92/doctr-dummy-tf-db-resnet50"],
["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-tf-db-mobilenet-v3-large"],
["linknet_resnet18", "detection", "Felix92/doctr-dummy-tf-linknet-resnet18"],
["linknet_resnet18_rotation", "detection", "Felix92/doctr-dummy-tf-linknet-resnet18-rotation"],
["linknet_resnet34", "detection", "Felix92/doctr-dummy-tf-linknet-resnet34"],
["linknet_resnet50", "detection", "Felix92/doctr-dummy-tf-linknet-resnet50"],
["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-tf-crnn-vgg16-bn"],
Expand Down

0 comments on commit 3b99d46

Please sign in to comment.