From bf717462d969a1b024d731261c29a27e3965d390 Mon Sep 17 00:00:00 2001 From: hank121314 Date: Fri, 2 Aug 2024 20:49:53 +0800 Subject: [PATCH] Refactor internal `UserDefaults` observation class (#181) Co-authored-by: Sindre Sorhus --- Sources/Defaults/Defaults+iCloud.swift | 10 +- Sources/Defaults/Defaults.swift | 4 +- Sources/Defaults/Observation+Combine.swift | 8 +- Sources/Defaults/Observation.swift | 313 ++++++------------ .../DefaultsTests/DefaultsSwiftUITests.swift | 1 + 5 files changed, 113 insertions(+), 223 deletions(-) diff --git a/Sources/Defaults/Defaults+iCloud.swift b/Sources/Defaults/Defaults+iCloud.swift index 2f24c27..6efe552 100644 --- a/Sources/Defaults/Defaults+iCloud.swift +++ b/Sources/Defaults/Defaults+iCloud.swift @@ -249,11 +249,11 @@ final class iCloudSynchronizer { @Atomic(value: []) private var remoteSyncingKeys: Set // TODO: Replace it with async stream when Swift supports custom executors. - private lazy var localKeysMonitor: Defaults.CompositeUserDefaultsAnyKeyObservation = .init { [weak self] observable in + private lazy var localKeysMonitor: Defaults.CompositeDefaultsObservation = .init { [weak self] pair, _ in guard let self, - let suite = observable.suite, - let key = keys.first(where: { $0.name == observable.key && $0.suite == suite }), + let suite = pair.suite, + let key = keys.first(where: { $0.name == pair.key && $0.suite == suite }), // Prevent triggering local observation when syncing from remote. !remoteSyncingKeys.contains(key) else { @@ -273,7 +273,7 @@ final class iCloudSynchronizer { self.keys.formUnion(keys) syncWithoutWaiting(keys) for key in keys { - localKeysMonitor.addObserver(key) + localKeysMonitor.add(key: key) } } @@ -283,7 +283,7 @@ final class iCloudSynchronizer { func remove(_ keys: [Defaults.Keys]) { self.keys.subtract(keys) for key in keys { - localKeysMonitor.removeObserver(key) + localKeysMonitor.remove(key: key) } } diff --git a/Sources/Defaults/Defaults.swift b/Sources/Defaults/Defaults.swift index 88b040d..6c37700 100644 --- a/Sources/Defaults/Defaults.swift +++ b/Sources/Defaults/Defaults.swift @@ -239,7 +239,7 @@ extension Defaults { initial: Bool = true ) -> AsyncStream { // TODO: Make this `some AsyncSequence` when Swift 6 is out. .init { continuation in - let observation = UserDefaultsKeyObservation2(object: key.suite, key: key.name) { change in + let observation = DefaultsObservation(object: key.suite, key: key.name) { _, change in // TODO: Use the `.deserialize` method directly. let value = KeyChange(change: change, defaultValue: key.defaultValue).newValue continuation.yield(value) @@ -275,7 +275,7 @@ extension Defaults { ) -> AsyncStream { // TODO: Make this `some AsyncSequence` when Swift 6 is out. .init { continuation in let observations = keys.indexed().map { index, key in - let observation = UserDefaultsKeyObservation2(object: key.suite, key: key.name) { _ in + let observation = DefaultsObservation(object: key.suite, key: key.name) { _, _ in continuation.yield() } diff --git a/Sources/Defaults/Observation+Combine.swift b/Sources/Defaults/Observation+Combine.swift index 18a99c5..5a1feaf 100644 --- a/Sources/Defaults/Observation+Combine.swift +++ b/Sources/Defaults/Observation+Combine.swift @@ -7,16 +7,16 @@ extension Defaults { */ final class DefaultsSubscription: Subscription where SubscriberType.Input == BaseChange { private var subscriber: SubscriberType? - private var observation: UserDefaultsKeyObservation? + private var observation: DefaultsObservationWithLifeTime? private let options: ObservationOptions init(subscriber: SubscriberType, suite: UserDefaults, key: String, options: ObservationOptions) { self.subscriber = subscriber self.options = options - self.observation = UserDefaultsKeyObservation( + self.observation = DefaultsObservationWithLifeTime( object: suite, key: key, - callback: observationCallback(_:) + observationCallback ) } @@ -33,7 +33,7 @@ extension Defaults { observation?.start(options: options) } - private func observationCallback(_ change: BaseChange) { + private func observationCallback(_: SuiteKeyPair, change: BaseChange) { _ = subscriber?.receive(change) } } diff --git a/Sources/Defaults/Observation.swift b/Sources/Defaults/Observation.swift index 0d9a98d..89375cd 100644 --- a/Sources/Defaults/Observation.swift +++ b/Sources/Defaults/Observation.swift @@ -119,95 +119,49 @@ extension Defaults { Thread.current.threadDictionary[key] = false } - final class UserDefaultsKeyObservation: NSObject, Observation { - typealias Callback = (BaseChange) -> Void - - private weak var object: UserDefaults? - private let key: String - private let callback: Callback - private var isObserving = false + final class SuiteKeyPair: Hashable { + weak var suite: UserDefaults? + let key: String - init(object: UserDefaults, key: String, callback: @escaping Callback) { - self.object = object + init(suite: UserDefaults, key: String) { + self.suite = suite self.key = key - self.callback = callback - } - - deinit { - invalidate() - } - - func start(options: ObservationOptions) { - object?.addObserver(self, forKeyPath: key, options: options.toNSKeyValueObservingOptions, context: nil) - isObserving = true - } - - func invalidate() { - if isObserving { - object?.removeObserver(self, forKeyPath: key, context: nil) - isObserving = false - } - - object = nil - lifetimeAssociation?.cancel() } - private var lifetimeAssociation: LifetimeAssociation? - - func tieToLifetime(of weaklyHeldObject: AnyObject) -> Self { - // swiftlint:disable:next trailing_closure - lifetimeAssociation = LifetimeAssociation(of: self, with: weaklyHeldObject, deinitHandler: { [weak self] in - self?.invalidate() - }) - - return self - } - - func removeLifetimeTie() { - lifetimeAssociation?.cancel() + func hash(into hasher: inout Hasher) { + hasher.combine(key) + hasher.combine(suite) } - // swiftlint:disable:next block_based_kvo - override func observeValue( - forKeyPath keyPath: String?, - of object: Any?, - change: [NSKeyValueChangeKey: Any]?, // swiftlint:disable:this discouraged_optional_collection - context: UnsafeMutableRawPointer? - ) { - guard let selfObject = self.object else { - invalidate() - return - } - - guard - selfObject == (object as? NSObject), - let change - else { - return - } - - let key = preventPropagationThreadDictionaryKey - let updatingValuesFlag = (Thread.current.threadDictionary[key] as? Bool) ?? false - guard !updatingValuesFlag else { - return - } - - callback(BaseChange(change: change)) + static func == (lhs: SuiteKeyPair, rhs: SuiteKeyPair) -> Bool { + lhs.key == rhs.key + && lhs.suite == rhs.suite } } - // Same as the above, but without the lifetime utilities, which slows down invalidation and we don't need them for `.updates()`. - final class UserDefaultsKeyObservation2: NSObject { - typealias Callback = (BaseChange) -> Void + /** + Standard observation for `Defaults`. + The only class which handle the low level observation. + */ + final class DefaultsObservation: NSObject { + typealias Callback = (SuiteKeyPair, BaseChange) -> Void - private weak var object: UserDefaults? - private let key: String + static var observationContext = 0 + private weak var suite: UserDefaults? + private let name: String private let callback: Callback private var isObserving = false + private let lock: Lock = .make() - init(object: UserDefaults, key: String, callback: @escaping Callback) { - self.object = object - self.key = key + init(object: UserDefaults, key: String, _ callback: @escaping Callback) { + self.suite = object + self.name = key + self.callback = callback + } + + init(key: Defaults._AnyKey, _ callback: @escaping Callback) { + self.suite = key.suite + self.name = key.name self.callback = callback } @@ -216,17 +170,24 @@ extension Defaults { } func start(options: ObservationOptions) { - object?.addObserver(self, forKeyPath: key, options: options.toNSKeyValueObservingOptions, context: nil) - isObserving = true + lock.with { + guard !isObserving else { + return + } + suite?.addObserver(self, forKeyPath: name, options: options.toNSKeyValueObservingOptions, context: &Self.observationContext) + isObserving = true + } } func invalidate() { - if isObserving { - object?.removeObserver(self, forKeyPath: key, context: nil) + lock.with { + guard isObserving else { + return + } + suite?.removeObserver(self, forKeyPath: name) isObserving = false + suite = nil } - - object = nil } // swiftlint:disable:next block_based_kvo @@ -236,13 +197,20 @@ extension Defaults { change: [NSKeyValueChangeKey: Any]?, // swiftlint:disable:this discouraged_optional_collection context: UnsafeMutableRawPointer? ) { - guard let selfObject = self.object else { + guard + context == &Self.observationContext + else { + super.observeValue(forKeyPath: keyPath, of: object, change: change, context: context) + return + } + + guard let selfSuite = suite else { invalidate() return } guard - selfObject == (object as? NSObject), + selfSuite == (object as? UserDefaults), let change else { return @@ -250,45 +218,27 @@ extension Defaults { let key = preventPropagationThreadDictionaryKey let updatingValuesFlag = (Thread.current.threadDictionary[key] as? Bool) ?? false - guard !updatingValuesFlag else { + if updatingValuesFlag { return } - callback(BaseChange(change: change)) + callback(SuiteKeyPair(suite: selfSuite, key: name), BaseChange(change: change)) } } - final class SuiteKeyPair: Hashable { - weak var suite: UserDefaults? - let key: String - - init(suite: UserDefaults, key: String) { - self.suite = suite - self.key = key - } - - func hash(into hasher: inout Hasher) { - hasher.combine(key) - hasher.combine(suite) - } + /** + Observation that wraps `DefaultsObservation` and adds a lifetime association. + */ + final class DefaultsObservationWithLifeTime: Observation { + private var observation: DefaultsObservation + private var lifetimeAssociation: LifetimeAssociation? - static func == (lhs: SuiteKeyPair, rhs: SuiteKeyPair) -> Bool { - lhs.key == rhs.key - && lhs.suite == rhs.suite + init(object: UserDefaults, key: String, _ callback: @escaping DefaultsObservation.Callback) { + self.observation = .init(object: object, key: key, callback) } - } - - private final class CompositeUserDefaultsKeyObservation: NSObject, Observation { - private static var observationContext = 0 - private var observables: [SuiteKeyPair] - private var lifetimeAssociation: LifetimeAssociation? - private let callback: UserDefaultsKeyObservation.Callback - - init(observables: [(suite: UserDefaults, key: String)], callback: @escaping UserDefaultsKeyObservation.Callback) { - self.observables = observables.map { SuiteKeyPair(suite: $0.suite, key: $0.key) } - self.callback = callback - super.init() + init(key: Defaults._AnyKey, _ callback: @escaping DefaultsObservation.Callback) { + self.observation = .init(key: key, callback) } deinit { @@ -296,25 +246,15 @@ extension Defaults { } func start(options: ObservationOptions) { - for observable in observables { - observable.suite?.addObserver( - self, - forKeyPath: observable.key, - options: options.toNSKeyValueObservingOptions, - context: &Self.observationContext - ) - } + observation.start(options: options) } func invalidate() { - for observable in observables { - observable.suite?.removeObserver(self, forKeyPath: observable.key, context: &Self.observationContext) - observable.suite = nil - } - + observation.invalidate() lifetimeAssociation?.cancel() } + @discardableResult func tieToLifetime(of weaklyHeldObject: AnyObject) -> Self { // swiftlint:disable:next trailing_closure lifetimeAssociation = LifetimeAssociation(of: self, with: weaklyHeldObject, deinitHandler: { [weak self] in @@ -327,67 +267,49 @@ extension Defaults { func removeLifetimeTie() { lifetimeAssociation?.cancel() } + } - // swiftlint:disable:next block_based_kvo - override func observeValue( - forKeyPath keyPath: String?, - of object: Any?, - change: [NSKeyValueChangeKey: Any]?, // swiftlint:disable:this discouraged_optional_collection - context: UnsafeMutableRawPointer? - ) { - guard - context == &Self.observationContext - else { - super.observeValue(forKeyPath: keyPath, of: object, change: change, context: context) - return - } - - guard - object is UserDefaults, - let change - else { - return - } - - let key = preventPropagationThreadDictionaryKey - let updatingValuesFlag = (Thread.current.threadDictionary[key] as? Bool) ?? false - if updatingValuesFlag { - return - } + /** + Observation that manages multiple `DefaultsObservation`. + Can add or remove the observed key dynamically. + */ + final class CompositeDefaultsObservation: Observation { + private var observations: Set = [] + private let callback: DefaultsObservation.Callback + private var lifetimeAssociation: LifetimeAssociation? - callback(BaseChange(change: change)) + init(_ callback: @escaping DefaultsObservation.Callback) { + self.callback = callback } - } - final class CompositeUserDefaultsAnyKeyObservation: NSObject, Observation { - typealias Callback = (SuiteKeyPair) -> Void - private static var observationContext = 1 + deinit { + invalidate() + } - private var observables: Set = [] - private var lifetimeAssociation: LifetimeAssociation? - private let callback: CompositeUserDefaultsAnyKeyObservation.Callback + func start(options: ObservationOptions) { + observations.forEach { $0.start(options: options) } + } - init(_ callback: @escaping CompositeUserDefaultsAnyKeyObservation.Callback) { - self.callback = callback + func invalidate() { + observations.forEach { $0.invalidate() } + lifetimeAssociation?.cancel() } - func addObserver(_ key: Defaults._AnyKey, options: ObservationOptions = []) { - let keyPair: SuiteKeyPair = .init(suite: key.suite, key: key.name) - let (inserted, observable) = observables.insert(keyPair) - guard inserted else { + func add(key: Defaults._AnyKey, options: ObservationOptions = []) { + let (isInserted, observation) = observations.insert(DefaultsObservation(key: key, callback)) + guard isInserted else { return } - observable.suite?.addObserver(self, forKeyPath: observable.key, options: options.toNSKeyValueObservingOptions, context: &Self.observationContext) + observation.start(options: options) } - func removeObserver(_ key: Defaults._AnyKey) { - let keyPair: SuiteKeyPair = .init(suite: key.suite, key: key.name) - guard let observable = observables.remove(keyPair) else { + func remove(key: Defaults._AnyKey) { + guard let observation = observations.remove(DefaultsObservation(key: key, self.callback)) else { return } - observable.suite?.removeObserver(self, forKeyPath: observable.key, context: &Self.observationContext) + observation.invalidate() } @discardableResult @@ -403,41 +325,6 @@ extension Defaults { func removeLifetimeTie() { lifetimeAssociation?.cancel() } - - func invalidate() { - for observable in observables { - observable.suite?.removeObserver(self, forKeyPath: observable.key, context: &Self.observationContext) - observable.suite = nil - } - - observables.removeAll() - lifetimeAssociation?.cancel() - } - - // swiftlint:disable:next block_based_kvo - override func observeValue( - forKeyPath keyPath: String?, - of object: Any?, - change: [NSKeyValueChangeKey: Any]?, // swiftlint:disable:this discouraged_optional_collection - context: UnsafeMutableRawPointer? - ) { - guard - context == &Self.observationContext - else { - super.observeValue(forKeyPath: keyPath, of: object, change: change, context: context) - return - } - - guard - let object = object as? UserDefaults, - let keyPath, - let observable = observables.first(where: { $0.key == keyPath && $0.suite == object }) - else { - return - } - - callback(observable) - } } /** @@ -461,7 +348,7 @@ extension Defaults { options: ObservationOptions = [.initial], handler: @escaping (KeyChange) -> Void ) -> some Observation { - let observation = UserDefaultsKeyObservation(object: key.suite, key: key.name) { change in + let observation = DefaultsObservationWithLifeTime(key: key) { _, change in handler( KeyChange(change: change, defaultValue: key.defaultValue) ) @@ -491,12 +378,14 @@ extension Defaults { options: ObservationOptions = [.initial], handler: @escaping () -> Void ) -> some Observation { - let pairs = keys.map { - (suite: $0.suite, key: $0.name) - } - let compositeObservation = CompositeUserDefaultsKeyObservation(observables: pairs) { _ in + let compositeObservation = CompositeDefaultsObservation { _, _ in handler() } + + for key in keys { + compositeObservation.add(key: key) + } + compositeObservation.start(options: options) return compositeObservation diff --git a/Tests/DefaultsTests/DefaultsSwiftUITests.swift b/Tests/DefaultsTests/DefaultsSwiftUITests.swift index cff0ebd..c189b0a 100644 --- a/Tests/DefaultsTests/DefaultsSwiftUITests.swift +++ b/Tests/DefaultsTests/DefaultsSwiftUITests.swift @@ -29,6 +29,7 @@ struct ContentView: View { } } + final class DefaultsSwiftUITests: XCTestCase { override func setUp() { super.setUp()