diff --git a/magicclass/_gui/_base.py b/magicclass/_gui/_base.py index 015e72ec..5647bf1c 100644 --- a/magicclass/_gui/_base.py +++ b/magicclass/_gui/_base.py @@ -4,6 +4,7 @@ from typing import ( Any, ContextManager, + Literal, Union, Callable, TYPE_CHECKING, @@ -162,9 +163,39 @@ class construction. ) +def count_callback_levels(cls: type): + if cls.__mro__[1] is not MagicTemplate: + return + for name, attr in cls.__dict__.items(): + if isinstance(attr, MagicField): + for cb in attr.callbacks: + cb_ns = cb.__qualname__.rsplit(".", maxsplit=1)[0] + if not hasattr(cb, "__qualname__"): + continue + if cls.__qualname__.startswith(cb_ns): + level = cls.__qualname__[len(cb_ns) :].count(".") + cb.__magicclass_callback_level__ = level + + +def init_sub_magicclass(cls: type): + check_override(cls) + if cls.__mro__[1] is not MagicTemplate: + return + for attr in cls.__dict__.values(): + if isinstance(attr, MagicField): + for cb in attr.callbacks: + cb_ns = cb.__qualname__.rsplit(".", maxsplit=1)[0] + if not hasattr(cb, "__qualname__"): + continue + if cls.__qualname__.startswith(cb_ns): + level = cls.__qualname__[len(cb_ns) :].count(".") + cb.__magicclass_callback_level__ = level + + _ANCESTORS: dict[tuple[int, int], MagicTemplate] = {} _T = TypeVar("_T", bound="MagicTemplate") +_T1 = TypeVar("_T1") _F = TypeVar("_F", bound=Callable) @@ -176,7 +207,7 @@ def __get__(self: type[_T], obj: Any, objtype=None) -> _T: ... @overload - def __get__(self: type[_T], obj: None, objtype=None) -> type[_T]: + def __get__(self: _T1, obj: Literal[None], objtype=None) -> _T1: ... def __get__(self, obj, objtype=None): @@ -224,7 +255,8 @@ class MagicTemplate(MutableSequence[_Comp], metaclass=_MagicTemplateMeta): widget_type: str width: int - __init_subclass__ = check_override + def __init_subclass__(cls, **kwargs): + init_sub_magicclass(cls) @overload def __getitem__(self, key: int | str) -> _Comp: @@ -886,6 +918,9 @@ def __init__( self._my_symbol = Symbol.var("ui") self._icon = None + def __init_subclass__(cls, **kwargs): + pass + @property def icon(self): """Icon of this GUI.""" @@ -1307,7 +1342,7 @@ def convert_attributes( for name, obj in subcls.__dict__.items(): _isfunc = callable(obj) if isinstance(obj, _MagicTemplateMeta): - new_attr = copy_class(obj, cls, name=name) + new_attr = copy_class(obj, cls.__qualname__, name=name) elif name.startswith("_") or isinstance(obj, _pass) or not _isfunc: # private method, non-action-like object, not-callable object are passed. new_attr = obj @@ -1323,6 +1358,8 @@ def convert_attributes( return _dict +# def _find_callback_level() + _dummy_macro = DummyMacro() diff --git a/magicclass/_gui/utils.py b/magicclass/_gui/utils.py index 22795bbd..f4ae0afd 100644 --- a/magicclass/_gui/utils.py +++ b/magicclass/_gui/utils.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING, Callable, TypeVar from types import FunctionType - from magicgui.widgets import FunctionGui, Widget from magicgui.types import Undefined from magicgui.type_map import get_widget_class @@ -22,7 +21,7 @@ def get_parameters(fgui: FunctionGui): _C = TypeVar("_C", bound=type) -def copy_class(cls: _C, ns: type, name: str | None = None) -> _C: +def copy_class(cls: _C, ns: str, name: str) -> _C: """ Copy a class in a new namespace. @@ -34,32 +33,30 @@ def copy_class(cls: _C, ns: type, name: str | None = None) -> _C: cls : type Class to be copied. ns : type - New namespace of ``cls``. - name : str, optional - New name of ``cls``. If not given, the original name will be used. + New namespace (the qualname of parent class) of ``cls``. + name : str + New name of ``cls``. Returns ------- type Copied class object. """ - out = type(cls.__name__, cls.__bases__, dict(cls.__dict__)) - if name is None: - name = out.__name__ - _update_qualnames(out, f"{ns.__qualname__}.{name}") - return out - - -def _update_qualnames(cls: type, cls_qualname: str) -> None: - cls.__qualname__ = cls_qualname - # NOTE: updating cls.__name__ will make `wraps` incompatible. + namespace = {} + qualname = f"{ns}.{name}" for key, attr in cls.__dict__.items(): if isinstance(attr, FunctionType): - attr.__qualname__ = f"{cls_qualname}.{key}" + if attr.__qualname__.split(".")[-1].count(".") == 0: + pass + else: + attr.__qualname__ = f"{qualname}.{key}" elif isinstance(attr, type): - _update_qualnames(attr, f"{cls_qualname}.{key}") + attr = copy_class(attr, qualname, attr.__name__) + namespace[key] = attr - return None + out = type(cls.__name__, cls.__bases__, namespace) + out.__qualname__ = qualname + return out class MagicClassConstructionError(Exception): diff --git a/magicclass/core.py b/magicclass/core.py index 208cfdcf..eab90b69 100644 --- a/magicclass/core.py +++ b/magicclass/core.py @@ -31,7 +31,7 @@ ErrorMode, defaults, MagicTemplate, - check_override, + init_sub_magicclass, convert_attributes, ) from magicclass._gui import ContextMenuGui, MenuGui, ToolBarGui @@ -156,7 +156,7 @@ def wrapper(cls) -> type[ClassGui]: class_gui = _TYPE_MAP[widget_type] if not issubclass(cls, MagicTemplate): - check_override(cls) + init_sub_magicclass(cls) # get class attributes first doc = cls.__doc__ @@ -282,7 +282,7 @@ def wrapper(cls) -> type[ContextMenuGui]: raise TypeError(f"magicclass can only wrap classes, not {type(cls)}") if not issubclass(cls, MagicTemplate): - check_override(cls) + init_sub_magicclass(cls) # get class attributes first doc = cls.__doc__ @@ -407,7 +407,7 @@ def wrapper(cls) -> type[menugui_class]: raise TypeError(f"magicclass can only wrap classes, not {type(cls)}") if not issubclass(cls, MagicTemplate): - check_override(cls) + init_sub_magicclass(cls) # get class attributes first doc = cls.__doc__ diff --git a/magicclass/fields/_define.py b/magicclass/fields/_define.py index a03994ce..b0750cc6 100644 --- a/magicclass/fields/_define.py +++ b/magicclass/fields/_define.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING, Callable +import inspect from magicclass.utils import argcount @@ -15,9 +16,27 @@ def define_callback(self: Any, callback: Callable): def define_callback_gui(self: MagicTemplate, callback: Callable): """Define a callback function from a method of a magic-class.""" + if callback.__qualname__.split(".")[-1].count(".") == 0: + # not defined in a class + params = list(inspect.signature(callback).parameters.values()) + if len(params) > 0 and params[0].name == "self": + callback: Callable = callback.__get__(self) + _func = _normalize_argcount(callback) + + def _callback(v): + with self.macro.blocked(): + _func(v) + return None + + return _callback + *_, clsname, funcname = callback.__qualname__.split(".") mro = self.__class__.__mro__ for base in mro: + if base.__module__ in ("collections.abc", "abc", "typing", "builtins"): + continue + if base.__module__.startswith(("magicclass.widgets", "magicgui.widgets")): + continue if base.__name__ == clsname: _func: Callable = getattr(base, funcname).__get__(self) _func = _normalize_argcount(_func) @@ -27,7 +46,23 @@ def _callback(v): _func(v) return None - break + return _callback + + if hasattr(callback, "__magicclass_callback_level__"): + level = callback.__magicclass_callback_level__ + assert isinstance(level, int) and level >= 0 + + def _callback(v): + # search for parent instances that have the same name. + current_self = self + for _ in range(level): + current_self = current_self.__magicclass_parent__ + _func = _normalize_argcount(getattr(current_self, funcname)) + + with self.macro.blocked(): + _func(v) + return None + else: def _callback(v): diff --git a/magicclass/fields/_fields.py b/magicclass/fields/_fields.py index 5dd8e9ea..17706e0e 100644 --- a/magicclass/fields/_fields.py +++ b/magicclass/fields/_fields.py @@ -275,20 +275,14 @@ def get_widget(self, obj: Any) -> _W: This function will be called every time MagicField is referred by ``obj.field``. """ - from magicclass._gui import MagicTemplate - obj_id = id(obj) if (widget := self._guis.get(obj_id, None)) is None: self._guis[obj_id] = widget = self.construct(obj) widget.name = self.name or "" if isinstance(widget, (ValueWidget, ContainerWidget)): - if isinstance(obj, MagicTemplate): - _def = define_callback_gui - else: - _def = define_callback + _def = self._get_define_callback(obj) for callback in self._callbacks: - # funcname = callback.__name__ widget.changed.connect(_def(obj, callback)) return widget @@ -298,8 +292,6 @@ def get_action(self, obj: Any) -> AbstractAction: Get an action from ``obj``. This function will be called every time MagicField is referred by ``obj.field``. """ - from magicclass._gui import MagicTemplate - obj_id = id(obj) if obj_id in self._guis.keys(): action = self._guis[obj_id] @@ -324,16 +316,25 @@ def get_action(self, obj: Any) -> AbstractAction: self._guis[obj_id] = action if action.support_value: - if isinstance(obj, MagicTemplate): - _def = define_callback_gui - else: - _def = define_callback + _def = self._get_define_callback(obj) for callback in self._callbacks: # funcname = callback.__name__ action.changed.connect(_def(obj, callback)) return action + def _get_define_callback( + self, + obj: Any, + ) -> Callable[[MagicTemplate, Any], Callable[[Any], None]]: + from magicclass._gui import MagicTemplate + + if isinstance(obj, MagicTemplate): + _def = define_callback_gui + else: + _def = define_callback + return _def + def as_getter(self, obj: Any) -> Callable[[Any], Any]: """Make a function that get the value of Widget or Action.""" return lambda w: self._guis[id(obj)].value diff --git a/magicclass/utils/qthreading/_callback.py b/magicclass/utils/qthreading/_callback.py index 6fba0208..d024b337 100644 --- a/magicclass/utils/qthreading/_callback.py +++ b/magicclass/utils/qthreading/_callback.py @@ -6,14 +6,15 @@ Callable, TYPE_CHECKING, Iterable, + Literal, TypeVar, Generic, + overload, ) from typing_extensions import ParamSpec if TYPE_CHECKING: from magicclass._gui import BaseGui - from .thread_worker import thread_worker _P = ParamSpec("_P") _R1 = TypeVar("_R1") @@ -112,10 +113,18 @@ def with_args(self, *args: _P.args, **kwargs: _P.kwargs) -> Callback[[], _R1]: """Return a partial callback.""" return self.__class__(partial(self._func, *args, **kwargs)) - def __get__(self, obj, type=None) -> Callback[_P, _R1]: + @overload + def __get__(self, obj, type=None) -> Callback[..., _R1]: + ... + + @overload + def __get__(self, obj: Literal[None], type=None) -> Callback[_P, _R1]: + ... + + def __get__(self, obj, type=None): if obj is None: return self - return self.with_args(obj) + return self.__class__(partial(self._func, obj)) class NestedCallback: diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 2eb67aec..758dc768 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -175,3 +175,24 @@ def _callback(self, v): ui = A() ui.a.value = 1 assert ui._v == 1 + +def test_callback_outside_class(): + def _cb0(): + _cb0.x = 1 + def _cb1(val): + _cb1.x = val + def _cb2(self, val): + _cb2.x = val + + @magicclass + class A: + a = field(int) + a.connect(_cb0) + a.connect(_cb1) + a.connect(_cb2) + + ui = A() + ui.a.value = 4 + assert _cb0.x == 1 + assert _cb1.x == 4 + assert _cb2.x == 4 diff --git a/tests/test_copy_class.py b/tests/test_copy_class.py index 5bfd7f4f..331d058f 100644 --- a/tests/test_copy_class.py +++ b/tests/test_copy_class.py @@ -1,5 +1,7 @@ -from magicclass import magicclass, field -from magicclass.types import Bound, OneOf +from typing_extensions import Annotated +from unittest.mock import MagicMock +from magicclass import magicclass, field, MagicTemplate, abstractapi +from magicclass.types import OneOf @magicclass class B: @@ -12,11 +14,11 @@ def test_getter_of_same_name(): class A: B = B out = None - def f(self, x: Bound[B.f]): + def f(self, x: Annotated[int, {"bind": B.f}]): x.as_integer_ratio() self.out = x - def g(self, x: Bound[B._get_value]): + def g(self, x: Annotated[int, {"bind": B._get_value}]): x.capitalize() self.out = x @@ -32,11 +34,11 @@ def test_getter_of_different_name(): class A: b = B out = None - def f(self, x: Bound[b.f]): + def f(self, x: Annotated[int, {"bind": b.f}]): x.as_integer_ratio() self.out = x - def g(self, x: Bound[b._get_value]): + def g(self, x: Annotated[str, {"bind": b._get_value}]): x.capitalize() self.out = x @@ -51,11 +53,11 @@ def test_getter_of_private_name(): class A: _b = B out = None - def f(self, x: Bound[_b.f]): + def f(self, x: Annotated[int, {"bind": _b.f}]): x.as_integer_ratio() self.out = x - def g(self, x: Bound[_b._get_value]): + def g(self, x: Annotated[str, {"bind": _b._get_value}]): x.capitalize() self.out = x @@ -81,7 +83,7 @@ class A: _c = C out = None @_c.wraps - def run(self, x: Bound[_c.f]): + def run(self, x: Annotated[int, {"bind": _c.f}]): x.as_integer_ratio() self.out = x @@ -107,7 +109,7 @@ class A: _c = C out = None - def f(self, x: Bound[_c.f]): + def f(self, x: Annotated[int, {"bind": _c.f}]): x.as_integer_ratio() self.out = x @@ -119,3 +121,43 @@ def f(self, x: Bound[_c.f]): # check value-changed event ui._c.f.value = 2 assert ui._c._value == 2 + +def test_reuse_class(): + mock = MagicMock() + + @magicclass + class Parent(MagicTemplate): + top = field(int) + + @top.connect + def _top_changed(self, v) -> None: + mock(self, v) + + bottom = abstractapi() + + @magicclass + class A(MagicTemplate): + p = Parent + @p.wraps + def bottom(self): + self.p.top.value = 100 + + @magicclass + class B(MagicTemplate): + p = Parent + @p.wraps + def bottom(self): + self.p.top.value = 200 + + a = A() + b = B() + + a.p.top.value = 1 + mock.assert_called_once_with(a.p, 1) + assert b.p.top.value == 0 + a["bottom"].changed() + assert a.p.top.value == 100 + assert b.p.top.value == 0 + b["bottom"].changed() + assert a.p.top.value == 100 + assert b.p.top.value == 200 diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py index bd78434a..e914f519 100644 --- a/tests/test_inheritance.py +++ b/tests/test_inheritance.py @@ -108,3 +108,46 @@ class B(Base): b.X["func"].changed() assert a.result.value == "A" assert b.result.value == "B" + +def test_same_callback(): + class Parent(MagicTemplate): + def __init__(self): + self._value = None + foo = field(False) + + @foo.connect + def _foo_changed(self, v) -> None: + self._value = v + + @magicclass + class Container(MagicTemplate): + def __init__(self) -> None: + self._value = None + + @magicclass + class ChildA(Parent): + bar = field(int) + + @magicclass + class ChildB(Parent): + bar = field(int) + + @ChildA.bar.connect + @ChildB.bar.connect + def _bar_changed(self, v): + self._value = v + + c = Container() + assert c._value is None + c.ChildA.foo.value = True + assert c.ChildA._value + assert c.ChildB._value is None + c.ChildB.foo.value = True + c.ChildB.foo.value = False + assert c.ChildA._value + assert not c.ChildB._value + + c.ChildA.bar.value = 1 + assert c._value == 1 + c.ChildB.bar.value = 10 + assert c._value == 10