diff --git a/anki_ocr/api.py b/anki_ocr/api.py index 3f15a6a..920fa9f 100644 --- a/anki_ocr/api.py +++ b/anki_ocr/api.py @@ -218,12 +218,12 @@ def add_imgdata_to_note(self, method="tooltip"): class NotesQuery: """ Represents a collection of Notes from a query of the Collection db""" col: Collection - query: str = "" + note_ids: List[int] notes: List[OCRNote] = None notes_to_process: List[OCRNote] = None def __post_init__(self): - self.notes = [OCRNote(note_id=nid, col=self.col) for nid in self.col.findNotes(query=self.query)] + self.notes = [OCRNote(note_id=nid, col=self.col) for nid in self.note_ids] def __len__(self): return len(self.notes) diff --git a/anki_ocr/ocr.py b/anki_ocr/ocr.py index 73eea43..724a3d6 100644 --- a/anki_ocr/ocr.py +++ b/anki_ocr/ocr.py @@ -18,7 +18,7 @@ from anki.storage import Collection from .api import OCRNote, NotesQuery, OCRImage -from .utils import batch, format_note_id_query +from .utils import batch ANKI_ENV = "python" not in Path(sys.executable).stem @@ -260,12 +260,12 @@ def _ocr_img(img_pth: Union[Path, str, PathLike], num_threads: int, languages: L return pytesseract.image_to_string(str(img_pth), lang="+".join(languages or ["eng"]), config=tessdata_config) - def run_ocr_on_query(self, query: str) -> NotesQuery: + def run_ocr_on_query(self, note_ids: List[int]) -> NotesQuery: """ Main method for the ocr class. Runs OCR on a sequence of notes returned from a collection query. - :param query: Query to collection, see https://docs.ankiweb.net/#/searching for more info. + :param note_ids: Note id's to process """ - notes_query = NotesQuery(col=self.col, query=query) + notes_query = NotesQuery(col=self.col, note_ids=note_ids) # self.col.modSchema(check=True) if self.use_batching: logger.info(f"Processing {len(notes_query)} notes with _ocr_batch_process() ...") @@ -301,8 +301,7 @@ def run_ocr_on_notes(self, note_ids: List[int]) -> NotesQuery: :param note_ids: List of note ids """ # self.col.modSchema(check=True) - query_str = format_note_id_query(note_ids=note_ids) - notes_query = self.run_ocr_on_query(query=query_str) + notes_query = self.run_ocr_on_query(note_ids=note_ids) return notes_query def remove_ocr_on_notes(self, note_ids: List[int]): @@ -311,7 +310,7 @@ def remove_ocr_on_notes(self, note_ids: List[int]): :param note_ids: List of note ids """ # self.col.modSchema(check=True) - query_notes = NotesQuery(col=self.col, query=format_note_id_query(note_ids)) + query_notes = NotesQuery(col=self.col, note_ids=note_ids) for note in query_notes: note.remove_OCR_text() self.col.reset() @@ -325,5 +324,3 @@ def path_to_tesseract() -> str: platform_name = platform.system() # E.g. 'Windows' return exec_data[platform_name] - - diff --git a/anki_ocr/utils.py b/anki_ocr/utils.py index 779cb4d..033669e 100644 --- a/anki_ocr/utils.py +++ b/anki_ocr/utils.py @@ -32,11 +32,6 @@ def create_ocr_logger(): return ocr_logger -def format_note_id_query(note_ids: List[int]) -> str: - """Generates an anki db query string from a list of note ids""" - return f"{' OR '.join([f'nid:{nid}' for nid in note_ids])}" - - def batch(it: Iterable, batch_size: int): """Batches an Iterable into batches of at most batch_size in length""" it = iter(it) diff --git a/tests/test_ocr.py b/tests/test_ocr.py index b99d036..cce6835 100644 --- a/tests/test_ocr.py +++ b/tests/test_ocr.py @@ -8,7 +8,6 @@ from anki_ocr.api import NotesQuery from anki_ocr.ocr import OCR -from anki_ocr.utils import format_note_id_query TESTDATA_DIR = Path(__file__).parent / "testdata" TEMPLATE_COLLECTION_PTH = TESTDATA_DIR / "test_collection_template" / "collection.anki2" @@ -65,7 +64,8 @@ def test_gen_queryimages(self, tmpdir): col_dir = tmpdir.mkdir("collection") test_col = gen_test_collection(col_dir) ocr = OCR(col=test_col) - q_images = NotesQuery(col=test_col, query="") + all_note_ids = ocr.col.db.list("select * from notes") + q_images = NotesQuery(col=test_col, note_ids=all_note_ids) print(q_images) def test_query_noteids(self, tmpdir): @@ -73,8 +73,7 @@ def test_query_noteids(self, tmpdir): test_col = gen_test_collection(col_dir) ocr = OCR(col=test_col) note_ids = [1601851621708, 1601851571572] - query = format_note_id_query(note_ids) - q_images = NotesQuery(col=test_col, query=query) + q_images = NotesQuery(col=test_col, note_ids=note_ids) assert len(q_images.notes) == 2 for note in q_images.notes: assert note.note_id in note_ids @@ -83,7 +82,8 @@ def test_run_ocr_on_collection(self, tmpdir): col_dir = tmpdir.mkdir("collection") test_col = gen_test_collection(col_dir) ocr = OCR(col=test_col) - ocr.run_ocr_on_query(query="") + all_note_ids = ocr.col.db.list("select * from notes") + ocr.run_ocr_on_query(note_ids=all_note_ids) def test_run_ocr_on_notes_batched_multithreaded(self, tmpdir): col_dir = tmpdir.mkdir("collection") diff --git a/tests/test_util.py b/tests/test_util.py index 61d6bd1..5871ed8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,10 +1 @@ import pytest -from anki_ocr.utils import format_note_id_query -NOTE_ID_QUERY_EXPECTED = [ - ([1601851621708, 1601851571572], "nid:1601851621708 OR nid:1601851571572"), - ([1601851621708], "nid:1601851621708") -] -@pytest.mark.parametrize(["note_ids", "expected"], NOTE_ID_QUERY_EXPECTED) -def test_format_note_id_query(note_ids, expected): - output = format_note_id_query(note_ids=note_ids) - assert output == expected