diff --git a/api/README.md b/api/README.md index 200c2a164a..a9501e2542 100644 --- a/api/README.md +++ b/api/README.md @@ -45,12 +45,22 @@ should yield ```json [ { - "name": "invitation.png", - "boxes": [ - [0.50390625, 0.712890625, 0.5185546875, 0.720703125], - [0.4716796875, 0.712890625, 0.48828125, 0.720703125] + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "geometries": [ + [ + 0.724609375, + 0.1787109375, + 0.7900390625, + 0.2080078125 + ], + [ + 0.6748046875, + 0.1796875, + 0.7314453125, + 0.20703125 ] - }, + ] + } ] ``` @@ -73,9 +83,10 @@ should yield ```json [ { - "name": "invitation.png", - "value": "invite" - }, + "name": "117133599-c073fa00-ada4-11eb-831b-412de4d28341.jpeg", + "value": "invite", + "confidence": 1.0 + } ] ``` @@ -98,17 +109,61 @@ should yield ```json [ { - "name": "hello_world.jpg", - "items": [ + "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", + "orientation": { + "value": 0, + "confidence": null + }, + "language": { + "value": null, + "confidence": null + }, + "items": [ { - "value": "Hello", - "box": [0.005859375, 0.003312938981562763, 0.0205078125, 0.0332854340430202] - }, - { - "value": "world!", - "box": [0.005859375, 0.003312938981562763, 0.0205078125, 0.0332854340430202] - }, - ], + "blocks": [ + { + "geometry": [ + 0.7471996155154171, + 0.1787109375, + 0.9101580212741838, + 0.2080078125 + ], + "lines": [ + { + "geometry": [ + 0.7471996155154171, + 0.1787109375, + 0.9101580212741838, + 0.2080078125 + ], + "words": [ + { + "value": "Hello", + "geometry": [ + 0.7471996155154171, + 0.1796875, + 0.8272978149561669, + 0.20703125 + ], + "confidence": 1.0 + }, + { + "value": "world!", + "geometry": [ + 0.8176307908857315, + 0.1787109375, + 0.9101580212741838, + 0.2080078125 + ], + "confidence": 1.0 + } + ] + } + ] + } + ] + } + ] } ] ``` diff --git a/api/app/routes/detection.py b/api/app/routes/detection.py index 2e9216639e..e044d1f815 100644 --- a/api/app/routes/detection.py +++ b/api/app/routes/detection.py @@ -5,33 +5,31 @@ from typing import List -from fastapi import APIRouter, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import DetectionOut -from app.vision import det_predictor +from app.schemas import DetectionIn, DetectionOut +from app.utils import get_documents, resolve_geometry +from app.vision import init_predictor from doctr.file_utils import CLASS_NAME -from doctr.io import DocumentFile router = APIRouter() @router.post("/", response_model=List[DetectionOut], status_code=status.HTTP_200_OK, summary="Perform text detection") -async def text_detection(files: List[UploadFile] = [File(...)]): +async def text_detection(request: DetectionIn = Depends(), files: List[UploadFile] = [File(...)]): """Runs docTR text detection model to analyze the input image""" - boxes: List[DetectionOut] = [] - for file in files: - mime_type = file.content_type - if mime_type in ["image/jpeg", "image/png"]: - content = DocumentFile.from_images([await file.read()]) - elif mime_type == "application/pdf": - content = DocumentFile.from_pdf(await file.read()) - else: - raise HTTPException(status_code=400, detail=f"Unsupported file format for detection endpoint: {mime_type}") - - boxes.append( - DetectionOut( - name=file.filename or "", boxes=[box.tolist() for box in det_predictor(content)[0][CLASS_NAME][:, :-1]] - ) + try: + predictor = init_predictor(request) + content, filenames = await get_documents(files) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + return [ + DetectionOut( + name=filename, + geometries=[ + geom[:-1].tolist() if len(geom) == 5 else resolve_geometry(geom.tolist()) for geom in doc[CLASS_NAME] + ], ) - - return boxes + for doc, filename in zip(predictor(content), filenames) + ] diff --git a/api/app/routes/kie.py b/api/app/routes/kie.py index 2d947cc49e..ece3e1a8cb 100644 --- a/api/app/routes/kie.py +++ b/api/app/routes/kie.py @@ -5,45 +5,47 @@ from typing import List -from fastapi import APIRouter, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import KIEElement, KIEOut -from app.vision import kie_predictor -from doctr.io import DocumentFile +from app.schemas import KIEElement, KIEIn, KIEOut +from app.utils import get_documents, resolve_geometry +from app.vision import init_predictor router = APIRouter() @router.post("/", response_model=List[KIEOut], status_code=status.HTTP_200_OK, summary="Perform KIE") -async def perform_kie(files: List[UploadFile] = [File(...)]): +async def perform_kie(request: KIEIn = Depends(), files: List[UploadFile] = [File(...)]): """Runs docTR KIE model to analyze the input image""" - results: List[KIEOut] = [] - for file in files: - mime_type = file.content_type - if mime_type in ["image/jpeg", "image/png"]: - content = DocumentFile.from_images([await file.read()]) - elif mime_type == "application/pdf": - content = DocumentFile.from_pdf(await file.read()) - else: - raise HTTPException(status_code=400, detail=f"Unsupported file format for KIE endpoint: {mime_type}") - - out = kie_predictor(content) - - for page in out.pages: - results.append( - KIEOut( - name=file.filename or "", - predictions=[ - KIEElement( - class_name=class_name, - items=[ - dict(value=prediction.value, box=(*prediction.geometry[0], *prediction.geometry[1])) - for prediction in page.predictions[class_name] - ], + try: + predictor = init_predictor(request) + content, filenames = await get_documents(files) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + out = predictor(content) + + results = [ + KIEOut( + name=filenames[i], + orientation=page.orientation, + language=page.language, + predictions=[ + KIEElement( + class_name=class_name, + items=[ + dict( + value=prediction.value, + geometry=resolve_geometry(prediction.geometry), + confidence=round(prediction.confidence, 2), ) - for class_name in page.predictions.keys() + for prediction in page.predictions[class_name] ], ) - ) + for class_name in page.predictions.keys() + ], + ) + for i, page in enumerate(out.pages) + ] return results diff --git a/api/app/routes/ocr.py b/api/app/routes/ocr.py index 484898daae..dc18af795c 100644 --- a/api/app/routes/ocr.py +++ b/api/app/routes/ocr.py @@ -5,40 +5,58 @@ from typing import List -from fastapi import APIRouter, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import OCROut -from app.vision import predictor -from doctr.io import DocumentFile +from app.schemas import OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord +from app.utils import get_documents, resolve_geometry +from app.vision import init_predictor router = APIRouter() @router.post("/", response_model=List[OCROut], status_code=status.HTTP_200_OK, summary="Perform OCR") -async def perform_ocr(files: List[UploadFile] = [File(...)]): +async def perform_ocr(request: OCRIn = Depends(), files: List[UploadFile] = [File(...)]): """Runs docTR OCR model to analyze the input image""" - results: List[OCROut] = [] - for file in files: - mime_type = file.content_type - if mime_type in ["image/jpeg", "image/png"]: - content = DocumentFile.from_images([await file.read()]) - elif mime_type == "application/pdf": - content = DocumentFile.from_pdf(await file.read()) - else: - raise HTTPException(status_code=400, detail=f"Unsupported file format for OCR endpoint: {mime_type}") - - out = predictor(content) - for page in out.pages: - results.append( - OCROut( - name=file.filename or "", - items=[ - dict(value=word.value, box=(*word.geometry[0], *word.geometry[1])) + try: + # generator object to list + content, filenames = await get_documents(files) + predictor = init_predictor(request) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + out = predictor(content) + + results = [ + OCROut( + name=filenames[i], + orientation=page.orientation, + language=page.language, + items=[ + OCRPage( + blocks=[ + OCRBlock( + geometry=resolve_geometry(block.geometry), + lines=[ + OCRLine( + geometry=resolve_geometry(line.geometry), + words=[ + OCRWord( + value=word.value, + geometry=resolve_geometry(word.geometry), + confidence=round(word.confidence, 2), + ) + for word in line.words + ], + ) + for line in block.lines + ], + ) for block in page.blocks - for line in block.lines - for word in line.words - ], + ] ) - ) + ], + ) + for i, page in enumerate(out.pages) + ] return results diff --git a/api/app/routes/recognition.py b/api/app/routes/recognition.py index e8bf4610e4..65de3e07ba 100644 --- a/api/app/routes/recognition.py +++ b/api/app/routes/recognition.py @@ -5,11 +5,11 @@ from typing import List -from fastapi import APIRouter, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import RecognitionOut -from app.vision import reco_predictor -from doctr.io import DocumentFile +from app.schemas import RecognitionIn, RecognitionOut +from app.utils import get_documents +from app.vision import init_predictor router = APIRouter() @@ -17,18 +17,14 @@ @router.post( "/", response_model=List[RecognitionOut], status_code=status.HTTP_200_OK, summary="Perform text recognition" ) -async def text_recognition(files: List[UploadFile] = [File(...)]): +async def text_recognition(request: RecognitionIn = Depends(), files: List[UploadFile] = [File(...)]): """Runs docTR text recognition model to analyze the input image""" - words: List[RecognitionOut] = [] - for file in files: - mime_type = file.content_type - if mime_type in ["image/jpeg", "image/png"]: - content = DocumentFile.from_images([await file.read()]) - else: - raise HTTPException( - status_code=400, detail=f"Unsupported file format for recognition endpoint: {mime_type}" - ) - - words.append(RecognitionOut(name=file.filename or "", value=reco_predictor(content)[0][0])) - - return words + try: + predictor = init_predictor(request) + content, filenames = await get_documents(files) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return [ + RecognitionOut(name=filename, value=res[0], confidence=round(res[1], 2)) + for res, filename in zip(predictor(content), filenames) + ] diff --git a/api/app/schemas.py b/api/app/schemas.py index ad9ea1dd35..46a9cb0ac5 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -3,35 +3,130 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Union from pydantic import BaseModel, Field +class KIEIn(BaseModel): + det_arch: str = Field(default="db_resnet50", examples=["db_resnet50"]) + reco_arch: str = Field(default="crnn_vgg16_bn", examples=["crnn_vgg16_bn"]) + assume_straight_pages: bool = Field(default=True, examples=[True]) + preserve_aspect_ratio: bool = Field(default=True, examples=[True]) + detect_orientation: bool = Field(default=False, examples=[False]) + detect_language: bool = Field(default=False, examples=[False]) + symmetric_pad: bool = Field(default=True, examples=[True]) + straighten_pages: bool = Field(default=False, examples=[False]) + det_bs: int = Field(default=2, examples=[2]) + reco_bs: int = Field(default=128, examples=[128]) + bin_thresh: float = Field(default=0.1, examples=[0.1]) + box_thresh: float = Field(default=0.1, examples=[0.1]) + + +class OCRIn(KIEIn): + resolve_lines: bool = Field(default=True, examples=[True]) + resolve_blocks: bool = Field(default=True, examples=[True]) + paragraph_break: float = Field(default=0.0035, examples=[0.0035]) + + +class RecognitionIn(BaseModel): + reco_arch: str = Field(default="crnn_vgg16_bn", examples=["crnn_vgg16_bn"]) + reco_bs: int = Field(default=128, examples=[128]) + + +class DetectionIn(BaseModel): + det_arch: str = Field(default="db_resnet50", examples=["db_resnet50"]) + assume_straight_pages: bool = Field(default=True, examples=[True]) + preserve_aspect_ratio: bool = Field(default=True, examples=[True]) + symmetric_pad: bool = Field(default=True, examples=[True]) + det_bs: int = Field(default=2, examples=[2]) + bin_thresh: float = Field(default=0.1, examples=[0.1]) + box_thresh: float = Field(default=0.1, examples=[0.1]) + + class RecognitionOut(BaseModel): name: str = Field(..., examples=["example.jpg"]) value: str = Field(..., examples=["Hello"]) + confidence: float = Field(..., examples=[0.99]) class DetectionOut(BaseModel): name: str = Field(..., examples=["example.jpg"]) - boxes: List[Tuple[float, float, float, float]] + geometries: List[List[float]] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + + +class OCRWord(BaseModel): + value: str = Field(..., examples=["example"]) + geometry: List[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + confidence: float = Field(..., examples=[0.99]) + + +class OCRLine(BaseModel): + geometry: List[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + words: List[OCRWord] = Field( + ..., examples=[{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}] + ) + + +class OCRBlock(BaseModel): + geometry: List[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + lines: List[OCRLine] = Field( + ..., + examples=[ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "words": [{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}], + } + ], + ) + + +class OCRPage(BaseModel): + blocks: List[OCRBlock] = Field( + ..., + examples=[ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "lines": [ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "words": [{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}], + } + ], + } + ], + ) class OCROut(BaseModel): name: str = Field(..., examples=["example.jpg"]) - items: List[Dict[str, Union[str, Tuple[float, float, float, float]]]] = Field( - ..., examples=[{"value": "example", "box": [0.0, 0.0, 0.0, 0.0]}] + orientation: Dict[str, Union[float, None]] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}]) + language: Dict[str, Union[str, float, None]] = Field(..., examples=[{"value": "en", "confidence": 0.99}]) + items: List[OCRPage] = Field( + ..., + examples=[ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "lines": [ + { + "geometry": [0.0, 0.0, 0.0, 0.0], + "words": [{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}], + } + ], + } + ], ) class KIEElement(BaseModel): class_name: str = Field(..., examples=["example"]) - items: List[Dict[str, Union[str, Tuple[float, float, float, float]]]] = Field( - ..., examples=[{"value": "example", "box": [0.0, 0.0, 0.0, 0.0]}] + items: List[Dict[str, Union[str, List[float], float]]] = Field( + ..., examples=[{"value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}] ) class KIEOut(BaseModel): name: str = Field(..., examples=["example.jpg"]) + orientation: Dict[str, Union[float, None]] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}]) + language: Dict[str, Union[str, float, None]] = Field(..., examples=[{"value": "en", "confidence": 0.99}]) predictions: List[KIEElement] diff --git a/api/app/utils.py b/api/app/utils.py new file mode 100644 index 0000000000..d1897f51b1 --- /dev/null +++ b/api/app/utils.py @@ -0,0 +1,49 @@ +# Copyright (C) 2021-2024, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from typing import Any, List, Tuple, Union + +import numpy as np +from fastapi import UploadFile + +from doctr.io import DocumentFile + + +def resolve_geometry( + geom: Any, +) -> Union[Tuple[float, float, float, float], Tuple[float, float, float, float, float, float, float, float]]: + if len(geom) == 4: + return (*geom[0], *geom[1], *geom[2], *geom[3]) + return (*geom[0], *geom[1]) + + +async def get_documents(files: List[UploadFile]) -> Tuple[List[np.ndarray], List[str]]: # pragma: no cover + """Convert a list of UploadFile objects to lists of numpy arrays and their corresponding filenames + + Args: + ---- + files: list of UploadFile objects + + Returns: + ------- + Tuple[List[np.ndarray], List[str]]: list of numpy arrays and their corresponding filenames + + """ + filenames = [] + docs = [] + for file in files: + mime_type = file.content_type + if mime_type in ["image/jpeg", "image/png"]: + docs.extend(DocumentFile.from_images([await file.read()])) + filenames.append(file.filename or "") + elif mime_type == "application/pdf": + pdf_content = DocumentFile.from_pdf(await file.read()) + docs.extend(pdf_content) + filenames.append(file.filename or "" * len(pdf_content)) + else: + raise ValueError(f"Unsupported file format: {mime_type} for file {file.filename}") + + return docs, filenames diff --git a/api/app/vision.py b/api/app/vision.py index 0ec3f73d5e..005c8d1548 100644 --- a/api/app/vision.py +++ b/api/app/vision.py @@ -3,15 +3,45 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. + import tensorflow as tf gpu_devices = tf.config.experimental.list_physical_devices("GPU") if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) +from typing import Callable, Union + from doctr.models import kie_predictor, ocr_predictor -predictor = ocr_predictor(pretrained=True, assume_straight_pages=True) -det_predictor = predictor.det_predictor -reco_predictor = predictor.reco_predictor -kie_predictor = kie_predictor(pretrained=True, assume_straight_pages=True) +from .schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn + + +def init_predictor(request: Union[KIEIn, OCRIn, RecognitionIn, DetectionIn]) -> Callable: + """Initialize the predictor based on the request + + Args: + ---- + request: input request + + Returns: + ------- + Callable: the predictor + """ + params = request.model_dump() + bin_thresh = params.pop("bin_thresh", None) + box_thresh = params.pop("box_thresh", None) + if isinstance(request, (OCRIn, RecognitionIn, DetectionIn)): + predictor = ocr_predictor(pretrained=True, **params) + predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh + predictor.det_predictor.model.postprocessor.box_thresh = box_thresh + if isinstance(request, DetectionIn): + return predictor.det_predictor + elif isinstance(request, RecognitionIn): + return predictor.reco_predictor + return predictor + elif isinstance(request, KIEIn): + predictor = kie_predictor(pretrained=True, **params) + predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh + predictor.det_predictor.model.postprocessor.box_thresh = box_thresh + return predictor diff --git a/api/tests/routes/test_detection.py b/api/tests/routes/test_detection.py index 5c6852d1eb..05f54a11e9 100644 --- a/api/tests/routes/test_detection.py +++ b/api/tests/routes/test_detection.py @@ -17,9 +17,9 @@ async def test_text_detection(test_app_asyncio, mock_detection_image, mock_txt_f # Check that IoU with GT if reasonable assert isinstance(json_response, list) and len(json_response) == 2 - first_pred = json_response[0] - assert isinstance(first_pred, dict) and len(first_pred["boxes"]) == gt_boxes.shape[0] - pred_boxes = np.array(first_pred["boxes"]) + first_pred = json_response[0] # it's enough to test for the first file because the same image is used twice + assert isinstance(first_pred, dict) and len(first_pred["geometries"]) == gt_boxes.shape[0] + pred_boxes = np.array(first_pred["geometries"]) iou_mat = box_iou(gt_boxes, pred_boxes) gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8 diff --git a/api/tests/routes/test_kie.py b/api/tests/routes/test_kie.py index 60fcec7e0a..2b0c9b3b38 100644 --- a/api/tests/routes/test_kie.py +++ b/api/tests/routes/test_kie.py @@ -18,9 +18,13 @@ async def test_perform_kie(test_app_asyncio, mock_detection_image, mock_txt_file # Check that IoU with GT if reasonable assert isinstance(json_response, list) and len(json_response) == 2 - first_pred = json_response[0] - assert isinstance(first_pred, dict) and len(first_pred["predictions"]["items"]) == gt_boxes.shape[0] - pred_boxes = np.array([elt["box"] for elt in first_pred["predictions"]["items"]]) + first_pred = json_response[0] # it's enough to test for the first file because the same image is used twice + assert ( + isinstance(first_pred, dict) + and len(first_pred["predictions"]["items"]) == gt_boxes.shape[0] + and isinstance(first_pred["predictions"]["class_name"], str) + ) + pred_boxes = np.array([elt["geometry"] for elt in first_pred["predictions"]["items"]]) pred_labels = np.array([elt["value"] for elt in first_pred["predictions"]["items"]]) iou_mat = box_iou(gt_boxes, pred_boxes) gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) diff --git a/api/tests/routes/test_ocr.py b/api/tests/routes/test_ocr.py index a896181948..aa678c27ee 100644 --- a/api/tests/routes/test_ocr.py +++ b/api/tests/routes/test_ocr.py @@ -18,7 +18,7 @@ async def test_perform_ocr(test_app_asyncio, mock_detection_image, mock_txt_file # Check that IoU with GT if reasonable assert isinstance(json_response, list) and len(json_response) == 2 - first_pred = json_response[0] + first_pred = json_response[0] # it's enough to test for the first file because the same image is used twice assert isinstance(first_pred, dict) and len(first_pred["items"]) == gt_boxes.shape[0] pred_boxes = np.array([elt["box"] for elt in first_pred["items"]]) pred_labels = np.array([elt["value"] for elt in first_pred["items"]]) diff --git a/api/tests/utils/test_utils.py b/api/tests/utils/test_utils.py new file mode 100644 index 0000000000..b346565feb --- /dev/null +++ b/api/tests/utils/test_utils.py @@ -0,0 +1,26 @@ +from app.utils import resolve_geometry + + +def test_resolve_geometry(): + dummy_box = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)] + dummy_polygon = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)] + + assert resolve_geometry(dummy_box) == (0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0) + assert resolve_geometry(dummy_polygon) == [ + 0.0, + 0.0, + 1.0, + 0.0, + 1.0, + 1.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 1.0, + 1.0, + 0.0, + 1.0, + ] diff --git a/api/tests/utils/test_vision.py b/api/tests/utils/test_vision.py new file mode 100644 index 0000000000..04050268f7 --- /dev/null +++ b/api/tests/utils/test_vision.py @@ -0,0 +1,13 @@ +from app.schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn +from app.vision import init_predictor +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.kie_predictor import KIEPredictor +from doctr.models.predictor import OCRPredictor +from doctr.models.recognition.predictor import RecognitionPredictor + + +def test_vision(): + assert isinstance(init_predictor(OCRIn), OCRPredictor) + assert isinstance(init_predictor(DetectionIn), DetectionPredictor) + assert isinstance(init_predictor(RecognitionIn), RecognitionPredictor) + assert isinstance(init_predictor(KIEIn), KIEPredictor)