diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index 0403b1247b..792e870347 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -9,8 +9,25 @@ from doctr.models import ocr_predictor from doctr.models.predictor import OCRPredictor -DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet50_rotation"] -RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"] +DET_ARCHS = [ + "db_resnet50", + "db_resnet34", + "db_mobilenet_v3_large", + "db_resnet50_rotation", + "linknet_resnet18", + "linknet_resnet34", + "linknet_resnet50", +] +RECO_ARCHS = [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "master", + "sar_resnet31", + "vitstr_small", + "vitstr_base", + "parseq", +] def load_predictor(det_arch: str, reco_arch: str, device) -> OCRPredictor: diff --git a/demo/backend/tensorflow.py b/demo/backend/tensorflow.py index c787a2db45..8eaccca364 100644 --- a/demo/backend/tensorflow.py +++ b/demo/backend/tensorflow.py @@ -9,8 +9,24 @@ from doctr.models import ocr_predictor from doctr.models.predictor import OCRPredictor -DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18_rotation"] -RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"] +DET_ARCHS = [ + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + "linknet_resnet18_rotation", + "linknet_resnet34", + "linknet_resnet50", +] +RECO_ARCHS = [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "master", + "sar_resnet31", + "vitstr_small", + "vitstr_base", + "parseq", +] def load_predictor(det_arch: str, reco_arch: str, device) -> OCRPredictor: