Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix connect_async not yielded. #118

Merged
merged 8 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion magicclass/_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from qtpy.QtWidgets import QApplication

APPLICATION: QApplication = None
APPLICATION: "QApplication | None" = None


def get_shell():
Expand Down
20 changes: 17 additions & 3 deletions magicclass/_gui/_macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from macrokit import Symbol, Expr, Head, BaseMacro, parse, symbol
from macrokit.utils import check_call_args, check_attributes
from magicgui.widgets import FileEdit, LineEdit, EmptyWidget, PushButton
from magicclass.utils.qthreading import thread_worker
from magicclass.utils.qthreading import thread_worker, run_async

from magicclass.widgets import CodeEdit, TabbedContainer, ScrollableContainer, Dialog
from magicclass.utils import move_to_screen_center
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(self, is_main: bool = True, **kwargs):
self._signature_check = False
self._name_check = False
self._syntax_highlight = False
self._run_async = False

def _add_code_edit(self, name: str = "script", native: bool = False) -> CodeEdit:
"""Add a new code edit widget as a new tab."""
Expand Down Expand Up @@ -295,7 +296,10 @@ def _execute(self, code: Expr):
raise ValueError("No code selected")
if (viewer := parent.parent_viewer) is not None:
ns.setdefault(Symbol.var("viewer"), viewer)
code.eval(ns)
if self._run_async:
run_async(code, parent, ns=ns)
else:
code.eval(ns)

def _execute_selected(self, e=None):
"""Run selected line of macro."""
Expand Down Expand Up @@ -450,7 +454,8 @@ def __repr__(self) -> str:
f"syntax_highlight={self.syntax_highlight}, "
f"attribute_check={self.attribute_check}, "
f"signature_check={self.signature_check}, "
f"name_check={self.name_check})"
f"name_check={self.name_check}, "
f"run_async={self.run_async})"
)

@property
Expand Down Expand Up @@ -520,6 +525,15 @@ def name_check(self) -> bool:
def name_check(self, value: bool):
self.macro.widget._name_check = bool(value)

@property
def run_async(self) -> bool:
"""Whether to execute macro asynchronously."""
return self.macro.widget._run_async

@run_async.setter
def run_async(self, value: bool):
self.macro.widget._run_async = bool(value)


class GuiMacro(BaseMacro):
"""Macro object with GUI-specific functions."""
Expand Down
4 changes: 3 additions & 1 deletion magicclass/fields/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def _func(self: MagicTemplate, *args, **kwargs):
if isinstance(out, GeneratorType):
while True:
try:
next(out)
next_value = next(out)
except StopIteration as exc:
out = exc.value
break
Expand All @@ -556,6 +556,8 @@ def _func(self: MagicTemplate, *args, **kwargs):
or _this_id < _last_run_id
):
return thread_worker.callback()
else:
yield next_value
except Exception as exc:
if _running is not None:
_running.quit()
Expand Down
2 changes: 1 addition & 1 deletion magicclass/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
from ._click import click
from ._recent import call_recent_menu
from .qtsignal import QtSignal
from .qthreading import thread_worker, Timer, Callback, to_callback
from .qthreading import thread_worker, Timer, Callback
6 changes: 4 additions & 2 deletions magicclass/utils/qthreading/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from .thread_worker import thread_worker, to_callback
from .thread_worker import thread_worker
from ._callback import Callback, CallbackList
from ._progressbar import Timer, ProgressDict, DefaultProgressBar
from ._to_async import to_async_code, run_async

__all__ = [
"thread_worker",
"Callback",
"to_callback",
"Timer",
"CallbackList",
"ProgressDict",
"DefaultProgressBar",
"to_async_code",
"run_async",
]
79 changes: 66 additions & 13 deletions magicclass/utils/qthreading/_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import time
from functools import partial, wraps
from typing import (
Any,
Expand Down Expand Up @@ -77,7 +78,9 @@ def _iter_as_nested_cb(
self, gui: BaseGui, *args, filtered: bool = False
) -> Iterable[NestedCallback]:
for c in self._iter_as_method(gui, filtered=filtered):
yield NestedCallback(c, *args)
ncb = NestedCallback(c).with_args(*args)
yield ncb
ncb.await_call()


def _make_method(func, obj: BaseGui):
Expand All @@ -102,22 +105,57 @@ def f(yielded):
return f


class Callback(Generic[_P, _R1]):
"""Callback object that can be recognized by thread_worker."""

class _AwaitableCallback(Generic[_P, _R1]):
def __init__(self, f: Callable[_P, _R1]):
if not callable(f):
raise TypeError(f"{f} is not callable.")
self._func = f
wraps(f)(self)
self._called = False

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R1:
return self._func(*args, **kwargs)
self._called = False
out = self._func(*args, **kwargs)
self._called = True
return out

def with_args(self, *args: _P.args, **kwargs: _P.kwargs) -> Callback[[], _R1]:
def with_args(
self, *args: _P.args, **kwargs: _P.kwargs
) -> _AwaitableCallback[[], _R1]:
"""Return a partial callback."""
return self.__class__(partial(self._func, *args, **kwargs))

def copy(self) -> _AwaitableCallback[_P, _R1]:
"""Return a copy of the callback."""
return self.__class__(self._func)

def await_call(self, timeout: float = -1) -> None:
"""
Await the callback to be called.

Usage
-----
>>> cb = thread_worker.callback(func)
>>> yield cb
>>> cb.await_call() # stop here until callback is called
"""
if timeout <= 0:
while not self._called:
time.sleep(0.01)
return None
t0 = time.time()
while not self._called:
time.sleep(0.01)
if time.time() - t0 > timeout:
raise TimeoutError(
f"Callback {self} did not finish within {timeout} seconds."
)
return None


class Callback(_AwaitableCallback[_P, _R1]):
"""Callback object that can be recognized by thread_worker."""

@overload
def __get__(self, obj: Any, type=None) -> Callback[..., _R1]:
...
Expand All @@ -131,12 +169,27 @@ def __get__(self, obj, type=None):
return self
return self.__class__(partial(self._func, obj))

def with_args(self, *args: _P.args, **kwargs: _P.kwargs) -> Callback[[], _R1]:
"""Return a partial callback."""
return self.__class__(partial(self._func, *args, **kwargs))

def arun(self, *args: _P.args, **kwargs: _P.kwargs) -> CallbackTask[_R1]:
"""Run the callback in a thread."""
return CallbackTask(self.with_args(*args, **kwargs))


class NestedCallback(_AwaitableCallback[_P, _R1]):
def with_args(self, *args: _P.args, **kwargs: _P.kwargs) -> NestedCallback[_P, _R1]:
"""Return a partial callback."""
return self.__class__(partial(self._func, *args, **kwargs))


class CallbackTask(Generic[_R1]):
"""A class to make the syntax of thread_worker and Callback similar."""

class NestedCallback:
def __init__(self, cb: Callable[..., Any], *args):
self._cb = cb
self._args = args
def __init__(self, callback: Callback[[], _R1]):
self._callback = callback

def call(self):
"""Call the callback function."""
return self._cb(*self._args)
def __iter__(self):
yield self._callback
self._callback.await_call()
116 changes: 116 additions & 0 deletions magicclass/utils/qthreading/_to_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

from typing import Any, TYPE_CHECKING
from macrokit import Expr, Head, Symbol, parse

if TYPE_CHECKING:
from magicclass._gui import BaseGui


def _is_thread_worker_call(line: Expr, ns: dict[str, Any]) -> bool:
if line.head is not Head.call:
return False
_f = line.args[0]
if not (isinstance(_f, Expr) and _f.head is Head.getattr):
return False
func_obj = _f.eval(ns)
return hasattr(func_obj, "__thread_worker__")


def _rewrite_thread_worker_call(line: Expr) -> Expr:
assert line.head is Head.call
a0 = line.args[0]
assert a0.head is Head.getattr
_with_arun = Expr(Head.getattr, [a0, "arun"])
expr = Expr(Head.call, [_with_arun] + line.args[1:])
return parse(f"yield from {expr}")


def _rewrite_callback(lines: list[Expr]) -> Expr:
func_body = Expr(Head.block, lines)
cb_expr = Symbol("__magicclass_temp_callback")
funcdef = Expr(Head.function, [Expr(Head.call, [cb_expr]), func_body])
funcdef_dec = Expr(Head.decorator, [DEC_CB, funcdef])
cb_yield = Expr(Head.yield_, [cb_expr])
cb_await = Expr(Head.call, [Expr(Head.getattr, [cb_expr, "await_call"])])
return Expr(Head.block, [funcdef_dec, cb_yield, cb_await])


CAN_PASS = frozenset(
[Head.yield_, Head.comment, Head.class_, Head.function, Head.decorator]
)
FORBIDDEN = frozenset([Head.import_, Head.from_])
DECORATOR = parse("thread_worker(force_async=True)")
DEC_CB = parse("thread_worker.callback")


def _to_async_code_list(code_lines: list[Symbol | Expr], ui: BaseGui) -> Expr:
lines: list[Expr] = []
stack: list[Expr] = []

def _flush_stack():
if stack:
lines.append(_rewrite_callback(stack))
stack.clear()
return None

ns = {str(ui._my_symbol): ui}
for line in code_lines:
if isinstance(line, Symbol):
lines.append(line)
if _is_thread_worker_call(line, ns):
_flush_stack()
lines.append(_rewrite_thread_worker_call(line))
elif line.head in CAN_PASS:
lines.append(line)
elif line.head in FORBIDDEN:
raise ValueError(f"Cannot use {line.head} in async code: {line}")
elif line.head is Head.if_:
lines.append(_wrap_blocks(line, [1, 2], ui))
elif line.head in (Head.for_, Head.while_, Head.with_):
lines.append(_wrap_blocks(line, [1], ui))
elif line.head is Head.try_:
lines.append(_wrap_blocks(line, [0, 2, 3, 4], ui))
else:
stack.append(line)
if stack:
lines.append(_rewrite_callback(stack))
return lines


def _wrap_blocks(line: Expr, idx: list[int], ui: BaseGui) -> Expr:
out = []
for i, arg in enumerate(line.args):
if i in idx:
blk = _to_async_code_list(arg.args, ui)
out.append(Expr(Head.block, blk))
else:
out.append(arg)
return Expr(line.head, out)


def to_async_code(code: Expr, ui: BaseGui) -> Expr:
"""Convert the code to async-compatible code."""
assert code.head is Head.block
lines = _to_async_code_list(code.args, ui)
_fn = Symbol("_")
func_body = Expr(Head.block, lines)
funcdef = Expr(Head.function, [Expr(Head.call, [_fn, ui._my_symbol]), func_body])
funcdef_dec = Expr(Head.decorator, [DECORATOR, funcdef])
descriptor = Expr(Head.call, [Expr(Head.getattr, [_fn, "__get__"]), ui._my_symbol])
funccall = Expr(Head.call, [descriptor]) # -> _.__get__(ui)()
return Expr(Head.block, [funcdef_dec, funccall])


def run_async(code: Expr, ui: BaseGui, ns: dict[str, Any] = {}) -> Any | None:
"""Run the code in a thread worker."""
from .thread_worker import thread_worker

_ns = dict(
**{
str(ui._my_symbol): ui,
"thread_worker": thread_worker,
},
)
_ns.update(ns)
return to_async_code(code, ui).eval(_ns)
Loading