diff --git a/newsfragments/1244.bugfix.rst b/newsfragments/1244.bugfix.rst new file mode 100644 index 0000000000..6245199a2b --- /dev/null +++ b/newsfragments/1244.bugfix.rst @@ -0,0 +1 @@ +Added a helpful error message if an async function is passed to `trio.from_thread.run_sync` or a sync function to `trio.from_thread.run`. diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 5904e682fd..f77860d2f2 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -17,7 +17,6 @@ from sniffio import current_async_library_cvar import attr -from async_generator import isasyncgen from sortedcontainers import SortedDict from outcome import Error, Value, capture @@ -36,7 +35,7 @@ ) from .. import _core from .._deprecate import deprecated -from .._util import Final, NoPublicConstructor +from .._util import Final, NoPublicConstructor, coroutine_or_error _NO_SEND = object() @@ -1247,86 +1246,7 @@ def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False): # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. ###### - - def _return_value_looks_like_wrong_library(value): - # Returned by legacy @asyncio.coroutine functions, which includes - # a surprising proportion of asyncio builtins. - if isinstance(value, collections.abc.Generator): - return True - # The protocol for detecting an asyncio Future-like object - if getattr(value, "_asyncio_future_blocking", None) is not None: - return True - # This janky check catches tornado Futures and twisted Deferreds. - # By the time we're calling this function, we already know - # something has gone wrong, so a heuristic is pretty safe. - if value.__class__.__name__ in ("Future", "Deferred"): - return True - return False - - try: - coro = async_fn(*args) - except TypeError: - # Give good error for: nursery.start_soon(trio.sleep(1)) - if isinstance(async_fn, collections.abc.Coroutine): - raise TypeError( - "Trio was expecting an async function, but instead it got " - "a coroutine object {async_fn!r}\n" - "\n" - "Probably you did something like:\n" - "\n" - " trio.run({async_fn.__name__}(...)) # incorrect!\n" - " nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n" - "\n" - "Instead, you want (notice the parentheses!):\n" - "\n" - " trio.run({async_fn.__name__}, ...) # correct!\n" - " nursery.start_soon({async_fn.__name__}, ...) # correct!" - .format(async_fn=async_fn) - ) from None - - # Give good error for: nursery.start_soon(future) - if _return_value_looks_like_wrong_library(async_fn): - raise TypeError( - "Trio was expecting an async function, but instead it got " - "{!r} – are you trying to use a library written for " - "asyncio/twisted/tornado or similar? That won't work " - "without some sort of compatibility shim." - .format(async_fn) - ) from None - - raise - - # We can't check iscoroutinefunction(async_fn), because that will fail - # for things like functools.partial objects wrapping an async - # function. So we have to just call it and then check whether the - # return value is a coroutine object. - if not isinstance(coro, collections.abc.Coroutine): - # Give good error for: nursery.start_soon(func_returning_future) - if _return_value_looks_like_wrong_library(coro): - raise TypeError( - "start_soon got unexpected {!r} – are you trying to use a " - "library written for asyncio/twisted/tornado or similar? " - "That won't work without some sort of compatibility shim." - .format(coro) - ) - - if isasyncgen(coro): - raise TypeError( - "start_soon expected an async function but got an async " - "generator {!r}".format(coro) - ) - - # Give good error for: nursery.start_soon(some_sync_fn) - raise TypeError( - "Trio expected an async function, but {!r} appears to be " - "synchronous".format( - getattr(async_fn, "__qualname__", async_fn) - ) - ) - - ###### - # Set up the Task object - ###### + coro = coroutine_or_error(async_fn, *args) if name is None: name = async_fn @@ -1353,6 +1273,9 @@ async def python_wrapper(orig_coro): LOCALS_KEY_KI_PROTECTION_ENABLED, system_task ) + ###### + # Set up the Task object + ###### task = Task._create( coro=coro, parent_nursery=nursery, diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index e705af5c22..53446a601f 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -15,7 +15,11 @@ import sniffio import pytest -from .tutil import slow, check_sequence_matches, gc_collect_harder +from .tutil import ( + slow, check_sequence_matches, gc_collect_harder, + ignore_coroutine_never_awaited_warnings +) + from ... import _core from ..._threads import to_thread_run_sync from ..._timeouts import sleep, fail_after @@ -33,24 +37,6 @@ async def sleep_forever(): return await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -# Some of our tests need to leak coroutines, and thus trigger the -# "RuntimeWarning: coroutine '...' was never awaited" message. This context -# manager should be used anywhere this happens to hide those messages, because -# when expected they're clutter. -@contextmanager -def ignore_coroutine_never_awaited_warnings(): - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="coroutine '.*' was never awaited" - ) - try: - yield - finally: - # Make sure to trigger any coroutine __del__ methods now, before - # we leave the context manager. - gc_collect_harder() - - def test_basic(): async def trivial(x): return x @@ -1696,8 +1682,6 @@ async def test_current_effective_deadline(mock_clock): assert _core.current_effective_deadline() == inf -# @coroutine is deprecated since python 3.8, which is fine with us. -@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning") def test_nice_error_on_bad_calls_to_run_or_spawn(): def bad_call_run(*args): _core.run(*args) @@ -1709,59 +1693,22 @@ async def main(): _core.run(main) - class Deferred: - "Just kidding" - - with ignore_coroutine_never_awaited_warnings(): - for bad_call in bad_call_run, bad_call_spawn: - - async def f(): # pragma: no cover - pass - - with pytest.raises(TypeError) as excinfo: - bad_call(f()) - assert "expecting an async function" in str(excinfo.value) - - import asyncio - - @asyncio.coroutine - def generator_based_coro(): # pragma: no cover - yield from asyncio.sleep(1) - - with pytest.raises(TypeError) as excinfo: - bad_call(generator_based_coro()) - assert "asyncio" in str(excinfo.value) + for bad_call in bad_call_run, bad_call_spawn: - with pytest.raises(TypeError) as excinfo: - bad_call(asyncio.Future()) - assert "asyncio" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - bad_call(lambda: asyncio.Future()) - assert "asyncio" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - bad_call(Deferred()) - assert "twisted" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - bad_call(lambda: Deferred()) - assert "twisted" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - bad_call(len, [1, 2, 3]) - assert "appears to be synchronous" in str(excinfo.value) + async def f(): # pragma: no cover + pass - async def async_gen(arg): # pragma: no cover - yield + with pytest.raises(TypeError, match="expecting an async function"): + bad_call(f()) - with pytest.raises(TypeError) as excinfo: - bad_call(async_gen, 0) - msg = "expected an async function but got an async generator" - assert msg in str(excinfo.value) + async def async_gen(arg): # pragma: no cover + yield arg - # Make sure no references are kept around to keep anything alive - del excinfo + with pytest.raises( + TypeError, + match="expected an async function but got an async generator" + ): + bad_call(async_gen, 0) def test_calling_asyncio_function_gives_nice_error(): diff --git a/trio/_core/tests/tutil.py b/trio/_core/tests/tutil.py index dac53b81fd..ac090cb8de 100644 --- a/trio/_core/tests/tutil.py +++ b/trio/_core/tests/tutil.py @@ -3,6 +3,8 @@ import os import pytest +import warnings +from contextlib import contextmanager import gc @@ -52,6 +54,24 @@ def gc_collect_harder(): gc.collect() +# Some of our tests need to leak coroutines, and thus trigger the +# "RuntimeWarning: coroutine '...' was never awaited" message. This context +# manager should be used anywhere this happens to hide those messages, because +# when expected they're clutter. +@contextmanager +def ignore_coroutine_never_awaited_warnings(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="coroutine '.*' was never awaited" + ) + try: + yield + finally: + # Make sure to trigger any coroutine __del__ methods now, before + # we leave the context manager. + gc_collect_harder() + + # template is like: # [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3] def check_sequence_matches(seq, template): diff --git a/trio/_threads.py b/trio/_threads.py index 811bc526a0..c03a353789 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -3,12 +3,14 @@ from itertools import count import attr +import inspect import outcome import trio from ._sync import CapacityLimiter from ._core import enable_ki_protection, disable_ki_protection, RunVar, TrioToken +from ._util import coroutine_or_error # Global due to Threading API, thread local storage for trio token TOKEN_LOCAL = threading.local() @@ -365,6 +367,7 @@ def from_thread_run(afn, *args, trio_token=None): which would otherwise cause a deadlock. AttributeError: if no ``trio_token`` was provided, and we can't infer one from context. + TypeError: if ``afn`` is not an asynchronous function. **Locating a Trio Token**: There are two ways to specify which `trio.run` loop to reenter: @@ -380,7 +383,8 @@ def from_thread_run(afn, *args, trio_token=None): def callback(q, afn, args): @disable_ki_protection async def unprotected_afn(): - return await afn(*args) + coro = coroutine_or_error(afn, *args) + return await coro async def await_in_trio_thread_task(): q.put_nowait(await outcome.acapture(unprotected_afn)) @@ -403,13 +407,11 @@ def from_thread_run_sync(fn, *args, trio_token=None): Raises: RunFinishedError: if the corresponding call to `trio.run` has already completed. - Cancelled: if the corresponding call to `trio.run` completes - while ``afn(*args)`` is running, then ``afn`` is likely to raise - :exc:`trio.Cancelled`, and this will propagate out into RuntimeError: if you try calling this from inside the Trio thread, which would otherwise cause a deadlock. AttributeError: if no ``trio_token`` was provided, and we can't infer one from context. + TypeError: if ``fn`` is an async function. **Locating a Trio Token**: There are two ways to specify which `trio.run` loop to reenter: @@ -425,7 +427,17 @@ def from_thread_run_sync(fn, *args, trio_token=None): def callback(q, fn, args): @disable_ki_protection def unprotected_fn(): - return fn(*args) + ret = fn(*args) + + if inspect.iscoroutine(ret): + # Manually close coroutine to avoid RuntimeWarnings + ret.close() + raise TypeError( + "Trio expected a sync function, but {!r} appears to be " + "asynchronous".format(getattr(fn, "__qualname__", fn)) + ) + + return ret res = outcome.capture(unprotected_fn) q.put_nowait(res) diff --git a/trio/_util.py b/trio/_util.py index b331c8b48f..06c3a29a19 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -8,6 +8,9 @@ from functools import wraps, update_wrapper import typing as t import threading +import collections + +from async_generator import isasyncgen from ._deprecate import warn_deprecated @@ -85,6 +88,91 @@ def is_main_thread(): return False +###### +# Call the function and get the coroutine object, while giving helpful +# errors for common mistakes. Returns coroutine object. +###### +def coroutine_or_error(async_fn, *args): + def _return_value_looks_like_wrong_library(value): + # Returned by legacy @asyncio.coroutine functions, which includes + # a surprising proportion of asyncio builtins. + if isinstance(value, collections.abc.Generator): + return True + # The protocol for detecting an asyncio Future-like object + if getattr(value, "_asyncio_future_blocking", None) is not None: + return True + # This janky check catches tornado Futures and twisted Deferreds. + # By the time we're calling this function, we already know + # something has gone wrong, so a heuristic is pretty safe. + if value.__class__.__name__ in ("Future", "Deferred"): + return True + return False + + try: + coro = async_fn(*args) + + except TypeError: + # Give good error for: nursery.start_soon(trio.sleep(1)) + if isinstance(async_fn, collections.abc.Coroutine): + # explicitly close coroutine to avoid RuntimeWarning + async_fn.close() + + raise TypeError( + "Trio was expecting an async function, but instead it got " + "a coroutine object {async_fn!r}\n" + "\n" + "Probably you did something like:\n" + "\n" + " trio.run({async_fn.__name__}(...)) # incorrect!\n" + " nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n" + "\n" + "Instead, you want (notice the parentheses!):\n" + "\n" + " trio.run({async_fn.__name__}, ...) # correct!\n" + " nursery.start_soon({async_fn.__name__}, ...) # correct!" + .format(async_fn=async_fn) + ) from None + + # Give good error for: nursery.start_soon(future) + if _return_value_looks_like_wrong_library(async_fn): + raise TypeError( + "Trio was expecting an async function, but instead it got " + "{!r} – are you trying to use a library written for " + "asyncio/twisted/tornado or similar? That won't work " + "without some sort of compatibility shim.".format(async_fn) + ) from None + + raise + + # We can't check iscoroutinefunction(async_fn), because that will fail + # for things like functools.partial objects wrapping an async + # function. So we have to just call it and then check whether the + # return value is a coroutine object. + if not isinstance(coro, collections.abc.Coroutine): + # Give good error for: nursery.start_soon(func_returning_future) + if _return_value_looks_like_wrong_library(coro): + raise TypeError( + "Trio got unexpected {!r} – are you trying to use a " + "library written for asyncio/twisted/tornado or similar? " + "That won't work without some sort of compatibility shim." + .format(coro) + ) + + if isasyncgen(coro): + raise TypeError( + "start_soon expected an async function but got an async " + "generator {!r}".format(coro) + ) + + # Give good error for: nursery.start_soon(some_sync_fn) + raise TypeError( + "Trio expected an async function, but {!r} appears to be " + "synchronous".format(getattr(async_fn, "__qualname__", async_fn)) + ) + + return coro + + class ConflictDetector: """Detect when two tasks are about to perform operations that would conflict. diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index 29d44adc4a..6f5d2b6229 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -13,7 +13,6 @@ ) from .._core.tests.test_ki import ki_self -from .._core.tests.tutil import slow async def test_do_in_trio_thread(): @@ -471,6 +470,16 @@ def thread_fn(): trio_time = await to_thread_run_sync(thread_fn) assert isinstance(trio_time, float) + # Test correct error when passed async function + async def async_fn(): # pragma: no cover + pass + + def thread_fn(): + from_thread_run_sync(async_fn) + + with pytest.raises(TypeError, match="expected a sync function"): + await to_thread_run_sync(thread_fn) + async def test_trio_from_thread_run(): # Test that to_thread_run_sync correctly "hands off" the trio token to @@ -488,6 +497,13 @@ def thread_fn(): await to_thread_run_sync(thread_fn) assert record == ["in thread", "back in trio"] + # Test correct error when passed sync function + def sync_fn(): # pragma: no cover + pass + + with pytest.raises(TypeError, match="appears to be synchronous"): + await to_thread_run_sync(from_thread_run, sync_fn) + async def test_trio_from_thread_token(): # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index d57b1997df..009f9fa8f7 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -1,12 +1,13 @@ import signal - import pytest import trio from .. import _core +from .._core.tests.tutil import ignore_coroutine_never_awaited_warnings from .._util import ( - signal_raise, ConflictDetector, is_main_thread, generic_function, Final, - NoPublicConstructor, SubclassingDeprecatedIn_v0_15_0 + signal_raise, ConflictDetector, is_main_thread, coroutine_or_error, + generic_function, Final, NoPublicConstructor, + SubclassingDeprecatedIn_v0_15_0 ) from ..testing import wait_all_tasks_blocked @@ -82,6 +83,64 @@ def not_main_thread(): await trio.to_thread.run_sync(not_main_thread) +# @coroutine is deprecated since python 3.8, which is fine with us. +@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning") +def test_coroutine_or_error(): + class Deferred: + "Just kidding" + + with ignore_coroutine_never_awaited_warnings(): + + async def f(): # pragma: no cover + pass + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(f()) + assert "expecting an async function" in str(excinfo.value) + + import asyncio + + @asyncio.coroutine + def generator_based_coro(): # pragma: no cover + yield from asyncio.sleep(1) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(generator_based_coro()) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(asyncio.Future()) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(lambda: asyncio.Future()) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(Deferred()) + assert "twisted" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(lambda: Deferred()) + assert "twisted" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(len, [[1, 2, 3]]) + + assert "appears to be synchronous" in str(excinfo.value) + + async def async_gen(arg): # pragma: no cover + yield + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(async_gen, [0]) + msg = "expected an async function but got an async generator" + assert msg in str(excinfo.value) + + # Make sure no references are kept around to keep anything alive + del excinfo + + def test_generic_function(): @generic_function def test_func(arg):