Skip to content

Commit

Permalink
Update to use note ids from .getNote() rather than a query
Browse files Browse the repository at this point in the history
Fixes #16  and #18
  • Loading branch information
cfculhane committed May 2, 2021
1 parent 479e1af commit d69fd85
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 30 deletions.
4 changes: 2 additions & 2 deletions anki_ocr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ def add_imgdata_to_note(self, method="tooltip"):
class NotesQuery:
""" Represents a collection of Notes from a query of the Collection db"""
col: Collection
query: str = ""
note_ids: List[int]
notes: List[OCRNote] = None
notes_to_process: List[OCRNote] = None

def __post_init__(self):
self.notes = [OCRNote(note_id=nid, col=self.col) for nid in self.col.findNotes(query=self.query)]
self.notes = [OCRNote(note_id=nid, col=self.col) for nid in self.note_ids]

def __len__(self):
return len(self.notes)
Expand Down
15 changes: 6 additions & 9 deletions anki_ocr/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from anki.storage import Collection

from .api import OCRNote, NotesQuery, OCRImage
from .utils import batch, format_note_id_query
from .utils import batch

ANKI_ENV = "python" not in Path(sys.executable).stem

Expand Down Expand Up @@ -260,12 +260,12 @@ def _ocr_img(img_pth: Union[Path, str, PathLike], num_threads: int, languages: L
return pytesseract.image_to_string(str(img_pth), lang="+".join(languages or ["eng"]),
config=tessdata_config)

def run_ocr_on_query(self, query: str) -> NotesQuery:
def run_ocr_on_query(self, note_ids: List[int]) -> NotesQuery:
""" Main method for the ocr class. Runs OCR on a sequence of notes returned from a collection query.
:param query: Query to collection, see https://docs.ankiweb.net/#/searching for more info.
:param note_ids: Note id's to process
"""
notes_query = NotesQuery(col=self.col, query=query)
notes_query = NotesQuery(col=self.col, note_ids=note_ids)
# self.col.modSchema(check=True)
if self.use_batching:
logger.info(f"Processing {len(notes_query)} notes with _ocr_batch_process() ...")
Expand Down Expand Up @@ -301,8 +301,7 @@ def run_ocr_on_notes(self, note_ids: List[int]) -> NotesQuery:
:param note_ids: List of note ids
"""
# self.col.modSchema(check=True)
query_str = format_note_id_query(note_ids=note_ids)
notes_query = self.run_ocr_on_query(query=query_str)
notes_query = self.run_ocr_on_query(note_ids=note_ids)
return notes_query

def remove_ocr_on_notes(self, note_ids: List[int]):
Expand All @@ -311,7 +310,7 @@ def remove_ocr_on_notes(self, note_ids: List[int]):
:param note_ids: List of note ids
"""
# self.col.modSchema(check=True)
query_notes = NotesQuery(col=self.col, query=format_note_id_query(note_ids))
query_notes = NotesQuery(col=self.col, note_ids=note_ids)
for note in query_notes:
note.remove_OCR_text()
self.col.reset()
Expand All @@ -325,5 +324,3 @@ def path_to_tesseract() -> str:

platform_name = platform.system() # E.g. 'Windows'
return exec_data[platform_name]


5 changes: 0 additions & 5 deletions anki_ocr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ def create_ocr_logger():
return ocr_logger


def format_note_id_query(note_ids: List[int]) -> str:
"""Generates an anki db query string from a list of note ids"""
return f"{' OR '.join([f'nid:{nid}' for nid in note_ids])}"


def batch(it: Iterable, batch_size: int):
"""Batches an Iterable into batches of at most batch_size in length"""
it = iter(it)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from anki_ocr.api import NotesQuery
from anki_ocr.ocr import OCR
from anki_ocr.utils import format_note_id_query

TESTDATA_DIR = Path(__file__).parent / "testdata"
TEMPLATE_COLLECTION_PTH = TESTDATA_DIR / "test_collection_template" / "collection.anki2"
Expand Down Expand Up @@ -65,16 +64,16 @@ def test_gen_queryimages(self, tmpdir):
col_dir = tmpdir.mkdir("collection")
test_col = gen_test_collection(col_dir)
ocr = OCR(col=test_col)
q_images = NotesQuery(col=test_col, query="")
all_note_ids = ocr.col.db.list("select * from notes")
q_images = NotesQuery(col=test_col, note_ids=all_note_ids)
print(q_images)

def test_query_noteids(self, tmpdir):
col_dir = tmpdir.mkdir("collection")
test_col = gen_test_collection(col_dir)
ocr = OCR(col=test_col)
note_ids = [1601851621708, 1601851571572]
query = format_note_id_query(note_ids)
q_images = NotesQuery(col=test_col, query=query)
q_images = NotesQuery(col=test_col, note_ids=note_ids)
assert len(q_images.notes) == 2
for note in q_images.notes:
assert note.note_id in note_ids
Expand All @@ -83,7 +82,8 @@ def test_run_ocr_on_collection(self, tmpdir):
col_dir = tmpdir.mkdir("collection")
test_col = gen_test_collection(col_dir)
ocr = OCR(col=test_col)
ocr.run_ocr_on_query(query="")
all_note_ids = ocr.col.db.list("select * from notes")
ocr.run_ocr_on_query(note_ids=all_note_ids)

def test_run_ocr_on_notes_batched_multithreaded(self, tmpdir):
col_dir = tmpdir.mkdir("collection")
Expand Down
9 changes: 0 additions & 9 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1 @@
import pytest
from anki_ocr.utils import format_note_id_query
NOTE_ID_QUERY_EXPECTED = [
([1601851621708, 1601851571572], "nid:1601851621708 OR nid:1601851571572"),
([1601851621708], "nid:1601851621708")
]
@pytest.mark.parametrize(["note_ids", "expected"], NOTE_ID_QUERY_EXPECTED)
def test_format_note_id_query(note_ids, expected):
output = format_note_id_query(note_ids=note_ids)
assert output == expected

0 comments on commit d69fd85

Please sign in to comment.