Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix field callback search #116

Merged
merged 4 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions magicclass/_gui/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
Any,
ContextManager,
Literal,
Union,
Callable,
TYPE_CHECKING,
Expand Down Expand Up @@ -162,9 +163,39 @@ class construction.
)


def count_callback_levels(cls: type):
if cls.__mro__[1] is not MagicTemplate:
return
for name, attr in cls.__dict__.items():
if isinstance(attr, MagicField):
for cb in attr.callbacks:
cb_ns = cb.__qualname__.rsplit(".", maxsplit=1)[0]
if not hasattr(cb, "__qualname__"):
continue
if cls.__qualname__.startswith(cb_ns):
level = cls.__qualname__[len(cb_ns) :].count(".")
cb.__magicclass_callback_level__ = level


def init_sub_magicclass(cls: type):
check_override(cls)
if cls.__mro__[1] is not MagicTemplate:
return
for attr in cls.__dict__.values():
if isinstance(attr, MagicField):
for cb in attr.callbacks:
cb_ns = cb.__qualname__.rsplit(".", maxsplit=1)[0]
if not hasattr(cb, "__qualname__"):
continue
if cls.__qualname__.startswith(cb_ns):
level = cls.__qualname__[len(cb_ns) :].count(".")
cb.__magicclass_callback_level__ = level


_ANCESTORS: dict[tuple[int, int], MagicTemplate] = {}

_T = TypeVar("_T", bound="MagicTemplate")
_T1 = TypeVar("_T1")
_F = TypeVar("_F", bound=Callable)


Expand All @@ -176,7 +207,7 @@ def __get__(self: type[_T], obj: Any, objtype=None) -> _T:
...

@overload
def __get__(self: type[_T], obj: None, objtype=None) -> type[_T]:
def __get__(self: _T1, obj: Literal[None], objtype=None) -> _T1:
...

def __get__(self, obj, objtype=None):
Expand Down Expand Up @@ -224,7 +255,8 @@ class MagicTemplate(MutableSequence[_Comp], metaclass=_MagicTemplateMeta):
widget_type: str
width: int

__init_subclass__ = check_override
def __init_subclass__(cls, **kwargs):
init_sub_magicclass(cls)

@overload
def __getitem__(self, key: int | str) -> _Comp:
Expand Down Expand Up @@ -886,6 +918,9 @@ def __init__(
self._my_symbol = Symbol.var("ui")
self._icon = None

def __init_subclass__(cls, **kwargs):
pass

@property
def icon(self):
"""Icon of this GUI."""
Expand Down Expand Up @@ -1307,7 +1342,7 @@ def convert_attributes(
for name, obj in subcls.__dict__.items():
_isfunc = callable(obj)
if isinstance(obj, _MagicTemplateMeta):
new_attr = copy_class(obj, cls, name=name)
new_attr = copy_class(obj, cls.__qualname__, name=name)
elif name.startswith("_") or isinstance(obj, _pass) or not _isfunc:
# private method, non-action-like object, not-callable object are passed.
new_attr = obj
Expand All @@ -1323,6 +1358,8 @@ def convert_attributes(
return _dict


# def _find_callback_level()

_dummy_macro = DummyMacro()


Expand Down
33 changes: 15 additions & 18 deletions magicclass/_gui/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING, Callable, TypeVar
from types import FunctionType

from magicgui.widgets import FunctionGui, Widget
from magicgui.types import Undefined
from magicgui.type_map import get_widget_class
Expand All @@ -22,7 +21,7 @@ def get_parameters(fgui: FunctionGui):
_C = TypeVar("_C", bound=type)


def copy_class(cls: _C, ns: type, name: str | None = None) -> _C:
def copy_class(cls: _C, ns: str, name: str) -> _C:
"""
Copy a class in a new namespace.

Expand All @@ -34,32 +33,30 @@ def copy_class(cls: _C, ns: type, name: str | None = None) -> _C:
cls : type
Class to be copied.
ns : type
New namespace of ``cls``.
name : str, optional
New name of ``cls``. If not given, the original name will be used.
New namespace (the qualname of parent class) of ``cls``.
name : str
New name of ``cls``.

Returns
-------
type
Copied class object.
"""
out = type(cls.__name__, cls.__bases__, dict(cls.__dict__))
if name is None:
name = out.__name__
_update_qualnames(out, f"{ns.__qualname__}.{name}")
return out


def _update_qualnames(cls: type, cls_qualname: str) -> None:
cls.__qualname__ = cls_qualname
# NOTE: updating cls.__name__ will make `wraps` incompatible.
namespace = {}
qualname = f"{ns}.{name}"
for key, attr in cls.__dict__.items():
if isinstance(attr, FunctionType):
attr.__qualname__ = f"{cls_qualname}.{key}"
if attr.__qualname__.split("<locals>.")[-1].count(".") == 0:
pass
else:
attr.__qualname__ = f"{qualname}.{key}"
elif isinstance(attr, type):
_update_qualnames(attr, f"{cls_qualname}.{key}")
attr = copy_class(attr, qualname, attr.__name__)
namespace[key] = attr

return None
out = type(cls.__name__, cls.__bases__, namespace)
out.__qualname__ = qualname
return out


class MagicClassConstructionError(Exception):
Expand Down
8 changes: 4 additions & 4 deletions magicclass/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ErrorMode,
defaults,
MagicTemplate,
check_override,
init_sub_magicclass,
convert_attributes,
)
from magicclass._gui import ContextMenuGui, MenuGui, ToolBarGui
Expand Down Expand Up @@ -156,7 +156,7 @@ def wrapper(cls) -> type[ClassGui]:
class_gui = _TYPE_MAP[widget_type]

if not issubclass(cls, MagicTemplate):
check_override(cls)
init_sub_magicclass(cls)

# get class attributes first
doc = cls.__doc__
Expand Down Expand Up @@ -282,7 +282,7 @@ def wrapper(cls) -> type[ContextMenuGui]:
raise TypeError(f"magicclass can only wrap classes, not {type(cls)}")

if not issubclass(cls, MagicTemplate):
check_override(cls)
init_sub_magicclass(cls)

# get class attributes first
doc = cls.__doc__
Expand Down Expand Up @@ -407,7 +407,7 @@ def wrapper(cls) -> type[menugui_class]:
raise TypeError(f"magicclass can only wrap classes, not {type(cls)}")

if not issubclass(cls, MagicTemplate):
check_override(cls)
init_sub_magicclass(cls)

# get class attributes first
doc = cls.__doc__
Expand Down
37 changes: 36 additions & 1 deletion magicclass/fields/_define.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING, Callable
import inspect

from magicclass.utils import argcount

Expand All @@ -15,9 +16,27 @@ def define_callback(self: Any, callback: Callable):
def define_callback_gui(self: MagicTemplate, callback: Callable):
"""Define a callback function from a method of a magic-class."""

if callback.__qualname__.split("<locals>.")[-1].count(".") == 0:
# not defined in a class
params = list(inspect.signature(callback).parameters.values())
if len(params) > 0 and params[0].name == "self":
callback: Callable = callback.__get__(self)
_func = _normalize_argcount(callback)

def _callback(v):
with self.macro.blocked():
_func(v)
return None

return _callback

*_, clsname, funcname = callback.__qualname__.split(".")
mro = self.__class__.__mro__
for base in mro:
if base.__module__ in ("collections.abc", "abc", "typing", "builtins"):
continue
if base.__module__.startswith(("magicclass.widgets", "magicgui.widgets")):
continue
if base.__name__ == clsname:
_func: Callable = getattr(base, funcname).__get__(self)
_func = _normalize_argcount(_func)
Expand All @@ -27,7 +46,23 @@ def _callback(v):
_func(v)
return None

break
return _callback

if hasattr(callback, "__magicclass_callback_level__"):
level = callback.__magicclass_callback_level__
assert isinstance(level, int) and level >= 0

def _callback(v):
# search for parent instances that have the same name.
current_self = self
for _ in range(level):
current_self = current_self.__magicclass_parent__
_func = _normalize_argcount(getattr(current_self, funcname))

with self.macro.blocked():
_func(v)
return None

else:

def _callback(v):
Expand Down
27 changes: 14 additions & 13 deletions magicclass/fields/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,20 +275,14 @@ def get_widget(self, obj: Any) -> _W:
This function will be called every time MagicField is referred by
``obj.field``.
"""
from magicclass._gui import MagicTemplate

obj_id = id(obj)
if (widget := self._guis.get(obj_id, None)) is None:
self._guis[obj_id] = widget = self.construct(obj)
widget.name = self.name or ""

if isinstance(widget, (ValueWidget, ContainerWidget)):
if isinstance(obj, MagicTemplate):
_def = define_callback_gui
else:
_def = define_callback
_def = self._get_define_callback(obj)
for callback in self._callbacks:
# funcname = callback.__name__
widget.changed.connect(_def(obj, callback))

return widget
Expand All @@ -298,8 +292,6 @@ def get_action(self, obj: Any) -> AbstractAction:
Get an action from ``obj``. This function will be called every time MagicField is referred
by ``obj.field``.
"""
from magicclass._gui import MagicTemplate

obj_id = id(obj)
if obj_id in self._guis.keys():
action = self._guis[obj_id]
Expand All @@ -324,16 +316,25 @@ def get_action(self, obj: Any) -> AbstractAction:
self._guis[obj_id] = action

if action.support_value:
if isinstance(obj, MagicTemplate):
_def = define_callback_gui
else:
_def = define_callback
_def = self._get_define_callback(obj)
for callback in self._callbacks:
# funcname = callback.__name__
action.changed.connect(_def(obj, callback))

return action

def _get_define_callback(
self,
obj: Any,
) -> Callable[[MagicTemplate, Any], Callable[[Any], None]]:
from magicclass._gui import MagicTemplate

if isinstance(obj, MagicTemplate):
_def = define_callback_gui
else:
_def = define_callback
return _def

def as_getter(self, obj: Any) -> Callable[[Any], Any]:
"""Make a function that get the value of Widget or Action."""
return lambda w: self._guis[id(obj)].value
Expand Down
15 changes: 12 additions & 3 deletions magicclass/utils/qthreading/_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
Callable,
TYPE_CHECKING,
Iterable,
Literal,
TypeVar,
Generic,
overload,
)
from typing_extensions import ParamSpec

if TYPE_CHECKING:
from magicclass._gui import BaseGui
from .thread_worker import thread_worker

_P = ParamSpec("_P")
_R1 = TypeVar("_R1")
Expand Down Expand Up @@ -112,10 +113,18 @@ def with_args(self, *args: _P.args, **kwargs: _P.kwargs) -> Callback[[], _R1]:
"""Return a partial callback."""
return self.__class__(partial(self._func, *args, **kwargs))

def __get__(self, obj, type=None) -> Callback[_P, _R1]:
@overload
def __get__(self, obj, type=None) -> Callback[..., _R1]:
...

@overload
def __get__(self, obj: Literal[None], type=None) -> Callback[_P, _R1]:
...

def __get__(self, obj, type=None):
if obj is None:
return self
return self.with_args(obj)
return self.__class__(partial(self._func, obj))


class NestedCallback:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,24 @@ def _callback(self, v):
ui = A()
ui.a.value = 1
assert ui._v == 1

def test_callback_outside_class():
def _cb0():
_cb0.x = 1
def _cb1(val):
_cb1.x = val
def _cb2(self, val):
_cb2.x = val

@magicclass
class A:
a = field(int)
a.connect(_cb0)
a.connect(_cb1)
a.connect(_cb2)

ui = A()
ui.a.value = 4
assert _cb0.x == 1
assert _cb1.x == 4
assert _cb2.x == 4
Loading