Skip to content

Commit

Permalink
Refactor mode/config system
Browse files Browse the repository at this point in the history
  • Loading branch information
zwimer committed Oct 12, 2024
1 parent 170214a commit 56b8cd4
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 387 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py312-plus]
args: [--py310-plus]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
hooks:
Expand Down
58 changes: 16 additions & 42 deletions rpipe/client/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dataclasses import dataclass
from datetime import datetime
from collections import deque
from functools import partial
from logging import getLogger
from json import loads, dumps
from base64 import b85encode
from functools import cache
from pathlib import Path
import zlib

Expand All @@ -15,7 +15,7 @@
from requests import Session

from ..shared import QueryResponse, AdminMessage, AdminEC, version
from .config import ConfigFile, UsageError, Option
from .client import Config, UsageError

if TYPE_CHECKING:
from collections.abc import Callable
Expand All @@ -38,12 +38,6 @@ class AccessDenied(RuntimeError):
"""


class IllegalVersion(UsageError):
"""
Raised when the server is running a version that is not supported
"""


#
# Helper Classes
#
Expand Down Expand Up @@ -113,7 +107,7 @@ def _request(self, path: str, body: str = "") -> Response:
self._log.critical("Admin access denied")
raise AccessDenied()
case AdminEC.illegal_version:
raise IllegalVersion(ret.text, log=self._log.critical)
raise UsageError(ret.text)
if not ret.ok:
what = f"Error {ret.status_code}: {ret.text}"
self._log.critical(what)
Expand Down Expand Up @@ -178,55 +172,35 @@ class Admin:
A class used to ask the server to run admin functions
"""

def __init__(self):
def __init__(self, conf: Config):
self._log = getLogger(_LOG)
self._methods = _Methods()
if not conf.url or not conf.key_file:
raise UsageError("Admin mode requires a URL and key-file to be set")
self._conf = Conf(sign=self._load_ssh_key_file(conf.key_file), url=conf.url, session=Session())

def _load_ssh_key_file(self, key_file: Path) -> Callable[[bytes], bytes]:
"""
Load a private key from a file
:return: A function that can sign data using the key file
"""
self._log.info("Extracting private key from %s", key_file)
if not key_file.exists():
raise UsageError(f"Key file {key_file} does not exist", log=self._log.critical)
raise UsageError(f"Key file {key_file} does not exist")
try:
key = load_ssh_private_key(key_file.read_bytes(), None)
except UnsupportedAlgorithm as e:
raise UsageError(f"Key file {key_file} is not a supported ssh key", log=self._log.critical) from e
raise UsageError(f"Key file {key_file} is not a supported ssh key") from e
if not hasattr(key, "sign"):
raise UsageError(f"Key file {key_file} does not support signing", log=self._log.critical)
raise UsageError(f"Key file {key_file} does not support signing")
if TYPE_CHECKING:
return cast(Callable[[bytes], bytes], key.sign)
return key.sign

@cache # pylint: disable=method-cache-max-size-none
def _get_conf(self, raw_url: str | None, raw_key_file: Path | None):
"""
Extract a Conf from the given arguments and existing config file
Saves the config internally
"""
self._log.info("Determining Conf; defaults: url=%s, key_file=%s", raw_url, raw_key_file)
# Load data from config file
try:
path = ConfigFile().path
self._log.info("Querying config file %s if it exists", path)
raw = loads(path.read_text(encoding="utf-8")) if path.exists() else {}
key_file = Path(Option(raw_key_file).opt(raw.get("key_file", None)).value)
url: str = Option(raw_url).opt(raw.get("url", None)).value
except Exception as e:
msg = "Admin mode requires a URL and key file to be set or provided via the CLI"
raise UsageError(msg, log=self._log.critical) from e
self._log.info("Found key file: %s, extracting private key", key_file)
# Load ssh key
return Conf(sign=self._load_ssh_key_file(key_file), url=url, session=Session())

def _give_conf(self, func):
def wrapper(*args, **kwargs):
conf = self._get_conf(kwargs.pop("url"), kwargs.pop("key_file"))
return func(*args, conf=conf, **kwargs)

return wrapper

def __getattribute__(self, item: str) -> Any:
"""
Override the getattribute method to expose all methods of Methods
"""
if item.startswith("_"):
return object.__getattribute__(self, item)
return self._give_conf(self._methods.get(item, require_ssl=item != "debug"))
return partial(self._methods.get(item, require_ssl=item != "debug"), conf=self._conf)
4 changes: 3 additions & 1 deletion rpipe/client/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .client import Mode, rpipe
from .errors import UsageError
from .data import Config, Mode
from .client import rpipe
147 changes: 42 additions & 105 deletions rpipe/client/client/client.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,42 @@
from dataclasses import dataclass
from __future__ import annotations
from typing import TYPE_CHECKING
from logging import getLogger
from json import dumps

from zstandard import ZstdCompressor
from human_readable import listing


from ...shared import TRACE, QueryEC, Version, version
from ..config import ConfigFile, Option, PartialConfig
from .util import REQUEST_TIMEOUT, request
from .errors import UsageError, VersionError
from .data import Config, Mode
from .delete import delete
from .recv import recv
from .send import send

if TYPE_CHECKING:
from pathlib import Path


_LOG: str = "client"
_DEFAULT_LVL: int = 3


# pylint: disable=too-many-instance-attributes
@dataclass(kw_only=True, frozen=True)
class Mode:
"""
Arguments used to decide how rpipe should operate
"""

# Priority modes (in order)
print_config: bool
save_config: bool
outdated: bool
server_version: bool
query: bool
# Read/Write/Delete modes
read: bool
delete: bool
write: bool
# Read options
block: bool
peek: bool
force: bool
# Write options
ttl: int | None
zstd: int | None
threads: int
# Read / Write options
encrypt: Option[bool]
progress: bool | int


def _n_priority(mode: Mode) -> int:
return (mode.print_config, mode.save_config, mode.outdated, mode.server_version, mode.query).count(True)


def _check_mode_flags(mode: Mode) -> None:
def tru(x) -> bool:
rv = getattr(mode, x)
return (False if rv.is_none() else rv.value) if isinstance(rv, Option) else bool(rv)

# Flag specific checks
if mode.ttl is not None and mode.ttl <= 0:
raise UsageError("--ttl must be positive")
if mode.progress is not False and mode.progress <= 0:
raise UsageError("--progress argument must be positive if passed")
# Sanity check
if (n_pri := _n_priority(mode)) > 1:
raise UsageError("Only one priority mode may be used at a time")
if (mode.read, mode.write, mode.delete).count(True) != 1:
raise UsageError("Can only read, write, or delete at a time")
# Mode flags
read_bad = {"ttl"}
write_bad = {"block", "peek", "force"}
delete_bad = read_bad | write_bad | {"progress", "encrypt"}
bad = lambda x: [f"--{i}" for i in x if tru(i)]
fmt = lambda x: f"argument{'' if len(x) == 1 else 's'} {listing(x, ',', 'and') }: may not be used "
if n_pri > 0 and (args := bad(delete_bad)):
raise UsageError(fmt(args) + "with priority modes")
# Mode specific flags
if mode.read and (args := bad(read_bad)):
raise UsageError(fmt(args) + "when reading data from the pipe")
if mode.write and (args := bad(write_bad)):
raise UsageError(fmt(args) + "when writing data to the pipe")
if mode.delete and (args := bad(delete_bad)):
raise UsageError(fmt(args) + "when deleting data from the pipe")
def _print_config(conf: Config, config_file: Path) -> None:
log = getLogger(_LOG)
log.info("Mode: print-config")
print(f"Path: {config_file}")
print(conf)
try:
conf.validate()
log.info("Config validated")
except UsageError as e:
log.warning("Config invalid %s", e)


def _check_outdated(conf: PartialConfig) -> None:
def _check_outdated(conf: Config) -> None:
log = getLogger(_LOG)
log.info("Mode: Outdated")
r = request("GET", f"{conf.url.value}/supported")
r = request("GET", f"{conf.url}/supported")
if not r.ok:
raise RuntimeError(f"Failed to get server minimum version: {r}")
info = r.json()
Expand All @@ -97,86 +45,75 @@ def _check_outdated(conf: PartialConfig) -> None:
print(f"{'' if ok else 'NOT '}SUPPORTED")


def _query(conf: PartialConfig) -> None:
def _query(conf: Config) -> None:
log = getLogger(_LOG)
log.info("Mode: Query")
if conf.channel is None:
if not conf.channel:
raise UsageError("Channel unknown; try again with --channel")
log.info("Querying channel %s ...", conf.channel)
r = request("GET", f"{conf.url.value}/q/{conf.channel.value}")
r = request("GET", f"{conf.url}/q/{conf.channel}")
log.debug("Got response %s", r)
log.log(TRACE, "Data: %s", r.content)
match r.status_code:
case QueryEC.illegal_version:
raise VersionError(r.text)
case QueryEC.no_data:
print(f"No data on channel: {conf.channel.value}")
print(f"No data on channel: {conf.channel}")
return
if not r.ok:
raise RuntimeError(f"Query failed. Error {r.status_code}: {r.text}")
print(f"{conf.channel.value}: {dumps(r.json(), indent=4)}")
print(f"{conf.channel}: {dumps(r.json(), indent=4)}")


def _priority_actions(conf: PartialConfig, mode: Mode, config_file) -> bool:
if not (np := _n_priority(mode)):
return False
assert np == 1, "Sanity check on priority mode count failed"
def _priority_actions(conf: Config, mode: Mode, config_file: Path) -> None:
log = getLogger(_LOG)
if mode.print_config:
_print_config(conf, config_file)
return
if mode.save_config:
log.info("Mode: Save Config")
config_file.save(conf, mode.encrypt.is_true())
return True
conf.save(config_file)
return
# Everything after this requires the URL
if conf.url is None:
raise UsageError("Missing: --url")
# Check if supported
raise UsageError("Missing: URL")
# Remaining priority modes
if mode.outdated:
_check_outdated(conf)
# Print server version if requested
if mode.server_version:
log.info("Mode: Server Version")
r = request("GET", f"{conf.url.value}/version")
r = request("GET", f"{conf.url}/version")
if not r.ok:
raise RuntimeError(f"Failed to get version: {r}")
print(f"rpipe_server {r.text}")
if mode.query:
_query(conf)
return True


def rpipe(conf: PartialConfig, mode: Mode) -> None:
def rpipe(conf: Config, mode: Mode, config_file: Path) -> None:
"""
rpipe: A remote piping tool
Assumes no UsageError's in mode that argparse would catch
:param conf: Configuration for the remote pipe, may be invalid at this point
:param config_file: Path to the configuration file, may not exist
:param mode: Mode to operate in, assumes flags are valid and within expected ranges
"""
_check_mode_flags(mode)
log = getLogger(_LOG)
config_file = ConfigFile()
log.info("Config file: %s", config_file.path)
# Print config if requested, else load it
if mode.print_config:
log.info("Mode: Print Config")
config_file.print()
return
conf = config_file.load_onto(conf, mode.encrypt.is_false())
# Remaining priority actions + finish creating config
if _priority_actions(conf, mode, config_file):
if mode.priority():
_priority_actions(conf, mode, config_file)
return
if mode.write and not mode.encrypt.is_none():
log.info("Write mode: No password found, falling back to plaintext mode")
full_conf = config_file.verify(conf, mode.encrypt.is_true())
conf.validate()
if (mode.read or mode.write) and not conf.password:
log.warning("Encryption disabled: plaintext mode")
if mode.zstd is not None:
raise UsageError("Cannot compress data in plaintext mode")
# Invoke mode
log.info("HTTP timeout set to %d seconds", REQUEST_TIMEOUT)
if mode.read:
recv(full_conf, mode.block, mode.peek, mode.force, mode.progress)
recv(conf, mode.block, mode.peek, mode.force, mode.progress)
elif mode.write:
lvl = _DEFAULT_LVL if mode.zstd is None else mode.zstd
log.debug("Using compression level %d and %d threads", lvl, mode.threads)
compress = ZstdCompressor(write_checksum=True, level=lvl, threads=mode.threads).compress
send(full_conf, mode.ttl, mode.progress, compress)
send(conf, mode.ttl, mode.progress, compress)
else:
delete(full_conf)
delete(conf)
Loading

0 comments on commit 56b8cd4

Please sign in to comment.