Skip to content

Commit

Permalink
implemented Similarity index generation to gen_cfeatures.py .
Browse files Browse the repository at this point in the history
  • Loading branch information
ryogrid committed Nov 7, 2024
1 parent 3953fc1 commit eec193a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 61 deletions.
153 changes: 92 additions & 61 deletions gen_cfeatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import os.path
from functools import lru_cache
from io import TextIOWrapper
from typing import Union, List, Optional

import numpy as np
Expand All @@ -18,6 +19,7 @@
from imgutils.data import MultiImagesTyping, load_images, ImageTyping
from imgutils.utils import open_onnx_model
from onnxruntime import InferenceSession
from gensim.similarities import Similarity

try:
from typing import Literal
Expand Down Expand Up @@ -57,6 +59,8 @@ class Predictor:
def __init__(self) -> None:
self.embed_model: Optional[InferenceSession] = None
self.threshold: float = -1.0
self.f: Optional[TextIOWrapper] = None
self.cindex: Optional[Similarity] = None
# self.tagger_model: Optional[nn.Module] = None


Expand Down Expand Up @@ -91,8 +95,8 @@ def list_files_recursive(self, dir_path: str) -> List[str]:
#
# return padded_image

# def write_to_file(self, csv_line: str) -> None:
# self.f.write(csv_line + '\n')
def write_to_file(self, csv_line: str) -> None:
self.f.write(csv_line + '\n')

def filter_files_by_date(self, file_list: List[str], added_date: datetime.date) -> List[str]:
filtered_list: List[str] = []
Expand Down Expand Up @@ -262,13 +266,21 @@ def gen_image_ndarray(self, file_path) -> np.ndarray | None:
print(err_msg)
return None

def write_vecs_to_index(self, vecs: np.ndarray) -> bool:
for vec in vecs:
if self.cindex is None:
self.cindex = Similarity('charactor-featues-idx', [vec], num_features=768)
else:
self.cindex.add_documents([vec])

def process_directory(self, dir_path: str, added_date: datetime.date | None = None) -> None:
file_list: List[str] = self.list_files_recursive(dir_path)
print(f'{len(file_list)} files found')

# self.load_model()
self.embed_model = self._open_feat_model(_DEFAULT_MODEL_NAMES)
self.threshold = self.ccip_default_threshold(_DEFAULT_MODEL_NAMES)
self.f = open('charactor-featues-idx.csv', 'a', encoding='utf-8')

ndarrs: List[np.ndarray] = []
fpathes: List[str] = []
Expand All @@ -277,65 +289,84 @@ def process_directory(self, dir_path: str, added_date: datetime.date | None = No
cnt: int = 0
failed_cnt: int = 0
passed_idx: int = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=WORKER_NUM) as executor:
# dispatch get Tensor task to processes
future_to_path = {executor.submit(self.gen_image_ndarray, file_path): file_path for file_path in
file_list[0: BATCH_SIZE]}
passed_idx += BATCH_SIZE
while passed_idx < len(file_list):
for future in concurrent.futures.as_completed(future_to_path):
path = future_to_path[future]
try:
ndarr = future.result()
if ndarr is None:
failed_cnt += 1
cnt -= 1
# continue

if ndarr is not None:
ndarrs.append(ndarr)
fpathes.append(path)

if len(ndarrs) >= BATCH_SIZE - failed_cnt:
# submit load Tensor tasks for next batch
end_idx = passed_idx + BATCH_SIZE
if end_idx > len(file_list):
end_idx = len(file_list)
future_to_path = {executor.submit(self.gen_image_ndarray, file_path): file_path for file_path
in file_list[passed_idx: end_idx]}
passed_idx = end_idx

# run inference
# dimension of results: (batch_size, 768)
results: np.ndarray = self.predict(ndarrs)
# for idx, line in enumerate(results_in_csv_format):
# self.write_to_file(fpathes[idx] + ',' + line)
# for arr in results:
# print(arr.astype(float))
ndarrs = []
fpathes = []
failed_cnt = 0

cnt += 1

if cnt - last_cnt >= PROGRESS_INTERVAL:
now: float = time.perf_counter()
print(f'{cnt} files processed')
diff: float = now - start
print('{:.2f} seconds elapsed'.format(diff))
if cnt > 0:
time_per_file: float = diff / cnt
print('{:.4f} seconds per file'.format(time_per_file))
print("", flush=True)
last_cnt = cnt

except Exception as e:
error_class: type = type(e)
error_description: str = str(e)
err_msg: str = '%s: %s' % (error_class, error_description)
print(err_msg)
print_traceback()
continue
future_to_vec: dict[concurrent.futures.Future[np.ndarray], bool] = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor_vec_write:
with concurrent.futures.ThreadPoolExecutor(max_workers=WORKER_NUM) as executor:
# dispatch get Tensor task to processes
future_to_path = {executor.submit(self.gen_image_ndarray, file_path): file_path for file_path in
file_list[0: BATCH_SIZE]}
passed_idx += BATCH_SIZE
while passed_idx < len(file_list):
for future in concurrent.futures.as_completed(future_to_path):
path = future_to_path[future]
try:
ndarr = future.result()
if ndarr is None:
failed_cnt += 1
cnt -= 1
# continue

if ndarr is not None:
ndarrs.append(ndarr)
fpathes.append(path)

if len(ndarrs) >= BATCH_SIZE - failed_cnt:
# submit load Tensor tasks for next batch
end_idx = passed_idx + BATCH_SIZE
if end_idx > len(file_list):
end_idx = len(file_list)
future_to_path = {executor.submit(self.gen_image_ndarray, file_path): file_path for file_path
in file_list[passed_idx: end_idx]}
passed_idx = end_idx

# run inference
# dimension of results: (batch_size, 768)
results: np.ndarray = self.predict(ndarrs)
for idx in range(0, len(results)):
self.write_to_file(fpathes[idx])
# submit write to index tasks to another thread
future_to_vec[executor_vec_write.submit(self.write_vecs_to_index, results)] = True
# for idx, line in enumerate(results_in_csv_format):
# self.write_to_file(fpathes[idx] + ',' + line)
# for arr in results:
# print(arr.astype(float))
ndarrs = []
fpathes = []
failed_cnt = 0

cnt += 1

if cnt - last_cnt >= PROGRESS_INTERVAL:
now: float = time.perf_counter()
print(f'{cnt} files processed')
diff: float = now - start
print('{:.2f} seconds elapsed'.format(diff))
if cnt > 0:
time_per_file: float = diff / cnt
print('{:.4f} seconds per file'.format(time_per_file))
print("", flush=True)
last_cnt = cnt

except Exception as e:
error_class: type = type(e)
error_description: str = str(e)
err_msg: str = '%s: %s' % (error_class, error_description)
print(err_msg)
print_traceback()
continue

# wait for all tasks to be finished
for future in concurrent.futures.as_completed(future_to_vec):
try:
future.result()
except Exception as e:
error_class: type = type(e)
error_description: str = str(e)
err_msg: str = '%s: %s' % (error_class, error_description)
print(err_msg)
print_traceback()
continue
self.cindex.save('charactor-featues-idx')

def main(arg_str: list[str]) -> None:
parser: argparse.ArgumentParser = argparse.ArgumentParser()
Expand Down
Binary file modified requirements_cu121.txt
Binary file not shown.

0 comments on commit eec193a

Please sign in to comment.