From 9ebf20e181420635dc2d06b9e8850bc2d5aa3049 Mon Sep 17 00:00:00 2001 From: Ben Dean-Kawamura Date: Fri, 8 Dec 2023 11:40:02 -0500 Subject: [PATCH] Async blocking task support Added the `BlockingTaskQueue` type BlockingTaskQueue allows a Rust closure to be scheduled on a foreign thread where blocking operations are okay. The closure runs inside the parent future, which is nice because it allows the closure to reference its outside scope. On the foreign side, a `BlockingTaskQueue` is a native type that runs a task in some sort of thread queue (`DispatchQueue`, `CoroutineContext`, `futures.Executor`, etc.). Updated some of the foreign code so that the handle-maps used for callback interfaces can also be used for this. It's quite messy right now, but #1823 should clean things up. Renamed `Handle` to `UniffiHandle` on Kotlin, since `Handle` can easily conflict with user names. In fact it did on the `custom_types` fixture, the only reason that this worked before was that that fixture didn't use callback interfaces. Added new tests for this in the futures fixtures. Updated the tests to check that handles are being released properly. --- Cargo.lock | 81 ++++++++- docs/manual/src/futures.md | 60 +++++++ fixtures/futures/Cargo.toml | 1 + fixtures/futures/src/lib.rs | 55 ++++++ .../futures/tests/bindings/test_futures.kts | 73 ++++++-- .../futures/tests/bindings/test_futures.py | 59 +++++++ .../futures/tests/bindings/test_futures.swift | 129 ++++++-------- fixtures/metadata/src/tests.rs | 1 + .../kotlin/gen_kotlin/blocking_task_queue.rs | 19 +++ .../src/bindings/kotlin/gen_kotlin/mod.rs | 5 +- .../src/bindings/kotlin/templates/Async.kt | 40 +++-- .../templates/BlockingTaskQueueTemplate.kt | 44 +++++ .../templates/CallbackInterfaceRuntime.kt | 4 +- .../kotlin/templates/ObjectRuntime.kt | 159 ++++++++++++++++++ .../src/bindings/kotlin/templates/Types.kt | 5 + .../src/bindings/kotlin/templates/wrapper.kt | 1 + .../python/gen_python/blocking_task_queue.rs | 19 +++ .../src/bindings/python/gen_python/mod.rs | 3 + .../src/bindings/python/templates/Async.py | 50 +++++- .../templates/BlockingTaskQueueTemplate.py | 39 +++++ .../src/bindings/python/templates/Helpers.py | 40 +++++ .../src/bindings/python/templates/Types.py | 3 + .../src/bindings/python/templates/wrapper.py | 1 + .../src/bindings/ruby/gen_ruby/mod.rs | 4 + .../swift/gen_swift/blocking_task_queue.rs | 19 +++ .../src/bindings/swift/gen_swift/mod.rs | 3 + .../src/bindings/swift/templates/Async.swift | 64 ++++++- .../templates/BlockingTaskQueueTemplate.swift | 37 ++++ .../templates/CallbackInterfaceTemplate.swift | 2 +- .../bindings/swift/templates/HandleMap.swift | 6 + .../bindings/swift/templates/Helpers.swift | 63 +++++++ .../src/bindings/swift/templates/Types.swift | 3 + uniffi_bindgen/src/interface/ffi.rs | 3 +- uniffi_bindgen/src/interface/mod.rs | 45 ++++- uniffi_bindgen/src/interface/universe.rs | 1 + uniffi_bindgen/src/scaffolding/mod.rs | 1 + .../src/ffi/rustfuture/blocking_task_queue.rs | 69 ++++++++ uniffi_core/src/ffi/rustfuture/future.rs | 38 ++++- uniffi_core/src/ffi/rustfuture/mod.rs | 26 ++- uniffi_core/src/ffi/rustfuture/scheduler.rs | 110 +++++++++++- uniffi_core/src/ffi/rustfuture/tests.rs | 121 ++++++++++--- uniffi_core/src/ffi_converter_impls.rs | 35 +++- uniffi_core/src/metadata.rs | 1 + uniffi_macros/src/setup_scaffolding.rs | 9 +- uniffi_meta/src/metadata.rs | 1 + uniffi_meta/src/reader.rs | 1 + uniffi_meta/src/types.rs | 1 + uniffi_udl/src/resolver.rs | 1 + 48 files changed, 1386 insertions(+), 169 deletions(-) create mode 100644 uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs create mode 100644 uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt create mode 100644 uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt create mode 100644 uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs create mode 100644 uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py create mode 100644 uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs create mode 100644 uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift create mode 100644 uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs diff --git a/Cargo.lock b/Cargo.lock index 12e04864f7..004d1532c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -623,26 +623,53 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0845fa252299212f0389d64ba26f34fa32cfe41588355f21ed507c59a0f64541" +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" [[package]] name = "futures-lite" @@ -659,6 +686,47 @@ dependencies = [ "waker-fn", ] +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.28.0" @@ -1661,6 +1729,7 @@ name = "uniffi-fixture-futures" version = "0.21.0" dependencies = [ "async-trait", + "futures", "once_cell", "thiserror", "tokio", diff --git a/docs/manual/src/futures.md b/docs/manual/src/futures.md index 0e1d22ea85..07996a4d6e 100644 --- a/docs/manual/src/futures.md +++ b/docs/manual/src/futures.md @@ -71,3 +71,63 @@ pub trait SayAfterTrait: Send + Sync { async fn say_after(&self, ms: u16, who: String) -> String; } ``` + +## Blocking tasks + +Rust executors are designed around an assumption that the `Future::poll` function will return quickly. +This assumption, combined with cooperative scheduling, allows for a large number of futures to be handled by a small number of threads. +Foreign executors make similar assumptions and sometimes more extreme ones. +For example, the Python eventloop is single threaded -- if any task spends a long time between `await` points, then it will block all other tasks from progressing. + +This raises the question of how async code can interact with blocking code that performs blocking IO, long-running computations without `await` breaks, etc. +UniFFI defines the `BlockingTaskQueue` type, which is a foreign object that schedules work on a thread where it's okay to block. + +On Rust, `BlockingTaskQueue` is a UniFFI type that can safely run blocking code. +It's `execute` method works like tokio's [block_in_place](https://docs.rs/tokio/latest/tokio/task/fn.block_in_place.html) function. +It inputs a closure and runs it in the `BlockingTaskQueue`. +This closure can reference the outside scope (i.e. it does not need to be `'static`). +For example: + +```rust +#[derive(uniffi::Object)] +struct DataStore { + // Used to run blocking tasks + queue: uniffi::BlockingTaskQueue, + // Low-level DB object with blocking methods + db: Mutex, +} + +#[uniffi::export] +impl DataStore { + #[uniffi::constructor] + fn new(queue: uniffi::BlockingTaskQueue) -> Self { + Self { + queue, + db: Mutex::new(Database::new()) + } + } + + fn fetch_all_items(&self) -> Vec { + self.queue.execute(|| self.db.lock().fetch_all_items()) + } +} +``` + +On the foreign side `BlockingTaskQueue` corresponds to a language-dependent class. + +### Kotlin +Kotlin uses `CoroutineContext` for its `BlockingTaskQueue`. +Any `CoroutineContext` will work, but `Dispatchers.IO` is usually a good choice. +A DataStore from the example above can be created with `DataStore(Dispatchers.IO)`. + +### Swift +Swift uses `DispatchQueue` for its `BlockingTaskQueue`. +The user-initiated global queue is normally a good choice. +A DataStore from the example above can be created with `DataStore(queue: DispatchQueue.global(qos: .userInitiated)`. +The `DispatchQueue` should be concurrent. + +### Python + +Python uses a `futures.Executor` for its `BlockingTaskQueue`. +`ThreadPoolExecutor` is typically a good choice. +A DataStore from the example above can be created with `DataStore(ThreadPoolExecutor())`. diff --git a/fixtures/futures/Cargo.toml b/fixtures/futures/Cargo.toml index f386c6d85c..0bb6f1c04d 100644 --- a/fixtures/futures/Cargo.toml +++ b/fixtures/futures/Cargo.toml @@ -17,6 +17,7 @@ path = "src/bin.rs" [dependencies] uniffi = { workspace = true, features = ["tokio", "cli"] } async-trait = "0.1" +futures = "0.3.29" thiserror = "1.0" tokio = { version = "1.24.1", features = ["time", "sync"] } once_cell = "1.18.0" diff --git a/fixtures/futures/src/lib.rs b/fixtures/futures/src/lib.rs index 4b4ed1cca9..8e58639f66 100644 --- a/fixtures/futures/src/lib.rs +++ b/fixtures/futures/src/lib.rs @@ -11,6 +11,8 @@ use std::{ time::Duration, }; +use futures::stream::{FuturesUnordered, StreamExt}; + /// Non-blocking timer future. pub struct TimerFuture { shared_state: Arc>, @@ -385,6 +387,59 @@ impl SayAfterUdlTrait for SayAfterImpl2 { #[uniffi::export] fn get_say_after_udl_traits() -> Vec> { vec![Arc::new(SayAfterImpl1), Arc::new(SayAfterImpl2)] + +/// Async function that uses a blocking task queue to do its work +#[uniffi::export] +pub async fn calc_square(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.execute(|| value * value).await +} + +/// Same as before, but this one runs multiple tasks +#[uniffi::export] +pub async fn calc_squares(queue: uniffi::BlockingTaskQueue, items: Vec) -> Vec { + // Use `FuturesUnordered` to test our blocking task queue code which is known to be a tricky API to work with. + // In particular, if we don't notify the waker then FuturesUnordered will not poll again. + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queue.execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// ...and this one uses multiple BlockingTaskQueues +#[uniffi::export] +pub async fn calc_squares_multi_queue( + queues: Vec, + items: Vec, +) -> Vec { + let mut futures: FuturesUnordered<_> = (0..items.len()) + .map(|i| { + // Test that we can use references from the surrounding scope + let items = &items; + queues[i].execute(move || items[i] * items[i]) + }) + .collect(); + let mut results = vec![]; + while let Some(result) = futures.next().await { + results.push(result); + } + results.sort(); + results +} + +/// Like calc_square, but it clones the BlockingTaskQueue first then drops both copies. Used to +/// test that a) the clone works and b) we correctly drop the references. +#[uniffi::export] +pub async fn calc_square_with_clone(queue: uniffi::BlockingTaskQueue, value: i32) -> i32 { + queue.clone().execute(|| value * value).await } uniffi::include_scaffolding!("futures"); diff --git a/fixtures/futures/tests/bindings/test_futures.kts b/fixtures/futures/tests/bindings/test_futures.kts index 175f4a619a..f8cf0cb4e3 100644 --- a/fixtures/futures/tests/bindings/test_futures.kts +++ b/fixtures/futures/tests/bindings/test_futures.kts @@ -1,9 +1,22 @@ import uniffi.fixture.futures.* +import java.util.concurrent.Executors import kotlinx.coroutines.* import kotlin.system.* +fun runAsyncTest(test: suspend CoroutineScope.() -> Unit) { + val initialBlockingTaskQueueHandleCount = uniffiBlockingTaskQueueHandleCount() + val initialPollHandleCount = uniffiPollHandleCount() + val time = runBlocking { + measureTimeMillis { + test() + } + } + assert(uniffiBlockingTaskQueueHandleCount() == initialBlockingTaskQueueHandleCount) + assert(uniffiPollHandleCount() == initialPollHandleCount) +} + // init UniFFI to get good measurements after that -runBlocking { +runAsyncTest { val time = measureTimeMillis { alwaysReady() } @@ -24,7 +37,7 @@ fun assertApproximateTime(actualTime: Long, expectedTime: Int, testName: String } // Test `always_ready`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = alwaysReady() @@ -35,7 +48,7 @@ runBlocking { } // Test `void`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = void() @@ -46,7 +59,7 @@ runBlocking { } // Test `sleep`. -runBlocking { +runAsyncTest { val time = measureTimeMillis { sleep(200U) } @@ -55,7 +68,7 @@ runBlocking { } // Test sequential futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfter(100U, "Alice") val resultBob = sayAfter(200U, "Bob") @@ -68,7 +81,7 @@ runBlocking { } // Test concurrent futures. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = async { sayAfter(100U, "Alice") } val resultBob = async { sayAfter(200U, "Bob") } @@ -81,7 +94,7 @@ runBlocking { } // Test async methods. -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = megaphone.sayAfter(200U, "Alice") @@ -92,7 +105,7 @@ runBlocking { assertApproximateTime(time, 200, "async methods") } -runBlocking { +runAsyncTest { val megaphone = newMegaphone() val time = measureTimeMillis { val resultAlice = sayAfterWithMegaphone(megaphone, 200U, "Alice") @@ -104,7 +117,7 @@ runBlocking { } // Test async method returning optional object -runBlocking { +runAsyncTest { val megaphone = asyncMaybeNewMegaphone(true) assert(megaphone != null) @@ -141,7 +154,7 @@ runBlocking { } // Test with the Tokio runtime. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val resultAlice = sayAfterWithTokio(200U, "Alice") @@ -152,7 +165,7 @@ runBlocking { } // Test fallible function/method. -runBlocking { +runAsyncTest { val time1 = measureTimeMillis { try { fallibleMe(false) @@ -217,7 +230,7 @@ runBlocking { } // Test record. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val result = newMyRecord("foo", 42U) @@ -231,7 +244,7 @@ runBlocking { } // Test a broken sleep. -runBlocking { +runAsyncTest { val time = measureTimeMillis { brokenSleep(100U, 0U) // calls the waker twice immediately sleep(100U) // wait for possible failure @@ -245,7 +258,7 @@ runBlocking { // Test a future that uses a lock and that is cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { val job = launch { useSharedResource(SharedResourceOptions(releaseAfterMs=5000U, timeoutMs=100U)) @@ -264,7 +277,7 @@ runBlocking { } // Test a future that uses a lock and that is not cancelled. -runBlocking { +runAsyncTest { val time = measureTimeMillis { useSharedResource(SharedResourceOptions(releaseAfterMs=100U, timeoutMs=1000U)) @@ -272,3 +285,33 @@ runBlocking { } println("useSharedResource (not canceled): ${time}ms") } + +// Test blocking task queues +runAsyncTest { + withTimeout(1000) { + assert(calcSquare(Dispatchers.IO, 20) == 400) + } + + withTimeout(1000) { + assert(calcSquares(Dispatchers.IO, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + + val executors = listOf( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + ) + withTimeout(1000) { + assert(calcSquaresMultiQueue(executors.map { it.asCoroutineDispatcher() }, listOf(1, -2, 3)) == listOf(1, 4, 9)) + } + for (executor in executors) { + executor.shutdown() + } +} + +// Test blocking task queue cloning +runAsyncTest { + withTimeout(1000) { + assert(calcSquareWithClone(Dispatchers.IO, 20) == 400) + } +} diff --git a/fixtures/futures/tests/bindings/test_futures.py b/fixtures/futures/tests/bindings/test_futures.py index 7b2e12324f..4365689aab 100644 --- a/fixtures/futures/tests/bindings/test_futures.py +++ b/fixtures/futures/tests/bindings/test_futures.py @@ -1,26 +1,32 @@ +import futures from futures import * +import contextlib import unittest from datetime import datetime import asyncio import typing +from concurrent.futures import ThreadPoolExecutor def now(): return datetime.now() class TestFutures(unittest.TestCase): def test_always_ready(self): + @self.check_handle_counts() async def test(): self.assertEqual(await always_ready(), True) asyncio.run(test()) def test_void(self): + @self.check_handle_counts() async def test(): self.assertEqual(await void(), None) asyncio.run(test()) def test_sleep(self): + @self.check_handle_counts() async def test(): t0 = now() await sleep(2000) @@ -32,6 +38,7 @@ async def test(): asyncio.run(test()) def test_sequential_futures(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after(100, 'Alice') @@ -46,6 +53,7 @@ async def test(): asyncio.run(test()) def test_concurrent_tasks(self): + @self.check_handle_counts() async def test(): alice = asyncio.create_task(say_after(100, 'Alice')) bob = asyncio.create_task(say_after(200, 'Bob')) @@ -63,6 +71,7 @@ async def test(): asyncio.run(test()) def test_async_methods(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -106,6 +115,7 @@ async def test(): asyncio.run(test()) def test_async_object_param(self): + @self.check_handle_counts() async def test(): megaphone = new_megaphone() t0 = now() @@ -119,6 +129,7 @@ async def test(): asyncio.run(test()) def test_with_tokio_runtime(self): + @self.check_handle_counts() async def test(): t0 = now() result_alice = await say_after_with_tokio(200, 'Alice') @@ -131,6 +142,7 @@ async def test(): asyncio.run(test()) def test_fallible(self): + @self.check_handle_counts() async def test(): result = await fallible_me(False) self.assertEqual(result, 42) @@ -155,6 +167,7 @@ async def test(): asyncio.run(test()) def test_fallible_struct(self): + @self.check_handle_counts() async def test(): megaphone = await fallible_struct(False) self.assertEqual(await megaphone.fallible_me(False), 42) @@ -168,6 +181,7 @@ async def test(): asyncio.run(test()) def test_record(self): + @self.check_handle_counts() async def test(): result = await new_my_record("foo", 42) self.assertEqual(result.__class__, MyRecord) @@ -177,6 +191,7 @@ async def test(): asyncio.run(test()) def test_cancel(self): + @self.check_handle_counts() async def test(): # Create a task task = asyncio.create_task(say_after(200, 'Alice')) @@ -194,6 +209,7 @@ async def test(): # Test a future that uses a lock and that is cancelled. def test_shared_resource_cancellation(self): + @self.check_handle_counts() async def test(): task = asyncio.create_task(use_shared_resource( SharedResourceOptions(release_after_ms=5000, timeout_ms=100))) @@ -204,6 +220,7 @@ async def test(): asyncio.run(test()) def test_shared_resource_no_cancellation(self): + @self.check_handle_counts() async def test(): await use_shared_resource(SharedResourceOptions(release_after_ms=100, timeout_ms=1000)) await use_shared_resource(SharedResourceOptions(release_after_ms=0, timeout_ms=1000)) @@ -215,5 +232,47 @@ async def test(): self.assertEqual(typing.get_type_hints(sleep_no_return), {"ms": int, "return": type(None)}) asyncio.run(test()) + # blocking task queue tests + + def test_calc_square(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square(executor, 20), 400) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_square_with_clone(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_square_with_clone(executor, 20), 400) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_squares(self): + @self.check_handle_counts() + async def test(): + executor = ThreadPoolExecutor() + self.assertEqual(await calc_squares(executor, [1, -2, 3]), [1, 4, 9]) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + def test_calc_squares_multi_queue(self): + @self.check_handle_counts() + async def test(): + executors = [ + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ThreadPoolExecutor(), + ] + self.assertEqual(await calc_squares_multi_queue(executors, [1, -2, 3]), [1, 4, 9]) + asyncio.run(asyncio.wait_for(test(), timeout=1)) + + @contextlib.asynccontextmanager + async def check_handle_counts(self): + initial_poll_handle_count = len(futures.UNIFFI_POLL_DATA_POINTER_MANAGER) + initial_blocking_task_queue_handle_count = len(futures.UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP) + yield + self.assertEqual(len(futures.UNIFFI_POLL_DATA_POINTER_MANAGER), initial_poll_handle_count) + self.assertEqual(len(futures.UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP), initial_blocking_task_queue_handle_count) + if __name__ == '__main__': unittest.main() diff --git a/fixtures/futures/tests/bindings/test_futures.swift b/fixtures/futures/tests/bindings/test_futures.swift index 2fd413a8ab..e70cdbc5b7 100644 --- a/fixtures/futures/tests/bindings/test_futures.swift +++ b/fixtures/futures/tests/bindings/test_futures.swift @@ -3,10 +3,21 @@ import Foundation // To get `DispatchGroup` and `Date` types. var counter = DispatchGroup() -// Test `alwaysReady` -counter.enter() +func asyncTest(test: @escaping () async throws -> ()) { + let initialBlockingTaskQueueCount = uniffiBlockingTaskQueueHandleCountFutures() + let initialPollDataHandleCount = uniffiPollDataHandleCountFutures() + counter.enter() + Task { + try! await test() + counter.leave() + } + counter.wait() + assert(uniffiBlockingTaskQueueHandleCountFutures() == initialBlockingTaskQueueCount) + assert(uniffiPollDataHandleCountFutures() == initialPollDataHandleCount) +} -Task { +// Test `alwaysReady` +asyncTest { let t0 = Date() let result = await alwaysReady() let t1 = Date() @@ -14,40 +25,28 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) assert(result == true) - - counter.leave() } // Test record. -counter.enter() - -Task { +asyncTest { let result = await newMyRecord(a: "foo", b: 42) assert(result.a == "foo") assert(result.b == 42) - - counter.leave() } // Test `void` -counter.enter() - -Task { +asyncTest { let t0 = Date() await void() let t1 = Date() let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration < 0.1) - - counter.leave() } // Test `Sleep` -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = await sleep(ms: 2000) let t1 = Date() @@ -55,14 +54,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result == true) - - counter.leave() } // Test sequential futures. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfter(ms: 1000, who: "Alice") let result_bob = await sayAfter(ms: 2000, who: "Bob") @@ -72,14 +67,10 @@ Task { assert(tDelta.duration > 3 && tDelta.duration < 3.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test concurrent futures. -counter.enter() - -Task { +asyncTest { async let alice = sayAfter(ms: 1000, who: "Alice") async let bob = sayAfter(ms: 2000, who: "Bob") @@ -91,14 +82,10 @@ Task { assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice!") assert(result_bob == "Hello, Bob!") - - counter.leave() } // Test async methods -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -108,8 +95,6 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "HELLO, ALICE!") - - counter.leave() } // Test async trait interface methods @@ -151,21 +136,15 @@ Task { } // Test async function returning an object -counter.enter() - -Task { +asyncTest { let megaphone = await asyncNewMegaphone() let result = try await megaphone.fallibleMe(doFail: false) assert(result == 42) - - counter.leave() } // Test with the Tokio runtime. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result_alice = await sayAfterWithTokio(ms: 2000, who: "Alice") let t1 = Date() @@ -173,15 +152,11 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 2 && tDelta.duration < 2.1) assert(result_alice == "Hello, Alice (with Tokio)!") - - counter.leave() } // Test fallible function/method… // … which doesn't throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() let result = try await fallibleMe(doFail: false) let t1 = Date() @@ -189,19 +164,15 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } -Task { +asyncTest { let m = try await fallibleStruct(doFail: false) let result = try await m.fallibleMe(doFail: false) assert(result == 42) } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -211,14 +182,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) assert(result == 42) - - counter.leave() } // … which does throw. -counter.enter() - -Task { +asyncTest { let t0 = Date() do { @@ -233,11 +200,9 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } -Task { +asyncTest { do { let _ = try await fallibleStruct(doFail: true) } catch MyError.Foo { @@ -247,9 +212,7 @@ Task { } } -counter.enter() - -Task { +asyncTest { let megaphone = newMegaphone() let t0 = Date() @@ -266,13 +229,10 @@ Task { let tDelta = DateInterval(start: t0, end: t1) assert(tDelta.duration > 0 && tDelta.duration < 0.1) - - counter.leave() } // Test a future that uses a lock and that is cancelled. -counter.enter() -Task { +asyncTest { let task = Task { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) } @@ -288,15 +248,36 @@ Task { // Try accessing the shared resource again. The initial task should release the shared resource // before the timeout expires. try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } // Test a future that uses a lock and that is not cancelled. -counter.enter() -Task { +asyncTest { try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 100, timeoutMs: 1000)) try! await useSharedResource(options: SharedResourceOptions(releaseAfterMs: 0, timeoutMs: 1000)) - counter.leave() } -counter.wait() +// Test blocking task queues +asyncTest { + let calcSquareResult = await calcSquare(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) + + let calcSquaresResult = await calcSquares(queue: DispatchQueue.global(qos: .userInitiated), items: [1, -2, 3]) + assert(calcSquaresResult == [1, 4, 9]) + + let calcSquaresMultiQueueResult = await calcSquaresMultiQueue( + queues: [ + DispatchQueue(label: "test-queue1", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue2", attributes: DispatchQueue.Attributes.concurrent), + DispatchQueue(label: "test-queue3", attributes: DispatchQueue.Attributes.concurrent) + ], + items: [1, -2, 3] + ) + assert(calcSquaresMultiQueueResult == [1, 4, 9]) +} + +// Test blocking task queue cloning +asyncTest { + let calcSquareResult = await calcSquareWithClone(queue: DispatchQueue.global(qos: .userInitiated), value: 20) + assert(calcSquareResult == 400) +} + diff --git a/fixtures/metadata/src/tests.rs b/fixtures/metadata/src/tests.rs index 2a280597ea..1a54cbe067 100644 --- a/fixtures/metadata/src/tests.rs +++ b/fixtures/metadata/src/tests.rs @@ -129,6 +129,7 @@ mod test_type_ids { check_type_id::(Type::Float64); check_type_id::(Type::Boolean); check_type_id::(Type::String); + check_type_id::(Type::BlockingTaskQueue); } #[test] diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs new file mode 100644 index 0000000000..a664f1fcd3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self, _ci: &crate::ComponentInterface) -> String { + // Kotlin uses CoroutineContext for BlockingTaskQueue + "CoroutineContext".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs index fecf303913..4d0d7a314b 100644 --- a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs @@ -16,6 +16,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -466,6 +467,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), @@ -637,7 +640,7 @@ mod filters { ) -> Result { let ffi_func = callable.ffi_rust_future_poll(ci); Ok(format!( - "{{ future, callback, continuation -> UniffiLib.INSTANCE.{ffi_func}(future, callback, continuation) }}" + "{{ future, callback, continuation, blockingTaskQueueHandle -> UniFFILib.INSTANCE.{ffi_func}(future, callback, continuation, blockingTaskQueueHandle) }}" )) } diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt index dc547d4ddf..906646d944 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt @@ -3,19 +3,38 @@ internal const val UNIFFI_RUST_FUTURE_POLL_READY = 0.toByte() internal const val UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1.toByte() -internal val uniffiContinuationHandleMap = UniffiHandleMap>() +/** + * Data for an in-progress poll of a RustFuture + */ +internal data class UniffiPollData( + val continuation: CancellableContinuation, + val rustFuture: Long, + val pollFunc: (Long, UniffiRustFutureContinuationCallback, USize, Long) -> Unit, +) + +internal val uniffiPollDataHandleMap = UniffiHandleMap() // FFI type for Rust future continuations internal object uniffiRustFutureContinuationCallbackImpl: UniffiRustFutureContinuationCallback { - override fun callback(data: Long, pollResult: Byte) { - uniffiContinuationHandleMap.remove(data).resume(pollResult) + override fun callback(data: USize, pollResult: Byte, blockingTaskQueueHandle: Long) { + if (blockingTaskQueueHandle == 0L) { + // Complete the Kotlin continuation + uniffiPollDataHandleMap.remove(data)!!.continuation.resume(pollResult) + } else { + // Call the poll function again, but inside the BlockingTaskQueue coroutine context + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(blockingTaskQueueHandle) + val pollData = uniffiPollDataHandleMap.get(data)!! + CoroutineScope(coroutineContext).launch { + pollData.pollFunc(pollData.rustFuture, uniffiRustFutureContinuationCallbackImpl, data, blockingTaskQueueHandle) + } + } } } internal suspend fun uniffiRustCallAsync( rustFuture: Long, - pollFunc: (Long, UniffiRustFutureContinuationCallback, Long) -> Unit, - completeFunc: (Long, UniffiRustCallStatus) -> F, + pollFunc: (Long, UniffiRustFutureContinuationCallback, USize, Long) -> Unit, + completeFunc: (Long, RustCallStatus) -> F, freeFunc: (Long) -> Unit, liftFunc: (F) -> T, errorHandler: UniffiRustCallStatusErrorHandler @@ -23,11 +42,9 @@ internal suspend fun uniffiRustCallAsync( try { do { val pollResult = suspendCancellableCoroutine { continuation -> - pollFunc( - rustFuture, - uniffiRustFutureContinuationCallbackImpl, - uniffiContinuationHandleMap.insert(continuation) - ) + val pollData = UniffiPollData(continuation, rustFuture, pollFunc) + val pollDataHandle = uniffiPollDataHandleMap.insert(pollData) + pollFunc(rustFuture, uniffiRustFutureContinuationCallback, pollDataHandle, 0L) } } while (pollResult != UNIFFI_RUST_FUTURE_POLL_READY); @@ -39,3 +56,6 @@ internal suspend fun uniffiRustCallAsync( } } +// For testing +public fun uniffiPollHandleCount() = uniffiPollDataHandleMap.size + diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt new file mode 100644 index 0000000000..717ff62df3 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/templates/BlockingTaskQueueTemplate.kt @@ -0,0 +1,44 @@ +{{ self.add_import("kotlin.coroutines.CoroutineContext") }} + +object uniffiBlockingTaskQueueClone : UniffiBlockingTaskQueueClone { + override fun callback(handle: Long): Long { + val coroutineContext = uniffiBlockingTaskQueueHandleMap.get(handle) + return uniffiBlockingTaskQueueHandleMap.insert(coroutineContext) + } +} + +object uniffiBlockingTaskQueueFree : UniffiBlockingTaskQueueFree { + override fun callback(handle: Long) { + uniffiBlockingTaskQueueHandleMap.remove(handle) + } +} + +internal val uniffiBlockingTaskQueueVTable = UniffiBlockingTaskQueueVTable( + uniffiBlockingTaskQueueClone, + uniffiBlockingTaskQueueFree, +) +internal val uniffiBlockingTaskQueueHandleMap = ConcurrentHandleMap() + +public object {{ ffi_converter_name }}: FfiConverterRustBuffer { + override fun allocationSize(value: {{ type_name }}) = 16 + + override fun write(value: CoroutineContext, buf: ByteBuffer) { + // Call `write()` to make sure the data is written to the JNA backing data + uniffiBlockingTaskQueueVTable.write() + val handle = uniffiBlockingTaskQueueHandleMap.insert(value) + buf.putLong(handle) + buf.putLong(Pointer.nativeValue(uniffiBlockingTaskQueueVTable.getPointer())) + } + + override fun read(buf: ByteBuffer): CoroutineContext { + val handle = buf.getLong() + val coroutineContext = uniffiBlockingTaskQueueHandleMap.remove(handle)!! + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + buf.getLong() + return coroutineContext + } +} + +// For testing +public fun uniffiBlockingTaskQueueHandleCount() = uniffiBlockingTaskQueueHandleMap.size diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt b/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt index 81b0493d65..2bcbad39bb 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/CallbackInterfaceRuntime.kt @@ -6,8 +6,8 @@ internal const val UNIFFI_CALLBACK_SUCCESS = 0 internal const val UNIFFI_CALLBACK_ERROR = 1 internal const val UNIFFI_CALLBACK_UNEXPECTED_ERROR = 2 -public abstract class FfiConverterCallbackInterface: FfiConverter { - internal val handleMap = UniffiHandleMap() +public abstract class FfiConverterCallbackInterface: FfiConverter { + internal val handleMap = ConcurrentHandleMap() internal fun drop(handle: Long) { handleMap.remove(handle) diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt b/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt new file mode 100644 index 0000000000..ed2d6c8f71 --- /dev/null +++ b/uniffi_bindgen/src/bindings/kotlin/templates/ObjectRuntime.kt @@ -0,0 +1,159 @@ +{{- self.add_import("java.util.concurrent.atomic.AtomicBoolean") }} +// The base class for all UniFFI Object types. +// +// This class provides core operations for working with the Rust `Arc` pointer to +// the live Rust struct on the other side of the FFI. +// +// There's some subtlety here, because we have to be careful not to operate on a Rust +// struct after it has been dropped, and because we must expose a public API for freeing +// the Kotlin wrapper object in lieu of reliable finalizers. The core requirements are: +// +// * Each `FFIObject` instance holds an opaque pointer to the underlying Rust struct. +// Method calls need to read this pointer from the object's state and pass it in to +// the Rust FFI. +// +// * When an `FFIObject` is no longer needed, its pointer should be passed to a +// special destructor function provided by the Rust FFI, which will drop the +// underlying Rust struct. +// +// * Given an `FFIObject` instance, calling code is expected to call the special +// `destroy` method in order to free it after use, either by calling it explicitly +// or by using a higher-level helper like the `use` method. Failing to do so will +// leak the underlying Rust struct. +// +// * We can't assume that calling code will do the right thing, and must be prepared +// to handle Kotlin method calls executing concurrently with or even after a call to +// `destroy`, and to handle multiple (possibly concurrent!) calls to `destroy`. +// +// * We must never allow Rust code to operate on the underlying Rust struct after +// the destructor has been called, and must never call the destructor more than once. +// Doing so may trigger memory unsafety. +// +// If we try to implement this with mutual exclusion on access to the pointer, there is the +// possibility of a race between a method call and a concurrent call to `destroy`: +// +// * Thread A starts a method call, reads the value of the pointer, but is interrupted +// before it can pass the pointer over the FFI to Rust. +// * Thread B calls `destroy` and frees the underlying Rust struct. +// * Thread A resumes, passing the already-read pointer value to Rust and triggering +// a use-after-free. +// +// One possible solution would be to use a `ReadWriteLock`, with each method call taking +// a read lock (and thus allowed to run concurrently) and the special `destroy` method +// taking a write lock (and thus blocking on live method calls). However, we aim not to +// generate methods with any hidden blocking semantics, and a `destroy` method that might +// block if called incorrectly seems to meet that bar. +// +// So, we achieve our goals by giving each `FFIObject` an associated `AtomicLong` counter to track +// the number of in-flight method calls, and an `AtomicBoolean` flag to indicate whether `destroy` +// has been called. These are updated according to the following rules: +// +// * The initial value of the counter is 1, indicating a live object with no in-flight calls. +// The initial value for the flag is false. +// +// * At the start of each method call, we atomically check the counter. +// If it is 0 then the underlying Rust struct has already been destroyed and the call is aborted. +// If it is nonzero them we atomically increment it by 1 and proceed with the method call. +// +// * At the end of each method call, we atomically decrement and check the counter. +// If it has reached zero then we destroy the underlying Rust struct. +// +// * When `destroy` is called, we atomically flip the flag from false to true. +// If the flag was already true we silently fail. +// Otherwise we atomically decrement and check the counter. +// If it has reached zero then we destroy the underlying Rust struct. +// +// Astute readers may observe that this all sounds very similar to the way that Rust's `Arc` works, +// and indeed it is, with the addition of a flag to guard against multiple calls to `destroy`. +// +// The overall effect is that the underlying Rust struct is destroyed only when `destroy` has been +// called *and* all in-flight method calls have completed, avoiding violating any of the expectations +// of the underlying Rust code. +// +// In the future we may be able to replace some of this with automatic finalization logic, such as using +// the new "Cleaner" functionaility in Java 9. The above scheme has been designed to work even if `destroy` is +// invoked by garbage-collection machinery rather than by calling code (which by the way, it's apparently also +// possible for the JVM to finalize an object while there is an in-flight call to one of its methods [1], +// so there would still be some complexity here). +// +// Sigh...all of this for want of a robust finalization mechanism. +// +// [1] https://stackoverflow.com/questions/24376768/can-java-finalize-an-object-when-it-is-still-in-scope/24380219 +// +abstract class FFIObject: Disposable, AutoCloseable { + + constructor(pointer: Pointer) { + this.pointer = pointer + } + + /** + * This constructor can be used to instantiate a fake object. + * + * **WARNING: Any object instantiated with this constructor cannot be passed to an actual Rust-backed object.** + * Since there isn't a backing [Pointer] the FFI lower functions will crash. + * @param noPointer Placeholder value so we can have a constructor separate from the default empty one that may be + * implemented for classes extending [FFIObject]. + */ + @Suppress("UNUSED_PARAMETER") + constructor(noPointer: NoPointer) { + this.pointer = null + } + + protected val pointer: Pointer? + + private val wasDestroyed = AtomicBoolean(false) + private val callCounter = AtomicLong(1) + + open protected fun freeRustArcPtr() { + // Overridden by generated subclasses, the default method exists to allow users to manually + // implement the interface + } + + open fun uniffiClonePointer(): Pointer { + // Overridden by generated subclasses, the default method exists to allow users to manually + // implement the interface + throw RuntimeException("uniffiClonePointer not implemented") + } + + override fun destroy() { + // Only allow a single call to this method. + // TODO: maybe we should log a warning if called more than once? + if (this.wasDestroyed.compareAndSet(false, true)) { + // This decrement always matches the initial count of 1 given at creation time. + if (this.callCounter.decrementAndGet() == 0L) { + this.freeRustArcPtr() + } + } + } + + @Synchronized + override fun close() { + this.destroy() + } + + internal inline fun callWithPointer(block: (ptr: Pointer) -> R): R { + // Check and increment the call counter, to keep the object alive. + // This needs a compare-and-set retry loop in case of concurrent updates. + do { + val c = this.callCounter.get() + if (c == 0L) { + throw IllegalStateException("${this.javaClass.simpleName} object has already been destroyed") + } + if (c == Long.MAX_VALUE) { + throw IllegalStateException("${this.javaClass.simpleName} call counter would overflow") + } + } while (! this.callCounter.compareAndSet(c, c + 1L)) + // Now we can safely do the method call without the pointer being freed concurrently. + try { + return block(this.uniffiClonePointer()) + } finally { + // This decrement always matches the increment we performed above. + if (this.callCounter.decrementAndGet() == 0L) { + this.freeRustArcPtr() + } + } + } +} + +/** Used to instantiate a [FFIObject] without an actual pointer, for fakes in tests, mostly. */ +object NoPointer diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt index ba56716401..552e269ab9 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Types.kt @@ -89,6 +89,9 @@ object NoPointer {%- when Type::Bytes %} {%- include "ByteArrayHelper.kt" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.kt" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {%- if !ci.is_name_used_as_error(name) %} @@ -134,6 +137,8 @@ object NoPointer {%- if ci.has_async_fns() %} {# Import types needed for async support #} {{ self.add_import("kotlin.coroutines.resume") }} +{{ self.add_import("kotlinx.coroutines.launch") }} {{ self.add_import("kotlinx.coroutines.suspendCancellableCoroutine") }} {{ self.add_import("kotlinx.coroutines.CancellableContinuation") }} +{{ self.add_import("kotlinx.coroutines.CoroutineScope") }} {%- endif %} diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt index 2cdc72a5e2..ea326bb242 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/wrapper.kt @@ -32,6 +32,7 @@ import java.nio.CharBuffer import java.nio.charset.CodingErrorAction import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.withLock {%- for req in self.imports() %} {{ req.render() }} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs new file mode 100644 index 0000000000..2f5a74fda0 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/gen_python/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + // On python we use an concurrent.futures.Executor for a BlockingTaskQueue + fn type_label(&self) -> String { + "concurrent.futures.Executor".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs index 6c1996eb5c..cc10d0cae5 100644 --- a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs +++ b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs @@ -16,6 +16,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -426,6 +427,8 @@ impl AsCodeType for T { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, .. } => Box::new(object::ObjectCodeType::new(name)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/python/templates/Async.py b/uniffi_bindgen/src/bindings/python/templates/Async.py index 4a230112ea..e4fc81d562 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Async.py +++ b/uniffi_bindgen/src/bindings/python/templates/Async.py @@ -2,33 +2,67 @@ _UNIFFI_RUST_FUTURE_POLL_READY = 0 _UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1 +""" +Data for an in-progress poll of a RustFuture +""" +class UniffiPoll(typing.NamedTuple): + eventloop: asyncio.AbstractEventLoop + py_future: asyncio.Future + rust_future: int + # Must be UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK, but it's not clear how to specify as valid + # type for mypy and our current Python version + poll_fn: object + # Stores futures for _uniffi_continuation_callback _UniffiContinuationHandleMap = _UniffiHandleMap() +_UniffiPollDataHandleMap = _UniffiHandleMap() # Continuation callback for async functions # lift the return value or error and resolve the future, causing the async function to resume. @UNIFFI_RUST_FUTURE_CONTINUATION_CALLBACK -def _uniffi_continuation_callback(future_ptr, poll_code): - (eventloop, future) = _UniffiContinuationHandleMap.remove(future_ptr) - eventloop.call_soon_threadsafe(_uniffi_set_future_result, future, poll_code) +def _uniffi_continuation_callback(poll_data_handle, poll_code, blocking_task_queue_handle): + if blocking_task_queue_handle == 0: + # Complete the Python Future + poll_data = _UniffiPollDataHandleMap.release_pointer(poll_data_handle) + poll_data.eventloop.call_soon_threadsafe(_uniffi_set_future_result, poll_data.py_future, poll_code) + else: + # Call the poll function again, but inside the executor + poll_data = _UniffiPollDataHandleMap.lookup(poll_data_handle) + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(blocking_task_queue_handle) + executor.submit( + poll_data.poll_fn, + poll_data.rust_future, + _uniffi_continuation_callback, + poll_data_handle, + blocking_task_queue_handle + ) def _uniffi_set_future_result(future, poll_code): if not future.cancelled(): future.set_result(poll_code) -async def _uniffi_rust_call_async(rust_future, ffi_poll, ffi_complete, ffi_free, lift_func, error_ffi_converter): +async def _uniffi_rust_call_async(rust_future, poll_fn, ffi_complete, ffi_free, lift_func, error_ffi_converter): try: eventloop = asyncio.get_running_loop() - # Loop and poll until we see a _UNIFFI_RUST_FUTURE_POLL_READY value + # Loop and poll until we see a UNIFFI_RUST_FUTURE_POLL_READY value while True: - future = eventloop.create_future() - ffi_poll( + py_future = eventloop.create_future() + poll_data = UniffiPoll( + eventloop=eventloop, + py_future=py_future, + rust_future=rust_future, + poll_fn=poll_fn, + ) + poll_handle = _UniffiPollDataHandleMap.new_pointer(poll_data) + poll_fn( rust_future, _uniffi_continuation_callback, _UniffiContinuationHandleMap.insert((eventloop, future)), + poll_handle, + 0, ) - poll_code = await future + poll_code = await py_future if poll_code == _UNIFFI_RUST_FUTURE_POLL_READY: break diff --git a/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py new file mode 100644 index 0000000000..329ff37906 --- /dev/null +++ b/uniffi_bindgen/src/bindings/python/templates/BlockingTaskQueueTemplate.py @@ -0,0 +1,39 @@ +{{ self.add_import("concurrent.futures") }} + +@UNIFFI_BLOCKING_TASK_QUEUE_CLONE +def uniffi_blocking_task_queue_clone(handle): + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.get(handle) + return UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(executor) + +@UNIFFI_BLOCKING_TASK_QUEUE_FREE +def uniffi_blocking_task_queue_free(handle): + UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle) + +UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + uniffi_blocking_task_queue_clone, + uniffi_blocking_task_queue_free, +) + +UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP = ConcurrentHandleMap() + +class {{ ffi_converter_name }}(_UniffiConverterRustBuffer): + @staticmethod + def check_lower(value): + if not isinstance(value, concurrent.futures.Executor): + raise TypeError("Expected concurrent.futures.Executor instance, {} found".format(type(value).__name__)) + + @staticmethod + def write(value, buf): + handle = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.insert(value) + buf.write_u64(handle) + buf.write_u64(ctypes.addressof(UNIFFI_BLOCKING_TASK_QUEUE_VTABLE)) + + @staticmethod + def read(buf): + handle = buf.read_u64() + executor = UNIFFI_BLOCKING_TASK_QUEUE_HANDLE_MAP.remove(handle) + # Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + # same value. + buf.read_u64() + return executor + diff --git a/uniffi_bindgen/src/bindings/python/templates/Helpers.py b/uniffi_bindgen/src/bindings/python/templates/Helpers.py index b349e3ac29..889d045e41 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Helpers.py +++ b/uniffi_bindgen/src/bindings/python/templates/Helpers.py @@ -83,3 +83,43 @@ def _uniffi_trait_interface_call_with_error(call_status, make_call, write_return except Exception as e: call_status.code = _UniffiRustCallStatus.CALL_UNEXPECTED_ERROR call_status.error_buf = {{ Type::String.borrow()|lower_fn }}(repr(e)) + +class ConcurrentHandleMap: + """ + A map where inserting, getting and removing data is synchronized with a lock. + + TODO: consolidate this with `PointerManager` + https://github.com/mozilla/uniffi-rs/pull/1823 + """ + + def __init__(self): + # type Handle = int + self._left_map = {} # type: Dict[Handle, Any] + + self._lock = threading.Lock() + # Start with 1 so that 0 can be special-cased as the null value. + self._current_handle = 1 + self._stride = 1 + + def insert(self, obj): + with self._lock: + handle = self._current_handle + self._current_handle += self._stride + self._left_map[handle] = obj + return handle + + def get(self, handle): + with self._lock: + obj = self._left_map.get(handle) + if not obj: + raise InternalError("No callback in handlemap; this is a uniffi bug") + return obj + + def remove(self, handle): + with self._lock: + if handle in self._left_map: + obj = self._left_map.pop(handle) + return obj + + def __len__(self): + return len(self._left_map) diff --git a/uniffi_bindgen/src/bindings/python/templates/Types.py b/uniffi_bindgen/src/bindings/python/templates/Types.py index 4aaed253e0..9f2840fefb 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Types.py +++ b/uniffi_bindgen/src/bindings/python/templates/Types.py @@ -55,6 +55,9 @@ {%- when Type::Bytes %} {%- include "BytesHelper.py" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.py" %} + {%- when Type::Enum { name, module_path } %} {%- let e = ci.get_enum_definition(name).unwrap() %} {# For enums, there are either an error *or* an enum, they can't be both. #} diff --git a/uniffi_bindgen/src/bindings/python/templates/wrapper.py b/uniffi_bindgen/src/bindings/python/templates/wrapper.py index 2050b8d589..0c2a7a3cc3 100644 --- a/uniffi_bindgen/src/bindings/python/templates/wrapper.py +++ b/uniffi_bindgen/src/bindings/python/templates/wrapper.py @@ -26,6 +26,7 @@ import threading import itertools import typing +import threading {%- if ci.has_async_fns() %} import asyncio {%- endif %} diff --git a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs index 07da3882b6..cb60300d65 100644 --- a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs +++ b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs @@ -57,6 +57,7 @@ pub fn canonical_name(t: &Type) -> String { Type::CallbackInterface { name, .. } => format!("CallbackInterface{name}"), Type::Timestamp => "Timestamp".into(), Type::Duration => "Duration".into(), + Type::BlockingTaskQueue => "BlockingTaskQueue".into(), // Recursive types. // These add a prefix to the name of the underlying type. // The component API definition cannot give names to recursive types, so as long as the @@ -262,6 +263,7 @@ mod filters { } Type::External { .. } => panic!("No support for external types, yet"), Type::Custom { .. } => panic!("No support for custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -315,6 +317,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lowering external types, yet"), Type::Custom { .. } => panic!("No support for lowering custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } @@ -355,6 +358,7 @@ mod filters { ), Type::External { .. } => panic!("No support for lifting external types, yet"), Type::Custom { .. } => panic!("No support for lifting custom types, yet"), + Type::BlockingTaskQueue => panic!("No support for async functions, yet"), }) } } diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs new file mode 100644 index 0000000000..ab4df07f9b --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/blocking_task_queue.rs @@ -0,0 +1,19 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::CodeType; + +#[derive(Debug)] +pub struct BlockingTaskQueueCodeType; + +impl CodeType for BlockingTaskQueueCodeType { + fn type_label(&self) -> String { + // On Swift, we use a DispatchQueue for BlockingTaskQueue + "DispatchQueue".into() + } + + fn canonical_name(&self) -> String { + "BlockingTaskQueue".into() + } +} diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs index 5717041b32..51b9a26461 100644 --- a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs @@ -18,6 +18,7 @@ use crate::backend::TemplateExpression; use crate::interface::*; use crate::BindingsConfig; +mod blocking_task_queue; mod callback_interface; mod compounds; mod custom; @@ -462,6 +463,8 @@ impl SwiftCodeOracle { Type::Timestamp => Box::new(miscellany::TimestampCodeType), Type::Duration => Box::new(miscellany::DurationCodeType), + Type::BlockingTaskQueue => Box::new(blocking_task_queue::BlockingTaskQueueCodeType), + Type::Enum { name, .. } => Box::new(enum_::EnumCodeType::new(name)), Type::Object { name, imp, .. } => Box::new(object::ObjectCodeType::new(name, imp)), Type::Record { name, .. } => Box::new(record::RecordCodeType::new(name)), diff --git a/uniffi_bindgen/src/bindings/swift/templates/Async.swift b/uniffi_bindgen/src/bindings/swift/templates/Async.swift index 761e0dd70e..b8675bc653 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Async.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Async.swift @@ -1,13 +1,30 @@ private let UNIFFI_RUST_FUTURE_POLL_READY: Int8 = 0 private let UNIFFI_RUST_FUTURE_POLL_MAYBE_READY: Int8 = 1 -fileprivate let uniffiContinuationHandleMap = UniffiHandleMap>() +// Data for an in-progress poll of a RustFuture +fileprivate class UniffiPollData { + let continuation: UnsafeContinuation + let rustFuture: UInt64 + let pollFunc: (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64, UInt64) -> () + + init( + continuation: UnsafeContinuation, + rustFuture: UInt64, + pollFunc: @escaping (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64, UInt64) -> () + ) { + self.continuation = continuation + self.rustFuture = rustFuture + self.pollFunc = pollFunc + } +} + +fileprivate let uniffiPollDataHandleMap = UniffiHandleMap() fileprivate func uniffiRustCallAsync( - rustFutureFunc: () -> UInt64, - pollFunc: (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64) -> (), - completeFunc: (UInt64, UnsafeMutablePointer) -> F, - freeFunc: (UInt64) -> (), + rustFutureFunc: () -> Int64, + pollFunc: @escaping (Int64, @escaping UniffiRustFutureContinuationCallback, Int64, UInt64) -> (), + completeFunc: (Int64, UnsafeMutablePointer) -> F, + freeFunc: (Int64) -> (), liftFunc: (F) throws -> T, errorHandler: ((RustBuffer) throws -> Error)? ) async throws -> T { @@ -21,11 +38,20 @@ fileprivate func uniffiRustCallAsync( var pollResult: Int8; repeat { pollResult = await withUnsafeContinuation { + let pollData = UniffiPollData( + continuation: $0, + rustFuture: rustFuture, + pollFunc: pollFunc + ) pollFunc( rustFuture, uniffiFutureContinuationCallback, - uniffiContinuationHandleMap.insert(obj: $0) + uniffiPollDataHandleMap.insert(obj: pollData) ) + + UNIFFI_POLL_DATA_HANDLE_COUNT += 1 + pollFunc(rustFuture, uniffiFutureContinuationCallback, pollDataPtr, 0) + } } while pollResult != UNIFFI_RUST_FUTURE_POLL_READY @@ -37,10 +63,30 @@ fileprivate func uniffiRustCallAsync( // Callback handlers for an async calls. These are invoked by Rust when the future is ready. They // lift the return value or error and resume the suspended function. -fileprivate func uniffiFutureContinuationCallback(handle: UInt64, pollResult: Int8) { - if let continuation = try? uniffiContinuationHandleMap.remove(handle: handle) { - continuation.resume(returning: pollResult) +fileprivate func uniffiFutureContinuationCallback( + pollDataHandle: UInt64, + pollResult: Int8, + blockingTaskQueueHandle: UInt64 +) { + if let pollData = try? uniffiPollDataHandleMap.remove(handle: pollDataHandle) { + pollData.continuation.resume(returning: pollResult) + if (blockingTaskQueueHandle == 0) { + // Complete the Swift continutation + pollData.continuation.resume(returning: pollResult) + } else { + // Call the poll function again, but inside the DispatchQuee + let queue = uniffiBlockingTaskQueueHandleMap.get(handle: blockingTaskQueueHandle)! + queue.async { + pollData.pollFunc(pollData.rustFuture, uniffiFutureContinuationCallback, pollDataHandle, blockingTaskQueueHandle) + } + } } else { print("uniffiFutureContinuationCallback invalid handle") } } + + +// For testing +public func uniffiPollDataHandleCount{{ ci.namespace()|class_name }}() -> Int { + return uniffiPollDataHandleMap.count +} diff --git a/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift new file mode 100644 index 0000000000..3b02a7c403 --- /dev/null +++ b/uniffi_bindgen/src/bindings/swift/templates/BlockingTaskQueueTemplate.swift @@ -0,0 +1,37 @@ +fileprivate var UNIFFI_BLOCKING_TASK_QUEUE_VTABLE = UniffiBlockingTaskQueueVTable( + clone: { (handle: UInt64) -> UInt64 in + let dispatchQueue = uniffiBlockingTaskQueueHandleMap.get(handle: handle)! + return uniffiBlockingTaskQueueHandleMap.insert(obj: dispatchQueue) + }, + free: { (handle: UInt64) in + uniffiBlockingTaskQueueHandleMap.remove(handle: handle) + } +) +fileprivate var uniffiBlockingTaskQueueHandleMap = UniffiHandleMap() + +fileprivate struct {{ ffi_converter_name }}: FfiConverterRustBuffer { + typealias SwiftType = DispatchQueue + + public static func write(_ value: DispatchQueue, into buf: inout [UInt8]) { + let handle = uniffiBlockingTaskQueueHandleMap.insert(obj: value) + writeInt(&buf, handle) + // From Apple: "You can safely use the address of a global variable as a persistent unique + // pointer value" (https://developer.apple.com/swift/blog/?id=6) + let vtablePointer = UnsafeMutablePointer(&UNIFFI_BLOCKING_TASK_QUEUE_VTABLE) + // Convert the pointer to a word-sized Int then to a 64-bit int then write it out. + writeInt(&buf, Int64(Int(bitPattern: vtablePointer))) + } + + public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> DispatchQueue { + let handle: UInt64 = try readInt(&buf) + // Read the VTable pointer and throw it out. The vtable is only used by Rust and always the + // same value. + let _: UInt64 = try readInt(&buf) + return uniffiBlockingTaskQueueHandleMap.remove(handle: handle)! + } +} + +// For testing +public func uniffiBlockingTaskQueueHandleCount{{ ci.namespace()|class_name }}() -> Int { + uniffiBlockingTaskQueueHandleMap.count +} diff --git a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift index 7aa1cca9b2..7c6c8c4750 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/CallbackInterfaceTemplate.swift @@ -21,7 +21,7 @@ extension {{ ffi_converter_name }} : FfiConverter { typealias FfiType = UInt64 public static func lift(_ handle: UInt64) throws -> SwiftType { - try handleMap.get(handle: handle) + return try handleMap.get(handle: handle) } public static func read(from buf: inout (data: Data, offset: Data.Index)) throws -> SwiftType { diff --git a/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift b/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift index af0305872b..85dfca540f 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/HandleMap.swift @@ -30,5 +30,11 @@ fileprivate class UniffiHandleMap { return obj } } + + var count: Int { + get { + return map.count + } + } } diff --git a/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift b/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift index cfddf7b313..5cd149464d 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Helpers.swift @@ -52,6 +52,69 @@ fileprivate extension RustCallStatus { } } +fileprivate class UniffiHandleMap { + private var leftMap: [UInt64: T] = [:] + private var counter: [UInt64: UInt64] = [:] + private var rightMap: [ObjectIdentifier: UInt64] = [:] + + private let lock = NSLock() + // Start with 1 so that 0 can be special-cased as the null value. + private var currentHandle: UInt64 = 1 + private let stride: UInt64 = 1 + + func insert(obj: T) -> UInt64 { + lock.withLock { + let id = ObjectIdentifier(obj as AnyObject) + let handle = rightMap[id] ?? { + currentHandle += stride + let handle = currentHandle + leftMap[handle] = obj + rightMap[id] = handle + return handle + }() + counter[handle] = (counter[handle] ?? 0) + 1 + return handle + } + } + + func get(handle: UInt64) -> T? { + lock.withLock { + leftMap[handle] + } + } + + func delete(handle: UInt64) { + remove(handle: handle) + } + + @discardableResult + func remove(handle: UInt64) -> T? { + lock.withLock { + defer { counter[handle] = (counter[handle] ?? 1) - 1 } + guard counter[handle] == 1 else { return leftMap[handle] } + let obj = leftMap.removeValue(forKey: handle) + if let obj = obj { + rightMap.removeValue(forKey: ObjectIdentifier(obj as AnyObject)) + } + return obj + } + } + + var count: Int { + get { + leftMap.count + } + } +} + +fileprivate extension NSLock { + func withLock(f: () throws -> T) rethrows -> T { + self.lock() + defer { self.unlock() } + return try f() + } +} + private func rustCall(_ callback: (UnsafeMutablePointer) -> T) throws -> T { try makeRustCall(callback, errorHandler: nil) } diff --git a/uniffi_bindgen/src/bindings/swift/templates/Types.swift b/uniffi_bindgen/src/bindings/swift/templates/Types.swift index 5e26758f3c..ba4d2059c8 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Types.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Types.swift @@ -64,6 +64,9 @@ {%- when Type::CallbackInterface { name, module_path } %} {%- include "CallbackInterfaceTemplate.swift" %} +{%- when Type::BlockingTaskQueue %} +{%- include "BlockingTaskQueueTemplate.swift" %} + {%- when Type::Custom { name, module_path, builtin } %} {%- include "CustomType.swift" %} diff --git a/uniffi_bindgen/src/interface/ffi.rs b/uniffi_bindgen/src/interface/ffi.rs index 19354e16dc..cfab6ed823 100644 --- a/uniffi_bindgen/src/interface/ffi.rs +++ b/uniffi_bindgen/src/interface/ffi.rs @@ -129,7 +129,8 @@ impl From<&Type> for FfiType { | Type::Sequence { .. } | Type::Map { .. } | Type::Timestamp - | Type::Duration => FfiType::RustBuffer(None), + | Type::Duration + | Type::BlockingTaskQueue => FfiType::RustBuffer(None), Type::External { name, kind: ExternalKind::Interface, diff --git a/uniffi_bindgen/src/interface/mod.rs b/uniffi_bindgen/src/interface/mod.rs index a91656caab..f036eb27b5 100644 --- a/uniffi_bindgen/src/interface/mod.rs +++ b/uniffi_bindgen/src/interface/mod.rs @@ -67,9 +67,7 @@ mod record; pub use record::{Field, Record}; pub mod ffi; -pub use ffi::{ - FfiArgument, FfiCallbackFunction, FfiDefinition, FfiField, FfiFunction, FfiStruct, FfiType, -}; +pub use ffi::{FfiArgument, FfiCallbackFunction, FfiDefinition, FfiField, FfiFunction, FfiStruct, FfiType}; pub use uniffi_meta::Radix; use uniffi_meta::{ ConstructorMetadata, LiteralMetadata, NamespaceMetadata, ObjectMetadata, TraitMethodMetadata, @@ -224,10 +222,16 @@ impl ComponentInterface { .chain(self.objects.iter().flat_map(|o| o.ffi_callbacks())) } + } + /// Get the definitions for callback FFI functions /// /// These are defined by the foreign code and invoked by Rust. fn callback_interface_vtable_definitions(&self) -> impl IntoIterator + '_ { + + pub fn callback_interface_vtable_definitions( + &self, + ) -> impl IntoIterator + '_ { self.callback_interface_definitions() .iter() .map(|cbi| cbi.vtable_definition()) @@ -503,6 +507,10 @@ impl ComponentInterface { name: "callback_data".to_owned(), type_: FfiType::Handle, }, + FfiArgument { + name: "blocking_task_queue_handle".to_owned(), + type_: FfiType::UInt64, + }, ], return_type: None, has_rust_call_status_arg: false, @@ -632,6 +640,18 @@ impl ComponentInterface { has_rust_call_status_arg: false, } .into(), + FfiCallbackFunction { + name: "BlockingTaskQueueClone".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: Some(FfiType::UInt64), + has_rust_call_status_arg: false, + }, + FfiCallbackFunction { + name: "BlockingTaskQueueFree".to_owned(), + arguments: vec![FfiArgument::new("handle", FfiType::UInt64)], + return_type: None, + has_rust_call_status_arg: false, + }, FfiStruct { name: "ForeignFuture".to_owned(), fields: vec![ @@ -640,6 +660,19 @@ impl ComponentInterface { ], } .into(), + FfiStruct { + name: "BlockingTaskQueueVTable".to_owned(), + fields: vec![ + FfiField::new( + "clone", + FfiType::Callback("BlockingTaskQueueClone".to_owned()), + ), + FfiField::new( + "free", + FfiType::Callback("BlockingTaskQueueFree".to_owned()), + ), + ], + } ] .into_iter() } @@ -843,6 +876,12 @@ impl ComponentInterface { .map(|n| self.errors.insert(n.to_string())); self.functions.push(defn); + if defn.is_async() { + self.types + .add_known_type(&uniffi_meta::Type::BlockingTaskQueue)?; + } + + self.functions.push(defn); Ok(()) } diff --git a/uniffi_bindgen/src/interface/universe.rs b/uniffi_bindgen/src/interface/universe.rs index 70bc61f8a9..2faef72fd6 100644 --- a/uniffi_bindgen/src/interface/universe.rs +++ b/uniffi_bindgen/src/interface/universe.rs @@ -84,6 +84,7 @@ impl TypeUniverse { Type::Bytes => self.add_type_definition("bytes", type_)?, Type::Timestamp => self.add_type_definition("timestamp", type_)?, Type::Duration => self.add_type_definition("duration", type_)?, + Type::BlockingTaskQueue => self.add_type_definition("BlockingTaskQueue", type_)?, Type::Object { name, .. } | Type::Record { name, .. } | Type::Enum { name, .. } diff --git a/uniffi_bindgen/src/scaffolding/mod.rs b/uniffi_bindgen/src/scaffolding/mod.rs index 7fd81831aa..231e0495d3 100644 --- a/uniffi_bindgen/src/scaffolding/mod.rs +++ b/uniffi_bindgen/src/scaffolding/mod.rs @@ -45,6 +45,7 @@ mod filters { format!("std::sync::Arc<{}>", imp.rust_name_for(name)) } Type::CallbackInterface { name, .. } => format!("Box"), + Type::BlockingTaskQueue => "::uniffi::BlockingTaskQueue".to_owned(), Type::Optional { inner_type } => { format!("std::option::Option<{}>", type_rs(inner_type)?) } diff --git a/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs new file mode 100644 index 0000000000..50abe5a66f --- /dev/null +++ b/uniffi_core/src/ffi/rustfuture/blocking_task_queue.rs @@ -0,0 +1,69 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +//! Defines the BlockingTaskQueue struct +//! +//! This module is responsible for the general handling of BlockingTaskQueue instances (cloning, droping, etc). +//! See `scheduler.rs` and the foreign bindings code for how the async functionality is implemented. + +use super::scheduler::schedule_in_blocking_task_queue; +use std::num::NonZeroU64; + +/// Foreign-managed blocking task queue that we can use to schedule futures +/// +/// On the foreign side this is a Kotlin `CoroutineContext`, Python `Executor` or Swift +/// `DispatchQueue`. UniFFI converts those objects into this struct for the Rust code to use. +/// +/// Rust async code can call [BlockingTaskQueue::execute] to run a closure in that +/// blocking task queue. Use this for functions with blocking operations that should not be executed +/// in a normal async context. Some examples are non-async file/network operations, long-running +/// CPU-bound tasks, blocking database operations, etc. +#[repr(C)] +pub struct BlockingTaskQueue { + /// Opaque handle for the task queue + pub handle: NonZeroU64, + /// Method VTable + /// + /// This is simply a C struct where each field is a function pointer that inputs a + /// BlockingTaskQueue handle + pub vtable: &'static BlockingTaskQueueVTable, +} + +#[repr(C)] +#[derive(Debug)] +pub struct BlockingTaskQueueVTable { + clone: extern "C" fn(u64) -> u64, + drop: extern "C" fn(u64), +} + +// Note: see `scheduler.rs` for details on how BlockingTaskQueue is used. +impl BlockingTaskQueue { + /// Run a closure in a blocking task queue + pub async fn execute(&self, f: F) -> R + where + F: FnOnce() -> R, + { + schedule_in_blocking_task_queue(self.handle).await; + f() + } +} + +impl Clone for BlockingTaskQueue { + fn clone(&self) -> Self { + let raw_handle = (self.vtable.clone)(self.handle.into()); + let handle = raw_handle + .try_into() + .expect("BlockingTaskQueue.clone() returned 0"); + Self { + handle, + vtable: self.vtable, + } + } +} + +impl Drop for BlockingTaskQueue { + fn drop(&mut self) { + (self.vtable.drop)(self.handle.into()) + } +} diff --git a/uniffi_core/src/ffi/rustfuture/future.rs b/uniffi_core/src/ffi/rustfuture/future.rs index 93c34e7543..26781206d2 100644 --- a/uniffi_core/src/ffi/rustfuture/future.rs +++ b/uniffi_core/src/ffi/rustfuture/future.rs @@ -21,6 +21,10 @@ //! 2b. If the async function is cancelled, then call [rust_future_cancel]. This causes the //! continuation function to be called with [RustFuturePoll::Ready] and the [RustFuture] to //! enter a cancelled state. +//! 2c. If the Rust code wants schedule work to be run in a `BlockingTaskQueue`, then the +//! continuation is called with [RustFuturePoll::MaybeReady] and the blocking task queue handle. +//! The foreign code is responsible for ensuring the next [rust_future_poll] call happens in +//! that blocking task queue and the handle is passed to [rust_future_poll]. //! 3. Call [rust_future_complete] to get the result of the future. //! 4. Call [rust_future_free] to free the future, ideally in a finally block. This: //! - Releases any resources held by the future @@ -78,6 +82,7 @@ use std::{ future::Future, marker::PhantomData, + num::NonZeroU64, ops::Deref, panic, pin::Pin, @@ -85,8 +90,8 @@ use std::{ task::{Context, Poll, Wake}, }; -use super::{RustFutureContinuationCallback, RustFuturePoll, Scheduler}; -use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus}; +use super::{scheduler, RustFutureContinuationCallback, Scheduler}; +use crate::{rust_call_with_out_status, FfiDefault, LowerReturn, RustCallStatus, RustFuturePoll}; /// Wraps the actual future we're polling struct WrappedFuture @@ -223,17 +228,26 @@ where }) } - pub(super) fn poll(self: Arc, callback: RustFutureContinuationCallback, data: u64) { + pub(super) fn poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ) { + scheduler::on_poll_start(blocking_task_queue_handle); + // Clear out the waked flag, since we're about to poll right now. + self.scheduler.lock().unwrap().clear_wake_flag(); let ready = self.is_cancelled() || { let mut locked = self.future.lock().unwrap(); let waker: std::task::Waker = Arc::clone(&self).into(); locked.poll(&mut Context::from_waker(&waker)) }; if ready { - callback(data, RustFuturePoll::Ready) + callback(data, RustFuturePoll::Ready, 0) } else { self.scheduler.lock().unwrap().store(callback, data); } + scheduler::on_poll_end(); } pub(super) fn is_cancelled(&self) -> bool { @@ -289,7 +303,12 @@ where /// only create those functions for each of the 13 possible FFI return types. #[doc(hidden)] pub trait RustFutureFfi: Send + Sync { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: u64); + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ); fn ffi_cancel(&self); fn ffi_complete(&self, call_status: &mut RustCallStatus) -> ReturnType; fn ffi_free(self: Arc); @@ -302,8 +321,13 @@ where T: LowerReturn + Send + 'static, UT: Send + 'static, { - fn ffi_poll(self: Arc, callback: RustFutureContinuationCallback, data: u64) { - self.poll(callback, data) + fn ffi_poll( + self: Arc, + callback: RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: Option, + ) { + self.poll(callback, data, blocking_task_queue_handle) } fn ffi_cancel(&self) { diff --git a/uniffi_core/src/ffi/rustfuture/mod.rs b/uniffi_core/src/ffi/rustfuture/mod.rs index 39529f2db1..bb064d2e65 100644 --- a/uniffi_core/src/ffi/rustfuture/mod.rs +++ b/uniffi_core/src/ffi/rustfuture/mod.rs @@ -4,8 +4,11 @@ use std::{future::Future, sync::Arc}; +mod blocking_task_queue; mod future; mod scheduler; + +pub use blocking_task_queue::*; use future::*; use scheduler::*; @@ -28,7 +31,19 @@ pub enum RustFuturePoll { /// /// The Rust side of things calls this when the foreign side should call [rust_future_poll] again /// to continue progress on the future. -pub type RustFutureContinuationCallback = extern "C" fn(callback_data: u64, RustFuturePoll); +/// +/// WARNING: the call to [rust_future_poll] must be scheduled to happen soon after the callback is +/// called, but not inside the callback itself. If [rust_future_poll] is called inside the +/// callback, some futures will deadlock and our scheduler code might as well. +/// +/// * `callback_data` is the handle that the foreign code passed to `poll()` +/// * `poll_result` is the result of the poll +/// * If `blocking_task_task_queue` is non-zero, it's the BlockingTaskQueue handle that the next `poll()` should run on +pub type RustFutureContinuationCallback = extern "C" fn( + callback_data: u64, + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +); // === Public FFI API === @@ -62,6 +77,9 @@ where /// a [RustFuturePoll] value. For each [rust_future_poll] call the continuation will be called /// exactly once. /// +/// If this is running in a BlockingTaskQueue, then `blocking_task_queue_handle` must be the handle +/// for it. If not, `blocking_task_queue_handle` must be `0`. +/// /// # Safety /// /// The [Handle] must not previously have been passed to [rust_future_free] @@ -69,10 +87,14 @@ pub unsafe fn rust_future_poll( handle: Handle, callback: RustFutureContinuationCallback, data: u64, + blocking_task_queue_handle: u64, ) where dyn RustFutureFfi: HandleAlloc, { - as HandleAlloc>::get_arc(handle).ffi_poll(callback, data) + let future = &*(future.0 as *mut Arc>); + future + .clone() + .ffi_poll(callback, data, blocking_task_queue_handle.try_into().ok()) } /// Cancel a Rust future diff --git a/uniffi_core/src/ffi/rustfuture/scheduler.rs b/uniffi_core/src/ffi/rustfuture/scheduler.rs index 629ee0c109..26526701d3 100644 --- a/uniffi_core/src/ffi/rustfuture/scheduler.rs +++ b/uniffi_core/src/ffi/rustfuture/scheduler.rs @@ -2,10 +2,89 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use std::mem; +use std::{cell::RefCell, future::poll_fn, mem, num::NonZeroU64, task::Poll, thread_local}; use super::{RustFutureContinuationCallback, RustFuturePoll}; +/// Context of the current `RustFuture::poll` call +struct RustFutureContext { + /// Blocking task queue that the future is being polled on + current_blocking_task_queue_handle: Option, + /// Blocking task queue that we've been asked to schedule the next poll on + scheduled_blocking_task_queue_handle: Option, +} + +thread_local! { + static CONTEXT: RefCell = RefCell::new(RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + }); +} + +fn with_context R, R>(operation: F) -> R { + CONTEXT.with(|context| operation(&mut context.borrow_mut())) +} + +pub fn on_poll_start(current_blocking_task_queue_handle: Option) { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +pub fn on_poll_end() { + with_context(|context| { + *context = RustFutureContext { + current_blocking_task_queue_handle: None, + scheduled_blocking_task_queue_handle: None, + } + }); +} + +/// Schedule work in a blocking task queue +/// +/// The returned future will attempt to arrange for [RustFuture::poll] to be called in the +/// blocking task queue. Once [RustFuture::poll] is running in the blocking task queue, then the future +/// will be ready. +/// +/// There's one tricky issue here: how can we ensure that when the top-level task is run in the +/// blocking task queue, this future will be polled? What happens this future is a child of `join!`, +/// `FuturesUnordered` or some other Future that handles its own polling? +/// +/// We start with an assumption: if we notify the waker then this future will be polled when the +/// top-level task is polled next. If a future does not honor this then we consider it a broken +/// future. This seems fair, since that future would almost certainly break a lot of other future +/// code. +/// +/// Based on that, we can have a simple system. When we're polled: +/// * If we're running in the blocking task queue, then we return `Poll::Ready`. +/// * If not, we return `Poll::Pending` and notify the waker so that the future polls again on +/// the next top-level poll. +/// +/// Note that this can be inefficient if the code awaits multiple blocking task queues at once. We +/// can only run the next poll on one of them, but all futures will be woken up. This seems okay +/// for our intended use cases, it would be pretty odd for a library to use multiple blocking task +/// queues. The alternative would be to store the set of all pending blocking task queues, which +/// seems like complete overkill for our purposes. +pub(super) async fn schedule_in_blocking_task_queue(handle: NonZeroU64) { + poll_fn(|future_context| { + with_context(|poll_context| { + if poll_context.current_blocking_task_queue_handle == Some(handle) { + Poll::Ready(()) + } else { + poll_context + .scheduled_blocking_task_queue_handle + .get_or_insert(handle); + future_context.waker().wake_by_ref(); + Poll::Pending + } + }) + }) + .await +} + /// Schedules a [crate::RustFuture] by managing the continuation data /// /// This struct manages the continuation callback and data that comes from the foreign side. It @@ -41,21 +120,34 @@ impl Scheduler { /// Store new continuation data if we are in the `Empty` state. If we are in the `Waked` or /// `Cancelled` state, call the continuation immediately with the data. pub(super) fn store(&mut self, callback: RustFutureContinuationCallback, data: u64) { + if let Some(blocking_task_queue_handle) = + with_context(|context| context.scheduled_blocking_task_queue_handle) + { + // We were asked to schedule the future in a blocking task queue, call the callback + // rather than storing it + callback( + data, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle.into(), + ); + return; + } + match self { Self::Empty => *self = Self::Set(callback, data), Self::Set(old_callback, old_data) => { log::error!( "store: observed `Self::Set` state. Is poll() being called from multiple threads at once?" ); - old_callback(*old_data, RustFuturePoll::Ready); + old_callback(*old_data, RustFuturePoll::Ready, 0); *self = Self::Set(callback, data); } Self::Waked => { *self = Self::Empty; - callback(data, RustFuturePoll::MaybeReady); + callback(data, RustFuturePoll::MaybeReady, 0); } Self::Cancelled => { - callback(data, RustFuturePoll::Ready); + callback(data, RustFuturePoll::Ready, 0); } } } @@ -67,7 +159,7 @@ impl Scheduler { let old_data = *old_data; let callback = *callback; *self = Self::Empty; - callback(old_data, RustFuturePoll::MaybeReady); + callback(old_data, RustFuturePoll::MaybeReady, 0); } // If we were in the `Empty` state, then transition to `Waked`. The next time `store` // is called, we will immediately call the continuation. @@ -79,7 +171,13 @@ impl Scheduler { pub(super) fn cancel(&mut self) { if let Self::Set(callback, old_data) = mem::replace(self, Self::Cancelled) { - callback(old_data, RustFuturePoll::Ready); + callback(old_data, RustFuturePoll::Ready, 0); + } + } + + pub(super) fn clear_wake_flag(&mut self) { + if let Self::Waked = self { + *self = Self::Empty } } diff --git a/uniffi_core/src/ffi/rustfuture/tests.rs b/uniffi_core/src/ffi/rustfuture/tests.rs index 886ee27c71..70a49179b0 100644 --- a/uniffi_core/src/ffi/rustfuture/tests.rs +++ b/uniffi_core/src/ffi/rustfuture/tests.rs @@ -65,16 +65,38 @@ fn channel() -> (Sender, Arc>) { } /// Poll a Rust future and get an OnceCell that's set when the continuation is called -fn poll(rust_future: &Arc>) -> Arc> { +fn poll(rust_future: &Arc>) -> Arc> { let cell = Arc::new(OnceCell::new()); - let handle = Arc::into_raw(cell.clone()) as u64; - rust_future.clone().ffi_poll(poll_continuation, handle); + let cell_ptr = Arc::into_raw(cell.clone()) as u64; + rust_future + .clone() + .ffi_poll(poll_continuation, cell_ptr, None); cell } -extern "C" fn poll_continuation(data: u64, code: RustFuturePoll) { - let cell = unsafe { Arc::from_raw(data as *const OnceCell) }; - cell.set(code).expect("Error setting OnceCell"); +/// Like poll, but simulate `poll()` being called from a blocking task queue +fn poll_from_blocking_task_queue( + rust_future: &Arc>, + blocking_task_queue_handle: u64, +) -> Arc> { + let cell = Arc::new(OnceCell::new()); + let cell_ptr = Arc::into_raw(cell.clone()) as *const (); + rust_future.clone().ffi_poll( + poll_continuation, + cell_ptr, + Some(blocking_task_queue_handle.try_into().unwrap()), + ); + cell +} + +extern "C" fn poll_continuation( + data: *const (), + code: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + let cell = unsafe { Arc::from_raw(data as *const OnceCell<(RustFuturePoll, u64)>) }; + cell.set((code, blocking_task_queue_handle)) + .expect("Error setting OnceCell"); } fn complete(rust_future: Arc>) -> (RustBuffer, RustCallStatus) { @@ -83,25 +105,47 @@ fn complete(rust_future: Arc>) -> (RustBuffer, Rus (return_value, out_status_code) } +fn check_continuation_not_called(once_cell: &OnceCell<(RustFuturePoll, u64)>) { + assert_eq!(once_cell.get(), None); +} + +fn check_continuation_called( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, +) { + assert_eq!(once_cell.get(), Some(&(poll_result, 0))); +} + +fn check_continuation_called_with_blocking_task_queue_handle( + once_cell: &OnceCell<(RustFuturePoll, u64)>, + poll_result: RustFuturePoll, + blocking_task_queue_handle: u64, +) { + assert_eq!( + once_cell.get(), + Some(&(poll_result, blocking_task_queue_handle)) + ) +} + #[test] fn test_success() { let (sender, rust_future) = channel(); // Test polling the rust future before it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.wake(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Test polling the rust future when it's ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Ok("All done".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // Future polls should immediately return ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Complete the future let (return_buf, call_status) = complete(rust_future); @@ -117,12 +161,12 @@ fn test_error() { let (sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); sender.send(Err("Something went wrong".into())); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Error); @@ -144,14 +188,14 @@ fn test_cancel() { let (_sender, rust_future) = channel(); let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), None); + check_continuation_not_called(&continuation_result); rust_future.ffi_cancel(); // Cancellation should immediately invoke the callback with RustFuturePoll::Ready - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); // Future polls should immediately invoke the callback with RustFuturePoll::Ready let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); let (_, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Cancelled); @@ -187,7 +231,7 @@ fn test_complete_with_stored_continuation() { let continuation_result = poll(&rust_future); rust_future.ffi_free(); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); } // Test what happens if we see a `wake()` call while we're polling the future. This can @@ -210,10 +254,47 @@ fn test_wake_during_poll() { let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); let continuation_result = poll(&rust_future); // The continuation function should called immediately - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::MaybeReady)); + check_continuation_called(&continuation_result, RustFuturePoll::MaybeReady); // A second poll should finish the future let continuation_result = poll(&rust_future); - assert_eq!(continuation_result.get(), Some(&RustFuturePoll::Ready)); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + let (return_buf, call_status) = complete(rust_future); + assert_eq!(call_status.code, RustCallStatusCode::Success); + assert_eq!( + >::try_lift(return_buf).unwrap(), + "All done" + ); +} + +#[test] +fn test_blocking_task() { + let blocking_task_queue_handle = 1001; + let future = async move { + schedule_in_blocking_task_queue(blocking_task_queue_handle.try_into().unwrap()).await; + "All done".to_owned() + }; + let rust_future: Arc> = RustFuture::new(future, crate::UniFfiTag); + // On the first poll, the future should not be ready and it should ask to be scheduled in the + // blocking task queue + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // If we poll it again not in a blocking task queue, then we get the same result + let continuation_result = poll(&rust_future); + check_continuation_called_with_blocking_task_queue_handle( + &continuation_result, + RustFuturePoll::MaybeReady, + blocking_task_queue_handle, + ); + // When we poll it in the blocking task queue, then the future is ready + let continuation_result = + poll_from_blocking_task_queue(&rust_future, blocking_task_queue_handle); + check_continuation_called(&continuation_result, RustFuturePoll::Ready); + + // Complete the future let (return_buf, call_status) = complete(rust_future); assert_eq!(call_status.code, RustCallStatusCode::Success); assert_eq!( diff --git a/uniffi_core/src/ffi_converter_impls.rs b/uniffi_core/src/ffi_converter_impls.rs index aec093154a..3866f84eda 100644 --- a/uniffi_core/src/ffi_converter_impls.rs +++ b/uniffi_core/src/ffi_converter_impls.rs @@ -23,8 +23,9 @@ /// "UT" means an arbitrary `UniFfiTag` type. use crate::{ check_remaining, derive_ffi_traits, ffi_converter_rust_buffer_lift_and_lower, metadata, - ConvertError, FfiConverter, Lift, LiftRef, LiftReturn, Lower, LowerReturn, MetadataBuffer, - Result, RustBuffer, UnexpectedUniFFICallbackError, + BlockingTaskQueue, BlockingTaskQueueVTable, ConvertError, FfiConverter, Lift, LiftRef, + LiftReturn, Lower, LowerReturn, MetadataBuffer, Result, RustBuffer, + UnexpectedUniFFICallbackError, }; use anyhow::bail; use bytes::buf::{Buf, BufMut}; @@ -244,6 +245,35 @@ unsafe impl FfiConverter for Duration { const TYPE_ID_META: MetadataBuffer = MetadataBuffer::from_code(metadata::codes::TYPE_DURATION); } +/// Support for passing [BlockingTaskQueue] across the FFI +/// +/// Both fields of [BlockingTaskQueue] are serialized into a RustBuffer. The vtable pointer is +/// casted to a u64. +unsafe impl FfiConverter for BlockingTaskQueue { + ffi_converter_rust_buffer_lift_and_lower!(UT); + + fn write(obj: BlockingTaskQueue, buf: &mut Vec) { + let obj = obj.clone(); + buf.put_u64(obj.handle.into()); + buf.put_u64(obj.vtable as *const BlockingTaskQueueVTable as u64); + } + + fn try_read(buf: &mut &[u8]) -> Result { + check_remaining(buf, 16)?; + let handle = buf + .get_u64() + .try_into() + .expect("handle = 0 when reading BlockingTaskQueue"); + let vtable = unsafe { + &*(buf.get_u64() as *const BlockingTaskQueueVTable) as &'static BlockingTaskQueueVTable + }; + Ok(Self { handle, vtable }) + } + + const TYPE_ID_META: MetadataBuffer = + MetadataBuffer::from_code(metadata::codes::TYPE_BLOCKING_TASK_QUEUE); +} + // Support for passing optional values via the FFI. // // Optional values are currently always passed by serializing to a buffer. @@ -419,6 +449,7 @@ derive_ffi_traits!(blanket bool); derive_ffi_traits!(blanket String); derive_ffi_traits!(blanket Duration); derive_ffi_traits!(blanket SystemTime); +derive_ffi_traits!(blanket BlockingTaskQueue); // For composite types, derive LowerReturn, LiftReturn, etc, from Lift/Lower. // diff --git a/uniffi_core/src/metadata.rs b/uniffi_core/src/metadata.rs index 934d09cf87..4b9f9b861f 100644 --- a/uniffi_core/src/metadata.rs +++ b/uniffi_core/src/metadata.rs @@ -69,6 +69,7 @@ pub mod codes { pub const TYPE_RESULT: u8 = 23; pub const TYPE_TRAIT_INTERFACE: u8 = 24; pub const TYPE_CALLBACK_TRAIT_INTERFACE: u8 = 25; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 26; pub const TYPE_UNIT: u8 = 255; // Literal codes for LiteralMetadata - note that we don't support diff --git a/uniffi_macros/src/setup_scaffolding.rs b/uniffi_macros/src/setup_scaffolding.rs index 08bb2cc568..05a7b7ea49 100644 --- a/uniffi_macros/src/setup_scaffolding.rs +++ b/uniffi_macros/src/setup_scaffolding.rs @@ -166,8 +166,13 @@ fn rust_future_scaffolding_fns(module_path: &str) -> TokenStream { #[allow(clippy::missing_safety_doc, missing_docs)] #[doc(hidden)] #[no_mangle] - pub unsafe extern "C" fn #ffi_rust_future_poll(handle: ::uniffi::Handle, callback: ::uniffi::RustFutureContinuationCallback, data: u64) { - ::uniffi::ffi::rust_future_poll::<#return_type, crate::UniFfiTag>(handle, callback, data); + pub unsafe extern "C" fn #ffi_rust_future_poll( + handle: ::uniffi::RustFutureHandle, + callback: ::uniffi::RustFutureContinuationCallback, + data: u64, + blocking_task_queue_handle: u64, + ) { + ::uniffi::ffi::rust_future_poll::<#return_type>(handle, callback, data, blocking_task_queue_handle); } #[allow(clippy::missing_safety_doc, missing_docs)] diff --git a/uniffi_meta/src/metadata.rs b/uniffi_meta/src/metadata.rs index 66c2c63952..b553e0060e 100644 --- a/uniffi_meta/src/metadata.rs +++ b/uniffi_meta/src/metadata.rs @@ -52,6 +52,7 @@ pub mod codes { pub const TYPE_RESULT: u8 = 23; pub const TYPE_TRAIT_INTERFACE: u8 = 24; pub const TYPE_CALLBACK_TRAIT_INTERFACE: u8 = 25; + pub const TYPE_BLOCKING_TASK_QUEUE: u8 = 26; pub const TYPE_UNIT: u8 = 255; // Literal codes diff --git a/uniffi_meta/src/reader.rs b/uniffi_meta/src/reader.rs index 5a09d9dd7c..5341f1b24a 100644 --- a/uniffi_meta/src/reader.rs +++ b/uniffi_meta/src/reader.rs @@ -145,6 +145,7 @@ impl<'a> MetadataReader<'a> { codes::TYPE_STRING => Type::String, codes::TYPE_DURATION => Type::Duration, codes::TYPE_SYSTEM_TIME => Type::Timestamp, + codes::TYPE_BLOCKING_TASK_QUEUE => Type::BlockingTaskQueue, codes::TYPE_RECORD => Type::Record { module_path: self.read_string()?, name: self.read_string()?, diff --git a/uniffi_meta/src/types.rs b/uniffi_meta/src/types.rs index 51bf156b50..5003dd9f77 100644 --- a/uniffi_meta/src/types.rs +++ b/uniffi_meta/src/types.rs @@ -88,6 +88,7 @@ pub enum Type { // How the object is implemented. imp: ObjectImpl, }, + BlockingTaskQueue, // Types defined in the component API, each of which has a string name. Record { module_path: String, diff --git a/uniffi_udl/src/resolver.rs b/uniffi_udl/src/resolver.rs index ea98cd7a99..1409c3a6ff 100644 --- a/uniffi_udl/src/resolver.rs +++ b/uniffi_udl/src/resolver.rs @@ -209,6 +209,7 @@ pub(crate) fn resolve_builtin_type(name: &str) -> Option { "f64" => Some(Type::Float64), "timestamp" => Some(Type::Timestamp), "duration" => Some(Type::Duration), + "BlockingTaskQueue" => Some(Type::BlockingTaskQueue), _ => None, } }