diff --git a/Sources/AsyncObjects/CancellationSource/CancellationSource.swift b/Sources/AsyncObjects/CancellationSource/CancellationSource.swift index b538296..0ecf636 100644 --- a/Sources/AsyncObjects/CancellationSource/CancellationSource.swift +++ b/Sources/AsyncObjects/CancellationSource/CancellationSource.swift @@ -1,4 +1,5 @@ import Foundation +import AsyncAlgorithms /// An object that controls cooperative cancellation of multiple registered tasks and linked object registered tasks. /// @@ -36,16 +37,19 @@ public struct CancellationSource: AsyncObject, Cancellable, Loggable { internal typealias Continuation = GlobalContinuation /// The cancellable work with invocation context. internal typealias WorkItem = ( - Cancellable, id: UUID, file: String, function: String, line: UInt + any Cancellable, id: UUID, file: String, function: String, line: UInt ) /// The lifetime task that is cancelled when /// `CancellationSource` is cancelled. @usableFromInline - var lifetime: Task! + let lifetime: Task /// The stream continuation used to register work items /// for cooperative cancellation. - var pipe: AsyncStream.Continuation! + let pipe: AsyncStream.Continuation + /// The channel that controls waiting on the `CancellationSource`. + /// Once `CancellationSource` is cancelled, channel finishes. + let waiter: AsyncChannel /// A Boolean value that indicates whether cancellation is already /// invoked on the source. @@ -61,24 +65,57 @@ public struct CancellationSource: AsyncObject, Cancellable, Loggable { /// /// - Returns: The newly created cancellation source. public init() { - let stream = AsyncStream { self.pipe = $0 } - self.lifetime = Task.detached { - try await withThrowingTaskGroup(of: Void.self) { group in - for await item in stream { - group.addTask { - try? await waitHandlingCancelation( - for: item.0, associatedId: item.id, - file: item.file, - function: item.function, - line: item.line - ) + var continuation: AsyncStream.Continuation! + let stream = AsyncStream { continuation = $0 } + let channel = AsyncChannel() + self.pipe = continuation + self.waiter = channel + + func lifetime() -> Task { + return Task.detached { + await withThrowingTaskGroup(of: Void.self) { group in + for await item in stream { + group.addTask { + try? await waitHandlingCancelation( + for: item.0, associatedId: item.id, + file: item.file, + function: item.function, + line: item.line + ) + } } + + group.cancelAll() } + channel.finish() + } + } - group.cancelAll() - try await group.waitForAll() + #if swift(>=5.8) + if #available(macOS 13.3, iOS 16.4, tvOS 16.4, watchOS 9.4, *) { + self.lifetime = Task.detached { + await withDiscardingTaskGroup { group in + for await item in stream { + group.addTask { + try? await waitHandlingCancelation( + for: item.0, associatedId: item.id, + file: item.file, + function: item.function, + line: item.line + ) + } + } + + group.cancelAll() + } + channel.finish() } + } else { + self.lifetime = lifetime() } + #else + self.lifetime = lifetime() + #endif } /// Register cancellable work for cooperative cancellation @@ -163,11 +200,17 @@ public struct CancellationSource: AsyncObject, Cancellable, Loggable { file: String = #fileID, function: String = #function, line: UInt = #line - ) async { + ) async throws { let id = UUID() log("Waiting", id: id, file: file, function: function, line: line) - let _ = await lifetime.result - log("Completed", id: id, file: file, function: function, line: line) + await waiter.send(()) + do { + try Task.checkCancellation() + log("Completed", id: id, file: file, function: function, line: line) + } catch { + log("Cancelled", id: id, file: file, function: function, line: line) + throw error + } } } diff --git a/Tests/AsyncObjectsTests/CancellationSourceTests.swift b/Tests/AsyncObjectsTests/CancellationSourceTests.swift index 5d678b3..052fee7 100644 --- a/Tests/AsyncObjectsTests/CancellationSourceTests.swift +++ b/Tests/AsyncObjectsTests/CancellationSourceTests.swift @@ -184,3 +184,39 @@ class CancellationSourceInitializationTests: XCTestCase { XCTAssertTrue(task.isCancelled) } } + +@MainActor +class CancellationSourceWaitTests: XCTestCase { + + func testWithoutCancellation() async throws { + let source = CancellationSource() + let task = Task.detached { + try await Task.sleep(seconds: 10) + XCTFail("Unexpected task progression") + } + source.register(task: task) + do { + try await source.wait(forSeconds: 3) + XCTFail("Unexpected task progression") + } catch is DurationTimeoutError {} + XCTAssertFalse(source.isCancelled) + XCTAssertFalse(task.isCancelled) + } + + func testCooperativeCancellation() async throws { + let source = CancellationSource() + Task.detached(cancellationSource: source) { + try await Task.sleep(seconds: 20) + XCTFail("Unexpected task progression") + } + let task = Task.detached { + do { + try await source.wait(forSeconds: 5) + XCTFail("Unexpected task progression") + } catch is CancellationError {} + } + task.cancel() + try await task.value + XCTAssertFalse(source.isCancelled) + } +}