Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(AsyncEvent): switch to using AsyncChannel and AsyncStream from continuations #20

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import class Foundation.ProcessInfo

var dependencies: [Target.Dependency] = {
var dependencies: [Target.Dependency] = [
.product(name: "OrderedCollections", package: "swift-collections")
.product(name: "OrderedCollections", package: "swift-collections"),
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
]

if ProcessInfo.processInfo.environment["ASYNCOBJECTS_ENABLE_LOGGING_LEVEL"] != nil {
Expand Down Expand Up @@ -64,6 +65,7 @@ let package = Package(
],
dependencies: [
.package(url: "https://github.com/apple/swift-collections.git", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-async-algorithms", from: "0.1.0"),
.package(url: "https://github.com/apple/swift-docc-plugin", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-format", from: "0.50700.0"),
.package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"),
Expand Down
191 changes: 42 additions & 149 deletions Sources/AsyncObjects/AsyncEvent.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Foundation
import AsyncAlgorithms

/// An object that controls execution of tasks depending on the signal state.
///
Expand All @@ -21,148 +22,44 @@ import Foundation
/// // signal event after completing some task
/// event.signal()
/// ```
public actor AsyncEvent: AsyncObject, ContinuableCollectionActor, LoggableActor
{
/// The suspended tasks continuation type.
public final class AsyncEvent: AsyncObject, Loggable {
/// The stream continuation that updates state change
/// info for `AsyncEvent`.
@usableFromInline
internal typealias Continuation = TrackedContinuation<
GlobalContinuation<Void, Error>
>

/// The continuations stored with an associated key for all the suspended task that are waiting for event signal.
@usableFromInline
internal private(set) var continuations: [UUID: Continuation] = [:]
/// Indicates whether current state of event is signalled.
@usableFromInline
internal private(set) var signalled: Bool

// MARK: Internal

/// Add continuation with the provided key in `continuations` map.
///
/// - Parameters:
/// - continuation: The `continuation` to add.
/// - key: The key in the map.
/// - file: The file add request originates from (there's usually no need to pass it
/// explicitly as it defaults to `#fileID`).
/// - function: The function add request originates from (there's usually no need to
/// pass it explicitly as it defaults to `#function`).
/// - line: The line add request originates from (there's usually no need to pass it
/// explicitly as it defaults to `#line`).
/// - preinit: The pre-initialization handler to run
/// in the beginning of this method.
///
/// - Important: The pre-initialization handler must run
/// before any logic in this method.
@inlinable
internal func addContinuation(
_ continuation: Continuation,
withKey key: UUID,
file: String, function: String, line: UInt,
preinit: @Sendable () -> Void
) {
preinit()
log("Adding", id: key, file: file, function: function, line: line)
guard !continuation.resumed else {
log(
"Already resumed, not tracking", id: key,
file: file, function: function, line: line
)
return
}

guard !signalled else {
continuation.resume()
log("Resumed", id: key, file: file, function: function, line: line)
return
}

continuations[key] = continuation
log("Tracking", id: key, file: file, function: function, line: line)
}

/// Remove continuation associated with provided key
/// from `continuations` map and resumes with `CancellationError`.
///
/// - Parameters:
/// - continuation: The continuation to remove and cancel.
/// - key: The key in the map.
/// - file: The file remove request originates from (there's usually no need to pass it
/// explicitly as it defaults to `#fileID`).
/// - function: The function remove request originates from (there's usually no need to
/// pass it explicitly as it defaults to `#function`).
/// - line: The line remove request originates from (there's usually no need to pass it
/// explicitly as it defaults to `#line`).
@inlinable
internal func removeContinuation(
_ continuation: Continuation,
withKey key: UUID,
file: String, function: String, line: UInt
) {
log("Removing", id: key, file: file, function: function, line: line)
continuations.removeValue(forKey: key)
guard !continuation.resumed else {
log(
"Already resumed, not cancelling", id: key,
file: file, function: function, line: line
)
return
}

continuation.cancel()
log("Cancelled", id: key, file: file, function: function, line: line)
}

/// Resets signal of event.
///
/// - Parameters:
/// - file: The file reset originates from (there's usually no need to pass it
/// explicitly as it defaults to `#fileID`).
/// - function: The function reset originates from (there's usually no need to
/// pass it explicitly as it defaults to `#function`).
/// - line: The line reset originates from (there's usually no need to pass it
/// explicitly as it defaults to `#line`).
@inlinable
internal func resetEvent(file: String, function: String, line: UInt) {
signalled = false
log("Reset", file: file, function: function, line: line)
}

/// Signals the event and resumes all the tasks
/// suspended and waiting for signal.
///
/// - Parameters:
/// - file: The file signal originates from (there's usually no need to pass it
/// explicitly as it defaults to `#fileID`).
/// - function: The function signal originates from (there's usually no need to
/// pass it explicitly as it defaults to `#function`).
/// - line: The line signal originates from (there's usually no need to pass it
/// explicitly as it defaults to `#line`).
@inlinable
internal func signalEvent(file: String, function: String, line: UInt) {
log("Signalling", file: file, function: function, line: line)
continuations.forEach { key, value in
value.resume()
log("Resumed", id: key, file: file, function: function, line: line)
}
continuations = [:]
signalled = true
log("Signalled", file: file, function: function, line: line)
}

// MARK: Public
let transmitter: AsyncStream<Bool>.Continuation
/// The channel that controls waiting on the `AsyncEvent`.
/// The waiting completes when `AsyncEvent` is signalled.
let waiter: AsyncChannel<Void>

/// Creates a new event with signal state provided.
/// By default, event is initially in signalled state.
///
/// - Parameter signalled: The signal state for event.
/// - Returns: The newly created event.
public init(signaledInitially signalled: Bool = true) {
self.signalled = signalled
let channel = AsyncChannel<Void>()
var continuation: AsyncStream<Bool>.Continuation!
let stream = AsyncStream<Bool>(
bufferingPolicy: .bufferingNewest(1)
) { continuation = $0 }

self.transmitter = continuation
self.waiter = channel

Task.detached {
var state = signalled
var wt = state ? Task { for await _ in channel { continue } } : nil
for await signal in stream {
guard state != signal else { continue }
state = signal
guard state else { wt?.cancel(); continue }
wt = Task { for await _ in channel { continue } }
}
wt?.cancel()
}
}

// TODO: Explore alternative cleanup for actor
// deinit { self.continuations.forEach { $1.cancel() } }
deinit { self.transmitter.finish() }

/// Resets signal of event.
///
Expand All @@ -176,13 +73,12 @@ public actor AsyncEvent: AsyncObject, ContinuableCollectionActor, LoggableActor
/// - line: The line reset originates from (there's usually no need to pass it
/// explicitly as it defaults to `#line`).
@Sendable
@inlinable
public nonisolated func reset(
file: String = #fileID,
function: String = #function,
line: UInt = #line
) {
Task { await resetEvent(file: file, function: function, line: line) }
}
) { transmitter.yield(false) }

/// Signals the event.
///
Expand All @@ -196,13 +92,12 @@ public actor AsyncEvent: AsyncObject, ContinuableCollectionActor, LoggableActor
/// - line: The line signal originates from (there's usually no need to pass it
/// explicitly as it defaults to `#line`).
@Sendable
@inlinable
public nonisolated func signal(
file: String = #fileID,
function: String = #function,
line: UInt = #line
) {
Task { await signalEvent(file: file, function: function, line: line) }
}
) { transmitter.yield(true) }

/// Waits for event signal, or proceeds if already signalled.
///
Expand All @@ -224,18 +119,16 @@ public actor AsyncEvent: AsyncObject, ContinuableCollectionActor, LoggableActor
function: String = #function,
line: UInt = #line
) async throws {
guard !signalled else {
log("Acquired", file: file, function: function, line: line)
return
let id = UUID()
log("Waiting", id: id, file: file, function: function, line: line)
await waiter.send(())
do {
try Task.checkCancellation()
log("Completed", id: id, file: file, function: function, line: line)
} catch {
log("Cancelled", id: id, file: file, function: function, line: line)
throw error
}

let key = UUID()
log("Waiting", id: key, file: file, function: function, line: line)
try await withPromisedContinuation(
withKey: key,
file: file, function: function, line: line
)
log("Received", id: key, file: file, function: function, line: line)
}
}

Expand Down
13 changes: 8 additions & 5 deletions Tests/AsyncObjectsTests/AsyncEventTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class AsyncEventTests: XCTestCase {
func testResetSignal() async throws {
let event = AsyncEvent()
event.reset()
try await waitUntil(event, timeout: 3) { !$0.signalled }
event.signal()
try await event.wait(forSeconds: 3)
}
Expand Down Expand Up @@ -73,9 +72,11 @@ class AsyncEventTimeoutTests: XCTestCase {
func testResetSignal() async throws {
let event = AsyncEvent()
event.reset()
try await waitUntil(event, timeout: 3) { !$0.signalled }
do {
try await event.wait(forSeconds: 3)
let start = DispatchTime.now()
while DispatchTime.now() < start + .seconds(5) {
try await event.wait(forSeconds: 3)
}
XCTFail("Unexpected task progression")
} catch is DurationTimeoutError {}
}
Expand Down Expand Up @@ -108,9 +109,11 @@ class AsyncEventClockTimeoutTests: XCTestCase {
let clock: ContinuousClock = .continuous
let event = AsyncEvent()
event.reset()
try await waitUntil(event, timeout: 3) { !$0.signalled }
do {
try await event.wait(forSeconds: 3, clock: clock)
let start = clock.now
while clock.now < start + .seconds(5) {
try await event.wait(forSeconds: 3, clock: clock)
}
XCTFail("Unexpected task progression")
} catch is TimeoutError<ContinuousClock> {}
}
Expand Down