diff --git a/src/custom_types.py b/src/custom_types.py new file mode 100644 index 0000000..c44799c --- /dev/null +++ b/src/custom_types.py @@ -0,0 +1,12 @@ +UsersWatchedStatus = dict[str, str | tuple[str] | dict[str, bool | int]] +UsersWatched = ( + dict[ + frozenset[UsersWatchedStatus], + dict[ + str, + list[UsersWatchedStatus] + | dict[frozenset[UsersWatchedStatus], list[UsersWatchedStatus]], + ], + ] + | list[UsersWatchedStatus] +) diff --git a/src/functions.py b/src/functions.py index 1e76a26..8b1f69b 100644 --- a/src/functions.py +++ b/src/functions.py @@ -1,5 +1,6 @@ import os -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Any, Callable from dotenv import load_dotenv load_dotenv(override=True) @@ -8,11 +9,11 @@ mark_file = os.getenv("MARK_FILE", os.getenv("MARKFILE", "mark.log")) -def logger(message: str, log_type=0): +def logger(message: str, log_type: int = 0): debug = str_to_bool(os.getenv("DEBUG", "False")) debug_level = os.getenv("DEBUG_LEVEL", "info").lower() - output = str(message) + output: str | None = str(message) if log_type == 0: pass elif log_type == 1 and (debug and debug_level in ("info", "debug")): @@ -42,12 +43,9 @@ def log_marked( username: str, library: str, movie_show: str, - episode: str = None, - duration=None, + episode: str | None = None, + duration: float | None = None, ): - if mark_file is None: - return - output = f"{server_type}/{server_name}/{username}/{library}/{movie_show}" if episode: @@ -69,7 +67,7 @@ def str_to_bool(value: str) -> bool: # Search for nested element in list -def contains_nested(element, lst): +def contains_nested(element: str, lst: list[tuple[str] | None] | tuple[str] | None): if lst is None: return None @@ -116,33 +114,39 @@ def match_list( def future_thread_executor( - args: list, threads: int = None, override_threads: bool = False -): - futures_list = [] - results = [] - - workers = min(int(os.getenv("MAX_THREADS", 32)), os.cpu_count() * 2) - if threads: + args: list[tuple[Callable[..., Any], ...]], + threads: int | None = None, + override_threads: bool = False, +) -> list[Any]: + results: list[Any] = [] + + # Determine the number of workers, defaulting to 1 if os.cpu_count() returns None + max_threads_env: int = int(os.getenv("MAX_THREADS", 32)) + cpu_threads: int = os.cpu_count() or 1 # Default to 1 if os.cpu_count() is None + workers: int = min(max_threads_env, cpu_threads * 2) + + # Adjust workers based on threads parameter and override_threads flag + if threads is not None: workers = min(threads, workers) - if override_threads: - workers = threads + workers = threads if threads is not None else workers # If only one worker, run in main thread to avoid overhead if workers == 1: - results = [] for arg in args: results.append(arg[0](*arg[1:])) return results with ThreadPoolExecutor(max_workers=workers) as executor: + futures_list: list[Future[Any]] = [] + for arg in args: # * arg unpacks the list into actual arguments futures_list.append(executor.submit(*arg)) - for future in futures_list: + for out in futures_list: try: - result = future.result() + result = out.result() results.append(result) except Exception as e: raise Exception(e) diff --git a/src/jellyfin_emby.py b/src/jellyfin_emby.py index d8f4387..f8afa71 100644 --- a/src/jellyfin_emby.py +++ b/src/jellyfin_emby.py @@ -2,7 +2,7 @@ import traceback, os from math import floor -from typing import Literal +from typing import Any, Literal from dotenv import load_dotenv import requests from packaging.version import parse, Version @@ -15,6 +15,7 @@ str_to_bool, ) from src.library import generate_library_guids_dict +from src.custom_types import UsersWatched load_dotenv(override=True) @@ -101,10 +102,12 @@ def query( query: str, query_type: Literal["get", "post"], identifiers: dict[str, str] | None = None, - json: dict | None = None, - ) -> dict | list[dict]: + json: dict[str, float] | None = None, + ) -> dict[str, Any] | list[dict[str, Any]] | None: try: - results = None + results: ( + dict[str, list[Any] | dict[str, str]] | list[dict[str, Any]] | None + ) = None if query_type == "get": response = self.session.get( @@ -140,7 +143,7 @@ def query( raise Exception("Query result is not of type list or dict") # append identifiers to results - if identifiers: + if identifiers and results: results["Identifiers"] = identifiers return results @@ -158,13 +161,13 @@ def info( try: query_string = "/System/Info/Public" - response: dict = self.query(query_string, "get") + response: dict[str, Any] = self.query(query_string, "get") if response: if name_only: - return response.get("ServerName") + return response["ServerName"] elif version_only: - return parse(response.get("Version")) + return parse(response["Version"]) return f"{self.server_type} {response.get('ServerName')}: {response.get('Version')}" else: @@ -176,15 +179,16 @@ def info( def get_users(self) -> dict[str, str]: try: - users = {} + users: dict[str, str] = {} query_string = "/Users" - response = self.query(query_string, "get") + response: list[dict[str, str | bool]] = self.query(query_string, "get") # If response is not empty if response: for user in response: - users[user["Name"]] = user["Id"] + if isinstance(user["Name"], str) and isinstance(user["Id"], str): + users[user["Name"]] = user["Id"] return users except Exception as e: @@ -421,9 +425,11 @@ def get_user_library_watched( logger(traceback.format_exc(), 2) return {} - def get_watched(self, users, sync_libraries): + def get_watched( + self, users: dict[str, str], sync_libraries: list[str] + ) -> UsersWatched: try: - users_watched = {} + users_watched: UsersWatched = {} watched = [] for user_name, user_id in users.items(): @@ -437,7 +443,7 @@ def get_watched(self, users, sync_libraries): if library_title not in sync_libraries: continue - identifiers = { + identifiers: dict[str, str] = { "library_id": library_id, "library_title": library_title, } @@ -454,8 +460,8 @@ def get_watched(self, users, sync_libraries): if len(library["Items"]) == 0: continue - library_id = library["Identifiers"]["library_id"] - library_title = library["Identifiers"]["library_title"] + library_id: str = library["Identifiers"]["library_id"] + library_title: str = library["Identifiers"]["library_title"] # Get all library types excluding "Folder" types = set( diff --git a/src/library.py b/src/library.py index 1f26d56..a4f9464 100644 --- a/src/library.py +++ b/src/library.py @@ -3,17 +3,18 @@ match_list, search_mapping, ) +from src.custom_types import UsersWatched def check_skip_logic( - library_title, - library_type, - blacklist_library, - whitelist_library, - blacklist_library_type, - whitelist_library_type, - library_mapping=None, -): + library_title: str, + library_type: str, + blacklist_library: list[str], + whitelist_library: list[str], + blacklist_library_type: list[str], + whitelist_library_type: list[str], + library_mapping: dict[str, str] | None = None, +) -> str | None: skip_reason = None library_other = None if library_mapping: @@ -48,11 +49,11 @@ def check_skip_logic( def check_blacklist_logic( - library_title, - library_type, - blacklist_library, - blacklist_library_type, - library_other=None, + library_title: str, + library_type: str, + blacklist_library: list[str], + blacklist_library_type: list[str], + library_other: str | None = None, ): skip_reason = None if isinstance(library_type, (list, tuple, set)): @@ -84,11 +85,11 @@ def check_blacklist_logic( def check_whitelist_logic( - library_title, - library_type, - whitelist_library, - whitelist_library_type, - library_other=None, + library_title: str, + library_type: str, + whitelist_library: list[str], + whitelist_library_type: list[str], + library_other: str | None = None, ): skip_reason = None if len(whitelist_library_type) > 0: @@ -131,14 +132,14 @@ def check_whitelist_logic( def filter_libaries( - server_libraries, - blacklist_library, - blacklist_library_type, - whitelist_library, - whitelist_library_type, - library_mapping=None, -): - filtered_libaries = [] + server_libraries: dict[str, str], + blacklist_library: list[str], + blacklist_library_type: list[str], + whitelist_library: list[str], + whitelist_library_type: list[str], + library_mapping: dict[str, str] | None = None, +) -> list[str]: + filtered_libaries: list[str] = [] for library in server_libraries: skip_reason = check_skip_logic( library, @@ -162,12 +163,12 @@ def filter_libaries( def setup_libraries( server_1, server_2, - blacklist_library, - blacklist_library_type, - whitelist_library, - whitelist_library_type, - library_mapping=None, -): + blacklist_library: list[str], + blacklist_library_type: list[str], + whitelist_library: list[str], + whitelist_library_type: list[str], + library_mapping: dict[str, str] | None = None, +) -> tuple[list[str], list[str]]: server_1_libraries = server_1.get_libraries() server_2_libraries = server_2.get_libraries() logger(f"Server 1 libraries: {server_1_libraries}", 1) @@ -201,14 +202,16 @@ def setup_libraries( return output_server_1_libaries, output_server_2_libaries -def show_title_dict(user_list: dict): +def show_title_dict(user_list: UsersWatched) -> dict[str, list[tuple[str] | None]]: try: - show_output_dict = {} + if not isinstance(user_list, dict): + return {} + + show_output_dict: dict[str, list[tuple[str] | None]] = {} show_output_dict["locations"] = [] show_counter = 0 # Initialize a counter for the current show position - show_output_keys = user_list.keys() - show_output_keys = [dict(x) for x in list(show_output_keys)] + show_output_keys = [dict(x) for x in list(user_list.keys())] for show_key in show_output_keys: for provider_key, provider_value in show_key.items(): # Skip title @@ -233,9 +236,19 @@ def show_title_dict(user_list: dict): return {} -def episode_title_dict(user_list: dict): +def episode_title_dict( + user_list: UsersWatched, +) -> dict[ + str, list[str | bool | int | tuple[str] | dict[str, str | tuple[str]] | None] +]: try: - episode_output_dict = {} + if not isinstance(user_list, dict): + return {} + + episode_output_dict: dict[ + str, + list[str | bool | int | tuple[str] | dict[str, str | tuple[str]] | None], + ] = {} episode_output_dict["completed"] = [] episode_output_dict["time"] = [] episode_output_dict["locations"] = [] @@ -293,12 +306,18 @@ def episode_title_dict(user_list: dict): return {} -def movies_title_dict(user_list: dict): +def movies_title_dict( + user_list: UsersWatched, +) -> dict[str, list[str | bool | int | tuple[str] | None]]: try: - movies_output_dict = {} - movies_output_dict["completed"] = [] - movies_output_dict["time"] = [] - movies_output_dict["locations"] = [] + if not isinstance(user_list, list): + return {} + + movies_output_dict: dict[str, list[str | bool | int | tuple[str] | None]] = { + "completed": [], + "time": [], + "locations": [], + } movie_counter = 0 # Initialize a counter for the current movie position for movie in user_list: @@ -325,7 +344,13 @@ def movies_title_dict(user_list: dict): return {} -def generate_library_guids_dict(user_list: dict): +def generate_library_guids_dict( + user_list: UsersWatched, +) -> tuple[ + dict[str, list[tuple[str] | None]], + dict[str, list[str | bool | int | tuple[str] | dict[str, str | tuple[str]] | None]], + dict[str, list[str | bool | int | tuple[str] | None]], +]: # Handle the case where user_list is empty or does not contain the expected keys and values if not user_list: return {}, {}, {} diff --git a/src/plex.py b/src/plex.py index 116fcd7..82039c5 100644 --- a/src/plex.py +++ b/src/plex.py @@ -488,7 +488,7 @@ def get_users(self): logger(f"Plex: Failed to get users, Error: {e}", 2) raise Exception(e) - def get_libraries(self): + def get_libraries(self) -> dict[str, str]: try: output = {}