Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel DispatchSource before closing socket (#4791) #4859

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CoreFoundation/URL.subproj/CFURLSessionInterface.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ CFURLSessionEasyCode CFURLSession_easy_setopt_tc(CFURLSessionEasyHandle _Nonnull
return MakeEasyCode(curl_easy_setopt(curl, option.value, a));
}

CFURLSessionEasyCode CFURLSession_easy_setopt_scl(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionOption option, CFURLSessionCloseSocketCallback * _Nullable a) {
return MakeEasyCode(curl_easy_setopt(curl, option.value, a));
}

CFURLSessionEasyCode CFURLSession_easy_getinfo_long(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionInfo info, long *_Nonnull a) {
return MakeEasyCode(curl_easy_getinfo(curl, info.value, a));
}
Expand Down
2 changes: 2 additions & 0 deletions CoreFoundation/URL.subproj/CFURLSessionInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,8 @@ typedef int (CFURLSessionSeekCallback)(void *_Nullable userp, long long offset,
CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_setopt_seek(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionOption option, CFURLSessionSeekCallback * _Nullable a);
typedef int (CFURLSessionTransferInfoCallback)(void *_Nullable userp, long long dltotal, long long dlnow, long long ultotal, long long ulnow);
CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_setopt_tc(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionOption option, CFURLSessionTransferInfoCallback * _Nullable a);
typedef int (CFURLSessionCloseSocketCallback)(void *_Nullable clientp, CFURLSession_socket_t item);
CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_setopt_scl(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionOption option, CFURLSessionCloseSocketCallback * _Nullable a);

CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_getinfo_long(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionInfo info, long *_Nonnull a);
CF_EXPORT CFURLSessionEasyCode CFURLSession_easy_getinfo_double(CFURLSessionEasyHandle _Nonnull curl, CFURLSessionInfo info, double *_Nonnull a);
Expand Down
165 changes: 146 additions & 19 deletions Sources/FoundationNetworking/URLSession/libcurl/MultiHandle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ extension URLSession {
let queue: DispatchQueue
let group = DispatchGroup()
fileprivate var easyHandles: [_EasyHandle] = []
fileprivate var socketReferences: [CFURLSession_socket_t: _SocketReference] = [:]
fileprivate var timeoutSource: _TimeoutSource? = nil
private var reentrantInUpdateTimeoutTimer = false

Expand Down Expand Up @@ -127,13 +128,14 @@ fileprivate extension URLSession._MultiHandle {
if let opaque = socketSourcePtr {
Unmanaged<_SocketSources>.fromOpaque(opaque).release()
}
socketSources?.tearDown(handle: self, socket: socket, queue: queue)
socketSources = nil
}
if let ss = socketSources {
let handler = DispatchWorkItem { [weak self] in
self?.performAction(for: socket)
}
ss.createSources(with: action, socket: socket, queue: queue, handler: handler)
ss.createSources(with: action, handle: self, socket: socket, queue: queue, handler: handler)
}
return 0
}
Expand Down Expand Up @@ -161,9 +163,104 @@ extension Collection where Element == _EasyHandle {
}
}

private extension URLSession._MultiHandle {
class _SocketReference {
let socket: CFURLSession_socket_t
var shouldClose: Bool
var workItem: DispatchWorkItem?

init(socket: CFURLSession_socket_t) {
self.socket = socket
shouldClose = false
}

deinit {
if shouldClose {
#if os(Windows)
closesocket(socket)
#else
close(socket)
#endif
}
}
}

/// Creates and stores socket reference. Reentrancy is not supported.
/// Trying to begin operation for same socket twice would mean something
/// went horribly wrong, or our assumptions about CURL register/unregister
/// action flow are nor correct.
func beginOperation(for socket: CFURLSession_socket_t) -> _SocketReference {
let reference = _SocketReference(socket: socket)
precondition(socketReferences.updateValue(reference, forKey: socket) == nil, "Reentrancy is not supported for socket operations")
return reference
}

/// Removes socket reference from the shared store. If there is work item scheduled,
/// executes it on the current thread.
func endOperation(for socketReference: _SocketReference) {
precondition(socketReferences.removeValue(forKey: socketReference.socket) != nil, "No operation associated with the socket")
if let workItem = socketReference.workItem, !workItem.isCancelled {
// CURL never asks for socket close without unregistering first, and
// we should cancel pending work when unregister action is requested.
precondition(!socketReference.shouldClose, "Socket close was scheduled, but there is some pending work left")
workItem.perform()
}
}

/// Marks this reference to close socket on deinit. This allows us
/// to extend socket lifecycle by keeping the reference alive.
func scheduleClose(for socket: CFURLSession_socket_t) {
let reference = socketReferences[socket] ?? _SocketReference(socket: socket)
reference.shouldClose = true
}

/// Schedules work to be performed when an operation ends for the socket,
/// or performs it immediately if there is no operation in progress.
///
/// We're using this to postpone Dispatch Source creation when
/// previous Dispatch Source is not cancelled yet.
func schedule(_ workItem: DispatchWorkItem, for socket: CFURLSession_socket_t) {
guard let socketReference = socketReferences[socket] else {
workItem.perform()
return
}
// CURL never asks for register without pairing it with unregister later,
// and we're cancelling pending work item on unregister.
// But it is safe to just drop existing work item anyway,
// and replace it with the new one.
socketReference.workItem = workItem
}

/// Cancels pending work for socket operation. Does nothing if
/// there is no operation in progress or no pending work item.
///
/// CURL may become not interested in Dispatch Sources
/// we have planned to create. In this case we should just cancel
/// scheduled work.
func cancelWorkItem(for socket: CFURLSession_socket_t) {
guard let socketReference = socketReferences[socket] else {
return
}
socketReference.workItem?.cancel()
socketReference.workItem = nil
}

}

internal extension URLSession._MultiHandle {
/// Add an easy handle -- start its transfer.
func add(_ handle: _EasyHandle) {
// Set CLOSESOCKETFUNCTION. Note that while the option belongs to easy_handle,
// the connection cache is managed by CURL multi_handle, and sockets can actually
// outlive easy_handle (even after curl_easy_cleanup call). That's why
// socket management lives in _MultiHandle.
try! CFURLSession_easy_setopt_ptr(handle.rawHandle, CFURLSessionOptionCLOSESOCKETDATA, UnsafeMutableRawPointer(Unmanaged.passUnretained(self).toOpaque())).asError()
try! CFURLSession_easy_setopt_scl(handle.rawHandle, CFURLSessionOptionCLOSESOCKETFUNCTION) { (clientp: UnsafeMutableRawPointer?, item: CFURLSession_socket_t) in
guard let handle = URLSession._MultiHandle.from(callbackUserData: clientp) else { fatalError() }
handle.scheduleClose(for: item)
return 0
}.asError()

// If this is the first handle being added, we need to `kick` the
// underlying multi handle by calling `timeoutTimerFired` as
// described in
Expand Down Expand Up @@ -359,7 +456,7 @@ class _TimeoutSource {
let delay = UInt64(max(1, milliseconds - 1))
let start = DispatchTime.now() + DispatchTimeInterval.milliseconds(Int(delay))

rawSource.schedule(deadline: start, repeating: .milliseconds(Int(delay)), leeway: (milliseconds == 1) ? .microseconds(Int(1)) : .milliseconds(Int(1)))
rawSource.schedule(deadline: start, repeating: .never, leeway: (milliseconds == 1) ? .microseconds(Int(1)) : .milliseconds(Int(1)))
rawSource.setEventHandler(handler: handler)
rawSource.resume()
}
Expand All @@ -384,13 +481,12 @@ fileprivate extension URLSession._MultiHandle {
timeoutSource = nil
queue.async { self.timeoutTimerFired() }
case .milliseconds(let milliseconds):
if (timeoutSource == nil) || timeoutSource!.milliseconds != milliseconds {
//TODO: Could simply change the existing timer by using DispatchSourceTimer again.
let block = DispatchWorkItem { [weak self] in
self?.timeoutTimerFired()
}
timeoutSource = _TimeoutSource(queue: queue, milliseconds: milliseconds, handler: block)
//TODO: Could simply change the existing timer by using DispatchSourceTimer again.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it has to be done or one issue will be created ?

Suggested change
//TODO: Could simply change the existing timer by using DispatchSourceTimer again.
// TODO: Could simply change the existing timer by using DispatchSourceTimer again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't really say on behalf on this comment author😞 This was here from the beginning. Guess it is better to leave as is to keep the diff less noisy.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let the Todo, I thought you've added it

let block = DispatchWorkItem { [weak self] in
self?.timeoutTimerFired()
}
// Note: Previous timer instance would cancel internal Dispatch timer in deinit
timeoutSource = _TimeoutSource(queue: queue, milliseconds: milliseconds, handler: block)
}
}
enum _Timeout {
Expand Down Expand Up @@ -449,25 +545,56 @@ fileprivate class _SocketSources {
s.resume()
}

func tearDown() {
if let s = readSource {
s.cancel()
func tearDown(handle: URLSession._MultiHandle, socket: CFURLSession_socket_t, queue: DispatchQueue) {
handle.cancelWorkItem(for: socket) // There could be pending register action which needs to be cancelled

guard readSource != nil || writeSource != nil else {
// This means that we have posponed (and already abandoned)
// sources creation.
return
}
readSource = nil
if let s = writeSource {
s.cancel()

// Socket is guaranteed to not to be closed as long as we keeping
// the reference.
let socketReference = handle.beginOperation(for: socket)
let cancelHandlerGroup = DispatchGroup()
[readSource, writeSource].compactMap({ $0 }).forEach { source in
cancelHandlerGroup.enter()
source.setCancelHandler {
cancelHandlerGroup.leave()
}
source.cancel()
}
cancelHandlerGroup.notify(queue: queue) {
handle.endOperation(for: socketReference)
}

readSource = nil
writeSource = nil
}
}
extension _SocketSources {
/// Create a read and/or write source as specified by the action.
func createSources(with action: URLSession._MultiHandle._SocketRegisterAction, socket: CFURLSession_socket_t, queue: DispatchQueue, handler: DispatchWorkItem) {
if action.needsReadSource {
createReadSource(socket: socket, queue: queue, handler: handler)
func createSources(with action: URLSession._MultiHandle._SocketRegisterAction, handle: URLSession._MultiHandle, socket: CFURLSession_socket_t, queue: DispatchQueue, handler: DispatchWorkItem) {
// CURL casually requests to unregister and register handlers for same
// socket in a row. There is (pretty low) chance of overlapping tear-down operation
// with "register" request. Bad things could happen if we create
// a new Dispatch Source while other is being cancelled for the same socket.
// We're using `_MultiHandle.schedule(_:for:)` here to postpone sources creation until
// pending operation is finished (if there is none, submitted work item is performed
// immediately).
// Also, CURL may request unregister even before we perform any postponed work,
// so we have to cancel such work in such case. See
let createSources = DispatchWorkItem {
if action.needsReadSource {
self.createReadSource(socket: socket, queue: queue, handler: handler)
}
if action.needsWriteSource {
self.createWriteSource(socket: socket, queue: queue, handler: handler)
}
}
if action.needsWriteSource {
createWriteSource(socket: socket, queue: queue, handler: handler)
if action.needsReadSource || action.needsWriteSource {
handle.schedule(createSources, for: socket)
}
}
}
Expand Down
54 changes: 43 additions & 11 deletions Tests/Foundation/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class _TCPSocket: CustomStringConvertible {
listening = false
}

init(port: UInt16?) throws {
init(port: UInt16?, backlog: Int32) throws {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you choose Int32 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propagated backlog parameter from listen function signature, and it is Int32 there. I guess it is imported from C int, which is CInt in Swift, which, in turn, is the alias for Int32. 🤔 Do you think it is better to use CInt here instead?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation 🙌

No for me Int32 is the good choice, I asked only for know how did you make your choice.

listening = true
self.port = 0

Expand All @@ -124,7 +124,7 @@ class _TCPSocket: CustomStringConvertible {
try socketAddress.withMemoryRebound(to: sockaddr.self, capacity: MemoryLayout<sockaddr>.size, {
let addr = UnsafePointer<sockaddr>($0)
_ = try attempt("bind", valid: isZero, bind(_socket, addr, socklen_t(MemoryLayout<sockaddr>.size)))
_ = try attempt("listen", valid: isZero, listen(_socket, SOMAXCONN))
_ = try attempt("listen", valid: isZero, listen(_socket, backlog))
})

var actualSA = sockaddr_in()
Expand Down Expand Up @@ -295,8 +295,8 @@ class _HTTPServer: CustomStringConvertible {
let tcpSocket: _TCPSocket
var port: UInt16 { tcpSocket.port }

init(port: UInt16?) throws {
tcpSocket = try _TCPSocket(port: port)
init(port: UInt16?, backlog: Int32 = SOMAXCONN) throws {
tcpSocket = try _TCPSocket(port: port, backlog: backlog)
}

init(socket: _TCPSocket) {
Expand Down Expand Up @@ -1094,15 +1094,32 @@ enum InternalServerError : Error {
case badHeaders
}

extension LoopbackServerTest {
struct Options {
var serverBacklog: Int32
var isAsynchronous: Bool

static let `default` = Options(serverBacklog: SOMAXCONN, isAsynchronous: true)
}
}

class LoopbackServerTest : XCTestCase {
private static let staticSyncQ = DispatchQueue(label: "org.swift.TestFoundation.HTTPServer.StaticSyncQ")

private static var _serverPort: Int = -1
private static var _serverActive = false
private static var testServer: _HTTPServer? = nil


private static var _options: Options = .default

static var options: Options {
get {
return staticSyncQ.sync { _options }
}
set {
staticSyncQ.sync { _options = newValue }
}
}

static var serverPort: Int {
get {
return staticSyncQ.sync { _serverPort }
Expand All @@ -1119,27 +1136,42 @@ class LoopbackServerTest : XCTestCase {

override class func setUp() {
super.setUp()
Self.startServer()
}

override class func tearDown() {
Self.stopServer()
super.tearDown()
}

static func startServer() {
var _serverPort = 0
let dispatchGroup = DispatchGroup()

func runServer() throws {
testServer = try _HTTPServer(port: nil)
testServer = try _HTTPServer(port: nil, backlog: options.serverBacklog)
_serverPort = Int(testServer!.port)
serverActive = true
dispatchGroup.leave()

while serverActive {
do {
let httpServer = try testServer!.listen()
globalDispatchQueue.async {

func handleRequest() {
let subServer = TestURLSessionServer(httpServer: httpServer)
do {
try subServer.readAndRespond()
} catch {
NSLog("readAndRespond: \(error)")
}
}

if options.isAsynchronous {
globalDispatchQueue.async(execute: handleRequest)
} else {
handleRequest()
}
} catch {
if (serverActive) { // Ignore errors thrown on shutdown
NSLog("httpServer: \(error)")
Expand All @@ -1165,11 +1197,11 @@ class LoopbackServerTest : XCTestCase {
fatalError("Timedout waiting for server to be ready")
}
serverPort = _serverPort
debugLog("Listening on \(serverPort)")
}

override class func tearDown() {
static func stopServer() {
serverActive = false
try? testServer?.stop()
super.tearDown()
}
}
Loading