Skip to content

Commit

Permalink
[orientation] Allow disable of page and crop orientation (#1735)
Browse files Browse the repository at this point in the history
  • Loading branch information
milosacimovic authored Sep 27, 2024
1 parent 9045dcf commit 420ab32
Show file tree
Hide file tree
Showing 20 changed files with 328 additions and 74 deletions.
2 changes: 2 additions & 0 deletions api/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class KIEIn(BaseModel):
straighten_pages: bool = Field(default=False, examples=[False])
det_bs: int = Field(default=2, examples=[2])
reco_bs: int = Field(default=128, examples=[128])
disable_page_orientation: bool = Field(default=False, examples=[False])
disable_crop_orientation: bool = Field(default=False, examples=[False])
bin_thresh: float = Field(default=0.1, examples=[0.1])
box_thresh: float = Field(default=0.1, examples=[0.1])

Expand Down
16 changes: 15 additions & 1 deletion demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def main(det_archs, reco_archs):
st.sidebar.title("Parameters")
assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
st.sidebar.write("\n")
# Disable page orientation detection
disable_page_orientation = st.sidebar.checkbox("Disable page orientation detection", value=False)
st.sidebar.write("\n")
# Disable crop orientation detection
disable_crop_orientation = st.sidebar.checkbox("Disable crop orientation detection", value=False)
st.sidebar.write("\n")
# Straighten pages
straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
st.sidebar.write("\n")
Expand All @@ -89,7 +95,15 @@ def main(det_archs, reco_archs):
else:
with st.spinner("Loading model..."):
predictor = load_predictor(
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, box_thresh, forward_device
det_arch,
reco_arch,
assume_straight_pages,
straighten_pages,
disable_page_orientation,
disable_crop_orientation,
bin_thresh,
box_thresh,
forward_device,
)

with st.spinner("Analyzing..."):
Expand Down
6 changes: 6 additions & 0 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def load_predictor(
reco_arch: str,
assume_straight_pages: bool,
straighten_pages: bool,
disable_page_orientation: bool,
disable_crop_orientation: bool,
bin_thresh: float,
box_thresh: float,
device: torch.device,
Expand All @@ -49,6 +51,8 @@ def load_predictor(
reco_arch: recognition architecture
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
disable_page_orientation: whether to disable page orientation or not
disable_crop_orientation: whether to disable crop orientation or not
bin_thresh: binarization threshold for the segmentation map
box_thresh: minimal objectness score to consider a box
device: torch.device, the device to load the predictor on
Expand All @@ -65,6 +69,8 @@ def load_predictor(
straighten_pages=straighten_pages,
export_as_straight_boxes=straighten_pages,
detect_orientation=not assume_straight_pages,
disable_page_orientation=disable_page_orientation,
disable_crop_orientation=disable_crop_orientation,
).to(device)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
Expand Down
6 changes: 6 additions & 0 deletions demo/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def load_predictor(
reco_arch: str,
assume_straight_pages: bool,
straighten_pages: bool,
disable_page_orientation: bool,
disable_crop_orientation: bool,
bin_thresh: float,
box_thresh: float,
device: tf.device,
Expand All @@ -48,6 +50,8 @@ def load_predictor(
reco_arch: recognition architecture
assume_straight_pages: whether to assume straight pages or not
straighten_pages: whether to straighten rotated pages or not
disable_page_orientation: whether to disable page orientation or not
disable_crop_orientation: whether to disable crop orientation or not
bin_thresh: binarization threshold for the segmentation map
box_thresh: threshold for the detection boxes
device: tf.device, the device to load the predictor on
Expand All @@ -65,6 +69,8 @@ def load_predictor(
straighten_pages=straighten_pages,
export_as_straight_boxes=straighten_pages,
detect_orientation=not assume_straight_pages,
disable_page_orientation=disable_page_orientation,
disable_crop_orientation=disable_crop_orientation,
)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
Expand Down
34 changes: 31 additions & 3 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,16 @@ Those architectures involve one stage of text detection, and one stage of text r
You can pass specific boolean arguments to the predictor:

* `assume_straight_pages`
* `preserve_aspect_ratio`
* `symmetric_pad`
* `assume_straight_pages`: if you work with straight documents only, it will fit straight bounding boxes to the text areas.
* `preserve_aspect_ratio`: if you want to preserve the aspect ratio of your documents while resizing before sending them to the model.
* `symmetric_pad`: if you choose to preserve the aspect ratio, it will pad the image symmetrically and not from the bottom-right.

Those 3 are going straight to the detection predictor, as mentioned above (in the detection part).

Additional arguments which can be passed to the `ocr_predictor` are:

* `export_as_straight_boxes`: If you work with rotated and skewed documents but you still want to export straight bounding boxes and not polygons, set it to True.
* `straighten_pages`: If you want to straighten the pages before sending them to the detection model, set it to True.

For instance, this snippet instantiates an end-to-end ocr_predictor working with rotated documents, which preserves the aspect ratio of the documents, and returns polygons:

Expand All @@ -298,6 +301,7 @@ For instance, this snippet instantiates an end-to-end ocr_predictor working with
from doctr.model import ocr_predictor
model = ocr_predictor('linknet_resnet18', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True)
Additionally, you can change the batch size of the underlying detection and recognition predictors to optimize the performance depending on your hardware:

* `det_bs`: batch size for the detection model (default: 2)
Expand Down Expand Up @@ -465,6 +469,30 @@ This is useful to detect (possible less) text regions more accurately with a hig
out = predictor([input_page])
* Disable page orientation classification

If you deal with documents which contains only small rotations (~ -45 to 45 degrees), you can disable the page orientation classification to speed up the inference.

This will only have an effect with `assume_straight_pages=False` and/or `straighten_pages=True` and/or `detect_orientation=True`.

.. code:: python3
from doctr.model import ocr_predictor
model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_page_orientation=True)
* Disable crop orientation classification

If you deal with documents which contains only horizontal text, you can disable the crop orientation classification to speed up the inference.

This will only have an effect with `assume_straight_pages=False` and/or `straighten_pages=True`.

.. code:: python3
from doctr.model import ocr_predictor
model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_crop_orientation=True)
* Add a hook to the `ocr_predictor` to manipulate the location predictions before the crops are passed to the recognition model.

.. code:: python3
Expand Down
18 changes: 11 additions & 7 deletions doctr/models/classification/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import List, Union
from typing import List, Optional, Union

import numpy as np
import torch
Expand All @@ -27,12 +27,12 @@ class OrientationPredictor(nn.Module):

def __init__(
self,
pre_processor: PreProcessor,
model: nn.Module,
pre_processor: Optional[PreProcessor],
model: Optional[nn.Module],
) -> None:
super().__init__()
self.pre_processor = pre_processor
self.model = model.eval()
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
self.model = model.eval() if isinstance(model, nn.Module) else None

@torch.inference_mode()
def forward(
Expand All @@ -43,12 +43,16 @@ def forward(
if any(input.ndim != 3 for input in inputs):
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")

if self.model is None or self.pre_processor is None:
# predictor is disabled
return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]

processed_batches = self.pre_processor(inputs)
_params = next(self.model.parameters())
self.model, processed_batches = set_device_and_dtype(
self.model, processed_batches, _params.device, _params.dtype
)
predicted_batches = [self.model(batch) for batch in processed_batches]
predicted_batches = [self.model(batch) for batch in processed_batches] # type: ignore[misc]
# confidence
probs = [
torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches
Expand All @@ -57,7 +61,7 @@ def forward(
predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]

class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore[union-attr]
confs = [round(float(p), 2) for prob in probs for p in prob]

return [class_idxs, classes, confs]
14 changes: 9 additions & 5 deletions doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import List, Union
from typing import List, Optional, Union

import numpy as np
import tensorflow as tf
Expand All @@ -29,11 +29,11 @@ class OrientationPredictor(NestedObject):

def __init__(
self,
pre_processor: PreProcessor,
model: keras.Model,
pre_processor: Optional[PreProcessor],
model: Optional[keras.Model],
) -> None:
self.pre_processor = pre_processor
self.model = model
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
self.model = model if isinstance(model, keras.Model) else None

def __call__(
self,
Expand All @@ -43,6 +43,10 @@ def __call__(
if any(input.ndim != 3 for input in inputs):
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")

if self.model is None or self.pre_processor is None:
# predictor is disabled
return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)]

processed_batches = self.pre_processor(inputs)
predicted_batches = [self.model(batch, training=False) for batch in processed_batches]

Expand Down
8 changes: 7 additions & 1 deletion doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]


def _orientation_predictor(arch: Any, pretrained: bool, model_type: str, **kwargs: Any) -> OrientationPredictor:
def _orientation_predictor(
arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any
) -> OrientationPredictor:
if disabled:
# Case where the orientation predictor is disabled
return OrientationPredictor(None, None)

if isinstance(arch, str):
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/kie_predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,8 @@ def __init__(
assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs
)

# Remove the following arguments from kwargs after initialization of the parent class
kwargs.pop("disable_page_orientation", None)
kwargs.pop("disable_crop_orientation", None)

self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)
1 change: 1 addition & 0 deletions doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def forward(
dict_loc_preds[class_name],
channels_last=channels_last,
assume_straight_pages=self.assume_straight_pages,
assume_horizontal=self._page_orientation_disabled,
)
# Rectify crop orientation
crop_orientations: Any = {}
Expand Down
6 changes: 5 additions & 1 deletion doctr/models/kie_predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def __call__(
crops = {}
for class_name in dict_loc_preds.keys():
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages
pages,
dict_loc_preds[class_name],
channels_last=True,
assume_straight_pages=self.assume_straight_pages,
assume_horizontal=self._page_orientation_disabled,
)

# Rectify crop orientation
Expand Down
30 changes: 21 additions & 9 deletions doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,15 @@ def __init__(
) -> None:
self.assume_straight_pages = assume_straight_pages
self.straighten_pages = straighten_pages
self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True)
self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False)
self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False)
self.crop_orientation_predictor = (
None
if assume_straight_pages
else crop_orientation_predictor(pretrained=True, disabled=self._crop_orientation_disabled)
)
self.page_orientation_predictor = (
page_orientation_predictor(pretrained=True)
page_orientation_predictor(pretrained=True, disabled=self._page_orientation_disabled)
if detect_orientation or straighten_pages or not assume_straight_pages
else None
)
Expand Down Expand Up @@ -112,13 +118,18 @@ def _generate_crops(
loc_preds: List[np.ndarray],
channels_last: bool,
assume_straight_pages: bool = False,
assume_horizontal: bool = False,
) -> List[List[np.ndarray]]:
extraction_fn = extract_crops if assume_straight_pages else extract_rcrops

crops = [
extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator]
for page, _boxes in zip(pages, loc_preds)
]
if assume_straight_pages:
crops = [
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
for page, _boxes in zip(pages, loc_preds)
]
else:
crops = [
extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
for page, _boxes in zip(pages, loc_preds)
]
return crops

@staticmethod
Expand All @@ -127,8 +138,9 @@ def _prepare_crops(
loc_preds: List[np.ndarray],
channels_last: bool,
assume_straight_pages: bool = False,
assume_horizontal: bool = False,
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)

# Avoid sending zero-sized crops
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
Expand Down
1 change: 1 addition & 0 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def forward(
loc_preds,
channels_last=channels_last,
assume_straight_pages=self.assume_straight_pages,
assume_horizontal=self._page_orientation_disabled,
)
# Rectify crop orientation and get crop orientation predictions
crop_orientations: Any = []
Expand Down
6 changes: 5 additions & 1 deletion doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def __call__(

# Crop images
crops, loc_preds = self._prepare_crops(
pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
pages,
loc_preds,
channels_last=True,
assume_straight_pages=self.assume_straight_pages,
assume_horizontal=self._page_orientation_disabled,
)
# Rectify crop orientation and get crop orientation predictions
crop_orientations: Any = []
Expand Down
Loading

0 comments on commit 420ab32

Please sign in to comment.