Skip to content

Commit

Permalink
EventStreams: Customisable Terminating Byte Sequence (#115)
Browse files Browse the repository at this point in the history
### Motivation

As discussed in
apple/swift-openapi-generator#622, some APIs,
e.g., ChatGPT or Claude, may return a non-JSON byte sequence to
terminate a stream of events.
If not handled with a workaround (see below)such non-JSON terminating
byte sequences cause a decoding error.

### Modifications

This PR adds the ability to customise the terminating byte sequence by
providing a closure to `asDecodedServerSentEvents()` as well as
`asDecodedServerSentEventsWithJSONData()` that can match incoming data
for the terminating byte sequence before it is decoded into JSON, for
instance.

### Result

Instead of having to decode and re-encode incoming events to filter out
the terminating byte sequence - as seen in
apple/swift-openapi-generator#622 (comment)
- terminating byte sequences can now be cleanly caught by either
providing a closure or providing the terminating byte sequence directly
when calling `asDecodedServerSentEvents()` and
`asDecodedServerSentEventsWithJSONData()`.

### Test Plan

This PR includes unit tests that test the new function parameters as
part of the existing tests for `asDecodedServerSentEvents()` as well as
`asDecodedServerSentEventsWithJSONData()`.

---------

Co-authored-by: Honza Dvorsky <[email protected]>
  • Loading branch information
paulhdk and czechboy0 authored Oct 3, 2024
1 parent da2e5b8 commit d604dd0
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 27 deletions.
34 changes: 34 additions & 0 deletions Sources/OpenAPIRuntime/Deprecated/Deprecated.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,37 @@ extension Configuration {
)
}
}

extension AsyncSequence where Element == ArraySlice<UInt8>, Self: Sendable {
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
///
/// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`.
/// - Returns: A sequence that provides the events.
@available(*, deprecated, renamed: "asDecodedServerSentEvents(while:)") @_disfavoredOverload
public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence<
ServerSentEventsLineDeserializationSequence<Self>
> { asDecodedServerSentEvents(while: { _ in true }) }
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
///
/// Use this method if the event's `data` field is JSON.
/// - Parameters:
/// - dataType: The type to decode the JSON data into.
/// - decoder: The JSON decoder to use.
/// - Returns: A sequence that provides the events with the decoded JSON data.
@available(*, deprecated, renamed: "asDecodedServerSentEventsWithJSONData(of:decoder:while:)") @_disfavoredOverload
public func asDecodedServerSentEventsWithJSONData<JSONDataType: Decodable>(
of dataType: JSONDataType.Type = JSONDataType.self,
decoder: JSONDecoder = .init()
) -> AsyncThrowingMapSequence<
ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>>,
ServerSentEventWithJSONData<JSONDataType>
> { asDecodedServerSentEventsWithJSONData(of: dataType, decoder: decoder, while: { _ in true }) }
}

extension ServerSentEventsDeserializationSequence {
/// Creates a new sequence.
/// - Parameter upstream: The upstream sequence of arbitrary byte chunks.
@available(*, deprecated, renamed: "init(upstream:while:)") @_disfavoredOverload public init(upstream: Upstream) {
self.init(upstream: upstream, while: { _ in true })
}
}
71 changes: 50 additions & 21 deletions Sources/OpenAPIRuntime/EventStreams/ServerSentEventsDecoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,19 @@ where Upstream.Element == ArraySlice<UInt8> {
/// The upstream sequence.
private let upstream: Upstream

/// A closure that determines whether the given byte chunk should be forwarded to the consumer.
/// - Parameter: A byte chunk.
/// - Returns: `true` if the byte chunk should be forwarded, `false` if this byte chunk is the terminating sequence.
private let predicate: @Sendable (ArraySlice<UInt8>) -> Bool

/// Creates a new sequence.
/// - Parameter upstream: The upstream sequence of arbitrary byte chunks.
public init(upstream: Upstream) { self.upstream = upstream }
/// - Parameters:
/// - upstream: The upstream sequence of arbitrary byte chunks.
/// - predicate: A closure that determines whether the given byte chunk should be forwarded to the consumer.
public init(upstream: Upstream, while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool) {
self.upstream = upstream
self.predicate = predicate
}
}

extension ServerSentEventsDeserializationSequence: AsyncSequence {
Expand All @@ -46,7 +56,16 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence {
var upstream: UpstreamIterator

/// The state machine of the iterator.
var stateMachine: StateMachine = .init()
var stateMachine: StateMachine

/// Creates a new sequence.
/// - Parameters:
/// - upstream: The upstream sequence of arbitrary byte chunks.
/// - predicate: A closure that determines whether the given byte chunk should be forwarded to the consumer.
init(upstream: UpstreamIterator, while predicate: @escaping ((ArraySlice<UInt8>) -> Bool)) {
self.upstream = upstream
self.stateMachine = .init(while: predicate)
}

/// Asynchronously advances to the next element and returns it, or ends the
/// sequence if there is no next element.
Expand All @@ -70,7 +89,7 @@ extension ServerSentEventsDeserializationSequence: AsyncSequence {
/// Creates the asynchronous iterator that produces elements of this
/// asynchronous sequence.
public func makeAsyncIterator() -> Iterator<Upstream.AsyncIterator> {
Iterator(upstream: upstream.makeAsyncIterator())
Iterator(upstream: upstream.makeAsyncIterator(), while: predicate)
}
}

Expand All @@ -79,26 +98,30 @@ extension AsyncSequence where Element == ArraySlice<UInt8>, Self: Sendable {
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
///
/// Use this method if the event's `data` field is not JSON, or if you don't want to parse it using `asDecodedServerSentEventsWithJSONData`.
/// - Parameter: A closure that determines whether the given byte chunk should be forwarded to the consumer.
/// - Returns: A sequence that provides the events.
public func asDecodedServerSentEvents() -> ServerSentEventsDeserializationSequence<
ServerSentEventsLineDeserializationSequence<Self>
> { .init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self)) }

public func asDecodedServerSentEvents(
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
) -> ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>> {
.init(upstream: ServerSentEventsLineDeserializationSequence(upstream: self), while: predicate)
}
/// Returns another sequence that decodes each event's data as the provided type using the provided decoder.
///
/// Use this method if the event's `data` field is JSON.
/// - Parameters:
/// - dataType: The type to decode the JSON data into.
/// - decoder: The JSON decoder to use.
/// - predicate: A closure that determines whether the given byte sequence is the terminating byte sequence defined by the API.
/// - Returns: A sequence that provides the events with the decoded JSON data.
public func asDecodedServerSentEventsWithJSONData<JSONDataType: Decodable>(
of dataType: JSONDataType.Type = JSONDataType.self,
decoder: JSONDecoder = .init()
decoder: JSONDecoder = .init(),
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
) -> AsyncThrowingMapSequence<
ServerSentEventsDeserializationSequence<ServerSentEventsLineDeserializationSequence<Self>>,
ServerSentEventWithJSONData<JSONDataType>
> {
asDecodedServerSentEvents()
asDecodedServerSentEvents(while: predicate)
.map { event in
ServerSentEventWithJSONData(
event: event.event,
Expand All @@ -118,10 +141,10 @@ extension ServerSentEventsDeserializationSequence.Iterator {
struct StateMachine {

/// The possible states of the state machine.
enum State: Hashable {
enum State {

/// Accumulating an event, which hasn't been emitted yet.
case accumulatingEvent(ServerSentEvent, buffer: [ArraySlice<UInt8>])
case accumulatingEvent(ServerSentEvent, buffer: [ArraySlice<UInt8>], predicate: (ArraySlice<UInt8>) -> Bool)

/// Finished, the terminal state.
case finished
Expand All @@ -134,7 +157,9 @@ extension ServerSentEventsDeserializationSequence.Iterator {
private(set) var state: State

/// Creates a new state machine.
init() { self.state = .accumulatingEvent(.init(), buffer: []) }
init(while predicate: @escaping (ArraySlice<UInt8>) -> Bool) {
self.state = .accumulatingEvent(.init(), buffer: [], predicate: predicate)
}

/// An action returned by the `next` method.
enum NextAction {
Expand All @@ -156,20 +181,24 @@ extension ServerSentEventsDeserializationSequence.Iterator {
/// - Returns: An action to perform.
mutating func next() -> NextAction {
switch state {
case .accumulatingEvent(var event, var buffer):
case .accumulatingEvent(var event, var buffer, let predicate):
guard let line = buffer.first else { return .needsMore }
state = .mutating
buffer.removeFirst()
if line.isEmpty {
// Dispatch the accumulated event.
state = .accumulatingEvent(.init(), buffer: buffer)
// If the last character of data is a newline, strip it.
if event.data?.hasSuffix("\n") ?? false { event.data?.removeLast() }
if let data = event.data, !predicate(ArraySlice(data.utf8)) {
state = .finished
return .returnNil
}
state = .accumulatingEvent(.init(), buffer: buffer, predicate: predicate)
return .emitEvent(event)
}
if line.first! == ASCII.colon {
// A comment, skip this line.
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
}
// Parse the field name and value.
Expand All @@ -193,7 +222,7 @@ extension ServerSentEventsDeserializationSequence.Iterator {
}
guard let value else {
// An unknown type of event, skip.
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
}
// Process the field.
Expand All @@ -214,11 +243,11 @@ extension ServerSentEventsDeserializationSequence.Iterator {
}
default:
// An unknown or invalid field, skip.
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
}
// Processed the field, continue.
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
case .finished: return .returnNil
case .mutating: preconditionFailure("Invalid state")
Expand All @@ -240,11 +269,11 @@ extension ServerSentEventsDeserializationSequence.Iterator {
/// - Returns: An action to perform.
mutating func receivedValue(_ value: ArraySlice<UInt8>?) -> ReceivedValueAction {
switch state {
case .accumulatingEvent(let event, var buffer):
case .accumulatingEvent(let event, var buffer, let predicate):
if let value {
state = .mutating
buffer.append(value)
state = .accumulatingEvent(event, buffer: buffer)
state = .accumulatingEvent(event, buffer: buffer, predicate: predicate)
return .noop
} else {
// If no value is received, drop the existing event on the floor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@ import XCTest
import Foundation

final class Test_ServerSentEventsDecoding: Test_Runtime {
func _test(input: String, output: [ServerSentEvent], file: StaticString = #filePath, line: UInt = #line)
async throws
{
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)).asDecodedServerSentEvents()
func _test(
input: String,
output: [ServerSentEvent],
file: StaticString = #filePath,
line: UInt = #line,
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
) async throws {
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8)).asDecodedServerSentEvents(while: predicate)
let events = try await [ServerSentEvent](collecting: sequence)
XCTAssertEqual(events.count, output.count, file: file, line: line)
for (index, linePair) in zip(events, output).enumerated() {
let (actualEvent, expectedEvent) = linePair
XCTAssertEqual(actualEvent, expectedEvent, "Event: \(index)", file: file, line: line)
}
}

func test() async throws {
// Simple event.
try await _test(
Expand Down Expand Up @@ -83,22 +88,40 @@ final class Test_ServerSentEventsDecoding: Test_Runtime {
.init(id: "123", data: "This is a message with an ID."),
]
)

try await _test(
input: #"""
data: hello
data: world
data: [DONE]
data: hello2
data: world2
"""#,
output: [.init(data: "hello\nworld")],
while: { incomingData in incomingData != ArraySlice<UInt8>(Data("[DONE]".utf8)) }
)
}
func _testJSONData<JSONType: Decodable & Hashable & Sendable>(
input: String,
output: [ServerSentEventWithJSONData<JSONType>],
file: StaticString = #filePath,
line: UInt = #line
line: UInt = #line,
while predicate: @escaping @Sendable (ArraySlice<UInt8>) -> Bool = { _ in true }
) async throws {
let sequence = asOneBytePerElementSequence(ArraySlice(input.utf8))
.asDecodedServerSentEventsWithJSONData(of: JSONType.self)
.asDecodedServerSentEventsWithJSONData(of: JSONType.self, while: predicate)
let events = try await [ServerSentEventWithJSONData<JSONType>](collecting: sequence)
XCTAssertEqual(events.count, output.count, file: file, line: line)
for (index, linePair) in zip(events, output).enumerated() {
let (actualEvent, expectedEvent) = linePair
XCTAssertEqual(actualEvent, expectedEvent, "Event: \(index)", file: file, line: line)
}
}

struct TestEvent: Decodable, Hashable, Sendable { var index: Int }
func testJSONData() async throws {
// Simple event.
Expand All @@ -121,6 +144,33 @@ final class Test_ServerSentEventsDecoding: Test_Runtime {
.init(event: "event2", data: TestEvent(index: 2), id: "2"),
]
)

try await _testJSONData(
input: #"""
event: event1
id: 1
data: {"index":1}
event: event2
id: 2
data: {
data: "index": 2
data: }
data: [DONE]
event: event3
id: 1
data: {"index":3}
"""#,
output: [
.init(event: "event1", data: TestEvent(index: 1), id: "1"),
.init(event: "event2", data: TestEvent(index: 2), id: "2"),
],
while: { incomingData in incomingData != ArraySlice<UInt8>(Data("[DONE]".utf8)) }
)
}
}

Expand Down

0 comments on commit d604dd0

Please sign in to comment.