diff --git a/README.md b/README.md index 503efdb7..160588e0 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,39 @@ with fs.open(p.path) as f: data = f.read() ``` +### Register custom UPath implementations + +In case you develop a custom UPath implementation, feel free to open an issue to discuss integrating it +in `universal_pathlib`. You can dynamically register your implementation too! Here are your options: + +#### Dynamic registration from Python + +```python +# for example: mymodule/submodule.py +from upath import UPath +from upath.registry import register_implementation + +my_protocol = "myproto" +class MyPath(UPath): + ... # your custom implementation + +register_implementation(my_protocol, MyPath) +``` + +#### Registration via entry points + +```toml +# pyproject.toml +[project.entry-points."unversal_pathlib.implementations"] +myproto = "my_module.submodule:MyPath" +``` + +```ini +# setup.cfg +[options.entry_points] +universal_pathlib.implementations = + myproto = my_module.submodule:MyPath +``` ## Contributing diff --git a/upath/registry.py b/upath/registry.py index b30fd6d1..43bf3ccc 100644 --- a/upath/registry.py +++ b/upath/registry.py @@ -1,22 +1,63 @@ +"""upath.registry -- registry for file system specific implementations + +Retrieve UPath implementations via `get_upath_class`. +Register custom UPath subclasses in one of two ways: + +### directly from Python + +>>> from upath import UPath +>>> from upath.registry import register_implementation +>>> my_protocol = "myproto" +>>> class MyPath(UPath): +... pass +>>> register_implementation(my_protocol, MyPath) + +### via entry points + +```toml +# pyproject.toml +[project.entry-points."unversal_pathlib.implementations"] +myproto = "my_module.submodule:MyPath" +``` + +```ini +# setup.cfg +[options.entry_points] +universal_pathlib.implementations = + myproto = my_module.submodule:MyPath +``` +""" from __future__ import annotations -import importlib import os +import re +import sys import warnings +from collections import ChainMap from functools import lru_cache -from typing import TYPE_CHECKING +from importlib import import_module +from importlib.metadata import entry_points +from typing import Iterator +from typing import MutableMapping from fsspec.core import get_filesystem_class +from fsspec.registry import available_protocols -if TYPE_CHECKING: - from upath.core import UPath +import upath.core __all__ = [ "get_upath_class", + "available_implementations", + "register_implementation", ] -class _Registry: +_ENTRY_POINT_GROUP = "universal_pathlib.implementations" + + +class _Registry(MutableMapping[str, "type[upath.core.UPath]"]): + """internal registry for UPath subclasses""" + known_implementations: dict[str, str] = { "abfs": "upath.implementations.cloud.AzurePath", "abfss": "upath.implementations.cloud.AzurePath", @@ -35,26 +76,118 @@ class _Registry: "webdav+https": "upath.implementations.webdav.WebdavPath", } - def __getitem__(self, item: str) -> type[UPath] | None: - try: - fqn = self.known_implementations[item] - except KeyError: - return None - module_name, name = fqn.rsplit(".", 1) - mod = importlib.import_module(module_name) - return getattr(mod, name) # type: ignore + def __init__(self) -> None: + if sys.version_info >= (3, 10): + eps = entry_points(group=_ENTRY_POINT_GROUP) + else: + eps = entry_points().get(_ENTRY_POINT_GROUP, []) + self._entries = {ep.name: ep for ep in eps} + self._m = ChainMap({}, self.known_implementations) # type: ignore + + def __contains__(self, item: object) -> bool: + return item in set().union(self._m, self._entries) + + def __getitem__(self, item: str) -> type[upath.core.UPath]: + fqn = self._m.get(item) + if fqn is None: + if item in self._entries: + fqn = self._m[item] = self._entries[item].load() + if fqn is None: + raise KeyError(f"{item} not in registry") + if isinstance(fqn, str): + module_name, name = fqn.rsplit(".", 1) + mod = import_module(module_name) + cls = getattr(mod, name) # type: ignore + else: + cls = fqn + return cls + + def __setitem__(self, item: str, value: type[upath.core.UPath] | str) -> None: + if not ( + (isinstance(value, type) and issubclass(value, upath.core.UPath)) + or isinstance(value, str) + ): + raise ValueError( + f"expected UPath subclass or FQN-string, got: {type(value).__name__!r}" + ) + self._m[item] = value + + def __delitem__(self, __v: str) -> None: + raise NotImplementedError("removal is unsupported") + + def __len__(self) -> int: + return len(set().union(self._m, self._entries)) + + def __iter__(self) -> Iterator[str]: + return iter(set().union(self._m, self._entries)) _registry = _Registry() -@lru_cache -def get_upath_class(protocol: str) -> type[UPath] | None: - """Return the upath cls for the given protocol.""" - cls: type[UPath] | None = _registry[protocol] - if cls is not None: - return cls +def available_implementations(*, fallback: bool = False) -> list[str]: + """return a list of protocols for available implementations + + Parameters + ---------- + fallback: + If True, also return protocols for fsspec filesystems without + an implementation in universal_pathlib. + """ + impl = list(_registry) + if not fallback: + return impl else: + return list({*impl, *available_protocols()}) + + +def register_implementation( + protocol: str, + cls: type[upath.core.UPath] | str, + *, + clobber: bool = False, +) -> None: + """register a UPath implementation with a protocol + + Parameters + ---------- + protocol: + Protocol name to associate with the class + cls: + The UPath subclass for the protocol or a str representing the + full path to an implementation class like package.module.class. + clobber: + Whether to overwrite a protocol with the same name; if False, + will raise instead. + """ + if not re.match(r"^[a-z][a-z0-9+_.]+$", protocol): + raise ValueError(f"{protocol!r} is not a valid URI scheme") + if not clobber and protocol in _registry: + raise ValueError(f"{protocol!r} is already in registry and clobber is False!") + _registry[protocol] = cls + + +@lru_cache +def get_upath_class( + protocol: str, + *, + fallback: bool = True, +) -> type[upath.core.UPath] | None: + """Return the upath cls for the given protocol. + + Returns `None` if no matching protocol can be found. + + Parameters + ---------- + protocol: + The protocol string + fallback: + If fallback is False, don't return UPath instances for fsspec + filesystems that don't have an implementation registered. + """ + try: + return _registry[protocol] + except KeyError: if not protocol: if os.name == "nt": from upath.implementations.local import WindowsUPath @@ -64,6 +197,8 @@ def get_upath_class(protocol: str) -> type[UPath] | None: from upath.implementations.local import PosixUPath return PosixUPath + if not fallback: + return None try: _ = get_filesystem_class(protocol) except ValueError: @@ -76,5 +211,4 @@ def get_upath_class(protocol: str) -> type[UPath] | None: UserWarning, stacklevel=2, ) - mod = importlib.import_module("upath.core") - return mod.UPath # type: ignore + return upath.core.UPath diff --git a/upath/tests/test_registry.py b/upath/tests/test_registry.py new file mode 100644 index 00000000..148f0324 --- /dev/null +++ b/upath/tests/test_registry.py @@ -0,0 +1,126 @@ +import pytest +from fsspec.registry import available_protocols + +from upath import UPath +from upath.registry import available_implementations +from upath.registry import get_upath_class +from upath.registry import register_implementation + +IMPLEMENTATIONS = { + "abfs", + "abfss", + "adl", + "az", + "file", + "gcs", + "gs", + "hdfs", + "http", + "https", + "memory", + "s3", + "s3a", + "webdav+http", + "webdav+https", +} + + +@pytest.fixture(autouse=True) +def reset_registry(): + from upath.registry import _registry + + try: + yield + finally: + _registry._m.maps[0].clear() # type: ignore + + +@pytest.fixture() +def fake_entrypoint(): + from importlib.metadata import EntryPoint + + from upath.registry import _registry + + ep = EntryPoint( + name="myeps", + value="upath.core:UPath", + group="universal_pathlib.implementations", + ) + old_registry = _registry._entries.copy() + + try: + _registry._entries["myeps"] = ep + yield + finally: + _registry._entries.clear() + _registry._entries.update(old_registry) + + +def test_available_implementations(): + impl = available_implementations() + assert len(impl) == len(set(impl)) + assert set(impl) == IMPLEMENTATIONS + + +def test_available_implementations_with_fallback(): + impl = available_implementations(fallback=True) + assert set(impl) == IMPLEMENTATIONS.union(available_protocols()) + + +def test_available_implementations_with_entrypoint(fake_entrypoint): + impl = available_implementations() + assert set(impl) == IMPLEMENTATIONS.union({"myeps"}) + + +def test_register_implementation(): + class MyProtoPath(UPath): + pass + + register_implementation("myproto", MyProtoPath) + + assert get_upath_class("myproto") is MyProtoPath + + +def test_register_implementation_wrong_input(): + with pytest.raises(TypeError): + register_implementation(None, UPath) # type: ignore + with pytest.raises(ValueError): + register_implementation("incorrect**protocol", UPath) + with pytest.raises(ValueError): + register_implementation("myproto", object, clobber=True) # type: ignore + with pytest.raises(ValueError): + register_implementation("file", UPath, clobber=False) + assert set(available_implementations()) == IMPLEMENTATIONS + + +@pytest.mark.parametrize("protocol", IMPLEMENTATIONS) +def test_get_upath_class(protocol): + upath_cls = get_upath_class("file") + assert issubclass(upath_cls, UPath) + + +def test_get_upath_class_without_implementation(clear_registry): + with pytest.warns( + UserWarning, match="UPath 'mock' filesystem not explicitly implemented." + ): + upath_cls = get_upath_class("mock") + assert issubclass(upath_cls, UPath) + + +def test_get_upath_class_without_implementation_no_fallback(clear_registry): + assert get_upath_class("mock", fallback=False) is None + + +def test_get_upath_class_unknown_protocol(clear_registry): + assert get_upath_class("doesnotexist") is None + + +def test_get_upath_class_from_entrypoint(fake_entrypoint): + assert issubclass(get_upath_class("myeps"), UPath) + + +@pytest.mark.parametrize( + "protocol", [pytest.param("", id="empty-str"), pytest.param(None, id="none")] +) +def test_get_upath_class_falsey_protocol(protocol): + assert issubclass(get_upath_class(protocol), UPath)