diff --git a/doctr/models/kie_predictor/base.py b/doctr/models/kie_predictor/base.py index 53d807898e..67d6afbf87 100644 --- a/doctr/models/kie_predictor/base.py +++ b/doctr/models/kie_predictor/base.py @@ -36,6 +36,7 @@ class _KIEPredictor(_OCRPredictor): def __init__( self, assume_straight_pages: bool = True, + assume_straight_text: bool = False, straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, @@ -43,7 +44,13 @@ def __init__( **kwargs: Any, ) -> None: super().__init__( - assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs + assume_straight_pages, + assume_straight_text, + straighten_pages, + preserve_aspect_ratio, + symmetric_pad, + detect_orientation, + **kwargs, ) self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs) diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 4bcedc7064..c0c50b4a2b 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -29,6 +29,8 @@ class KIEPredictor(nn.Module, _KIEPredictor): reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. + assume_straight_text: if True, speeds up the inference by assuming you only pass straight text + without rotated textual elements. straighten_pages: if True, estimates the page general orientation based on the median line orientation. Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped accordingly. Doing so will improve performances for documents with page-uniform rotations. @@ -44,6 +46,7 @@ def __init__( det_predictor: DetectionPredictor, reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, + assume_straight_text: bool = False, straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, @@ -57,6 +60,7 @@ def __init__( _KIEPredictor.__init__( self, assume_straight_pages, + assume_straight_text, straighten_pages, preserve_aspect_ratio, symmetric_pad, @@ -129,10 +133,11 @@ def forward( dict_loc_preds[class_name], channels_last=channels_last, assume_straight_pages=self.assume_straight_pages, + assume_straight_text=self.assume_straight_text, ) # Rectify crop orientation crop_orientations: Any = {} - if not self.assume_straight_pages: + if not self.assume_straight_pages and not self.assume_straight_text: for class_name in dict_loc_preds.keys(): crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops( crops[class_name], dict_loc_preds[class_name] diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index d9d765bbe6..7a39e11e82 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -29,6 +29,8 @@ class KIEPredictor(NestedObject, _KIEPredictor): reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. + assume_straight_text: if True, speeds up the inference by assuming you only pass straight text + without rotated textual elements. straighten_pages: if True, estimates the page general orientation based on the median line orientation. Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped accordingly. Doing so will improve performances for documents with page-uniform rotations. @@ -46,6 +48,7 @@ def __init__( det_predictor: DetectionPredictor, reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, + assume_straight_text: bool = False, straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, @@ -58,6 +61,7 @@ def __init__( _KIEPredictor.__init__( self, assume_straight_pages, + assume_straight_text, straighten_pages, preserve_aspect_ratio, symmetric_pad, @@ -122,12 +126,16 @@ def __call__( crops = {} for class_name in dict_loc_preds.keys(): crops[class_name], dict_loc_preds[class_name] = self._prepare_crops( - pages, dict_loc_preds[class_name], channels_last=True, assume_straight_pages=self.assume_straight_pages + pages, + dict_loc_preds[class_name], + channels_last=True, + assume_straight_pages=self.assume_straight_pages, + assume_straight_text=self.assume_straight_text, ) # Rectify crop orientation crop_orientations: Any = {} - if not self.assume_straight_pages: + if not self.assume_straight_pages and not self.assume_straight_text: for class_name in dict_loc_preds.keys(): crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops( crops[class_name], dict_loc_preds[class_name] diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index 0469b32ea3..a804de272e 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -8,7 +8,7 @@ import numpy as np from doctr.models.builder import DocumentBuilder -from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image +from doctr.utils.geometry import extract_crops, extract_dewarped_crops, extract_rcrops, rotate_image from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds from ..classification import crop_orientation_predictor, page_orientation_predictor @@ -24,6 +24,8 @@ class _OCRPredictor: ---- assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. + assume_straight_text: if True, speeds up the inference by assuming you only pass straight text + without rotated textual elements. straighten_pages: if True, estimates the page general orientation based on the median line orientation. Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped accordingly. Doing so will improve performances for documents with page-uniform rotations. @@ -40,6 +42,7 @@ class _OCRPredictor: def __init__( self, assume_straight_pages: bool = True, + assume_straight_text: bool = False, straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, @@ -47,8 +50,13 @@ def __init__( **kwargs: Any, ) -> None: self.assume_straight_pages = assume_straight_pages + self.assume_straight_text = assume_straight_text self.straighten_pages = straighten_pages - self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True) + self.crop_orientation_predictor = ( + None + if assume_straight_pages or (not assume_straight_pages and assume_straight_text) + else crop_orientation_predictor(pretrained=True) + ) self.page_orientation_predictor = ( page_orientation_predictor(pretrained=True) if detect_orientation or straighten_pages or not assume_straight_pages @@ -112,8 +120,15 @@ def _generate_crops( loc_preds: List[np.ndarray], channels_last: bool, assume_straight_pages: bool = False, + assume_straight_text: bool = False, ) -> List[List[np.ndarray]]: - extraction_fn = extract_crops if assume_straight_pages else extract_rcrops + if assume_straight_pages: + extraction_fn = extract_crops + else: + if assume_straight_text: + extraction_fn = extract_dewarped_crops + else: + extraction_fn = extract_rcrops crops = [ extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator] @@ -127,8 +142,11 @@ def _prepare_crops( loc_preds: List[np.ndarray], channels_last: bool, assume_straight_pages: bool = False, + assume_straight_text: bool = False, ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: - crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages) + crops = _OCRPredictor._generate_crops( + pages, loc_preds, channels_last, assume_straight_pages, assume_straight_text + ) # Avoid sending zero-sized crops is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops] diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 7cbf383a06..c74ea70e41 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -29,6 +29,8 @@ class OCRPredictor(nn.Module, _OCRPredictor): reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. + assume_straight_text: if True, speeds up the inference by assuming you only pass straight text + without rotated textual elements. straighten_pages: if True, estimates the page general orientation based on the median line orientation. Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped accordingly. Doing so will improve performances for documents with page-uniform rotations. @@ -44,6 +46,7 @@ def __init__( det_predictor: DetectionPredictor, reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, + assume_straight_text: bool = False, straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, @@ -57,6 +60,7 @@ def __init__( _OCRPredictor.__init__( self, assume_straight_pages, + assume_straight_text, straighten_pages, preserve_aspect_ratio, symmetric_pad, @@ -123,10 +127,11 @@ def forward( loc_preds, channels_last=channels_last, assume_straight_pages=self.assume_straight_pages, + assume_straight_text=self.assume_straight_text, ) # Rectify crop orientation and get crop orientation predictions crop_orientations: Any = [] - if not self.assume_straight_pages: + if not self.assume_straight_pages and not self.assume_straight_text: crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds) crop_orientations = [ {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index f736614879..7aba80474f 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -29,6 +29,7 @@ class OCRPredictor(NestedObject, _OCRPredictor): reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. + assume_straight_text: if True, speeds up the inference by assuming you only pass straight text straighten_pages: if True, estimates the page general orientation based on the median line orientation. Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped accordingly. Doing so will improve performances for documents with page-uniform rotations. @@ -46,6 +47,7 @@ def __init__( det_predictor: DetectionPredictor, reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, + assume_straight_text: bool = False, straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index aceae8ca43..d222c20f50 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -458,6 +458,8 @@ def extract_rcrops( _boxes[:, :, 0] *= width _boxes[:, :, 1] *= height + src_img = img if channels_last else img.transpose(1, 2, 0) + src_pts = _boxes[:, :3].astype(np.float32) # Preserve size d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1) @@ -469,7 +471,7 @@ def extract_rcrops( # Use a warp transformation to extract the crop crops = [ cv2.warpAffine( - img if channels_last else img.transpose(1, 2, 0), + src_img, # Transformation matrix cv2.getAffineTransform(src_pts[idx], dst_pts[idx]), (int(d1[idx]), int(d2[idx])), @@ -477,3 +479,85 @@ def extract_rcrops( for idx in range(_boxes.shape[0]) ] return crops # type: ignore[return-value] + + +def extract_dewarped_crops( + img: np.ndarray, polys: np.ndarray, dtype=np.float32, channels_last: bool = True +) -> List[np.ndarray]: + """Created cropped images from list of skewed/warped bounding boxes, + but containing straight text + + Args: + ---- + img: input image + polys: bounding boxes of shape (N, 4, 2) + dtype: target data type of bounding boxes + channels_last: whether the channel dimensions is the last one instead of the last one + + Returns: + ------- + list of cropped images + """ + if polys.shape[0] == 0: + return [] + if polys.shape[1:] != (4, 2): + raise AssertionError("polys are expected to be quadrilateral, of shape (N, 4, 2)") + + # Project relative coordinates + _boxes = polys.copy() + height, width = img.shape[:2] if channels_last else img.shape[-2:] + if not np.issubdtype(_boxes.dtype, np.integer): + _boxes[:, :, 0] *= width + _boxes[:, :, 1] *= height + + src_img = img if channels_last else img.transpose(1, 2, 0) + + crops = [] + + for box in _boxes: + # Sort the points according to the x-axis + box_points = box[np.argsort(box[:, 0])] + + # Divide the points into left and right + left_points = box_points[:2] + right_points = box_points[2:] + + # Sort the left points according to the y-axis + left_points = left_points[np.argsort(left_points[:, 1])] + # Sort the right points according to the y-axis + right_points = right_points[np.argsort(right_points[:, 1])] + box_points = np.concatenate([left_points, right_points]) + + # Get the width and height of the rectangle that will contain the warped quadrilateral + # Designate the width and height based on maximum side of the quadrilateral + width_upper = np.linalg.norm(box_points[0] - box_points[2]) + width_lower = np.linalg.norm(box_points[1] - box_points[3]) + height_left = np.linalg.norm(box_points[0] - box_points[1]) + height_right = np.linalg.norm(box_points[2] - box_points[3]) + + # Get the maximum width and height + rect_width = int(max(width_upper, width_lower)) + rect_height = int(max(height_left, height_right)) + + dst_pts = np.array( + [ + [0, 0], # top-left + # bottom-left + [0, rect_height - 1], + # top-right + [rect_width - 1, 0], + # bottom-right + [rect_width - 1, rect_height - 1], + ], + dtype=dtype, + ) + + # Get the perspective transform matrix using the box points + affine_mat = cv2.getPerspectiveTransform(box_points.astype(np.float32), dst_pts) + + # Perform the perspective warp to get the rectified crop + crop = cv2.warpPerspective(src_img, affine_mat, (rect_width, rect_height)) + + # Add the crop to the list of crops + crops.append(crop) + return crops # type: ignore[return-value] diff --git a/tests/common/test_utils_geometry.py b/tests/common/test_utils_geometry.py index 984019e06c..d1216161fc 100644 --- a/tests/common/test_utils_geometry.py +++ b/tests/common/test_utils_geometry.py @@ -266,3 +266,37 @@ def test_extract_rcrops(mock_pdf): # No box assert geometry.extract_rcrops(doc_img, np.zeros((0, 4, 2))) == [] + + +def test_extract_dewarped_crops(mock_pdf): + doc_img = DocumentFile.from_pdf(mock_pdf)[0] + num_crops = 2 + rel_boxes = np.array( + [ + [ + [idx / num_crops, idx / num_crops], + [idx / num_crops + 0.1, idx / num_crops], + [idx / num_crops + 0.1, idx / num_crops + 0.1], + [idx / num_crops, idx / num_crops], + ] + for idx in range(num_crops) + ], + dtype=np.float32, + ) + abs_boxes = deepcopy(rel_boxes) + abs_boxes[:, :, 0] *= doc_img.shape[1] + abs_boxes[:, :, 1] *= doc_img.shape[0] + abs_boxes = abs_boxes.astype(np.int64) + + with pytest.raises(AssertionError): + geometry.extract_dewarped_crops(doc_img, np.zeros((1, 8))) + for boxes in (rel_boxes, abs_boxes): + croped_imgs = geometry.extract_dewarped_crops(doc_img, boxes) + # Number of crops + assert len(croped_imgs) == num_crops + # Data type and shape + assert all(isinstance(crop, np.ndarray) for crop in croped_imgs) + assert all(crop.ndim == 3 for crop in croped_imgs) + + # No box + assert geometry.extract_dewarped_crops(doc_img, np.zeros((0, 4, 2))) == [] diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 9be66edd7b..311b367ee1 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -25,14 +25,17 @@ def __call__(self, loc_preds): @pytest.mark.parametrize( - "assume_straight_pages, straighten_pages", + "assume_straight_pages, straighten_pages, assume_straight_text", [ - [True, False], - [False, False], - [True, True], + [True, False, False], + [False, False, False], + [True, True, False], + [True, False, True], + [False, False, True], + [True, True, True], ], ) -def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): +def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages, assume_straight_text): det_bsize = 4 det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=det_bsize), @@ -59,6 +62,7 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa det_predictor, reco_predictor, assume_straight_pages=assume_straight_pages, + assume_straight_text=assume_straight_text, straighten_pages=straighten_pages, detect_orientation=True, detect_language=True, @@ -73,7 +77,10 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa else: assert predictor.page_orientation_predictor is None else: - assert isinstance(predictor.crop_orientation_predictor, nn.Module) + if not assume_straight_text: + assert isinstance(predictor.crop_orientation_predictor, nn.Module) + else: + assert predictor.crop_orientation_predictor is None assert isinstance(predictor.page_orientation_predictor, nn.Module) out = predictor(doc) @@ -97,8 +104,9 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) else: - # Overwrite the default orientation models - predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + if not assume_straight_text: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) out = predictor(doc) @@ -123,6 +131,7 @@ def test_trained_ocr_predictor(mock_payslip): det_predictor, reco_predictor, assume_straight_pages=True, + assume_straight_text=False, straighten_pages=True, preserve_aspect_ratio=False, resolve_blocks=True, @@ -152,6 +161,7 @@ def test_trained_ocr_predictor(mock_payslip): det_predictor, reco_predictor, assume_straight_pages=True, + assume_straight_text=False, straighten_pages=True, preserve_aspect_ratio=True, symmetric_pad=True, @@ -167,14 +177,17 @@ def test_trained_ocr_predictor(mock_payslip): @pytest.mark.parametrize( - "assume_straight_pages, straighten_pages", + "assume_straight_pages, straighten_pages, assume_straight_text", [ - [True, False], - [False, False], - [True, True], + [True, False, False], + [False, False, False], + [True, True, False], + [True, False, True], + [False, False, True], + [True, True, True], ], ) -def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): +def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages, assume_straight_text): det_bsize = 4 det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=det_bsize), @@ -201,6 +214,7 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa det_predictor, reco_predictor, assume_straight_pages=assume_straight_pages, + assume_straight_text=assume_straight_text, straighten_pages=straighten_pages, detect_orientation=True, detect_language=True, @@ -215,7 +229,10 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa else: assert predictor.page_orientation_predictor is None else: - assert isinstance(predictor.crop_orientation_predictor, nn.Module) + if not assume_straight_text: + assert isinstance(predictor.crop_orientation_predictor, nn.Module) + else: + assert predictor.crop_orientation_predictor is None assert isinstance(predictor.page_orientation_predictor, nn.Module) out = predictor(doc) @@ -239,8 +256,9 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) else: - # Overwrite the default orientation models - predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + if not assume_straight_text: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) out = predictor(doc) @@ -265,6 +283,7 @@ def test_trained_kie_predictor(mock_payslip): det_predictor, reco_predictor, assume_straight_pages=True, + assume_straight_text=False, straighten_pages=True, preserve_aspect_ratio=False, resolve_blocks=True, @@ -297,6 +316,7 @@ def test_trained_kie_predictor(mock_payslip): det_predictor, reco_predictor, assume_straight_pages=True, + assume_straight_text=False, straighten_pages=True, preserve_aspect_ratio=True, symmetric_pad=True, diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index 4b7e606563..c198f3bdb6 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -25,14 +25,17 @@ def __call__(self, loc_preds): @pytest.mark.parametrize( - "assume_straight_pages, straighten_pages", + "assume_straight_pages, straighten_pages, assume_straight_text", [ - [True, False], - [False, False], - [True, True], + [True, False, False], + [False, False, False], + [True, True, False], + [True, False, True], + [False, False, True], + [True, True, True], ], ) -def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): +def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages, assume_straight_text): det_bsize = 4 det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=det_bsize), @@ -56,6 +59,7 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa det_predictor, reco_predictor, assume_straight_pages=assume_straight_pages, + assume_straight_text=assume_straight_text, straighten_pages=straighten_pages, detect_orientation=True, detect_language=True, @@ -70,7 +74,8 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa else: assert predictor.page_orientation_predictor is None else: - assert isinstance(predictor.crop_orientation_predictor, NestedObject) + if not assume_straight_text: + assert isinstance(predictor.crop_orientation_predictor, NestedObject) assert isinstance(predictor.page_orientation_predictor, NestedObject) out = predictor(doc) @@ -122,6 +127,7 @@ def test_trained_ocr_predictor(mock_payslip): det_predictor, reco_predictor, assume_straight_pages=True, + assume_straight_text=False, straighten_pages=True, preserve_aspect_ratio=False, resolve_blocks=True, @@ -153,6 +159,7 @@ def test_trained_ocr_predictor(mock_payslip): det_predictor, reco_predictor, assume_straight_pages=True, + assume_straight_text=False, straighten_pages=True, preserve_aspect_ratio=True, symmetric_pad=True, @@ -166,14 +173,17 @@ def test_trained_ocr_predictor(mock_payslip): @pytest.mark.parametrize( - "assume_straight_pages, straighten_pages", + "assume_straight_pages, straighten_pages, assume_straight_text", [ - [True, False], - [False, False], - [True, True], + [True, False, False], + [False, False, False], + [True, True, False], + [True, False, True], + [False, False, True], + [True, True, True], ], ) -def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): +def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages, assume_straight_text): det_bsize = 4 det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=det_bsize), @@ -197,6 +207,7 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa det_predictor, reco_predictor, assume_straight_pages=assume_straight_pages, + assume_straight_text=assume_straight_text, straighten_pages=straighten_pages, detect_orientation=True, detect_language=True, @@ -211,7 +222,8 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa else: assert predictor.page_orientation_predictor is None else: - assert isinstance(predictor.crop_orientation_predictor, NestedObject) + if not assume_straight_text: + assert isinstance(predictor.crop_orientation_predictor, NestedObject) assert isinstance(predictor.page_orientation_predictor, NestedObject) out = predictor(doc) @@ -237,8 +249,9 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) else: - # Overwrite the default orientation models - predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) + if not assume_straight_text: + # Overwrite the default orientation models + predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) out = predictor(doc) @@ -263,6 +276,7 @@ def test_trained_kie_predictor(mock_payslip): det_predictor, reco_predictor, assume_straight_pages=True, + assume_straight_text=False, straighten_pages=True, preserve_aspect_ratio=False, resolve_blocks=True, @@ -295,6 +309,7 @@ def test_trained_kie_predictor(mock_payslip): det_predictor, reco_predictor, assume_straight_pages=True, + assume_straight_text=False, straighten_pages=True, preserve_aspect_ratio=True, symmetric_pad=True,