Skip to content

Commit

Permalink
More typing
Browse files Browse the repository at this point in the history
Signed-off-by: Luis Garcia <[email protected]>
  • Loading branch information
luigi311 committed Nov 13, 2024
1 parent 2e8261f commit c3a91e2
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 82 deletions.
12 changes: 12 additions & 0 deletions src/custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
UsersWatchedStatus = dict[str, str | tuple[str] | dict[str, bool | int]]
UsersWatched = (
dict[
frozenset[UsersWatchedStatus],
dict[
str,
list[UsersWatchedStatus]
| dict[frozenset[UsersWatchedStatus], list[UsersWatchedStatus]],
],
]
| list[UsersWatchedStatus]
)
46 changes: 25 additions & 21 deletions src/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Callable
from dotenv import load_dotenv

load_dotenv(override=True)
Expand All @@ -8,11 +9,11 @@
mark_file = os.getenv("MARK_FILE", os.getenv("MARKFILE", "mark.log"))


def logger(message: str, log_type=0):
def logger(message: str, log_type: int = 0):
debug = str_to_bool(os.getenv("DEBUG", "False"))
debug_level = os.getenv("DEBUG_LEVEL", "info").lower()

output = str(message)
output: str | None = str(message)
if log_type == 0:
pass
elif log_type == 1 and (debug and debug_level in ("info", "debug")):
Expand Down Expand Up @@ -42,12 +43,9 @@ def log_marked(
username: str,
library: str,
movie_show: str,
episode: str = None,
duration=None,
episode: str | None = None,
duration: float | None = None,
):
if mark_file is None:
return

output = f"{server_type}/{server_name}/{username}/{library}/{movie_show}"

if episode:
Expand All @@ -69,7 +67,7 @@ def str_to_bool(value: str) -> bool:


# Search for nested element in list
def contains_nested(element, lst):
def contains_nested(element: str, lst: list[tuple[str] | None] | tuple[str] | None):
if lst is None:
return None

Expand Down Expand Up @@ -116,33 +114,39 @@ def match_list(


def future_thread_executor(
args: list, threads: int = None, override_threads: bool = False
):
futures_list = []
results = []

workers = min(int(os.getenv("MAX_THREADS", 32)), os.cpu_count() * 2)
if threads:
args: list[tuple[Callable[..., Any], ...]],
threads: int | None = None,
override_threads: bool = False,
) -> list[Any]:
results: list[Any] = []

# Determine the number of workers, defaulting to 1 if os.cpu_count() returns None
max_threads_env: int = int(os.getenv("MAX_THREADS", 32))
cpu_threads: int = os.cpu_count() or 1 # Default to 1 if os.cpu_count() is None
workers: int = min(max_threads_env, cpu_threads * 2)

# Adjust workers based on threads parameter and override_threads flag
if threads is not None:
workers = min(threads, workers)

if override_threads:
workers = threads
workers = threads if threads is not None else workers

# If only one worker, run in main thread to avoid overhead
if workers == 1:
results = []
for arg in args:
results.append(arg[0](*arg[1:]))
return results

with ThreadPoolExecutor(max_workers=workers) as executor:
futures_list: list[Future[Any]] = []

for arg in args:
# * arg unpacks the list into actual arguments
futures_list.append(executor.submit(*arg))

for future in futures_list:
for out in futures_list:
try:
result = future.result()
result = out.result()
results.append(result)
except Exception as e:
raise Exception(e)
Expand Down
38 changes: 22 additions & 16 deletions src/jellyfin_emby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import traceback, os
from math import floor
from typing import Literal
from typing import Any, Literal
from dotenv import load_dotenv
import requests
from packaging.version import parse, Version
Expand All @@ -15,6 +15,7 @@
str_to_bool,
)
from src.library import generate_library_guids_dict
from src.custom_types import UsersWatched

load_dotenv(override=True)

Expand Down Expand Up @@ -101,10 +102,12 @@ def query(
query: str,
query_type: Literal["get", "post"],
identifiers: dict[str, str] | None = None,
json: dict | None = None,
) -> dict | list[dict]:
json: dict[str, float] | None = None,
) -> dict[str, Any] | list[dict[str, Any]] | None:
try:
results = None
results: (
dict[str, list[Any] | dict[str, str]] | list[dict[str, Any]] | None
) = None

if query_type == "get":
response = self.session.get(
Expand Down Expand Up @@ -140,7 +143,7 @@ def query(
raise Exception("Query result is not of type list or dict")

# append identifiers to results
if identifiers:
if identifiers and results:
results["Identifiers"] = identifiers

return results
Expand All @@ -158,13 +161,13 @@ def info(
try:
query_string = "/System/Info/Public"

response: dict = self.query(query_string, "get")
response: dict[str, Any] = self.query(query_string, "get")

if response:
if name_only:
return response.get("ServerName")
return response["ServerName"]
elif version_only:
return parse(response.get("Version"))
return parse(response["Version"])

return f"{self.server_type} {response.get('ServerName')}: {response.get('Version')}"
else:
Expand All @@ -176,15 +179,16 @@ def info(

def get_users(self) -> dict[str, str]:
try:
users = {}
users: dict[str, str] = {}

query_string = "/Users"
response = self.query(query_string, "get")
response: list[dict[str, str | bool]] = self.query(query_string, "get")

# If response is not empty
if response:
for user in response:
users[user["Name"]] = user["Id"]
if isinstance(user["Name"], str) and isinstance(user["Id"], str):
users[user["Name"]] = user["Id"]

return users
except Exception as e:
Expand Down Expand Up @@ -421,9 +425,11 @@ def get_user_library_watched(
logger(traceback.format_exc(), 2)
return {}

def get_watched(self, users, sync_libraries):
def get_watched(
self, users: dict[str, str], sync_libraries: list[str]
) -> UsersWatched:
try:
users_watched = {}
users_watched: UsersWatched = {}
watched = []

for user_name, user_id in users.items():
Expand All @@ -437,7 +443,7 @@ def get_watched(self, users, sync_libraries):
if library_title not in sync_libraries:
continue

identifiers = {
identifiers: dict[str, str] = {
"library_id": library_id,
"library_title": library_title,
}
Expand All @@ -454,8 +460,8 @@ def get_watched(self, users, sync_libraries):
if len(library["Items"]) == 0:
continue

library_id = library["Identifiers"]["library_id"]
library_title = library["Identifiers"]["library_title"]
library_id: str = library["Identifiers"]["library_id"]
library_title: str = library["Identifiers"]["library_title"]

# Get all library types excluding "Folder"
types = set(
Expand Down
Loading

0 comments on commit c3a91e2

Please sign in to comment.