Skip to content

Commit

Permalink
fix fromfileinterrogator
Browse files Browse the repository at this point in the history
  • Loading branch information
Roel Kluin committed Sep 19, 2023
1 parent 6d68c96 commit 486edca
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 26 deletions.
7 changes: 7 additions & 0 deletions default/interrogators.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@
"zip": "https://github.com/KichangKim/DeepDanbooru/releases/download/v4-20200814-sgd-e30/deepdanbooru-v4-20200814-sgd-e30.zip"
}
},
"FromFileInterrogator": {
"[name].[hash:sha1].[output_extension]": {
"format": "[name].[hash:sha1].[output_extension]",
"path": "",
"value": 1.0
}
},
"MLDanbooruInterrogator": {
"mld-caformer.dec-5-97527" : {
"model_path" : "ml_caformer_m36_dec-5-97527.onnx",
Expand Down
3 changes: 2 additions & 1 deletion json_schema/interrogators_v1_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
"type": "object",
"properties": {
"path": { "type": "string" },
"val": { "type": "number", "default": 1.0 }
"format": { "type": "string", "default": "[name].[output_extension]" },
"value": { "type": "number", "default": 1.0 }
},
"required": [
"path"
Expand Down
46 changes: 30 additions & 16 deletions tagger/interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from modules import shared
from tagger import settings # pylint: disable=import-error
from tagger.uiset import QData, IOData # pylint: disable=import-error
from tagger.uiset import QData, IOData, supported_extensions, \
format_output_filename
from preload import root_dir
from . import dbimutils # pylint: disable=import-error # noqa

Expand Down Expand Up @@ -487,33 +488,37 @@ def load_model(self, model_path) -> None:

class FromFileInterrogator(Interrogator):
""" Pseudo Interrogator reading preinterrogated tags files """
def __init__(self, name: str, path: os.PathLike, val=1.0) -> None:
def __init__(
self, name: str, path, format='[name].[output_extension]', value=1.0
) -> None:
super().__init__(name)
self.path = path
self.val = val
self.tags = None
self.path = Path(self.path)
self.val = value
self.format = format
if format == Its.output_filename_format and path == '':
raise ValueError(f"tagsfiles will ne overwritten with {format}")
self.tags = {}

def load(self) -> None:
print(f'Loading {self.name} from {str(self.path)}')
if self.path == '':
return

# self.path is a directory
if not os.path.isdir(self.path):
if not os.path.isdir(Path(self.path)):
raise ValueError(f'{self.path} is not a directory')
else:
self.tags = {}
for f in os.listdir(self.path):
self.tags[f] = {}
self.load_file(f)

def load_file(self, tags_file: str) -> None:
image_name = str(tags_file).split('/')[-1].split('.')[0]
basename = '.'.join(str(tags_file).split('/')[-1].split('.')[:-1])
self.tags[basename] = {}
with open(tags_file, 'r') as f:
for line in f:
for x in map(str.split, line.split(',')):
for x in map(str.strip, line.split(',')):
if x[0] == '(' and x[-1] == ')' and ':' in x:
tag, val = x[1:-1].split(':')
self.tags[image_name][tag] = float(val)
self.tags[basename][tag] = float(val)
else:
self.tags[image_name][x] = self.val
self.tags[basename][x] = self.val

def unload(self) -> None:
self.tags = {}
Expand All @@ -525,7 +530,16 @@ def interrogate(
Dict[str, float], # rating confidences
Dict[str, float] # tag confidences
]:
return {}, self.tags[image.filename]
basename = '.'.join(image.filename.split('/')[-1].split('.')[:-1])
path = Path(image.filename)
tags_filename = format_output_filename(path, self.format)
if self.path == '':
dir = path.parent
else:
dir = self.path
tags_filename = os.path.join(dir, tags_filename)
self.load_file(tags_filename)
return {}, self.tags[basename]


class WaifuDiffusionInterrogator(HFInterrogator):
Expand Down
18 changes: 9 additions & 9 deletions tagger/uiset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
]


def format_output_filename(path: Path, format='[name].[output_extension]') -> str:
info = tags_format.Info(path, 'txt')
fmt = partial(lambda info, m: tags_format.parse(m, info), info)
return tags_format.pattern.sub(fmt, format)


class IOData:
""" data class for input and output paths """
last_path_mtimes = None
Expand Down Expand Up @@ -119,7 +125,7 @@ def update_input_glob(cls, input_glob: str) -> None:
path_mtimes = []
for filename in glob(input_glob, recursive=recursive):
if not os.path.isdir(filename):
ext = os.path.splitext(filename)[1].lower()
ext = os.path.splitext(filename)[-1].lower()
if ext in supported_extensions:
path_mtimes.append(os.path.getmtime(filename))
paths.append(filename)
Expand Down Expand Up @@ -163,16 +169,11 @@ def set_batch_io(cls, paths: List[str]) -> None:
base_dir_last_idx = path.parts.index(cls.base_dir_last)
# format output filename

info = tags_format.Info(path, 'txt')
fmt = partial(lambda info, m: tags_format.parse(m, info), info)

msg = 'Invalid output format'
cls.err.discard(msg)
try:
formatted_output_filename = tags_format.pattern.sub(
fmt,
Its.output_filename_format
)
formatted_output_filename = format_output_filename(
path, format=Its.output_filename_format)
except (TypeError, ValueError):
cls.err.add(msg)

Expand Down Expand Up @@ -483,7 +484,6 @@ def correct_tag(cls, tag: str) -> str:
if re_match(regex, tag):
tag = re_sub(regex, cls.replace_tags[i], tag)
break

return tag

@classmethod
Expand Down

0 comments on commit 486edca

Please sign in to comment.