diff --git a/Cargo.lock b/Cargo.lock index 12e04864f7..004d1532c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -623,26 +623,53 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541" +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-lite" @@ -659,6 +686,47 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.28.0" @@ -1661,6 +1729,7 @@ name = "uniffi-fixture-futures" version = "0.21.0" dependencies = [ "async-trait", + "futures", "once_cell", "thiserror", "tokio", diff --git a/docs/manual/src/futures.md b/docs/manual/src/futures.md index 0e1d22ea85..167de96925 100644 --- a/docs/manual/src/futures.md +++ b/docs/manual/src/futures.md @@ -71,3 +71,63 @@ pub trait SayAfterTrait: Send + Sync { async fn say_after(&self, ms: u16, who: String) -> String; } ``` + +## Blocking tasks + +Rust executors are designed around an assumption that the `Future::poll` function will return quickly. +This assumption, combined with cooperative scheduling, allows for a large number of futures to be handled by a small number of threads. +Foreign executors make similar assumptions and sometimes more extreme ones. +For example, the Python eventloop is single threaded -- if any task spends a long time between `await` points, then it will block all other tasks from progressing. + +This raises the question of how async code can interact with blocking code that performs blocking IO, long-running computations without `await` breaks, etc. +UniFFI defines the `BlockingTaskQueue` type, which is a foreign object that schedules work on a thread where it's okay to block. + +On Rust, `BlockingTaskQueue` is a UniFFI type that can safely run blocking code. +It's `execute` method works like tokio's [block_in_place](https://docs.rs/tokio/latest/tokio/task/fn.block_in_place.html) function. +It inputs a closure and runs it in the `BlockingTaskQueue`. +This closure can reference the outside scope (i.e. it does not need to be `'static`). +For example: + +```rust +#[derive(uniffi::Object)] +struct DataStore { + // Used to run blocking tasks + queue: uniffi::BlockingTaskQueue, + // Low-level DB object with blocking methods + db: Mutex, +} + +#[uniffi::export] +impl DataStore { + #[uniffi::constructor] + fn new(queue: uniffi::BlockingTaskQueue) -> Self { + Self { + queue, + db: Mutex::new(Database::new()) + } + } + + async fn fetch_all_items(&self) -> Vec { + self.queue.execute(|| self.db.lock().fetch_all_items()).await + } +} +``` + +On the foreign side `BlockingTaskQueue` corresponds to a language-dependent class. + +### Kotlin +Kotlin uses `CoroutineContext` for its `BlockingTaskQueue`. +Any `CoroutineContext` will work, but `Dispatchers.IO` is usually a good choice. +A DataStore from the example above can be created with `DataStore(Dispatchers.IO)`. + +### Swift +Swift uses `DispatchQueue` for its `BlockingTaskQueue`. +The user-initiated global queue is normally a good choice. +A DataStore from the example above can be created with `DataStore(queue: DispatchQueue.global(qos: .userInitiated)`. +The `DispatchQueue` should be concurrent. + +### Python + +Python uses a `futures.Executor` for its `BlockingTaskQueue`. +`ThreadPoolExecutor` is typically a good choice. +A DataStore from the example above can be created with `DataStore(ThreadPoolExecutor())`. diff --git a/fixtures/futures/Cargo.toml b/fixtures/futures/Cargo.toml index f386c6d85c..0bb6f1c04d 100644 --- a/fixtures/futures/Cargo.toml +++ b/fixtures/futures/Cargo.toml @@ -17,6 +17,7 @@ path = "src/bin.rs" [dependencies] uniffi = { workspace = true, features = ["tokio", "cli"] } async-trait = "0.1" +futures = "0.3.29" thiserror = "1.0" tokio = { version = "1.24.1", features = ["time", "sync"] } once_cell = "1.18.0" diff --git a/fixtures/futures/src/lib.rs b/fixtures/futures/src/lib.rs index 4b4ed1cca9..8e58639f66 100644 --- a/fixtures/futures/src/lib.rs +++ b/fixtures/futures/src/lib.rs @@ -11,6 +11,8 @@ use std::{ time::Duration, }; +use futures::stream::{FuturesUnordered, StreamExt}; + /// Non-blocking timer future. pub struct TimerFuture { shared_state: Arc>, @@ -385,6 +387,59 @@ impl SayAfterUdlTrait for SayAfterImpl2 { #[uniffi::export] fn get_say_after_udl_traits() -> Vec> { vec![Arc::new(SayAfterImpl1), Arc::new(SayAfterImpl2)] + +/// Async function that uses a blocking task queue to do its work +#[uniffi::export] +pub async fn calc_square(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.execute(|| value * value).await +} + +/// Same as before, but this one runs multiple tasks +#[uniffi::export] +pub async fn calc_squares(queue: uniffi::BlockingTaskQueue, items: Vec) -> Vec { + // Use `FuturesUnordered` to test our blocking task queue code which is known to be a tricky API to work with. + // In particular, if we don't notify the waker then FuturesUnordered will not poll again. + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queue.execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// ...and this one uses multiple BlockingTaskQueues +#[uniffi::export] +pub async fn calc_squares_multi_queue( + queues: Vec, + items: Vec, +) -> Vec { + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queues[i].execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// Like calc_square, but it clones the BlockingTaskQueue first then drops both copies. Used to +/// test that a) the clone works and b) we correctly drop the references. +#[uniffi::export] +pub async fn calc_square_with_clone(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.clone().execute(|| value * value).await } uniffi::include_scaffolding!("futures"); diff --git a/fixtures/futures/tests/bindings/test_futures.kts b/fixtures/futures/tests/bindings/test_futures.kts index 175f4a619a..f8cf0cb4e3 100644 --- a/fixtures/futures/tests/bindings/test_futures.kts +++ b/fixtures/futures/tests/bindings/test_futures.kts @@ -1,9 +1,22 @@ import uniffi.fixture.futures.* +import java.util.concurrent.Executors import kotlinx.coroutines.* import kotlin.system.* +fun runAsyncTest(test: suspend CoroutineScope.() -> Unit) { + val initialBlockingTaskQueueHandleCount = uniffiBlockingTaskQueueHandleCount() + val initialPollHandleCount = uniffiPollHandleCount() + val time = runBlocking { + measureTimeMillis { + test() + } + } + assert(uniffiBlockingTaskQueueHandleCount() == initialBlockingTaskQueueHandleCount) + assert(uniffiPollHandleCount() == initialPollHandleCount) +} + // init UniFFI to get good measurements after that -runBlocking { +runAsyncTest { val time = measureTimeMillis { alwaysReady() } @@ -24,7 +37,7 @@ fun assertApproximateTime(actualTime: Long, expectedTime: Int, testName: String } // Test `always_ready`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = alwaysReady() @@ -35,7 +48,7 @@ runBlocking { } // Test `void`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = void() @@ -46,7 +59,7 @@ runBlocking { } // Test `sleep`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { sleep(200U) } @@ -55,7 +68,7 @@ runBlocking { } // Test sequential futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfter(100U, "Alice") val resultBob = sayAfter(200U, "Bob") @@ -68,7 +81,7 @@ runBlocking { } // Test concurrent futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = async { sayAfter(100U, "Alice") } val resultBob = async { sayAfter(200U, "Bob") } @@ -81,7 +94,7 @@ runBlocking { } // Test async methods. -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = megaphone.sayAfter(200U, "Alice") @@ -92,7 +105,7 @@ runBlocking { assertApproximateTime(time, 200, "async methods") } -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = sayAfterWithMegaphone(megaphone, 200U, "Alice") @@ -104,7 +117,7 @@ runBlocking { } // Test async method returning optional object -runBlocking { +runAsyncTest { val megaphone = asyncMaybeNewMegaphone(true) assert(megaphone != null) @@ -141,7 +154,7 @@ runBlocking { } // Test with the Tokio runtime. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfterWithTokio(200U, "Alice") @@ -152,7 +165,7 @@ runBlocking { } // Test fallible function/method. -runBlocking { +runAsyncTest { val time1 = measureTimeMillis { try { fallibleMe(false) @@ -217,7 +230,7 @@ runBlocking { } // Test record. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = newMyRecord("foo", 42U) @@ -231,7 +244,7 @@ runBlocking { } // Test a broken sleep. -runBlocking { +runAsyncTest { val time = measureTimeMillis { brokenSleep(100U, 0U) // calls the waker twice immediately sleep(100U) // wait for possible failure @@ -245,7 +258,7 @@ runBlocking { // Test a future that uses a lock and that is cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val job = launch { useSharedResource(SharedResourceOptions(releaseAfterMs=5000U, timeoutMs=100U)) @@ -264,7 +277,7 @@ runBlocking { } // Test a future that uses a lock and that is not cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { useSharedResource(SharedResourceOptions(releaseAfterMs=100U, timeoutMs=1000U)) @@ -272,3 +285,33 @@ runBlocking { } println("useSharedResource (not canceled): ${time}ms") } + +// Test blocking task queues +runAsyncTest { + withTimeout(1000) { + assert(calcSquare(Dispatchers.IO, 20) == 400) + } + + withTimeout(1000) { + assert(calcSquares(Dispatchers.IO, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + + val executors = listOf( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + ) + withTimeout(1000) { + assert(calcSquaresMultiQueue(executors.map { it.asCoroutineDispatcher() }, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + for (executor in executors) { + executor.shutdown() + } +} + +// Test blocking task queue cloning +runAsyncTest { + withTimeout(1000) { + assert(calcSquareWithClone(Dispatchers.IO, 20) == 400) + } +} diff --git a/fixtures/futures/tests/bindings/test_futures.py b/fixtures/futures/tests/bindings/test_futures.py index 7b2e12324f..4365689aab 100644 --- a/fixtures/futures/tests/bindings/test_futures.py +++ b/fixtures/futures/tests/bindings/test_futures.py @@ -1,26 +1,32 @@ +import futures from futures import * +import contextlib import unittest from datetime import datetime import asyncio import typing +from concurrent.futures import ThreadPoolExecutor def now(): return datetime.now() class TestFutures(unittest.TestCase): def test_always_ready(self): + @self.check_handle_counts() async def test(): self.assertEqual(await always_ready(), True) asyncio.run(test()) def test_void(self): + @self.check_handle_counts() async def test(): self.assertEqual(await void(), None) asyncio.run(test()) def test_sleep(self): + @self.check_handle_counts() async def test(): t0 = now() await sleep(2000) @@ -32,6 +38,7 @@ async def test(): asyncio.run(test()) def test_sequential_futures(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after(100, 'Alice') @@ -46,6 +53,7 @@ async def test(): asyncio.run(test()) def test_concurrent_tasks(self): + @self.check_handle_counts() async def test(): alice = asyncio.create_task(say_after(100, 'Alice')) bob = asyncio.create_task(say_after(200, 'Bob')) @@ -63,6 +71,7 @@ async def test(): asyncio.run(test()) def test_async_methods(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -106,6 +115,7 @@ async def test(): asyncio.run(test()) def test_async_object_param(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -119,6 +129,7 @@ async def test(): asyncio.run(test()) def test_with_tokio_runtime(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after_with_tokio(200, 'Alice') @@ -131,6 +142,7 @@ async def test(): asyncio.run(test()) def test_fallible(self): + @self.check_handle_counts() async def test(): result = await fallible_me(False) self.assertEqual(result, 42) @@ -155,6 +167,7 @@ async def test(): asyncio.run(test()) def test_fallible_struct(self): + @self.check_handle_counts() async def test(): megaphone = await fallible_struct(False) self.assertEqual(await megaphone.fallible_me(False), 42) @@ -168,6 +181,7 @@ async def test(): asyncio.run(test()) def test_record(self): + @self.check_handle_counts() async def test(): result = await new_my_record("foo", 42) self.assertEqual(result.__class__, MyRecord) @@ -177,6 +191,7 @@ async def test(): asyncio.run(test()) def test_cancel(self): + @self.check_handle_counts() async def test(): # Create a task task = asyncio.create_task(say_after(200, 'Alice')) @@ -194,6 +209,7 @@ async def test(): # Test a future that uses a lock and that is cancelled. def test_shared_resource_cancellation(self): + @self.check_handle_counts() async def test(): task = asyncio.create_task(use_shared_resource( SharedResourceOptions(release_after_ms=5000, timeout_ms=100))) @@ -204,6 +220,7 @@ async def test(): asyncio.run(test()) def test_shared_resource_no_cancellation(self): + @self.check_handle_counts() async def test(): await use_shared_resource(SharedResourceOptions(release_after_ms=100, timeout_ms=1000)) await use_shared_resource(SharedResourceOptions(release_after_ms=0, timeout_ms=1000)) @@ -215,5 +232,47 @@ async def test(): self.assertEqual(typing.get_type_hints(sleep_no_return), {"ms": int, "return": type(None)}) asyncio.run(test()) + # blocking task queue tests + + def test_calc_square(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square(executor, 20), 400) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_square_with_clone(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square_with_clone(executor, 20), 400) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_squares(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_squares(executor, [1, -2, 3]), [1, 4, 9]) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_squares_multi_queue(self): + @self.check_handle_counts() + async def test(): + executors = [ + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ] + self.assertEqual(await calc_squares_multi_queue(executors, [1, -2, 3]), [1, 4, 9]) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + @contextlib.asynccontextmanager + async def check_handle_counts(self): + initial_poll_handle_count = len(futures.UNIFFI_POLL_DATA_POINTER_MANAGER) + initial_blocking_task_queue_handle_count = len(futures.UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP) + yield + self.assertEqual(len(futures.UNIFFI_POLL_DATA_POINTER_MANAGER), initial_poll_handle_count) + self.assertEqual(len(futures.UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP), initial_blocking_task_queue_handle_count) + if __name__ == '__main__': unittest.main() diff --git a/fixtures/futures/tests/bindings/test_futures.swift b/fixtures/futures/tests/bindings/test_futures.swift index 2fd413a8ab..e70cdbc5b7 100644 --- a/fixtures/futures/tests/bindings/test_futures.swift +++ b/fixtures/futures/tests/bindings/test_futures.swift @@ -3,10 +3,21 @@ import Foundation // To get `DispatchGroup` and `Date` types. var counter = DispatchGroup() -// Test `alwaysReady` -counter.enter() +func asyncTest(test: @escaping () async throws -> ()) { + let initialBlockingTaskQueueCount = uniffiBlockingTaskQueueHandleCountFutures() + let initialPollDataHandleCount = uniffiPollDataHandleCountFutures() + counter.enter() + Task { + try! await test() + counter.leave() + } + counter.wait() + assert(uniffiBlockingTaskQueueHandleCountFutures() == initialBlockingTaskQueueCount) + assert(uniffiPollDataHandleCountFutures() == initialPollDataHandleCount) +} -Task { +// Test `alwaysReady` +asyncTest { let t0 = Date() let result = await alwaysReady() let t1 = Date() @@ -14,40 +25,28 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) assert(result == true) - - counter.leave() } // Test record. -counter.enter() - -Task { +asyncTest { let result = await newMyRecord(a: "foo", b: 42) assert(result.a == "foo") assert(result.b == 42) - - counter.leave() } // Test `void` -counter.enter() - -Task { +asyncTest { let t0 = Date() await void() let t1 = Date() let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) - - counter.leave() } // Test `Sleep` -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = await sleep(ms: 2000) let t1 = Date() @@ -55,14 +54,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result == true) - - counter.leave() } // Test sequential futures. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfter(ms: 1000, who: "Alice") let result_bob = await sayAfter(ms: 2000, who: "Bob") @@ -72,14 +67,10 @@ Task { assert(tDelta.duration > 3 && tDelta.duration < 3.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test concurrent futures. -counter.enter() - -Task { +asyncTest { async let alice = sayAfter(ms: 1000, who: "Alice") async let bob = sayAfter(ms: 2000, who: "Bob") @@ -91,14 +82,10 @@ Task { assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test async methods -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -108,8 +95,6 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "HELLO, ALICE!") - - counter.leave() } // Test async trait interface methods @@ -151,21 +136,15 @@ Task { } // Test async function returning an object -counter.enter() - -Task { +asyncTest { let megaphone = await asyncNewMegaphone() let result = try await megaphone.fallibleMe(doFail: false) assert(result == 42) - - counter.leave() } // Test with the Tokio runtime. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfterWithTokio(ms: 2000, who: "Alice") let t1 = Date() @@ -173,15 +152,11 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice (with Tokio)!") - - counter.leave() } // Test fallible function/method… // … which doesn't throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = try await fallibleMe(doFail: false) let t1 = Date() @@ -189,19 +164,15 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } -Task { +asyncTest { let m = try await fallibleStruct(doFail: false) let result = try await m.fallibleMe(doFail: false) assert(result == 42) } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -211,14 +182,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } // … which does throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() do { @@ -233,11 +200,9 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } -Task { +asyncTest { do { let _ = try await fallibleStruct(doFail: true) } catch MyError.Foo { @@ -247,9 +212,7 @@ Task { } } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -266,13 +229,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } // Test a future that uses a lock and that is cancelled. -counter.enter() -Task { +asyncTest { let task = Task { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) } @@ -288,15 +248,36 @@ Task { // Try accessing the shared resource again. The initial task should release the shared resource // before the timeout expires. try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } // Test a future that uses a lock and that is not cancelled. -counter.enter() -Task { +asyncTest { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } -counter.wait() +// Test blocking task queues +asyncTest { + let calcSquareResult = await calcSquare(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) + + let calcSquaresResult = await calcSquares(queue: DispatchQueue.global(qos: .userInitiated), items: [1, -2, 3]) + assert(calcSquaresResult == [1, 4, 9]) + + let calcSquaresMultiQueueResult = await calcSquaresMultiQueue( + queues: [ + DispatchQueue(label: "test-queue1", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue2", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue3", attributes: DispatchQueue.Attributes.concurrent) + ], + items: [1, -2, 3] + ) + assert(calcSquaresMultiQueueResult == [1, 4, 9]) +} + +// Test blocking task queue cloning +asyncTest { + let calcSquareResult = await calcSquareWithClone(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) +} + diff --git a/fixtures/metadata/src/tests.rs b/fixtures/metadata/src/tests.rs index 2a280597ea..1a54cbe067 100644 --- a/fixtures/metadata/src/tests.rs +++ b/fixtures/metadata/src/tests.rs @@ -129,6 +129,7 @@ mod test_type_ids { check_type_id::(Type::Float64); check_type_id::(Type::Boolean); check_type_id::(Type::String); + check_type_id::(Type::BlockingTaskQueue); } #[test] diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs new file mode 100644 index 0000000000..a664f1fcd3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self, _ci: &crate::ComponentInterface) -> String { + // Kotlin uses CoroutineContext for BlockingTaskQueue + "CoroutineContext".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs index fecf303913..f3e7f30d11 100644 --- a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs @@ -16,6 +16,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -466,6 +467,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), @@ -637,7 +640,7 @@ mod filters { ) -> Result { let ffi_func = callable.ffi_rust_future_poll(ci); Ok(format!( - "{{ future, callback, continuation -> UniffiLib.INSTANCE.{ffi_func}(future, callback, continuation) }}" + "{{ future, callback, continuation, blockingTaskQueueHandle -> UniffiLib.INSTANCE.{ffi_func}(future, callback, continuation, blockingTaskQueueHandle) }}" )) } diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt index dc547d4ddf..e02462b9b6 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt @@ -3,18 +3,37 @@ internal const val UNIFFI_RUST_FUTURE_POLL_READY = 0.toByte() internal const val UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1.toByte() -internal val uniffiContinuationHandleMap = UniffiHandleMap>() +/** + * Data for an in-progress poll of a RustFuture + */ +internal data class UniffiPollData( + val continuation: CancellableContinuation, + val rustFuture: Long, + val pollFunc: (Long, UniffiRustFutureContinuationCallback, USize, Long) -> Unit, +) + +internal val uniffiPollDataHandleMap = UniffiHandleMap() // FFI type for Rust future continuations internal object uniffiRustFutureContinuationCallbackImpl: UniffiRustFutureContinuationCallback { - override fun callback(data: Long, pollResult: Byte) { - uniffiContinuationHandleMap.remove(data).resume(pollResult) + override fun callback(data: USize, pollResult: Byte, blockingTaskQueueHandle: Long) { + if (blockingTaskQueueHandle == 0L) { + // Complete the Kotlin continuation + uniffiPollDataHandleMap.remove(data)!!.continuation.resume(pollResult) + } else { + // Call the poll function again, but inside the BlockingTaskQueue coroutine context + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(blockingTaskQueueHandle) + val pollData = uniffiPollDataHandleMap.get(data)!! + CoroutineScope(coroutineContext).launch { + pollData.pollFunc(pollData.rustFuture, uniffiRustFutureContinuationCallbackImpl, data, blockingTaskQueueHandle) + } + } } } internal suspend fun uniffiRustCallAsync( rustFuture: Long, - pollFunc: (Long, UniffiRustFutureContinuationCallback, Long) -> Unit, + pollFunc: (Long, UniffiRustFutureContinuationCallback, Long, Long) -> Unit, completeFunc: (Long, UniffiRustCallStatus) -> F, freeFunc: (Long) -> Unit, liftFunc: (F) -> T, @@ -23,10 +42,12 @@ internal suspend fun uniffiRustCallAsync( try { do { val pollResult = suspendCancellableCoroutine { continuation -> + val pollData = UniffiPollData(continuation, rustFuture, pollFunc) pollFunc( rustFuture, uniffiRustFutureContinuationCallbackImpl, - uniffiContinuationHandleMap.insert(continuation) + uniffiPollDataHandleMap.insert(pollData), + 0L ) } } while (pollResult != UNIFFI_RUST_FUTURE_POLL_READY); @@ -39,3 +60,6 @@ internal suspend fun uniffiRustCallAsync( } } +// For testing +public fun uniffiPollHandleCount() = uniffiPollDataHandleMap.size + diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt new file mode 100644 index 0000000000..717ff62df3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt @@ -0,0 +1,44 @@ +{{ self.add_import("kotlin.coroutines.CoroutineContext") }} + +object uniffiBlockingTaskQueueClone : UniffiBlockingTaskQueueClone { + override fun callback(handle: Long): Long { + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(handle) + return uniffiBlockingTaskQueueHandleMap.insert(coroutineContext) + } +} + +object uniffiBlockingTaskQueueFree : UniffiBlockingTaskQueueFree { + override fun callback(handle: Long) { + uniffiBlockingTaskQueueHandleMap.remove(handle) + } +} + +internal val uniffiBlockingTaskQueueVTable = UniffiBlockingTaskQueueVTable( + uniffiBlockingTaskQueueClone, + uniffiBlockingTaskQueueFree, +) +internal val uniffiBlockingTaskQueueHandleMap = ConcurrentHandleMap() + +public object {{ ffi_converter_name }}: FfiConverterRustBuffer { + override fun allocationSize(value: {{ type_name }}) = 16 + + override fun write(value: CoroutineContext, buf: ByteBuffer) { + // Call `write()` to make sure the data is written to the JNA backing data + uniffiBlockingTaskQueueVTable.write() + val handle = uniffiBlockingTaskQueueHandleMap.insert(value) + buf.putLong(handle) + buf.putLong(Pointer.nativeValue(uniffiBlockingTaskQueueVTable.getPointer())) + } + + override fun read(buf: ByteBuffer): CoroutineContext { + val handle = buf.getLong() + val coroutineContext = uniffiBlockingTaskQueueHandleMap.remove(handle)!! + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + buf.getLong() + return coroutineContext + } +} + +// For testing +public fun uniffiBlockingTaskQueueHandleCount() = uniffiBlockingTaskQueueHandleMap.size diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt index ba56716401..552e269ab9 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt @@ -89,6 +89,9 @@ object NoPointer {%- when Type::Bytes %} {%- include "ByteArrayHelper.kt" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.kt" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {%- if !ci.is_name_used_as_error(name) %} @@ -134,6 +137,8 @@ object NoPointer {%- if ci.has_async_fns() %} {# Import types needed for async support #} {{ self.add_import("kotlin.coroutines.resume") }} +{{ self.add_import("kotlinx.coroutines.launch") }} {{ self.add_import("kotlinx.coroutines.suspendCancellableCoroutine") }} {{ self.add_import("kotlinx.coroutines.CancellableContinuation") }} +{{ self.add_import("kotlinx.coroutines.CoroutineScope") }} {%- endif %} diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt index 2cdc72a5e2..ea326bb242 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt @@ -32,6 +32,7 @@ import java.nio.CharBuffer import java.nio.charset.CodingErrorAction import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.withLock {%- for req in self.imports() %} {{ req.render() }} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs new file mode 100644 index 0000000000..2f5a74fda0 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + // On python we use an concurrent.futures.Executor for a BlockingTaskQueue + fn type_label(&self) -> String { + "concurrent.futures.Executor".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs index 6c1996eb5c..cc10d0cae5 100644 --- a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs +++ b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs @@ -16,6 +16,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -426,6 +427,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, .. } => Box::new(object::ObjectCodeType::new(name)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/python/templates/Async.py b/uniffi_bindgen/src/bindings/python/templates/Async.py index 4a230112ea..57ed68ddeb 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Async.py +++ b/uniffi_bindgen/src/bindings/python/templates/Async.py @@ -2,15 +2,38 @@ _UNIFFI_RUST_FUTURE_POLL_READY = 0 _UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1 -# Stores futures for _uniffi_continuation_callback -_UniffiContinuationHandleMap = _UniffiHandleMap() +""" +Data for an in-progress poll of a RustFuture +""" +class UniffiPoll(typing.NamedTuple): + eventloop: asyncio.AbstractEventLoop + future: asyncio.Future + rust_future: int + # Must be UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK, but it's not clear how to specify as valid + # type for mypy and our current Python version + ffi_poll: object + +_UniffiPollDataHandleMap = _UniffiHandleMap() # Continuation callback for async functions # lift the return value or error and resolve the future, causing the async function to resume. @UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK -def _uniffi_continuation_callback(future_ptr, poll_code): - (eventloop, future) = _UniffiContinuationHandleMap.remove(future_ptr) - eventloop.call_soon_threadsafe(_uniffi_set_future_result, future, poll_code) +def _uniffi_continuation_callback(poll_data_handle, poll_code, blocking_task_queue_handle): + if blocking_task_queue_handle == 0: + # Complete the Python Future + poll_data = _UniffiPollDataHandleMap.remove(poll_data_handle) + poll_data.eventloop.call_soon_threadsafe(_uniffi_set_future_result, poll_data.future, poll_code) + else: + # Call the poll function again, but inside the executor + poll_data = _UniffiPollDataHandleMap.get(poll_data_handle) + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(blocking_task_queue_handle) + executor.submit( + poll_data.ffi_poll, + poll_data.rust_future, + _uniffi_continuation_callback, + poll_data_handle, + blocking_task_queue_handle + ) def _uniffi_set_future_result(future, poll_code): if not future.cancelled(): @@ -23,10 +46,17 @@ async def _uniffi_rust_call_async(rust_future, ffi_poll, ffi_complete, ffi_free, # Loop and poll until we see a _UNIFFI_RUST_FUTURE_POLL_READY value while True: future = eventloop.create_future() + poll_data = UniffiPoll( + eventloop=eventloop, + future=future, + rust_future=rust_future, + ffi_poll=ffi_poll, + ) ffi_poll( rust_future, _uniffi_continuation_callback, - _UniffiContinuationHandleMap.insert((eventloop, future)), + _UniffiPollDataHandleMap.insert(poll_data), + 0, ) poll_code = await future if poll_code == _UNIFFI_RUST_FUTURE_POLL_READY: diff --git a/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py new file mode 100644 index 0000000000..329ff37906 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py @@ -0,0 +1,39 @@ +{{ self.add_import("concurrent.futures") }} + +@UNIFFI_BLOCKING_TASK_QUEUE_CLONE +def uniffi_blocking_task_queue_clone(handle): + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(handle) + return UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(executor) + +@UNIFFI_BLOCKING_TASK_QUEUE_FREE +def uniffi_blocking_task_queue_free(handle): + UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle) + +UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + uniffi_blocking_task_queue_clone, + uniffi_blocking_task_queue_free, +) + +UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP = ConcurrentHandleMap() + +class {{ ffi_converter_name }}(_UniffiConverterRustBuffer): + @staticmethod + def check_lower(value): + if not isinstance(value, concurrent.futures.Executor): + raise TypeError("Expected concurrent.futures.Executor instance, {} found".format(type(value).__name__)) + + @staticmethod + def write(value, buf): + handle = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(value) + buf.write_u64(handle) + buf.write_u64(ctypes.addressof(UNIFFI_BLOCKING_TASK_QUEUE_VTABLE)) + + @staticmethod + def read(buf): + handle = buf.read_u64() + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle) + # Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + # same value. + buf.read_u64() + return executor + diff --git a/uniffi_bindgen/src/bindings/python/templates/Types.py b/uniffi_bindgen/src/bindings/python/templates/Types.py index 4aaed253e0..9f2840fefb 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Types.py +++ b/uniffi_bindgen/src/bindings/python/templates/Types.py @@ -55,6 +55,9 @@ {%- when Type::Bytes %} {%- include "BytesHelper.py" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.py" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {# For enums, there are either an error *or* an enum, they can't be both. #} diff --git a/uniffi_bindgen/src/bindings/python/templates/wrapper.py b/uniffi_bindgen/src/bindings/python/templates/wrapper.py index 2050b8d589..0c2a7a3cc3 100644 --- a/uniffi_bindgen/src/bindings/python/templates/wrapper.py +++ b/uniffi_bindgen/src/bindings/python/templates/wrapper.py @@ -26,6 +26,7 @@ import threading import itertools import typing +import threading {%- if ci.has_async_fns() %} import asyncio {%- endif %} diff --git a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs index 07da3882b6..cb60300d65 100644 --- a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs +++ b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs @@ -57,6 +57,7 @@ pub fn canonical_name(t: &Type) -> String { Type::CallbackInterface { name, .. } => format!("CallbackInterface{name}"), Type::Timestamp => "Timestamp".into(), Type::Duration => "Duration".into(), + Type::BlockingTaskQueue => "BlockingTaskQueue".into(), // Recursive types. // These add a prefix to the name of the underlying type. // The component API definition cannot give names to recursive types, so as long as the @@ -262,6 +263,7 @@ mod filters { } Type::External { .. } => panic!("No support for external types, yet"), Type::Custom { .. } => panic!("No support for custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -315,6 +317,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lowering external types, yet"), Type::Custom { .. } => panic!("No support for lowering custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -355,6 +358,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lifting external types, yet"), Type::Custom { .. } => panic!("No support for lifting custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } } diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs new file mode 100644 index 0000000000..ab4df07f9b --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self) -> String { + // On Swift, we use a DispatchQueue for BlockingTaskQueue + "DispatchQueue".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs index 5717041b32..51b9a26461 100644 --- a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs @@ -18,6 +18,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -462,6 +463,8 @@ impl SwiftCodeOracle { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/swift/templates/Async.swift b/uniffi_bindgen/src/bindings/swift/templates/Async.swift index 761e0dd70e..b8675bc653 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Async.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Async.swift @@ -1,13 +1,30 @@ private let UNIFFI_RUST_FUTURE_POLL_READY: Int8 = 0 private let UNIFFI_RUST_FUTURE_POLL_MAYBE_READY: Int8 = 1 -fileprivate let uniffiContinuationHandleMap = UniffiHandleMap>() +// Data for an in-progress poll of a RustFuture +fileprivate class UniffiPollData { + let continuation: UnsafeContinuation + let rustFuture: UInt64 + let pollFunc: (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64, UInt64) -> () + + init( + continuation: UnsafeContinuation, + rustFuture: UInt64, + pollFunc: @escaping (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64, UInt64) -> () + ) { + self.continuation = continuation + self.rustFuture = rustFuture + self.pollFunc = pollFunc + } +} + +fileprivate let uniffiPollDataHandleMap = UniffiHandleMap() fileprivate func uniffiRustCallAsync( - rustFutureFunc: () -> UInt64, - pollFunc: (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64) -> (), - completeFunc: (UInt64, UnsafeMutablePointer) -> F, - freeFunc: (UInt64) -> (), + rustFutureFunc: () -> Int64, + pollFunc: @escaping (Int64, @escaping UniffiRustFutureContinuationCallback, Int64, UInt64) -> (), + completeFunc: (Int64, UnsafeMutablePointer) -> F, + freeFunc: (Int64) -> (), liftFunc: (F) throws -> T, errorHandler: ((RustBuffer) throws -> Error)? ) async throws -> T { @@ -21,11 +38,20 @@ fileprivate func uniffiRustCallAsync( var pollResult: Int8; repeat { pollResult = await withUnsafeContinuation { + let pollData = UniffiPollData( + continuation: $0, + rustFuture: rustFuture, + pollFunc: pollFunc + ) pollFunc( rustFuture, uniffiFutureContinuationCallback, - uniffiContinuationHandleMap.insert(obj: $0) + uniffiPollDataHandleMap.insert(obj: pollData) ) + + UNIFFI_POLL_DATA_HANDLE_COUNT += 1 + pollFunc(rustFuture, uniffiFutureContinuationCallback, pollDataPtr, 0) + } } while pollResult != UNIFFI_RUST_FUTURE_POLL_READY @@ -37,10 +63,30 @@ fileprivate func uniffiRustCallAsync( // Callback handlers for an async calls. These are invoked by Rust when the future is ready. They // lift the return value or error and resume the suspended function. -fileprivate func uniffiFutureContinuationCallback(handle: UInt64, pollResult: Int8) { - if let continuation = try? uniffiContinuationHandleMap.remove(handle: handle) { - continuation.resume(returning: pollResult) +fileprivate func uniffiFutureContinuationCallback( + pollDataHandle: UInt64, + pollResult: Int8, + blockingTaskQueueHandle: UInt64 +) { + if let pollData = try? uniffiPollDataHandleMap.remove(handle: pollDataHandle) { + pollData.continuation.resume(returning: pollResult) + if (blockingTaskQueueHandle == 0) { + // Complete the Swift continutation + pollData.continuation.resume(returning: pollResult) + } else { + // Call the poll function again, but inside the DispatchQuee + let queue = uniffiBlockingTaskQueueHandleMap.get(handle: blockingTaskQueueHandle)! + queue.async { + pollData.pollFunc(pollData.rustFuture, uniffiFutureContinuationCallback, pollDataHandle, blockingTaskQueueHandle) + } + } } else { print("uniffiFutureContinuationCallback invalid handle") } } + + +// For testing +public func uniffiPollDataHandleCount{{ ci.namespace()|class_name }}() -> Int { + return uniffiPollDataHandleMap.count +} diff --git a/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift new file mode 100644 index 0000000000..3b02a7c403 --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift @@ -0,0 +1,37 @@ +fileprivate var UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + clone: { (handle: UInt64) -> UInt64 in + let dispatchQueue = uniffiBlockingTaskQueueHandleMap.get(handle: handle)! + return uniffiBlockingTaskQueueHandleMap.insert(obj: dispatchQueue) + }, + free: { (handle: UInt64) in + uniffiBlockingTaskQueueHandleMap.remove(handle: handle) + } +) +fileprivate var uniffiBlockingTaskQueueHandleMap = UniffiHandleMap() + +fileprivate struct {{ ffi_converter_name }}: FfiConverterRustBuffer { + typealias SwiftType = DispatchQueue + + public static func write(_ value: DispatchQueue, into buf: inout [UInt8]) { + let handle = uniffiBlockingTaskQueueHandleMap.insert(obj: value) + writeInt(&buf, handle) + // From Apple: "You can safely use the address of a global variable as a persistent unique + // pointer value" (https://developer.apple.com/swift/blog/?id=6) + let vtablePointer = UnsafeMutablePointer(&UNIFFI_BLOCKING_TASK_QUEUE_VTABLE) + // Convert the pointer to a word-sized Int then to a 64-bit int then write it out. + writeInt(&buf, Int64(Int(bitPattern: vtablePointer))) + } + + public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> DispatchQueue { + let handle: UInt64 = try readInt(&buf) + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + let _: UInt64 = try readInt(&buf) + return uniffiBlockingTaskQueueHandleMap.remove(handle: handle)! + } +} + +// For testing +public func uniffiBlockingTaskQueueHandleCount{{ ci.namespace()|class_name }}() -> Int { + uniffiBlockingTaskQueueHandleMap.count +} diff --git a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift index 7aa1cca9b2..7c6c8c4750 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift @@ -21,7 +21,7 @@ extension {{ ffi_converter_name }} : FfiConverter { typealias FfiType = UInt64 public static func lift(_ handle: UInt64) throws -> SwiftType { - try handleMap.get(handle: handle) + return try handleMap.get(handle: handle) } public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> SwiftType { diff --git a/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift b/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift index af0305872b..85dfca540f 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift @@ -30,5 +30,11 @@ fileprivate class UniffiHandleMap { return obj } } + + var count: Int { + get { + return map.count + } + } } diff --git a/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift b/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift index cfddf7b313..5cd149464d 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift @@ -52,6 +52,69 @@ fileprivate extension RustCallStatus { } } +fileprivate class UniffiHandleMap { + private var leftMap: [UInt64: T] = [:] + private var counter: [UInt64: UInt64] = [:] + private var rightMap: [ObjectIdentifier: UInt64] = [:] + + private let lock = NSLock() + // Start with 1 so that 0 can be special-cased as the null value. + private var currentHandle: UInt64 = 1 + private let stride: UInt64 = 1 + + func insert(obj: T) -> UInt64 { + lock.withLock { + let id = ObjectIdentifier(obj as AnyObject) + let handle = rightMap[id] ?? { + currentHandle += stride + let handle = currentHandle + leftMap[handle] = obj + rightMap[id] = handle + return handle + }() + counter[handle] = (counter[handle] ?? 0) + 1 + return handle + } + } + + func get(handle: UInt64) -> T? { + lock.withLock { + leftMap[handle] + } + } + + func delete(handle: UInt64) { + remove(handle: handle) + } + + @discardableResult + func remove(handle: UInt64) -> T? { + lock.withLock { + defer { counter[handle] = (counter[handle] ?? 1) - 1 } + guard counter[handle] == 1 else { return leftMap[handle] } + let obj = leftMap.removeValue(forKey: handle) + if let obj = obj { + rightMap.removeValue(forKey: ObjectIdentifier(obj as AnyObject)) + } + return obj + } + } + + var count: Int { + get { + leftMap.count + } + } +} + +fileprivate extension NSLock { + func withLock(f: () throws -> T) rethrows -> T { + self.lock() + defer { self.unlock() } + return try f() + } +} + private func rustCall(_ callback: (UnsafeMutablePointer) -> T) throws -> T { try makeRustCall(callback, errorHandler: nil) } diff --git a/uniffi_bindgen/src/bindings/swift/templates/Types.swift b/uniffi_bindgen/src/bindings/swift/templates/Types.swift index 5e26758f3c..ba4d2059c8 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Types.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Types.swift @@ -64,6 +64,9 @@ {%- when Type::CallbackInterface { name, module_path } %} {%- include "CallbackInterfaceTemplate.swift" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.swift" %} + {%- when Type::Custom { name, module_path, builtin } %} {%- include "CustomType.swift" %} diff --git a/uniffi_bindgen/src/interface/ffi.rs b/uniffi_bindgen/src/interface/ffi.rs index 19354e16dc..cfab6ed823 100644 --- a/uniffi_bindgen/src/interface/ffi.rs +++ b/uniffi_bindgen/src/interface/ffi.rs @@ -129,7 +129,8 @@ impl From<&Type> for FfiType { | Type::Sequence { .. } | Type::Map { .. } | Type::Timestamp - | Type::Duration => FfiType::RustBuffer(None), + | Type::Duration + | Type::BlockingTaskQueue => FfiType::RustBuffer(None), Type::External { name, kind: ExternalKind::Interface, diff --git a/uniffi_bindgen/src/interface/mod.rs b/uniffi_bindgen/src/interface/mod.rs index a91656caab..f036eb27b5 100644 --- a/uniffi_bindgen/src/interface/mod.rs +++ b/uniffi_bindgen/src/interface/mod.rs @@ -67,9 +67,7 @@ mod record; pub use record::{Field, Record}; pub mod ffi; -pub use ffi::{ - FfiArgument, FfiCallbackFunction, FfiDefinition, FfiField, FfiFunction, FfiStruct, FfiType, -}; +pub use ffi::{FfiArgument, FfiCallbackFunction, FfiDefinition, FfiField, FfiFunction, FfiStruct, FfiType}; pub use uniffi_meta::Radix; use uniffi_meta::{ ConstructorMetadata, LiteralMetadata, NamespaceMetadata, ObjectMetadata, TraitMethodMetadata, @@ -224,10 +222,16 @@ impl ComponentInterface { .chain(self.objects.iter().flat_map(|o| o.ffi_callbacks())) } + } + /// Get the definitions for callback FFI functions /// /// These are defined by the foreign code and invoked by Rust. fn callback_interface_vtable_definitions(&self) -> impl IntoIterator + '_ { + + pub fn callback_interface_vtable_definitions( + &self, + ) -> impl IntoIterator + '_ { self.callback_interface_definitions() .iter() .map(|cbi| cbi.vtable_definition()) @@ -503,6 +507,10 @@ impl ComponentInterface { name: "callback_data".to_owned(), type_: FfiType::Handle, }, + FfiArgument { + name: "blocking_task_queue_handle".to_owned(), + type_: FfiType::UInt64, + }, ], return_type: None, has_rust_call_status_arg: false, @@ -632,6 +640,18 @@ impl ComponentInterface { has_rust_call_status_arg: false, } .into(), + FfiCallbackFunction { + name: "BlockingTaskQueueClone".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: Some(FfiType::UInt64), + has_rust_call_status_arg: false, + }, + FfiCallbackFunction { + name: "BlockingTaskQueueFree".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: None, + has_rust_call_status_arg: false, + }, FfiStruct { name: "ForeignFuture".to_owned(), fields: vec![ @@ -640,6 +660,19 @@ impl ComponentInterface { ], } .into(), + FfiStruct { + name: "BlockingTaskQueueVTable".to_owned(), + fields: vec![ + FfiField::new( + "clone", + FfiType::Callback("BlockingTaskQueueClone".to_owned()), + ), + FfiField::new( + "free", + FfiType::Callback("BlockingTaskQueueFree".to_owned()), + ), + ], + } ] .into_iter() } @@ -843,6 +876,12 @@ impl ComponentInterface { .map(|n| self.errors.insert(n.to_string())); self.functions.push(defn); + if defn.is_async() { + self.types + .add_known_type(&uniffi_meta::Type::BlockingTaskQueue)?; + } + + self.functions.push(defn); Ok(()) } diff --git a/uniffi_bindgen/src/interface/universe.rs b/uniffi_bindgen/src/interface/universe.rs index 70bc61f8a9..2faef72fd6 100644 --- a/uniffi_bindgen/src/interface/universe.rs +++ b/uniffi_bindgen/src/interface/universe.rs @@ -84,6 +84,7 @@ impl TypeUniverse { Type::Bytes => self.add_type_definition("bytes", type_)?, Type::Timestamp => self.add_type_definition("timestamp", type_)?, Type::Duration => self.add_type_definition("duration", type_)?, + Type::BlockingTaskQueue => self.add_type_definition("BlockingTaskQueue", type_)?, Type::Object { name, .. } | Type::Record { name, .. } | Type::Enum { name, .. } diff --git a/uniffi_bindgen/src/scaffolding/mod.rs b/uniffi_bindgen/src/scaffolding/mod.rs index 7fd81831aa..231e0495d3 100644 --- a/uniffi_bindgen/src/scaffolding/mod.rs +++ b/uniffi_bindgen/src/scaffolding/mod.rs @@ -45,6 +45,7 @@ mod filters { format!("std::sync::Arc<{}>", imp.rust_name_for(name)) } Type::CallbackInterface { name, .. } => format!("Box"), + Type::BlockingTaskQueue => "::uniffi::BlockingTaskQueue".to_owned(), Type::Optional { inner_type } => { format!("std::option::Option<{}>", type_rs(inner_type)?) } diff --git a/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs new file mode 100644 index 0000000000..50abe5a66f --- /dev/null +++ b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs @@ -0,0 +1,69 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +//! Defines the BlockingTaskQueue struct +//! +//! This module is responsible for the general handling of BlockingTaskQueue instances (cloning, droping, etc). +//! See `scheduler.rs` and the foreign bindings code for how the async functionality is implemented. + +use super::scheduler::schedule_in_blocking_task_queue; +use std::num::NonZeroU64; + +/// Foreign-managed blocking task queue that we can use to schedule futures +/// +/// On the foreign side this is a Kotlin `CoroutineContext`, Python `Executor` or Swift +/// `DispatchQueue`. UniFFI converts those objects into this struct for the Rust code to use. +/// +/// Rust async code can call [BlockingTaskQueue::execute] to run a closure in that +/// blocking task queue. Use this for functions with blocking operations that should not be executed +/// in a normal async context. Some examples are non-async file/network operations, long-running +/// CPU-bound tasks, blocking database operations, etc. +#[repr(C)] +pub struct BlockingTaskQueue { + /// Opaque handle for the task queue + pub handle: NonZeroU64, + /// Method VTable + /// + /// This is simply a C struct where each field is a function pointer that inputs a + /// BlockingTaskQueue handle + pub vtable: &'static BlockingTaskQueueVTable, +} + +#[repr(C)] +#[derive(Debug)] +pub struct BlockingTaskQueueVTable { + clone: extern "C" fn(u64) -> u64, + drop: extern "C" fn(u64), +} + +// Note: see `scheduler.rs` for details on how BlockingTaskQueue is used. +impl BlockingTaskQueue { + /// Run a closure in a blocking task queue + pub async fn execute(&self, f: F) -> R + where + F: FnOnce() -> R, + { + schedule_in_blocking_task_queue(self.handle).await; + f() + } +} + +impl Clone for BlockingTaskQueue { + fn clone(&self) -> Self { + let raw_handle = (self.vtable.clone)(self.handle.into()); + let handle = raw_handle + .try_into() + .expect("BlockingTaskQueue.clone() returned 0"); + Self { + handle, + vtable: self.vtable, + } + } +} + +impl Drop for BlockingTaskQueue { + fn drop(&mut self) { + (self.vtable.drop)(self.handle.into()) + } +} diff --git a/uniffi_core/src/ffi/rustfuture/future.rs b/uniffi_core/src/ffi/rustfuture/future.rs index 93c34e7543..26781206d2 100644 --- a/uniffi_core/src/ffi/rustfuture/future.rs +++ b/uniffi_core/src/ffi/rustfuture/future.rs @@ -21,6 +21,10 @@ //! 2b. If the async function is cancelled, then call [rust_future_cancel]. This causes the //! continuation function to be called with [RustFuturePoll::Ready] and the [RustFuture] to //! enter a cancelled state. +//! 2c. If the Rust code wants schedule work to be run in a `BlockingTaskQueue`, then the +//! continuation is called with [RustFuturePoll::MaybeReady] and the blocking task queue handle. +//! The foreign code is responsible for ensuring the next [rust_future_poll] call happens in +//! that blocking task queue and the handle is passed to [rust_future_poll]. //! 3. Call [rust_future_complete] to get the result of the future. //! 4. Call [rust_future_free] to free the future, ideally in a finally block. This: //! - Releases any resources held by the future @@ -78,6 +82,7 @@ use std::{ future::Future, marker::PhantomData, + num::NonZeroU64, ops::Deref, panic, pin::Pin, @@ -85,8 +90,8 @@ use std::{ task::{Context, Poll, Wake}, }; -use super::{RustFutureContinuationCallback, RustFuturePoll, Scheduler}; -use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus}; +use super::{scheduler, RustFutureContinuationCallback, Scheduler}; +use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus, RustFuturePoll}; /// Wraps the actual future we're polling struct WrappedFuture @@ -223,17 +228,26 @@ where }) } - pub(super) fn poll(self: Arc, callback: RustFutureContinuationCallback, data: u64) { + pub(super) fn poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ) { + scheduler::on_poll_start(blocking_task_queue_handle); + // Clear out the waked flag, since we're about to poll right now. + self.scheduler.lock().unwrap().clear_wake_flag(); let ready = self.is_cancelled() || { let mut locked = self.future.lock().unwrap(); let waker: std::task::Waker = Arc::clone(&self).into(); locked.poll(&mut Context::from_waker(&waker)) }; if ready { - callback(data, RustFuturePoll::Ready) + callback(data, RustFuturePoll::Ready, 0) } else { self.scheduler.lock().unwrap().store(callback, data); } + scheduler::on_poll_end(); } pub(super) fn is_cancelled(&self) -> bool { @@ -289,7 +303,12 @@ where /// only create those functions for each of the 13 possible FFI return types. #[doc(hidden)] pub trait RustFutureFfi: Send + Sync { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: u64); + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ); fn ffi_cancel(&self); fn ffi_complete(&self, call_status: &mut RustCallStatus) -> ReturnType; fn ffi_free(self: Arc); @@ -302,8 +321,13 @@ where T: LowerReturn + Send + 'static, UT: Send + 'static, { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: u64) { - self.poll(callback, data) + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ) { + self.poll(callback, data, blocking_task_queue_handle) } fn ffi_cancel(&self) { diff --git a/uniffi_core/src/ffi/rustfuture/mod.rs b/uniffi_core/src/ffi/rustfuture/mod.rs index 39529f2db1..bb064d2e65 100644 --- a/uniffi_core/src/ffi/rustfuture/mod.rs +++ b/uniffi_core/src/ffi/rustfuture/mod.rs @@ -4,8 +4,11 @@ use std::{future::Future, sync::Arc}; +mod blocking_task_queue; mod future; mod scheduler; + +pub use blocking_task_queue::*; use future::*; use scheduler::*; @@ -28,7 +31,19 @@ pub enum RustFuturePoll { /// /// The Rust side of things calls this when the foreign side should call [rust_future_poll] again /// to continue progress on the future. -pub type RustFutureContinuationCallback = extern "C" fn(callback_data: u64, RustFuturePoll); +/// +/// WARNING: the call to [rust_future_poll] must be scheduled to happen soon after the callback is +/// called, but not inside the callback itself. If [rust_future_poll] is called inside the +/// callback, some futures will deadlock and our scheduler code might as well. +/// +/// * `callback_data` is the handle that the foreign code passed to `poll()` +/// * `poll_result` is the result of the poll +/// * If `blocking_task_task_queue` is non-zero, it's the BlockingTaskQueue handle that the next `poll()` should run on +pub type RustFutureContinuationCallback = extern "C" fn( + callback_data: u64, + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +); // === Public FFI API === @@ -62,6 +77,9 @@ where /// a [RustFuturePoll] value. For each [rust_future_poll] call the continuation will be called /// exactly once. /// +/// If this is running in a BlockingTaskQueue, then `blocking_task_queue_handle` must be the handle +/// for it. If not, `blocking_task_queue_handle` must be `0`. +/// /// # Safety /// /// The [Handle] must not previously have been passed to [rust_future_free] @@ -69,10 +87,14 @@ pub unsafe fn rust_future_poll( handle: Handle, callback: RustFutureContinuationCallback, data: u64, + blocking_task_queue_handle: u64, ) where dyn RustFutureFfi: HandleAlloc, { - as HandleAlloc>::get_arc(handle).ffi_poll(callback, data) + let future = &*(future.0 as *mut Arc>); + future + .clone() + .ffi_poll(callback, data, blocking_task_queue_handle.try_into().ok()) } /// Cancel a Rust future diff --git a/uniffi_core/src/ffi/rustfuture/scheduler.rs b/uniffi_core/src/ffi/rustfuture/scheduler.rs index 629ee0c109..26526701d3 100644 --- a/uniffi_core/src/ffi/rustfuture/scheduler.rs +++ b/uniffi_core/src/ffi/rustfuture/scheduler.rs @@ -2,10 +2,89 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use std::mem; +use std::{cell::RefCell, future::poll_fn, mem, num::NonZeroU64, task::Poll, thread_local}; use super::{RustFutureContinuationCallback, RustFuturePoll}; +/// Context of the current `RustFuture::poll` call +struct RustFutureContext { + /// Blocking task queue that the future is being polled on + current_blocking_task_queue_handle: Option, + /// Blocking task queue that we've been asked to schedule the next poll on + scheduled_blocking_task_queue_handle: Option, +} + +thread_local! { + static CONTEXT: RefCell = RefCell::new(RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + }); +} + +fn with_context R, R>(operation: F) -> R { + CONTEXT.with(|context| operation(&mut context.borrow_mut())) +} + +pub fn on_poll_start(current_blocking_task_queue_handle: Option) { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +pub fn on_poll_end() { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +/// Schedule work in a blocking task queue +/// +/// The returned future will attempt to arrange for [RustFuture::poll] to be called in the +/// blocking task queue. Once [RustFuture::poll] is running in the blocking task queue, then the future +/// will be ready. +/// +/// There's one tricky issue here: how can we ensure that when the top-level task is run in the +/// blocking task queue, this future will be polled? What happens this future is a child of `join!`, +/// `FuturesUnordered` or some other Future that handles its own polling? +/// +/// We start with an assumption: if we notify the waker then this future will be polled when the +/// top-level task is polled next. If a future does not honor this then we consider it a broken +/// future. This seems fair, since that future would almost certainly break a lot of other future +/// code. +/// +/// Based on that, we can have a simple system. When we're polled: +/// * If we're running in the blocking task queue, then we return `Poll::Ready`. +/// * If not, we return `Poll::Pending` and notify the waker so that the future polls again on +/// the next top-level poll. +/// +/// Note that this can be inefficient if the code awaits multiple blocking task queues at once. We +/// can only run the next poll on one of them, but all futures will be woken up. This seems okay +/// for our intended use cases, it would be pretty odd for a library to use multiple blocking task +/// queues. The alternative would be to store the set of all pending blocking task queues, which +/// seems like complete overkill for our purposes. +pub(super) async fn schedule_in_blocking_task_queue(handle: NonZeroU64) { + poll_fn(|future_context| { + with_context(|poll_context| { + if poll_context.current_blocking_task_queue_handle == Some(handle) { + Poll::Ready(()) + } else { + poll_context + .scheduled_blocking_task_queue_handle + .get_or_insert(handle); + future_context.waker().wake_by_ref(); + Poll::Pending + } + }) + }) + .await +} + /// Schedules a [crate::RustFuture] by managing the continuation data /// /// This struct manages the continuation callback and data that comes from the foreign side. It @@ -41,21 +120,34 @@ impl Scheduler { /// Store new continuation data if we are in the `Empty` state. If we are in the `Waked` or /// `Cancelled` state, call the continuation immediately with the data. pub(super) fn store(&mut self, callback: RustFutureContinuationCallback, data: u64) { + if let Some(blocking_task_queue_handle) = + with_context(|context| context.scheduled_blocking_task_queue_handle) + { + // We were asked to schedule the future in a blocking task queue, call the callback + // rather than storing it + callback( + data, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle.into(), + ); + return; + } + match self { Self::Empty => *self = Self::Set(callback, data), Self::Set(old_callback, old_data) => { log::error!( "store: observed `Self::Set` state. Is poll() being called from multiple threads at once?" ); - old_callback(*old_data, RustFuturePoll::Ready); + old_callback(*old_data, RustFuturePoll::Ready, 0); *self = Self::Set(callback, data); } Self::Waked => { *self = Self::Empty; - callback(data, RustFuturePoll::MaybeReady); + callback(data, RustFuturePoll::MaybeReady, 0); } Self::Cancelled => { - callback(data, RustFuturePoll::Ready); + callback(data, RustFuturePoll::Ready, 0); } } } @@ -67,7 +159,7 @@ impl Scheduler { let old_data = *old_data; let callback = *callback; *self = Self::Empty; - callback(old_data, RustFuturePoll::MaybeReady); + callback(old_data, RustFuturePoll::MaybeReady, 0); } // If we were in the `Empty` state, then transition to `Waked`. The next time `store` // is called, we will immediately call the continuation. @@ -79,7 +171,13 @@ impl Scheduler { pub(super) fn cancel(&mut self) { if let Self::Set(callback, old_data) = mem::replace(self, Self::Cancelled) { - callback(old_data, RustFuturePoll::Ready); + callback(old_data, RustFuturePoll::Ready, 0); + } + } + + pub(super) fn clear_wake_flag(&mut self) { + if let Self::Waked = self { + *self = Self::Empty } } diff --git a/uniffi_core/src/ffi/rustfuture/tests.rs b/uniffi_core/src/ffi/rustfuture/tests.rs index 886ee27c71..70a49179b0 100644 --- a/uniffi_core/src/ffi/rustfuture/tests.rs +++ b/uniffi_core/src/ffi/rustfuture/tests.rs @@ -65,16 +65,38 @@ fn channel() -> (Sender, Arc>) { } /// Poll a Rust future and get an OnceCell that's set when the continuation is called -fn poll(rust_future: &Arc>) -> Arc> { +fn poll(rust_future: &Arc>) -> Arc> { let cell = Arc::new(OnceCell::new()); - let handle = Arc::into_raw(cell.clone()) as u64; - rust_future.clone().ffi_poll(poll_continuation, handle); + let cell_ptr = Arc::into_raw(cell.clone()) as u64; + rust_future + .clone() + .ffi_poll(poll_continuation, cell_ptr, None); cell } -extern "C" fn poll_continuation(data: u64, code: RustFuturePoll) { - let cell = unsafe { Arc::from_raw(data as *const OnceCell) }; - cell.set(code).expect("Error setting OnceCell"); +/// Like poll, but simulate `poll()` being called from a blocking task queue +fn poll_from_blocking_task_queue( + rust_future: &Arc>, + blocking_task_queue_handle: u64, +) -> Arc> { + let cell = Arc::new(OnceCell::new()); + let cell_ptr = Arc::into_raw(cell.clone()) as *const (); + rust_future.clone().ffi_poll( + poll_continuation, + cell_ptr, + Some(blocking_task_queue_handle.try_into().unwrap()), + ); + cell +} + +extern "C" fn poll_continuation( + data: *const (), + code: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + let cell = unsafe { Arc::from_raw(data as *const OnceCell<(RustFuturePoll, u64)>) }; + cell.set((code, blocking_task_queue_handle)) + .expect("Error setting OnceCell"); } fn complete(rust_future: Arc>) -> (RustBuffer, RustCallStatus) { @@ -83,25 +105,47 @@ fn complete(rust_future: Arc>) -> (RustBuffer, Rus (return_value, out_status_code) } +fn check_continuation_not_called(once_cell: &OnceCell<(RustFuturePoll, u64)>) { + assert_eq!(once_cell.get(), None); +} + +fn check_continuation_called( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, +) { + assert_eq!(once_cell.get(), Some(&(poll_result, 0))); +} + +fn check_continuation_called_with_blocking_task_queue_handle( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + assert_eq!( + once_cell.get(), + Some(&(poll_result, blocking_task_queue_handle)) + ) +} + #[test] fn test_success() { let (sender, rust_future) = channel(); // Test polling the rust future before it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.wake(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Test polling the rust future when it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Ok("All done".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Future polls should immediately return ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Complete the future let (return_buf, call_status) = complete(rust_future); @@ -117,12 +161,12 @@ fn test_error() { let (sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Err("Something went wrong".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Error); @@ -144,14 +188,14 @@ fn test_cancel() { let (_sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); rust_future.ffi_cancel(); // Cancellation should immediately invoke the callback with RustFuturePoll::Ready - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Future polls should immediately invoke the callback with RustFuturePoll::Ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Cancelled); @@ -187,7 +231,7 @@ fn test_complete_with_stored_continuation() { let continuation_result = poll(&rust_future); rust_future.ffi_free(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); } // Test what happens if we see a `wake()` call while we're polling the future. This can @@ -210,10 +254,47 @@ fn test_wake_during_poll() { let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); let continuation_result = poll(&rust_future); // The continuation function should called immediately - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // A second poll should finish the future let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + let (return_buf, call_status) = complete(rust_future); + assert_eq!(call_status.code, RustCallStatusCode::Success); + assert_eq!( + >::try_lift(return_buf).unwrap(), + "All done" + ); +} + +#[test] +fn test_blocking_task() { + let blocking_task_queue_handle = 1001; + let future = async move { + schedule_in_blocking_task_queue(blocking_task_queue_handle.try_into().unwrap()).await; + "All done".to_owned() + }; + let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); + // On the first poll, the future should not be ready and it should ask to be scheduled in the + // blocking task queue + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // If we poll it again not in a blocking task queue, then we get the same result + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // When we poll it in the blocking task queue, then the future is ready + let continuation_result = + poll_from_blocking_task_queue(&rust_future, blocking_task_queue_handle); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + + // Complete the future let (return_buf, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Success); assert_eq!( diff --git a/uniffi_core/src/ffi_converter_impls.rs b/uniffi_core/src/ffi_converter_impls.rs index aec093154a..3866f84eda 100644 --- a/uniffi_core/src/ffi_converter_impls.rs +++ b/uniffi_core/src/ffi_converter_impls.rs @@ -23,8 +23,9 @@ /// "UT" means an arbitrary `UniFfiTag` type. use crate::{ check_remaining, derive_ffi_traits, ffi_converter_rust_buffer_lift_and_lower, metadata, - ConvertError, FfiConverter, Lift, LiftRef, LiftReturn, Lower, LowerReturn, MetadataBuffer, - Result, RustBuffer, UnexpectedUniFFICallbackError, + BlockingTaskQueue, BlockingTaskQueueVTable, ConvertError, FfiConverter, Lift, LiftRef, + LiftReturn, Lower, LowerReturn, MetadataBuffer, Result, RustBuffer, + UnexpectedUniFFICallbackError, }; use anyhow::bail; use bytes::buf::{Buf, BufMut}; @@ -244,6 +245,35 @@ unsafe impl FfiConverter for Duration { const TYPE_ID_META: MetadataBuffer = MetadataBuffer::from_code(metadata::codes::TYPE_DURATION); } +/// Support for passing [BlockingTaskQueue] across the FFI +/// +/// Both fields of [BlockingTaskQueue] are serialized into a RustBuffer. The vtable pointer is +/// casted to a u64. +unsafe impl FfiConverter for BlockingTaskQueue { + ffi_converter_rust_buffer_lift_and_lower!(UT); + + fn write(obj: BlockingTaskQueue, buf: &mut Vec) { + let obj = obj.clone(); + buf.put_u64(obj.handle.into()); + buf.put_u64(obj.vtable as *const BlockingTaskQueueVTable as u64); + } + + fn try_read(buf: &mut &[u8]) -> Result { + check_remaining(buf, 16)?; + let handle = buf + .get_u64() + .try_into() + .expect("handle = 0 when reading BlockingTaskQueue"); + let vtable = unsafe { + &*(buf.get_u64() as *const BlockingTaskQueueVTable) as &'static BlockingTaskQueueVTable + }; + Ok(Self { handle, vtable }) + } + + const TYPE_ID_META: MetadataBuffer = + MetadataBuffer::from_code(metadata::codes::TYPE_BLOCKING_TASK_QUEUE); +} + // Support for passing optional values via the FFI. // // Optional values are currently always passed by serializing to a buffer. @@ -419,6 +449,7 @@ derive_ffi_traits!(blanket bool); derive_ffi_traits!(blanket String); derive_ffi_traits!(blanket Duration); derive_ffi_traits!(blanket SystemTime); +derive_ffi_traits!(blanket BlockingTaskQueue); // For composite types, derive LowerReturn, LiftReturn, etc, from Lift/Lower. // diff --git a/uniffi_core/src/metadata.rs b/uniffi_core/src/metadata.rs index 934d09cf87..4b9f9b861f 100644 --- a/uniffi_core/src/metadata.rs +++ b/uniffi_core/src/metadata.rs @@ -69,6 +69,7 @@ pub mod codes { pub const TYPE_RESULT: u8 = 23; pub const TYPE_TRAIT_INTERFACE: u8 = 24; pub const TYPE_CALLBACK_TRAIT_INTERFACE: u8 = 25; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 26; pub const TYPE_UNIT: u8 = 255; // Literal codes for LiteralMetadata - note that we don't support diff --git a/uniffi_macros/src/setup_scaffolding.rs b/uniffi_macros/src/setup_scaffolding.rs index 08bb2cc568..05a7b7ea49 100644 --- a/uniffi_macros/src/setup_scaffolding.rs +++ b/uniffi_macros/src/setup_scaffolding.rs @@ -166,8 +166,13 @@ fn rust_future_scaffolding_fns(module_path: &str) -> TokenStream { #[allow(clippy::missing_safety_doc, missing_docs)] #[doc(hidden)] #[no_mangle] - pub unsafe extern "C" fn #ffi_rust_future_poll(handle: ::uniffi::Handle, callback: ::uniffi::RustFutureContinuationCallback, data: u64) { - ::uniffi::ffi::rust_future_poll::<#return_type, crate::UniFfiTag>(handle, callback, data); + pub unsafe extern "C" fn #ffi_rust_future_poll( + handle: ::uniffi::RustFutureHandle, + callback: ::uniffi::RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: u64, + ) { + ::uniffi::ffi::rust_future_poll::<#return_type>(handle, callback, data, blocking_task_queue_handle); } #[allow(clippy::missing_safety_doc, missing_docs)] diff --git a/uniffi_meta/src/metadata.rs b/uniffi_meta/src/metadata.rs index 66c2c63952..b553e0060e 100644 --- a/uniffi_meta/src/metadata.rs +++ b/uniffi_meta/src/metadata.rs @@ -52,6 +52,7 @@ pub mod codes { pub const TYPE_RESULT: u8 = 23; pub const TYPE_TRAIT_INTERFACE: u8 = 24; pub const TYPE_CALLBACK_TRAIT_INTERFACE: u8 = 25; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 26; pub const TYPE_UNIT: u8 = 255; // Literal codes diff --git a/uniffi_meta/src/reader.rs b/uniffi_meta/src/reader.rs index 5a09d9dd7c..5341f1b24a 100644 --- a/uniffi_meta/src/reader.rs +++ b/uniffi_meta/src/reader.rs @@ -145,6 +145,7 @@ impl<'a> MetadataReader<'a> { codes::TYPE_STRING => Type::String, codes::TYPE_DURATION => Type::Duration, codes::TYPE_SYSTEM_TIME => Type::Timestamp, + codes::TYPE_BLOCKING_TASK_QUEUE => Type::BlockingTaskQueue, codes::TYPE_RECORD => Type::Record { module_path: self.read_string()?, name: self.read_string()?, diff --git a/uniffi_meta/src/types.rs b/uniffi_meta/src/types.rs index 51bf156b50..5003dd9f77 100644 --- a/uniffi_meta/src/types.rs +++ b/uniffi_meta/src/types.rs @@ -88,6 +88,7 @@ pub enum Type { // How the object is implemented. imp: ObjectImpl, }, + BlockingTaskQueue, // Types defined in the component API, each of which has a string name. Record { module_path: String, diff --git a/uniffi_udl/src/resolver.rs b/uniffi_udl/src/resolver.rs index ea98cd7a99..1409c3a6ff 100644 --- a/uniffi_udl/src/resolver.rs +++ b/uniffi_udl/src/resolver.rs @@ -209,6 +209,7 @@ pub(crate) fn resolve_builtin_type(name: &str) -> Option { "f64" => Some(Type::Float64), "timestamp" => Some(Type::Timestamp), "duration" => Some(Type::Duration), + "BlockingTaskQueue" => Some(Type::BlockingTaskQueue), _ => None, } }